1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling;
18
19 import java.util.Arrays;
20
21 import org.junit.Assert;
22 import org.junit.Test;
23
24 import org.apache.commons.math3.stat.inference.ChiSquareTest;
25
26 import org.apache.commons.rng.UniformRandomProvider;
27 import org.apache.commons.rng.simple.RandomSource;
28
29
30
31
32 public class PermutationSamplerTest {
33 private final UniformRandomProvider rng = RandomSource.create(RandomSource.ISAAC, 1232343456L);
34 private final ChiSquareTest chiSquareTest = new ChiSquareTest();
35
36 @Test
37 public void testSampleTrivial() {
38 final int n = 6;
39 final int k = 3;
40 final PermutationSampler sampler = new PermutationSampler(RandomSource.create(RandomSource.KISS),
41 n, k);
42 final int[] random = sampler.sample();
43 SAMPLE: for (int s : random) {
44 for (int i = 0; i < n; i++) {
45 if (i == s) {
46 continue SAMPLE;
47 }
48 }
49 Assert.fail("number " + s + " not in array");
50 }
51 }
52
53 @Test
54 public void testSampleChiSquareTest() {
55 final int n = 3;
56 final int k = 3;
57 final int[][] p = {{0, 1, 2}, {0, 2, 1},
58 {1, 0, 2}, {1, 2, 0},
59 {2, 0, 1}, {2, 1, 0}};
60 runSampleChiSquareTest(n, k, p);
61 }
62
63 @Test
64 public void testSubSampleChiSquareTest() {
65 final int n = 4;
66 final int k = 2;
67 final int[][] p = {{0, 1}, {1, 0},
68 {0, 2}, {2, 0},
69 {0, 3}, {3, 0},
70 {1, 2}, {2, 1},
71 {1, 3}, {3, 1},
72 {2, 3}, {3, 2}};
73 runSampleChiSquareTest(n, k, p);
74 }
75
76 @Test
77 public void testSampleBoundaryCase() {
78
79 final PermutationSampler sampler = new PermutationSampler(rng, 1, 1);
80 final int[] perm = sampler.sample();
81 Assert.assertEquals(1, perm.length);
82 Assert.assertEquals(0, perm[0]);
83 }
84
85 @Test(expected = IllegalArgumentException.class)
86 public void testSamplePrecondition1() {
87
88 new PermutationSampler(rng, 2, 3);
89 }
90
91 @Test(expected = IllegalArgumentException.class)
92 public void testSamplePrecondition2() {
93
94 new PermutationSampler(rng, 0, 0);
95 }
96
97 @Test(expected = IllegalArgumentException.class)
98 public void testSamplePrecondition3() {
99
100 new PermutationSampler(rng, -1, 0);
101 }
102
103 @Test(expected = IllegalArgumentException.class)
104 public void testSamplePrecondition4() {
105
106 new PermutationSampler(rng, 1, -1);
107 }
108
109 @Test
110 public void testNatural() {
111 final int n = 4;
112 final int[] expected = {0, 1, 2, 3};
113
114 final int[] natural = PermutationSampler.natural(n);
115 for (int i = 0; i < n; i++) {
116 Assert.assertEquals(expected[i], natural[i]);
117 }
118 }
119
120 @Test
121 public void testNaturalZero() {
122 final int[] natural = PermutationSampler.natural(0);
123 Assert.assertEquals(0, natural.length);
124 }
125
126 @Test
127 public void testShuffleNoDuplicates() {
128 final int n = 100;
129 final int[] orig = PermutationSampler.natural(n);
130 PermutationSampler.shuffle(rng, orig);
131
132
133 final int[] count = new int[n];
134 for (int i = 0; i < n; i++) {
135 count[orig[i]] += 1;
136 }
137
138 for (int i = 0; i < n; i++) {
139 Assert.assertEquals(1, count[i]);
140 }
141 }
142
143 @Test
144 public void testShuffleTail() {
145 final int[] orig = new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
146 final int[] list = orig.clone();
147 final int start = 4;
148 PermutationSampler.shuffle(rng, list, start, false);
149
150
151 for (int i = 0; i < start; i++) {
152 Assert.assertEquals(orig[i], list[i]);
153 }
154
155
156 boolean ok = false;
157 for (int i = start; i < orig.length - 1; i++) {
158 if (orig[i] != list[i]) {
159 ok = true;
160 break;
161 }
162 }
163 Assert.assertTrue(ok);
164 }
165
166 @Test
167 public void testShuffleHead() {
168 final int[] orig = new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
169 final int[] list = orig.clone();
170 final int start = 4;
171 PermutationSampler.shuffle(rng, list, start, true);
172
173
174 for (int i = start + 1; i < orig.length; i++) {
175 Assert.assertEquals(orig[i], list[i]);
176 }
177
178
179 boolean ok = false;
180 for (int i = 0; i <= start; i++) {
181 if (orig[i] != list[i]) {
182 ok = true;
183 break;
184 }
185 }
186 Assert.assertTrue(ok);
187 }
188
189
190
191
192 @Test
193 public void testSharedStateSampler() {
194 final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
195 final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
196 final int n = 17;
197 final int k = 13;
198 final PermutationSampler sampler1 =
199 new PermutationSampler(rng1, n, k);
200 final PermutationSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
201 RandomAssert.assertProduceSameSequence(
202 new RandomAssert.Sampler<int[]>() {
203 @Override
204 public int[] sample() {
205 return sampler1.sample();
206 }
207 },
208 new RandomAssert.Sampler<int[]>() {
209 @Override
210 public int[] sample() {
211 return sampler2.sample();
212 }
213 });
214 }
215
216
217
218 private void runSampleChiSquareTest(int n,
219 int k,
220 int[][] p) {
221 final int len = p.length;
222 final long[] observed = new long[len];
223 final int numSamples = 6000;
224 final double numExpected = numSamples / (double) len;
225 final double[] expected = new double[len];
226 Arrays.fill(expected, numExpected);
227
228 final PermutationSampler sampler = new PermutationSampler(rng, n, k);
229 for (int i = 0; i < numSamples; i++) {
230 observed[findPerm(p, sampler.sample())]++;
231 }
232
233
234 Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
235 }
236
237 private static int findPerm(int[][] p,
238 int[] samp) {
239 for (int i = 0; i < p.length; i++) {
240 if (Arrays.equals(p[i], samp)) {
241 return i;
242 }
243 }
244 Assert.fail("Permutation not found");
245 return -1;
246 }
247 }