View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling;
18  
19  import java.util.Set;
20  import java.util.HashSet;
21  import java.util.List;
22  import java.util.ArrayList;
23  import java.util.Collection;
24  import java.util.Arrays;
25  
26  import org.junit.Assert;
27  import org.junit.Test;
28  
29  import org.apache.commons.math3.stat.inference.ChiSquareTest;
30  
31  import org.apache.commons.rng.UniformRandomProvider;
32  import org.apache.commons.rng.simple.RandomSource;
33  
34  /**
35   * Tests for {@link ListSampler}.
36   */
37  public class ListSamplerTest {
38      private final UniformRandomProvider rng = RandomSource.create(RandomSource.ISAAC, 6543432321L);
39      private final ChiSquareTest chiSquareTest = new ChiSquareTest();
40  
41      @Test
42      public void testSample() {
43          final String[][] c = { { "0", "1" }, { "0", "2" }, { "0", "3" }, { "0", "4" },
44                                 { "1", "2" }, { "1", "3" }, { "1", "4" },
45                                 { "2", "3" }, { "2", "4" },
46                                 { "3", "4" } };
47          final long[] observed = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
48          final double[] expected = { 100, 100, 100, 100, 100, 100, 100, 100, 100, 100 };
49  
50          final HashSet<String> cPop = new HashSet<String>(); // {0, 1, 2, 3, 4}.
51          for (int i = 0; i < 5; i++) {
52              cPop.add(Integer.toString(i));
53          }
54  
55          final List<Set<String>> sets = new ArrayList<Set<String>>(); // 2-sets from 5.
56          for (int i = 0; i < 10; i++) {
57              final HashSet<String> hs = new HashSet<String>();
58              hs.add(c[i][0]);
59              hs.add(c[i][1]);
60              sets.add(hs);
61          }
62  
63          for (int i = 0; i < 1000; i++) {
64              observed[findSample(sets, ListSampler.sample(rng, new ArrayList<String>(cPop), 2))]++;
65          }
66  
67          // Pass if we cannot reject null hypothesis that distributions are the same.
68          Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
69      }
70  
71      @Test
72      public void testSampleWhole() {
73          // Sample of size = size of collection must return the same collection.
74          final List<String> list = new ArrayList<String>();
75          list.add("one");
76  
77          final List<String> one = ListSampler.sample(rng, list, 1);
78          Assert.assertEquals(1, one.size());
79          Assert.assertTrue(one.contains("one"));
80      }
81  
82      @Test(expected=IllegalArgumentException.class)
83      public void testSamplePrecondition1() {
84          // Must fail for sample size > collection size.
85          final List<String> list = new ArrayList<String>();
86          list.add("one");
87          ListSampler.sample(rng, list, 2);
88      }
89  
90      @Test(expected=IllegalArgumentException.class)
91      public void testSamplePrecondition2() {
92          // Must fail for empty collection.
93          final List<String> list = new ArrayList<String>();
94          ListSampler.sample(rng, list, 1);
95      }
96  
97      @Test
98      public void testShuffle() {
99          final List<Integer> orig = new ArrayList<Integer>();
100         for (int i = 0; i < 10; i++) {
101             orig.add((i + 1) * rng.nextInt());
102         }
103         final List<Integer> list = new ArrayList<Integer>(orig);
104 
105         ListSampler.shuffle(rng, list);
106         // Ensure that at least one entry has moved.
107         Assert.assertTrue(compare(orig, list, 0, orig.size(), false));
108     }
109 
110     @Test
111     public void testShuffleTail() {
112         final List<Integer> orig = new ArrayList<Integer>();
113         for (int i = 0; i < 10; i++) {
114             orig.add((i + 1) * rng.nextInt());
115         }
116         final List<Integer> list = new ArrayList<Integer>(orig);
117 
118         final int start = 4;
119         ListSampler.shuffle(rng, list, start, false);
120 
121         // Ensure that all entries below index "start" did not move.
122         Assert.assertTrue(compare(orig, list, 0, start, true));
123 
124         // Ensure that at least one entry has moved.
125         Assert.assertTrue(compare(orig, list, start, orig.size(), false));
126     }
127 
128     @Test
129     public void testShuffleHead() {
130         final List<Integer> orig = new ArrayList<Integer>();
131         for (int i = 0; i < 10; i++) {
132             orig.add((i + 1) * rng.nextInt());
133         }
134         final List<Integer> list = new ArrayList<Integer>(orig);
135 
136         final int start = 4;
137         ListSampler.shuffle(rng, list, start, true);
138 
139         // Ensure that all entries above index "start" did not move.
140         Assert.assertTrue(compare(orig, list, start + 1, orig.size(), true));
141 
142         // Ensure that at least one entry has moved.
143         Assert.assertTrue(compare(orig, list, 0, start + 1, false));
144     }
145 
146     //// Support methods.
147 
148     /**
149      * If {@code same == true}, return {@code true} if all entries are
150      * the same; if {@code same == false}, return {@code true} if at
151      * least one entry is different.
152      */
153     private <T> boolean compare(List<T> orig,
154                                 List<T> list,
155                                 int start,
156                                 int end,
157                                 boolean same) {
158         for (int i = start; i < end; i++) {
159             if (!orig.get(i).equals(list.get(i))) {
160                 return same ? false : true;
161             }
162         }
163         return same ? true : false;
164     }
165 
166     private <T extends Set<String>> int findSample(List<T> u,
167                                                    Collection<String> sampList) {
168         final String[] samp = sampList.toArray(new String[sampList.size()]);
169         for (int i = 0; i < u.size(); i++) {
170             final T set = u.get(i);
171             final HashSet<String> sampSet = new HashSet<String>();
172             for (int j = 0; j < samp.length; j++) {
173                 sampSet.add(samp[j]);
174             }
175             if (set.equals(sampSet)) {
176                 return i;
177             }
178         }
179 
180         Assert.fail("Sample not found: { " +
181                     samp[0] + ", " + samp[1] + " }");
182         return -1;
183     }
184 }