View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import java.util.List;
20  import java.util.ArrayList;
21  import java.util.Collections;
22  
23  import org.apache.commons.math3.util.MathArrays;
24  
25  import org.apache.commons.rng.UniformRandomProvider;
26  import org.apache.commons.rng.simple.RandomSource;
27  
28  /**
29   * List of samplers.
30   */
31  public class DiscreteSamplersList {
32      /** List of all RNGs implemented in the library. */
33      private static final List<DiscreteSamplerTestData[]> LIST =
34          new ArrayList<DiscreteSamplerTestData[]>();
35  
36      static {
37          try {
38              // List of distributions to test.
39  
40              // Binomial ("inverse method").
41              final int trialsBinomial = 20;
42              final double probSuccessBinomial = 0.67;
43              add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(trialsBinomial, probSuccessBinomial),
44                  MathArrays.sequence(8, 9, 1),
45                  RandomSource.create(RandomSource.KISS));
46  
47              // Geometric ("inverse method").
48              final double probSuccessGeometric = 0.21;
49              add(LIST, new org.apache.commons.math3.distribution.GeometricDistribution(probSuccessGeometric),
50                  MathArrays.sequence(10, 0, 1),
51                  RandomSource.create(RandomSource.ISAAC));
52  
53              // Hypergeometric ("inverse method").
54              final int popSizeHyper = 34;
55              final int numSuccessesHyper = 11;
56              final int sampleSizeHyper = 12;
57              add(LIST, new org.apache.commons.math3.distribution.HypergeometricDistribution(popSizeHyper, numSuccessesHyper, sampleSizeHyper),
58                  MathArrays.sequence(10, 0, 1),
59                  RandomSource.create(RandomSource.MT));
60  
61              // Pascal ("inverse method").
62              final int numSuccessesPascal = 6;
63              final double probSuccessPascal = 0.2;
64              add(LIST, new org.apache.commons.math3.distribution.PascalDistribution(numSuccessesPascal, probSuccessPascal),
65                  MathArrays.sequence(18, 1, 1),
66                  RandomSource.create(RandomSource.TWO_CMRES));
67  
68              // Uniform ("inverse method").
69              final int loUniform = -3;
70              final int hiUniform = 4;
71              add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(loUniform, hiUniform),
72                  MathArrays.sequence(8, -3, 1),
73                  RandomSource.create(RandomSource.SPLIT_MIX_64));
74              // Uniform.
75              add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(loUniform, hiUniform),
76                  MathArrays.sequence(8, -3, 1),
77                  new DiscreteUniformSampler(RandomSource.create(RandomSource.MT_64), loUniform, hiUniform));
78              // Uniform (large range).
79              final int halfMax = Integer.MAX_VALUE / 2;
80              final int hiLargeUniform = halfMax + 10;
81              final int loLargeUniform = -hiLargeUniform;
82              add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(loLargeUniform, hiLargeUniform),
83                  MathArrays.sequence(20, -halfMax, halfMax / 10),
84                  new DiscreteUniformSampler(RandomSource.create(RandomSource.WELL_1024_A), loLargeUniform, hiLargeUniform));
85  
86              // Zipf ("inverse method").
87              final int numElementsZipf = 5;
88              final double exponentZipf = 2.345;
89              add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(numElementsZipf, exponentZipf),
90                  MathArrays.sequence(5, 1, 1),
91                  RandomSource.create(RandomSource.XOR_SHIFT_1024_S));
92              // Zipf.
93              add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(numElementsZipf, exponentZipf),
94                  MathArrays.sequence(5, 1, 1),
95                  new RejectionInversionZipfSampler(RandomSource.create(RandomSource.WELL_19937_C), numElementsZipf, exponentZipf));
96              // Zipf (exponent close to 1).
97              final double exponentCloseToOneZipf = 1 - 1e-10;
98              add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(numElementsZipf, exponentCloseToOneZipf),
99                  MathArrays.sequence(5, 1, 1),
100                 new RejectionInversionZipfSampler(RandomSource.create(RandomSource.WELL_19937_C), numElementsZipf, exponentCloseToOneZipf));
101 
102             // Poisson ("inverse method").
103             final double meanPoisson = 3.21;
104             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(meanPoisson),
105                 MathArrays.sequence(10, 0, 1),
106                 RandomSource.create(RandomSource.MWC_256));
107             // Poisson.
108             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(meanPoisson),
109                 MathArrays.sequence(10, 0, 1),
110                 new PoissonSampler(RandomSource.create(RandomSource.KISS), meanPoisson));
111             // Dedicated small mean poisson sampler
112             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(meanPoisson),
113                 MathArrays.sequence(10, 0, 1),
114                 new SmallMeanPoissonSampler(RandomSource.create(RandomSource.KISS), meanPoisson));
115             // Poisson (40 < mean < 80).
116             final double largeMeanPoisson = 67.89;
117             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(largeMeanPoisson),
118                 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
119                 new PoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), largeMeanPoisson));
120             // Dedicated large mean poisson sampler
121             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(largeMeanPoisson),
122                 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
123                 new LargeMeanPoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), largeMeanPoisson));
124             // Poisson (mean >> 40).
125             final double veryLargeMeanPoisson = 543.21;
126             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(veryLargeMeanPoisson),
127                 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
128                 new PoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), veryLargeMeanPoisson));
129             // Dedicated large mean poisson sampler
130             add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(veryLargeMeanPoisson),
131                 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
132                 new LargeMeanPoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), veryLargeMeanPoisson));
133         } catch (Exception e) {
134             System.err.println("Unexpected exception while creating the list of samplers: " + e);
135             e.printStackTrace(System.err);
136             throw new RuntimeException(e);
137         }
138     }
139 
140     /**
141      * Class contains only static methods.
142      */
143     private DiscreteSamplersList() {}
144 
145     /**
146      * @param list List of data (one the "parameters" tested by the Junit parametric test).
147      * @param dist Distribution to which the samples are supposed to conform.
148      * @param points Outcomes selection.
149      * @param rng Generator of uniformly distributed sequences.
150      */
151     private static void add(List<DiscreteSamplerTestData[]> list,
152                             final org.apache.commons.math3.distribution.IntegerDistribution dist,
153                             int[] points,
154                             UniformRandomProvider rng) {
155         final DiscreteSampler inverseMethodSampler =
156             new InverseTransformDiscreteSampler(rng,
157                                                 new DiscreteInverseCumulativeProbabilityFunction() {
158                                                     @Override
159                                                     public int inverseCumulativeProbability(double p) {
160                                                         return dist.inverseCumulativeProbability(p);
161                                                     }
162                                                     @Override
163                                                     public String toString() {
164                                                         return dist.toString();
165                                                     }
166                                                 });
167         list.add(new DiscreteSamplerTestData[] { new DiscreteSamplerTestData(inverseMethodSampler,
168                                                                              points,
169                                                                              getProbabilities(dist, points)) });
170      }
171 
172     /**
173      * @param list List of data (one the "parameters" tested by the Junit parametric test).
174      * @param dist Distribution to which the samples are supposed to conform.
175      * @param points Outcomes selection.
176      * @param sampler Sampler.
177      */
178     private static void add(List<DiscreteSamplerTestData[]> list,
179                             final org.apache.commons.math3.distribution.IntegerDistribution dist,
180                             int[] points,
181                             final DiscreteSampler sampler) {
182         list.add(new DiscreteSamplerTestData[] { new DiscreteSamplerTestData(sampler,
183                                                                              points,
184                                                                              getProbabilities(dist, points)) });
185     }
186 
187     /**
188      * Subclasses that are "parametric" tests can forward the call to
189      * the "@Parameters"-annotated method to this method.
190      *
191      * @return the list of all generators.
192      */
193     public static Iterable<DiscreteSamplerTestData[]> list() {
194         return Collections.unmodifiableList(LIST);
195     }
196 
197     /**
198      * @param dist Distribution.
199      * @param points Points.
200      * @return the probabilities of the given points according to the distribution. 
201      */
202     private static double[] getProbabilities(org.apache.commons.math3.distribution.IntegerDistribution dist,
203                                              int[] points) {
204         final int len = points.length;
205         final double[] prob = new double[len];
206         for (int i = 0; i < len; i++) {
207             prob[i] = dist instanceof org.apache.commons.math3.distribution.UniformIntegerDistribution ? // XXX Workaround.
208                 getProbability((org.apache.commons.math3.distribution.UniformIntegerDistribution) dist) :
209                 dist.probability(points[i]);
210 
211             if (prob[i] < 0) {
212                 throw new IllegalStateException(dist + ": p < 0 (at " + points[i] + ", p=" + prob[i]);
213             }
214         }
215         return prob;
216     }
217 
218     /**
219      * Workaround bugs in Commons Math's "UniformIntegerDistribution" (cf. MATH-1396).
220      */
221     private static double getProbability(org.apache.commons.math3.distribution.UniformIntegerDistribution dist) {
222         return 1 / ((double) dist.getSupportUpperBound() - (double) dist.getSupportLowerBound() + 1);
223     }
224 }