1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math.optimization;
19
20 import org.apache.commons.math.optimization.ConvergenceChecker;
21 import org.apache.commons.math.optimization.CostException;
22 import org.apache.commons.math.optimization.CostFunction;
23 import org.apache.commons.math.optimization.NelderMead;
24 import org.apache.commons.math.ConvergenceException;
25 import org.apache.commons.math.optimization.PointCostPair;
26 import org.apache.commons.math.random.JDKRandomGenerator;
27 import org.apache.commons.math.random.NotPositiveDefiniteMatrixException;
28 import org.apache.commons.math.random.RandomGenerator;
29 import org.apache.commons.math.random.RandomVectorGenerator;
30 import org.apache.commons.math.random.UncorrelatedRandomVectorGenerator;
31 import org.apache.commons.math.random.UniformRandomGenerator;
32
33 import junit.framework.*;
34
35 public class NelderMeadTest
36 extends TestCase {
37
38 public NelderMeadTest(String name) {
39 super(name);
40 }
41
42 public void testCostExceptions() throws ConvergenceException {
43 CostFunction wrong =
44 new CostFunction() {
45 public double cost(double[] x) throws CostException {
46 if (x[0] < 0) {
47 throw new CostException("{0}", new Object[] { "oops"});
48 } else if (x[0] > 1) {
49 throw new CostException(new RuntimeException("oops"));
50 } else {
51 return x[0] * (1 - x[0]);
52 }
53 }
54 };
55 try {
56 new NelderMead(0.9, 1.9, 0.4, 0.6).minimize(wrong, 10, new ValueChecker(1.0e-3),
57 new double[] { -0.5 }, new double[] { 0.5 });
58 fail("an exception should have been thrown");
59 } catch (CostException ce) {
60
61 assertNull(ce.getCause());
62 } catch (Exception e) {
63 fail("wrong exception caught: " + e.getMessage());
64 }
65 try {
66 new NelderMead(0.9, 1.9, 0.4, 0.6).minimize(wrong, 10, new ValueChecker(1.0e-3),
67 new double[] { 0.5 }, new double[] { 1.5 });
68 fail("an exception should have been thrown");
69 } catch (CostException ce) {
70
71 assertNotNull(ce.getCause());
72 } catch (Exception e) {
73 fail("wrong exception caught: " + e.getMessage());
74 }
75 }
76
77 public void testRosenbrock()
78 throws CostException, ConvergenceException, NotPositiveDefiniteMatrixException {
79
80 CostFunction rosenbrock =
81 new CostFunction() {
82 public double cost(double[] x) {
83 ++count;
84 double a = x[1] - x[0] * x[0];
85 double b = 1.0 - x[0];
86 return 100 * a * a + b * b;
87 }
88 };
89
90 count = 0;
91 NelderMead nm = new NelderMead();
92 try {
93 nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3),
94 new double[][] {
95 { -1.2, 1.0 }, { 3.5, -2.3 }, { 0.4, 1.5 }
96 }, 1, 5384353l);
97 fail("an exception should have been thrown");
98 } catch (ConvergenceException ce) {
99
100 } catch (Exception e) {
101 fail("wrong exception caught: " + e.getMessage());
102 }
103
104 count = 0;
105 PointCostPair optimum =
106 nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3),
107 new double[][] {
108 { -1.2, 1.0 }, { 0.9, 1.2 }, { 3.5, -2.3 }
109 }, 10, 1642738l);
110
111 assertTrue(count > 700);
112 assertTrue(count < 800);
113 assertEquals(0.0, optimum.getCost(), 5.0e-5);
114 assertEquals(1.0, optimum.getPoint()[0], 0.01);
115 assertEquals(1.0, optimum.getPoint()[1], 0.01);
116
117 PointCostPair[] minima = nm.getMinima();
118 assertEquals(10, minima.length);
119 assertNotNull(minima[0]);
120 assertNull(minima[minima.length - 1]);
121 for (int i = 0; i < minima.length; ++i) {
122 if (minima[i] == null) {
123 if ((i + 1) < minima.length) {
124 assertTrue(minima[i+1] == null);
125 }
126 } else {
127 if (i > 0) {
128 assertTrue(minima[i-1].getCost() <= minima[i].getCost());
129 }
130 }
131 }
132
133 RandomGenerator rg = new JDKRandomGenerator();
134 rg.setSeed(64453353l);
135 RandomVectorGenerator rvg =
136 new UncorrelatedRandomVectorGenerator(new double[] { 0.9, 1.1 },
137 new double[] { 0.2, 0.2 },
138 new UniformRandomGenerator(rg));
139 optimum =
140 nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3), rvg);
141 assertEquals(0.0, optimum.getCost(), 2.0e-4);
142 optimum =
143 nm.minimize(rosenbrock, 100, new ValueChecker(1.0e-3), rvg, 3);
144 assertEquals(0.0, optimum.getCost(), 3.0e-5);
145
146 }
147
148 public void testPowell()
149 throws CostException, ConvergenceException {
150
151 CostFunction powell =
152 new CostFunction() {
153 public double cost(double[] x) {
154 ++count;
155 double a = x[0] + 10 * x[1];
156 double b = x[2] - x[3];
157 double c = x[1] - 2 * x[2];
158 double d = x[0] - x[3];
159 return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
160 }
161 };
162
163 count = 0;
164 NelderMead nm = new NelderMead();
165 PointCostPair optimum =
166 nm.minimize(powell, 200, new ValueChecker(1.0e-3),
167 new double[] { 3.0, -1.0, 0.0, 1.0 },
168 new double[] { 4.0, 0.0, 1.0, 2.0 },
169 1, 1642738l);
170 assertTrue(count < 150);
171 assertEquals(0.0, optimum.getCost(), 6.0e-4);
172 assertEquals(0.0, optimum.getPoint()[0], 0.07);
173 assertEquals(0.0, optimum.getPoint()[1], 0.07);
174 assertEquals(0.0, optimum.getPoint()[2], 0.07);
175 assertEquals(0.0, optimum.getPoint()[3], 0.07);
176
177 }
178
179 private static class ValueChecker implements ConvergenceChecker {
180
181 public ValueChecker(double threshold) {
182 this.threshold = threshold;
183 }
184
185 public boolean converged(PointCostPair[] simplex) {
186 PointCostPair smallest = simplex[0];
187 PointCostPair largest = simplex[simplex.length - 1];
188 return (largest.getCost() - smallest.getCost()) < threshold;
189 }
190
191 private double threshold;
192
193 };
194
195 public static Test suite() {
196 return new TestSuite(NelderMeadTest.class);
197 }
198
199 private int count;
200
201 }