View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import org.apache.commons.math3.distribution.BinomialDistribution;
20  import org.apache.commons.math3.distribution.PoissonDistribution;
21  import org.apache.commons.math3.stat.inference.ChiSquareTest;
22  import org.apache.commons.rng.UniformRandomProvider;
23  import org.apache.commons.rng.sampling.RandomAssert;
24  import org.apache.commons.rng.simple.RandomSource;
25  import org.junit.Assert;
26  import org.junit.Test;
27  
28  /**
29   * Test for the {@link GuideTableDiscreteSampler}.
30   */
31  public class GuideTableDiscreteSamplerTest {
32      @Test(expected = IllegalArgumentException.class)
33      public void testConstructorThrowsWithNullProbabilites() {
34          createSampler(null, 1.0);
35      }
36  
37      @Test(expected = IllegalArgumentException.class)
38      public void testConstructorThrowsWithZeroLengthProbabilites() {
39          createSampler(new double[0], 1.0);
40      }
41  
42      @Test(expected = IllegalArgumentException.class)
43      public void testConstructorThrowsWithNegativeProbabilites() {
44          createSampler(new double[] {-1, 0.1, 0.2}, 1.0);
45      }
46  
47      @Test(expected = IllegalArgumentException.class)
48      public void testConstructorThrowsWithNaNProbabilites() {
49          createSampler(new double[] {0.1, Double.NaN, 0.2}, 1.0);
50      }
51  
52      @Test(expected = IllegalArgumentException.class)
53      public void testConstructorThrowsWithInfiniteProbabilites() {
54          createSampler(new double[] {0.1, Double.POSITIVE_INFINITY, 0.2}, 1.0);
55      }
56  
57      @Test(expected = IllegalArgumentException.class)
58      public void testConstructorThrowsWithInfiniteSumProbabilites() {
59          createSampler(new double[] {Double.MAX_VALUE, Double.MAX_VALUE}, 1.0);
60      }
61  
62      @Test(expected = IllegalArgumentException.class)
63      public void testConstructorThrowsWithZeroSumProbabilites() {
64          createSampler(new double[4], 1.0);
65      }
66  
67      @Test(expected = IllegalArgumentException.class)
68      public void testConstructorThrowsWithZeroAlpha() {
69          createSampler(new double[] {0.5, 0.5}, 0.0);
70      }
71  
72      @Test(expected = IllegalArgumentException.class)
73      public void testConstructorThrowsWithNegativeAlpha() {
74          createSampler(new double[] {0.5, 0.5}, -1.0);
75      }
76  
77      @Test
78      public void testToString() {
79          final SharedStateDiscreteSampler sampler = createSampler(new double[] {0.5, 0.5}, 1.0);
80          Assert.assertTrue(sampler.toString().toLowerCase().contains("guide table"));
81      }
82  
83      /**
84       * Creates the sampler.
85       *
86       * @param probabilities the probabilities
87       * @return the alias method discrete sampler
88       */
89      private static SharedStateDiscreteSampler createSampler(double[] probabilities, double alpha) {
90          final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64);
91          return GuideTableDiscreteSampler.of(rng, probabilities, alpha);
92      }
93  
94      /**
95       * Test sampling from a binomial distribution.
96       */
97      @Test
98      public void testBinomialSamples() {
99          final int trials = 67;
100         final double probabilityOfSuccess = 0.345;
101         final BinomialDistribution dist = new BinomialDistribution(null, trials, probabilityOfSuccess);
102         final double[] expected = new double[trials + 1];
103         for (int i = 0; i < expected.length; i++) {
104             expected[i] = dist.probability(i);
105         }
106         checkSamples(expected, 1.0);
107     }
108 
109     /**
110      * Test sampling from a Poisson distribution.
111      */
112     @Test
113     public void testPoissonSamples() {
114         final double mean = 3.14;
115         final PoissonDistribution dist = new PoissonDistribution(null, mean,
116             PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS);
117         final int maxN = dist.inverseCumulativeProbability(1 - 1e-6);
118         final double[] expected = new double[maxN];
119         for (int i = 0; i < expected.length; i++) {
120             expected[i] = dist.probability(i);
121         }
122         checkSamples(expected, 1.0);
123     }
124 
125     /**
126      * Test sampling from a non-uniform distribution of probabilities (these sum to 1).
127      */
128     @Test
129     public void testNonUniformSamplesWithProbabilities() {
130         final double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3};
131         checkSamples(expected, 1.0);
132     }
133 
134     /**
135      * Test sampling from a non-uniform distribution of probabilities with an alpha smaller than
136      * the default.
137      */
138     @Test
139     public void testNonUniformSamplesWithProbabilitiesWithSmallAlpha() {
140         final double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3};
141         checkSamples(expected, 0.1);
142     }
143 
144     /**
145      * Test sampling from a non-uniform distribution of probabilities with an alpha larger than
146      * the default.
147      */
148     @Test
149     public void testNonUniformSamplesWithProbabilitiesWithLargeAlpha() {
150         final double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3};
151         checkSamples(expected, 10.0);
152     }
153 
154     /**
155      * Test sampling from a non-uniform distribution of observations (i.e. the sum is not 1 as per
156      * probabilities).
157      */
158     @Test
159     public void testNonUniformSamplesWithObservations() {
160         final double[] expected = {1, 2, 3, 1, 3};
161         checkSamples(expected, 1.0);
162     }
163 
164     /**
165      * Test sampling from a non-uniform distribution of probabilities (these sum to 1).
166      * Extra zero-values are added.
167      */
168     @Test
169     public void testNonUniformSamplesWithZeroProbabilities() {
170         final double[] expected = {0.1, 0, 0.2, 0.3, 0.1, 0.3, 0};
171         checkSamples(expected, 1.0);
172     }
173 
174     /**
175      * Test sampling from a non-uniform distribution of observations (i.e. the sum is not 1 as per
176      * probabilities). Extra zero-values are added.
177      */
178     @Test
179     public void testNonUniformSamplesWithZeroObservations() {
180         final double[] expected = {1, 2, 3, 0, 1, 3, 0};
181         checkSamples(expected, 1.0);
182     }
183 
184     /**
185      * Test sampling from a uniform distribution. This is an edge case where there
186      * are no probabilities less than the mean.
187      */
188     @Test
189     public void testUniformSamplesWithNoObservationLessThanTheMean() {
190         final double[] expected = {2, 2, 2, 2, 2, 2};
191         checkSamples(expected, 1.0);
192     }
193 
194     /**
195      * Check the distribution of samples match the expected probabilities.
196      *
197      * <p>If the expected probability is zero then this should never be sampled. The non-zero
198      * probabilities are compared to the sample distribution using a Chi-square test.</p>
199      *
200      * @param probabilies the probabilities
201      * @param alpha the alpha
202      */
203     private static void checkSamples(double[] probabilies, double alpha) {
204         final SharedStateDiscreteSampler sampler = createSampler(probabilies, alpha);
205 
206         final int numberOfSamples = 10000;
207         final long[] samples = new long[probabilies.length];
208         for (int i = 0; i < numberOfSamples; i++) {
209             samples[sampler.sample()]++;
210         }
211 
212         // Handle a test with some zero-probability observations by mapping them out.
213         // The results is the Chi-square test is performed using only the non-zero probabilities.
214         int mapSize = 0;
215         for (int i = 0; i < probabilies.length; i++) {
216             if (probabilies[i] != 0) {
217                 mapSize++;
218             }
219         }
220 
221         final double[] expected = new double[mapSize];
222         final long[] observed = new long[mapSize];
223         for (int i = 0; i < probabilies.length; i++) {
224             if (probabilies[i] == 0) {
225                 Assert.assertEquals("No samples expected from zero probability", 0, samples[i]);
226             } else {
227                 // This can be added for the Chi-square test
228                 --mapSize;
229                 expected[mapSize] = probabilies[i];
230                 observed[mapSize] = samples[i];
231             }
232         }
233 
234         final ChiSquareTest chiSquareTest = new ChiSquareTest();
235         // Pass if we cannot reject null hypothesis that the distributions are the same.
236         Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
237     }
238 
239     /**
240      * Test the SharedStateSampler implementation.
241      */
242     @Test
243     public void testSharedStateSampler() {
244         final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
245         final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
246         final double[] probabilities = {0.1, 0, 0.2, 0.3, 0.1, 0.3, 0};
247         final SharedStateDiscreteSampler sampler1 =
248             GuideTableDiscreteSampler.of(rng1, probabilities);
249         final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
250         RandomAssert.assertProduceSameSequence(sampler1, sampler2);
251     }
252 }