1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling;
18
19 import java.util.Set;
20 import java.util.HashSet;
21 import java.util.List;
22 import java.util.ArrayList;
23 import java.util.Arrays;
24
25 import org.junit.Assert;
26 import org.junit.Test;
27
28 import org.apache.commons.math3.stat.inference.ChiSquareTest;
29 import org.apache.commons.math3.util.MathArrays;
30
31 import org.apache.commons.rng.UniformRandomProvider;
32 import org.apache.commons.rng.simple.RandomSource;
33
34
35
36
37 public class PermutationSamplerTest {
38 private final UniformRandomProvider rng = RandomSource.create(RandomSource.ISAAC, 1232343456L);
39 private final ChiSquareTest chiSquareTest = new ChiSquareTest();
40
41 @Test
42 public void testSampleTrivial() {
43 final int n = 6;
44 final int k = 3;
45 final PermutationSampler sampler = new PermutationSampler(RandomSource.create(RandomSource.KISS),
46 6, 3);
47 final int[] random = sampler.sample();
48 SAMPLE: for (int s : random) {
49 for (int i = 0; i < n; i++) {
50 if (i == s) {
51 continue SAMPLE;
52 }
53 }
54 Assert.fail("number " + s + " not in array");
55 }
56 }
57
58 @Test
59 public void testSampleChiSquareTest() {
60 final int[][] p = { { 0, 1, 2 }, { 0, 2, 1 },
61 { 1, 0, 2 }, { 1, 2, 0 },
62 { 2, 0, 1 }, { 2, 1, 0 } };
63 final int len = p.length;
64 final long[] observed = new long[len];
65 final int numSamples = 6000;
66 final double numExpected = numSamples / (double) len;
67 final double[] expected = new double[len];
68 Arrays.fill(expected, numExpected);
69
70 final PermutationSampler sampler = new PermutationSampler(rng, 3, 3);
71 for (int i = 0; i < numSamples; i++) {
72 observed[findPerm(p, sampler.sample())]++;
73 }
74
75
76 Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
77 }
78
79 @Test
80 public void testSampleBoundaryCase() {
81
82 final PermutationSampler sampler = new PermutationSampler(rng, 1, 1);
83 final int[] perm = sampler.sample();
84 Assert.assertEquals(1, perm.length);
85 Assert.assertEquals(0, perm[0]);
86 }
87
88 @Test(expected=IllegalArgumentException.class)
89 public void testSamplePrecondition1() {
90
91 new PermutationSampler(rng, 2, 3);
92 }
93
94 @Test(expected=IllegalArgumentException.class)
95 public void testSamplePrecondition2() {
96
97 new PermutationSampler(rng, 0, 0);
98 }
99
100 @Test(expected=IllegalArgumentException.class)
101 public void testSamplePrecondition3() {
102
103 new PermutationSampler(rng, -1, 0);
104 }
105
106 @Test(expected=IllegalArgumentException.class)
107 public void testSamplePrecondition4() {
108
109 new PermutationSampler(rng, 1, -1);
110 }
111
112 @Test
113 public void testNatural() {
114 final int n = 4;
115 final int[] expected = {0, 1, 2, 3};
116
117 final int[] natural = PermutationSampler.natural(n);
118 for (int i = 0; i < n; i++) {
119 Assert.assertEquals(expected[i], natural[i]);
120 }
121 }
122
123 @Test
124 public void testNaturalZero() {
125 final int[] natural = PermutationSampler.natural(0);
126 Assert.assertEquals(0, natural.length);
127 }
128
129 @Test
130 public void testShuffleNoDuplicates() {
131 final int n = 100;
132 final int[] orig = PermutationSampler.natural(n);
133 PermutationSampler.shuffle(rng, orig);
134
135
136 final int[] count = new int[n];
137 for (int i = 0; i < n; i++) {
138 count[orig[i]] += 1;
139 }
140
141 for (int i = 0; i < n; i++) {
142 Assert.assertEquals(1, count[i]);
143 }
144 }
145
146 @Test
147 public void testShuffleTail() {
148 final int[] orig = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
149 final int[] list = orig.clone();
150 final int start = 4;
151 PermutationSampler.shuffle(rng, list, start, false);
152
153
154 for (int i = 0; i < start; i++) {
155 Assert.assertEquals(orig[i], list[i]);
156 }
157
158
159 boolean ok = false;
160 for (int i = start; i < orig.length - 1; i++) {
161 if (orig[i] != list[i]) {
162 ok = true;
163 break;
164 }
165 }
166 Assert.assertTrue(ok);
167 }
168
169 @Test
170 public void testShuffleHead() {
171 final int[] orig = new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
172 final int[] list = orig.clone();
173 final int start = 4;
174 PermutationSampler.shuffle(rng, list, start, true);
175
176
177 for (int i = start + 1; i < orig.length; i++) {
178 Assert.assertEquals(orig[i], list[i]);
179 }
180
181
182 boolean ok = false;
183 for (int i = 0; i <= start; i++) {
184 if (orig[i] != list[i]) {
185 ok = true;
186 break;
187 }
188 }
189 Assert.assertTrue(ok);
190 }
191
192
193
194 private int findPerm(int[][] p,
195 int[] samp) {
196 for (int i = 0; i < p.length; i++) {
197 boolean good = true;
198 for (int j = 0; j < samp.length; j++) {
199 if (samp[j] != p[i][j]) {
200 good = false;
201 }
202 }
203 if (good) {
204 return i;
205 }
206 }
207
208 Assert.fail("Permutation not found");
209 return -1;
210 }
211 }