View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling;
18  
19  import java.util.Arrays;
20  
21  import org.junit.Assert;
22  import org.junit.Test;
23  
24  import org.apache.commons.math3.stat.inference.ChiSquareTest;
25  
26  import org.apache.commons.rng.UniformRandomProvider;
27  import org.apache.commons.rng.simple.RandomSource;
28  
29  /**
30   * Tests for {@link PermutationSampler}.
31   */
32  public class PermutationSamplerTest {
33      private final UniformRandomProvider rng = RandomSource.create(RandomSource.ISAAC, 1232343456L);
34      private final ChiSquareTest chiSquareTest = new ChiSquareTest();
35  
36      @Test
37      public void testSampleTrivial() {
38          final int n = 6;
39          final int k = 3;
40          final PermutationSampler sampler = new PermutationSampler(RandomSource.create(RandomSource.KISS),
41                                                                    n, k);
42          final int[] random = sampler.sample();
43          SAMPLE: for (int s : random) {
44              for (int i = 0; i < n; i++) {
45                  if (i == s) {
46                      continue SAMPLE;
47                  }
48              }
49              Assert.fail("number " + s + " not in array");
50          }
51      }
52  
53      @Test
54      public void testSampleChiSquareTest() {
55          final int n = 3;
56          final int k = 3;
57          final int[][] p = {{0, 1, 2}, {0, 2, 1},
58                             {1, 0, 2}, {1, 2, 0},
59                             {2, 0, 1}, {2, 1, 0}};
60          runSampleChiSquareTest(n, k, p);
61      }
62  
63      @Test
64      public void testSubSampleChiSquareTest() {
65          final int n = 4;
66          final int k = 2;
67          final int[][] p = {{0, 1}, {1, 0},
68                             {0, 2}, {2, 0},
69                             {0, 3}, {3, 0},
70                             {1, 2}, {2, 1},
71                             {1, 3}, {3, 1},
72                             {2, 3}, {3, 2}};
73          runSampleChiSquareTest(n, k, p);
74      }
75  
76      @Test
77      public void testSampleBoundaryCase() {
78          // Check size = 1 boundary case.
79          final PermutationSampler sampler = new PermutationSampler(rng, 1, 1);
80          final int[] perm = sampler.sample();
81          Assert.assertEquals(1, perm.length);
82          Assert.assertEquals(0, perm[0]);
83      }
84  
85      @Test(expected = IllegalArgumentException.class)
86      public void testSamplePrecondition1() {
87          // Must fail for k > n.
88          new PermutationSampler(rng, 2, 3);
89      }
90  
91      @Test(expected = IllegalArgumentException.class)
92      public void testSamplePrecondition2() {
93          // Must fail for n = 0.
94          new PermutationSampler(rng, 0, 0);
95      }
96  
97      @Test(expected = IllegalArgumentException.class)
98      public void testSamplePrecondition3() {
99          // Must fail for k < n < 0.
100         new PermutationSampler(rng, -1, 0);
101     }
102 
103     @Test(expected = IllegalArgumentException.class)
104     public void testSamplePrecondition4() {
105         // Must fail for k < n < 0.
106         new PermutationSampler(rng, 1, -1);
107     }
108 
109     @Test
110     public void testNatural() {
111         final int n = 4;
112         final int[] expected = {0, 1, 2, 3};
113 
114         final int[] natural = PermutationSampler.natural(n);
115         for (int i = 0; i < n; i++) {
116             Assert.assertEquals(expected[i], natural[i]);
117         }
118     }
119 
120     @Test
121     public void testNaturalZero() {
122         final int[] natural = PermutationSampler.natural(0);
123         Assert.assertEquals(0, natural.length);
124     }
125 
126     @Test
127     public void testShuffleNoDuplicates() {
128         final int n = 100;
129         final int[] orig = PermutationSampler.natural(n);
130         PermutationSampler.shuffle(rng, orig);
131 
132         // Test that all (unique) entries exist in the shuffled array.
133         final int[] count = new int[n];
134         for (int i = 0; i < n; i++) {
135             count[orig[i]] += 1;
136         }
137 
138         for (int i = 0; i < n; i++) {
139             Assert.assertEquals(1, count[i]);
140         }
141     }
142 
143     @Test
144     public void testShuffleTail() {
145         final int[] orig = new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
146         final int[] list = orig.clone();
147         final int start = 4;
148         PermutationSampler.shuffle(rng, list, start, false);
149 
150         // Ensure that all entries below index "start" did not move.
151         for (int i = 0; i < start; i++) {
152             Assert.assertEquals(orig[i], list[i]);
153         }
154 
155         // Ensure that at least one entry has moved.
156         boolean ok = false;
157         for (int i = start; i < orig.length - 1; i++) {
158             if (orig[i] != list[i]) {
159                 ok = true;
160                 break;
161             }
162         }
163         Assert.assertTrue(ok);
164     }
165 
166     @Test
167     public void testShuffleHead() {
168         final int[] orig = new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
169         final int[] list = orig.clone();
170         final int start = 4;
171         PermutationSampler.shuffle(rng, list, start, true);
172 
173         // Ensure that all entries above index "start" did not move.
174         for (int i = start + 1; i < orig.length; i++) {
175             Assert.assertEquals(orig[i], list[i]);
176         }
177 
178         // Ensure that at least one entry has moved.
179         boolean ok = false;
180         for (int i = 0; i <= start; i++) {
181             if (orig[i] != list[i]) {
182                 ok = true;
183                 break;
184             }
185         }
186         Assert.assertTrue(ok);
187     }
188 
189     /**
190      * Test the SharedStateSampler implementation.
191      */
192     @Test
193     public void testSharedStateSampler() {
194         final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
195         final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
196         final int n = 17;
197         final int k = 13;
198         final PermutationSampler sampler1 =
199             new PermutationSampler(rng1, n, k);
200         final PermutationSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
201         RandomAssert.assertProduceSameSequence(
202             new RandomAssert.Sampler<int[]>() {
203                 @Override
204                 public int[] sample() {
205                     return sampler1.sample();
206                 }
207             },
208             new RandomAssert.Sampler<int[]>() {
209                 @Override
210                 public int[] sample() {
211                     return sampler2.sample();
212                 }
213             });
214     }
215 
216     //// Support methods.
217 
218     private void runSampleChiSquareTest(int n,
219                                         int k,
220                                         int[][] p) {
221         final int len = p.length;
222         final long[] observed = new long[len];
223         final int numSamples = 6000;
224         final double numExpected = numSamples / (double) len;
225         final double[] expected = new double[len];
226         Arrays.fill(expected, numExpected);
227 
228         final PermutationSampler sampler = new PermutationSampler(rng, n, k);
229         for (int i = 0; i < numSamples; i++) {
230             observed[findPerm(p, sampler.sample())]++;
231         }
232 
233         // Pass if we cannot reject null hypothesis that distributions are the same.
234         Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
235     }
236 
237     private static int findPerm(int[][] p,
238                                 int[] samp) {
239         for (int i = 0; i < p.length; i++) {
240             if (Arrays.equals(p[i], samp)) {
241                 return i;
242             }
243         }
244         Assert.fail("Permutation not found");
245         return -1;
246     }
247 }