View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import org.apache.commons.math3.distribution.BinomialDistribution;
20  import org.apache.commons.math3.distribution.PoissonDistribution;
21  import org.apache.commons.math3.stat.inference.ChiSquareTest;
22  import org.apache.commons.rng.UniformRandomProvider;
23  import org.apache.commons.rng.sampling.RandomAssert;
24  import org.apache.commons.rng.simple.RandomSource;
25  import org.junit.Assert;
26  import org.junit.Test;
27  
28  import java.util.Arrays;
29  
30  /**
31   * Test for the {@link AliasMethodDiscreteSampler}.
32   */
33  public class AliasMethodDiscreteSamplerTest {
34      @Test(expected = IllegalArgumentException.class)
35      public void testConstructorThrowsWithNullProbabilites() {
36          createSampler(null);
37      }
38  
39      @Test(expected = IllegalArgumentException.class)
40      public void testConstructorThrowsWithZeroLengthProbabilites() {
41          createSampler(new double[0]);
42      }
43  
44      @Test(expected = IllegalArgumentException.class)
45      public void testConstructorThrowsWithNegativeProbabilites() {
46          createSampler(new double[] {-1, 0.1, 0.2});
47      }
48  
49      @Test(expected = IllegalArgumentException.class)
50      public void testConstructorThrowsWithNaNProbabilites() {
51          createSampler(new double[] {0.1, Double.NaN, 0.2});
52      }
53  
54      @Test(expected = IllegalArgumentException.class)
55      public void testConstructorThrowsWithInfiniteProbabilites() {
56          createSampler(new double[] {0.1, Double.POSITIVE_INFINITY, 0.2});
57      }
58  
59      @Test(expected = IllegalArgumentException.class)
60      public void testConstructorThrowsWithInfiniteSumProbabilites() {
61          createSampler(new double[] {Double.MAX_VALUE, Double.MAX_VALUE});
62      }
63  
64      @Test(expected = IllegalArgumentException.class)
65      public void testConstructorThrowsWithZeroSumProbabilites() {
66          createSampler(new double[4]);
67      }
68  
69      @Test
70      public void testToString() {
71          final SharedStateDiscreteSampler sampler = createSampler(new double[] {0.5, 0.5});
72          Assert.assertTrue(sampler.toString().toLowerCase().contains("alias method"));
73      }
74  
75      /**
76       * Creates the sampler without zero-padding enabled.
77       *
78       * @param probabilities the probabilities
79       * @return the alias method discrete sampler
80       */
81      private static SharedStateDiscreteSampler createSampler(double[] probabilities) {
82          final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64);
83          return AliasMethodDiscreteSampler.of(rng, probabilities, -1);
84      }
85  
86      /**
87       * Test sampling from a binomial distribution.
88       */
89      @Test
90      public void testBinomialSamples() {
91          final int trials = 67;
92          final double probabilityOfSuccess = 0.345;
93          final BinomialDistribution dist = new BinomialDistribution(trials, probabilityOfSuccess);
94          final double[] expected = new double[trials + 1];
95          for (int i = 0; i < expected.length; i++) {
96              expected[i] = dist.probability(i);
97          }
98          checkSamples(expected);
99      }
100 
101     /**
102      * Test sampling from a Poisson distribution.
103      */
104     @Test
105     public void testPoissonSamples() {
106         final double mean = 3.14;
107         final PoissonDistribution dist = new PoissonDistribution(null, mean,
108             PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS);
109         final int maxN = dist.inverseCumulativeProbability(1 - 1e-6);
110         double[] expected = new double[maxN];
111         for (int i = 0; i < expected.length; i++) {
112             expected[i] = dist.probability(i);
113         }
114         checkSamples(expected);
115     }
116 
117     /**
118      * Test sampling from a non-uniform distribution of probabilities (these sum to 1).
119      */
120     @Test
121     public void testNonUniformSamplesWithProbabilities() {
122         final double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3 };
123         checkSamples(expected);
124     }
125 
126     /**
127      * Test sampling from a non-uniform distribution using the factory constructor to zero pad
128      * the input probabilities.
129      */
130     @Test
131     public void testNonUniformSamplesWithProbabilitiesWithDefaultFactoryConstructor() {
132         final double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3 };
133         checkSamples(AliasMethodDiscreteSampler.of(RandomSource.create(RandomSource.SPLIT_MIX_64), expected), expected);
134     }
135 
136     /**
137      * Test sampling from a non-uniform distribution of observations (i.e. the sum is not 1 as per
138      * probabilities).
139      */
140     @Test
141     public void testNonUniformSamplesWithObservations() {
142         final double[] expected = {1, 2, 3, 1, 3 };
143         checkSamples(expected);
144     }
145 
146     /**
147      * Test sampling from a non-uniform distribution of probabilities (these sum to 1).
148      * Extra zero-values are added to make the table size a power of 2.
149      */
150     @Test
151     public void testNonUniformSamplesWithProbabilitiesPaddedToPowerOf2() {
152         final double[] expected = {0.1, 0, 0.2, 0.3, 0.1, 0.3, 0, 0 };
153         checkSamples(expected);
154     }
155 
156     /**
157      * Test sampling from a non-uniform distribution of observations (i.e. the sum is not 1 as per
158      * probabilities). Extra zero-values are added to make the table size a power of 2.
159      */
160     @Test
161     public void testNonUniformSamplesWithObservationsPaddedToPowerOf2() {
162         final double[] expected = {1, 2, 3, 0, 1, 3, 0, 0 };
163         checkSamples(expected);
164     }
165 
166     /**
167      * Test sampling from a non-uniform distribution of probabilities (these sum to 1).
168      * Extra zero-values are added.
169      */
170     @Test
171     public void testNonUniformSamplesWithZeroProbabilities() {
172         final double[] expected = {0.1, 0, 0.2, 0.3, 0.1, 0.3, 0 };
173         checkSamples(expected);
174     }
175 
176     /**
177      * Test sampling from a non-uniform distribution of observations (i.e. the sum is not 1 as per
178      * probabilities). Extra zero-values are added.
179      */
180     @Test
181     public void testNonUniformSamplesWithZeroObservations() {
182         final double[] expected = {1, 2, 3, 0, 1, 3, 0 };
183         checkSamples(expected);
184     }
185 
186     /**
187      * Test sampling from a uniform distribution. This is an edge case where there
188      * are no probabilities less than the mean.
189      */
190     @Test
191     public void testUniformSamplesWithNoObservationLessThanTheMean() {
192         final double[] expected = {2, 2, 2, 2, 2, 2 };
193         checkSamples(expected);
194     }
195 
196     /**
197      * Test sampling from a non-uniform distribution which is zero-padded to a large size.
198      */
199     @Test
200     public void testLargeTableSize() {
201         double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3 };
202         // Pad to a large table size not supported for fast sampling (anything > 2^11)
203         expected = Arrays.copyOf(expected, 1 << 12);
204         checkSamples(expected);
205     }
206 
207     /**
208      * Check the distribution of samples match the expected probabilities.
209      *
210      * @param expected the expected probabilities
211      */
212     private static void checkSamples(double[] probabilies) {
213         checkSamples(createSampler(probabilies), probabilies);
214     }
215 
216     /**
217      * Check the distribution of samples match the expected probabilities.
218      *
219      * @param expected the expected probabilities
220      */
221     private static void checkSamples(SharedStateDiscreteSampler sampler, double[] probabilies) {
222         final int numberOfSamples = 10000;
223         final long[] samples = new long[probabilies.length];
224         for (int i = 0; i < numberOfSamples; i++) {
225             samples[sampler.sample()]++;
226         }
227 
228         // Handle a test with some zero-probability observations by mapping them out
229         int mapSize = 0;
230         for (int i = 0; i < probabilies.length; i++) {
231             if (probabilies[i] != 0) {
232                 mapSize++;
233             }
234         }
235 
236         double[] expected = new double[mapSize];
237         long[] observed = new long[mapSize];
238         for (int i = 0; i < probabilies.length; i++) {
239             if (probabilies[i] != 0) {
240                 --mapSize;
241                 expected[mapSize] = probabilies[i];
242                 observed[mapSize] = samples[i];
243             } else {
244                 Assert.assertEquals("No samples expected from zero probability", 0, samples[i]);
245             }
246         }
247 
248         final ChiSquareTest chiSquareTest = new ChiSquareTest();
249         // Pass if we cannot reject null hypothesis that the distributions are the same.
250         Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
251     }
252 
253     /**
254      * Test the SharedStateSampler implementation for the specialised power-of-2 table size.
255      */
256     @Test
257     public void testSharedStateSamplerWithPowerOf2TableSize() {
258         testSharedStateSampler(new double[] {0.1, 0.2, 0.3, 0.4});
259     }
260 
261     /**
262      * Test the SharedStateSampler implementation for the generic non power-of-2 table size.
263      */
264     @Test
265     public void testSharedStateSamplerWithNonPowerOf2TableSize() {
266         testSharedStateSampler(new double[] {0.1, 0.2, 0.3});
267     }
268 
269     /**
270      * Test the SharedStateSampler implementation.
271      *
272      * @param probabilities The probabilities
273      */
274     private static void testSharedStateSampler(double[] probabilities) {
275         final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
276         final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
277         // Use negative alpha to disable padding
278         final SharedStateDiscreteSampler sampler1 =
279             AliasMethodDiscreteSampler.of(rng1, probabilities, -1);
280         final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
281         RandomAssert.assertProduceSameSequence(sampler1, sampler2);
282     }
283 }