001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math3.analysis.function;
019    
020    import java.util.Arrays;
021    
022    import org.apache.commons.math3.analysis.FunctionUtils;
023    import org.apache.commons.math3.analysis.UnivariateFunction;
024    import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction;
025    import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
026    import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
027    import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
028    import org.apache.commons.math3.exception.NotStrictlyPositiveException;
029    import org.apache.commons.math3.exception.NullArgumentException;
030    import org.apache.commons.math3.exception.DimensionMismatchException;
031    import org.apache.commons.math3.util.FastMath;
032    import org.apache.commons.math3.util.Precision;
033    
034    /**
035     * <a href="http://en.wikipedia.org/wiki/Gaussian_function">
036     *  Gaussian</a> function.
037     *
038     * @since 3.0
039     * @version $Id: Gaussian.java 1383441 2012-09-11 14:56:39Z luc $
040     */
041    public class Gaussian implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction {
042        /** Mean. */
043        private final double mean;
044        /** Inverse of the standard deviation. */
045        private final double is;
046        /** Inverse of twice the square of the standard deviation. */
047        private final double i2s2;
048        /** Normalization factor. */
049        private final double norm;
050    
051        /**
052         * Gaussian with given normalization factor, mean and standard deviation.
053         *
054         * @param norm Normalization factor.
055         * @param mean Mean.
056         * @param sigma Standard deviation.
057         * @throws NotStrictlyPositiveException if {@code sigma <= 0}.
058         */
059        public Gaussian(double norm,
060                        double mean,
061                        double sigma)
062            throws NotStrictlyPositiveException {
063            if (sigma <= 0) {
064                throw new NotStrictlyPositiveException(sigma);
065            }
066    
067            this.norm = norm;
068            this.mean = mean;
069            this.is   = 1 / sigma;
070            this.i2s2 = 0.5 * is * is;
071        }
072    
073        /**
074         * Normalized gaussian with given mean and standard deviation.
075         *
076         * @param mean Mean.
077         * @param sigma Standard deviation.
078         * @throws NotStrictlyPositiveException if {@code sigma <= 0}.
079         */
080        public Gaussian(double mean,
081                        double sigma)
082            throws NotStrictlyPositiveException {
083            this(1 / (sigma * FastMath.sqrt(2 * Math.PI)), mean, sigma);
084        }
085    
086        /**
087         * Normalized gaussian with zero mean and unit standard deviation.
088         */
089        public Gaussian() {
090            this(0, 1);
091        }
092    
093        /** {@inheritDoc} */
094        public double value(double x) {
095            return value(x - mean, norm, i2s2);
096        }
097    
098        /** {@inheritDoc}
099         * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)}
100         */
101        @Deprecated
102        public UnivariateFunction derivative() {
103            return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative();
104        }
105    
106        /**
107         * Parametric function where the input array contains the parameters of
108         * the Gaussian, ordered as follows:
109         * <ul>
110         *  <li>Norm</li>
111         *  <li>Mean</li>
112         *  <li>Standard deviation</li>
113         * </ul>
114         */
115        public static class Parametric implements ParametricUnivariateFunction {
116            /**
117             * Computes the value of the Gaussian at {@code x}.
118             *
119             * @param x Value for which the function must be computed.
120             * @param param Values of norm, mean and standard deviation.
121             * @return the value of the function.
122             * @throws NullArgumentException if {@code param} is {@code null}.
123             * @throws DimensionMismatchException if the size of {@code param} is
124             * not 3.
125             * @throws NotStrictlyPositiveException if {@code param[2]} is negative.
126             */
127            public double value(double x, double ... param)
128                throws NullArgumentException,
129                       DimensionMismatchException,
130                       NotStrictlyPositiveException {
131                validateParameters(param);
132    
133                final double diff = x - param[1];
134                final double i2s2 = 1 / (2 * param[2] * param[2]);
135                return Gaussian.value(diff, param[0], i2s2);
136            }
137    
138            /**
139             * Computes the value of the gradient at {@code x}.
140             * The components of the gradient vector are the partial
141             * derivatives of the function with respect to each of the
142             * <em>parameters</em> (norm, mean and standard deviation).
143             *
144             * @param x Value at which the gradient must be computed.
145             * @param param Values of norm, mean and standard deviation.
146             * @return the gradient vector at {@code x}.
147             * @throws NullArgumentException if {@code param} is {@code null}.
148             * @throws DimensionMismatchException if the size of {@code param} is
149             * not 3.
150             * @throws NotStrictlyPositiveException if {@code param[2]} is negative.
151             */
152            public double[] gradient(double x, double ... param)
153                throws NullArgumentException,
154                       DimensionMismatchException,
155                       NotStrictlyPositiveException {
156                validateParameters(param);
157    
158                final double norm = param[0];
159                final double diff = x - param[1];
160                final double sigma = param[2];
161                final double i2s2 = 1 / (2 * sigma * sigma);
162    
163                final double n = Gaussian.value(diff, 1, i2s2);
164                final double m = norm * n * 2 * i2s2 * diff;
165                final double s = m * diff / sigma;
166    
167                return new double[] { n, m, s };
168            }
169    
170            /**
171             * Validates parameters to ensure they are appropriate for the evaluation of
172             * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
173             * methods.
174             *
175             * @param param Values of norm, mean and standard deviation.
176             * @throws NullArgumentException if {@code param} is {@code null}.
177             * @throws DimensionMismatchException if the size of {@code param} is
178             * not 3.
179             * @throws NotStrictlyPositiveException if {@code param[2]} is negative.
180             */
181            private void validateParameters(double[] param)
182                throws NullArgumentException,
183                       DimensionMismatchException,
184                       NotStrictlyPositiveException {
185                if (param == null) {
186                    throw new NullArgumentException();
187                }
188                if (param.length != 3) {
189                    throw new DimensionMismatchException(param.length, 3);
190                }
191                if (param[2] <= 0) {
192                    throw new NotStrictlyPositiveException(param[2]);
193                }
194            }
195        }
196    
197        /**
198         * @param xMinusMean {@code x - mean}.
199         * @param norm Normalization factor.
200         * @param i2s2 Inverse of twice the square of the standard deviation.
201         * @return the value of the Gaussian at {@code x}.
202         */
203        private static double value(double xMinusMean,
204                                    double norm,
205                                    double i2s2) {
206            return norm * FastMath.exp(-xMinusMean * xMinusMean * i2s2);
207        }
208    
209        /** {@inheritDoc}
210         * @since 3.1
211         */
212        public DerivativeStructure value(final DerivativeStructure t) {
213    
214            final double u = is * (t.getValue() - mean);
215            double[] f = new double[t.getOrder() + 1];
216    
217            // the nth order derivative of the Gaussian has the form:
218            // dn(g(x)/dxn = (norm / s^n) P_n(u) exp(-u^2/2) with u=(x-m)/s
219            // where P_n(u) is a degree n polynomial with same parity as n
220            // P_0(u) = 1, P_1(u) = -u, P_2(u) = u^2 - 1, P_3(u) = -u^3 + 3 u...
221            // the general recurrence relation for P_n is:
222            // P_n(u) = P_(n-1)'(u) - u P_(n-1)(u)
223            // as per polynomial parity, we can store coefficients of both P_(n-1) and P_n in the same array
224            final double[] p = new double[f.length];
225            p[0] = 1;
226            final double u2 = u * u;
227            double coeff = norm * FastMath.exp(-0.5 * u2);
228            if (coeff <= Precision.SAFE_MIN) {
229                Arrays.fill(f, 0.0);
230            } else {
231                f[0] = coeff;
232                for (int n = 1; n < f.length; ++n) {
233    
234                    // update and evaluate polynomial P_n(x)
235                    double v = 0;
236                    p[n] = -p[n - 1];
237                    for (int k = n; k >= 0; k -= 2) {
238                        v = v * u2 + p[k];
239                        if (k > 2) {
240                            p[k - 2] = (k - 1) * p[k - 1] - p[k - 3];
241                        } else if (k == 2) {
242                            p[0] = p[1];
243                        }
244                    }
245                    if ((n & 0x1) == 1) {
246                        v *= u;
247                    }
248    
249                    coeff *= is;
250                    f[n] = coeff * v;
251    
252                }
253            }
254    
255            return t.compose(f);
256    
257        }
258    
259    }