001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math3.ode.nonstiff;
019    
020    import org.apache.commons.math3.exception.DimensionMismatchException;
021    import org.apache.commons.math3.exception.MaxCountExceededException;
022    import org.apache.commons.math3.exception.NoBracketingException;
023    import org.apache.commons.math3.exception.NumberIsTooSmallException;
024    import org.apache.commons.math3.ode.ExpandableStatefulODE;
025    import org.apache.commons.math3.util.FastMath;
026    
027    /**
028     * This class implements the common part of all embedded Runge-Kutta
029     * integrators for Ordinary Differential Equations.
030     *
031     * <p>These methods are embedded explicit Runge-Kutta methods with two
032     * sets of coefficients allowing to estimate the error, their Butcher
033     * arrays are as follows :
034     * <pre>
035     *    0  |
036     *   c2  | a21
037     *   c3  | a31  a32
038     *   ... |        ...
039     *   cs  | as1  as2  ...  ass-1
040     *       |--------------------------
041     *       |  b1   b2  ...   bs-1  bs
042     *       |  b'1  b'2 ...   b's-1 b's
043     * </pre>
044     * </p>
045     *
046     * <p>In fact, we rather use the array defined by ej = bj - b'j to
047     * compute directly the error rather than computing two estimates and
048     * then comparing them.</p>
049     *
050     * <p>Some methods are qualified as <i>fsal</i> (first same as last)
051     * methods. This means the last evaluation of the derivatives in one
052     * step is the same as the first in the next step. Then, this
053     * evaluation can be reused from one step to the next one and the cost
054     * of such a method is really s-1 evaluations despite the method still
055     * has s stages. This behaviour is true only for successful steps, if
056     * the step is rejected after the error estimation phase, no
057     * evaluation is saved. For an <i>fsal</i> method, we have cs = 1 and
058     * asi = bi for all i.</p>
059     *
060     * @version $Id: EmbeddedRungeKuttaIntegrator.java 1416643 2012-12-03 19:37:14Z tn $
061     * @since 1.2
062     */
063    
064    public abstract class EmbeddedRungeKuttaIntegrator
065      extends AdaptiveStepsizeIntegrator {
066    
067        /** Indicator for <i>fsal</i> methods. */
068        private final boolean fsal;
069    
070        /** Time steps from Butcher array (without the first zero). */
071        private final double[] c;
072    
073        /** Internal weights from Butcher array (without the first empty row). */
074        private final double[][] a;
075    
076        /** External weights for the high order method from Butcher array. */
077        private final double[] b;
078    
079        /** Prototype of the step interpolator. */
080        private final RungeKuttaStepInterpolator prototype;
081    
082        /** Stepsize control exponent. */
083        private final double exp;
084    
085        /** Safety factor for stepsize control. */
086        private double safety;
087    
088        /** Minimal reduction factor for stepsize control. */
089        private double minReduction;
090    
091        /** Maximal growth factor for stepsize control. */
092        private double maxGrowth;
093    
094      /** Build a Runge-Kutta integrator with the given Butcher array.
095       * @param name name of the method
096       * @param fsal indicate that the method is an <i>fsal</i>
097       * @param c time steps from Butcher array (without the first zero)
098       * @param a internal weights from Butcher array (without the first empty row)
099       * @param b propagation weights for the high order method from Butcher array
100       * @param prototype prototype of the step interpolator to use
101       * @param minStep minimal step (sign is irrelevant, regardless of
102       * integration direction, forward or backward), the last step can
103       * be smaller than this
104       * @param maxStep maximal step (sign is irrelevant, regardless of
105       * integration direction, forward or backward), the last step can
106       * be smaller than this
107       * @param scalAbsoluteTolerance allowed absolute error
108       * @param scalRelativeTolerance allowed relative error
109       */
110      protected EmbeddedRungeKuttaIntegrator(final String name, final boolean fsal,
111                                             final double[] c, final double[][] a, final double[] b,
112                                             final RungeKuttaStepInterpolator prototype,
113                                             final double minStep, final double maxStep,
114                                             final double scalAbsoluteTolerance,
115                                             final double scalRelativeTolerance) {
116    
117        super(name, minStep, maxStep, scalAbsoluteTolerance, scalRelativeTolerance);
118    
119        this.fsal      = fsal;
120        this.c         = c;
121        this.a         = a;
122        this.b         = b;
123        this.prototype = prototype;
124    
125        exp = -1.0 / getOrder();
126    
127        // set the default values of the algorithm control parameters
128        setSafety(0.9);
129        setMinReduction(0.2);
130        setMaxGrowth(10.0);
131    
132      }
133    
134      /** Build a Runge-Kutta integrator with the given Butcher array.
135       * @param name name of the method
136       * @param fsal indicate that the method is an <i>fsal</i>
137       * @param c time steps from Butcher array (without the first zero)
138       * @param a internal weights from Butcher array (without the first empty row)
139       * @param b propagation weights for the high order method from Butcher array
140       * @param prototype prototype of the step interpolator to use
141       * @param minStep minimal step (must be positive even for backward
142       * integration), the last step can be smaller than this
143       * @param maxStep maximal step (must be positive even for backward
144       * integration)
145       * @param vecAbsoluteTolerance allowed absolute error
146       * @param vecRelativeTolerance allowed relative error
147       */
148      protected EmbeddedRungeKuttaIntegrator(final String name, final boolean fsal,
149                                             final double[] c, final double[][] a, final double[] b,
150                                             final RungeKuttaStepInterpolator prototype,
151                                             final double   minStep, final double maxStep,
152                                             final double[] vecAbsoluteTolerance,
153                                             final double[] vecRelativeTolerance) {
154    
155        super(name, minStep, maxStep, vecAbsoluteTolerance, vecRelativeTolerance);
156    
157        this.fsal      = fsal;
158        this.c         = c;
159        this.a         = a;
160        this.b         = b;
161        this.prototype = prototype;
162    
163        exp = -1.0 / getOrder();
164    
165        // set the default values of the algorithm control parameters
166        setSafety(0.9);
167        setMinReduction(0.2);
168        setMaxGrowth(10.0);
169    
170      }
171    
172      /** Get the order of the method.
173       * @return order of the method
174       */
175      public abstract int getOrder();
176    
177      /** Get the safety factor for stepsize control.
178       * @return safety factor
179       */
180      public double getSafety() {
181        return safety;
182      }
183    
184      /** Set the safety factor for stepsize control.
185       * @param safety safety factor
186       */
187      public void setSafety(final double safety) {
188        this.safety = safety;
189      }
190    
191      /** {@inheritDoc} */
192      @Override
193      public void integrate(final ExpandableStatefulODE equations, final double t)
194          throws NumberIsTooSmallException, DimensionMismatchException,
195                 MaxCountExceededException, NoBracketingException {
196    
197        sanityChecks(equations, t);
198        setEquations(equations);
199        final boolean forward = t > equations.getTime();
200    
201        // create some internal working arrays
202        final double[] y0  = equations.getCompleteState();
203        final double[] y = y0.clone();
204        final int stages = c.length + 1;
205        final double[][] yDotK = new double[stages][y.length];
206        final double[] yTmp    = y0.clone();
207        final double[] yDotTmp = new double[y.length];
208    
209        // set up an interpolator sharing the integrator arrays
210        final RungeKuttaStepInterpolator interpolator = (RungeKuttaStepInterpolator) prototype.copy();
211        interpolator.reinitialize(this, yTmp, yDotK, forward,
212                                  equations.getPrimaryMapper(), equations.getSecondaryMappers());
213        interpolator.storeTime(equations.getTime());
214    
215        // set up integration control objects
216        stepStart         = equations.getTime();
217        double  hNew      = 0;
218        boolean firstTime = true;
219        initIntegration(equations.getTime(), y0, t);
220    
221        // main integration loop
222        isLastStep = false;
223        do {
224    
225          interpolator.shift();
226    
227          // iterate over step size, ensuring local normalized error is smaller than 1
228          double error = 10;
229          while (error >= 1.0) {
230    
231            if (firstTime || !fsal) {
232              // first stage
233              computeDerivatives(stepStart, y, yDotK[0]);
234            }
235    
236            if (firstTime) {
237              final double[] scale = new double[mainSetDimension];
238              if (vecAbsoluteTolerance == null) {
239                  for (int i = 0; i < scale.length; ++i) {
240                    scale[i] = scalAbsoluteTolerance + scalRelativeTolerance * FastMath.abs(y[i]);
241                  }
242              } else {
243                  for (int i = 0; i < scale.length; ++i) {
244                    scale[i] = vecAbsoluteTolerance[i] + vecRelativeTolerance[i] * FastMath.abs(y[i]);
245                  }
246              }
247              hNew = initializeStep(forward, getOrder(), scale,
248                                    stepStart, y, yDotK[0], yTmp, yDotK[1]);
249              firstTime = false;
250            }
251    
252            stepSize = hNew;
253            if (forward) {
254                if (stepStart + stepSize >= t) {
255                    stepSize = t - stepStart;
256                }
257            } else {
258                if (stepStart + stepSize <= t) {
259                    stepSize = t - stepStart;
260                }
261            }
262    
263            // next stages
264            for (int k = 1; k < stages; ++k) {
265    
266              for (int j = 0; j < y0.length; ++j) {
267                double sum = a[k-1][0] * yDotK[0][j];
268                for (int l = 1; l < k; ++l) {
269                  sum += a[k-1][l] * yDotK[l][j];
270                }
271                yTmp[j] = y[j] + stepSize * sum;
272              }
273    
274              computeDerivatives(stepStart + c[k-1] * stepSize, yTmp, yDotK[k]);
275    
276            }
277    
278            // estimate the state at the end of the step
279            for (int j = 0; j < y0.length; ++j) {
280              double sum    = b[0] * yDotK[0][j];
281              for (int l = 1; l < stages; ++l) {
282                sum    += b[l] * yDotK[l][j];
283              }
284              yTmp[j] = y[j] + stepSize * sum;
285            }
286    
287            // estimate the error at the end of the step
288            error = estimateError(yDotK, y, yTmp, stepSize);
289            if (error >= 1.0) {
290              // reject the step and attempt to reduce error by stepsize control
291              final double factor =
292                  FastMath.min(maxGrowth,
293                               FastMath.max(minReduction, safety * FastMath.pow(error, exp)));
294              hNew = filterStep(stepSize * factor, forward, false);
295            }
296    
297          }
298    
299          // local error is small enough: accept the step, trigger events and step handlers
300          interpolator.storeTime(stepStart + stepSize);
301          System.arraycopy(yTmp, 0, y, 0, y0.length);
302          System.arraycopy(yDotK[stages - 1], 0, yDotTmp, 0, y0.length);
303          stepStart = acceptStep(interpolator, y, yDotTmp, t);
304          System.arraycopy(y, 0, yTmp, 0, y.length);
305    
306          if (!isLastStep) {
307    
308              // prepare next step
309              interpolator.storeTime(stepStart);
310    
311              if (fsal) {
312                  // save the last evaluation for the next step
313                  System.arraycopy(yDotTmp, 0, yDotK[0], 0, y0.length);
314              }
315    
316              // stepsize control for next step
317              final double factor =
318                  FastMath.min(maxGrowth, FastMath.max(minReduction, safety * FastMath.pow(error, exp)));
319              final double  scaledH    = stepSize * factor;
320              final double  nextT      = stepStart + scaledH;
321              final boolean nextIsLast = forward ? (nextT >= t) : (nextT <= t);
322              hNew = filterStep(scaledH, forward, nextIsLast);
323    
324              final double  filteredNextT      = stepStart + hNew;
325              final boolean filteredNextIsLast = forward ? (filteredNextT >= t) : (filteredNextT <= t);
326              if (filteredNextIsLast) {
327                  hNew = t - stepStart;
328              }
329    
330          }
331    
332        } while (!isLastStep);
333    
334        // dispatch results
335        equations.setTime(stepStart);
336        equations.setCompleteState(y);
337    
338        resetInternalState();
339    
340      }
341    
342      /** Get the minimal reduction factor for stepsize control.
343       * @return minimal reduction factor
344       */
345      public double getMinReduction() {
346        return minReduction;
347      }
348    
349      /** Set the minimal reduction factor for stepsize control.
350       * @param minReduction minimal reduction factor
351       */
352      public void setMinReduction(final double minReduction) {
353        this.minReduction = minReduction;
354      }
355    
356      /** Get the maximal growth factor for stepsize control.
357       * @return maximal growth factor
358       */
359      public double getMaxGrowth() {
360        return maxGrowth;
361      }
362    
363      /** Set the maximal growth factor for stepsize control.
364       * @param maxGrowth maximal growth factor
365       */
366      public void setMaxGrowth(final double maxGrowth) {
367        this.maxGrowth = maxGrowth;
368      }
369    
370      /** Compute the error ratio.
371       * @param yDotK derivatives computed during the first stages
372       * @param y0 estimate of the step at the start of the step
373       * @param y1 estimate of the step at the end of the step
374       * @param h  current step
375       * @return error ratio, greater than 1 if step should be rejected
376       */
377      protected abstract double estimateError(double[][] yDotK,
378                                              double[] y0, double[] y1,
379                                              double h);
380    
381    }