1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.rng.sampling;
19
20 import java.util.Arrays;
21 import org.junit.Assert;
22 import org.junit.Test;
23 import org.apache.commons.rng.UniformRandomProvider;
24 import org.apache.commons.rng.simple.RandomSource;
25
26
27
28
29 public class DiscreteProbabilityCollectionSamplerTest {
30
31 private static final UniformRandomProvider rng = RandomSource.create(RandomSource.WELL_1024_A);
32
33 @Test(expected=IllegalArgumentException.class)
34 public void testPrecondition1() {
35 new DiscreteProbabilityCollectionSampler<Double>(rng,
36 Arrays.asList(new Double[] {1d, 2d}),
37 new double[] {0d});
38 }
39 @Test(expected=IllegalArgumentException.class)
40 public void testPrecondition2() {
41 new DiscreteProbabilityCollectionSampler<Double>(rng,
42 Arrays.asList(new Double[] {1d, 2d}),
43 new double[] {0d, -1d});
44 }
45 @Test(expected=IllegalArgumentException.class)
46 public void testPrecondition3() {
47 new DiscreteProbabilityCollectionSampler<Double>(rng,
48 Arrays.asList(new Double[] {1d, 2d}),
49 new double[] {0d, 0d});
50 }
51 @Test(expected=IllegalArgumentException.class)
52 public void testPrecondition4() {
53 new DiscreteProbabilityCollectionSampler<Double>(rng,
54 Arrays.asList(new Double[] {1d, 2d}),
55 new double[] {0d, Double.NaN});
56 }
57 @Test(expected=IllegalArgumentException.class)
58 public void testPrecondition5() {
59 new DiscreteProbabilityCollectionSampler<Double>(rng,
60 Arrays.asList(new Double[] {1d, 2d}),
61 new double[] {0d, Double.POSITIVE_INFINITY});
62 }
63
64 @Test
65 public void testSample() {
66 final DiscreteProbabilityCollectionSampler<Double> sampler =
67 new DiscreteProbabilityCollectionSampler<Double>(rng,
68 Arrays.asList(new Double[] {3d, -1d, 3d, 7d, -2d, 8d}),
69 new double[] {0.2, 0.2, 0.3, 0.3, 0, 0});
70 final double expectedMean = 3.4;
71 final double expectedVariance = 7.84;
72
73 final int n = 100000000;
74 double sum = 0;
75 double sumOfSquares = 0;
76 for (int i = 0; i < n; i++) {
77 final double rand = sampler.sample();
78 sum += rand;
79 sumOfSquares += rand * rand;
80 }
81
82 final double mean = sum / n;
83 Assert.assertEquals(expectedMean, mean, 1e-3);
84 final double variance = sumOfSquares / n - mean * mean;
85 Assert.assertEquals(expectedVariance, variance, 2e-3);
86 }
87 }