001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math3.optim.univariate;
018    
019    import org.apache.commons.math3.util.Precision;
020    import org.apache.commons.math3.util.FastMath;
021    import org.apache.commons.math3.exception.NumberIsTooSmallException;
022    import org.apache.commons.math3.exception.NotStrictlyPositiveException;
023    import org.apache.commons.math3.optim.ConvergenceChecker;
024    import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
025    
026    /**
027     * For a function defined on some interval {@code (lo, hi)}, this class
028     * finds an approximation {@code x} to the point at which the function
029     * attains its minimum.
030     * It implements Richard Brent's algorithm (from his book "Algorithms for
031     * Minimization without Derivatives", p. 79) for finding minima of real
032     * univariate functions.
033     * <br/>
034     * This code is an adaptation, partly based on the Python code from SciPy
035     * (module "optimize.py" v0.5); the original algorithm is also modified
036     * <ul>
037     *  <li>to use an initial guess provided by the user,</li>
038     *  <li>to ensure that the best point encountered is the one returned.</li>
039     * </ul>
040     *
041     * @version $Id: BrentOptimizer.java 1416643 2012-12-03 19:37:14Z tn $
042     * @since 2.0
043     */
044    public class BrentOptimizer extends UnivariateOptimizer {
045        /**
046         * Golden section.
047         */
048        private static final double GOLDEN_SECTION = 0.5 * (3 - FastMath.sqrt(5));
049        /**
050         * Minimum relative tolerance.
051         */
052        private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d);
053        /**
054         * Relative threshold.
055         */
056        private final double relativeThreshold;
057        /**
058         * Absolute threshold.
059         */
060        private final double absoluteThreshold;
061    
062        /**
063         * The arguments are used implement the original stopping criterion
064         * of Brent's algorithm.
065         * {@code abs} and {@code rel} define a tolerance
066         * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
067         * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
068         * where <em>macheps</em> is the relative machine precision. {@code abs} must
069         * be positive.
070         *
071         * @param rel Relative threshold.
072         * @param abs Absolute threshold.
073         * @param checker Additional, user-defined, convergence checking
074         * procedure.
075         * @throws NotStrictlyPositiveException if {@code abs <= 0}.
076         * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
077         */
078        public BrentOptimizer(double rel,
079                              double abs,
080                              ConvergenceChecker<UnivariatePointValuePair> checker) {
081            super(checker);
082    
083            if (rel < MIN_RELATIVE_TOLERANCE) {
084                throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
085            }
086            if (abs <= 0) {
087                throw new NotStrictlyPositiveException(abs);
088            }
089    
090            relativeThreshold = rel;
091            absoluteThreshold = abs;
092        }
093    
094        /**
095         * The arguments are used for implementing the original stopping criterion
096         * of Brent's algorithm.
097         * {@code abs} and {@code rel} define a tolerance
098         * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
099         * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
100         * where <em>macheps</em> is the relative machine precision. {@code abs} must
101         * be positive.
102         *
103         * @param rel Relative threshold.
104         * @param abs Absolute threshold.
105         * @throws NotStrictlyPositiveException if {@code abs <= 0}.
106         * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
107         */
108        public BrentOptimizer(double rel,
109                              double abs) {
110            this(rel, abs, null);
111        }
112    
113        /** {@inheritDoc} */
114        @Override
115        protected UnivariatePointValuePair doOptimize() {
116            final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
117            final double lo = getMin();
118            final double mid = getStartValue();
119            final double hi = getMax();
120    
121            // Optional additional convergence criteria.
122            final ConvergenceChecker<UnivariatePointValuePair> checker
123                = getConvergenceChecker();
124    
125            double a;
126            double b;
127            if (lo < hi) {
128                a = lo;
129                b = hi;
130            } else {
131                a = hi;
132                b = lo;
133            }
134    
135            double x = mid;
136            double v = x;
137            double w = x;
138            double d = 0;
139            double e = 0;
140            double fx = computeObjectiveValue(x);
141            if (!isMinim) {
142                fx = -fx;
143            }
144            double fv = fx;
145            double fw = fx;
146    
147            UnivariatePointValuePair previous = null;
148            UnivariatePointValuePair current
149                = new UnivariatePointValuePair(x, isMinim ? fx : -fx);
150            // Best point encountered so far (which is the initial guess).
151            UnivariatePointValuePair best = current;
152    
153            int iter = 0;
154            while (true) {
155                final double m = 0.5 * (a + b);
156                final double tol1 = relativeThreshold * FastMath.abs(x) + absoluteThreshold;
157                final double tol2 = 2 * tol1;
158    
159                // Default stopping criterion.
160                final boolean stop = FastMath.abs(x - m) <= tol2 - 0.5 * (b - a);
161                if (!stop) {
162                    double p = 0;
163                    double q = 0;
164                    double r = 0;
165                    double u = 0;
166    
167                    if (FastMath.abs(e) > tol1) { // Fit parabola.
168                        r = (x - w) * (fx - fv);
169                        q = (x - v) * (fx - fw);
170                        p = (x - v) * q - (x - w) * r;
171                        q = 2 * (q - r);
172    
173                        if (q > 0) {
174                            p = -p;
175                        } else {
176                            q = -q;
177                        }
178    
179                        r = e;
180                        e = d;
181    
182                        if (p > q * (a - x) &&
183                            p < q * (b - x) &&
184                            FastMath.abs(p) < FastMath.abs(0.5 * q * r)) {
185                            // Parabolic interpolation step.
186                            d = p / q;
187                            u = x + d;
188    
189                            // f must not be evaluated too close to a or b.
190                            if (u - a < tol2 || b - u < tol2) {
191                                if (x <= m) {
192                                    d = tol1;
193                                } else {
194                                    d = -tol1;
195                                }
196                            }
197                        } else {
198                            // Golden section step.
199                            if (x < m) {
200                                e = b - x;
201                            } else {
202                                e = a - x;
203                            }
204                            d = GOLDEN_SECTION * e;
205                        }
206                    } else {
207                        // Golden section step.
208                        if (x < m) {
209                            e = b - x;
210                        } else {
211                            e = a - x;
212                        }
213                        d = GOLDEN_SECTION * e;
214                    }
215    
216                    // Update by at least "tol1".
217                    if (FastMath.abs(d) < tol1) {
218                        if (d >= 0) {
219                            u = x + tol1;
220                        } else {
221                            u = x - tol1;
222                        }
223                    } else {
224                        u = x + d;
225                    }
226    
227                    double fu = computeObjectiveValue(u);
228                    if (!isMinim) {
229                        fu = -fu;
230                    }
231    
232                    // User-defined convergence checker.
233                    previous = current;
234                    current = new UnivariatePointValuePair(u, isMinim ? fu : -fu);
235                    best = best(best,
236                                best(previous,
237                                     current,
238                                     isMinim),
239                                isMinim);
240    
241                    if (checker != null) {
242                        if (checker.converged(iter, previous, current)) {
243                            return best;
244                        }
245                    }
246    
247                    // Update a, b, v, w and x.
248                    if (fu <= fx) {
249                        if (u < x) {
250                            b = x;
251                        } else {
252                            a = x;
253                        }
254                        v = w;
255                        fv = fw;
256                        w = x;
257                        fw = fx;
258                        x = u;
259                        fx = fu;
260                    } else {
261                        if (u < x) {
262                            a = u;
263                        } else {
264                            b = u;
265                        }
266                        if (fu <= fw ||
267                            Precision.equals(w, x)) {
268                            v = w;
269                            fv = fw;
270                            w = u;
271                            fw = fu;
272                        } else if (fu <= fv ||
273                                   Precision.equals(v, x) ||
274                                   Precision.equals(v, w)) {
275                            v = u;
276                            fv = fu;
277                        }
278                    }
279                } else { // Default termination (Brent's criterion).
280                    return best(best,
281                                best(previous,
282                                     current,
283                                     isMinim),
284                                isMinim);
285                }
286                ++iter;
287            }
288        }
289    
290        /**
291         * Selects the best of two points.
292         *
293         * @param a Point and value.
294         * @param b Point and value.
295         * @param isMinim {@code true} if the selected point must be the one with
296         * the lowest value.
297         * @return the best point, or {@code null} if {@code a} and {@code b} are
298         * both {@code null}. When {@code a} and {@code b} have the same function
299         * value, {@code a} is returned.
300         */
301        private UnivariatePointValuePair best(UnivariatePointValuePair a,
302                                              UnivariatePointValuePair b,
303                                              boolean isMinim) {
304            if (a == null) {
305                return b;
306            }
307            if (b == null) {
308                return a;
309            }
310    
311            if (isMinim) {
312                return a.getValue() <= b.getValue() ? a : b;
313            } else {
314                return a.getValue() >= b.getValue() ? a : b;
315            }
316        }
317    }