View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.sampling;
19  
20  import java.util.Arrays;
21  import java.util.Collections;
22  import java.util.HashMap;
23  import java.util.List;
24  import java.util.Map;
25  import java.util.TreeMap;
26  
27  import org.junit.Assert;
28  import org.junit.Test;
29  import org.apache.commons.rng.UniformRandomProvider;
30  import org.apache.commons.rng.simple.RandomSource;
31  
32  /**
33   * Test class for {@link DiscreteProbabilityCollectionSampler}.
34   */
35  public class DiscreteProbabilityCollectionSamplerTest {
36      /** RNG. */
37      private final UniformRandomProvider rng = RandomSource.create(RandomSource.WELL_1024_A);
38  
39      @Test(expected = IllegalArgumentException.class)
40      public void testPrecondition1() {
41          // Size mismatch
42          new DiscreteProbabilityCollectionSampler<Double>(rng,
43                                                           Arrays.asList(new Double[] {1d, 2d}),
44                                                           new double[] {0d});
45      }
46      @Test(expected = IllegalArgumentException.class)
47      public void testPrecondition2() {
48          // Negative probability
49          new DiscreteProbabilityCollectionSampler<Double>(rng,
50                                                           Arrays.asList(new Double[] {1d, 2d}),
51                                                           new double[] {0d, -1d});
52      }
53      @Test(expected = IllegalArgumentException.class)
54      public void testPrecondition3() {
55          // Probabilities do not sum above 0
56          new DiscreteProbabilityCollectionSampler<Double>(rng,
57                                                           Arrays.asList(new Double[] {1d, 2d}),
58                                                           new double[] {0d, 0d});
59      }
60      @Test(expected = IllegalArgumentException.class)
61      public void testPrecondition4() {
62          // NaN probability
63          new DiscreteProbabilityCollectionSampler<Double>(rng,
64                                                           Arrays.asList(new Double[] {1d, 2d}),
65                                                           new double[] {0d, Double.NaN});
66      }
67      @Test(expected = IllegalArgumentException.class)
68      public void testPrecondition5() {
69          // Infinite probability
70          new DiscreteProbabilityCollectionSampler<Double>(rng,
71                                                           Arrays.asList(new Double[] {1d, 2d}),
72                                                           new double[] {0d, Double.POSITIVE_INFINITY});
73      }
74      @Test(expected = IllegalArgumentException.class)
75      public void testPrecondition6() {
76          // Empty Map<T, Double> not allowed
77          new DiscreteProbabilityCollectionSampler<Double>(rng,
78                                                           new HashMap<Double, Double>());
79      }
80      @Test(expected = IllegalArgumentException.class)
81      public void testPrecondition7() {
82          // Empty List<T> not allowed
83          new DiscreteProbabilityCollectionSampler<Double>(rng,
84                                                           Collections.<Double>emptyList(),
85                                                           new double[0]);
86      }
87  
88      @Test
89      public void testSample() {
90          final DiscreteProbabilityCollectionSampler<Double> sampler =
91              new DiscreteProbabilityCollectionSampler<Double>(rng,
92                                                               Arrays.asList(new Double[] {3d, -1d, 3d, 7d, -2d, 8d}),
93                                                               new double[] {0.2, 0.2, 0.3, 0.3, 0, 0});
94          final double expectedMean = 3.4;
95          final double expectedVariance = 7.84;
96  
97          final int n = 100000000;
98          double sum = 0;
99          double sumOfSquares = 0;
100         for (int i = 0; i < n; i++) {
101             final double rand = sampler.sample();
102             sum += rand;
103             sumOfSquares += rand * rand;
104         }
105 
106         final double mean = sum / n;
107         Assert.assertEquals(expectedMean, mean, 1e-3);
108         final double variance = sumOfSquares / n - mean * mean;
109         Assert.assertEquals(expectedVariance, variance, 2e-3);
110     }
111 
112 
113     @Test
114     public void testSampleUsingMap() {
115         final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
116         final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
117         final List<Integer> items = Arrays.asList(1, 3, 4, 6, 9);
118         final double[] probabilities = {0.1, 0.2, 0.3, 0.4, 0.5};
119         final DiscreteProbabilityCollectionSampler<Integer> sampler1 =
120             new DiscreteProbabilityCollectionSampler<Integer>(rng1, items, probabilities);
121 
122         // Create a map version. The map iterator must be ordered so use a TreeMap.
123         final Map<Integer, Double> map = new TreeMap<Integer, Double>();
124         for (int i = 0; i < probabilities.length; i++) {
125             map.put(items.get(i), probabilities[i]);
126         }
127         final DiscreteProbabilityCollectionSampler<Integer> sampler2 =
128             new DiscreteProbabilityCollectionSampler<Integer>(rng2, map);
129 
130         for (int i = 0; i < 50; i++) {
131             Assert.assertEquals(sampler1.sample(), sampler2.sample());
132         }
133     }
134 
135     /**
136      * Edge-case test:
137      * Create a sampler that will return 1 for nextDouble() forcing the search to
138      * identify the end item of the cumulative probability array.
139      */
140     @Test
141     public void testSampleWithProbabilityAtLastItem() {
142         // Ensure the samples pick probability 0 (the first item) and then
143         // a probability (for the second item) that hits an edge case.
144         final UniformRandomProvider dummyRng = new UniformRandomProvider() {
145             private int count;
146             // CHECKSTYLE: stop all
147             public long nextLong(long n) { return 0; }
148             public long nextLong() { return 0; }
149             public int nextInt(int n) { return 0; }
150             public int nextInt() { return 0; }
151             public float nextFloat() { return 0; }
152             // Return 0 then the given probability
153             public double nextDouble() { return (count++ == 0) ? 0 : 1.0; }
154             public void nextBytes(byte[] bytes, int start, int len) {}
155             public void nextBytes(byte[] bytes) {}
156             public boolean nextBoolean() { return false; }
157             // CHECKSTYLE: resume all
158         };
159 
160         final List<Double> items = Arrays.asList(new Double[] {1d, 2d});
161         final DiscreteProbabilityCollectionSampler<Double> sampler =
162             new DiscreteProbabilityCollectionSampler<Double>(dummyRng,
163                                                              items,
164                                                              new double[] {0.5, 0.5});
165         final Double item1 = sampler.sample();
166         final Double item2 = sampler.sample();
167         // Check they are in the list
168         Assert.assertTrue("Sample item1 is not from the list", items.contains(item1));
169         Assert.assertTrue("Sample item2 is not from the list", items.contains(item2));
170         // Test the two samples are different items
171         Assert.assertNotSame("Item1 and 2 should be different", item1, item2);
172     }
173 
174     /**
175      * Test the SharedStateSampler implementation.
176      */
177     @Test
178     public void testSharedStateSampler() {
179         final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
180         final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
181         final List<Double> items = Arrays.asList(new Double[] {1d, 2d, 3d, 4d});
182         final DiscreteProbabilityCollectionSampler<Double> sampler1 =
183             new DiscreteProbabilityCollectionSampler<Double>(rng1,
184                                                              items,
185                                                              new double[] {0.1, 0.2, 0.3, 0.4});
186         final DiscreteProbabilityCollectionSampler<Double> sampler2 = sampler1.withUniformRandomProvider(rng2);
187         RandomAssert.assertProduceSameSequence(
188             new RandomAssert.Sampler<Double>() {
189                 @Override
190                 public Double sample() {
191                     return sampler1.sample();
192                 }
193             },
194             new RandomAssert.Sampler<Double>() {
195                 @Override
196                 public Double sample() {
197                     return sampler2.sample();
198                 }
199             });
200     }
201 }