View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.examples.jmh.distribution;
19  
20  import java.util.concurrent.TimeUnit;
21  
22  import org.apache.commons.rng.RandomProviderState;
23  import org.apache.commons.rng.RestorableUniformRandomProvider;
24  import org.apache.commons.rng.UniformRandomProvider;
25  import org.apache.commons.rng.sampling.PermutationSampler;
26  import org.apache.commons.rng.sampling.distribution.DiscreteSampler;
27  import org.apache.commons.rng.sampling.distribution.PoissonSampler;
28  import org.apache.commons.rng.sampling.distribution.PoissonSamplerCache;
29  import org.apache.commons.rng.simple.RandomSource;
30  import org.openjdk.jmh.annotations.Benchmark;
31  import org.openjdk.jmh.annotations.BenchmarkMode;
32  import org.openjdk.jmh.annotations.Fork;
33  import org.openjdk.jmh.annotations.Measurement;
34  import org.openjdk.jmh.annotations.Mode;
35  import org.openjdk.jmh.annotations.OutputTimeUnit;
36  import org.openjdk.jmh.annotations.Param;
37  import org.openjdk.jmh.annotations.Scope;
38  import org.openjdk.jmh.annotations.Setup;
39  import org.openjdk.jmh.annotations.State;
40  import org.openjdk.jmh.annotations.Warmup;
41  import org.openjdk.jmh.infra.Blackhole;
42  
43  /**
44   * Executes benchmark to compare the speed of generation of Poisson random numbers when using a
45   * cache.
46   *
47   * <p>The benchmark is designed for a worse case scenario of Poisson means that are uniformly spread
48   * over a range and non-integer. A single sample is required per mean, E.g.
49   *
50   * <pre>
51   * int min = 40;
52   * int max = 1000;
53   * int range = max - min;
54   * UniformRandomProvider rng = ...;
55   *
56   * // Compare ...
57   * for (int i = 0; i < 1000; i++) {
58   *   new PoissonSampler(rng, min + rng.nextDouble() * range).sample();
59   * }
60   *
61   * // To ...
62   * PoissonSamplerCache cache = new PoissonSamplerCache(min, max);
63   * for (int i = 0; i < 1000; i++) {
64   *   PoissonSamplerCache.createPoissonSampler(rng, min + rng.nextDouble() * range).sample();
65   * }
66   * </pre>
67   *
68   * <p>The alternative scenario where the means are integer is not considered as this could be easily
69   * handled by creating an array to hold the PoissonSamplers for each mean. This does not require any
70   * specialised caching of state and is simple enough to perform for single threaded applications:
71   *
72   * <pre>
73   * public class SimpleUnsafePoissonSamplerCache {
74   *   int min = 50;
75   *   int max = 100;
76   *   PoissonSampler[] samplers = new PoissonSampler[max - min + 1];
77   *
78   *   public PoissonSampler createPoissonSampler(UniformRandomProvider rng, int mean) {
79   *     if (mean < min || mean > max) {
80   *       return new PoissonSampler(rng, mean);
81   *     }
82   *     int index = mean - min;
83   *     PoissonSampler sample = samplers[index];
84   *     if (sampler == null) {
85   *       sampler = new PoissonSampler(rng, mean);
86   *       samplers[index] = sampler;
87   *     }
88   *     return sampler;
89   *   }
90   * }
91   * </pre>
92   *
93   * Note that in this example the UniformRandomProvider is also cached and so this is only
94   * applicable to a single threaded application.
95   *
96   * <p>Re-written to use the PoissonSamplerCache would provide a new PoissonSampler per call in a
97   * thread-safe manner:
98   *
99   * <pre>
100  * public class SimplePoissonSamplerCache {
101  *   int min = 50;
102  *   int max = 100;
103  *   PoissonSamplerCache samplers = new PoissonSamplerCache(min, max);
104  *
105  *   public PoissonSampler createPoissonSampler(UniformRandomProvider rng, int mean) {
106  *       return samplers.createPoissonSampler(rng, mean);
107  *   }
108  * }
109  * </pre>
110 */
111 @BenchmarkMode(Mode.AverageTime)
112 @OutputTimeUnit(TimeUnit.MICROSECONDS)
113 @Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
114 @Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
115 @State(Scope.Benchmark)
116 @Fork(value = 1, jvmArgs = { "-server", "-Xms128M", "-Xmx128M" })
117 public class PoissonSamplerCachePerformance {
118     /** Number of samples per run. */
119     private static final int NUM_SAMPLES = 100000;
120     /**
121      * Number of range samples.
122      *
123      * <p>Note: The LargeMeanPoissonSampler will not use a SmallMeanPoissonSampler
124      * if the mean is an integer. This will occur if the [range sample] * range is
125      * an integer.
126      *
127      * <p>If the SmallMeanPoissonSampler is not used then the cache has more
128      * advantage over the uncached version as relatively more time is spent in
129      * initialising the algorithm.
130      *
131      * <p>To avoid this use a prime number above the maximum range
132      * (currently 4096). Any number (n/RANGE_SAMPLES) * range will not be integer
133      * with n<RANGE_SAMPLES and range<RANGE_SAMPLES (unless n==0).
134      */
135     private static final int RANGE_SAMPLE_SIZE = 4099;
136     /** The size of the seed. */
137     private static final int SEED_SIZE = 128;
138 
139     /**
140      * Seed used to ensure the tests are the same. This can be different per
141      * benchmark, but should be the same within the benchmark.
142      */
143     private static final int[] SEED;
144 
145     /**
146      * The range sample. Should contain doubles in the range 0 inclusive to 1 exclusive.
147      *
148      * <p>The range sample is used to create a mean using:
149      * rangeMin + sample * (rangeMax - rangeMin).
150      *
151      * <p>Ideally this should be large enough to fully sample the
152      * range when expressed as discrete integers, i.e. no sparseness, and random.
153      */
154     private static final double[] RANGE_SAMPLE;
155 
156     static {
157         // Build a random seed for all the tests
158         SEED = new int[SEED_SIZE];
159         final UniformRandomProvider rng = RandomSource.create(RandomSource.MWC_256);
160         for (int i = 0; i < SEED.length; i++) {
161             SEED[i] = rng.nextInt();
162         }
163 
164         final int size = RANGE_SAMPLE_SIZE;
165         final int[] sample = PermutationSampler.natural(size);
166         PermutationSampler.shuffle(rng, sample);
167 
168         RANGE_SAMPLE = new double[size];
169         for (int i = 0; i < size; i++) {
170             // Note: This will have one occurrence of zero in the range.
171             // This will create at least one LargeMeanPoissonSampler that will
172             // not use a SmallMeanPoissonSampler. The different performance of this
173             // will be lost among the other samples.
174             RANGE_SAMPLE[i] = (double) sample[i] / size;
175         }
176     }
177 
178     /**
179      * The benchmark state (retrieve the various "RandomSource"s).
180      */
181     @State(Scope.Benchmark)
182     public static class Sources {
183         /**
184          * RNG providers.
185          *
186          * <p>Use different speeds.
187          *
188          * @see <a href="https://commons.apache.org/proper/commons-rng/userguide/rng.html">
189          *      Commons RNG user guide</a>
190          */
191         @Param({ "SPLIT_MIX_64",
192             // Comment in for slower generators
193             //"MWC_256", "KISS", "WELL_1024_A", "WELL_44497_B"
194             })
195         private String randomSourceName;
196 
197         /** RNG. */
198         private RestorableUniformRandomProvider generator;
199 
200         /**
201          * The state of the generator at the start of the test (for reproducible
202          * results).
203          */
204         private RandomProviderState state;
205 
206         /**
207          * @return the RNG.
208          */
209         public UniformRandomProvider getGenerator() {
210             generator.restoreState(state);
211             return generator;
212         }
213 
214         /** Instantiates generator. */
215         @Setup
216         public void setup() {
217             final RandomSource randomSource = RandomSource
218                     .valueOf(randomSourceName);
219             // Use the same seed
220             generator = RandomSource.create(randomSource, SEED.clone());
221             state = generator.saveState();
222         }
223     }
224 
225     /**
226      * The range of mean values for testing the cache.
227      */
228     @State(Scope.Benchmark)
229     public static class MeanRange {
230         /**
231          * Test range.
232          *
233          * <p>The covers the best case scenario of caching everything (range=1) and upwards
234          * in powers of 4.
235          */
236         @Param({ "1", "4", "16", "64", "256", "1024", "4096"})
237         private double range;
238 
239         /**
240          * Gets the mean.
241          *
242          * @param i the index
243          * @return the mean
244          */
245         public double getMean(int i) {
246             return getMin() + RANGE_SAMPLE[i % RANGE_SAMPLE.length] * range;
247         }
248 
249         /**
250          * Gets the min of the range.
251          *
252          * @return the min
253          */
254         public double getMin() {
255             return PoissonSamplerCache.getMinimumCachedMean();
256         }
257 
258         /**
259          * Gets the max of the range.
260          *
261          * @return the max
262          */
263         public double getMax() {
264             return getMin() + range;
265         }
266     }
267 
268     /**
269      * A factory for creating Poisson sampler objects.
270      */
271     private interface PoissonSamplerFactory {
272         /**
273          * Creates a new Poisson sampler object.
274          *
275          * @param mean the mean
276          * @return The sampler
277          */
278         DiscreteSampler createPoissonSampler(double mean);
279     }
280 
281     /**
282      * Exercises a poisson sampler created for a single use with a range of means.
283      *
284      * @param factory The factory.
285      * @param range   The range of means.
286      * @param bh      Data sink.
287      */
288     private static void runSample(PoissonSamplerFactory factory,
289                                   MeanRange range,
290                                   Blackhole bh) {
291         for (int i = 0; i < NUM_SAMPLES; i++) {
292             bh.consume(factory.createPoissonSampler(range.getMean(i)).sample());
293         }
294     }
295 
296     // Benchmarks methods below.
297 
298     /**
299      * @param sources Source of randomness.
300      * @param range   The range.
301      * @param bh      Data sink.
302      */
303     @Benchmark
304     public void runPoissonSampler(Sources sources,
305                                   MeanRange range,
306                                   Blackhole bh) {
307         final UniformRandomProvider r = sources.getGenerator();
308         final PoissonSamplerFactory factory = new PoissonSamplerFactory() {
309             @Override
310             public DiscreteSampler createPoissonSampler(double mean) {
311                 return new PoissonSampler(r, mean);
312             }
313         };
314         runSample(factory, range, bh);
315     }
316 
317     /**
318      * @param sources Source of randomness.
319      * @param range   The range.
320      * @param bh      Data sink.
321      */
322     @Benchmark
323     public void runPoissonSamplerCacheWhenEmpty(Sources sources,
324                                                 MeanRange range,
325                                                 Blackhole bh) {
326         final UniformRandomProvider r = sources.getGenerator();
327         final PoissonSamplerCache cache = new PoissonSamplerCache(0, 0);
328         final PoissonSamplerFactory factory = new PoissonSamplerFactory() {
329             @Override
330             public DiscreteSampler createPoissonSampler(double mean) {
331                 return cache.createPoissonSampler(r, mean);
332             }
333         };
334         runSample(factory, range, bh);
335     }
336 
337     /**
338      * @param sources Source of randomness.
339      * @param range   The range.
340      * @param bh      Data sink.
341      */
342     @Benchmark
343     public void runPoissonSamplerCache(Sources sources,
344                                        MeanRange range,
345                                        Blackhole bh) {
346         final UniformRandomProvider r = sources.getGenerator();
347         final PoissonSamplerCache cache = new PoissonSamplerCache(
348                 range.getMin(), range.getMax());
349         final PoissonSamplerFactory factory = new PoissonSamplerFactory() {
350             @Override
351             public DiscreteSampler createPoissonSampler(double mean) {
352                 return cache.createPoissonSampler(r, mean);
353             }
354         };
355         runSample(factory, range, bh);
356     }
357 }