1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math.optimization;
19  
20  import org.apache.commons.math.optimization.ConvergenceChecker;
21  import org.apache.commons.math.optimization.CostException;
22  import org.apache.commons.math.optimization.CostFunction;
23  import org.apache.commons.math.optimization.NelderMead;
24  import org.apache.commons.math.ConvergenceException;
25  import org.apache.commons.math.optimization.PointCostPair;
26  import org.apache.commons.math.random.JDKRandomGenerator;
27  import org.apache.commons.math.random.NotPositiveDefiniteMatrixException;
28  import org.apache.commons.math.random.RandomGenerator;
29  import org.apache.commons.math.random.RandomVectorGenerator;
30  import org.apache.commons.math.random.UncorrelatedRandomVectorGenerator;
31  import org.apache.commons.math.random.UniformRandomGenerator;
32  
33  import junit.framework.*;
34  
35  public class NelderMeadTest
36    extends TestCase {
37  
38    public NelderMeadTest(String name) {
39      super(name);
40    }
41  
42    public void testCostExceptions() throws ConvergenceException {
43        CostFunction wrong =
44            new CostFunction() {
45              public double cost(double[] x) throws CostException {
46                  if (x[0] < 0) {
47                      throw new CostException("{0}", new Object[] { "oops"});
48                  } else if (x[0] > 1) {
49                      throw new CostException(new RuntimeException("oops"));
50                  } else {
51                      return x[0] * (1 - x[0]);
52                  }
53              }
54        };
55        try {
56            new NelderMead(0.9, 1.9, 0.4, 0.6).minimize(wrong, 10, new ValueChecker(1.0e-3),
57                                                        new double[] { -0.5 }, new double[] { 0.5 });
58            fail("an exception should have been thrown");
59        } catch (CostException ce) {
60            // expected behavior
61            assertNull(ce.getCause());
62        } catch (Exception e) {
63            fail("wrong exception caught: " + e.getMessage());
64        } 
65        try {
66            new NelderMead(0.9, 1.9, 0.4, 0.6).minimize(wrong, 10, new ValueChecker(1.0e-3),
67                                                        new double[] { 0.5 }, new double[] { 1.5 });
68            fail("an exception should have been thrown");
69        } catch (CostException ce) {
70            // expected behavior
71            assertNotNull(ce.getCause());
72        } catch (Exception e) {
73            fail("wrong exception caught: " + e.getMessage());
74        } 
75    }
76  
77    public void testRosenbrock()
78      throws CostException, ConvergenceException, NotPositiveDefiniteMatrixException {
79  
80      CostFunction rosenbrock =
81        new CostFunction() {
82          public double cost(double[] x) {
83            ++count;
84            double a = x[1] - x[0] * x[0];
85            double b = 1.0 - x[0];
86            return 100 * a * a + b * b;
87          }
88        };
89  
90      count = 0;
91      NelderMead nm = new NelderMead();
92      try {
93        nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3),
94                    new double[][] {
95                      { -1.2, 1.0 }, { 3.5, -2.3 }, { 0.4, 1.5 }
96                    }, 1, 5384353l);
97        fail("an exception should have been thrown");
98      } catch (ConvergenceException ce) {
99          // expected behavior
100     } catch (Exception e) {
101         fail("wrong exception caught: " + e.getMessage());
102     }
103 
104     count = 0;
105     PointCostPair optimum =
106         nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3),
107                     new double[][] {
108                       { -1.2, 1.0 }, { 0.9, 1.2 }, { 3.5, -2.3 }
109                     }, 10, 1642738l);
110 
111     assertTrue(count > 700);
112     assertTrue(count < 800);
113     assertEquals(0.0, optimum.getCost(), 5.0e-5);
114     assertEquals(1.0, optimum.getPoint()[0], 0.01);
115     assertEquals(1.0, optimum.getPoint()[1], 0.01);
116 
117     PointCostPair[] minima = nm.getMinima();
118     assertEquals(10, minima.length);
119     assertNotNull(minima[0]);
120     assertNull(minima[minima.length - 1]);
121     for (int i = 0; i < minima.length; ++i) {
122         if (minima[i] == null) {
123             if ((i + 1) < minima.length) {
124                 assertTrue(minima[i+1] == null);
125             }
126         } else {
127             if (i > 0) {
128                 assertTrue(minima[i-1].getCost() <= minima[i].getCost());
129             }
130         }
131     }
132 
133     RandomGenerator rg = new JDKRandomGenerator();
134     rg.setSeed(64453353l);
135     RandomVectorGenerator rvg =
136         new UncorrelatedRandomVectorGenerator(new double[] { 0.9, 1.1 },
137                                               new double[] { 0.2, 0.2 },
138                                               new UniformRandomGenerator(rg));
139     optimum =
140         nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3), rvg);
141     assertEquals(0.0, optimum.getCost(), 2.0e-4);
142     optimum =
143         nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3), rvg, 3);
144     assertEquals(0.0, optimum.getCost(), 3.0e-5);
145 
146   }
147 
148   public void testPowell()
149     throws CostException, ConvergenceException {
150 
151     CostFunction powell =
152       new CostFunction() {
153         public double cost(double[] x) {
154           ++count;
155           double a = x[0] + 10 * x[1];
156           double b = x[2] - x[3];
157           double c = x[1] - 2 * x[2];
158           double d = x[0] - x[3];
159           return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
160         }
161       };
162 
163     count = 0;
164     NelderMead nm = new NelderMead();
165     PointCostPair optimum =
166       nm.minimize(powell, 200, new ValueChecker(1.0e-3),
167                   new double[] {  3.0, -1.0, 0.0, 1.0 },
168                   new double[] {  4.0,  0.0, 1.0, 2.0 },
169                   1, 1642738l);
170     assertTrue(count < 150);
171     assertEquals(0.0, optimum.getCost(), 6.0e-4);
172     assertEquals(0.0, optimum.getPoint()[0], 0.07);
173     assertEquals(0.0, optimum.getPoint()[1], 0.07);
174     assertEquals(0.0, optimum.getPoint()[2], 0.07);
175     assertEquals(0.0, optimum.getPoint()[3], 0.07);
176 
177   }
178 
179   private static class ValueChecker implements ConvergenceChecker {
180 
181     public ValueChecker(double threshold) {
182       this.threshold = threshold;
183     }
184 
185     public boolean converged(PointCostPair[] simplex) {
186       PointCostPair smallest = simplex[0];
187       PointCostPair largest  = simplex[simplex.length - 1];
188       return (largest.getCost() - smallest.getCost()) < threshold;
189     }
190 
191     private double threshold;
192 
193   };
194 
195   public static Test suite() {
196     return new TestSuite(NelderMeadTest.class);
197   }
198 
199   private int count;
200 
201 }