1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling;
18
19 import java.util.Arrays;
20
21 import org.junit.Assert;
22 import org.junit.Test;
23
24 import org.apache.commons.math3.stat.inference.ChiSquareTest;
25 import org.apache.commons.math3.util.CombinatoricsUtils;
26 import org.apache.commons.rng.UniformRandomProvider;
27 import org.apache.commons.rng.simple.RandomSource;
28
29
30
31
32 public class CombinationSamplerTest {
33 private final UniformRandomProvider rng = RandomSource.create(RandomSource.XOR_SHIFT_1024_S);
34
35 @Test
36 public void testSampleIsInDomain() {
37 final int n = 6;
38 for (int k = 1; k <= n; k++) {
39 final CombinationSampler sampler = new CombinationSampler(rng, n, k);
40 final int[] random = sampler.sample();
41 for (int s : random) {
42 assertIsInDomain(n, s);
43 }
44 }
45 }
46
47 @Test
48 public void testUniformWithKlessThanHalfN() {
49 final int n = 8;
50 final int k = 2;
51 assertUniformSamples(n, k);
52 }
53
54 @Test
55 public void testUniformWithKmoreThanHalfN() {
56 final int n = 8;
57 final int k = 6;
58 assertUniformSamples(n, k);
59 }
60
61 @Test
62 public void testSampleWhenNequalsKIsNotShuffled() {
63
64
65 for (int n = 1; n < 3; n++) {
66 final int k = n;
67 final CombinationSampler sampler = new CombinationSampler(rng, n, k);
68 final int[] sample = sampler.sample();
69 Assert.assertEquals("Incorrect sample length", n, sample.length);
70 for (int i = 0; i < n; i++) {
71 Assert.assertEquals("Sample was shuffled", i, sample[i]);
72 }
73 }
74 }
75
76 @Test(expected = IllegalArgumentException.class)
77 public void testKgreaterThanNThrows() {
78
79 final int n = 2;
80 final int k = 3;
81 new CombinationSampler(rng, n, k);
82 }
83
84 @Test(expected = IllegalArgumentException.class)
85 public void testNequalsZeroThrows() {
86
87 final int n = 0;
88 final int k = 3;
89 new CombinationSampler(rng, n, k);
90 }
91
92 @Test(expected = IllegalArgumentException.class)
93 public void testKequalsZeroThrows() {
94
95 final int n = 2;
96 final int k = 0;
97 new CombinationSampler(rng, n, k);
98 }
99
100 @Test(expected = IllegalArgumentException.class)
101 public void testNisNegativeThrows() {
102
103 final int n = -1;
104 final int k = 3;
105 new CombinationSampler(rng, n, k);
106 }
107
108 @Test(expected = IllegalArgumentException.class)
109 public void testKisNegativeThrows() {
110
111 final int n = 0;
112 final int k = -1;
113 new CombinationSampler(rng, n, k);
114 }
115
116
117
118
119 @Test
120 public void testSharedStateSampler() {
121 final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
122 final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
123 final int n = 17;
124 final int k = 3;
125 final CombinationSampler sampler1 =
126 new CombinationSampler(rng1, n, k);
127 final CombinationSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
128 RandomAssert.assertProduceSameSequence(
129 new RandomAssert.Sampler<int[]>() {
130 @Override
131 public int[] sample() {
132 return sampler1.sample();
133 }
134 },
135 new RandomAssert.Sampler<int[]>() {
136 @Override
137 public int[] sample() {
138 return sampler2.sample();
139 }
140 });
141 }
142
143
144
145
146
147
148
149
150
151 private static void assertIsInDomain(int n, int value) {
152 if (value < 0 || value >= n) {
153 Assert.fail("sample " + value + " not in the domain " + n);
154 }
155 }
156
157 private void assertUniformSamples(int n, int k) {
158
159
160
161
162
163 final int totalBitCombinations = 1 << n;
164 int[] codeLookup = new int[totalBitCombinations];
165 Arrays.fill(codeLookup, -1);
166 int codes = 0;
167 for (int i = 0; i < totalBitCombinations; i++) {
168 if (Integer.bitCount(i) == k) {
169
170 codeLookup[i] = codes++;
171 }
172 }
173
174
175 Assert.assertEquals("Incorrect number of combination codes",
176 CombinatoricsUtils.binomialCoefficient(n, k), codes);
177
178 final long[] observed = new long[codes];
179 final int numSamples = 6000;
180
181 final CombinationSampler sampler = new CombinationSampler(rng, n, k);
182 for (int i = 0; i < numSamples; i++) {
183 observed[findCode(codeLookup, sampler.sample())]++;
184 }
185
186
187 final double numExpected = numSamples / (double) codes;
188 final double[] expected = new double[codes];
189 Arrays.fill(expected, numExpected);
190 final ChiSquareTest chiSquareTest = new ChiSquareTest();
191
192 Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
193 }
194
195 private static int findCode(int[] codeLookup, int[] sample) {
196
197
198 int bits = 0;
199 for (int s : sample) {
200
201
202 bits |= 1 << s;
203 }
204 if (bits >= codeLookup.length) {
205 Assert.fail("Bad bit combination: " + Arrays.toString(sample));
206 }
207 final int code = codeLookup[bits];
208 if (code < 0) {
209 Assert.fail("Bad bit code: " + Arrays.toString(sample));
210 }
211 return code;
212 }
213 }