View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.examples.jmh.sampling.distribution;
19  
20  import org.apache.commons.math3.distribution.BinomialDistribution;
21  import org.apache.commons.math3.distribution.IntegerDistribution;
22  import org.apache.commons.math3.distribution.PoissonDistribution;
23  import org.apache.commons.rng.UniformRandomProvider;
24  import org.apache.commons.rng.sampling.distribution.AliasMethodDiscreteSampler;
25  import org.apache.commons.rng.sampling.distribution.DiscreteSampler;
26  import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
27  import org.apache.commons.rng.sampling.distribution.MarsagliaTsangWangDiscreteSampler;
28  import org.apache.commons.rng.simple.RandomSource;
29  
30  import org.openjdk.jmh.annotations.Benchmark;
31  import org.openjdk.jmh.annotations.BenchmarkMode;
32  import org.openjdk.jmh.annotations.Fork;
33  import org.openjdk.jmh.annotations.Level;
34  import org.openjdk.jmh.annotations.Measurement;
35  import org.openjdk.jmh.annotations.Mode;
36  import org.openjdk.jmh.annotations.OutputTimeUnit;
37  import org.openjdk.jmh.annotations.Param;
38  import org.openjdk.jmh.annotations.Scope;
39  import org.openjdk.jmh.annotations.Setup;
40  import org.openjdk.jmh.annotations.State;
41  import org.openjdk.jmh.annotations.Warmup;
42  
43  import java.util.Arrays;
44  import java.util.concurrent.ThreadLocalRandom;
45  import java.util.concurrent.TimeUnit;
46  
47  /**
48   * Executes benchmark to compare the speed of generation of random numbers from an enumerated
49   * discrete probability distribution.
50   */
51  @BenchmarkMode(Mode.AverageTime)
52  @OutputTimeUnit(TimeUnit.NANOSECONDS)
53  @Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
54  @Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
55  @State(Scope.Benchmark)
56  @Fork(value = 1, jvmArgs = {"-server", "-Xms128M", "-Xmx128M"})
57  public class EnumeratedDistributionSamplersPerformance {
58      /**
59       * The value for the baseline generation of an {@code int} value.
60       *
61       * <p>This must NOT be final!</p>
62       */
63      private int value;
64  
65      /**
66       * The random sources to use for testing. This is a smaller list than all the possible
67       * random sources; the list is composed of generators of different speeds.
68       */
69      @State(Scope.Benchmark)
70      public static class LocalRandomSources {
71          /**
72           * RNG providers.
73           *
74           * <p>Use different speeds.</p>
75           *
76           * @see <a href="https://commons.apache.org/proper/commons-rng/userguide/rng.html">
77           *      Commons RNG user guide</a>
78           */
79          @Param({"WELL_44497_B",
80                  "ISAAC",
81                  "XO_RO_SHI_RO_128_PLUS",
82                  })
83          private String randomSourceName;
84  
85          /** RNG. */
86          private UniformRandomProvider generator;
87  
88          /**
89           * @return the RNG.
90           */
91          public UniformRandomProvider getGenerator() {
92              return generator;
93          }
94  
95          /** Create the random source. */
96          @Setup
97          public void setup() {
98              final RandomSource randomSource = RandomSource.valueOf(randomSourceName);
99              generator = RandomSource.create(randomSource);
100         }
101     }
102 
103     /**
104      * The {@link DiscreteSampler} samplers to use for testing. Creates the sampler for each
105      * random source.
106      *
107      * <p>This class is abstract. The probability distribution is created by implementations.</p>
108      */
109     @State(Scope.Benchmark)
110     public abstract static class SamplerSources extends LocalRandomSources {
111         /**
112          * The sampler type.
113          */
114         @Param({"BinarySearchDiscreteSampler",
115                 "AliasMethodDiscreteSampler",
116                 "GuideTableDiscreteSampler",
117                 "MarsagliaTsangWangDiscreteSampler",
118 
119                 // Uncomment to test non-default parameters
120                 //"AliasMethodDiscreteSamplerNoPad", // Not optimal for sampling
121                 //"AliasMethodDiscreteSamplerAlpha1",
122                 //"AliasMethodDiscreteSamplerAlpha2",
123 
124                 // The AliasMethod memory requirement doubles for each alpha increment.
125                 // A fair comparison is to use 2^alpha for the equivalent guide table method.
126                 //"GuideTableDiscreteSamplerAlpha2",
127                 //"GuideTableDiscreteSamplerAlpha4",
128                 })
129         private String samplerType;
130 
131         /** The factory. */
132         private DiscreteSamplerFactory factory;
133 
134         /** The sampler. */
135         private DiscreteSampler sampler;
136 
137         /**
138          * A factory for creating DiscreteSampler objects.
139          */
140         interface DiscreteSamplerFactory {
141             /**
142              * Creates the sampler.
143              *
144              * @return the sampler
145              */
146             DiscreteSampler create();
147         }
148 
149         /**
150          * Gets the sampler.
151          *
152          * @return the sampler.
153          */
154         public DiscreteSampler getSampler() {
155             return sampler;
156         }
157 
158         /** Create the distribution (per iteration as it may vary) and instantiates sampler. */
159         @Override
160         @Setup(Level.Iteration)
161         public void setup() {
162             super.setup();
163 
164             final double[] probabilities = createProbabilities();
165             createSamplerFactory(getGenerator(), probabilities);
166             sampler = factory.create();
167         }
168 
169         /**
170          * Creates the probabilities for the distribution.
171          *
172          * @return The probabilities.
173          */
174         protected abstract double[] createProbabilities();
175 
176         /**
177          * Creates the sampler factory.
178          *
179          * @param rng The random generator.
180          * @param probabilities The probabilities.
181          */
182         private void createSamplerFactory(final UniformRandomProvider rng,
183             final double[] probabilities) {
184             // This would benefit from Java 8 lambda functions
185             if ("BinarySearchDiscreteSampler".equals(samplerType)) {
186                 factory = new DiscreteSamplerFactory() {
187                     @Override
188                     public DiscreteSampler create() {
189                         return new BinarySearchDiscreteSampler(rng, probabilities);
190                     }
191                 };
192             } else if ("AliasMethodDiscreteSampler".equals(samplerType)) {
193                 factory = new DiscreteSamplerFactory() {
194                     @Override
195                     public DiscreteSampler create() {
196                         return AliasMethodDiscreteSampler.of(rng, probabilities);
197                     }
198                 };
199             } else if ("AliasMethodDiscreteSamplerNoPad".equals(samplerType)) {
200                 factory = new DiscreteSamplerFactory() {
201                     @Override
202                     public DiscreteSampler create() {
203                         return AliasMethodDiscreteSampler.of(rng, probabilities, -1);
204                     }
205                 };
206             } else if ("AliasMethodDiscreteSamplerAlpha1".equals(samplerType)) {
207                 factory = new DiscreteSamplerFactory() {
208                     @Override
209                     public DiscreteSampler create() {
210                         return AliasMethodDiscreteSampler.of(rng, probabilities, 1);
211                     }
212                 };
213             } else if ("AliasMethodDiscreteSamplerAlpha2".equals(samplerType)) {
214                 factory = new DiscreteSamplerFactory() {
215                     @Override
216                     public DiscreteSampler create() {
217                         return AliasMethodDiscreteSampler.of(rng, probabilities, 2);
218                     }
219                 };
220             } else if ("GuideTableDiscreteSampler".equals(samplerType)) {
221                 factory = new DiscreteSamplerFactory() {
222                     @Override
223                     public DiscreteSampler create() {
224                         return GuideTableDiscreteSampler.of(rng, probabilities);
225                     }
226                 };
227             } else if ("GuideTableDiscreteSamplerAlpha2".equals(samplerType)) {
228                 factory = new DiscreteSamplerFactory() {
229                     @Override
230                     public DiscreteSampler create() {
231                         return GuideTableDiscreteSampler.of(rng, probabilities, 2);
232                     }
233                 };
234             } else if ("GuideTableDiscreteSamplerAlpha8".equals(samplerType)) {
235                 factory = new DiscreteSamplerFactory() {
236                     @Override
237                     public DiscreteSampler create() {
238                         return GuideTableDiscreteSampler.of(rng, probabilities, 8);
239                     }
240                 };
241             } else if ("MarsagliaTsangWangDiscreteSampler".equals(samplerType)) {
242                 factory = new DiscreteSamplerFactory() {
243                     @Override
244                     public DiscreteSampler create() {
245                         return MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng, probabilities);
246                     }
247                 };
248             } else {
249                 throw new IllegalStateException();
250             }
251         }
252 
253         /**
254          * Creates a new instance of the sampler.
255          *
256          * @return The sampler.
257          */
258         public DiscreteSampler createSampler() {
259             return factory.create();
260         }
261     }
262 
263     /**
264      * Define known probability distributions for testing. These are expected to have well
265      * behaved cumulative probability functions.
266      */
267     @State(Scope.Benchmark)
268     public static class KnownDistributionSources extends SamplerSources {
269         /** The cumulative probability limit for unbounded distributions. */
270         private static final double CUMULATIVE_PROBABILITY_LIMIT = 1 - 1e-9;
271 
272         /**
273          * The distribution.
274          */
275         @Param({"Binomial_N67_P0.7",
276                 "Geometric_P0.2",
277                 "4SidedLoadedDie",
278                 "Poisson_Mean3.14",
279                 "Poisson_Mean10_Mean20",
280                 })
281         private String distribution;
282 
283         /** {@inheritDoc} */
284         @Override
285         protected double[] createProbabilities() {
286             if ("Binomial_N67_P0.7".equals(distribution)) {
287                 final int trials = 67;
288                 final double probabilityOfSuccess = 0.7;
289                 final BinomialDistribution dist = new BinomialDistribution(null, trials, probabilityOfSuccess);
290                 return createProbabilities(dist, 0, trials);
291             } else if ("Geometric_P0.2".equals(distribution)) {
292                 final double probabilityOfSuccess = 0.2;
293                 final double probabilityOfFailure = 1 - probabilityOfSuccess;
294                 // https://en.wikipedia.org/wiki/Geometric_distribution
295                 // PMF = (1-p)^k * p
296                 // k is number of failures before a success
297                 double p = 1.0; // (1-p)^0
298                 // Build until the cumulative function is big
299                 double[] probabilities = new double[100];
300                 double sum = 0;
301                 int k = 0;
302                 while (k < probabilities.length) {
303                     probabilities[k] = p * probabilityOfSuccess;
304                     sum += probabilities[k++];
305                     if (sum > CUMULATIVE_PROBABILITY_LIMIT) {
306                         break;
307                     }
308                     // For the next PMF
309                     p *= probabilityOfFailure;
310                 }
311                 return Arrays.copyOf(probabilities, k);
312             } else if ("4SidedLoadedDie".equals(distribution)) {
313                 return new double[] {1.0 / 2, 1.0 / 3, 1.0 / 12, 1.0 / 12};
314             } else if ("Poisson_Mean3.14".equals(distribution)) {
315                 final double mean = 3.14;
316                 final IntegerDistribution dist = createPoissonDistribution(mean);
317                 final int max = dist.inverseCumulativeProbability(CUMULATIVE_PROBABILITY_LIMIT);
318                 return createProbabilities(dist, 0, max);
319             } else if ("Poisson_Mean10_Mean20".equals(distribution)) {
320                 // Create a Bimodel using two Poisson distributions
321                 final double mean1 = 10;
322                 final double mean2 = 20;
323                 final IntegerDistribution dist1 = createPoissonDistribution(mean2);
324                 final int max = dist1.inverseCumulativeProbability(CUMULATIVE_PROBABILITY_LIMIT);
325                 final double[] p1 = createProbabilities(dist1, 0, max);
326                 final double[] p2 = createProbabilities(createPoissonDistribution(mean1), 0, max);
327                 for (int i = 0; i < p1.length; i++) {
328                     p1[i] += p2[i];
329                 }
330                 // Leave to the distribution to normalise the sum
331                 return p1;
332             }
333             throw new IllegalStateException();
334         }
335 
336         /**
337          * Creates the poisson distribution.
338          *
339          * @param mean the mean
340          * @return the distribution
341          */
342         private static IntegerDistribution createPoissonDistribution(double mean) {
343             return new PoissonDistribution(null, mean,
344                 PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS);
345         }
346 
347         /**
348          * Creates the probabilities from the distribution.
349          *
350          * @param dist the distribution
351          * @param lower the lower bounds (inclusive)
352          * @param upper the upper bounds (inclusive)
353          * @return the probabilities
354          */
355         private static double[] createProbabilities(IntegerDistribution dist, int lower, int upper) {
356             double[] probabilities = new double[upper - lower + 1];
357             int index = 0;
358             for (int x = lower; x <= upper; x++) {
359                 probabilities[index++] = dist.probability(x);
360             }
361             return probabilities;
362         }
363     }
364 
365     /**
366      * Define random probability distributions of known size for testing. These are random but
367      * the average cumulative probability function will be a straight line given the increment
368      * average is 0.5.
369      */
370     @State(Scope.Benchmark)
371     public static class RandomDistributionSources extends SamplerSources {
372         /**
373          * The distribution size.
374          * These are spaced half-way between powers-of-2 to minimise the advantage of
375          * padding by the Alias method sampler.
376          */
377         @Param({"6",
378                 //"12",
379                 //"24",
380                 //"48",
381                 "96",
382                 //"192",
383                 //"384",
384                 // Above 2048 forces the Alias method to use more than 64-bits for sampling
385                 "3072"
386                 })
387         private int randomNonUniformSize;
388 
389         /** {@inheritDoc} */
390         @Override
391         protected double[] createProbabilities() {
392             final double[] probabilities = new double[randomNonUniformSize];
393             final ThreadLocalRandom rng = ThreadLocalRandom.current();
394             for (int i = 0; i < probabilities.length; i++) {
395                 probabilities[i] = rng.nextDouble();
396             }
397             return probabilities;
398         }
399     }
400 
401     /**
402      * Compute a sample by binary search of the cumulative probability distribution.
403      */
404     static final class BinarySearchDiscreteSampler
405         implements DiscreteSampler {
406         /** Underlying source of randomness. */
407         private final UniformRandomProvider rng;
408         /**
409          * The cumulative probability table.
410          */
411         private final double[] cumulativeProbabilities;
412 
413         /**
414          * @param rng Generator of uniformly distributed random numbers.
415          * @param probabilities The probabilities.
416          * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
417          * probability is negative, infinite or {@code NaN}, or the sum of all
418          * probabilities is not strictly positive.
419          */
420         BinarySearchDiscreteSampler(UniformRandomProvider rng,
421                                     double[] probabilities) {
422             // Minimal set-up validation
423             if (probabilities == null || probabilities.length == 0) {
424                 throw new IllegalArgumentException("Probabilities must not be empty.");
425             }
426 
427             final int size = probabilities.length;
428             cumulativeProbabilities = new double[size];
429 
430             double sumProb = 0;
431             int count = 0;
432             for (final double prob : probabilities) {
433                 if (prob < 0 ||
434                     Double.isInfinite(prob) ||
435                     Double.isNaN(prob)) {
436                     throw new IllegalArgumentException("Invalid probability: " +
437                                                        prob);
438                 }
439 
440                 // Compute and store cumulative probability.
441                 sumProb += prob;
442                 cumulativeProbabilities[count++] = sumProb;
443             }
444 
445             if (Double.isInfinite(sumProb) || sumProb <= 0) {
446                 throw new IllegalArgumentException("Invalid sum of probabilities: " + sumProb);
447             }
448 
449             this.rng = rng;
450 
451             // Normalise cumulative probability.
452             for (int i = 0; i < size; i++) {
453                 final double norm = cumulativeProbabilities[i] / sumProb;
454                 cumulativeProbabilities[i] = (norm < 1) ? norm : 1.0;
455             }
456         }
457 
458         /** {@inheritDoc} */
459         @Override
460         public int sample() {
461             final double u = rng.nextDouble();
462 
463             // Java binary search
464             //int index = Arrays.binarySearch(cumulativeProbabilities, u);
465             //if (index < 0) {
466             //    index = -index - 1;
467             //}
468             //
469             //return index < cumulativeProbabilities.length ?
470             //    index :
471             //    cumulativeProbabilities.length - 1;
472 
473             // Binary search within known cumulative probability table.
474             // Find x so that u > f[x-1] and u <= f[x].
475             // This is a looser search than Arrays.binarySearch:
476             // - The output is x = upper.
477             // - The table stores probabilities where f[0] is >= 0 and the max == 1.0.
478             // - u should be >= 0 and <= 1 (or the random generator is broken).
479             // - It avoids comparisons using Double.doubleToLongBits.
480             // - It avoids the low likelihood of equality between two doubles for fast exit
481             //   so uses only 1 compare per loop.
482             int lower = 0;
483             int upper = cumulativeProbabilities.length - 1;
484             while (lower < upper) {
485                 final int mid = (lower + upper) >>> 1;
486                 final double midVal = cumulativeProbabilities[mid];
487                 if (u > midVal) {
488                     // Change lower such that
489                     // u > f[lower - 1]
490                     lower = mid + 1;
491                 } else {
492                     // Change upper such that
493                     // u <= f[upper]
494                     upper = mid;
495                 }
496             }
497             return upper;
498         }
499     }
500 
501     // Benchmarks methods below.
502 
503     /**
504      * Baseline for the JMH timing overhead for production of an {@code int} value.
505      *
506      * @return the {@code int} value
507      */
508     @Benchmark
509     public int baselineInt() {
510         return value;
511     }
512 
513     /**
514      * Baseline for the production of a {@code double} value.
515      * This is used to assess the performance of the underlying random source.
516      *
517      * @param sources Source of randomness.
518      * @return the {@code int} value
519      */
520     @Benchmark
521     public int baselineNextDouble(LocalRandomSources sources) {
522         return sources.getGenerator().nextDouble() < 0.5 ? 1 : 0;
523     }
524 
525     /**
526      * Run the sampler.
527      *
528      * @param sources Source of randomness.
529      * @return the sample value
530      */
531     @Benchmark
532     public int sampleKnown(KnownDistributionSources sources) {
533         return sources.getSampler().sample();
534     }
535 
536     /**
537      * Run the sampler.
538      *
539      * @param sources Source of randomness.
540      * @return the sample value
541      */
542     @Benchmark
543     public int singleSampleKnown(KnownDistributionSources sources) {
544         return sources.createSampler().sample();
545     }
546 
547     /**
548      * Run the sampler.
549      *
550      * @param sources Source of randomness.
551      * @return the sample value
552      */
553     @Benchmark
554     public int sampleRandom(RandomDistributionSources sources) {
555         return sources.getSampler().sample();
556     }
557 
558     /**
559      * Run the sampler.
560      *
561      * @param sources Source of randomness.
562      * @return the sample value
563      */
564     @Benchmark
565     public int singleSampleRandom(RandomDistributionSources sources) {
566         return sources.createSampler().sample();
567     }
568 }