1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import java.util.List;
20 import java.util.ArrayList;
21 import java.util.Collections;
22
23 import org.apache.commons.math3.util.MathArrays;
24
25 import org.apache.commons.rng.UniformRandomProvider;
26 import org.apache.commons.rng.simple.RandomSource;
27
28
29
30
31 public class DiscreteSamplersList {
32
33 private static final List<DiscreteSamplerTestData[]> LIST =
34 new ArrayList<DiscreteSamplerTestData[]>();
35
36 static {
37 try {
38
39
40
41 final int trialsBinomial = 20;
42 final double probSuccessBinomial = 0.67;
43 add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(trialsBinomial, probSuccessBinomial),
44 MathArrays.sequence(8, 9, 1),
45 RandomSource.create(RandomSource.KISS));
46
47
48 final double probSuccessGeometric = 0.21;
49 add(LIST, new org.apache.commons.math3.distribution.GeometricDistribution(probSuccessGeometric),
50 MathArrays.sequence(10, 0, 1),
51 RandomSource.create(RandomSource.ISAAC));
52
53
54 final int popSizeHyper = 34;
55 final int numSuccessesHyper = 11;
56 final int sampleSizeHyper = 12;
57 add(LIST, new org.apache.commons.math3.distribution.HypergeometricDistribution(popSizeHyper, numSuccessesHyper, sampleSizeHyper),
58 MathArrays.sequence(10, 0, 1),
59 RandomSource.create(RandomSource.MT));
60
61
62 final int numSuccessesPascal = 6;
63 final double probSuccessPascal = 0.2;
64 add(LIST, new org.apache.commons.math3.distribution.PascalDistribution(numSuccessesPascal, probSuccessPascal),
65 MathArrays.sequence(18, 1, 1),
66 RandomSource.create(RandomSource.TWO_CMRES));
67
68
69 final int loUniform = -3;
70 final int hiUniform = 4;
71 add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(loUniform, hiUniform),
72 MathArrays.sequence(8, -3, 1),
73 RandomSource.create(RandomSource.SPLIT_MIX_64));
74
75 add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(loUniform, hiUniform),
76 MathArrays.sequence(8, -3, 1),
77 new DiscreteUniformSampler(RandomSource.create(RandomSource.MT_64), loUniform, hiUniform));
78
79 final int halfMax = Integer.MAX_VALUE / 2;
80 final int hiLargeUniform = halfMax + 10;
81 final int loLargeUniform = -hiLargeUniform;
82 add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(loLargeUniform, hiLargeUniform),
83 MathArrays.sequence(20, -halfMax, halfMax / 10),
84 new DiscreteUniformSampler(RandomSource.create(RandomSource.WELL_1024_A), loLargeUniform, hiLargeUniform));
85
86
87 final int numElementsZipf = 5;
88 final double exponentZipf = 2.345;
89 add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(numElementsZipf, exponentZipf),
90 MathArrays.sequence(5, 1, 1),
91 RandomSource.create(RandomSource.XOR_SHIFT_1024_S));
92
93 add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(numElementsZipf, exponentZipf),
94 MathArrays.sequence(5, 1, 1),
95 new RejectionInversionZipfSampler(RandomSource.create(RandomSource.WELL_19937_C), numElementsZipf, exponentZipf));
96
97 final double exponentCloseToOneZipf = 1 - 1e-10;
98 add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(numElementsZipf, exponentCloseToOneZipf),
99 MathArrays.sequence(5, 1, 1),
100 new RejectionInversionZipfSampler(RandomSource.create(RandomSource.WELL_19937_C), numElementsZipf, exponentCloseToOneZipf));
101
102
103 final double meanPoisson = 3.21;
104 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(meanPoisson),
105 MathArrays.sequence(10, 0, 1),
106 RandomSource.create(RandomSource.MWC_256));
107
108 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(meanPoisson),
109 MathArrays.sequence(10, 0, 1),
110 new PoissonSampler(RandomSource.create(RandomSource.KISS), meanPoisson));
111
112 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(meanPoisson),
113 MathArrays.sequence(10, 0, 1),
114 new SmallMeanPoissonSampler(RandomSource.create(RandomSource.KISS), meanPoisson));
115
116 final double largeMeanPoisson = 67.89;
117 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(largeMeanPoisson),
118 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
119 new PoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), largeMeanPoisson));
120
121 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(largeMeanPoisson),
122 MathArrays.sequence(50, (int) (largeMeanPoisson - 25), 1),
123 new LargeMeanPoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), largeMeanPoisson));
124
125 final double veryLargeMeanPoisson = 543.21;
126 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(veryLargeMeanPoisson),
127 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
128 new PoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), veryLargeMeanPoisson));
129
130 add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(veryLargeMeanPoisson),
131 MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
132 new LargeMeanPoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), veryLargeMeanPoisson));
133 } catch (Exception e) {
134 System.err.println("Unexpected exception while creating the list of samplers: " + e);
135 e.printStackTrace(System.err);
136 throw new RuntimeException(e);
137 }
138 }
139
140
141
142
143 private DiscreteSamplersList() {}
144
145
146
147
148
149
150
151 private static void add(List<DiscreteSamplerTestData[]> list,
152 final org.apache.commons.math3.distribution.IntegerDistribution dist,
153 int[] points,
154 UniformRandomProvider rng) {
155 final DiscreteSampler inverseMethodSampler =
156 new InverseTransformDiscreteSampler(rng,
157 new DiscreteInverseCumulativeProbabilityFunction() {
158 @Override
159 public int inverseCumulativeProbability(double p) {
160 return dist.inverseCumulativeProbability(p);
161 }
162 @Override
163 public String toString() {
164 return dist.toString();
165 }
166 });
167 list.add(new DiscreteSamplerTestData[] { new DiscreteSamplerTestData(inverseMethodSampler,
168 points,
169 getProbabilities(dist, points)) });
170 }
171
172
173
174
175
176
177
178 private static void add(List<DiscreteSamplerTestData[]> list,
179 final org.apache.commons.math3.distribution.IntegerDistribution dist,
180 int[] points,
181 final DiscreteSampler sampler) {
182 list.add(new DiscreteSamplerTestData[] { new DiscreteSamplerTestData(sampler,
183 points,
184 getProbabilities(dist, points)) });
185 }
186
187
188
189
190
191
192
193 public static Iterable<DiscreteSamplerTestData[]> list() {
194 return Collections.unmodifiableList(LIST);
195 }
196
197
198
199
200
201
202 private static double[] getProbabilities(org.apache.commons.math3.distribution.IntegerDistribution dist,
203 int[] points) {
204 final int len = points.length;
205 final double[] prob = new double[len];
206 for (int i = 0; i < len; i++) {
207 prob[i] = dist instanceof org.apache.commons.math3.distribution.UniformIntegerDistribution ?
208 getProbability((org.apache.commons.math3.distribution.UniformIntegerDistribution) dist) :
209 dist.probability(points[i]);
210
211 if (prob[i] < 0) {
212 throw new IllegalStateException(dist + ": p < 0 (at " + points[i] + ", p=" + prob[i]);
213 }
214 }
215 return prob;
216 }
217
218
219
220
221 private static double getProbability(org.apache.commons.math3.distribution.UniformIntegerDistribution dist) {
222 return 1 / ((double) dist.getSupportUpperBound() - (double) dist.getSupportLowerBound() + 1);
223 }
224 }