1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math.optimization;
19  
20  import org.apache.commons.math.optimization.ConvergenceChecker;
21  import org.apache.commons.math.optimization.CostException;
22  import org.apache.commons.math.optimization.CostFunction;
23  import org.apache.commons.math.optimization.MultiDirectional;
24  import org.apache.commons.math.ConvergenceException;
25  import org.apache.commons.math.optimization.PointCostPair;
26  
27  import junit.framework.*;
28  
29  public class MultiDirectionalTest
30    extends TestCase {
31  
32    public MultiDirectionalTest(String name) {
33      super(name);
34    }
35  
36    public void testCostExceptions() throws ConvergenceException {
37        CostFunction wrong =
38            new CostFunction() {
39              public double cost(double[] x) throws CostException {
40                  if (x[0] < 0) {
41                      throw new CostException("{0}", new Object[] { "oops"});
42                  } else if (x[0] > 1) {
43                      throw new CostException(new RuntimeException("oops"));
44                  } else {
45                      return x[0] * (1 - x[0]);
46                  }
47              }
48        };
49        try {
50            new MultiDirectional(1.9, 0.4).minimize(wrong, 10, new ValueChecker(1.0e-3),
51                                                    new double[] { -0.5 }, new double[] { 0.5 });
52            fail("an exception should have been thrown");
53        } catch (CostException ce) {
54            // expected behavior
55            assertNull(ce.getCause());
56        } catch (Exception e) {
57            fail("wrong exception caught: " + e.getMessage());
58        } 
59        try {
60            new MultiDirectional(1.9, 0.4).minimize(wrong, 10, new ValueChecker(1.0e-3),
61                                                    new double[] { 0.5 }, new double[] { 1.5 });
62            fail("an exception should have been thrown");
63        } catch (CostException ce) {
64            // expected behavior
65            assertNotNull(ce.getCause());
66        } catch (Exception e) {
67            fail("wrong exception caught: " + e.getMessage());
68        } 
69    }
70  
71    public void testRosenbrock()
72      throws CostException, ConvergenceException {
73  
74      CostFunction rosenbrock =
75        new CostFunction() {
76          public double cost(double[] x) {
77            ++count;
78            double a = x[1] - x[0] * x[0];
79            double b = 1.0 - x[0];
80            return 100 * a * a + b * b;
81          }
82        };
83  
84      count = 0;
85      PointCostPair optimum =
86        new MultiDirectional().minimize(rosenbrock, 100, new ValueChecker(1.0e-3),
87                                        new double[][] {
88                                          { -1.2,  1.0 }, { 0.9, 1.2 } , {  3.5, -2.3 }
89                                        });
90  
91      assertTrue(count > 60);
92      assertTrue(optimum.getCost() > 0.01);
93  
94    }
95  
96    public void testPowell()
97      throws CostException, ConvergenceException {
98  
99      CostFunction powell =
100       new CostFunction() {
101         public double cost(double[] x) {
102           ++count;
103           double a = x[0] + 10 * x[1];
104           double b = x[2] - x[3];
105           double c = x[1] - 2 * x[2];
106           double d = x[0] - x[3];
107           return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
108         }
109       };
110 
111     count = 0;
112     PointCostPair optimum =
113       new MultiDirectional().minimize(powell, 1000, new ValueChecker(1.0e-3),
114                                       new double[] {  3.0, -1.0, 0.0, 1.0 },
115                                       new double[] {  4.0,  0.0, 1.0, 2.0 });
116     assertTrue(count > 850);
117     assertTrue(optimum.getCost() > 0.015);
118 
119   }
120 
121   private static class ValueChecker implements ConvergenceChecker {
122 
123     public ValueChecker(double threshold) {
124       this.threshold = threshold;
125     }
126 
127     public boolean converged(PointCostPair[] simplex) {
128       PointCostPair smallest = simplex[0];
129       PointCostPair largest  = simplex[simplex.length - 1];
130       return (largest.getCost() - smallest.getCost()) < threshold;
131     }
132 
133     private double threshold;
134 
135   };
136 
137   public static Test suite() {
138     return new TestSuite(MultiDirectionalTest.class);
139   }
140 
141   private int count;
142 
143 }