View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.rng.sampling.distribution;
18  
19  import org.apache.commons.rng.RestorableUniformRandomProvider;
20  import org.apache.commons.rng.UniformRandomProvider;
21  import org.apache.commons.rng.sampling.RandomAssert;
22  import org.apache.commons.rng.sampling.SharedStateSampler;
23  import org.apache.commons.rng.simple.RandomSource;
24  import org.junit.Test;
25  
26  /**
27   * Test for the {@link GaussianSampler}. The tests hit edge cases for the sampler.
28   */
29  public class GaussianSamplerTest {
30      /**
31       * Test the constructor with a bad standard deviation.
32       */
33      @Test(expected = IllegalArgumentException.class)
34      public void testConstructorThrowsWithZeroStandardDeviation() {
35          final RestorableUniformRandomProvider rng =
36              RandomSource.create(RandomSource.SPLIT_MIX_64);
37          final NormalizedGaussianSampler gauss = new ZigguratNormalizedGaussianSampler(rng);
38          final double mean = 1;
39          final double standardDeviation = 0;
40          GaussianSampler.of(gauss, mean, standardDeviation);
41      }
42  
43      /**
44       * Test the SharedStateSampler implementation.
45       */
46      @Test
47      public void testSharedStateSampler() {
48          final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
49          final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
50          final NormalizedGaussianSampler gauss = new ZigguratNormalizedGaussianSampler(rng1);
51          final double mean = 1.23;
52          final double standardDeviation = 4.56;
53          final SharedStateContinuousSampler sampler1 =
54              GaussianSampler.of(gauss, mean, standardDeviation);
55          final SharedStateContinuousSampler sampler2 = sampler1.withUniformRandomProvider(rng2);
56          RandomAssert.assertProduceSameSequence(sampler1, sampler2);
57      }
58  
59      /**
60       * Test the SharedStateSampler implementation throws if the underlying sampler is
61       * not a SharedStateSampler.
62       */
63      @Test(expected = UnsupportedOperationException.class)
64      public void testSharedStateSamplerThrowsIfUnderlyingSamplerDoesNotShareState() {
65          final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
66          final NormalizedGaussianSampler gauss = new NormalizedGaussianSampler() {
67              @Override
68              public double sample() {
69                  return 0;
70              }
71          };
72          final double mean = 1.23;
73          final double standardDeviation = 4.56;
74          final SharedStateContinuousSampler sampler1 =
75              GaussianSampler.of(gauss, mean, standardDeviation);
76          sampler1.withUniformRandomProvider(rng2);
77      }
78  
79      /**
80       * Test the SharedStateSampler implementation throws if the underlying sampler is
81       * a SharedStateSampler that returns an incorrect type.
82       */
83      @Test(expected = UnsupportedOperationException.class)
84      public void testSharedStateSamplerThrowsIfUnderlyingSamplerReturnsWrongSharedState() {
85          final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
86          final NormalizedGaussianSampler gauss = new BadSharedStateNormalizedGaussianSampler();
87          final double mean = 1.23;
88          final double standardDeviation = 4.56;
89          final SharedStateContinuousSampler sampler1 =
90              GaussianSampler.of(gauss, mean, standardDeviation);
91          sampler1.withUniformRandomProvider(rng2);
92      }
93  
94      /**
95       * Test class to return an incorrect sampler from the SharedStateSampler method.
96       *
97       * <p>Note that due to type erasure the type returned by the SharedStateSampler is not
98       * available at run-time and the GaussianSampler has to assume it is the correct type.</p>
99       */
100     private static class BadSharedStateNormalizedGaussianSampler
101             implements NormalizedGaussianSampler, SharedStateSampler<Integer> {
102         @Override
103         public double sample() {
104             return 0;
105         }
106 
107         @Override
108         public Integer withUniformRandomProvider(UniformRandomProvider rng) {
109             // Something that is not a NormalizedGaussianSampler
110             return Integer.valueOf(44);
111         }
112     }
113 }