1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.rng.sampling;
19
20 import java.util.Arrays;
21 import java.util.Collections;
22 import java.util.HashMap;
23 import java.util.List;
24 import java.util.Map;
25 import java.util.TreeMap;
26
27 import org.junit.Assert;
28 import org.junit.Test;
29 import org.apache.commons.rng.UniformRandomProvider;
30 import org.apache.commons.rng.simple.RandomSource;
31
32
33
34
35 public class DiscreteProbabilityCollectionSamplerTest {
36
37 private final UniformRandomProvider rng = RandomSource.create(RandomSource.WELL_1024_A);
38
39 @Test(expected = IllegalArgumentException.class)
40 public void testPrecondition1() {
41
42 new DiscreteProbabilityCollectionSampler<Double>(rng,
43 Arrays.asList(new Double[] {1d, 2d}),
44 new double[] {0d});
45 }
46 @Test(expected = IllegalArgumentException.class)
47 public void testPrecondition2() {
48
49 new DiscreteProbabilityCollectionSampler<Double>(rng,
50 Arrays.asList(new Double[] {1d, 2d}),
51 new double[] {0d, -1d});
52 }
53 @Test(expected = IllegalArgumentException.class)
54 public void testPrecondition3() {
55
56 new DiscreteProbabilityCollectionSampler<Double>(rng,
57 Arrays.asList(new Double[] {1d, 2d}),
58 new double[] {0d, 0d});
59 }
60 @Test(expected = IllegalArgumentException.class)
61 public void testPrecondition4() {
62
63 new DiscreteProbabilityCollectionSampler<Double>(rng,
64 Arrays.asList(new Double[] {1d, 2d}),
65 new double[] {0d, Double.NaN});
66 }
67 @Test(expected = IllegalArgumentException.class)
68 public void testPrecondition5() {
69
70 new DiscreteProbabilityCollectionSampler<Double>(rng,
71 Arrays.asList(new Double[] {1d, 2d}),
72 new double[] {0d, Double.POSITIVE_INFINITY});
73 }
74 @Test(expected = IllegalArgumentException.class)
75 public void testPrecondition6() {
76
77 new DiscreteProbabilityCollectionSampler<Double>(rng,
78 new HashMap<Double, Double>());
79 }
80 @Test(expected = IllegalArgumentException.class)
81 public void testPrecondition7() {
82
83 new DiscreteProbabilityCollectionSampler<Double>(rng,
84 Collections.<Double>emptyList(),
85 new double[0]);
86 }
87
88 @Test
89 public void testSample() {
90 final DiscreteProbabilityCollectionSampler<Double> sampler =
91 new DiscreteProbabilityCollectionSampler<Double>(rng,
92 Arrays.asList(new Double[] {3d, -1d, 3d, 7d, -2d, 8d}),
93 new double[] {0.2, 0.2, 0.3, 0.3, 0, 0});
94 final double expectedMean = 3.4;
95 final double expectedVariance = 7.84;
96
97 final int n = 100000000;
98 double sum = 0;
99 double sumOfSquares = 0;
100 for (int i = 0; i < n; i++) {
101 final double rand = sampler.sample();
102 sum += rand;
103 sumOfSquares += rand * rand;
104 }
105
106 final double mean = sum / n;
107 Assert.assertEquals(expectedMean, mean, 1e-3);
108 final double variance = sumOfSquares / n - mean * mean;
109 Assert.assertEquals(expectedVariance, variance, 2e-3);
110 }
111
112
113 @Test
114 public void testSampleUsingMap() {
115 final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
116 final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
117 final List<Integer> items = Arrays.asList(1, 3, 4, 6, 9);
118 final double[] probabilities = {0.1, 0.2, 0.3, 0.4, 0.5};
119 final DiscreteProbabilityCollectionSampler<Integer> sampler1 =
120 new DiscreteProbabilityCollectionSampler<Integer>(rng1, items, probabilities);
121
122
123 final Map<Integer, Double> map = new TreeMap<Integer, Double>();
124 for (int i = 0; i < probabilities.length; i++) {
125 map.put(items.get(i), probabilities[i]);
126 }
127 final DiscreteProbabilityCollectionSampler<Integer> sampler2 =
128 new DiscreteProbabilityCollectionSampler<Integer>(rng2, map);
129
130 for (int i = 0; i < 50; i++) {
131 Assert.assertEquals(sampler1.sample(), sampler2.sample());
132 }
133 }
134
135
136
137
138
139
140 @Test
141 public void testSampleWithProbabilityAtLastItem() {
142
143
144 final UniformRandomProvider dummyRng = new UniformRandomProvider() {
145 private int count;
146
147 public long nextLong(long n) { return 0; }
148 public long nextLong() { return 0; }
149 public int nextInt(int n) { return 0; }
150 public int nextInt() { return 0; }
151 public float nextFloat() { return 0; }
152
153 public double nextDouble() { return (count++ == 0) ? 0 : 1.0; }
154 public void nextBytes(byte[] bytes, int start, int len) {}
155 public void nextBytes(byte[] bytes) {}
156 public boolean nextBoolean() { return false; }
157
158 };
159
160 final List<Double> items = Arrays.asList(new Double[] {1d, 2d});
161 final DiscreteProbabilityCollectionSampler<Double> sampler =
162 new DiscreteProbabilityCollectionSampler<Double>(dummyRng,
163 items,
164 new double[] {0.5, 0.5});
165 final Double item1 = sampler.sample();
166 final Double item2 = sampler.sample();
167
168 Assert.assertTrue("Sample item1 is not from the list", items.contains(item1));
169 Assert.assertTrue("Sample item2 is not from the list", items.contains(item2));
170
171 Assert.assertNotSame("Item1 and 2 should be different", item1, item2);
172 }
173
174
175
176
177 @Test
178 public void testSharedStateSampler() {
179 final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
180 final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
181 final List<Double> items = Arrays.asList(new Double[] {1d, 2d, 3d, 4d});
182 final DiscreteProbabilityCollectionSampler<Double> sampler1 =
183 new DiscreteProbabilityCollectionSampler<Double>(rng1,
184 items,
185 new double[] {0.1, 0.2, 0.3, 0.4});
186 final DiscreteProbabilityCollectionSampler<Double> sampler2 = sampler1.withUniformRandomProvider(rng2);
187 RandomAssert.assertProduceSameSequence(
188 new RandomAssert.Sampler<Double>() {
189 @Override
190 public Double sample() {
191 return sampler1.sample();
192 }
193 },
194 new RandomAssert.Sampler<Double>() {
195 @Override
196 public Double sample() {
197 return sampler2.sample();
198 }
199 });
200 }
201 }