1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math.optimization;
19
20
21
22
23
24
25
26
27 public class NelderMead
28 extends DirectSearchOptimizer {
29
30
31
32
33
34 public NelderMead() {
35 super();
36 this.rho = 1.0;
37 this.khi = 2.0;
38 this.gamma = 0.5;
39 this.sigma = 0.5;
40 }
41
42
43
44
45
46
47
48 public NelderMead(double rho, double khi, double gamma, double sigma) {
49 super();
50 this.rho = rho;
51 this.khi = khi;
52 this.gamma = gamma;
53 this.sigma = sigma;
54 }
55
56
57
58
59
60 protected void iterateSimplex()
61 throws CostException {
62
63
64 int n = simplex.length - 1;
65
66
67 double smallest = simplex[0].getCost();
68 double secondLargest = simplex[n-1].getCost();
69 double largest = simplex[n].getCost();
70 double[] xLargest = simplex[n].getPoint();
71
72
73
74 double[] centroid = new double[n];
75 for (int i = 0; i < n; ++i) {
76 double[] x = simplex[i].getPoint();
77 for (int j = 0; j < n; ++j) {
78 centroid[j] += x[j];
79 }
80 }
81 double scaling = 1.0 / n;
82 for (int j = 0; j < n; ++j) {
83 centroid[j] *= scaling;
84 }
85
86
87 double[] xR = new double[n];
88 for (int j = 0; j < n; ++j) {
89 xR[j] = centroid[j] + rho * (centroid[j] - xLargest[j]);
90 }
91 double costR = evaluateCost(xR);
92
93 if ((smallest <= costR) && (costR < secondLargest)) {
94
95
96 replaceWorstPoint(new PointCostPair(xR, costR));
97
98 } else if (costR < smallest) {
99
100
101 double[] xE = new double[n];
102 for (int j = 0; j < n; ++j) {
103 xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
104 }
105 double costE = evaluateCost(xE);
106
107 if (costE < costR) {
108
109 replaceWorstPoint(new PointCostPair(xE, costE));
110 } else {
111
112 replaceWorstPoint(new PointCostPair(xR, costR));
113 }
114
115 } else {
116
117 if (costR < largest) {
118
119
120 double[] xC = new double[n];
121 for (int j = 0; j < n; ++j) {
122 xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
123 }
124 double costC = evaluateCost(xC);
125
126 if (costC <= costR) {
127
128 replaceWorstPoint(new PointCostPair(xC, costC));
129 return;
130 }
131
132 } else {
133
134
135 double[] xC = new double[n];
136 for (int j = 0; j < n; ++j) {
137 xC[j] = centroid[j] - gamma * (centroid[j] - xLargest[j]);
138 }
139 double costC = evaluateCost(xC);
140
141 if (costC < largest) {
142
143 replaceWorstPoint(new PointCostPair(xC, costC));
144 return;
145 }
146
147 }
148
149
150 double[] xSmallest = simplex[0].getPoint();
151 for (int i = 1; i < simplex.length; ++i) {
152 double[] x = simplex[i].getPoint();
153 for (int j = 0; j < n; ++j) {
154 x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
155 }
156 simplex[i] = new PointCostPair(x, Double.NaN);
157 }
158 evaluateSimplex();
159
160 }
161
162 }
163
164
165 private double rho;
166
167
168 private double khi;
169
170
171 private double gamma;
172
173
174 private double sigma;
175
176 }