1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import java.util.Arrays;
20 import java.util.List;
21 import java.util.ArrayList;
22
23 import org.junit.Assert;
24 import org.junit.Test;
25 import org.junit.runner.RunWith;
26 import org.junit.runners.Parameterized;
27 import org.junit.runners.Parameterized.Parameters;
28
29 import org.apache.commons.math3.distribution.ChiSquaredDistribution;
30 import org.apache.commons.math3.stat.inference.ChiSquareTest;
31
32
33
34
35 @RunWith(value=Parameterized.class)
36 public class DiscreteSamplerParametricTest {
37
38 private final DiscreteSamplerTestData sampler;
39
40
41
42
43
44
45 public DiscreteSamplerParametricTest(DiscreteSamplerTestData data) {
46 sampler = data;
47 }
48
49 @Parameters(name = "{index}: data={0}")
50 public static Iterable<DiscreteSamplerTestData[]> getList() {
51 return DiscreteSamplersList.list();
52 }
53
54 @Test
55 public void testSampling() {
56 final int sampleSize = 10000;
57
58 final double[] prob = sampler.getProbabilities();
59 final int len = prob.length;
60 final double[] expected = new double[len];
61 for (int i = 0; i < len; i++) {
62 expected[i] = prob[i] * sampleSize;
63 }
64 check(sampleSize,
65 sampler.getSampler(),
66 sampler.getPoints(),
67 expected);
68 }
69
70
71
72
73
74
75
76
77
78
79
80
81 private void check(long sampleSize,
82 DiscreteSampler sampler,
83 int[] points,
84 double[] expected) {
85 final ChiSquareTest chiSquareTest = new ChiSquareTest();
86 final int numTests = 50;
87
88
89 int numFailures = 0;
90
91 final int numBins = points.length;
92 final long[] observed = new long[numBins];
93
94
95 final List<Double> failedStat = new ArrayList<Double>();
96 try {
97 for (int i = 0; i < numTests; i++) {
98 Arrays.fill(observed, 0);
99 SAMPLE: for (long j = 0; j < sampleSize; j++) {
100 final int value = sampler.sample();
101
102 for (int k = 0; k < numBins; k++) {
103 if (value == points[k]) {
104 ++observed[k];
105 continue SAMPLE;
106 }
107 }
108 }
109
110 if (chiSquareTest.chiSquareTest(expected, observed, 0.001)) {
111 failedStat.add(chiSquareTest.chiSquareTest(expected, observed));
112 ++numFailures;
113 }
114 }
115 } catch (Exception e) {
116
117 throw new RuntimeException("Unexpected", e);
118 }
119
120 if ((double) numFailures / (double) numTests > 0.05) {
121 Assert.fail(sampler + ": Too many failures for sample size = " + sampleSize +
122 " (" + numFailures + " out of " + numTests + " tests failed, " +
123 "chi2=" + Arrays.toString(failedStat.toArray(new Double[0])));
124 }
125 }
126 }