1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling;
18
19 import java.util.Set;
20 import java.util.HashSet;
21 import java.util.LinkedList;
22 import java.util.List;
23 import java.util.ListIterator;
24 import java.util.ArrayList;
25 import java.util.Collection;
26
27 import org.junit.Assert;
28 import org.junit.Test;
29
30 import org.apache.commons.math3.stat.inference.ChiSquareTest;
31
32 import org.apache.commons.rng.UniformRandomProvider;
33 import org.apache.commons.rng.simple.RandomSource;
34
35
36
37
38 public class ListSamplerTest {
39 private final UniformRandomProvider rng = RandomSource.create(RandomSource.ISAAC, 6543432321L);
40 private final ChiSquareTest chiSquareTest = new ChiSquareTest();
41
42 @Test
43 public void testSample() {
44 final String[][] c = {{"0", "1"}, {"0", "2"}, {"0", "3"}, {"0", "4"},
45 {"1", "2"}, {"1", "3"}, {"1", "4"},
46 {"2", "3"}, {"2", "4"},
47 {"3", "4"}};
48 final long[] observed = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
49 final double[] expected = {100, 100, 100, 100, 100, 100, 100, 100, 100, 100};
50
51 final HashSet<String> cPop = new HashSet<String>();
52 for (int i = 0; i < 5; i++) {
53 cPop.add(Integer.toString(i));
54 }
55
56 final List<Set<String>> sets = new ArrayList<Set<String>>();
57 for (int i = 0; i < 10; i++) {
58 final HashSet<String> hs = new HashSet<String>();
59 hs.add(c[i][0]);
60 hs.add(c[i][1]);
61 sets.add(hs);
62 }
63
64 for (int i = 0; i < 1000; i++) {
65 observed[findSample(sets, ListSampler.sample(rng, new ArrayList<String>(cPop), 2))]++;
66 }
67
68
69 Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
70 }
71
72 @Test
73 public void testSampleWhole() {
74
75 final List<String> list = new ArrayList<String>();
76 list.add("one");
77
78 final List<String> one = ListSampler.sample(rng, list, 1);
79 Assert.assertEquals(1, one.size());
80 Assert.assertTrue(one.contains("one"));
81 }
82
83 @Test(expected = IllegalArgumentException.class)
84 public void testSamplePrecondition1() {
85
86 final List<String> list = new ArrayList<String>();
87 list.add("one");
88 ListSampler.sample(rng, list, 2);
89 }
90
91 @Test(expected = IllegalArgumentException.class)
92 public void testSamplePrecondition2() {
93
94 final List<String> list = new ArrayList<String>();
95 ListSampler.sample(rng, list, 1);
96 }
97
98 @Test
99 public void testShuffle() {
100 final List<Integer> orig = new ArrayList<Integer>();
101 for (int i = 0; i < 10; i++) {
102 orig.add((i + 1) * rng.nextInt());
103 }
104
105 final List<Integer> arrayList = new ArrayList<Integer>(orig);
106
107 ListSampler.shuffle(rng, arrayList);
108
109 Assert.assertTrue("ArrayList", compare(orig, arrayList, 0, orig.size(), false));
110
111 final List<Integer> linkedList = new LinkedList<Integer>(orig);
112
113 ListSampler.shuffle(rng, linkedList);
114
115 Assert.assertTrue("LinkedList", compare(orig, linkedList, 0, orig.size(), false));
116 }
117
118 @Test
119 public void testShuffleTail() {
120 final List<Integer> orig = new ArrayList<Integer>();
121 for (int i = 0; i < 10; i++) {
122 orig.add((i + 1) * rng.nextInt());
123 }
124 final List<Integer> list = new ArrayList<Integer>(orig);
125
126 final int start = 4;
127 ListSampler.shuffle(rng, list, start, false);
128
129
130 Assert.assertTrue(compare(orig, list, 0, start, true));
131
132
133 Assert.assertTrue(compare(orig, list, start, orig.size(), false));
134 }
135
136 @Test
137 public void testShuffleHead() {
138 final List<Integer> orig = new ArrayList<Integer>();
139 for (int i = 0; i < 10; i++) {
140 orig.add((i + 1) * rng.nextInt());
141 }
142 final List<Integer> list = new ArrayList<Integer>(orig);
143
144 final int start = 4;
145 ListSampler.shuffle(rng, list, start, true);
146
147
148 Assert.assertTrue(compare(orig, list, start + 1, orig.size(), true));
149
150
151 Assert.assertTrue(compare(orig, list, 0, start + 1, false));
152 }
153
154
155
156
157
158
159 @Test
160 public void testShuffleMatchesPermutationSamplerShuffle() {
161 final List<Integer> orig = new ArrayList<Integer>();
162 for (int i = 0; i < 10; i++) {
163 orig.add((i + 1) * rng.nextInt());
164 }
165
166 assertShuffleMatchesPermutationSamplerShuffle(new ArrayList<Integer>(orig));
167 assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig));
168 }
169
170
171
172
173
174
175 @Test
176 public void testShuffleMatchesPermutationSamplerShuffleDirectional() {
177 final List<Integer> orig = new ArrayList<Integer>();
178 for (int i = 0; i < 10; i++) {
179 orig.add((i + 1) * rng.nextInt());
180 }
181
182 assertShuffleMatchesPermutationSamplerShuffle(new ArrayList<Integer>(orig), 4, true);
183 assertShuffleMatchesPermutationSamplerShuffle(new ArrayList<Integer>(orig), 4, false);
184 assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig), 4, true);
185 assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig), 4, false);
186 }
187
188
189
190
191
192 @Test
193 public void testShuffleWithSmallLinkedList() {
194 final int size = 3;
195 final List<Integer> orig = new ArrayList<Integer>();
196 for (int i = 0; i < size; i++) {
197 orig.add((i + 1) * rng.nextInt());
198 }
199
200
201
202
203
204 for (int i = 0; i < 10; i++) {
205 assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig), size - 1, true);
206 }
207 }
208
209
210
211
212
213
214
215
216 private <T> boolean compare(List<T> orig,
217 List<T> list,
218 int start,
219 int end,
220 boolean same) {
221 for (int i = start; i < end; i++) {
222 if (!orig.get(i).equals(list.get(i))) {
223 return same ? false : true;
224 }
225 }
226 return same ? true : false;
227 }
228
229 private <T extends Set<String>> int findSample(List<T> u,
230 Collection<String> sampList) {
231 final String[] samp = sampList.toArray(new String[sampList.size()]);
232 for (int i = 0; i < u.size(); i++) {
233 final T set = u.get(i);
234 final HashSet<String> sampSet = new HashSet<String>();
235 for (int j = 0; j < samp.length; j++) {
236 sampSet.add(samp[j]);
237 }
238 if (set.equals(sampSet)) {
239 return i;
240 }
241 }
242
243 Assert.fail("Sample not found: { " +
244 samp[0] + ", " + samp[1] + " }");
245 return -1;
246 }
247
248
249
250
251
252
253 private static void assertShuffleMatchesPermutationSamplerShuffle(List<Integer> list) {
254 final int[] array = new int[list.size()];
255 ListIterator<Integer> it = list.listIterator();
256 for (int i = 0; i < array.length; i++) {
257 array[i] = it.next();
258 }
259
260
261 final long seed = RandomSource.createLong();
262 final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
263 final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
264
265 ListSampler.shuffle(rng1, list);
266 PermutationSampler.shuffle(rng2, array);
267
268 final String msg = "Type=" + list.getClass().getSimpleName();
269 it = list.listIterator();
270 for (int i = 0; i < array.length; i++) {
271 Assert.assertEquals(msg, array[i], it.next().intValue());
272 }
273 }
274
275
276
277
278
279
280
281
282
283 private static void assertShuffleMatchesPermutationSamplerShuffle(List<Integer> list,
284 int start,
285 boolean towardHead) {
286 final int[] array = new int[list.size()];
287 ListIterator<Integer> it = list.listIterator();
288 for (int i = 0; i < array.length; i++) {
289 array[i] = it.next();
290 }
291
292
293 final long seed = RandomSource.createLong();
294 final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
295 final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
296
297 ListSampler.shuffle(rng1, list, start, towardHead);
298 PermutationSampler.shuffle(rng2, array, start, towardHead);
299
300 final String msg = String.format("Type=%s start=%d towardHead=%b",
301 list.getClass().getSimpleName(), start, towardHead);
302 it = list.listIterator();
303 for (int i = 0; i < array.length; i++) {
304 Assert.assertEquals(msg, array[i], it.next().intValue());
305 }
306 }
307 }