View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import java.util.List;
20  import java.util.ArrayList;
21  import java.util.Collections;
22  
23  import org.apache.commons.math3.util.MathArrays;
24  
25  import org.apache.commons.rng.UniformRandomProvider;
26  import org.apache.commons.rng.simple.RandomSource;
27  
28  /**
29   * List of samplers.
30   */
31  public final class DiscreteSamplersList {
32      /** List of all RNGs implemented in the library. */
33      private static final List<DiscreteSamplerTestData[]> LIST =
34          new ArrayList<DiscreteSamplerTestData[]>();
35  
36      static {
37          try {
38              // This test uses reference distributions from commons-math3 to compute the expected
39              // PMF. These distributions have a dual functionality to compute the PMF and perform
40              // sampling. When no sampling is needed for the created distribution, it is advised
41              // to pass null as the random generator via the appropriate constructors to avoid the
42              // additional initialisation overhead.
43              org.apache.commons.math3.random.RandomGenerator unusedRng = null;
44  
45              // List of distributions to test.
46  
47              // Binomial ("inverse method").
48              final int trialsBinomial = 20;
49              final double probSuccessBinomial = 0.67;
50              add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, probSuccessBinomial),
51                  MathArrays.sequence(8, 9, 1),
52                  RandomSource.create(RandomSource.KISS));
53              add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, probSuccessBinomial),
54                  // range [9,16]
55                  MathArrays.sequence(8, 9, 1),
56                  MarsagliaTsangWangDiscreteSampler.Binomial.of(RandomSource.create(RandomSource.WELL_19937_A), trialsBinomial, probSuccessBinomial));
57              // Inverted
58              add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, 1 - probSuccessBinomial),
59                  // range [4,11] = [20-16, 20-9]
60                  MathArrays.sequence(8, 4, 1),
61                  MarsagliaTsangWangDiscreteSampler.Binomial.of(RandomSource.create(RandomSource.WELL_19937_C), trialsBinomial, 1 - probSuccessBinomial));
62  
63              // Geometric ("inverse method").
64              final double probSuccessGeometric = 0.21;
65              add(LIST, new org.apache.commons.math3.distribution.GeometricDistribution(unusedRng, probSuccessGeometric),
66                  MathArrays.sequence(10, 0, 1),
67                  RandomSource.create(RandomSource.ISAAC));
68              // Geometric.
69              add(LIST, new org.apache.commons.math3.distribution.GeometricDistribution(unusedRng, probSuccessGeometric),
70                  MathArrays.sequence(10, 0, 1),
71                  GeometricSampler.of(RandomSource.create(RandomSource.XOR_SHIFT_1024_S), probSuccessGeometric));
72  
73              // Hypergeometric ("inverse method").
74              final int popSizeHyper = 34;
75              final int numSuccessesHyper = 11;
76              final int sampleSizeHyper = 12;
77              add(LIST, new org.apache.commons.math3.distribution.HypergeometricDistribution(unusedRng, popSizeHyper, numSuccessesHyper, sampleSizeHyper),
78                  MathArrays.sequence(10, 0, 1),
79                  RandomSource.create(RandomSource.MT));
80  
81              // Pascal ("inverse method").
82              final int numSuccessesPascal = 6;
83              final double probSuccessPascal = 0.2;
84              add(LIST, new org.apache.commons.math3.distribution.PascalDistribution(unusedRng, numSuccessesPascal, probSuccessPascal),
85                  MathArrays.sequence(18, 1, 1),
86                  RandomSource.create(RandomSource.TWO_CMRES));
87  
88              // Uniform ("inverse method").
89              final int loUniform = -3;
90              final int hiUniform = 4;
91              add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiUniform),
92                  MathArrays.sequence(8, -3, 1),
93                  RandomSource.create(RandomSource.SPLIT_MIX_64));
94              // Uniform (power of 2 range).
95              add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiUniform),
96                  MathArrays.sequence(8, -3, 1),
97                  DiscreteUniformSampler.of(RandomSource.create(RandomSource.MT_64), loUniform, hiUniform));
98              // Uniform (large range).
99              final int halfMax = Integer.MAX_VALUE / 2;
100             final int hiLargeUniform = halfMax + 10;
101             final int loLargeUniform = -hiLargeUniform;
102             add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loLargeUniform, hiLargeUniform),
103                 MathArrays.sequence(20, -halfMax, halfMax / 10),
104                 DiscreteUniformSampler.of(RandomSource.create(RandomSource.WELL_1024_A), loLargeUniform, hiLargeUniform));
105             // Uniform (non-power of 2 range).
106             final int rangeNonPowerOf2Uniform = 11;
107             final int hiNonPowerOf2Uniform = loUniform + rangeNonPowerOf2Uniform;
108             add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiNonPowerOf2Uniform),
109                 MathArrays.sequence(rangeNonPowerOf2Uniform, -3, 1),
110                 DiscreteUniformSampler.of(RandomSource.create(RandomSource.XO_SHI_RO_256_SS), loUniform, hiNonPowerOf2Uniform));
111 
112             // Zipf ("inverse method").
113             final int numElementsZipf = 5;
114             final double exponentZipf = 2.345;
115             add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(unusedRng, numElementsZipf, exponentZipf),
116                 MathArrays.sequence(5, 1, 1),
117                 RandomSource.create(RandomSource.XOR_SHIFT_1024_S));
118             // Zipf.
119             add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(unusedRng, numElementsZipf, exponentZipf),
120                 MathArrays.sequence(5, 1, 1),
121                 RejectionInversionZipfSampler.of(RandomSource.create(RandomSource.WELL_19937_C), numElementsZipf, exponentZipf));
122             // Zipf (exponent close to 1).
123             final double exponentCloseToOneZipf = 1 - 1e-10;
124             add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(unusedRng, numElementsZipf, exponentCloseToOneZipf),
125                 MathArrays.sequence(5, 1, 1),
126                 RejectionInversionZipfSampler.of(RandomSource.create(RandomSource.WELL_19937_C), numElementsZipf, exponentCloseToOneZipf));
127 
128             // Poisson ("inverse method").
129             final double epsilonPoisson = org.apache.commons.math3.distribution.PoissonDistribution.DEFAULT_EPSILON;
130             final int maxIterationsPoisson = org.apache.commons.math3.distribution.PoissonDistribution.DEFAULT_MAX_ITERATIONS;
131             final double meanPoisson = 3.21;
132             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
133                 MathArrays.sequence(10, 0, 1),
134                 RandomSource.create(RandomSource.MWC_256));
135             // Poisson.
136             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
137                 MathArrays.sequence(10, 0, 1),
138                 PoissonSampler.of(RandomSource.create(RandomSource.KISS), meanPoisson));
139             // Dedicated small mean poisson samplers
140             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
141                 MathArrays.sequence(10, 0, 1),
142                 SmallMeanPoissonSampler.of(RandomSource.create(RandomSource.XO_SHI_RO_256_PLUS), meanPoisson));
143             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
144                 MathArrays.sequence(10, 0, 1),
145                 KempSmallMeanPoissonSampler.of(RandomSource.create(RandomSource.XO_SHI_RO_128_PLUS), meanPoisson));
146             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
147                 MathArrays.sequence(10, 0, 1),
148                 MarsagliaTsangWangDiscreteSampler.Poisson.of(RandomSource.create(RandomSource.XO_SHI_RO_128_PLUS), meanPoisson));
149             // LargeMeanPoissonSampler should work at small mean.
150             // Note: This hits a code path where the sample from the normal distribution is rejected.
151             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
152                 MathArrays.sequence(10, 0, 1),
153                 LargeMeanPoissonSampler.of(RandomSource.create(RandomSource.PCG_MCG_XSH_RR_32), meanPoisson));
154             // Poisson (40 < mean < 80).
155             final double largeMeanPoisson = 67.89;
156             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, largeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
157                 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
158                 PoissonSampler.of(RandomSource.create(RandomSource.SPLIT_MIX_64), largeMeanPoisson));
159             // Dedicated large mean poisson sampler
160             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, largeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
161                 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
162                 LargeMeanPoissonSampler.of(RandomSource.create(RandomSource.SPLIT_MIX_64), largeMeanPoisson));
163             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, largeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
164                 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
165                 MarsagliaTsangWangDiscreteSampler.Poisson.of(RandomSource.create(RandomSource.XO_RO_SHI_RO_128_PLUS), largeMeanPoisson));
166             // Poisson (mean >> 40).
167             final double veryLargeMeanPoisson = 543.21;
168             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
169                 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
170                 PoissonSampler.of(RandomSource.create(RandomSource.SPLIT_MIX_64), veryLargeMeanPoisson));
171             // Dedicated large mean poisson sampler
172             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
173                 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
174                 LargeMeanPoissonSampler.of(RandomSource.create(RandomSource.SPLIT_MIX_64), veryLargeMeanPoisson));
175             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
176                 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
177                 MarsagliaTsangWangDiscreteSampler.Poisson.of(RandomSource.create(RandomSource.XO_RO_SHI_RO_64_SS), veryLargeMeanPoisson));
178 
179             // Any discrete distribution
180             final double[] discreteProbabilities = new double[] {0.1, 0.2, 0.3, 0.4, 0.5};
181             add(LIST, discreteProbabilities,
182                 MarsagliaTsangWangDiscreteSampler.Enumerated.of(RandomSource.create(RandomSource.XO_SHI_RO_512_PLUS), discreteProbabilities));
183             add(LIST, discreteProbabilities,
184                 GuideTableDiscreteSampler.of(RandomSource.create(RandomSource.XO_SHI_RO_512_SS), discreteProbabilities));
185             add(LIST, discreteProbabilities,
186                     AliasMethodDiscreteSampler.of(RandomSource.create(RandomSource.KISS), discreteProbabilities));
187         } catch (Exception e) {
188             // CHECKSTYLE: stop Regexp
189             System.err.println("Unexpected exception while creating the list of samplers: " + e);
190             e.printStackTrace(System.err);
191             // CHECKSTYLE: resume Regexp
192             throw new RuntimeException(e);
193         }
194     }
195 
196     /**
197      * Class contains only static methods.
198      */
199     private DiscreteSamplersList() {}
200 
201     /**
202      * @param list List of data (one the "parameters" tested by the Junit parametric test).
203      * @param dist Distribution to which the samples are supposed to conform.
204      * @param points Outcomes selection.
205      * @param rng Generator of uniformly distributed sequences.
206      */
207     private static void add(List<DiscreteSamplerTestData[]> list,
208                             final org.apache.commons.math3.distribution.IntegerDistribution dist,
209                             int[] points,
210                             UniformRandomProvider rng) {
211         final DiscreteSampler inverseMethodSampler =
212             InverseTransformDiscreteSampler.of(rng,
213                 new DiscreteInverseCumulativeProbabilityFunction() {
214                     @Override
215                     public int inverseCumulativeProbability(double p) {
216                         return dist.inverseCumulativeProbability(p);
217                     }
218                     @Override
219                     public String toString() {
220                         return dist.toString();
221                     }
222                 });
223         list.add(new DiscreteSamplerTestData[] {new DiscreteSamplerTestData(inverseMethodSampler,
224                                                                             points,
225                                                                             getProbabilities(dist, points))});
226     }
227 
228     /**
229      * @param list List of data (one the "parameters" tested by the Junit parametric test).
230      * @param dist Distribution to which the samples are supposed to conform.
231      * @param points Outcomes selection.
232      * @param sampler Sampler.
233      */
234     private static void add(List<DiscreteSamplerTestData[]> list,
235                             final org.apache.commons.math3.distribution.IntegerDistribution dist,
236                             int[] points,
237                             final DiscreteSampler sampler) {
238         list.add(new DiscreteSamplerTestData[] {new DiscreteSamplerTestData(sampler,
239                                                                             points,
240                                                                             getProbabilities(dist, points))});
241     }
242 
243     /**
244      * @param list List of data (one the "parameters" tested by the Junit parametric test).
245      * @param probabilities Probability distribution to which the samples are supposed to conform.
246      * @param sampler Sampler.
247      */
248     private static void add(List<DiscreteSamplerTestData[]> list,
249                             final double[] probabilities,
250                             final DiscreteSampler sampler) {
251         list.add(new DiscreteSamplerTestData[] {new DiscreteSamplerTestData(sampler,
252                                                                             MathArrays.natural(probabilities.length),
253                                                                             probabilities)});
254     }
255 
256     /**
257      * Subclasses that are "parametric" tests can forward the call to
258      * the "@Parameters"-annotated method to this method.
259      *
260      * @return the list of all generators.
261      */
262     public static Iterable<DiscreteSamplerTestData[]> list() {
263         return Collections.unmodifiableList(LIST);
264     }
265 
266     /**
267      * @param dist Distribution.
268      * @param points Points.
269      * @return the probabilities of the given points according to the distribution.
270      */
271     private static double[] getProbabilities(org.apache.commons.math3.distribution.IntegerDistribution dist,
272                                              int[] points) {
273         final int len = points.length;
274         final double[] prob = new double[len];
275         for (int i = 0; i < len; i++) {
276             prob[i] = dist instanceof org.apache.commons.math3.distribution.UniformIntegerDistribution ? // XXX Workaround.
277                 getProbability((org.apache.commons.math3.distribution.UniformIntegerDistribution) dist) :
278                 dist.probability(points[i]);
279 
280             if (prob[i] < 0) {
281                 throw new IllegalStateException(dist + ": p < 0 (at " + points[i] + ", p=" + prob[i]);
282             }
283         }
284         return prob;
285     }
286 
287     /**
288      * Workaround bugs in Commons Math's "UniformIntegerDistribution" (cf. MATH-1396).
289      */
290     private static double getProbability(org.apache.commons.math3.distribution.UniformIntegerDistribution dist) {
291         return 1 / ((double) dist.getSupportUpperBound() - (double) dist.getSupportLowerBound() + 1);
292     }
293 }