View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math.optimization;
19  
20  /** 
21   * This class implements the Nelder-Mead direct search method.
22   *
23   * @version $Revision: 620312 $ $Date: 2008-02-10 12:28:59 -0700 (Sun, 10 Feb 2008) $
24   * @see MultiDirectional
25   * @since 1.2
26   */
27  public class NelderMead
28    extends DirectSearchOptimizer {
29  
30    /** Build a Nelder-Mead optimizer with default coefficients.
31     * <p>The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
32     * for both gamma and sigma.</p>
33     */
34    public NelderMead() {
35      super();
36      this.rho   = 1.0;
37      this.khi   = 2.0;
38      this.gamma = 0.5;
39      this.sigma = 0.5;
40    }
41  
42    /** Build a Nelder-Mead optimizer with specified coefficients.
43     * @param rho reflection coefficient
44     * @param khi expansion coefficient
45     * @param gamma contraction coefficient
46     * @param sigma shrinkage coefficient
47     */
48    public NelderMead(double rho, double khi, double gamma, double sigma) {
49      super();
50      this.rho   = rho;
51      this.khi   = khi;
52      this.gamma = gamma;
53      this.sigma = sigma;
54    }
55  
56    /** Compute the next simplex of the algorithm.
57     * @exception CostException if the function cannot be evaluated at
58     * some point
59     */
60    protected void iterateSimplex()
61      throws CostException {
62  
63      // the simplex has n+1 point if dimension is n
64      int n = simplex.length - 1;
65  
66      // interesting costs
67      double   smallest      = simplex[0].getCost();
68      double   secondLargest = simplex[n-1].getCost();
69      double   largest       = simplex[n].getCost();
70      double[] xLargest      = simplex[n].getPoint();
71  
72      // compute the centroid of the best vertices
73      // (dismissing the worst point at index n)
74      double[] centroid = new double[n];
75      for (int i = 0; i < n; ++i) {
76        double[] x = simplex[i].getPoint();
77        for (int j = 0; j < n; ++j) {
78          centroid[j] += x[j];
79        }
80      }
81      double scaling = 1.0 / n;
82      for (int j = 0; j < n; ++j) {
83        centroid[j] *= scaling;
84      }
85  
86      // compute the reflection point
87      double[] xR       = new double[n];
88      for (int j = 0; j < n; ++j) {
89        xR[j] = centroid[j] + rho * (centroid[j] - xLargest[j]);
90      }
91      double costR = evaluateCost(xR);
92  
93      if ((smallest <= costR) && (costR < secondLargest)) {
94  
95        // accept the reflected point
96        replaceWorstPoint(new PointCostPair(xR, costR));
97  
98      } else if (costR < smallest) {
99  
100       // compute the expansion point
101       double[] xE = new double[n];
102       for (int j = 0; j < n; ++j) {
103         xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
104       }
105       double costE = evaluateCost(xE);
106 
107       if (costE < costR) {
108         // accept the expansion point
109         replaceWorstPoint(new PointCostPair(xE, costE));
110       } else {
111         // accept the reflected point
112         replaceWorstPoint(new PointCostPair(xR, costR));
113       }
114 
115     } else {
116 
117       if (costR < largest) {
118 
119         // perform an outside contraction
120         double[] xC = new double[n];
121         for (int j = 0; j < n; ++j) {
122           xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
123         }
124         double costC = evaluateCost(xC);
125 
126         if (costC <= costR) {
127           // accept the contraction point
128           replaceWorstPoint(new PointCostPair(xC, costC));
129           return;
130         }
131 
132       } else {
133 
134         // perform an inside contraction
135         double[] xC = new double[n];
136         for (int j = 0; j < n; ++j) {
137           xC[j] = centroid[j] - gamma * (centroid[j] - xLargest[j]);
138         }
139         double costC = evaluateCost(xC);
140 
141         if (costC < largest) {
142           // accept the contraction point
143           replaceWorstPoint(new PointCostPair(xC, costC));
144           return;
145         }
146 
147       }
148 
149       // perform a shrink
150       double[] xSmallest = simplex[0].getPoint();
151       for (int i = 1; i < simplex.length; ++i) {
152         double[] x = simplex[i].getPoint();
153         for (int j = 0; j < n; ++j) {
154           x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
155         }
156         simplex[i] = new PointCostPair(x, Double.NaN);
157       }
158       evaluateSimplex();
159 
160     }
161 
162   }
163 
164   /** Reflection coefficient. */
165   private double rho;
166 
167   /** Expansion coefficient. */
168   private double khi;
169 
170   /** Contraction coefficient. */
171   private double gamma;
172 
173   /** Shrinkage coefficient. */
174   private double sigma;
175 
176 }