View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math.optimization;
19  
20  import java.util.Arrays;
21  import java.util.Comparator;
22  
23  import org.apache.commons.math.ConvergenceException;
24  import org.apache.commons.math.DimensionMismatchException;
25  import org.apache.commons.math.linear.RealMatrix;
26  import org.apache.commons.math.random.CorrelatedRandomVectorGenerator;
27  import org.apache.commons.math.random.JDKRandomGenerator;
28  import org.apache.commons.math.random.NotPositiveDefiniteMatrixException;
29  import org.apache.commons.math.random.RandomGenerator;
30  import org.apache.commons.math.random.RandomVectorGenerator;
31  import org.apache.commons.math.random.UncorrelatedRandomVectorGenerator;
32  import org.apache.commons.math.random.UniformRandomGenerator;
33  import org.apache.commons.math.stat.descriptive.moment.VectorialCovariance;
34  import org.apache.commons.math.stat.descriptive.moment.VectorialMean;
35  
36  /** 
37   * This class implements simplex-based direct search optimization
38   * algorithms.
39   *
40   * <p>Direct search methods only use cost function values, they don't
41   * need derivatives and don't either try to compute approximation of
42   * the derivatives. According to a 1996 paper by Margaret H. Wright
43   * (<a href="http://cm.bell-labs.com/cm/cs/doc/96/4-02.ps.gz">Direct
44   * Search Methods: Once Scorned, Now Respectable</a>), they are used
45   * when either the computation of the derivative is impossible (noisy
46   * functions, unpredictable dicontinuities) or difficult (complexity,
47   * computation cost). In the first cases, rather than an optimum, a
48   * <em>not too bad</em> point is desired. In the latter cases, an
49   * optimum is desired but cannot be reasonably found. In all cases
50   * direct search methods can be useful.</p>
51   *
52   * <p>Simplex-based direct search methods are based on comparison of
53   * the cost function values at the vertices of a simplex (which is a
54   * set of n+1 points in dimension n) that is updated by the algorithms
55   * steps.</p>
56   *
57   * <p>Minimization can be attempted either in single-start or in
58   * multi-start mode. Multi-start is a traditional way to try to avoid
59   * being trapped in a local minimum and miss the global minimum of a
60   * function. It can also be used to verify the convergence of an
61   * algorithm. The various multi-start-enabled <code>minimize</code>
62   * methods return the best minimum found after all starts, and the
63   * {@link #getMinima getMinima} method can be used to retrieve all
64   * minima from all starts (including the one already provided by the
65   * {@link #minimize(CostFunction, int, ConvergenceChecker, double[],
66   * double[]) minimize} method).</p>
67   *
68   * <p>This class is the base class performing the boilerplate simplex
69   * initialization and handling. The simplex update by itself is
70   * performed by the derived classes according to the implemented
71   * algorithms.</p>
72   *
73   * @see CostFunction
74   * @see NelderMead
75   * @see MultiDirectional
76   * @version $Revision: 628000 $ $Date: 2008-02-15 03:31:48 -0700 (Fri, 15 Feb 2008) $
77   * @since 1.2
78   */
79  public abstract class DirectSearchOptimizer {
80  
81      /** Simple constructor.
82       */
83      protected DirectSearchOptimizer() {
84      }
85  
86      /** Minimizes a cost function.
87       * <p>The initial simplex is built from two vertices that are
88       * considered to represent two opposite vertices of a box parallel
89       * to the canonical axes of the space. The simplex is the subset of
90       * vertices encountered while going from vertexA to vertexB
91       * traveling along the box edges only. This can be seen as a scaled
92       * regular simplex using the projected separation between the given
93       * points as the scaling factor along each coordinate axis.</p>
94       * <p>The optimization is performed in single-start mode.</p>
95       * @param f cost function
96       * @param maxEvaluations maximal number of function calls for each
97       * start (note that the number will be checked <em>after</em>
98       * complete simplices have been evaluated, this means that in some
99       * cases this number will be exceeded by a few units, depending on
100      * the dimension of the problem)
101      * @param checker object to use to check for convergence
102      * @param vertexA first vertex
103      * @param vertexB last vertex
104      * @return the point/cost pairs giving the minimal cost
105      * @exception CostException if the cost function throws one during
106      * the search
107      * @exception ConvergenceException if none of the starts did
108      * converge (it is not thrown if at least one start did converge)
109      */
110     public PointCostPair minimize(CostFunction f, int maxEvaluations,
111                                   ConvergenceChecker checker,
112                                   double[] vertexA, double[] vertexB)
113     throws CostException, ConvergenceException {
114 
115         // set up optimizer
116         buildSimplex(vertexA, vertexB);
117         setSingleStart();
118 
119         // compute minimum
120         return minimize(f, maxEvaluations, checker);
121 
122     }
123 
124     /** Minimizes a cost function.
125      * <p>The initial simplex is built from two vertices that are
126      * considered to represent two opposite vertices of a box parallel
127      * to the canonical axes of the space. The simplex is the subset of
128      * vertices encountered while going from vertexA to vertexB
129      * traveling along the box edges only. This can be seen as a scaled
130      * regular simplex using the projected separation between the given
131      * points as the scaling factor along each coordinate axis.</p>
132      * <p>The optimization is performed in multi-start mode.</p>
133      * @param f cost function
134      * @param maxEvaluations maximal number of function calls for each
135      * start (note that the number will be checked <em>after</em>
136      * complete simplices have been evaluated, this means that in some
137      * cases this number will be exceeded by a few units, depending on
138      * the dimension of the problem)
139      * @param checker object to use to check for convergence
140      * @param vertexA first vertex
141      * @param vertexB last vertex
142      * @param starts number of starts to perform (including the
143      * first one), multi-start is disabled if value is less than or
144      * equal to 1
145      * @param seed seed for the random vector generator
146      * @return the point/cost pairs giving the minimal cost
147      * @exception CostException if the cost function throws one during
148      * the search
149      * @exception ConvergenceException if none of the starts did
150      * converge (it is not thrown if at least one start did converge)
151      */
152     public PointCostPair minimize(CostFunction f, int maxEvaluations,
153                                   ConvergenceChecker checker,
154                                   double[] vertexA, double[] vertexB,
155                                   int starts, long seed)
156     throws CostException, ConvergenceException {
157 
158         // set up the simplex traveling around the box
159         buildSimplex(vertexA, vertexB);
160 
161         // we consider the simplex could have been produced by a generator
162         // having its mean value at the center of the box, the standard
163         // deviation along each axe being the corresponding half size
164         double[] mean              = new double[vertexA.length];
165         double[] standardDeviation = new double[vertexA.length];
166         for (int i = 0; i < vertexA.length; ++i) {
167             mean[i]              = 0.5 * (vertexA[i] + vertexB[i]);
168             standardDeviation[i] = 0.5 * Math.abs(vertexA[i] - vertexB[i]);
169         }
170 
171         RandomGenerator rg = new JDKRandomGenerator();
172         rg.setSeed(seed);
173         UniformRandomGenerator urg = new UniformRandomGenerator(rg);
174         RandomVectorGenerator rvg =
175             new UncorrelatedRandomVectorGenerator(mean, standardDeviation, urg);
176         setMultiStart(starts, rvg);
177 
178         // compute minimum
179         return minimize(f, maxEvaluations, checker);
180 
181     }
182 
183     /** Minimizes a cost function.
184      * <p>The simplex is built from all its vertices.</p>
185      * <p>The optimization is performed in single-start mode.</p>
186      * @param f cost function
187      * @param maxEvaluations maximal number of function calls for each
188      * start (note that the number will be checked <em>after</em>
189      * complete simplices have been evaluated, this means that in some
190      * cases this number will be exceeded by a few units, depending on
191      * the dimension of the problem)
192      * @param checker object to use to check for convergence
193      * @param vertices array containing all vertices of the simplex
194      * @return the point/cost pairs giving the minimal cost
195      * @exception CostException if the cost function throws one during
196      * the search
197      * @exception ConvergenceException if none of the starts did
198      * converge (it is not thrown if at least one start did converge)
199      */
200     public PointCostPair minimize(CostFunction f, int maxEvaluations,
201                                   ConvergenceChecker checker,
202                                   double[][] vertices)
203     throws CostException, ConvergenceException {
204 
205         // set up optimizer
206         buildSimplex(vertices);
207         setSingleStart();
208 
209         // compute minimum
210         return minimize(f, maxEvaluations, checker);
211 
212     }
213 
214     /** Minimizes a cost function.
215      * <p>The simplex is built from all its vertices.</p>
216      * <p>The optimization is performed in multi-start mode.</p>
217      * @param f cost function
218      * @param maxEvaluations maximal number of function calls for each
219      * start (note that the number will be checked <em>after</em>
220      * complete simplices have been evaluated, this means that in some
221      * cases this number will be exceeded by a few units, depending on
222      * the dimension of the problem)
223      * @param checker object to use to check for convergence
224      * @param vertices array containing all vertices of the simplex
225      * @param starts number of starts to perform (including the
226      * first one), multi-start is disabled if value is less than or
227      * equal to 1
228      * @param seed seed for the random vector generator
229      * @return the point/cost pairs giving the minimal cost
230      * @exception NotPositiveDefiniteMatrixException if the vertices
231      * array is degenerated
232      * @exception CostException if the cost function throws one during
233      * the search
234      * @exception ConvergenceException if none of the starts did
235      * converge (it is not thrown if at least one start did converge)
236      */
237     public PointCostPair minimize(CostFunction f, int maxEvaluations,
238                                   ConvergenceChecker checker,
239                                   double[][] vertices,
240                                   int starts, long seed)
241     throws NotPositiveDefiniteMatrixException,
242     CostException, ConvergenceException {
243 
244         try {
245             // store the points into the simplex
246             buildSimplex(vertices);
247 
248             // compute the statistical properties of the simplex points
249             VectorialMean meanStat = new VectorialMean(vertices[0].length);
250             VectorialCovariance covStat = new VectorialCovariance(vertices[0].length, true);
251             for (int i = 0; i < vertices.length; ++i) {
252                 meanStat.increment(vertices[i]);
253                 covStat.increment(vertices[i]);
254             }
255             double[] mean = meanStat.getResult();
256             RealMatrix covariance = covStat.getResult();
257             
258 
259             RandomGenerator rg = new JDKRandomGenerator();
260             rg.setSeed(seed);
261             RandomVectorGenerator rvg =
262                 new CorrelatedRandomVectorGenerator(mean,
263                                                     covariance, 1.0e-12 * covariance.getNorm(),
264                                                     new UniformRandomGenerator(rg));
265             setMultiStart(starts, rvg);
266 
267             // compute minimum
268             return minimize(f, maxEvaluations, checker);
269 
270         } catch (DimensionMismatchException dme) {
271             // this should not happen
272             throw new RuntimeException("internal error");
273         }
274 
275     }
276 
277     /** Minimizes a cost function.
278      * <p>The simplex is built randomly.</p>
279      * <p>The optimization is performed in single-start mode.</p>
280      * @param f cost function
281      * @param maxEvaluations maximal number of function calls for each
282      * start (note that the number will be checked <em>after</em>
283      * complete simplices have been evaluated, this means that in some
284      * cases this number will be exceeded by a few units, depending on
285      * the dimension of the problem)
286      * @param checker object to use to check for convergence
287      * @param generator random vector generator
288      * @return the point/cost pairs giving the minimal cost
289      * @exception CostException if the cost function throws one during
290      * the search
291      * @exception ConvergenceException if none of the starts did
292      * converge (it is not thrown if at least one start did converge)
293      */
294     public PointCostPair minimize(CostFunction f, int maxEvaluations,
295                                   ConvergenceChecker checker,
296                                   RandomVectorGenerator generator)
297     throws CostException, ConvergenceException {
298 
299         // set up optimizer
300         buildSimplex(generator);
301         setSingleStart();
302 
303         // compute minimum
304         return minimize(f, maxEvaluations, checker);
305 
306     }
307 
308     /** Minimizes a cost function.
309      * <p>The simplex is built randomly.</p>
310      * <p>The optimization is performed in multi-start mode.</p>
311      * @param f cost function
312      * @param maxEvaluations maximal number of function calls for each
313      * start (note that the number will be checked <em>after</em>
314      * complete simplices have been evaluated, this means that in some
315      * cases this number will be exceeded by a few units, depending on
316      * the dimension of the problem)
317      * @param checker object to use to check for convergence
318      * @param generator random vector generator
319      * @param starts number of starts to perform (including the
320      * first one), multi-start is disabled if value is less than or
321      * equal to 1
322      * @return the point/cost pairs giving the minimal cost
323      * @exception CostException if the cost function throws one during
324      * the search
325      * @exception ConvergenceException if none of the starts did
326      * converge (it is not thrown if at least one start did converge)
327      */
328     public PointCostPair minimize(CostFunction f, int maxEvaluations,
329                                   ConvergenceChecker checker,
330                                   RandomVectorGenerator generator,
331                                   int starts)
332     throws CostException, ConvergenceException {
333 
334         // set up optimizer
335         buildSimplex(generator);
336         setMultiStart(starts, generator);
337 
338         // compute minimum
339         return minimize(f, maxEvaluations, checker);
340 
341     }
342 
343     /** Build a simplex from two extreme vertices.
344      * <p>The two vertices are considered to represent two opposite
345      * vertices of a box parallel to the canonical axes of the
346      * space. The simplex is the subset of vertices encountered while
347      * going from vertexA to vertexB traveling along the box edges
348      * only. This can be seen as a scaled regular simplex using the
349      * projected separation between the given points as the scaling
350      * factor along each coordinate axis.</p>
351      * @param vertexA first vertex
352      * @param vertexB last vertex
353      */
354     private void buildSimplex(double[] vertexA, double[] vertexB) {
355 
356         int n = vertexA.length;
357         simplex = new PointCostPair[n + 1];
358 
359         // set up the simplex traveling around the box
360         for (int i = 0; i <= n; ++i) {
361             double[] vertex = new double[n];
362             if (i > 0) {
363                 System.arraycopy(vertexB, 0, vertex, 0, i);
364             }
365             if (i < n) {
366                 System.arraycopy(vertexA, i, vertex, i, n - i);
367             }
368             simplex[i] = new PointCostPair(vertex, Double.NaN);
369         }
370 
371     }
372 
373     /** Build a simplex from all its points.
374      * @param vertices array containing all vertices of the simplex
375      */
376     private void buildSimplex(double[][] vertices) {
377         int n = vertices.length - 1;
378         simplex = new PointCostPair[n + 1];
379         for (int i = 0; i <= n; ++i) {
380             simplex[i] = new PointCostPair(vertices[i], Double.NaN);
381         }
382     }
383 
384     /** Build a simplex randomly.
385      * @param generator random vector generator
386      */
387     private void buildSimplex(RandomVectorGenerator generator) {
388 
389         // use first vector size to compute the number of points
390         double[] vertex = generator.nextVector();
391         int n = vertex.length;
392         simplex = new PointCostPair[n + 1];
393         simplex[0] = new PointCostPair(vertex, Double.NaN);
394 
395         // fill up the vertex
396         for (int i = 1; i <= n; ++i) {
397             simplex[i] = new PointCostPair(generator.nextVector(), Double.NaN);
398         }
399 
400     }
401 
402     /** Set up single-start mode.
403      */
404     private void setSingleStart() {
405         starts    = 1;
406         generator = null;
407         minima    = null;
408     }
409 
410     /** Set up multi-start mode.
411      * @param starts number of starts to perform (including the
412      * first one), multi-start is disabled if value is less than or
413      * equal to 1
414      * @param generator random vector generator to use for restarts
415      */
416     private void setMultiStart(int starts, RandomVectorGenerator generator) {
417         if (starts < 2) {
418             this.starts    = 1;
419             this.generator = null;
420             minima         = null;
421         } else {
422             this.starts    = starts;
423             this.generator = generator;
424             minima         = null;
425         }
426     }
427 
428     /** Get all the minima found during the last call to {@link
429      * #minimize(CostFunction, int, ConvergenceChecker, double[], double[])
430      * minimize}.
431      * <p>The optimizer stores all the minima found during a set of
432      * restarts when multi-start mode is enabled. The {@link
433      * #minimize(CostFunction, int, ConvergenceChecker, double[], double[])
434      * minimize} method returns the best point only. This method
435      * returns all the points found at the end of each starts, including
436      * the best one already returned by the {@link #minimize(CostFunction,
437      * int, ConvergenceChecker, double[], double[]) minimize} method.
438      * The array as one element for each start as specified in the constructor
439      * (it has one element only if optimizer has been set up for single-start).</p>
440      * <p>The array containing the minima is ordered with the results
441      * from the runs that did converge first, sorted from lowest to
442      * highest minimum cost, and null elements corresponding to the runs
443      * that did not converge (all elements will be null if the {@link
444      * #minimize(CostFunction, int, ConvergenceChecker, double[], double[])
445      * minimize} method did throw a {@link ConvergenceException
446      * ConvergenceException}).</p>
447      * @return array containing the minima, or null if {@link
448      * #minimize(CostFunction, int, ConvergenceChecker, double[], double[])
449      * minimize} has not been called
450      */
451     public PointCostPair[] getMinima() {
452         return (PointCostPair[]) minima.clone();
453     }
454 
455     /** Minimizes a cost function.
456      * @param f cost function
457      * @param maxEvaluations maximal number of function calls for each
458      * start (note that the number will be checked <em>after</em>
459      * complete simplices have been evaluated, this means that in some
460      * cases this number will be exceeded by a few units, depending on
461      * the dimension of the problem)
462      * @param checker object to use to check for convergence
463      * @return the point/cost pairs giving the minimal cost
464      * @exception CostException if the cost function throws one during
465      * the search
466      * @exception ConvergenceException if none of the starts did
467      * converge (it is not thrown if at least one start did converge)
468      */
469     private PointCostPair minimize(CostFunction f, int maxEvaluations,
470                                     ConvergenceChecker checker)
471     throws CostException, ConvergenceException {
472 
473         this.f = f;
474         minima = new PointCostPair[starts];
475 
476         // multi-start loop
477         for (int i = 0; i < starts; ++i) {
478 
479             evaluations = 0;
480             evaluateSimplex();
481 
482             for (boolean loop = true; loop;) {
483                 if (checker.converged(simplex)) {
484                     // we have found a minimum
485                     minima[i] = simplex[0];
486                     loop = false;
487                 } else if (evaluations >= maxEvaluations) {
488                     // this start did not converge, try a new one
489                     minima[i] = null;
490                     loop = false;
491                 } else {
492                     iterateSimplex();
493                 }
494             }
495 
496             if (i < (starts - 1)) {
497                 // restart
498                 buildSimplex(generator);
499             }
500 
501         }
502 
503         // sort the minima from lowest cost to highest cost, followed by
504         // null elements
505         Arrays.sort(minima, pointCostPairComparator);
506 
507         // return the found point given the lowest cost
508         if (minima[0] == null) {
509             throw new ConvergenceException("none of the {0} start points" +
510                                            " lead to convergence",
511                                            new Object[] {
512                                              Integer.toString(starts)
513                                            });
514         }
515         return minima[0];
516 
517     }
518 
519     /** Compute the next simplex of the algorithm.
520      * @exception CostException if the function cannot be evaluated at
521      * some point
522      */
523     protected abstract void iterateSimplex()
524     throws CostException;
525 
526     /** Evaluate the cost on one point.
527      * <p>A side effect of this method is to count the number of
528      * function evaluations</p>
529      * @param x point on which the cost function should be evaluated
530      * @return cost at the given point
531      * @exception CostException if no cost can be computed for the parameters
532      */
533     protected double evaluateCost(double[] x)
534     throws CostException {
535         evaluations++;
536         return f.cost(x);
537     }
538 
539     /** Evaluate all the non-evaluated points of the simplex.
540      * @exception CostException if no cost can be computed for the parameters
541      */
542     protected void evaluateSimplex()
543     throws CostException {
544 
545         // evaluate the cost at all non-evaluated simplex points
546         for (int i = 0; i < simplex.length; ++i) {
547             PointCostPair pair = simplex[i];
548             if (Double.isNaN(pair.getCost())) {
549                 simplex[i] = new PointCostPair(pair.getPoint(), evaluateCost(pair.getPoint()));
550             }
551         }
552 
553         // sort the simplex from lowest cost to highest cost
554         Arrays.sort(simplex, pointCostPairComparator);
555 
556     }
557 
558     /** Replace the worst point of the simplex by a new point.
559      * @param pointCostPair point to insert
560      */
561     protected void replaceWorstPoint(PointCostPair pointCostPair) {
562         int n = simplex.length - 1;
563         for (int i = 0; i < n; ++i) {
564             if (simplex[i].getCost() > pointCostPair.getCost()) {
565                 PointCostPair tmp = simplex[i];
566                 simplex[i]        = pointCostPair;
567                 pointCostPair     = tmp;
568             }
569         }
570         simplex[n] = pointCostPair;
571     }
572 
573     /** Comparator for {@link PointCostPair PointCostPair} objects. */
574     private static Comparator pointCostPairComparator = new Comparator() {
575         public int compare(Object o1, Object o2) {
576             if (o1 == null) {
577                 return (o2 == null) ? 0 : +1;
578             } else if (o2 == null) {
579                 return -1;
580             }
581             double cost1 = ((PointCostPair) o1).getCost();
582             double cost2 = ((PointCostPair) o2).getCost();
583             return (cost1 < cost2) ? -1 : ((o1 == o2) ? 0 : +1);
584         }
585     };
586 
587     /** Simplex. */
588     protected PointCostPair[] simplex;
589 
590     /** Cost function. */
591     private CostFunction f;
592 
593     /** Number of evaluations already performed. */
594     private int evaluations;
595 
596     /** Number of starts to go. */
597     private int starts;
598 
599     /** Random generator for multi-start. */
600     private RandomVectorGenerator generator;
601 
602     /** Found minima. */
603     private PointCostPair[] minima;
604 
605 }