1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.rng.UniformRandomProvider;
20 import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog;
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46 public class LargeMeanPoissonSampler
47 implements SharedStateDiscreteSampler {
48
49 private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE;
50
51 private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG;
52
53 private static final SharedStateDiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER =
54 new SharedStateDiscreteSampler() {
55 @Override
56 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
57
58 return this;
59 }
60
61 @Override
62 public int sample() {
63
64 return 0;
65 }
66 };
67
68 static {
69
70 NO_CACHE_FACTORIAL_LOG = FactorialLog.create();
71 }
72
73
74 private final UniformRandomProvider rng;
75
76 private final SharedStateContinuousSampler exponential;
77
78 private final SharedStateContinuousSampler gaussian;
79
80 private final InternalUtils.FactorialLog factorialLog;
81
82
83
84
85 private final double lambda;
86
87 private final double logLambda;
88
89 private final double logLambdaFactorial;
90
91 private final double delta;
92
93 private final double halfDelta;
94
95 private final double twolpd;
96
97
98
99
100
101
102
103 private final double p1;
104
105
106
107
108
109
110
111 private final double p2;
112
113 private final double c1;
114
115
116 private final SharedStateDiscreteSampler smallMeanPoissonSampler;
117
118
119
120
121
122
123
124 public LargeMeanPoissonSampler(UniformRandomProvider rng,
125 double mean) {
126 if (mean < 1) {
127 throw new IllegalArgumentException("mean is not >= 1: " + mean);
128 }
129
130 if (mean > MAX_MEAN) {
131 throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
132 }
133 this.rng = rng;
134
135 gaussian = new ZigguratNormalizedGaussianSampler(rng);
136 exponential = AhrensDieterExponentialSampler.of(rng, 1);
137
138 factorialLog = NO_CACHE_FACTORIAL_LOG;
139
140
141 lambda = Math.floor(mean);
142 logLambda = Math.log(lambda);
143 logLambdaFactorial = getFactorialLog((int) lambda);
144 delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
145 halfDelta = delta / 2;
146 twolpd = 2 * lambda + delta;
147 c1 = 1 / (8 * lambda);
148 final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
149 final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd);
150 final double aSum = a1 + a2 + 1;
151 p1 = a1 / aSum;
152 p2 = a2 / aSum;
153
154
155 final double lambdaFractional = mean - lambda;
156 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
157 NO_SMALL_MEAN_POISSON_SAMPLER :
158 KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
159 }
160
161
162
163
164
165
166
167
168
169
170
171 LargeMeanPoissonSampler(UniformRandomProvider rng,
172 LargeMeanPoissonSamplerState state,
173 double lambdaFractional) {
174 if (lambdaFractional < 0 || lambdaFractional >= 1) {
175 throw new IllegalArgumentException(
176 "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): " + lambdaFractional);
177 }
178 this.rng = rng;
179
180 gaussian = new ZigguratNormalizedGaussianSampler(rng);
181 exponential = AhrensDieterExponentialSampler.of(rng, 1);
182
183 factorialLog = NO_CACHE_FACTORIAL_LOG;
184
185
186 lambda = state.getLambdaRaw();
187 logLambda = state.getLogLambda();
188 logLambdaFactorial = state.getLogLambdaFactorial();
189 delta = state.getDelta();
190 halfDelta = state.getHalfDelta();
191 twolpd = state.getTwolpd();
192 p1 = state.getP1();
193 p2 = state.getP2();
194 c1 = state.getC1();
195
196
197 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
198 NO_SMALL_MEAN_POISSON_SAMPLER :
199 KempSmallMeanPoissonSampler.of(rng, lambdaFractional);
200 }
201
202
203
204
205
206 private LargeMeanPoissonSampler(UniformRandomProvider rng,
207 LargeMeanPoissonSampler source) {
208 this.rng = rng;
209
210 gaussian = source.gaussian.withUniformRandomProvider(rng);
211 exponential = source.exponential.withUniformRandomProvider(rng);
212
213 factorialLog = source.factorialLog;
214
215 lambda = source.lambda;
216 logLambda = source.logLambda;
217 logLambdaFactorial = source.logLambdaFactorial;
218 delta = source.delta;
219 halfDelta = source.halfDelta;
220 twolpd = source.twolpd;
221 p1 = source.p1;
222 p2 = source.p2;
223 c1 = source.c1;
224
225
226 smallMeanPoissonSampler = source.smallMeanPoissonSampler.withUniformRandomProvider(rng);
227 }
228
229
230 @Override
231 public int sample() {
232
233 final int y2 = smallMeanPoissonSampler.sample();
234
235 double x;
236 double y;
237 double v;
238 int a;
239 double t;
240 double qr;
241 double qa;
242 while (true) {
243
244 final double u = rng.nextDouble();
245 if (u <= p1) {
246
247 final double n = gaussian.sample();
248 x = n * Math.sqrt(lambda + halfDelta) - 0.5d;
249 if (x > delta || x < -lambda) {
250 continue;
251 }
252 y = x < 0 ? Math.floor(x) : Math.ceil(x);
253 final double e = exponential.sample();
254 v = -e - 0.5 * n * n + c1;
255 } else {
256
257 if (u > p1 + p2) {
258 y = lambda;
259 break;
260 }
261 x = delta + (twolpd / delta) * exponential.sample();
262 y = Math.ceil(x);
263 v = -exponential.sample() - delta * (x + 1) / twolpd;
264 }
265
266
267 a = x < 0 ? 1 : 0;
268 t = y * (y + 1) / (2 * lambda);
269
270 if (v < -t && a == 0) {
271 y = lambda + y;
272 break;
273 }
274
275 qr = t * ((2 * y + 1) / (6 * lambda) - 1);
276 qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
277
278 if (v < qa) {
279 y = lambda + y;
280 break;
281 }
282
283 if (v > qr) {
284 continue;
285 }
286
287 if (v < y * logLambda - getFactorialLog((int) (y + lambda)) + logLambdaFactorial) {
288 y = lambda + y;
289 break;
290 }
291 }
292
293 return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE);
294 }
295
296
297
298
299
300
301
302
303 private double getFactorialLog(int n) {
304 return factorialLog.value(n);
305 }
306
307
308 @Override
309 public String toString() {
310 return "Large Mean Poisson deviate [" + rng.toString() + "]";
311 }
312
313
314
315
316
317
318 @Override
319 public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
320 return new LargeMeanPoissonSampler(rng, this);
321 }
322
323
324
325
326
327
328
329
330
331
332
333 public static SharedStateDiscreteSampler of(UniformRandomProvider rng,
334 double mean) {
335 return new LargeMeanPoissonSampler(rng, mean);
336 }
337
338
339
340
341
342
343
344
345
346
347
348
349 LargeMeanPoissonSamplerState getState() {
350 return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial,
351 delta, halfDelta, twolpd, p1, p2, c1);
352 }
353
354
355
356
357
358
359
360
361
362 static final class LargeMeanPoissonSamplerState {
363
364 private final double lambda;
365
366 private final double logLambda;
367
368 private final double logLambdaFactorial;
369
370 private final double delta;
371
372 private final double halfDelta;
373
374 private final double twolpd;
375
376 private final double p1;
377
378 private final double p2;
379
380 private final double c1;
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398 LargeMeanPoissonSamplerState(double lambda, double logLambda,
399 double logLambdaFactorial, double delta, double halfDelta, double twolpd,
400 double p1, double p2, double c1) {
401 this.lambda = lambda;
402 this.logLambda = logLambda;
403 this.logLambdaFactorial = logLambdaFactorial;
404 this.delta = delta;
405 this.halfDelta = halfDelta;
406 this.twolpd = twolpd;
407 this.p1 = p1;
408 this.p2 = p2;
409 this.c1 = c1;
410 }
411
412
413
414
415
416
417
418 int getLambda() {
419 return (int) getLambdaRaw();
420 }
421
422
423
424
425 double getLambdaRaw() {
426 return lambda;
427 }
428
429
430
431
432 double getLogLambda() {
433 return logLambda;
434 }
435
436
437
438
439 double getLogLambdaFactorial() {
440 return logLambdaFactorial;
441 }
442
443
444
445
446 double getDelta() {
447 return delta;
448 }
449
450
451
452
453 double getHalfDelta() {
454 return halfDelta;
455 }
456
457
458
459
460 double getTwolpd() {
461 return twolpd;
462 }
463
464
465
466
467 double getP1() {
468 return p1;
469 }
470
471
472
473
474 double getP2() {
475 return p2;
476 }
477
478
479
480
481 double getC1() {
482 return c1;
483 }
484 }
485 }