1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.rng.examples.jmh.sampling.distribution;
19
20 import org.apache.commons.math3.distribution.BinomialDistribution;
21 import org.apache.commons.math3.distribution.IntegerDistribution;
22 import org.apache.commons.math3.distribution.PoissonDistribution;
23 import org.apache.commons.rng.UniformRandomProvider;
24 import org.apache.commons.rng.sampling.distribution.AliasMethodDiscreteSampler;
25 import org.apache.commons.rng.sampling.distribution.DiscreteSampler;
26 import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
27 import org.apache.commons.rng.sampling.distribution.MarsagliaTsangWangDiscreteSampler;
28 import org.apache.commons.rng.simple.RandomSource;
29
30 import org.openjdk.jmh.annotations.Benchmark;
31 import org.openjdk.jmh.annotations.BenchmarkMode;
32 import org.openjdk.jmh.annotations.Fork;
33 import org.openjdk.jmh.annotations.Level;
34 import org.openjdk.jmh.annotations.Measurement;
35 import org.openjdk.jmh.annotations.Mode;
36 import org.openjdk.jmh.annotations.OutputTimeUnit;
37 import org.openjdk.jmh.annotations.Param;
38 import org.openjdk.jmh.annotations.Scope;
39 import org.openjdk.jmh.annotations.Setup;
40 import org.openjdk.jmh.annotations.State;
41 import org.openjdk.jmh.annotations.Warmup;
42
43 import java.util.Arrays;
44 import java.util.concurrent.ThreadLocalRandom;
45 import java.util.concurrent.TimeUnit;
46
47
48
49
50
51 @BenchmarkMode(Mode.AverageTime)
52 @OutputTimeUnit(TimeUnit.NANOSECONDS)
53 @Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
54 @Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
55 @State(Scope.Benchmark)
56 @Fork(value = 1, jvmArgs = {"-server", "-Xms128M", "-Xmx128M"})
57 public class EnumeratedDistributionSamplersPerformance {
58
59
60
61
62
63 private int value;
64
65
66
67
68
69 @State(Scope.Benchmark)
70 public static class LocalRandomSources {
71
72
73
74
75
76
77
78
79 @Param({"WELL_44497_B",
80 "ISAAC",
81 "XO_RO_SHI_RO_128_PLUS",
82 })
83 private String randomSourceName;
84
85
86 private UniformRandomProvider generator;
87
88
89
90
91 public UniformRandomProvider getGenerator() {
92 return generator;
93 }
94
95
96 @Setup
97 public void setup() {
98 final RandomSource randomSource = RandomSource.valueOf(randomSourceName);
99 generator = RandomSource.create(randomSource);
100 }
101 }
102
103
104
105
106
107
108
109 @State(Scope.Benchmark)
110 public abstract static class SamplerSources extends LocalRandomSources {
111
112
113
114 @Param({"BinarySearchDiscreteSampler",
115 "AliasMethodDiscreteSampler",
116 "GuideTableDiscreteSampler",
117 "MarsagliaTsangWangDiscreteSampler",
118
119
120
121
122
123
124
125
126
127
128 })
129 private String samplerType;
130
131
132 private DiscreteSamplerFactory factory;
133
134
135 private DiscreteSampler sampler;
136
137
138
139
140 interface DiscreteSamplerFactory {
141
142
143
144
145
146 DiscreteSampler create();
147 }
148
149
150
151
152
153
154 public DiscreteSampler getSampler() {
155 return sampler;
156 }
157
158
159 @Override
160 @Setup(Level.Iteration)
161 public void setup() {
162 super.setup();
163
164 final double[] probabilities = createProbabilities();
165 createSamplerFactory(getGenerator(), probabilities);
166 sampler = factory.create();
167 }
168
169
170
171
172
173
174 protected abstract double[] createProbabilities();
175
176
177
178
179
180
181
182 private void createSamplerFactory(final UniformRandomProvider rng,
183 final double[] probabilities) {
184
185 if ("BinarySearchDiscreteSampler".equals(samplerType)) {
186 factory = new DiscreteSamplerFactory() {
187 @Override
188 public DiscreteSampler create() {
189 return new BinarySearchDiscreteSampler(rng, probabilities);
190 }
191 };
192 } else if ("AliasMethodDiscreteSampler".equals(samplerType)) {
193 factory = new DiscreteSamplerFactory() {
194 @Override
195 public DiscreteSampler create() {
196 return AliasMethodDiscreteSampler.of(rng, probabilities);
197 }
198 };
199 } else if ("AliasMethodDiscreteSamplerNoPad".equals(samplerType)) {
200 factory = new DiscreteSamplerFactory() {
201 @Override
202 public DiscreteSampler create() {
203 return AliasMethodDiscreteSampler.of(rng, probabilities, -1);
204 }
205 };
206 } else if ("AliasMethodDiscreteSamplerAlpha1".equals(samplerType)) {
207 factory = new DiscreteSamplerFactory() {
208 @Override
209 public DiscreteSampler create() {
210 return AliasMethodDiscreteSampler.of(rng, probabilities, 1);
211 }
212 };
213 } else if ("AliasMethodDiscreteSamplerAlpha2".equals(samplerType)) {
214 factory = new DiscreteSamplerFactory() {
215 @Override
216 public DiscreteSampler create() {
217 return AliasMethodDiscreteSampler.of(rng, probabilities, 2);
218 }
219 };
220 } else if ("GuideTableDiscreteSampler".equals(samplerType)) {
221 factory = new DiscreteSamplerFactory() {
222 @Override
223 public DiscreteSampler create() {
224 return GuideTableDiscreteSampler.of(rng, probabilities);
225 }
226 };
227 } else if ("GuideTableDiscreteSamplerAlpha2".equals(samplerType)) {
228 factory = new DiscreteSamplerFactory() {
229 @Override
230 public DiscreteSampler create() {
231 return GuideTableDiscreteSampler.of(rng, probabilities, 2);
232 }
233 };
234 } else if ("GuideTableDiscreteSamplerAlpha8".equals(samplerType)) {
235 factory = new DiscreteSamplerFactory() {
236 @Override
237 public DiscreteSampler create() {
238 return GuideTableDiscreteSampler.of(rng, probabilities, 8);
239 }
240 };
241 } else if ("MarsagliaTsangWangDiscreteSampler".equals(samplerType)) {
242 factory = new DiscreteSamplerFactory() {
243 @Override
244 public DiscreteSampler create() {
245 return MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng, probabilities);
246 }
247 };
248 } else {
249 throw new IllegalStateException();
250 }
251 }
252
253
254
255
256
257
258 public DiscreteSampler createSampler() {
259 return factory.create();
260 }
261 }
262
263
264
265
266
267 @State(Scope.Benchmark)
268 public static class KnownDistributionSources extends SamplerSources {
269
270 private static final double CUMULATIVE_PROBABILITY_LIMIT = 1 - 1e-9;
271
272
273
274
275 @Param({"Binomial_N67_P0.7",
276 "Geometric_P0.2",
277 "4SidedLoadedDie",
278 "Poisson_Mean3.14",
279 "Poisson_Mean10_Mean20",
280 })
281 private String distribution;
282
283
284 @Override
285 protected double[] createProbabilities() {
286 if ("Binomial_N67_P0.7".equals(distribution)) {
287 final int trials = 67;
288 final double probabilityOfSuccess = 0.7;
289 final BinomialDistribution dist = new BinomialDistribution(null, trials, probabilityOfSuccess);
290 return createProbabilities(dist, 0, trials);
291 } else if ("Geometric_P0.2".equals(distribution)) {
292 final double probabilityOfSuccess = 0.2;
293 final double probabilityOfFailure = 1 - probabilityOfSuccess;
294
295
296
297 double p = 1.0;
298
299 double[] probabilities = new double[100];
300 double sum = 0;
301 int k = 0;
302 while (k < probabilities.length) {
303 probabilities[k] = p * probabilityOfSuccess;
304 sum += probabilities[k++];
305 if (sum > CUMULATIVE_PROBABILITY_LIMIT) {
306 break;
307 }
308
309 p *= probabilityOfFailure;
310 }
311 return Arrays.copyOf(probabilities, k);
312 } else if ("4SidedLoadedDie".equals(distribution)) {
313 return new double[] {1.0 / 2, 1.0 / 3, 1.0 / 12, 1.0 / 12};
314 } else if ("Poisson_Mean3.14".equals(distribution)) {
315 final double mean = 3.14;
316 final IntegerDistribution dist = createPoissonDistribution(mean);
317 final int max = dist.inverseCumulativeProbability(CUMULATIVE_PROBABILITY_LIMIT);
318 return createProbabilities(dist, 0, max);
319 } else if ("Poisson_Mean10_Mean20".equals(distribution)) {
320
321 final double mean1 = 10;
322 final double mean2 = 20;
323 final IntegerDistribution dist1 = createPoissonDistribution(mean2);
324 final int max = dist1.inverseCumulativeProbability(CUMULATIVE_PROBABILITY_LIMIT);
325 final double[] p1 = createProbabilities(dist1, 0, max);
326 final double[] p2 = createProbabilities(createPoissonDistribution(mean1), 0, max);
327 for (int i = 0; i < p1.length; i++) {
328 p1[i] += p2[i];
329 }
330
331 return p1;
332 }
333 throw new IllegalStateException();
334 }
335
336
337
338
339
340
341
342 private static IntegerDistribution createPoissonDistribution(double mean) {
343 return new PoissonDistribution(null, mean,
344 PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS);
345 }
346
347
348
349
350
351
352
353
354
355 private static double[] createProbabilities(IntegerDistribution dist, int lower, int upper) {
356 double[] probabilities = new double[upper - lower + 1];
357 int index = 0;
358 for (int x = lower; x <= upper; x++) {
359 probabilities[index++] = dist.probability(x);
360 }
361 return probabilities;
362 }
363 }
364
365
366
367
368
369
370 @State(Scope.Benchmark)
371 public static class RandomDistributionSources extends SamplerSources {
372
373
374
375
376
377 @Param({"6",
378
379
380
381 "96",
382
383
384
385 "3072"
386 })
387 private int randomNonUniformSize;
388
389
390 @Override
391 protected double[] createProbabilities() {
392 final double[] probabilities = new double[randomNonUniformSize];
393 final ThreadLocalRandom rng = ThreadLocalRandom.current();
394 for (int i = 0; i < probabilities.length; i++) {
395 probabilities[i] = rng.nextDouble();
396 }
397 return probabilities;
398 }
399 }
400
401
402
403
404 static final class BinarySearchDiscreteSampler
405 implements DiscreteSampler {
406
407 private final UniformRandomProvider rng;
408
409
410
411 private final double[] cumulativeProbabilities;
412
413
414
415
416
417
418
419
420 BinarySearchDiscreteSampler(UniformRandomProvider rng,
421 double[] probabilities) {
422
423 if (probabilities == null || probabilities.length == 0) {
424 throw new IllegalArgumentException("Probabilities must not be empty.");
425 }
426
427 final int size = probabilities.length;
428 cumulativeProbabilities = new double[size];
429
430 double sumProb = 0;
431 int count = 0;
432 for (final double prob : probabilities) {
433 if (prob < 0 ||
434 Double.isInfinite(prob) ||
435 Double.isNaN(prob)) {
436 throw new IllegalArgumentException("Invalid probability: " +
437 prob);
438 }
439
440
441 sumProb += prob;
442 cumulativeProbabilities[count++] = sumProb;
443 }
444
445 if (Double.isInfinite(sumProb) || sumProb <= 0) {
446 throw new IllegalArgumentException("Invalid sum of probabilities: " + sumProb);
447 }
448
449 this.rng = rng;
450
451
452 for (int i = 0; i < size; i++) {
453 final double norm = cumulativeProbabilities[i] / sumProb;
454 cumulativeProbabilities[i] = (norm < 1) ? norm : 1.0;
455 }
456 }
457
458
459 @Override
460 public int sample() {
461 final double u = rng.nextDouble();
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482 int lower = 0;
483 int upper = cumulativeProbabilities.length - 1;
484 while (lower < upper) {
485 final int mid = (lower + upper) >>> 1;
486 final double midVal = cumulativeProbabilities[mid];
487 if (u > midVal) {
488
489
490 lower = mid + 1;
491 } else {
492
493
494 upper = mid;
495 }
496 }
497 return upper;
498 }
499 }
500
501
502
503
504
505
506
507
508 @Benchmark
509 public int baselineInt() {
510 return value;
511 }
512
513
514
515
516
517
518
519
520 @Benchmark
521 public int baselineNextDouble(LocalRandomSources sources) {
522 return sources.getGenerator().nextDouble() < 0.5 ? 1 : 0;
523 }
524
525
526
527
528
529
530
531 @Benchmark
532 public int sampleKnown(KnownDistributionSources sources) {
533 return sources.getSampler().sample();
534 }
535
536
537
538
539
540
541
542 @Benchmark
543 public int singleSampleKnown(KnownDistributionSources sources) {
544 return sources.createSampler().sample();
545 }
546
547
548
549
550
551
552
553 @Benchmark
554 public int sampleRandom(RandomDistributionSources sources) {
555 return sources.getSampler().sample();
556 }
557
558
559
560
561
562
563
564 @Benchmark
565 public int singleSampleRandom(RandomDistributionSources sources) {
566 return sources.createSampler().sample();
567 }
568 }