1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math.optimization;
19
20 import org.apache.commons.math.optimization.ConvergenceChecker;
21 import org.apache.commons.math.optimization.CostException;
22 import org.apache.commons.math.optimization.CostFunction;
23 import org.apache.commons.math.optimization.MultiDirectional;
24 import org.apache.commons.math.ConvergenceException;
25 import org.apache.commons.math.optimization.PointCostPair;
26
27 import junit.framework.*;
28
29 public class MultiDirectionalTest
30 extends TestCase {
31
32 public MultiDirectionalTest(String name) {
33 super(name);
34 }
35
36 public void testCostExceptions() throws ConvergenceException {
37 CostFunction wrong =
38 new CostFunction() {
39 public double cost(double[] x) throws CostException {
40 if (x[0] < 0) {
41 throw new CostException("{0}", new Object[] { "oops"});
42 } else if (x[0] > 1) {
43 throw new CostException(new RuntimeException("oops"));
44 } else {
45 return x[0] * (1 - x[0]);
46 }
47 }
48 };
49 try {
50 new MultiDirectional(1.9, 0.4).minimize(wrong, 10, new ValueChecker(1.0e-3),
51 new double[] { -0.5 }, new double[] { 0.5 });
52 fail("an exception should have been thrown");
53 } catch (CostException ce) {
54
55 assertNull(ce.getCause());
56 } catch (Exception e) {
57 fail("wrong exception caught: " + e.getMessage());
58 }
59 try {
60 new MultiDirectional(1.9, 0.4).minimize(wrong, 10, new ValueChecker(1.0e-3),
61 new double[] { 0.5 }, new double[] { 1.5 });
62 fail("an exception should have been thrown");
63 } catch (CostException ce) {
64
65 assertNotNull(ce.getCause());
66 } catch (Exception e) {
67 fail("wrong exception caught: " + e.getMessage());
68 }
69 }
70
71 public void testRosenbrock()
72 throws CostException, ConvergenceException {
73
74 CostFunction rosenbrock =
75 new CostFunction() {
76 public double cost(double[] x) {
77 ++count;
78 double a = x[1] - x[0] * x[0];
79 double b = 1.0 - x[0];
80 return 100 * a * a + b * b;
81 }
82 };
83
84 count = 0;
85 PointCostPair optimum =
86 new MultiDirectional().minimize(rosenbrock, 100, new ValueChecker(1.0e-3),
87 new double[][] {
88 { -1.2, 1.0 }, { 0.9, 1.2 } , { 3.5, -2.3 }
89 });
90
91 assertTrue(count > 60);
92 assertTrue(optimum.getCost() > 0.01);
93
94 }
95
96 public void testPowell()
97 throws CostException, ConvergenceException {
98
99 CostFunction powell =
100 new CostFunction() {
101 public double cost(double[] x) {
102 ++count;
103 double a = x[0] + 10 * x[1];
104 double b = x[2] - x[3];
105 double c = x[1] - 2 * x[2];
106 double d = x[0] - x[3];
107 return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
108 }
109 };
110
111 count = 0;
112 PointCostPair optimum =
113 new MultiDirectional().minimize(powell, 1000, new ValueChecker(1.0e-3),
114 new double[] { 3.0, -1.0, 0.0, 1.0 },
115 new double[] { 4.0, 0.0, 1.0, 2.0 });
116 assertTrue(count > 850);
117 assertTrue(optimum.getCost() > 0.015);
118
119 }
120
121 private static class ValueChecker implements ConvergenceChecker {
122
123 public ValueChecker(double threshold) {
124 this.threshold = threshold;
125 }
126
127 public boolean converged(PointCostPair[] simplex) {
128 PointCostPair smallest = simplex[0];
129 PointCostPair largest = simplex[simplex.length - 1];
130 return (largest.getCost() - smallest.getCost()) < threshold;
131 }
132
133 private double threshold;
134
135 };
136
137 public static Test suite() {
138 return new TestSuite(MultiDirectionalTest.class);
139 }
140
141 private int count;
142
143 }