View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling;
18  
19  import java.util.Set;
20  import java.util.HashSet;
21  import java.util.LinkedList;
22  import java.util.List;
23  import java.util.ListIterator;
24  import java.util.ArrayList;
25  import java.util.Collection;
26  
27  import org.junit.Assert;
28  import org.junit.Test;
29  
30  import org.apache.commons.math3.stat.inference.ChiSquareTest;
31  
32  import org.apache.commons.rng.UniformRandomProvider;
33  import org.apache.commons.rng.simple.RandomSource;
34  
35  /**
36   * Tests for {@link ListSampler}.
37   */
38  public class ListSamplerTest {
39      private final UniformRandomProvider rng = RandomSource.create(RandomSource.ISAAC, 6543432321L);
40      private final ChiSquareTest chiSquareTest = new ChiSquareTest();
41  
42      @Test
43      public void testSample() {
44          final String[][] c = {{"0", "1"}, {"0", "2"}, {"0", "3"}, {"0", "4"},
45                                {"1", "2"}, {"1", "3"}, {"1", "4"},
46                                {"2", "3"}, {"2", "4"},
47                                {"3", "4"}};
48          final long[] observed = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
49          final double[] expected = {100, 100, 100, 100, 100, 100, 100, 100, 100, 100};
50  
51          final HashSet<String> cPop = new HashSet<String>(); // {0, 1, 2, 3, 4}.
52          for (int i = 0; i < 5; i++) {
53              cPop.add(Integer.toString(i));
54          }
55  
56          final List<Set<String>> sets = new ArrayList<Set<String>>(); // 2-sets from 5.
57          for (int i = 0; i < 10; i++) {
58              final HashSet<String> hs = new HashSet<String>();
59              hs.add(c[i][0]);
60              hs.add(c[i][1]);
61              sets.add(hs);
62          }
63  
64          for (int i = 0; i < 1000; i++) {
65              observed[findSample(sets, ListSampler.sample(rng, new ArrayList<String>(cPop), 2))]++;
66          }
67  
68          // Pass if we cannot reject null hypothesis that distributions are the same.
69          Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
70      }
71  
72      @Test
73      public void testSampleWhole() {
74          // Sample of size = size of collection must return the same collection.
75          final List<String> list = new ArrayList<String>();
76          list.add("one");
77  
78          final List<String> one = ListSampler.sample(rng, list, 1);
79          Assert.assertEquals(1, one.size());
80          Assert.assertTrue(one.contains("one"));
81      }
82  
83      @Test(expected = IllegalArgumentException.class)
84      public void testSamplePrecondition1() {
85          // Must fail for sample size > collection size.
86          final List<String> list = new ArrayList<String>();
87          list.add("one");
88          ListSampler.sample(rng, list, 2);
89      }
90  
91      @Test(expected = IllegalArgumentException.class)
92      public void testSamplePrecondition2() {
93          // Must fail for empty collection.
94          final List<String> list = new ArrayList<String>();
95          ListSampler.sample(rng, list, 1);
96      }
97  
98      @Test
99      public void testShuffle() {
100         final List<Integer> orig = new ArrayList<Integer>();
101         for (int i = 0; i < 10; i++) {
102             orig.add((i + 1) * rng.nextInt());
103         }
104 
105         final List<Integer> arrayList = new ArrayList<Integer>(orig);
106 
107         ListSampler.shuffle(rng, arrayList);
108         // Ensure that at least one entry has moved.
109         Assert.assertTrue("ArrayList", compare(orig, arrayList, 0, orig.size(), false));
110 
111         final List<Integer> linkedList = new LinkedList<Integer>(orig);
112 
113         ListSampler.shuffle(rng, linkedList);
114         // Ensure that at least one entry has moved.
115         Assert.assertTrue("LinkedList", compare(orig, linkedList, 0, orig.size(), false));
116     }
117 
118     @Test
119     public void testShuffleTail() {
120         final List<Integer> orig = new ArrayList<Integer>();
121         for (int i = 0; i < 10; i++) {
122             orig.add((i + 1) * rng.nextInt());
123         }
124         final List<Integer> list = new ArrayList<Integer>(orig);
125 
126         final int start = 4;
127         ListSampler.shuffle(rng, list, start, false);
128 
129         // Ensure that all entries below index "start" did not move.
130         Assert.assertTrue(compare(orig, list, 0, start, true));
131 
132         // Ensure that at least one entry has moved.
133         Assert.assertTrue(compare(orig, list, start, orig.size(), false));
134     }
135 
136     @Test
137     public void testShuffleHead() {
138         final List<Integer> orig = new ArrayList<Integer>();
139         for (int i = 0; i < 10; i++) {
140             orig.add((i + 1) * rng.nextInt());
141         }
142         final List<Integer> list = new ArrayList<Integer>(orig);
143 
144         final int start = 4;
145         ListSampler.shuffle(rng, list, start, true);
146 
147         // Ensure that all entries above index "start" did not move.
148         Assert.assertTrue(compare(orig, list, start + 1, orig.size(), true));
149 
150         // Ensure that at least one entry has moved.
151         Assert.assertTrue(compare(orig, list, 0, start + 1, false));
152     }
153 
154     /**
155      * Test shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[])}.
156      * The implementation may be different but the result is a Fisher-Yates shuffle so the
157      * output order should match.
158      */
159     @Test
160     public void testShuffleMatchesPermutationSamplerShuffle() {
161         final List<Integer> orig = new ArrayList<Integer>();
162         for (int i = 0; i < 10; i++) {
163             orig.add((i + 1) * rng.nextInt());
164         }
165 
166         assertShuffleMatchesPermutationSamplerShuffle(new ArrayList<Integer>(orig));
167         assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig));
168     }
169 
170     /**
171      * Test shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[], int, boolean)}.
172      * The implementation may be different but the result is a Fisher-Yates shuffle so the
173      * output order should match.
174      */
175     @Test
176     public void testShuffleMatchesPermutationSamplerShuffleDirectional() {
177         final List<Integer> orig = new ArrayList<Integer>();
178         for (int i = 0; i < 10; i++) {
179             orig.add((i + 1) * rng.nextInt());
180         }
181 
182         assertShuffleMatchesPermutationSamplerShuffle(new ArrayList<Integer>(orig), 4, true);
183         assertShuffleMatchesPermutationSamplerShuffle(new ArrayList<Integer>(orig), 4, false);
184         assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig), 4, true);
185         assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig), 4, false);
186     }
187 
188     /**
189      * This test hits the edge case when a LinkedList is small enough that the algorithm
190      * using a RandomAccess list is faster than the one with an iterator.
191      */
192     @Test
193     public void testShuffleWithSmallLinkedList() {
194         final int size = 3;
195         final List<Integer> orig = new ArrayList<Integer>();
196         for (int i = 0; i < size; i++) {
197             orig.add((i + 1) * rng.nextInt());
198         }
199 
200         // When the size is small there is a chance that the list has no entries that move.
201         // E.g. The number of permutations of 3 items is only 6 giving a 1/6 chance of no change.
202         // So repeat test that the small shuffle matches the PermutationSampler.
203         // 10 times is (1/6)^10 or 1 in 60,466,176 of no change.
204         for (int i = 0; i < 10; i++) {
205             assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig), size - 1, true);
206         }
207     }
208 
209     //// Support methods.
210 
211     /**
212      * If {@code same == true}, return {@code true} if all entries are
213      * the same; if {@code same == false}, return {@code true} if at
214      * least one entry is different.
215      */
216     private <T> boolean compare(List<T> orig,
217                                 List<T> list,
218                                 int start,
219                                 int end,
220                                 boolean same) {
221         for (int i = start; i < end; i++) {
222             if (!orig.get(i).equals(list.get(i))) {
223                 return same ? false : true;
224             }
225         }
226         return same ? true : false;
227     }
228 
229     private <T extends Set<String>> int findSample(List<T> u,
230                                                    Collection<String> sampList) {
231         final String[] samp = sampList.toArray(new String[sampList.size()]);
232         for (int i = 0; i < u.size(); i++) {
233             final T set = u.get(i);
234             final HashSet<String> sampSet = new HashSet<String>();
235             for (int j = 0; j < samp.length; j++) {
236                 sampSet.add(samp[j]);
237             }
238             if (set.equals(sampSet)) {
239                 return i;
240             }
241         }
242 
243         Assert.fail("Sample not found: { " +
244                     samp[0] + ", " + samp[1] + " }");
245         return -1;
246     }
247 
248     /**
249      * Assert the shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[])}.
250      *
251      * @param list Array whose entries will be shuffled (in-place).
252      */
253     private static void assertShuffleMatchesPermutationSamplerShuffle(List<Integer> list) {
254         final int[] array = new int[list.size()];
255         ListIterator<Integer> it = list.listIterator();
256         for (int i = 0; i < array.length; i++) {
257             array[i] = it.next();
258         }
259 
260         // Identical RNGs
261         final long seed = RandomSource.createLong();
262         final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
263         final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
264 
265         ListSampler.shuffle(rng1, list);
266         PermutationSampler.shuffle(rng2, array);
267 
268         final String msg = "Type=" + list.getClass().getSimpleName();
269         it = list.listIterator();
270         for (int i = 0; i < array.length; i++) {
271             Assert.assertEquals(msg, array[i], it.next().intValue());
272         }
273     }
274     /**
275      * Assert the shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[], int, boolean)}.
276      *
277      * @param list Array whose entries will be shuffled (in-place).
278      * @param start Index at which shuffling begins.
279      * @param towardHead Shuffling is performed for index positions between
280      * {@code start} and either the end (if {@code false}) or the beginning
281      * (if {@code true}) of the array.
282      */
283     private static void assertShuffleMatchesPermutationSamplerShuffle(List<Integer> list,
284                                                                     int start,
285                                                                     boolean towardHead) {
286         final int[] array = new int[list.size()];
287         ListIterator<Integer> it = list.listIterator();
288         for (int i = 0; i < array.length; i++) {
289             array[i] = it.next();
290         }
291 
292         // Identical RNGs
293         final long seed = RandomSource.createLong();
294         final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
295         final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
296 
297         ListSampler.shuffle(rng1, list, start, towardHead);
298         PermutationSampler.shuffle(rng2, array, start, towardHead);
299 
300         final String msg = String.format("Type=%s start=%d towardHead=%b",
301                 list.getClass().getSimpleName(), start, towardHead);
302         it = list.listIterator();
303         for (int i = 0; i < array.length; i++) {
304             Assert.assertEquals(msg, array[i], it.next().intValue());
305         }
306     }
307 }