View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import java.util.Arrays;
20  import java.util.List;
21  import java.util.ArrayList;
22  
23  import org.junit.Assert;
24  import org.junit.Test;
25  import org.junit.runner.RunWith;
26  import org.junit.runners.Parameterized;
27  import org.junit.runners.Parameterized.Parameters;
28  
29  import org.apache.commons.math3.distribution.ChiSquaredDistribution;
30  import org.apache.commons.math3.stat.inference.ChiSquareTest;
31  
32  /**
33   * Tests for random deviates generators.
34   */
35  @RunWith(value=Parameterized.class)
36  public class DiscreteSamplerParametricTest {
37      /** Sampler under test. */
38      private final DiscreteSamplerTestData sampler;
39  
40      /**
41       * Initializes generator instance.
42       *
43       * @param rng RNG to be tested.
44       */
45      public DiscreteSamplerParametricTest(DiscreteSamplerTestData data) {
46          sampler = data;
47      }
48  
49      @Parameters(name = "{index}: data={0}")
50      public static Iterable<DiscreteSamplerTestData[]> getList() {
51          return DiscreteSamplersList.list();
52      }
53  
54      @Test
55      public void testSampling() {
56          final int sampleSize = 10000;
57  
58          final double[] prob = sampler.getProbabilities();
59          final int len = prob.length; 
60          final double[] expected = new double[len];
61          for (int i = 0; i < len; i++) {
62              expected[i] = prob[i] * sampleSize;
63          }
64          check(sampleSize,
65                sampler.getSampler(),
66                sampler.getPoints(),
67                expected);
68      }
69  
70      /**
71       * Performs a chi-square test of homogeneity of the observed
72       * distribution with the expected distribution.
73       * An average failure rate higher than 5% causes the test case
74       * to fail.
75       *
76       * @param sampler Sampler.
77       * @param sampleSize Number of random values to generate.
78       * @param points Outcomes.
79       * @param expected Expected counts of the given outcomes.
80       */
81      private void check(long sampleSize,
82                         DiscreteSampler sampler,
83                         int[] points,
84                         double[] expected) {
85          final ChiSquareTest chiSquareTest = new ChiSquareTest();
86          final int numTests = 50;
87  
88          // Run the tests.
89          int numFailures = 0;
90  
91          final int numBins = points.length;
92          final long[] observed = new long[numBins];
93  
94          // For storing chi2 larger than the critical value.
95          final List<Double> failedStat = new ArrayList<Double>();
96          try {
97              for (int i = 0; i < numTests; i++) {
98                  Arrays.fill(observed, 0);
99                  SAMPLE: for (long j = 0; j < sampleSize; j++) {
100                     final int value = sampler.sample();
101 
102                     for (int k = 0; k < numBins; k++) {
103                         if (value == points[k]) {
104                             ++observed[k];
105                             continue SAMPLE;
106                         }
107                     }
108                 }
109 
110                 if (chiSquareTest.chiSquareTest(expected, observed, 0.001)) {
111                     failedStat.add(chiSquareTest.chiSquareTest(expected, observed));
112                     ++numFailures;
113                 }
114             }
115         } catch (Exception e) {
116             // Should never happen.
117             throw new RuntimeException("Unexpected", e);
118         }
119 
120         if ((double) numFailures / (double) numTests > 0.05) {
121             Assert.fail(sampler + ": Too many failures for sample size = " + sampleSize +
122                         " (" + numFailures + " out of " + numTests + " tests failed, " +
123                         "chi2=" + Arrays.toString(failedStat.toArray(new Double[0])));
124         }
125     }
126 }