1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import java.util.Arrays;
20 import java.util.List;
21 import java.util.ArrayList;
22
23 import org.junit.Assert;
24 import org.junit.Test;
25 import org.junit.runner.RunWith;
26 import org.junit.runners.Parameterized;
27 import org.junit.runners.Parameterized.Parameters;
28
29 import org.apache.commons.math3.stat.inference.ChiSquareTest;
30
31
32
33
34 @RunWith(value = Parameterized.class)
35 public class DiscreteSamplerParametricTest {
36
37 private final DiscreteSamplerTestData sampler;
38
39
40
41
42
43
44 public DiscreteSamplerParametricTest(DiscreteSamplerTestData data) {
45 sampler = data;
46 }
47
48 @Parameters(name = "{index}: data={0}")
49 public static Iterable<DiscreteSamplerTestData[]> getList() {
50 return DiscreteSamplersList.list();
51 }
52
53 @Test
54 public void testSampling() {
55 final int sampleSize = 10000;
56
57 final double[] prob = sampler.getProbabilities();
58 final int len = prob.length;
59 final double[] expected = new double[len];
60 for (int i = 0; i < len; i++) {
61 expected[i] = prob[i] * sampleSize;
62 }
63 check(sampleSize,
64 sampler.getSampler(),
65 sampler.getPoints(),
66 expected);
67 }
68
69
70
71
72
73
74
75
76
77
78
79
80 private static void check(long sampleSize,
81 DiscreteSampler sampler,
82 int[] points,
83 double[] expected) {
84 final ChiSquareTest chiSquareTest = new ChiSquareTest();
85 final int numTests = 50;
86
87
88 int numFailures = 0;
89
90 final int numBins = points.length;
91 final long[] observed = new long[numBins];
92
93
94 final List<Double> failedStat = new ArrayList<Double>();
95 try {
96 for (int i = 0; i < numTests; i++) {
97 Arrays.fill(observed, 0);
98 SAMPLE: for (long j = 0; j < sampleSize; j++) {
99 final int value = sampler.sample();
100
101 for (int k = 0; k < numBins; k++) {
102 if (value == points[k]) {
103 ++observed[k];
104 continue SAMPLE;
105 }
106 }
107 }
108
109 if (chiSquareTest.chiSquareTest(expected, observed, 0.01)) {
110 failedStat.add(chiSquareTest.chiSquareTest(expected, observed));
111 ++numFailures;
112 }
113 }
114 } catch (Exception e) {
115
116 throw new RuntimeException("Unexpected", e);
117 }
118
119
120
121
122
123
124
125
126
127 if (numFailures > 3) {
128 Assert.fail(sampler + ": Too many failures for sample size = " + sampleSize +
129 " (" + numFailures + " out of " + numTests + " tests failed, " +
130 "chi2=" + Arrays.toString(failedStat.toArray(new Double[0])));
131 }
132 }
133 }