1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling;
18
19 import java.util.Arrays;
20
21 import org.junit.Assert;
22 import org.junit.Test;
23
24 import org.apache.commons.math3.stat.inference.ChiSquareTest;
25 import org.apache.commons.math3.util.CombinatoricsUtils;
26 import org.apache.commons.rng.UniformRandomProvider;
27 import org.apache.commons.rng.simple.RandomSource;
28
29
30
31
32 public class CombinationSamplerTest {
33 private final UniformRandomProvider rng = RandomSource.create(RandomSource.XOR_SHIFT_1024_S);
34
35 @Test
36 public void testSampleIsInDomain() {
37 final int n = 6;
38 for (int k = 1; k <= n; k++) {
39 final CombinationSampler sampler = new CombinationSampler(rng, n, k);
40 final int[] random = sampler.sample();
41 for (int s : random) {
42 assertIsInDomain(n, s);
43 }
44 }
45 }
46
47 @Test
48 public void testUniformWithKlessThanHalfN() {
49 final int n = 8;
50 final int k = 2;
51 assertUniformSamples(n, k);
52 }
53
54 @Test
55 public void testUniformWithKmoreThanHalfN() {
56 final int n = 8;
57 final int k = 6;
58 assertUniformSamples(n, k);
59 }
60
61 @Test
62 public void testSampleWhenNequalsKIsNotShuffled() {
63
64
65 for (int n = 1; n < 3; n++) {
66 final int k = n;
67 final CombinationSampler sampler = new CombinationSampler(rng, n, k);
68 final int[] sample = sampler.sample();
69 Assert.assertEquals("Incorrect sample length", n, sample.length);
70 for (int i = 0; i < n; i++) {
71 Assert.assertEquals("Sample was shuffled", i, sample[i]);
72 }
73 }
74 }
75
76 @Test(expected = IllegalArgumentException.class)
77 public void testKgreaterThanNThrows() {
78
79 final int n = 2;
80 final int k = 3;
81 new CombinationSampler(rng, n, k);
82 }
83
84 @Test(expected = IllegalArgumentException.class)
85 public void testNequalsZeroThrows() {
86
87 final int n = 0;
88 final int k = 3;
89 new CombinationSampler(rng, n, k);
90 }
91
92 @Test(expected = IllegalArgumentException.class)
93 public void testKequalsZeroThrows() {
94
95 final int n = 2;
96 final int k = 0;
97 new CombinationSampler(rng, n, k);
98 }
99
100 @Test(expected = IllegalArgumentException.class)
101 public void testNisNegativeThrows() {
102
103 final int n = -1;
104 final int k = 3;
105 new CombinationSampler(rng, n, k);
106 }
107
108 @Test(expected = IllegalArgumentException.class)
109 public void testKisNegativeThrows() {
110
111 final int n = 0;
112 final int k = -1;
113 new CombinationSampler(rng, n, k);
114 }
115
116
117
118
119
120
121
122
123
124 private static final void assertIsInDomain(int n, int value) {
125 if (value < 0 || value >= n) {
126 Assert.fail("sample " + value + " not in the domain " + n);
127 }
128 }
129
130 private void assertUniformSamples(int n, int k) {
131
132
133
134
135
136 final int totalBitCombinations = 1 << n;
137 int[] codeLookup = new int[totalBitCombinations];
138 Arrays.fill(codeLookup, -1);
139 int codes = 0;
140 for (int i = 0; i < totalBitCombinations; i++) {
141 if (Integer.bitCount(i) == k) {
142
143 codeLookup[i] = codes++;
144 }
145 }
146
147
148 Assert.assertEquals("Incorrect number of combination codes",
149 CombinatoricsUtils.binomialCoefficient(n, k), codes);
150
151 final long[] observed = new long[codes];
152 final int numSamples = 6000;
153
154 final CombinationSampler sampler = new CombinationSampler(rng, n, k);
155 for (int i = 0; i < numSamples; i++) {
156 observed[findCode(codeLookup, sampler.sample())]++;
157 }
158
159
160 final double numExpected = numSamples / (double) codes;
161 final double[] expected = new double[codes];
162 Arrays.fill(expected, numExpected);
163 final ChiSquareTest chiSquareTest = new ChiSquareTest();
164
165 Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
166 }
167
168 private static int findCode(int[] codeLookup, int[] sample) {
169
170
171 int bits = 0;
172 for (int s : sample) {
173
174
175 bits |= (1 << s);
176 }
177 if (bits >= codeLookup.length) {
178 Assert.fail("Bad bit combination: " + Arrays.toString(sample));
179 }
180 final int code = codeLookup[bits];
181 if (code < 0) {
182 Assert.fail("Bad bit code: " + Arrays.toString(sample));
183 }
184 return code;
185 }
186 }