View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import org.apache.commons.rng.UniformRandomProvider;
20  import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog;
21  
22  /**
23   * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson distribution</a>.
24   *
25   * <ul>
26   *  <li>
27   *   For large means, we use the rejection algorithm described in
28   *   <blockquote>
29   *    Devroye, Luc. (1981).<i>The Computer Generation of Poisson Random Variables</i><br>
30   *    <strong>Computing</strong> vol. 26 pp. 197-207.
31   *   </blockquote>
32   *  </li>
33   * </ul>
34   *
35   * <p>This sampler is suitable for {@code mean >= 40}.</p>
36   *
37   * <p>Sampling uses:</p>
38   *
39   * <ul>
40   *   <li>{@link UniformRandomProvider#nextLong()}
41   *   <li>{@link UniformRandomProvider#nextDouble()}
42   * </ul>
43   *
44   * @since 1.1
45   */
46  public class LargeMeanPoissonSampler
47      implements SharedStateDiscreteSampler {
48      /** Upper bound to avoid truncation. */
49      private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE;
50      /** Class to compute {@code log(n!)}. This has no cached values. */
51      private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG;
52      /** Used when there is no requirement for a small mean Poisson sampler. */
53      private static final SharedStateDiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER =
54          new SharedStateDiscreteSampler() {
55              @Override
56              public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
57                  // No requirement for RNG
58                  return this;
59              }
60  
61              @Override
62              public int sample() {
63                  // No Poisson sample
64                  return 0;
65              }
66          };
67  
68      static {
69          // Create without a cache.
70          NO_CACHE_FACTORIAL_LOG = FactorialLog.create();
71      }
72  
73      /** Underlying source of randomness. */
74      private final UniformRandomProvider rng;
75      /** Exponential. */
76      private final SharedStateContinuousSampler exponential;
77      /** Gaussian. */
78      private final SharedStateContinuousSampler gaussian;
79      /** Local class to compute {@code log(n!)}. This may have cached values. */
80      private final InternalUtils.FactorialLog factorialLog;
81  
82      // Working values
83  
84      /** Algorithm constant: {@code Math.floor(mean)}. */
85      private final double lambda;
86      /** Algorithm constant: {@code Math.log(lambda)}. */
87      private final double logLambda;
88      /** Algorithm constant: {@code factorialLog((int) lambda)}. */
89      private final double logLambdaFactorial;
90      /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */
91      private final double delta;
92      /** Algorithm constant: {@code delta / 2}. */
93      private final double halfDelta;
94      /** Algorithm constant: {@code 2 * lambda + delta}. */
95      private final double twolpd;
96      /**
97       * Algorithm constant: {@code a1 / aSum}.
98       * <ul>
99       *  <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li>
100      *  <li>{@code aSum = a1 + a2 + 1}</li>
101      * </ul>
102      */
103     private final double p1;
104     /**
105      * Algorithm constant: {@code a2 / aSum}.
106      * <ul>
107      *  <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li>
108      *  <li>{@code aSum = a1 + a2 + 1}</li>
109      * </ul>
110      */
111     private final double p2;
112     /** Algorithm constant: {@code 1 / (8 * lambda)}. */
113     private final double c1;
114 
115     /** The internal Poisson sampler for the lambda fraction. */
116     private final SharedStateDiscreteSampler smallMeanPoissonSampler;
117 
118     /**
119      * @param rng Generator of uniformly distributed random numbers.
120      * @param mean Mean.
121      * @throws IllegalArgumentException if {@code mean < 1} or
122      * {@code mean > 0.5 *} {@link Integer#MAX_VALUE}.
123      */
124     public LargeMeanPoissonSampler(UniformRandomProvider rng,
125                                    double mean) {
126         if (mean < 1) {
127             throw new IllegalArgumentException("mean is not >= 1: " + mean);
128         }
129         // The algorithm is not valid if Math.floor(mean) is not an integer.
130         if (mean > MAX_MEAN) {
131             throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
132         }
133         this.rng = rng;
134 
135         gaussian = new ZigguratNormalizedGaussianSampler(rng);
136         exponential = AhrensDieterExponentialSampler.of(rng, 1);
137         // Plain constructor uses the uncached function.
138         factorialLog = NO_CACHE_FACTORIAL_LOG;
139 
140         // Cache values used in the algorithm
141         lambda = Math.floor(mean);
142         logLambda = Math.log(lambda);
143         logLambdaFactorial = getFactorialLog((int) lambda);
144         delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
145         halfDelta = delta / 2;
146         twolpd = 2 * lambda + delta;
147         c1 = 1 / (8 * lambda);
148         final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
149         final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd);
150         final double aSum = a1 + a2 + 1;
151         p1 = a1 / aSum;
152         p2 = a2 / aSum;
153 
154         // The algorithm requires a Poisson sample from the remaining lambda fraction.
155         final double lambdaFractional = mean - lambda;
156         smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
157             NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
158             KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
159     }
160 
161     /**
162      * Instantiates a sampler using a precomputed state.
163      *
164      * @param rng              Generator of uniformly distributed random numbers.
165      * @param state            The state for {@code lambda = (int)Math.floor(mean)}.
166      * @param lambdaFractional The lambda fractional value
167      *                         ({@code mean - (int)Math.floor(mean))}.
168      * @throws IllegalArgumentException
169      *                         if {@code lambdaFractional < 0 || lambdaFractional >= 1}.
170      */
171     LargeMeanPoissonSampler(UniformRandomProvider rng,
172                             LargeMeanPoissonSamplerState state,
173                             double lambdaFractional) {
174         if (lambdaFractional < 0 || lambdaFractional >= 1) {
175             throw new IllegalArgumentException(
176                     "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): " + lambdaFractional);
177         }
178         this.rng = rng;
179 
180         gaussian = new ZigguratNormalizedGaussianSampler(rng);
181         exponential = AhrensDieterExponentialSampler.of(rng, 1);
182         // Plain constructor uses the uncached function.
183         factorialLog = NO_CACHE_FACTORIAL_LOG;
184 
185         // Use the state to initialise the algorithm
186         lambda = state.getLambdaRaw();
187         logLambda = state.getLogLambda();
188         logLambdaFactorial = state.getLogLambdaFactorial();
189         delta = state.getDelta();
190         halfDelta = state.getHalfDelta();
191         twolpd = state.getTwolpd();
192         p1 = state.getP1();
193         p2 = state.getP2();
194         c1 = state.getC1();
195 
196         // The algorithm requires a Poisson sample from the remaining lambda fraction.
197         smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
198             NO_SMALL_MEAN_POISSON_SAMPLER : // Not used.
199             KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
200     }
201 
202     /**
203      * @param rng Generator of uniformly distributed random numbers.
204      * @param source Source to copy.
205      */
206     private LargeMeanPoissonSampler(UniformRandomProvider rng,
207                                     LargeMeanPoissonSampler source) {
208         this.rng = rng;
209 
210         gaussian = source.gaussian.withUniformRandomProvider(rng);
211         exponential = source.exponential.withUniformRandomProvider(rng);
212         // Reuse the cache
213         factorialLog = source.factorialLog;
214 
215         lambda = source.lambda;
216         logLambda = source.logLambda;
217         logLambdaFactorial = source.logLambdaFactorial;
218         delta = source.delta;
219         halfDelta = source.halfDelta;
220         twolpd = source.twolpd;
221         p1 = source.p1;
222         p2 = source.p2;
223         c1 = source.c1;
224 
225         // Share the state of the small sampler
226         smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng);
227     }
228 
229     /** {@inheritDoc} */
230     @Override
231     public int sample() {
232         // This will never be null. It may be a no-op delegate that returns zero.
233         final int y2 = smallMeanPoissonSampler.sample();
234 
235         double x;
236         double y;
237         double v;
238         int a;
239         double t;
240         double qr;
241         double qa;
242         while (true) {
243             // Step 1:
244             final double u = rng.nextDouble();
245             if (u <= p1) {
246                 // Step 2:
247                 final double n = gaussian.sample();
248                 x = n * Math.sqrt(lambda + halfDelta) - 0.5d;
249                 if (x > delta || x < -lambda) {
250                     continue;
251                 }
252                 y = x < 0 ? Math.floor(x) : Math.ceil(x);
253                 final double e = exponential.sample();
254                 v = -e - 0.5 * n * n + c1;
255             } else {
256                 // Step 3:
257                 if (u > p1 + p2) {
258                     y = lambda;
259                     break;
260                 }
261                 x = delta + (twolpd / delta) * exponential.sample();
262                 y = Math.ceil(x);
263                 v = -exponential.sample() - delta * (x + 1) / twolpd;
264             }
265             // The Squeeze Principle
266             // Step 4.1:
267             a = x < 0 ? 1 : 0;
268             t = y * (y + 1) / (2 * lambda);
269             // Step 4.2
270             if (v < -t && a == 0) {
271                 y = lambda + y;
272                 break;
273             }
274             // Step 4.3:
275             qr = t * ((2 * y + 1) / (6 * lambda) - 1);
276             qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
277             // Step 4.4:
278             if (v < qa) {
279                 y = lambda + y;
280                 break;
281             }
282             // Step 4.5:
283             if (v > qr) {
284                 continue;
285             }
286             // Step 4.6:
287             if (v < y * logLambda - getFactorialLog((int) (y + lambda)) + logLambdaFactorial) {
288                 y = lambda + y;
289                 break;
290             }
291         }
292 
293         return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE);
294     }
295 
296     /**
297      * Compute the natural logarithm of the factorial of {@code n}.
298      *
299      * @param n Argument.
300      * @return {@code log(n!)}
301      * @throws IllegalArgumentException if {@code n < 0}.
302      */
303     private double getFactorialLog(int n) {
304         return factorialLog.value(n);
305     }
306 
307     /** {@inheritDoc} */
308     @Override
309     public String toString() {
310         return "Large Mean Poisson deviate [" + rng.toString() + "]";
311     }
312 
313     /**
314      * {@inheritDoc}
315      *
316      * @since 1.3
317      */
318     @Override
319     public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
320         return new LargeMeanPoissonSampler(rng, this);
321     }
322 
323     /**
324      * Creates a new Poisson distribution sampler.
325      *
326      * @param rng Generator of uniformly distributed random numbers.
327      * @param mean Mean.
328      * @return the sampler
329      * @throws IllegalArgumentException if {@code mean < 1} or {@code mean > 0.5 *}
330      * {@link Integer#MAX_VALUE}.
331      * @since 1.3
332      */
333     public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
334                                                 double mean) {
335         return new LargeMeanPoissonSampler(rng, mean);
336     }
337     /**
338      * Gets the initialisation state of the sampler.
339      *
340      * <p>The state is computed using an integer {@code lambda} value of
341      * {@code lambda = (int)Math.floor(mean)}.
342      *
343      * <p>The state will be suitable for reconstructing a new sampler with a mean
344      * in the range {@code lambda <= mean < lambda+1} using
345      * {@link #LargeMeanPoissonSampler(UniformRandomProvider, LargeMeanPoissonSamplerState, double)}.
346      *
347      * @return the state
348      */
349     LargeMeanPoissonSamplerState getState() {
350         return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial,
351                 delta, halfDelta, twolpd, p1, p2, c1);
352     }
353 
354     /**
355      * Encapsulate the state of the sampler. The state is valid for construction of
356      * a sampler in the range {@code lambda <= mean < lambda+1}.
357      *
358      * <p>This class is immutable.
359      *
360      * @see #getLambda()
361      */
362     static final class LargeMeanPoissonSamplerState {
363         /** Algorithm constant {@code lambda}. */
364         private final double lambda;
365         /** Algorithm constant {@code logLambda}. */
366         private final double logLambda;
367         /** Algorithm constant {@code logLambdaFactorial}. */
368         private final double logLambdaFactorial;
369         /** Algorithm constant {@code delta}. */
370         private final double delta;
371         /** Algorithm constant {@code halfDelta}. */
372         private final double halfDelta;
373         /** Algorithm constant {@code twolpd}. */
374         private final double twolpd;
375         /** Algorithm constant {@code p1}. */
376         private final double p1;
377         /** Algorithm constant {@code p2}. */
378         private final double p2;
379         /** Algorithm constant {@code c1}. */
380         private final double c1;
381 
382         /**
383          * Creates the state.
384          *
385          * <p>The state is valid for construction of a sampler in the range
386          * {@code lambda <= mean < lambda+1} where {@code lambda} is an integer.
387          *
388          * @param lambda the lambda
389          * @param logLambda the log lambda
390          * @param logLambdaFactorial the log lambda factorial
391          * @param delta the delta
392          * @param halfDelta the half delta
393          * @param twolpd the two lambda plus delta
394          * @param p1 the p1 constant
395          * @param p2 the p2 constant
396          * @param c1 the c1 constant
397          */
398         LargeMeanPoissonSamplerState(double lambda, double logLambda,
399                 double logLambdaFactorial, double delta, double halfDelta, double twolpd,
400                 double p1, double p2, double c1) {
401             this.lambda = lambda;
402             this.logLambda = logLambda;
403             this.logLambdaFactorial = logLambdaFactorial;
404             this.delta = delta;
405             this.halfDelta = halfDelta;
406             this.twolpd = twolpd;
407             this.p1 = p1;
408             this.p2 = p2;
409             this.c1 = c1;
410         }
411 
412         /**
413          * Get the lambda value for the state.
414          *
415          * <p>Equal to {@code floor(mean)} for a Poisson sampler.
416          * @return the lambda value
417          */
418         int getLambda() {
419             return (int) getLambdaRaw();
420         }
421 
422         /**
423          * @return algorithm constant {@code lambda}
424          */
425         double getLambdaRaw() {
426             return lambda;
427         }
428 
429         /**
430          * @return algorithm constant {@code logLambda}
431          */
432         double getLogLambda() {
433             return logLambda;
434         }
435 
436         /**
437          * @return algorithm constant {@code logLambdaFactorial}
438          */
439         double getLogLambdaFactorial() {
440             return logLambdaFactorial;
441         }
442 
443         /**
444          * @return algorithm constant {@code delta}
445          */
446         double getDelta() {
447             return delta;
448         }
449 
450         /**
451          * @return algorithm constant {@code halfDelta}
452          */
453         double getHalfDelta() {
454             return halfDelta;
455         }
456 
457         /**
458          * @return algorithm constant {@code twolpd}
459          */
460         double getTwolpd() {
461             return twolpd;
462         }
463 
464         /**
465          * @return algorithm constant {@code p1}
466          */
467         double getP1() {
468             return p1;
469         }
470 
471         /**
472          * @return algorithm constant {@code p2}
473          */
474         double getP2() {
475             return p2;
476         }
477 
478         /**
479          * @return algorithm constant {@code c1}
480          */
481         double getC1() {
482             return c1;
483         }
484     }
485 }