1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.rng.UniformRandomProvider;
20 import org.apache.commons.rng.sampling.distribution.InternalUtils.FactorialLog;
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39 public class LargeMeanPoissonSampler
40 implements DiscreteSampler {
41
42 private static final double MAX_MEAN = 0.5 * Integer.MAX_VALUE;
43
44 private static final InternalUtils.FactorialLog NO_CACHE_FACTORIAL_LOG;
45
46 private static final DiscreteSampler NO_SMALL_MEAN_POISSON_SAMPLER = null;
47
48 static {
49
50 NO_CACHE_FACTORIAL_LOG = FactorialLog.create();
51 }
52
53
54 private final UniformRandomProvider rng;
55
56 private final ContinuousSampler exponential;
57
58 private final ContinuousSampler gaussian;
59
60 private final InternalUtils.FactorialLog factorialLog;
61
62
63
64
65 private final double lambda;
66
67 private final double logLambda;
68
69 private final double logLambdaFactorial;
70
71 private final double delta;
72
73 private final double halfDelta;
74
75 private final double twolpd;
76
77
78
79
80
81
82
83 private final double p1;
84
85
86
87
88
89
90
91 private final double p2;
92
93 private final double c1;
94
95
96 private final DiscreteSampler smallMeanPoissonSampler;
97
98
99
100
101
102
103
104 public LargeMeanPoissonSampler(UniformRandomProvider rng,
105 double mean) {
106 if (mean <= 0) {
107 throw new IllegalArgumentException(mean + " <= " + 0);
108 }
109
110 if (mean > MAX_MEAN) {
111 throw new IllegalArgumentException(mean + " > " + MAX_MEAN);
112 }
113 this.rng = rng;
114
115 gaussian = new ZigguratNormalizedGaussianSampler(rng);
116 exponential = new AhrensDieterExponentialSampler(rng, 1);
117
118 factorialLog = NO_CACHE_FACTORIAL_LOG;
119
120
121 lambda = Math.floor(mean);
122 logLambda = Math.log(lambda);
123 logLambdaFactorial = factorialLog((int) lambda);
124 delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
125 halfDelta = delta / 2;
126 twolpd = 2 * lambda + delta;
127 c1 = 1 / (8 * lambda);
128 final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
129 final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd);
130 final double aSum = a1 + a2 + 1;
131 p1 = a1 / aSum;
132 p2 = a2 / aSum;
133
134
135 final double lambdaFractional = mean - lambda;
136 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
137 NO_SMALL_MEAN_POISSON_SAMPLER :
138 new SmallMeanPoissonSampler(rng, lambdaFractional);
139 }
140
141
142
143
144
145
146
147
148
149
150
151 LargeMeanPoissonSampler(UniformRandomProvider rng,
152 LargeMeanPoissonSamplerState state,
153 double lambdaFractional) {
154 if (lambdaFractional < 0 || lambdaFractional >= 1) {
155 throw new IllegalArgumentException(
156 "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): " + lambdaFractional);
157 }
158 this.rng = rng;
159
160 gaussian = new ZigguratNormalizedGaussianSampler(rng);
161 exponential = new AhrensDieterExponentialSampler(rng, 1);
162
163 factorialLog = NO_CACHE_FACTORIAL_LOG;
164
165
166 lambda = state.getLambdaRaw();
167 logLambda = state.getLogLambda();
168 logLambdaFactorial = state.getLogLambdaFactorial();
169 delta = state.getDelta();
170 halfDelta = state.getHalfDelta();
171 twolpd = state.getTwolpd();
172 p1 = state.getP1();
173 p2 = state.getP2();
174 c1 = state.getC1();
175
176
177 smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
178 NO_SMALL_MEAN_POISSON_SAMPLER :
179 new SmallMeanPoissonSampler(rng, lambdaFractional);
180 }
181
182
183 @Override
184 public int sample() {
185
186 final int y2 = (smallMeanPoissonSampler == null) ?
187 0 :
188 smallMeanPoissonSampler.sample();
189
190 double x = 0;
191 double y = 0;
192 double v = 0;
193 int a = 0;
194 double t = 0;
195 double qr = 0;
196 double qa = 0;
197 while (true) {
198 final double u = rng.nextDouble();
199 if (u <= p1) {
200 final double n = gaussian.sample();
201 x = n * Math.sqrt(lambda + halfDelta) - 0.5d;
202 if (x > delta || x < -lambda) {
203 continue;
204 }
205 y = x < 0 ? Math.floor(x) : Math.ceil(x);
206 final double e = exponential.sample();
207 v = -e - 0.5 * n * n + c1;
208 } else {
209 if (u > p1 + p2) {
210 y = lambda;
211 break;
212 }
213 x = delta + (twolpd / delta) * exponential.sample();
214 y = Math.ceil(x);
215 v = -exponential.sample() - delta * (x + 1) / twolpd;
216 }
217 a = x < 0 ? 1 : 0;
218 t = y * (y + 1) / (2 * lambda);
219 if (v < -t && a == 0) {
220 y = lambda + y;
221 break;
222 }
223 qr = t * ((2 * y + 1) / (6 * lambda) - 1);
224 qa = qr - (t * t) / (3 * (lambda + a * (y + 1)));
225 if (v < qa) {
226 y = lambda + y;
227 break;
228 }
229 if (v > qr) {
230 continue;
231 }
232 if (v < y * logLambda - factorialLog((int) (y + lambda)) + logLambdaFactorial) {
233 y = lambda + y;
234 break;
235 }
236 }
237
238 return (int) Math.min(y2 + (long) y, Integer.MAX_VALUE);
239 }
240
241
242
243
244
245
246
247
248 private double factorialLog(int n) {
249 return factorialLog.value(n);
250 }
251
252
253 @Override
254 public String toString() {
255 return "Large Mean Poisson deviate [" + rng.toString() + "]";
256 }
257
258
259
260
261
262
263
264
265
266
267
268
269
270 LargeMeanPoissonSamplerState getState() {
271 return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial,
272 delta, halfDelta, twolpd, p1, p2, c1);
273 }
274
275
276
277
278
279
280
281
282
283 static final class LargeMeanPoissonSamplerState {
284
285 private final double lambda;
286
287 private final double logLambda;
288
289 private final double logLambdaFactorial;
290
291 private final double delta;
292
293 private final double halfDelta;
294
295 private final double twolpd;
296
297 private final double p1;
298
299 private final double p2;
300
301 private final double c1;
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319 private LargeMeanPoissonSamplerState(double lambda, double logLambda,
320 double logLambdaFactorial, double delta, double halfDelta, double twolpd,
321 double p1, double p2, double c1) {
322 this.lambda = lambda;
323 this.logLambda = logLambda;
324 this.logLambdaFactorial = logLambdaFactorial;
325 this.delta = delta;
326 this.halfDelta = halfDelta;
327 this.twolpd = twolpd;
328 this.p1 = p1;
329 this.p2 = p2;
330 this.c1 = c1;
331 }
332
333
334
335
336
337
338
339 int getLambda() {
340 return (int) getLambdaRaw();
341 }
342
343
344
345
346 double getLambdaRaw() {
347 return lambda;
348 }
349
350
351
352
353 double getLogLambda() {
354 return logLambda;
355 }
356
357
358
359
360 double getLogLambdaFactorial() {
361 return logLambdaFactorial;
362 }
363
364
365
366
367 double getDelta() {
368 return delta;
369 }
370
371
372
373
374 double getHalfDelta() {
375 return halfDelta;
376 }
377
378
379
380
381 double getTwolpd() {
382 return twolpd;
383 }
384
385
386
387
388 double getP1() {
389 return p1;
390 }
391
392
393
394
395 double getP2() {
396 return p2;
397 }
398
399
400
401
402 double getC1() {
403 return c1;
404 }
405 }
406 }