View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing sampleissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling;
18  
19  import java.util.Arrays;
20  
21  import org.junit.Assert;
22  import org.junit.Test;
23  
24  import org.apache.commons.math3.stat.inference.ChiSquareTest;
25  import org.apache.commons.math3.util.CombinatoricsUtils;
26  import org.apache.commons.rng.UniformRandomProvider;
27  import org.apache.commons.rng.simple.RandomSource;
28  
29  /**
30   * Tests for {@link CombinationSampler}.
31   */
32  public class CombinationSamplerTest {
33      private final UniformRandomProvider rng = RandomSource.create(RandomSource.XOR_SHIFT_1024_S);
34  
35      @Test
36      public void testSampleIsInDomain() {
37          final int n = 6;
38          for (int k = 1; k <= n; k++) {
39              final CombinationSampler sampler = new CombinationSampler(rng, n, k);
40              final int[] random = sampler.sample();
41              for (int s : random) {
42                  assertIsInDomain(n, s);
43              }
44          }
45      }
46  
47      @Test
48      public void testUniformWithKlessThanHalfN() {
49          final int n = 8;
50          final int k = 2;
51          assertUniformSamples(n, k);
52      }
53  
54      @Test
55      public void testUniformWithKmoreThanHalfN() {
56          final int n = 8;
57          final int k = 6;
58          assertUniformSamples(n, k);
59      }
60  
61      @Test
62      public void testSampleWhenNequalsKIsNotShuffled() {
63          // Check n == k boundary case.
64          // This is allowed but the sample is not shuffled.
65          for (int n = 1; n < 3; n++) {
66              final int k = n;
67              final CombinationSampler sampler = new CombinationSampler(rng, n, k);
68              final int[] sample = sampler.sample();
69              Assert.assertEquals("Incorrect sample length", n, sample.length);
70              for (int i = 0; i < n; i++) {
71                  Assert.assertEquals("Sample was shuffled", i, sample[i]);
72              }
73          }
74      }
75  
76      @Test(expected = IllegalArgumentException.class)
77      public void testKgreaterThanNThrows() {
78          // Must fail for k > n.
79          final int n = 2;
80          final int k = 3;
81          new CombinationSampler(rng, n, k);
82      }
83  
84      @Test(expected = IllegalArgumentException.class)
85      public void testNequalsZeroThrows() {
86          // Must fail for n = 0.
87          final int n = 0;
88          final int k = 3;
89          new CombinationSampler(rng, n, k);
90      }
91  
92      @Test(expected = IllegalArgumentException.class)
93      public void testKequalsZeroThrows() {
94          // Must fail for k = 0.
95          final int n = 2;
96          final int k = 0;
97          new CombinationSampler(rng, n, k);
98      }
99  
100     @Test(expected = IllegalArgumentException.class)
101     public void testNisNegativeThrows() {
102         // Must fail for n <= 0.
103         final int n = -1;
104         final int k = 3;
105         new CombinationSampler(rng, n, k);
106     }
107 
108     @Test(expected = IllegalArgumentException.class)
109     public void testKisNegativeThrows() {
110         // Must fail for k <= 0.
111         final int n = 0;
112         final int k = -1;
113         new CombinationSampler(rng, n, k);
114     }
115 
116     //// Support methods.
117 
118     /**
119      * Asserts the sample value is in the range 0 to n-1.
120      *
121      * @param n     the n
122      * @param value the sample value
123      */
124     private static final void assertIsInDomain(int n, int value) {
125         if (value < 0 || value >= n) {
126             Assert.fail("sample " + value + " not in the domain " + n);
127         }
128     }
129 
130     private void assertUniformSamples(int n, int k) {
131         // The C(n, k) should generate a sample of unspecified order.
132         // To test this each combination is allocated a unique code
133         // based on setting k of the first n-bits in an integer.
134         // Codes are positive for all combinations of bits that use k-bits,
135         // otherwise they are negative.
136         final int totalBitCombinations = 1 << n;
137         int[] codeLookup = new int[totalBitCombinations];
138         Arrays.fill(codeLookup, -1); // initialise as negative
139         int codes = 0;
140         for (int i = 0; i < totalBitCombinations; i++) {
141             if (Integer.bitCount(i) == k) {
142                 // This is a valid sample so allocate a code
143                 codeLookup[i] = codes++;
144             }
145         }
146 
147         // The number of combinations C(n, k) is the binomial coefficient
148         Assert.assertEquals("Incorrect number of combination codes",
149                 CombinatoricsUtils.binomialCoefficient(n, k), codes);
150 
151         final long[] observed = new long[codes];
152         final int numSamples = 6000;
153 
154         final CombinationSampler sampler = new CombinationSampler(rng, n, k);
155         for (int i = 0; i < numSamples; i++) {
156             observed[findCode(codeLookup, sampler.sample())]++;
157         }
158 
159         // Chi squared test of uniformity
160         final double numExpected = numSamples / (double) codes;
161         final double[] expected = new double[codes];
162         Arrays.fill(expected, numExpected);
163         final ChiSquareTest chiSquareTest = new ChiSquareTest();
164         // Pass if we cannot reject null hypothesis that distributions are the same.
165         Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
166     }
167 
168     private static int findCode(int[] codeLookup, int[] sample) {
169         // Each sample index is used to set a bit in an integer.
170         // The resulting bits should be a valid code.
171         int bits = 0;
172         for (int s : sample) {
173             // This shift will be from 0 to n-1 since it is from the
174             // domain of size n.
175             bits |= (1 << s);
176         }
177         if (bits >= codeLookup.length) {
178             Assert.fail("Bad bit combination: " + Arrays.toString(sample));
179         }
180         final int code = codeLookup[bits];
181         if (code < 0) {
182             Assert.fail("Bad bit code: " + Arrays.toString(sample));
183         }
184         return code;
185     }
186 }