1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.math3.distribution.BinomialDistribution;
20 import org.apache.commons.math3.distribution.PoissonDistribution;
21 import org.apache.commons.math3.stat.inference.ChiSquareTest;
22 import org.apache.commons.rng.UniformRandomProvider;
23 import org.apache.commons.rng.sampling.RandomAssert;
24 import org.apache.commons.rng.simple.RandomSource;
25 import org.junit.Assert;
26 import org.junit.Test;
27
28
29
30
31 public class GuideTableDiscreteSamplerTest {
32 @Test(expected = IllegalArgumentException.class)
33 public void testConstructorThrowsWithNullProbabilites() {
34 createSampler(null, 1.0);
35 }
36
37 @Test(expected = IllegalArgumentException.class)
38 public void testConstructorThrowsWithZeroLengthProbabilites() {
39 createSampler(new double[0], 1.0);
40 }
41
42 @Test(expected = IllegalArgumentException.class)
43 public void testConstructorThrowsWithNegativeProbabilites() {
44 createSampler(new double[] {-1, 0.1, 0.2}, 1.0);
45 }
46
47 @Test(expected = IllegalArgumentException.class)
48 public void testConstructorThrowsWithNaNProbabilites() {
49 createSampler(new double[] {0.1, Double.NaN, 0.2}, 1.0);
50 }
51
52 @Test(expected = IllegalArgumentException.class)
53 public void testConstructorThrowsWithInfiniteProbabilites() {
54 createSampler(new double[] {0.1, Double.POSITIVE_INFINITY, 0.2}, 1.0);
55 }
56
57 @Test(expected = IllegalArgumentException.class)
58 public void testConstructorThrowsWithInfiniteSumProbabilites() {
59 createSampler(new double[] {Double.MAX_VALUE, Double.MAX_VALUE}, 1.0);
60 }
61
62 @Test(expected = IllegalArgumentException.class)
63 public void testConstructorThrowsWithZeroSumProbabilites() {
64 createSampler(new double[4], 1.0);
65 }
66
67 @Test(expected = IllegalArgumentException.class)
68 public void testConstructorThrowsWithZeroAlpha() {
69 createSampler(new double[] {0.5, 0.5}, 0.0);
70 }
71
72 @Test(expected = IllegalArgumentException.class)
73 public void testConstructorThrowsWithNegativeAlpha() {
74 createSampler(new double[] {0.5, 0.5}, -1.0);
75 }
76
77 @Test
78 public void testToString() {
79 final SharedStateDiscreteSampler sampler = createSampler(new double[] {0.5, 0.5}, 1.0);
80 Assert.assertTrue(sampler.toString().toLowerCase().contains("guide table"));
81 }
82
83
84
85
86
87
88
89 private static SharedStateDiscreteSampler createSampler(double[] probabilities, double alpha) {
90 final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64);
91 return GuideTableDiscreteSampler.of(rng, probabilities, alpha);
92 }
93
94
95
96
97 @Test
98 public void testBinomialSamples() {
99 final int trials = 67;
100 final double probabilityOfSuccess = 0.345;
101 final BinomialDistribution dist = new BinomialDistribution(null, trials, probabilityOfSuccess);
102 final double[] expected = new double[trials + 1];
103 for (int i = 0; i < expected.length; i++) {
104 expected[i] = dist.probability(i);
105 }
106 checkSamples(expected, 1.0);
107 }
108
109
110
111
112 @Test
113 public void testPoissonSamples() {
114 final double mean = 3.14;
115 final PoissonDistribution dist = new PoissonDistribution(null, mean,
116 PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS);
117 final int maxN = dist.inverseCumulativeProbability(1 - 1e-6);
118 final double[] expected = new double[maxN];
119 for (int i = 0; i < expected.length; i++) {
120 expected[i] = dist.probability(i);
121 }
122 checkSamples(expected, 1.0);
123 }
124
125
126
127
128 @Test
129 public void testNonUniformSamplesWithProbabilities() {
130 final double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3};
131 checkSamples(expected, 1.0);
132 }
133
134
135
136
137
138 @Test
139 public void testNonUniformSamplesWithProbabilitiesWithSmallAlpha() {
140 final double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3};
141 checkSamples(expected, 0.1);
142 }
143
144
145
146
147
148 @Test
149 public void testNonUniformSamplesWithProbabilitiesWithLargeAlpha() {
150 final double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3};
151 checkSamples(expected, 10.0);
152 }
153
154
155
156
157
158 @Test
159 public void testNonUniformSamplesWithObservations() {
160 final double[] expected = {1, 2, 3, 1, 3};
161 checkSamples(expected, 1.0);
162 }
163
164
165
166
167
168 @Test
169 public void testNonUniformSamplesWithZeroProbabilities() {
170 final double[] expected = {0.1, 0, 0.2, 0.3, 0.1, 0.3, 0};
171 checkSamples(expected, 1.0);
172 }
173
174
175
176
177
178 @Test
179 public void testNonUniformSamplesWithZeroObservations() {
180 final double[] expected = {1, 2, 3, 0, 1, 3, 0};
181 checkSamples(expected, 1.0);
182 }
183
184
185
186
187
188 @Test
189 public void testUniformSamplesWithNoObservationLessThanTheMean() {
190 final double[] expected = {2, 2, 2, 2, 2, 2};
191 checkSamples(expected, 1.0);
192 }
193
194
195
196
197
198
199
200
201
202
203 private static void checkSamples(double[] probabilies, double alpha) {
204 final SharedStateDiscreteSampler sampler = createSampler(probabilies, alpha);
205
206 final int numberOfSamples = 10000;
207 final long[] samples = new long[probabilies.length];
208 for (int i = 0; i < numberOfSamples; i++) {
209 samples[sampler.sample()]++;
210 }
211
212
213
214 int mapSize = 0;
215 for (int i = 0; i < probabilies.length; i++) {
216 if (probabilies[i] != 0) {
217 mapSize++;
218 }
219 }
220
221 final double[] expected = new double[mapSize];
222 final long[] observed = new long[mapSize];
223 for (int i = 0; i < probabilies.length; i++) {
224 if (probabilies[i] == 0) {
225 Assert.assertEquals("No samples expected from zero probability", 0, samples[i]);
226 } else {
227
228 --mapSize;
229 expected[mapSize] = probabilies[i];
230 observed[mapSize] = samples[i];
231 }
232 }
233
234 final ChiSquareTest chiSquareTest = new ChiSquareTest();
235
236 Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
237 }
238
239
240
241
242 @Test
243 public void testSharedStateSampler() {
244 final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
245 final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
246 final double[] probabilities = {0.1, 0, 0.2, 0.3, 0.1, 0.3, 0};
247 final SharedStateDiscreteSampler sampler1 =
248 GuideTableDiscreteSampler.of(rng1, probabilities);
249 final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
250 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
251 }
252 }