1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 package org.apache.commons.math.distribution;
19
20 import java.io.Serializable;
21
22 import org.apache.commons.math.MathException;
23 import org.apache.commons.math.MaxIterationsExceededException;
24 import org.apache.commons.math.special.Erf;
25
26 /**
27 * Default implementation of
28 * {@link org.apache.commons.math.distribution.NormalDistribution}.
29 *
30 * @version $Revision: 617953 $ $Date: 2008-02-02 22:54:00 -0700 (Sat, 02 Feb 2008) $
31 */
32 public class NormalDistributionImpl extends AbstractContinuousDistribution
33 implements NormalDistribution, Serializable {
34
35 /** Serializable version identifier */
36 private static final long serialVersionUID = 8589540077390120676L;
37
38 /** The mean of this distribution. */
39 private double mean = 0;
40
41 /** The standard deviation of this distribution. */
42 private double standardDeviation = 1;
43
44 /**
45 * Create a normal distribution using the given mean and standard deviation.
46 * @param mean mean for this distribution
47 * @param sd standard deviation for this distribution
48 */
49 public NormalDistributionImpl(double mean, double sd){
50 super();
51 setMean(mean);
52 setStandardDeviation(sd);
53 }
54
55 /**
56 * Creates normal distribution with the mean equal to zero and standard
57 * deviation equal to one.
58 */
59 public NormalDistributionImpl(){
60 this(0.0, 1.0);
61 }
62
63 /**
64 * Access the mean.
65 * @return mean for this distribution
66 */
67 public double getMean() {
68 return mean;
69 }
70
71 /**
72 * Modify the mean.
73 * @param mean for this distribution
74 */
75 public void setMean(double mean) {
76 this.mean = mean;
77 }
78
79 /**
80 * Access the standard deviation.
81 * @return standard deviation for this distribution
82 */
83 public double getStandardDeviation() {
84 return standardDeviation;
85 }
86
87 /**
88 * Modify the standard deviation.
89 * @param sd standard deviation for this distribution
90 * @throws IllegalArgumentException if <code>sd</code> is not positive.
91 */
92 public void setStandardDeviation(double sd) {
93 if (sd <= 0.0) {
94 throw new IllegalArgumentException(
95 "Standard deviation must be positive.");
96 }
97 standardDeviation = sd;
98 }
99
100 /**
101 * For this disbution, X, this method returns P(X < <code>x</code>).
102 * @param x the value at which the CDF is evaluated.
103 * @return CDF evaluted at <code>x</code>.
104 * @throws MathException if the algorithm fails to converge; unless
105 * x is more than 20 standard deviations from the mean, in which case the
106 * convergence exception is caught and 0 or 1 is returned.
107 */
108 public double cumulativeProbability(double x) throws MathException {
109 try {
110 return 0.5 * (1.0 + Erf.erf((x - mean) /
111 (standardDeviation * Math.sqrt(2.0))));
112 } catch (MaxIterationsExceededException ex) {
113 if (x < (mean - 20 * standardDeviation)) { // JDK 1.5 blows at 38
114 return 0.0d;
115 } else if (x > (mean + 20 * standardDeviation)) {
116 return 1.0d;
117 } else {
118 throw ex;
119 }
120 }
121 }
122
123 /**
124 * For this distribution, X, this method returns the critical point x, such
125 * that P(X < x) = <code>p</code>.
126 * <p>
127 * Returns <code>Double.NEGATIVE_INFINITY</code> for p=0 and
128 * <code>Double.POSITIVE_INFINITY</code> for p=1.</p>
129 *
130 * @param p the desired probability
131 * @return x, such that P(X < x) = <code>p</code>
132 * @throws MathException if the inverse cumulative probability can not be
133 * computed due to convergence or other numerical errors.
134 * @throws IllegalArgumentException if <code>p</code> is not a valid
135 * probability.
136 */
137 public double inverseCumulativeProbability(final double p)
138 throws MathException {
139 if (p == 0) {
140 return Double.NEGATIVE_INFINITY;
141 }
142 if (p == 1) {
143 return Double.POSITIVE_INFINITY;
144 }
145 return super.inverseCumulativeProbability(p);
146 }
147
148 /**
149 * Access the domain value lower bound, based on <code>p</code>, used to
150 * bracket a CDF root. This method is used by
151 * {@link #inverseCumulativeProbability(double)} to find critical values.
152 *
153 * @param p the desired probability for the critical value
154 * @return domain value lower bound, i.e.
155 * P(X < <i>lower bound</i>) < <code>p</code>
156 */
157 protected double getDomainLowerBound(double p) {
158 double ret;
159
160 if (p < .5) {
161 ret = -Double.MAX_VALUE;
162 } else {
163 ret = getMean();
164 }
165
166 return ret;
167 }
168
169 /**
170 * Access the domain value upper bound, based on <code>p</code>, used to
171 * bracket a CDF root. This method is used by
172 * {@link #inverseCumulativeProbability(double)} to find critical values.
173 *
174 * @param p the desired probability for the critical value
175 * @return domain value upper bound, i.e.
176 * P(X < <i>upper bound</i>) > <code>p</code>
177 */
178 protected double getDomainUpperBound(double p) {
179 double ret;
180
181 if (p < .5) {
182 ret = getMean();
183 } else {
184 ret = Double.MAX_VALUE;
185 }
186
187 return ret;
188 }
189
190 /**
191 * Access the initial domain value, based on <code>p</code>, used to
192 * bracket a CDF root. This method is used by
193 * {@link #inverseCumulativeProbability(double)} to find critical values.
194 *
195 * @param p the desired probability for the critical value
196 * @return initial domain value
197 */
198 protected double getInitialDomain(double p) {
199 double ret;
200
201 if (p < .5) {
202 ret = getMean() - getStandardDeviation();
203 } else if (p > .5) {
204 ret = getMean() + getStandardDeviation();
205 } else {
206 ret = getMean();
207 }
208
209 return ret;
210 }
211 }