1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.rng.UniformRandomProvider;
20 import org.apache.commons.rng.core.source32.IntProvider;
21 import org.apache.commons.rng.sampling.RandomAssert;
22 import org.apache.commons.rng.simple.RandomSource;
23 import org.junit.Assert;
24 import org.junit.Test;
25
26 import java.util.Locale;
27
28
29
30
31
32 public class DiscreteUniformSamplerTest {
33
34
35
36 @Test(expected = IllegalArgumentException.class)
37 public void testConstructorThrowsWithLowerAboveUpper() {
38 final int upper = 55;
39 final int lower = upper + 1;
40 final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
41 DiscreteUniformSampler.of(rng, lower, upper);
42 }
43
44 @Test
45 public void testSamplesWithRangeOf1() {
46 final int upper = 99;
47 final int lower = upper;
48 final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64);
49 final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
50 for (int i = 0; i < 5; i++) {
51 Assert.assertEquals(lower, sampler.sample());
52 }
53 }
54
55
56
57
58
59 @Test
60 public void testSamplesWithFullRange() {
61 final int upper = Integer.MAX_VALUE;
62 final int lower = Integer.MIN_VALUE;
63 final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
64 final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
65 final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng2, lower, upper);
66 for (int i = 0; i < 5; i++) {
67 Assert.assertEquals(rng1.nextInt(), sampler.sample());
68 }
69 }
70
71 @Test
72 public void testSamplesWithPowerOf2Range() {
73 final UniformRandomProvider rngZeroBits = new IntProvider() {
74 @Override
75 public int next() {
76 return 0;
77 }
78 };
79 final UniformRandomProvider rngAllBits = new IntProvider() {
80 @Override
81 public int next() {
82 return 0xffffffff;
83 }
84 };
85
86 final int lower = -3;
87 DiscreteUniformSampler sampler;
88
89
90
91 for (int i = 0; i < 32; i++) {
92 final int range = 1 << i;
93 final int upper = lower + range - 1;
94 sampler = new DiscreteUniformSampler(rngZeroBits, lower, upper);
95 Assert.assertEquals("Zero bits sample", lower, sampler.sample());
96 sampler = new DiscreteUniformSampler(rngAllBits, lower, upper);
97 Assert.assertEquals("All bits sample", upper, sampler.sample());
98 }
99 }
100
101 @Test
102 public void testOffsetSamplesWithNonPowerOf2Range() {
103 assertOffsetSamples(257);
104 }
105
106 @Test
107 public void testOffsetSamplesWithPowerOf2Range() {
108 assertOffsetSamples(256);
109 }
110
111 @Test
112 public void testOffsetSamplesWithRangeOf1() {
113 assertOffsetSamples(1);
114 }
115
116 private static void assertOffsetSamples(int range) {
117 final Long seed = RandomSource.createLong();
118 final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
119 final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
120 final UniformRandomProvider rng3 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
121
122
123 range = range - 1;
124 final int offsetLo = -13;
125 final int offsetHi = 42;
126 final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng1, 0, range);
127 final SharedStateDiscreteSampler samplerLo = DiscreteUniformSampler.of(rng2, offsetLo, offsetLo + range);
128 final SharedStateDiscreteSampler samplerHi = DiscreteUniformSampler.of(rng3, offsetHi, offsetHi + range);
129 for (int i = 0; i < 10; i++) {
130 final int sample1 = sampler.sample();
131 final int sample2 = samplerLo.sample();
132 final int sample3 = samplerHi.sample();
133 Assert.assertEquals("Incorrect negative offset sample", sample1 + offsetLo, sample2);
134 Assert.assertEquals("Incorrect positive offset sample", sample1 + offsetHi, sample3);
135 }
136 }
137
138
139
140
141 @Test
142 public void testSampleUniformityWithNonPowerOf2Range() {
143
144
145
146
147 final UniformRandomProvider rng = new IntProvider() {
148 private final int increment = 362437;
149
150 private final int start = Integer.MIN_VALUE - increment;
151
152 private int bits = start;
153
154 @Override
155 public int next() {
156
157
158
159 int result = bits += increment;
160 if (result < start) {
161 return result;
162 }
163 throw new IllegalStateException("end of sequence");
164 }
165 };
166
167
168 final int n = 37;
169 final int[] histogram = new int[n];
170
171 final int lower = 0;
172 final int upper = n - 1;
173
174 final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
175
176 try {
177 while (true) {
178 histogram[sampler.sample()]++;
179 }
180 } catch (IllegalStateException ex) {
181
182 }
183
184
185 int min = histogram[0];
186 int max = histogram[0];
187 for (int value : histogram) {
188 min = Math.min(min, value);
189 max = Math.max(max, value);
190 }
191 Assert.assertTrue("Not uniform, max = " + max + ", min=" + min, max - min <= 1);
192 }
193
194
195
196
197 @Test
198 public void testSampleUniformityWithPowerOf2Range() {
199
200 final UniformRandomProvider rng = new IntProvider() {
201 private int bits = 0;
202
203 @Override
204 public int next() {
205
206 return Integer.reverse(bits++);
207 }
208 };
209
210
211 final int n = 32;
212 final int[] histogram = new int[n];
213
214 final int lower = 0;
215 final int upper = n - 1;
216
217 final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
218
219 final int expected = 2;
220 for (int i = expected * n; i-- > 0;) {
221 histogram[sampler.sample()]++;
222 }
223
224
225 for (int value : histogram) {
226 Assert.assertEquals(expected, value);
227 }
228 }
229
230
231
232
233
234
235
236 @Test
237 public void testSampleRejectionWithNonPowerOf2Range() {
238
239
240 final int[] value = new int[1];
241 final UniformRandomProvider rng = new IntProvider() {
242 @Override
243 public int next() {
244 return value[0]++;
245 }
246 };
247
248
249
250 final int n = 37;
251 final int lower = 0;
252 final int upper = n - 1;
253
254 final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
255
256 final int sample = sampler.sample();
257
258 Assert.assertEquals("Sample is incorrect", 0, sample);
259 Assert.assertEquals("Sample should be produced from 2nd value", 2, value[0]);
260 }
261
262
263
264
265 @Test
266 public void testSharedStateSamplerWithSmallRange() {
267 testSharedStateSampler(5, 67);
268 }
269
270
271
272
273 @Test
274 public void testSharedStateSamplerWithLargeRange() {
275
276 testSharedStateSampler(Integer.MIN_VALUE / 2 - 1, Integer.MAX_VALUE / 2 + 1);
277 }
278
279
280
281
282 @Test
283 public void testSharedStateSamplerWithPowerOf2Range() {
284 testSharedStateSampler(0, 31);
285 }
286
287
288
289
290 @Test
291 public void testSharedStateSamplerWithRangeOf1() {
292 testSharedStateSampler(9, 9);
293 }
294
295
296
297
298
299
300
301 private static void testSharedStateSampler(int lower, int upper) {
302 final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
303 final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
304
305 final SharedStateDiscreteSampler sampler1 =
306 new DiscreteUniformSampler(rng1, lower, upper);
307 final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
308 RandomAssert.assertProduceSameSequence(sampler1, sampler2);
309 }
310
311 @Test
312 public void testToStringWithSmallRange() {
313 assertToString(5, 67);
314 }
315
316 @Test
317 public void testToStringWithLargeRange() {
318 assertToString(-99999999, Integer.MAX_VALUE);
319 }
320
321 @Test
322 public void testToStringWithPowerOf2Range() {
323
324 assertToString(0, 31);
325 }
326
327 @Test
328 public void testToStringWithRangeOf1() {
329 assertToString(9, 9);
330 }
331
332
333
334
335
336
337
338
339 private static void assertToString(int lower, int upper) {
340 final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
341 final DiscreteUniformSampler sampler =
342 new DiscreteUniformSampler(rng, lower, upper);
343 Assert.assertTrue(sampler.toString().toLowerCase(Locale.US).contains("uniform"));
344 }
345 }