1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import java.util.List;
20 import java.util.ArrayList;
21 import java.util.Collections;
22
23 import org.apache.commons.math3.util.MathArrays;
24
25 import org.apache.commons.rng.UniformRandomProvider;
26 import org.apache.commons.rng.simple.RandomSource;
27
28
29
30
31 public final class DiscreteSamplersList {
32
33 private static final List<DiscreteSamplerTestData[]> LIST =
34 new ArrayList<DiscreteSamplerTestData[]>();
35
36 static {
37 try {
38
39
40
41
42
43 org.apache.commons.math3.random.RandomGenerator unusedRng = null;
44
45
46
47
48 final int trialsBinomial = 20;
49 final double probSuccessBinomial = 0.67;
50 add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, probSuccessBinomial),
51 MathArrays.sequence(8, 9, 1),
52 RandomSource.create(RandomSource.KISS));
53 add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, probSuccessBinomial),
54
55 MathArrays.sequence(8, 9, 1),
56 MarsagliaTsangWangDiscreteSampler.Binomial.of(RandomSource.create(RandomSource.WELL_19937_A), trialsBinomial, probSuccessBinomial));
57
58 add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, 1 - probSuccessBinomial),
59
60 MathArrays.sequence(8, 4, 1),
61 MarsagliaTsangWangDiscreteSampler.Binomial.of(RandomSource.create(RandomSource.WELL_19937_C), trialsBinomial, 1 - probSuccessBinomial));
62
63
64 final double probSuccessGeometric = 0.21;
65 add(LIST, new org.apache.commons.math3.distribution.GeometricDistribution(unusedRng, probSuccessGeometric),
66 MathArrays.sequence(10, 0, 1),
67 RandomSource.create(RandomSource.ISAAC));
68
69 add(LIST, new org.apache.commons.math3.distribution.GeometricDistribution(unusedRng, probSuccessGeometric),
70 MathArrays.sequence(10, 0, 1),
71 GeometricSampler.of(RandomSource.create(RandomSource.XOR_SHIFT_1024_S), probSuccessGeometric));
72
73
74 final int popSizeHyper = 34;
75 final int numSuccessesHyper = 11;
76 final int sampleSizeHyper = 12;
77 add(LIST, new org.apache.commons.math3.distribution.HypergeometricDistribution(unusedRng, popSizeHyper, numSuccessesHyper, sampleSizeHyper),
78 MathArrays.sequence(10, 0, 1),
79 RandomSource.create(RandomSource.MT));
80
81
82 final int numSuccessesPascal = 6;
83 final double probSuccessPascal = 0.2;
84 add(LIST, new org.apache.commons.math3.distribution.PascalDistribution(unusedRng, numSuccessesPascal, probSuccessPascal),
85 MathArrays.sequence(18, 1, 1),
86 RandomSource.create(RandomSource.TWO_CMRES));
87
88
89 final int loUniform = -3;
90 final int hiUniform = 4;
91 add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiUniform),
92 MathArrays.sequence(8, -3, 1),
93 RandomSource.create(RandomSource.SPLIT_MIX_64));
94
95 add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiUniform),
96 MathArrays.sequence(8, -3, 1),
97 DiscreteUniformSampler.of(RandomSource.create(RandomSource.MT_64), loUniform, hiUniform));
98
99 final int halfMax = Integer.MAX_VALUE / 2;
100 final int hiLargeUniform = halfMax + 10;
101 final int loLargeUniform = -hiLargeUniform;
102 add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loLargeUniform, hiLargeUniform),
103 MathArrays.sequence(20, -halfMax, halfMax / 10),
104 DiscreteUniformSampler.of(RandomSource.create(RandomSource.WELL_1024_A), loLargeUniform, hiLargeUniform));
105
106 final int rangeNonPowerOf2Uniform = 11;
107 final int hiNonPowerOf2Uniform = loUniform + rangeNonPowerOf2Uniform;
108 add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiNonPowerOf2Uniform),
109 MathArrays.sequence(rangeNonPowerOf2Uniform, -3, 1),
110 DiscreteUniformSampler.of(RandomSource.create(RandomSource.XO_SHI_RO_256_SS), loUniform, hiNonPowerOf2Uniform));
111
112
113 final int numElementsZipf = 5;
114 final double exponentZipf = 2.345;
115 add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(unusedRng, numElementsZipf, exponentZipf),
116 MathArrays.sequence(5, 1, 1),
117 RandomSource.create(RandomSource.XOR_SHIFT_1024_S));
118
119 add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(unusedRng, numElementsZipf, exponentZipf),
120 MathArrays.sequence(5, 1, 1),
121 RejectionInversionZipfSampler.of(RandomSource.create(RandomSource.WELL_19937_C), numElementsZipf, exponentZipf));
122
123 final double exponentCloseToOneZipf = 1 - 1e-10;
124 add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(unusedRng, numElementsZipf, exponentCloseToOneZipf),
125 MathArrays.sequence(5, 1, 1),
126 RejectionInversionZipfSampler.of(RandomSource.create(RandomSource.WELL_19937_C), numElementsZipf, exponentCloseToOneZipf));
127
128
129 final double epsilonPoisson = org.apache.commons.math3.distribution.PoissonDistribution.DEFAULT_EPSILON;
130 final int maxIterationsPoisson = org.apache.commons.math3.distribution.PoissonDistribution.DEFAULT_MAX_ITERATIONS;
131 final double meanPoisson = 3.21;
132 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
133 MathArrays.sequence(10, 0, 1),
134 RandomSource.create(RandomSource.MWC_256));
135
136 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
137 MathArrays.sequence(10, 0, 1),
138 PoissonSampler.of(RandomSource.create(RandomSource.KISS), meanPoisson));
139
140 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
141 MathArrays.sequence(10, 0, 1),
142 SmallMeanPoissonSampler.of(RandomSource.create(RandomSource.XO_SHI_RO_256_PLUS), meanPoisson));
143 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
144 MathArrays.sequence(10, 0, 1),
145 KempSmallMeanPoissonSampler.of(RandomSource.create(RandomSource.XO_SHI_RO_128_PLUS), meanPoisson));
146 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
147 MathArrays.sequence(10, 0, 1),
148 MarsagliaTsangWangDiscreteSampler.Poisson.of(RandomSource.create(RandomSource.XO_SHI_RO_128_PLUS), meanPoisson));
149
150
151 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, meanPoisson, epsilonPoisson, maxIterationsPoisson),
152 MathArrays.sequence(10, 0, 1),
153 LargeMeanPoissonSampler.of(RandomSource.create(RandomSource.PCG_MCG_XSH_RR_32), meanPoisson));
154
155 final double largeMeanPoisson = 67.89;
156 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, largeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
157 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
158 PoissonSampler.of(RandomSource.create(RandomSource.SPLIT_MIX_64), largeMeanPoisson));
159
160 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, largeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
161 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
162 LargeMeanPoissonSampler.of(RandomSource.create(RandomSource.SPLIT_MIX_64), largeMeanPoisson));
163 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, largeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
164 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
165 MarsagliaTsangWangDiscreteSampler.Poisson.of(RandomSource.create(RandomSource.XO_RO_SHI_RO_128_PLUS), largeMeanPoisson));
166
167 final double veryLargeMeanPoisson = 543.21;
168 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
169 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
170 PoissonSampler.of(RandomSource.create(RandomSource.SPLIT_MIX_64), veryLargeMeanPoisson));
171
172 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
173 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
174 LargeMeanPoissonSampler.of(RandomSource.create(RandomSource.SPLIT_MIX_64), veryLargeMeanPoisson));
175 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
176 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
177 MarsagliaTsangWangDiscreteSampler.Poisson.of(RandomSource.create(RandomSource.XO_RO_SHI_RO_64_SS), veryLargeMeanPoisson));
178
179
180 final double[] discreteProbabilities = new double[] {0.1, 0.2, 0.3, 0.4, 0.5};
181 add(LIST, discreteProbabilities,
182 MarsagliaTsangWangDiscreteSampler.Enumerated.of(RandomSource.create(RandomSource.XO_SHI_RO_512_PLUS), discreteProbabilities));
183 add(LIST, discreteProbabilities,
184 GuideTableDiscreteSampler.of(RandomSource.create(RandomSource.XO_SHI_RO_512_SS), discreteProbabilities));
185 add(LIST, discreteProbabilities,
186 AliasMethodDiscreteSampler.of(RandomSource.create(RandomSource.KISS), discreteProbabilities));
187 } catch (Exception e) {
188
189 System.err.println("Unexpected exception while creating the list of samplers: " + e);
190 e.printStackTrace(System.err);
191
192 throw new RuntimeException(e);
193 }
194 }
195
196
197
198
199 private DiscreteSamplersList() {}
200
201
202
203
204
205
206
207 private static void add(List<DiscreteSamplerTestData[]> list,
208 final org.apache.commons.math3.distribution.IntegerDistribution dist,
209 int[] points,
210 UniformRandomProvider rng) {
211 final DiscreteSampler inverseMethodSampler =
212 InverseTransformDiscreteSampler.of(rng,
213 new DiscreteInverseCumulativeProbabilityFunction() {
214 @Override
215 public int inverseCumulativeProbability(double p) {
216 return dist.inverseCumulativeProbability(p);
217 }
218 @Override
219 public String toString() {
220 return dist.toString();
221 }
222 });
223 list.add(new DiscreteSamplerTestData[] {new DiscreteSamplerTestData(inverseMethodSampler,
224 points,
225 getProbabilities(dist, points))});
226 }
227
228
229
230
231
232
233
234 private static void add(List<DiscreteSamplerTestData[]> list,
235 final org.apache.commons.math3.distribution.IntegerDistribution dist,
236 int[] points,
237 final DiscreteSampler sampler) {
238 list.add(new DiscreteSamplerTestData[] {new DiscreteSamplerTestData(sampler,
239 points,
240 getProbabilities(dist, points))});
241 }
242
243
244
245
246
247
248 private static void add(List<DiscreteSamplerTestData[]> list,
249 final double[] probabilities,
250 final DiscreteSampler sampler) {
251 list.add(new DiscreteSamplerTestData[] {new DiscreteSamplerTestData(sampler,
252 MathArrays.natural(probabilities.length),
253 probabilities)});
254 }
255
256
257
258
259
260
261
262 public static Iterable<DiscreteSamplerTestData[]> list() {
263 return Collections.unmodifiableList(LIST);
264 }
265
266
267
268
269
270
271 private static double[] getProbabilities(org.apache.commons.math3.distribution.IntegerDistribution dist,
272 int[] points) {
273 final int len = points.length;
274 final double[] prob = new double[len];
275 for (int i = 0; i < len; i++) {
276 prob[i] = dist instanceof org.apache.commons.math3.distribution.UniformIntegerDistribution ?
277 getProbability((org.apache.commons.math3.distribution.UniformIntegerDistribution) dist) :
278 dist.probability(points[i]);
279
280 if (prob[i] < 0) {
281 throw new IllegalStateException(dist + ": p < 0 (at " + points[i] + ", p=" + prob[i]);
282 }
283 }
284 return prob;
285 }
286
287
288
289
290 private static double getProbability(org.apache.commons.math3.distribution.UniformIntegerDistribution dist) {
291 return 1 / ((double) dist.getSupportUpperBound() - (double) dist.getSupportLowerBound() + 1);
292 }
293 }