1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.math3.stat.inference.ChiSquareTest;
20 import org.apache.commons.rng.UniformRandomProvider;
21 import org.apache.commons.rng.core.source32.IntProvider;
22 import org.apache.commons.rng.core.source64.SplitMix64;
23 import org.apache.commons.rng.sampling.RandomAssert;
24 import org.apache.commons.rng.simple.RandomSource;
25 import org.junit.Assert;
26 import org.junit.Test;
27
28
29
30
31
32
33
34 public class MarsagliaTsangWangDiscreteSamplerTest {
35 @Test(expected = IllegalArgumentException.class)
36 public void testCreateDiscreteDistributionThrowsWithNullProbabilites() {
37 createDiscreteDistributionSampler(null);
38 }
39
40 @Test(expected = IllegalArgumentException.class)
41 public void testCreateDiscreteDistributionThrowsWithZeroLengthProbabilites() {
42 createDiscreteDistributionSampler(new double[0]);
43 }
44
45 @Test(expected = IllegalArgumentException.class)
46 public void testCreateDiscreteDistributionThrowsWithNegativeProbabilites() {
47 createDiscreteDistributionSampler(new double[] {-1, 0.1, 0.2});
48 }
49
50 @Test(expected = IllegalArgumentException.class)
51 public void testCreateDiscreteDistributionThrowsWithNaNProbabilites() {
52 createDiscreteDistributionSampler(new double[] {0.1, Double.NaN, 0.2});
53 }
54
55 @Test(expected = IllegalArgumentException.class)
56 public void testCreateDiscreteDistributionThrowsWithInfiniteProbabilites() {
57 createDiscreteDistributionSampler(new double[] {0.1, Double.POSITIVE_INFINITY, 0.2});
58 }
59
60 @Test(expected = IllegalArgumentException.class)
61 public void testCreateDiscreteDistributionThrowsWithInfiniteSumProbabilites() {
62 createDiscreteDistributionSampler(new double[] {Double.MAX_VALUE, Double.MAX_VALUE});
63 }
64
65 @Test(expected = IllegalArgumentException.class)
66 public void testCreateDiscreteDistributionThrowsWithZeroSumProbabilites() {
67 createDiscreteDistributionSampler(new double[4]);
68 }
69
70
71
72
73 @Test
74 public void testToString() {
75 final DiscreteSampler sampler = createDiscreteDistributionSampler(new double[] {0.5, 0.5});
76 String text = sampler.toString();
77 for (String item : new String[] {"Marsaglia", "Tsang", "Wang"}) {
78 Assert.assertTrue("toString missing: " + item, text.contains(item));
79 }
80 }
81
82
83
84
85
86
87
88 private static SharedStateDiscreteSampler createDiscreteDistributionSampler(double[] probabilities) {
89 final UniformRandomProvider rng = new SplitMix64(0L);
90 return MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng, probabilities);
91 }
92
93
94
95
96
97
98
99 @Test
100 public void testOffsetSamples() {
101
102
103 final int[] prob = new int[6];
104 prob[0] = 1;
105 prob[1] = 1 + 1 << 6;
106 prob[2] = 1 + 1 << 12;
107 prob[3] = 1 + 1 << 18;
108 prob[4] = 1 + 1 << 24;
109
110 prob[5] = (1 << 30) - (prob[0] + prob[1] + prob[2] + prob[3] + prob[4]);
111
112
113
114 int n1 = 0;
115 int n2 = 0;
116 int n3 = 0;
117 int n4 = 0;
118 for (final int m : prob) {
119 n1 += getBase64Digit(m, 1);
120 n2 += getBase64Digit(m, 2);
121 n3 += getBase64Digit(m, 3);
122 n4 += getBase64Digit(m, 4);
123 }
124
125 final int t1 = n1 << 24;
126 final int t2 = t1 + (n2 << 18);
127 final int t3 = t2 + (n3 << 12);
128 final int t4 = t3 + (n4 << 6);
129
130
131 final int[] values = new int[] {0, t1, t2, t3, t4, 0xffffffff};
132 for (int i = 0; i < values.length; i++) {
133 values[i] <<= 2;
134 }
135
136 final UniformRandomProvider rng1 = new FixedSequenceIntProvider(values);
137 final UniformRandomProvider rng2 = new FixedSequenceIntProvider(values);
138 final UniformRandomProvider rng3 = new FixedSequenceIntProvider(values);
139
140
141 final int offset1 = 1;
142 final int offset2 = 1 << 8;
143 final int offset3 = 1 << 16;
144
145 final double[] p1 = createProbabilities(offset1, prob);
146 final double[] p2 = createProbabilities(offset2, prob);
147 final double[] p3 = createProbabilities(offset3, prob);
148
149 final SharedStateDiscreteSampler sampler1 = MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng1, p1);
150 final SharedStateDiscreteSampler sampler2 = MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng2, p2);
151 final SharedStateDiscreteSampler sampler3 = MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng3, p3);
152
153 for (int i = 0; i < values.length; i++) {
154
155 final int s1 = sampler1.sample() - offset1;
156 final int s2 = sampler2.sample() - offset2;
157 final int s3 = sampler3.sample() - offset3;
158 Assert.assertEquals("Offset sample 1 and 2 do not match", s1, s2);
159 Assert.assertEquals("Offset Sample 1 and 3 do not match", s1, s3);
160 }
161 }
162
163
164
165
166
167
168
169
170 private static double[] createProbabilities(int offset, int[] prob) {
171 double[] probabilities = new double[offset + prob.length];
172 for (int i = 0; i < prob.length; i++) {
173 probabilities[i + offset] = prob[i];
174 }
175 return probabilities;
176 }
177
178
179
180
181 @Test
182 public void testRealProbabilityDistributionSamples() {
183
184 final double[] probabilities = new double[11];
185 final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64);
186 for (int i = 0; i < probabilities.length; i++) {
187 probabilities[i] = rng.nextDouble();
188 }
189
190
191 final UniformRandomProvider dummyRng = new FixedSequenceIntProvider(new int[] {0xffffffff});
192 final SharedStateDiscreteSampler dummySampler = MarsagliaTsangWangDiscreteSampler.Enumerated.of(dummyRng, probabilities);
193
194 dummySampler.sample();
195
196
197 final SharedStateDiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng, probabilities);
198
199 final int numberOfSamples = 10000;
200 final long[] samples = new long[probabilities.length];
201 for (int i = 0; i < numberOfSamples; i++) {
202 samples[sampler.sample()]++;
203 }
204
205 final ChiSquareTest chiSquareTest = new ChiSquareTest();
206
207 Assert.assertFalse(chiSquareTest.chiSquareTest(probabilities, samples, 0.001));
208 }
209
210
211
212
213
214 @Test
215 public void testStorageRequirements8() {
216
217
218
219
220
221
222 checkStorageRequirements(8, 0.06);
223 }
224
225
226
227
228
229 @Test
230 public void testStorageRequirements16() {
231
232
233
234
235
236
237 checkStorageRequirements(16, 17.0);
238 }
239
240
241
242
243
244
245
246
247 private static void checkStorageRequirements(int k, double expectedLimitMB) {
248
249
250
251 final int maxSamples = 1 << k;
252
253
254
255
256 final int m = (1 << (30 - k)) - 1;
257
258
259 final long sum = (long) maxSamples * m;
260 final int total = 1 << 30;
261 Assert.assertTrue("Worst case uniform distribution is above 2^30", sum < total);
262
263
264 final int d1 = getBase64Digit(m, 1);
265 final int d2 = getBase64Digit(m, 2);
266 final int d3 = getBase64Digit(m, 3);
267 final int d4 = getBase64Digit(m, 4);
268 final int d5 = getBase64Digit(m, 5);
269
270 int bytes;
271 if (k <= 8) {
272 bytes = 1;
273 } else if (k <= 16) {
274 bytes = 2;
275 } else {
276 bytes = 4;
277 }
278 final double storageMB = bytes * 1e-6 * (d1 + d2 + d3 + d4 + d5) * maxSamples;
279 Assert.assertTrue(
280 "Worst case uniform distribution storage " + storageMB + "MB is above expected limit: " + expectedLimitMB,
281 storageMB < expectedLimitMB);
282 }
283
284
285
286
287
288
289
290
291 private static int getBase64Digit(int m, int k) {
292 return (m >>> (30 - 6 * k)) & 63;
293 }
294
295
296
297
298 @Test(expected = IllegalArgumentException.class)
299 public void testCreatePoissonDistributionThrowsWithMeanLargerThanUpperBound() {
300 final UniformRandomProvider rng = new FixedRNG();
301 final double mean = 1025;
302 @SuppressWarnings("unused")
303 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Poisson.of(rng, mean);
304 }
305
306
307
308
309 @Test(expected = IllegalArgumentException.class)
310 public void testCreatePoissonDistributionThrowsWithZeroMean() {
311 final UniformRandomProvider rng = new FixedRNG();
312 final double mean = 0;
313 @SuppressWarnings("unused")
314 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Poisson.of(rng, mean);
315 }
316
317
318
319
320 @Test
321 public void testCreatePoissonDistributionWithMaximumMean() {
322 final UniformRandomProvider rng = new FixedRNG();
323 final double mean = 1024;
324 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Poisson.of(rng, mean);
325
326
327 sampler.sample();
328 }
329
330
331
332
333
334 @Test
335 public void testCreatePoissonDistributionWithSmallMean() {
336 final UniformRandomProvider rng = new FixedRNG();
337 final double mean = 0.25;
338 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Poisson.of(rng, mean);
339
340
341 sampler.sample();
342 }
343
344
345
346
347
348
349 @Test
350 public void testCreatePoissonDistributionWithMediumMean() {
351 final UniformRandomProvider rng = new FixedRNG();
352 final double mean = 21.4;
353 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Poisson.of(rng, mean);
354
355
356 sampler.sample();
357 }
358
359
360
361
362 @Test(expected = IllegalArgumentException.class)
363 public void testCreateBinomialDistributionThrowsWithTrialsBelow0() {
364 final UniformRandomProvider rng = new FixedRNG();
365 final int trials = -1;
366 final double p = 0.5;
367 @SuppressWarnings("unused")
368 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
369 }
370
371
372
373
374 @Test(expected = IllegalArgumentException.class)
375 public void testCreateBinomialDistributionThrowsWithTrialsAboveMax() {
376 final UniformRandomProvider rng = new FixedRNG();
377 final int trials = 1 << 16;
378 final double p = 0.5;
379 @SuppressWarnings("unused")
380 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
381 }
382
383
384
385
386 @Test(expected = IllegalArgumentException.class)
387 public void testCreateBinomialDistributionThrowsWithProbabilityBelow0() {
388 final UniformRandomProvider rng = new FixedRNG();
389 final int trials = 1;
390 final double p = -0.5;
391 @SuppressWarnings("unused")
392 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
393 }
394
395
396
397
398 @Test(expected = IllegalArgumentException.class)
399 public void testCreateBinomialDistributionThrowsWithProbabilityAbove1() {
400 final UniformRandomProvider rng = new FixedRNG();
401 final int trials = 1;
402 final double p = 1.5;
403 @SuppressWarnings("unused")
404 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
405 }
406
407
408
409
410
411 @Test
412 public void testCreateBinomialDistributionWithSmallestP0ValueAndHighestProbabilityOfSuccess() {
413 final UniformRandomProvider rng = new FixedRNG();
414
415
416
417
418
419
420
421
422 final int trials = (int) Math.floor(Math.log(Double.MIN_VALUE) / Math.log(0.5));
423 final double p = 0.5;
424
425 Assert.assertEquals("Invalid test set-up for p(0)", Double.MIN_VALUE, getBinomialP0(trials, p), 0);
426 Assert.assertEquals("Invalid test set-up for p(0)", 0, getBinomialP0(trials + 1, p), 0);
427
428
429 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
430 sampler.sample();
431 }
432
433
434
435
436
437 @Test(expected = IllegalArgumentException.class)
438 public void testCreateBinomialDistributionThrowsWhenP0IsZero() {
439 final UniformRandomProvider rng = new FixedRNG();
440
441 final int trials = 1 + (int) Math.floor(Math.log(Double.MIN_VALUE) / Math.log(0.5));
442 final double p = 0.5;
443
444 Assert.assertEquals("Invalid test set-up for p(0)", 0, getBinomialP0(trials, p), 0);
445 @SuppressWarnings("unused")
446 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
447 }
448
449
450
451
452
453 @Test
454 public void testCreateBinomialDistributionWithLargestTrialsAndSmallestProbabilityOfSuccess() {
455 final UniformRandomProvider rng = new FixedRNG();
456
457
458
459
460
461
462
463
464 final int trials = (1 << 16) - 1;
465 double p = 1 - Math.exp(Math.log(Double.MIN_VALUE) / trials);
466
467
468 Assert.assertEquals("Invalid test set-up for p(0)", Double.MIN_VALUE, getBinomialP0(trials, p), 0);
469
470
471 double upper = p * 2;
472 Assert.assertEquals("Invalid test set-up for p(0)", 0, getBinomialP0(trials, upper), 0);
473
474 double lower = p;
475 while (Double.doubleToRawLongBits(lower) + 1 < Double.doubleToRawLongBits(upper)) {
476 final double mid = (upper + lower) / 2;
477 if (getBinomialP0(trials, mid) == 0) {
478 upper = mid;
479 } else {
480 lower = mid;
481 }
482 }
483 p = lower;
484
485
486 Assert.assertEquals("Invalid test set-up for p(0)", Double.MIN_VALUE, getBinomialP0(trials, p), 0);
487 Assert.assertEquals("Invalid test set-up for p(0)", 0, getBinomialP0(trials, Math.nextAfter(p, 1)), 0);
488
489 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
490
491 sampler.sample();
492 }
493
494
495
496
497
498
499
500
501 private static double getBinomialP0(int trials, double probabilityOfSuccess) {
502 return Math.exp(trials * Math.log(1 - probabilityOfSuccess));
503 }
504
505
506
507
508 @Test
509 public void testCreateBinomialDistributionWithProbability0() {
510 final UniformRandomProvider rng = new FixedRNG();
511 final int trials = 1000000;
512 final double p = 0;
513 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
514 for (int i = 0; i < 5; i++) {
515 Assert.assertEquals(0, sampler.sample());
516 }
517
518 Assert.assertTrue(sampler.toString().contains("Binomial"));
519 }
520
521
522
523
524
525 @Test
526 public void testCreateBinomialDistributionWithProbability1() {
527 final UniformRandomProvider rng = new FixedRNG();
528 final int trials = 1000000;
529 final double p = 1;
530 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
531 for (int i = 0; i < 5; i++) {
532 Assert.assertEquals(trials, sampler.sample());
533 }
534
535 Assert.assertTrue(sampler.toString().contains("Binomial"));
536 }
537
538
539
540
541
542
543 @Test
544 public void testCreateBinomialDistributionWithLargeNumberOfTrials() {
545 final UniformRandomProvider rng = new FixedRNG();
546 final int trials = 65000;
547 final double p = 0.01;
548 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
549
550
551 sampler.sample();
552 }
553
554
555
556
557
558 @Test
559 public void testCreateBinomialDistributionWithProbability50Percent() {
560 final UniformRandomProvider rng = new FixedRNG();
561 final int trials = 10;
562 final double p = 0.5;
563 final DiscreteSampler sampler = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p);
564
565
566 sampler.sample();
567 }
568
569
570
571
572
573 @Test
574 public void testBinomialSamplerToString() {
575 final UniformRandomProvider rng = new FixedRNG();
576 final int trials = 10;
577 final double p1 = 0.4;
578 final double p2 = 1 - p1;
579 final DiscreteSampler sampler1 = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p1);
580 final DiscreteSampler sampler2 = MarsagliaTsangWangDiscreteSampler.Binomial.of(rng, trials, p2);
581 Assert.assertEquals(sampler1.toString(), sampler2.toString());
582 }
583
584
585
586
587 @Test
588 public void testSharedStateSamplerWith8bitStorage() {
589 testSharedStateSampler(0, new int[] {1, 2, 3, 4, 5});
590 }
591
592
593
594
595 @Test
596 public void testSharedStateSamplerWith16bitStorage() {
597 testSharedStateSampler(1 << 8, new int[] {1, 2, 3, 4, 5});
598 }
599
600
601
602
603 @Test
604 public void testSharedStateSamplerWith32bitStorage() {
605 testSharedStateSampler(1 << 16, new int[] {1, 2, 3, 4, 5});
606 }
607
608
609
610
611
612
613
614
615 private static void testSharedStateSampler(int offset, int[] prob) {
616 final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
617 final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
618 double[] probabilities = createProbabilities(offset, prob);
619 final SharedStateDiscreteSampler sampler1 =
620 MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng1, probabilities);
621 final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
622 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
623 }
624
625
626
627
628 @Test
629 public void testSharedStateSamplerWithFixedBinomialDistribution() {
630 testSharedStateSampler(10, 1.0);
631 }
632
633
634
635
636
637 @Test
638 public void testSharedStateSamplerWithInvertedBinomialDistribution() {
639 testSharedStateSampler(10, 0.999);
640 }
641
642
643
644
645
646
647
648
649 private static void testSharedStateSampler(int trials, double probabilityOfSuccess) {
650 final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
651 final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
652 final SharedStateDiscreteSampler sampler1 =
653 MarsagliaTsangWangDiscreteSampler.Binomial.of(rng1, trials, probabilityOfSuccess);
654 final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
655 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
656 }
657
658
659
660
661 private static class FixedSequenceIntProvider extends IntProvider {
662
663 private int count;
664
665 private final int[] values;
666
667
668
669
670
671
672 FixedSequenceIntProvider(int[] values) {
673 this.values = values;
674 }
675
676 @Override
677 public int next() {
678
679 return values[count++ % values.length];
680 }
681 }
682
683
684
685
686 private static class FixedRNG extends IntProvider {
687 @Override
688 public int next() {
689 return 0xffffffff;
690 }
691 }
692 }