1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.commons.math.optimization; 19 20 import java.util.Arrays; 21 import java.util.Comparator; 22 23 import org.apache.commons.math.ConvergenceException; 24 import org.apache.commons.math.DimensionMismatchException; 25 import org.apache.commons.math.linear.RealMatrix; 26 import org.apache.commons.math.random.CorrelatedRandomVectorGenerator; 27 import org.apache.commons.math.random.JDKRandomGenerator; 28 import org.apache.commons.math.random.NotPositiveDefiniteMatrixException; 29 import org.apache.commons.math.random.RandomGenerator; 30 import org.apache.commons.math.random.RandomVectorGenerator; 31 import org.apache.commons.math.random.UncorrelatedRandomVectorGenerator; 32 import org.apache.commons.math.random.UniformRandomGenerator; 33 import org.apache.commons.math.stat.descriptive.moment.VectorialCovariance; 34 import org.apache.commons.math.stat.descriptive.moment.VectorialMean; 35 36 /** 37 * This class implements simplex-based direct search optimization 38 * algorithms. 39 * 40 * <p>Direct search methods only use cost function values, they don't 41 * need derivatives and don't either try to compute approximation of 42 * the derivatives. According to a 1996 paper by Margaret H. Wright 43 * (<a href="http://cm.bell-labs.com/cm/cs/doc/96/4-02.ps.gz">Direct 44 * Search Methods: Once Scorned, Now Respectable</a>), they are used 45 * when either the computation of the derivative is impossible (noisy 46 * functions, unpredictable dicontinuities) or difficult (complexity, 47 * computation cost). In the first cases, rather than an optimum, a 48 * <em>not too bad</em> point is desired. In the latter cases, an 49 * optimum is desired but cannot be reasonably found. In all cases 50 * direct search methods can be useful.</p> 51 * 52 * <p>Simplex-based direct search methods are based on comparison of 53 * the cost function values at the vertices of a simplex (which is a 54 * set of n+1 points in dimension n) that is updated by the algorithms 55 * steps.</p> 56 * 57 * <p>Minimization can be attempted either in single-start or in 58 * multi-start mode. Multi-start is a traditional way to try to avoid 59 * being trapped in a local minimum and miss the global minimum of a 60 * function. It can also be used to verify the convergence of an 61 * algorithm. The various multi-start-enabled <code>minimize</code> 62 * methods return the best minimum found after all starts, and the 63 * {@link #getMinima getMinima} method can be used to retrieve all 64 * minima from all starts (including the one already provided by the 65 * {@link #minimize(CostFunction, int, ConvergenceChecker, double[], 66 * double[]) minimize} method).</p> 67 * 68 * <p>This class is the base class performing the boilerplate simplex 69 * initialization and handling. The simplex update by itself is 70 * performed by the derived classes according to the implemented 71 * algorithms.</p> 72 * 73 * @see CostFunction 74 * @see NelderMead 75 * @see MultiDirectional 76 * @version $Revision: 628000 $ $Date: 2008-02-15 03:31:48 -0700 (Fri, 15 Feb 2008) $ 77 * @since 1.2 78 */ 79 public abstract class DirectSearchOptimizer { 80 81 /** Simple constructor. 82 */ 83 protected DirectSearchOptimizer() { 84 } 85 86 /** Minimizes a cost function. 87 * <p>The initial simplex is built from two vertices that are 88 * considered to represent two opposite vertices of a box parallel 89 * to the canonical axes of the space. The simplex is the subset of 90 * vertices encountered while going from vertexA to vertexB 91 * traveling along the box edges only. This can be seen as a scaled 92 * regular simplex using the projected separation between the given 93 * points as the scaling factor along each coordinate axis.</p> 94 * <p>The optimization is performed in single-start mode.</p> 95 * @param f cost function 96 * @param maxEvaluations maximal number of function calls for each 97 * start (note that the number will be checked <em>after</em> 98 * complete simplices have been evaluated, this means that in some 99 * cases this number will be exceeded by a few units, depending on 100 * the dimension of the problem) 101 * @param checker object to use to check for convergence 102 * @param vertexA first vertex 103 * @param vertexB last vertex 104 * @return the point/cost pairs giving the minimal cost 105 * @exception CostException if the cost function throws one during 106 * the search 107 * @exception ConvergenceException if none of the starts did 108 * converge (it is not thrown if at least one start did converge) 109 */ 110 public PointCostPair minimize(CostFunction f, int maxEvaluations, 111 ConvergenceChecker checker, 112 double[] vertexA, double[] vertexB) 113 throws CostException, ConvergenceException { 114 115 // set up optimizer 116 buildSimplex(vertexA, vertexB); 117 setSingleStart(); 118 119 // compute minimum 120 return minimize(f, maxEvaluations, checker); 121 122 } 123 124 /** Minimizes a cost function. 125 * <p>The initial simplex is built from two vertices that are 126 * considered to represent two opposite vertices of a box parallel 127 * to the canonical axes of the space. The simplex is the subset of 128 * vertices encountered while going from vertexA to vertexB 129 * traveling along the box edges only. This can be seen as a scaled 130 * regular simplex using the projected separation between the given 131 * points as the scaling factor along each coordinate axis.</p> 132 * <p>The optimization is performed in multi-start mode.</p> 133 * @param f cost function 134 * @param maxEvaluations maximal number of function calls for each 135 * start (note that the number will be checked <em>after</em> 136 * complete simplices have been evaluated, this means that in some 137 * cases this number will be exceeded by a few units, depending on 138 * the dimension of the problem) 139 * @param checker object to use to check for convergence 140 * @param vertexA first vertex 141 * @param vertexB last vertex 142 * @param starts number of starts to perform (including the 143 * first one), multi-start is disabled if value is less than or 144 * equal to 1 145 * @param seed seed for the random vector generator 146 * @return the point/cost pairs giving the minimal cost 147 * @exception CostException if the cost function throws one during 148 * the search 149 * @exception ConvergenceException if none of the starts did 150 * converge (it is not thrown if at least one start did converge) 151 */ 152 public PointCostPair minimize(CostFunction f, int maxEvaluations, 153 ConvergenceChecker checker, 154 double[] vertexA, double[] vertexB, 155 int starts, long seed) 156 throws CostException, ConvergenceException { 157 158 // set up the simplex traveling around the box 159 buildSimplex(vertexA, vertexB); 160 161 // we consider the simplex could have been produced by a generator 162 // having its mean value at the center of the box, the standard 163 // deviation along each axe being the corresponding half size 164 double[] mean = new double[vertexA.length]; 165 double[] standardDeviation = new double[vertexA.length]; 166 for (int i = 0; i < vertexA.length; ++i) { 167 mean[i] = 0.5 * (vertexA[i] + vertexB[i]); 168 standardDeviation[i] = 0.5 * Math.abs(vertexA[i] - vertexB[i]); 169 } 170 171 RandomGenerator rg = new JDKRandomGenerator(); 172 rg.setSeed(seed); 173 UniformRandomGenerator urg = new UniformRandomGenerator(rg); 174 RandomVectorGenerator rvg = 175 new UncorrelatedRandomVectorGenerator(mean, standardDeviation, urg); 176 setMultiStart(starts, rvg); 177 178 // compute minimum 179 return minimize(f, maxEvaluations, checker); 180 181 } 182 183 /** Minimizes a cost function. 184 * <p>The simplex is built from all its vertices.</p> 185 * <p>The optimization is performed in single-start mode.</p> 186 * @param f cost function 187 * @param maxEvaluations maximal number of function calls for each 188 * start (note that the number will be checked <em>after</em> 189 * complete simplices have been evaluated, this means that in some 190 * cases this number will be exceeded by a few units, depending on 191 * the dimension of the problem) 192 * @param checker object to use to check for convergence 193 * @param vertices array containing all vertices of the simplex 194 * @return the point/cost pairs giving the minimal cost 195 * @exception CostException if the cost function throws one during 196 * the search 197 * @exception ConvergenceException if none of the starts did 198 * converge (it is not thrown if at least one start did converge) 199 */ 200 public PointCostPair minimize(CostFunction f, int maxEvaluations, 201 ConvergenceChecker checker, 202 double[][] vertices) 203 throws CostException, ConvergenceException { 204 205 // set up optimizer 206 buildSimplex(vertices); 207 setSingleStart(); 208 209 // compute minimum 210 return minimize(f, maxEvaluations, checker); 211 212 } 213 214 /** Minimizes a cost function. 215 * <p>The simplex is built from all its vertices.</p> 216 * <p>The optimization is performed in multi-start mode.</p> 217 * @param f cost function 218 * @param maxEvaluations maximal number of function calls for each 219 * start (note that the number will be checked <em>after</em> 220 * complete simplices have been evaluated, this means that in some 221 * cases this number will be exceeded by a few units, depending on 222 * the dimension of the problem) 223 * @param checker object to use to check for convergence 224 * @param vertices array containing all vertices of the simplex 225 * @param starts number of starts to perform (including the 226 * first one), multi-start is disabled if value is less than or 227 * equal to 1 228 * @param seed seed for the random vector generator 229 * @return the point/cost pairs giving the minimal cost 230 * @exception NotPositiveDefiniteMatrixException if the vertices 231 * array is degenerated 232 * @exception CostException if the cost function throws one during 233 * the search 234 * @exception ConvergenceException if none of the starts did 235 * converge (it is not thrown if at least one start did converge) 236 */ 237 public PointCostPair minimize(CostFunction f, int maxEvaluations, 238 ConvergenceChecker checker, 239 double[][] vertices, 240 int starts, long seed) 241 throws NotPositiveDefiniteMatrixException, 242 CostException, ConvergenceException { 243 244 try { 245 // store the points into the simplex 246 buildSimplex(vertices); 247 248 // compute the statistical properties of the simplex points 249 VectorialMean meanStat = new VectorialMean(vertices[0].length); 250 VectorialCovariance covStat = new VectorialCovariance(vertices[0].length, true); 251 for (int i = 0; i < vertices.length; ++i) { 252 meanStat.increment(vertices[i]); 253 covStat.increment(vertices[i]); 254 } 255 double[] mean = meanStat.getResult(); 256 RealMatrix covariance = covStat.getResult(); 257 258 259 RandomGenerator rg = new JDKRandomGenerator(); 260 rg.setSeed(seed); 261 RandomVectorGenerator rvg = 262 new CorrelatedRandomVectorGenerator(mean, 263 covariance, 1.0e-12 * covariance.getNorm(), 264 new UniformRandomGenerator(rg)); 265 setMultiStart(starts, rvg); 266 267 // compute minimum 268 return minimize(f, maxEvaluations, checker); 269 270 } catch (DimensionMismatchException dme) { 271 // this should not happen 272 throw new RuntimeException("internal error"); 273 } 274 275 } 276 277 /** Minimizes a cost function. 278 * <p>The simplex is built randomly.</p> 279 * <p>The optimization is performed in single-start mode.</p> 280 * @param f cost function 281 * @param maxEvaluations maximal number of function calls for each 282 * start (note that the number will be checked <em>after</em> 283 * complete simplices have been evaluated, this means that in some 284 * cases this number will be exceeded by a few units, depending on 285 * the dimension of the problem) 286 * @param checker object to use to check for convergence 287 * @param generator random vector generator 288 * @return the point/cost pairs giving the minimal cost 289 * @exception CostException if the cost function throws one during 290 * the search 291 * @exception ConvergenceException if none of the starts did 292 * converge (it is not thrown if at least one start did converge) 293 */ 294 public PointCostPair minimize(CostFunction f, int maxEvaluations, 295 ConvergenceChecker checker, 296 RandomVectorGenerator generator) 297 throws CostException, ConvergenceException { 298 299 // set up optimizer 300 buildSimplex(generator); 301 setSingleStart(); 302 303 // compute minimum 304 return minimize(f, maxEvaluations, checker); 305 306 } 307 308 /** Minimizes a cost function. 309 * <p>The simplex is built randomly.</p> 310 * <p>The optimization is performed in multi-start mode.</p> 311 * @param f cost function 312 * @param maxEvaluations maximal number of function calls for each 313 * start (note that the number will be checked <em>after</em> 314 * complete simplices have been evaluated, this means that in some 315 * cases this number will be exceeded by a few units, depending on 316 * the dimension of the problem) 317 * @param checker object to use to check for convergence 318 * @param generator random vector generator 319 * @param starts number of starts to perform (including the 320 * first one), multi-start is disabled if value is less than or 321 * equal to 1 322 * @return the point/cost pairs giving the minimal cost 323 * @exception CostException if the cost function throws one during 324 * the search 325 * @exception ConvergenceException if none of the starts did 326 * converge (it is not thrown if at least one start did converge) 327 */ 328 public PointCostPair minimize(CostFunction f, int maxEvaluations, 329 ConvergenceChecker checker, 330 RandomVectorGenerator generator, 331 int starts) 332 throws CostException, ConvergenceException { 333 334 // set up optimizer 335 buildSimplex(generator); 336 setMultiStart(starts, generator); 337 338 // compute minimum 339 return minimize(f, maxEvaluations, checker); 340 341 } 342 343 /** Build a simplex from two extreme vertices. 344 * <p>The two vertices are considered to represent two opposite 345 * vertices of a box parallel to the canonical axes of the 346 * space. The simplex is the subset of vertices encountered while 347 * going from vertexA to vertexB traveling along the box edges 348 * only. This can be seen as a scaled regular simplex using the 349 * projected separation between the given points as the scaling 350 * factor along each coordinate axis.</p> 351 * @param vertexA first vertex 352 * @param vertexB last vertex 353 */ 354 private void buildSimplex(double[] vertexA, double[] vertexB) { 355 356 int n = vertexA.length; 357 simplex = new PointCostPair[n + 1]; 358 359 // set up the simplex traveling around the box 360 for (int i = 0; i <= n; ++i) { 361 double[] vertex = new double[n]; 362 if (i > 0) { 363 System.arraycopy(vertexB, 0, vertex, 0, i); 364 } 365 if (i < n) { 366 System.arraycopy(vertexA, i, vertex, i, n - i); 367 } 368 simplex[i] = new PointCostPair(vertex, Double.NaN); 369 } 370 371 } 372 373 /** Build a simplex from all its points. 374 * @param vertices array containing all vertices of the simplex 375 */ 376 private void buildSimplex(double[][] vertices) { 377 int n = vertices.length - 1; 378 simplex = new PointCostPair[n + 1]; 379 for (int i = 0; i <= n; ++i) { 380 simplex[i] = new PointCostPair(vertices[i], Double.NaN); 381 } 382 } 383 384 /** Build a simplex randomly. 385 * @param generator random vector generator 386 */ 387 private void buildSimplex(RandomVectorGenerator generator) { 388 389 // use first vector size to compute the number of points 390 double[] vertex = generator.nextVector(); 391 int n = vertex.length; 392 simplex = new PointCostPair[n + 1]; 393 simplex[0] = new PointCostPair(vertex, Double.NaN); 394 395 // fill up the vertex 396 for (int i = 1; i <= n; ++i) { 397 simplex[i] = new PointCostPair(generator.nextVector(), Double.NaN); 398 } 399 400 } 401 402 /** Set up single-start mode. 403 */ 404 private void setSingleStart() { 405 starts = 1; 406 generator = null; 407 minima = null; 408 } 409 410 /** Set up multi-start mode. 411 * @param starts number of starts to perform (including the 412 * first one), multi-start is disabled if value is less than or 413 * equal to 1 414 * @param generator random vector generator to use for restarts 415 */ 416 private void setMultiStart(int starts, RandomVectorGenerator generator) { 417 if (starts < 2) { 418 this.starts = 1; 419 this.generator = null; 420 minima = null; 421 } else { 422 this.starts = starts; 423 this.generator = generator; 424 minima = null; 425 } 426 } 427 428 /** Get all the minima found during the last call to {@link 429 * #minimize(CostFunction, int, ConvergenceChecker, double[], double[]) 430 * minimize}. 431 * <p>The optimizer stores all the minima found during a set of 432 * restarts when multi-start mode is enabled. The {@link 433 * #minimize(CostFunction, int, ConvergenceChecker, double[], double[]) 434 * minimize} method returns the best point only. This method 435 * returns all the points found at the end of each starts, including 436 * the best one already returned by the {@link #minimize(CostFunction, 437 * int, ConvergenceChecker, double[], double[]) minimize} method. 438 * The array as one element for each start as specified in the constructor 439 * (it has one element only if optimizer has been set up for single-start).</p> 440 * <p>The array containing the minima is ordered with the results 441 * from the runs that did converge first, sorted from lowest to 442 * highest minimum cost, and null elements corresponding to the runs 443 * that did not converge (all elements will be null if the {@link 444 * #minimize(CostFunction, int, ConvergenceChecker, double[], double[]) 445 * minimize} method did throw a {@link ConvergenceException 446 * ConvergenceException}).</p> 447 * @return array containing the minima, or null if {@link 448 * #minimize(CostFunction, int, ConvergenceChecker, double[], double[]) 449 * minimize} has not been called 450 */ 451 public PointCostPair[] getMinima() { 452 return (PointCostPair[]) minima.clone(); 453 } 454 455 /** Minimizes a cost function. 456 * @param f cost function 457 * @param maxEvaluations maximal number of function calls for each 458 * start (note that the number will be checked <em>after</em> 459 * complete simplices have been evaluated, this means that in some 460 * cases this number will be exceeded by a few units, depending on 461 * the dimension of the problem) 462 * @param checker object to use to check for convergence 463 * @return the point/cost pairs giving the minimal cost 464 * @exception CostException if the cost function throws one during 465 * the search 466 * @exception ConvergenceException if none of the starts did 467 * converge (it is not thrown if at least one start did converge) 468 */ 469 private PointCostPair minimize(CostFunction f, int maxEvaluations, 470 ConvergenceChecker checker) 471 throws CostException, ConvergenceException { 472 473 this.f = f; 474 minima = new PointCostPair[starts]; 475 476 // multi-start loop 477 for (int i = 0; i < starts; ++i) { 478 479 evaluations = 0; 480 evaluateSimplex(); 481 482 for (boolean loop = true; loop;) { 483 if (checker.converged(simplex)) { 484 // we have found a minimum 485 minima[i] = simplex[0]; 486 loop = false; 487 } else if (evaluations >= maxEvaluations) { 488 // this start did not converge, try a new one 489 minima[i] = null; 490 loop = false; 491 } else { 492 iterateSimplex(); 493 } 494 } 495 496 if (i < (starts - 1)) { 497 // restart 498 buildSimplex(generator); 499 } 500 501 } 502 503 // sort the minima from lowest cost to highest cost, followed by 504 // null elements 505 Arrays.sort(minima, pointCostPairComparator); 506 507 // return the found point given the lowest cost 508 if (minima[0] == null) { 509 throw new ConvergenceException("none of the {0} start points" + 510 " lead to convergence", 511 new Object[] { 512 Integer.toString(starts) 513 }); 514 } 515 return minima[0]; 516 517 } 518 519 /** Compute the next simplex of the algorithm. 520 * @exception CostException if the function cannot be evaluated at 521 * some point 522 */ 523 protected abstract void iterateSimplex() 524 throws CostException; 525 526 /** Evaluate the cost on one point. 527 * <p>A side effect of this method is to count the number of 528 * function evaluations</p> 529 * @param x point on which the cost function should be evaluated 530 * @return cost at the given point 531 * @exception CostException if no cost can be computed for the parameters 532 */ 533 protected double evaluateCost(double[] x) 534 throws CostException { 535 evaluations++; 536 return f.cost(x); 537 } 538 539 /** Evaluate all the non-evaluated points of the simplex. 540 * @exception CostException if no cost can be computed for the parameters 541 */ 542 protected void evaluateSimplex() 543 throws CostException { 544 545 // evaluate the cost at all non-evaluated simplex points 546 for (int i = 0; i < simplex.length; ++i) { 547 PointCostPair pair = simplex[i]; 548 if (Double.isNaN(pair.getCost())) { 549 simplex[i] = new PointCostPair(pair.getPoint(), evaluateCost(pair.getPoint())); 550 } 551 } 552 553 // sort the simplex from lowest cost to highest cost 554 Arrays.sort(simplex, pointCostPairComparator); 555 556 } 557 558 /** Replace the worst point of the simplex by a new point. 559 * @param pointCostPair point to insert 560 */ 561 protected void replaceWorstPoint(PointCostPair pointCostPair) { 562 int n = simplex.length - 1; 563 for (int i = 0; i < n; ++i) { 564 if (simplex[i].getCost() > pointCostPair.getCost()) { 565 PointCostPair tmp = simplex[i]; 566 simplex[i] = pointCostPair; 567 pointCostPair = tmp; 568 } 569 } 570 simplex[n] = pointCostPair; 571 } 572 573 /** Comparator for {@link PointCostPair PointCostPair} objects. */ 574 private static Comparator pointCostPairComparator = new Comparator() { 575 public int compare(Object o1, Object o2) { 576 if (o1 == null) { 577 return (o2 == null) ? 0 : +1; 578 } else if (o2 == null) { 579 return -1; 580 } 581 double cost1 = ((PointCostPair) o1).getCost(); 582 double cost2 = ((PointCostPair) o2).getCost(); 583 return (cost1 < cost2) ? -1 : ((o1 == o2) ? 0 : +1); 584 } 585 }; 586 587 /** Simplex. */ 588 protected PointCostPair[] simplex; 589 590 /** Cost function. */ 591 private CostFunction f; 592 593 /** Number of evaluations already performed. */ 594 private int evaluations; 595 596 /** Number of starts to go. */ 597 private int starts; 598 599 /** Random generator for multi-start. */ 600 private RandomVectorGenerator generator; 601 602 /** Found minima. */ 603 private PointCostPair[] minima; 604 605 }