View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math.estimation;
19  
20  import java.io.Serializable;
21  
22  import org.apache.commons.math.linear.InvalidMatrixException;
23  import org.apache.commons.math.linear.RealMatrix;
24  import org.apache.commons.math.linear.RealMatrixImpl;
25  
26  /** 
27   * This class implements a solver for estimation problems.
28   *
29   * <p>This class solves estimation problems using a weighted least
30   * squares criterion on the measurement residuals. It uses a
31   * Gauss-Newton algorithm.</p>
32   *
33   * @version $Revision: 627987 $ $Date: 2008-02-15 03:01:26 -0700 (Fri, 15 Feb 2008) $
34   * @since 1.2
35   *
36   */
37  
38  public class GaussNewtonEstimator extends AbstractEstimator implements Serializable {
39  
40      /** 
41       * Simple constructor.
42       *
43       * <p>This constructor builds an estimator and stores its convergence
44       * characteristics.</p>
45       *
46       * <p>An estimator is considered to have converged whenever either
47       * the criterion goes below a physical threshold under which
48       * improvements are considered useless or when the algorithm is
49       * unable to improve it (even if it is still high). The first
50       * condition that is met stops the iterations.</p>
51       *
52       * <p>The fact an estimator has converged does not mean that the
53       * model accurately fits the measurements. It only means no better
54       * solution can be found, it does not mean this one is good. Such an
55       * analysis is left to the caller.</p>
56       *
57       * <p>If neither conditions are fulfilled before a given number of
58       * iterations, the algorithm is considered to have failed and an
59       * {@link EstimationException} is thrown.</p>
60       *
61       * @param maxCostEval maximal number of cost evaluations allowed
62       * @param convergence criterion threshold below which we do not need
63       * to improve the criterion anymore
64       * @param steadyStateThreshold steady state detection threshold, the
65       * problem has converged has reached a steady state if
66       * <code>Math.abs (Jn - Jn-1) < Jn * convergence</code>, where
67       * <code>Jn</code> and <code>Jn-1</code> are the current and
68       * preceding criterion value (square sum of the weighted residuals
69       * of considered measurements).
70       */
71      public GaussNewtonEstimator(int maxCostEval,
72              double convergence,
73              double steadyStateThreshold) {
74          setMaxCostEval(maxCostEval);
75          this.steadyStateThreshold = steadyStateThreshold;
76          this.convergence          = convergence;
77      }
78  
79      /** 
80       * Solve an estimation problem using a least squares criterion.
81       *
82       * <p>This method set the unbound parameters of the given problem
83       * starting from their current values through several iterations. At
84       * each step, the unbound parameters are changed in order to
85       * minimize a weighted least square criterion based on the
86       * measurements of the problem.</p>
87       *
88       * <p>The iterations are stopped either when the criterion goes
89       * below a physical threshold under which improvement are considered
90       * useless or when the algorithm is unable to improve it (even if it
91       * is still high). The first condition that is met stops the
92       * iterations. If the convergence it nos reached before the maximum
93       * number of iterations, an {@link EstimationException} is
94       * thrown.</p>
95       *
96       * @param problem estimation problem to solve
97       * @exception EstimationException if the problem cannot be solved
98       *
99       * @see EstimationProblem
100      *
101      */
102     public void estimate(EstimationProblem problem)
103     throws EstimationException {
104 
105         initializeEstimate(problem);
106 
107         // work matrices
108         double[] grad             = new double[parameters.length];
109         RealMatrixImpl bDecrement = new RealMatrixImpl(parameters.length, 1);
110         double[][] bDecrementData = bDecrement.getDataRef();
111         RealMatrixImpl wGradGradT = new RealMatrixImpl(parameters.length, parameters.length);
112         double[][] wggData        = wGradGradT.getDataRef();
113 
114         // iterate until convergence is reached
115         double previous = Double.POSITIVE_INFINITY;
116         do {
117 
118             // build the linear problem
119             incrementJacobianEvaluationsCounter();
120             RealMatrix b = new RealMatrixImpl(parameters.length, 1);
121             RealMatrix a = new RealMatrixImpl(parameters.length, parameters.length);
122             for (int i = 0; i < measurements.length; ++i) {
123                 if (! measurements [i].isIgnored()) {
124 
125                     double weight   = measurements[i].getWeight();
126                     double residual = measurements[i].getResidual();
127 
128                     // compute the normal equation
129                     for (int j = 0; j < parameters.length; ++j) {
130                         grad[j] = measurements[i].getPartial(parameters[j]);
131                         bDecrementData[j][0] = weight * residual * grad[j];
132                     }
133 
134                     // build the contribution matrix for measurement i
135                     for (int k = 0; k < parameters.length; ++k) {
136                         double[] wggRow = wggData[k];
137                         double gk = grad[k];
138                         for (int l = 0; l < parameters.length; ++l) {
139                             wggRow[l] =  weight * gk * grad[l];
140                         }
141                     }
142 
143                     // update the matrices
144                     a = a.add(wGradGradT);
145                     b = b.add(bDecrement);
146 
147                 }
148             }
149 
150             try {
151 
152                 // solve the linearized least squares problem
153                 RealMatrix dX = a.solve(b);
154 
155                 // update the estimated parameters
156                 for (int i = 0; i < parameters.length; ++i) {
157                     parameters[i].setEstimate(parameters[i].getEstimate() + dX.getEntry(i, 0));
158                 }
159 
160             } catch(InvalidMatrixException e) {
161                 throw new EstimationException("unable to solve: singular problem", new Object[0]);
162             }
163 
164 
165             previous = cost;
166             updateResidualsAndCost();
167 
168         } while ((getCostEvaluations() < 2) ||
169                  (Math.abs(previous - cost) > (cost * steadyStateThreshold) &&
170                   (Math.abs(cost) > convergence)));
171 
172     }
173 
174     /** Threshold for cost steady state detection. */
175     private double steadyStateThreshold;
176 
177     /** Threshold for cost convergence. */
178     private double convergence;
179 
180     /** Serializable version identifier */
181      private static final long serialVersionUID = 5485001826076289109L;
182 
183 }