View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import org.apache.commons.rng.UniformRandomProvider;
20  import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog;
21  
22  /**
23   * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>.
24   *
25   * <ul>
26   *  <li>
27   *   For large means, we use the rejection algorithm described in
28   *   <blockquote>
29   *    Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random Variables</i><br>
30   *    <strong>Computing</strong> vol. 26 pp. 197-207.
31   *   </blockquote>
32   *  </li>
33   * </ul>
34   *
35   * @since 1.1
36   *
37   * This sampler is suitable for {@code mean >= 40}.
38   */
39  public class LargeMeanPoissonSampler
40      implements DiscreteSampler {
41      /** Upper bound to avoid truncation. */
42      private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE;
43      /** Class to compute {@code log(n!)}. This has no cached values. */
44      private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG;
45      /** Used when there is no requirement for a small mean Poisson sampler. */
46      private static final DiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER = null;
47  
48      static {
49          // Create without a cache.
50          NO_CACHE_FACTORIAL_LOG = FactorialLog.create();
51      }
52  
53      /** Underlying source of randomness. */
54      private final UniformRandomProvider rng;
55      /** Exponential. */
56      private final ContinuousSampler exponential;
57      /** Gaussian. */
58      private final ContinuousSampler gaussian;
59      /** Local class to compute {@code log(n!)}. This may have cached values. */
60      private final InternalUtils.FactorialLog factorialLog;
61  
62      // Working values
63  
64      /** Algorithm constant: {@code Math.floor(mean)}. */
65      private final double lambda;
66      /** Algorithm constant: {@code Math.log(lambda)}. */
67      private final double logLambda;
68      /** Algorithm constant: {@code factorialLog((int) lambda)}. */
69      private final double logLambdaFactorial;
70      /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */
71      private final double delta;
72      /** Algorithm constant: {@code delta / 2}. */
73      private final double halfDelta;
74      /** Algorithm constant: {@code 2 * lambda + delta}. */
75      private final double twolpd;
76      /**
77       * Algorithm constant: {@code a1 / aSum} with
78       * <ul>
79       *  <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li>
80       *  <li>{@code aSum = a1 + a2 + 1}</li>
81       * </ul>
82       */
83      private final double p1;
84      /**
85       * Algorithm constant: {@code a2 / aSum} with
86       * <ul>
87       *  <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li>
88       *  <li>{@code aSum = a1 + a2 + 1}</li>
89       * </ul>
90       */
91      private final double p2;
92      /** Algorithm constant: {@code 1 / (8 * lambda)}. */
93      private final double c1;
94  
95      /** The internal Poisson sampler for the lambda fraction. */
96      private final DiscreteSampler smallMeanPoissonSampler;
97  
98      /**
99       * @param rng Generator of uniformly distributed random numbers.
100      * @param mean Mean.
101      * @throws IllegalArgumentException if {@code mean <= 0} or
102      * {@code mean > 0.5 *} {@link Integer#MAX_VALUE}.
103      */
104     public LargeMeanPoissonSampler(UniformRandomProvider rng,
105                                    double mean) {
106         if (mean <= 0) {
107           throw new IllegalArgumentException(mean + " <= " + 0);
108         }
109         // The algorithm is not valid if Math.floor(mean) is not an integer.
110         if (mean > MAX_MEAN) {
111             throw new IllegalArgumentException(mean + " > " + MAX_MEAN);
112         }
113         this.rng = rng;
114 
115         gaussian = new ZigguratNormalizedGaussianSampler(rng);
116         exponential = new AhrensDieterExponentialSampler(rng, 1);
117         // Plain constructor uses the uncached function.
118         factorialLog = NO_CACHE_FACTORIAL_LOG;
119 
120         // Cache values used in the algorithm
121         lambda = Math.floor(mean);
122         logLambda = Math.log(lambda);
123         logLambdaFactorial = factorialLog((int) lambda);
124         delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
125         halfDelta = delta / 2;
126         twolpd = 2 * lambda + delta;
127         c1 = 1 / (8 * lambda);
128         final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
129         final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd);
130         final double aSum = a1 + a2 + 1;
131         p1 = a1 / aSum;
132         p2 = a2 / aSum;
133 
134         // The algorithm requires a Poisson sample from the remaining lambda fraction.
135         final double lambdaFractional = mean - lambda;
136         smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
137             NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
138             new SmallMeanPoissonSampler(rng, lambdaFractional);
139     }
140 
141     /**
142      * Instantiates a sampler using a precomputed state.
143      *
144      * @param rng              Generator of uniformly distributed random numbers.
145      * @param state            The state for {@code lambda = (int)Math.floor(mean)}.
146      * @param lambdaFractional The lambda fractional value
147      *                         ({@code mean - (int)Math.floor(mean))}.
148      * @throws IllegalArgumentException
149      *                         if {@code lambdaFractional < 0 || lambdaFractional >= 1}.
150      */
151     LargeMeanPoissonSampler(UniformRandomProvider rng,
152                             LargeMeanPoissonSamplerState state,
153                             double lambdaFractional) {
154         if (lambdaFractional < 0 || lambdaFractional >= 1) {
155             throw new IllegalArgumentException(
156                     "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): " + lambdaFractional);
157         }
158         this.rng = rng;
159 
160         gaussian = new ZigguratNormalizedGaussianSampler(rng);
161         exponential = new AhrensDieterExponentialSampler(rng, 1);
162         // Plain constructor uses the uncached function.
163         factorialLog = NO_CACHE_FACTORIAL_LOG;
164 
165         // Use the state to initialise the algorithm
166         lambda = state.getLambdaRaw();
167         logLambda = state.getLogLambda();
168         logLambdaFactorial = state.getLogLambdaFactorial();
169         delta = state.getDelta();
170         halfDelta = state.getHalfDelta();
171         twolpd = state.getTwolpd();
172         p1 = state.getP1();
173         p2 = state.getP2();
174         c1 = state.getC1();
175 
176         // The algorithm requires a Poisson sample from the remaining lambda fraction.
177         smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
178             NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
179             new SmallMeanPoissonSampler(rng, lambdaFractional);
180     }
181 
182     /** {@inheritDoc} */
183     @Override
184     public int sample() {
185 
186         final int y2 = (smallMeanPoissonSampler == null) ?
187             0 : // No lambda fraction
188             smallMeanPoissonSampler.sample();
189 
190         double x = 0;
191         double y = 0;
192         double v = 0;
193         int a = 0;
194         double t = 0;
195         double qr = 0;
196         double qa = 0;
197         while (true) {
198             final double u = rng.nextDouble();
199             if (u <= p1) {
200                 final double n = gaussian.sample();
201                 x = n * Math.sqrt(lambda + halfDelta) - 0.5d;
202                 if (x > delta || x < -lambda) {
203                     continue;
204                 }
205                 y = x < 0 ? Math.floor(x) : Math.ceil(x);
206                 final double e = exponential.sample();
207                 v = -e - 0.5 * n * n + c1;
208             } else {
209                 if (u > p1 + p2) {
210                     y = lambda;
211                     break;
212                 }
213                 x = delta + (twolpd / delta) * exponential.sample();
214                 y = Math.ceil(x);
215                 v = -exponential.sample() - delta * (x + 1) / twolpd;
216             }
217             a = x < 0 ? 1 : 0;
218             t = y * (y + 1) / (2 * lambda);
219             if (v < -t && a == 0) {
220                 y = lambda + y;
221                 break;
222             }
223             qr = t * ((2 * y + 1) / (6 * lambda) - 1);
224             qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
225             if (v < qa) {
226                 y = lambda + y;
227                 break;
228             }
229             if (v > qr) {
230                 continue;
231             }
232             if (v < y * logLambda - factorialLog((int) (y + lambda)) + logLambdaFactorial) {
233                 y = lambda + y;
234                 break;
235             }
236         }
237 
238         return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE);
239     }
240 
241     /**
242      * Compute the natural logarithm of the factorial of {@code n}.
243      *
244      * @param n Argument.
245      * @return {@code log(n!)}
246      * @throws IllegalArgumentException if {@code n < 0}.
247      */
248     private double factorialLog(int n) {
249         return factorialLog.value(n);
250     }
251 
252     /** {@inheritDoc} */
253     @Override
254     public String toString() {
255         return "Large Mean Poisson deviate [" + rng.toString() + "]";
256     }
257 
258     /**
259      * Gets the initialisation state of the sampler.
260      *
261      * <p>The state is computed using an integer {@code lambda} value of
262      * {@code lambda = (int)Math.floor(mean)}.
263      *
264      * <p>The state will be suitable for reconstructing a new sampler with a mean
265      * in the range {@code lambda <= mean < lambda+1} using
266      * {@link #LargeMeanPoissonSampler(UniformRandomProvider, LargeMeanPoissonSamplerState, double)}.
267      *
268      * @return the state
269      */
270     LargeMeanPoissonSamplerState getState() {
271         return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial,
272                 delta, halfDelta, twolpd, p1, p2, c1);
273     }
274 
275     /**
276      * Encapsulate the state of the sampler. The state is valid for construction of
277      * a sampler in the range {@code lambda <= mean < lambda+1}.
278      *
279      * <p>This class is immutable.
280      *
281      * @see #getLambda()
282      */
283     static final class LargeMeanPoissonSamplerState {
284         /** Algorithm constant {@code lambda}. */
285         private final double lambda;
286         /** Algorithm constant {@code logLambda}. */
287         private final double logLambda;
288         /** Algorithm constant {@code logLambdaFactorial}. */
289         private final double logLambdaFactorial;
290         /** Algorithm constant {@code delta}. */
291         private final double delta;
292         /** Algorithm constant {@code halfDelta}. */
293         private final double halfDelta;
294         /** Algorithm constant {@code twolpd}. */
295         private final double twolpd;
296         /** Algorithm constant {@code p1}. */
297         private final double p1;
298         /** Algorithm constant {@code p2}. */
299         private final double p2;
300         /** Algorithm constant {@code c1}. */
301         private final double c1;
302 
303         /**
304          * Creates the state.
305          *
306          * <p>The state is valid for construction of a sampler in the range
307          * {@code lambda <= mean < lambda+1} where {@code lambda} is an integer.
308          *
309          * @param lambda the lambda
310          * @param logLambda the log lambda
311          * @param logLambdaFactorial the log lambda factorial
312          * @param delta the delta
313          * @param halfDelta the half delta
314          * @param twolpd the two lambda plus delta
315          * @param p1 the p1 constant
316          * @param p2 the p2 constant
317          * @param c1 the c1 constant
318          */
319         private LargeMeanPoissonSamplerState(double lambda, double logLambda,
320                 double logLambdaFactorial, double delta, double halfDelta, double twolpd,
321                 double p1, double p2, double c1) {
322           this.lambda = lambda;
323           this.logLambda = logLambda;
324           this.logLambdaFactorial = logLambdaFactorial;
325           this.delta = delta;
326           this.halfDelta = halfDelta;
327           this.twolpd = twolpd;
328           this.p1 = p1;
329           this.p2 = p2;
330           this.c1 = c1;
331         }
332 
333         /**
334          * Get the lambda value for the state.
335          *
336          * <p>Equal to {@code floor(mean)} for a Poisson sampler.
337          * @return the lambda value
338          */
339         int getLambda() {
340             return (int) getLambdaRaw();
341         }
342 
343         /**
344          * @return algorithm constant {@code lambda}
345          */
346         double getLambdaRaw() {
347           return lambda;
348         }
349 
350         /**
351          * @return algorithm constant {@code logLambda}
352          */
353         double getLogLambda() {
354           return logLambda;
355         }
356 
357         /**
358          * @return algorithm constant {@code logLambdaFactorial}
359          */
360         double getLogLambdaFactorial() {
361           return logLambdaFactorial;
362         }
363 
364         /**
365          * @return algorithm constant {@code delta}
366          */
367         double getDelta() {
368           return delta;
369         }
370 
371         /**
372          * @return algorithm constant {@code halfDelta}
373          */
374         double getHalfDelta() {
375           return halfDelta;
376         }
377 
378         /**
379          * @return algorithm constant {@code twolpd}
380          */
381         double getTwolpd() {
382           return twolpd;
383         }
384 
385         /**
386          * @return algorithm constant {@code p1}
387          */
388         double getP1() {
389           return p1;
390         }
391 
392         /**
393          * @return algorithm constant {@code p2}
394          */
395         double getP2() {
396           return p2;
397         }
398 
399         /**
400          * @return algorithm constant {@code c1}
401          */
402         double getC1() {
403           return c1;
404         }
405     }
406 }