001 /* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 package org.apache.commons.math3.optim.nonlinear.vector; 018 019 import java.util.Collections; 020 import java.util.List; 021 import java.util.ArrayList; 022 import java.util.Comparator; 023 import org.apache.commons.math3.exception.NotStrictlyPositiveException; 024 import org.apache.commons.math3.exception.NullArgumentException; 025 import org.apache.commons.math3.linear.RealMatrix; 026 import org.apache.commons.math3.linear.RealVector; 027 import org.apache.commons.math3.linear.ArrayRealVector; 028 import org.apache.commons.math3.random.RandomVectorGenerator; 029 import org.apache.commons.math3.optim.BaseMultiStartMultivariateOptimizer; 030 import org.apache.commons.math3.optim.PointVectorValuePair; 031 032 /** 033 * Multi-start optimizer for a (vector) model function. 034 * 035 * This class wraps an optimizer in order to use it several times in 036 * turn with different starting points (trying to avoid being trapped 037 * in a local extremum when looking for a global one). 038 * 039 * @version $Id$ 040 * @since 3.0 041 */ 042 public class MultiStartMultivariateVectorOptimizer 043 extends BaseMultiStartMultivariateOptimizer<PointVectorValuePair> { 044 /** Underlying optimizer. */ 045 private final MultivariateVectorOptimizer optimizer; 046 /** Found optima. */ 047 private final List<PointVectorValuePair> optima = new ArrayList<PointVectorValuePair>(); 048 049 /** 050 * Create a multi-start optimizer from a single-start optimizer. 051 * 052 * @param optimizer Single-start optimizer to wrap. 053 * @param starts Number of starts to perform. 054 * If {@code starts == 1}, the result will be same as if {@code optimizer} 055 * is called directly. 056 * @param generator Random vector generator to use for restarts. 057 * @throws NullArgumentException if {@code optimizer} or {@code generator} 058 * is {@code null}. 059 * @throws NotStrictlyPositiveException if {@code starts < 1}. 060 */ 061 public MultiStartMultivariateVectorOptimizer(final MultivariateVectorOptimizer optimizer, 062 final int starts, 063 final RandomVectorGenerator generator) 064 throws NullArgumentException, 065 NotStrictlyPositiveException { 066 super(optimizer, starts, generator); 067 this.optimizer = optimizer; 068 } 069 070 /** 071 * {@inheritDoc} 072 */ 073 @Override 074 public PointVectorValuePair[] getOptima() { 075 Collections.sort(optima, getPairComparator()); 076 return optima.toArray(new PointVectorValuePair[0]); 077 } 078 079 /** 080 * {@inheritDoc} 081 */ 082 @Override 083 protected void store(PointVectorValuePair optimum) { 084 optima.add(optimum); 085 } 086 087 /** 088 * {@inheritDoc} 089 */ 090 @Override 091 protected void clear() { 092 optima.clear(); 093 } 094 095 /** 096 * @return a comparator for sorting the optima. 097 */ 098 private Comparator<PointVectorValuePair> getPairComparator() { 099 return new Comparator<PointVectorValuePair>() { 100 private final RealVector target = new ArrayRealVector(optimizer.getTarget(), false); 101 private final RealMatrix weight = optimizer.getWeight(); 102 103 public int compare(final PointVectorValuePair o1, 104 final PointVectorValuePair o2) { 105 if (o1 == null) { 106 return (o2 == null) ? 0 : 1; 107 } else if (o2 == null) { 108 return -1; 109 } 110 return Double.compare(weightedResidual(o1), 111 weightedResidual(o2)); 112 } 113 114 private double weightedResidual(final PointVectorValuePair pv) { 115 final RealVector v = new ArrayRealVector(pv.getValueRef(), false); 116 final RealVector r = target.subtract(v); 117 return r.dotProduct(weight.operate(r)); 118 } 119 }; 120 } 121 }