View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling;
18  
19  import java.util.Arrays;
20  
21  import org.junit.Assert;
22  import org.junit.Test;
23  
24  import org.apache.commons.math3.stat.inference.ChiSquareTest;
25  import org.apache.commons.math3.util.CombinatoricsUtils;
26  import org.apache.commons.rng.UniformRandomProvider;
27  import org.apache.commons.rng.simple.RandomSource;
28  
29  /**
30   * Tests for {@link CombinationSampler}.
31   */
32  public class CombinationSamplerTest {
33      private final UniformRandomProvider rng = RandomSource.create(RandomSource.XOR_SHIFT_1024_S);
34  
35      @Test
36      public void testSampleIsInDomain() {
37          final int n = 6;
38          for (int k = 1; k <= n; k++) {
39              final CombinationSampler sampler = new CombinationSampler(rng, n, k);
40              final int[] random = sampler.sample();
41              for (int s : random) {
42                  assertIsInDomain(n, s);
43              }
44          }
45      }
46  
47      @Test
48      public void testUniformWithKlessThanHalfN() {
49          final int n = 8;
50          final int k = 2;
51          assertUniformSamples(n, k);
52      }
53  
54      @Test
55      public void testUniformWithKmoreThanHalfN() {
56          final int n = 8;
57          final int k = 6;
58          assertUniformSamples(n, k);
59      }
60  
61      @Test
62      public void testSampleWhenNequalsKIsNotShuffled() {
63          // Check n == k boundary case.
64          // This is allowed but the sample is not shuffled.
65          for (int n = 1; n < 3; n++) {
66              final int k = n;
67              final CombinationSampler sampler = new CombinationSampler(rng, n, k);
68              final int[] sample = sampler.sample();
69              Assert.assertEquals("Incorrect sample length", n, sample.length);
70              for (int i = 0; i < n; i++) {
71                  Assert.assertEquals("Sample was shuffled", i, sample[i]);
72              }
73          }
74      }
75  
76      @Test(expected = IllegalArgumentException.class)
77      public void testKgreaterThanNThrows() {
78          // Must fail for k > n.
79          final int n = 2;
80          final int k = 3;
81          new CombinationSampler(rng, n, k);
82      }
83  
84      @Test(expected = IllegalArgumentException.class)
85      public void testNequalsZeroThrows() {
86          // Must fail for n = 0.
87          final int n = 0;
88          final int k = 3;
89          new CombinationSampler(rng, n, k);
90      }
91  
92      @Test(expected = IllegalArgumentException.class)
93      public void testKequalsZeroThrows() {
94          // Must fail for k = 0.
95          final int n = 2;
96          final int k = 0;
97          new CombinationSampler(rng, n, k);
98      }
99  
100     @Test(expected = IllegalArgumentException.class)
101     public void testNisNegativeThrows() {
102         // Must fail for n <= 0.
103         final int n = -1;
104         final int k = 3;
105         new CombinationSampler(rng, n, k);
106     }
107 
108     @Test(expected = IllegalArgumentException.class)
109     public void testKisNegativeThrows() {
110         // Must fail for k <= 0.
111         final int n = 0;
112         final int k = -1;
113         new CombinationSampler(rng, n, k);
114     }
115 
116     /**
117      * Test the SharedStateSampler implementation.
118      */
119     @Test
120     public void testSharedStateSampler() {
121         final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
122         final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
123         final int n = 17;
124         final int k = 3;
125         final CombinationSampler sampler1 =
126             new CombinationSampler(rng1, n, k);
127         final CombinationSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
128         RandomAssert.assertProduceSameSequence(
129             new RandomAssert.Sampler<int[]>() {
130                 @Override
131                 public int[] sample() {
132                     return sampler1.sample();
133                 }
134             },
135             new RandomAssert.Sampler<int[]>() {
136                 @Override
137                 public int[] sample() {
138                     return sampler2.sample();
139                 }
140             });
141     }
142 
143     //// Support methods.
144 
145     /**
146      * Asserts the sample value is in the range 0 to n-1.
147      *
148      * @param n     the n
149      * @param value the sample value
150      */
151     private static void assertIsInDomain(int n, int value) {
152         if (value < 0 || value >= n) {
153             Assert.fail("sample " + value + " not in the domain " + n);
154         }
155     }
156 
157     private void assertUniformSamples(int n, int k) {
158         // The C(n, k) should generate a sample of unspecified order.
159         // To test this each combination is allocated a unique code
160         // based on setting k of the first n-bits in an integer.
161         // Codes are positive for all combinations of bits that use k-bits,
162         // otherwise they are negative.
163         final int totalBitCombinations = 1 << n;
164         int[] codeLookup = new int[totalBitCombinations];
165         Arrays.fill(codeLookup, -1); // initialise as negative
166         int codes = 0;
167         for (int i = 0; i < totalBitCombinations; i++) {
168             if (Integer.bitCount(i) == k) {
169                 // This is a valid sample so allocate a code
170                 codeLookup[i] = codes++;
171             }
172         }
173 
174         // The number of combinations C(n, k) is the binomial coefficient
175         Assert.assertEquals("Incorrect number of combination codes",
176                 CombinatoricsUtils.binomialCoefficient(n, k), codes);
177 
178         final long[] observed = new long[codes];
179         final int numSamples = 6000;
180 
181         final CombinationSampler sampler = new CombinationSampler(rng, n, k);
182         for (int i = 0; i < numSamples; i++) {
183             observed[findCode(codeLookup, sampler.sample())]++;
184         }
185 
186         // Chi squared test of uniformity
187         final double numExpected = numSamples / (double) codes;
188         final double[] expected = new double[codes];
189         Arrays.fill(expected, numExpected);
190         final ChiSquareTest chiSquareTest = new ChiSquareTest();
191         // Pass if we cannot reject null hypothesis that distributions are the same.
192         Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
193     }
194 
195     private static int findCode(int[] codeLookup, int[] sample) {
196         // Each sample index is used to set a bit in an integer.
197         // The resulting bits should be a valid code.
198         int bits = 0;
199         for (int s : sample) {
200             // This shift will be from 0 to n-1 since it is from the
201             // domain of size n.
202             bits |= 1 << s;
203         }
204         if (bits >= codeLookup.length) {
205             Assert.fail("Bad bit combination: " + Arrays.toString(sample));
206         }
207         final int code = codeLookup[bits];
208         if (code < 0) {
209             Assert.fail("Bad bit code: " + Arrays.toString(sample));
210         }
211         return code;
212     }
213 }