View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.sampling.distribution;
19  
20  import org.apache.commons.rng.UniformRandomProvider;
21  
22  /**
23   * <a href="https://en.wikipedia.org/wiki/Ziggurat_algorithm">
24   * Marsaglia and Tsang "Ziggurat" method</a> for sampling from a Gaussian
25   * distribution with mean 0 and standard deviation 1.
26   *
27   * <p>The algorithm is explained in this
28   * <a href="http://www.jstatsoft.org/article/view/v005i08/ziggurat.pdf">paper</a>
29   * and this implementation has been adapted from the C code provided therein.</p>
30   *
31   * <p>Sampling uses:</p>
32   *
33   * <ul>
34   *   <li>{@link UniformRandomProvider#nextLong()}
35   *   <li>{@link UniformRandomProvider#nextDouble()}
36   * </ul>
37   *
38   * @since 1.1
39   */
40  public class ZigguratNormalizedGaussianSampler
41      implements NormalizedGaussianSampler, SharedStateContinuousSampler {
42      /** Start of tail. */
43      private static final double R = 3.442619855899;
44      /** Inverse of R. */
45      private static final double ONE_OVER_R = 1 / R;
46      /** Rectangle area. */
47      private static final double V = 9.91256303526217e-3;
48      /** 2^63. */
49      private static final double MAX = Math.pow(2, 63);
50      /** 2^-63. */
51      private static final double ONE_OVER_MAX = 1d / MAX;
52      /** Number of entries. */
53      private static final int LEN = 128;
54      /** Index of last entry. */
55      private static final int LAST = LEN - 1;
56      /** Auxiliary table. */
57      private static final long[] K = new long[LEN];
58      /** Auxiliary table. */
59      private static final double[] W = new double[LEN];
60      /** Auxiliary table. */
61      private static final double[] F = new double[LEN];
62      /** Underlying source of randomness. */
63      private final UniformRandomProvider rng;
64  
65      static {
66          // Filling the tables.
67  
68          double d = R;
69          double t = d;
70          double fd = gauss(d);
71          final double q = V / fd;
72  
73          K[0] = (long) ((d / q) * MAX);
74          K[1] = 0;
75  
76          W[0] = q * ONE_OVER_MAX;
77          W[LAST] = d * ONE_OVER_MAX;
78  
79          F[0] = 1;
80          F[LAST] = fd;
81  
82          for (int i = LAST - 1; i >= 1; i--) {
83              d = Math.sqrt(-2 * Math.log(V / d + fd));
84              fd = gauss(d);
85  
86              K[i + 1] = (long) ((d / t) * MAX);
87              t = d;
88  
89              F[i] = fd;
90  
91              W[i] = d * ONE_OVER_MAX;
92          }
93      }
94  
95      /**
96       * @param rng Generator of uniformly distributed random numbers.
97       */
98      public ZigguratNormalizedGaussianSampler(UniformRandomProvider rng) {
99          this.rng = rng;
100     }
101 
102     /** {@inheritDoc} */
103     @Override
104     public double sample() {
105         final long j = rng.nextLong();
106         final int i = (int) (j & LAST);
107         if (Math.abs(j) < K[i]) {
108             return j * W[i];
109         } else {
110             return fix(j, i);
111         }
112     }
113 
114     /** {@inheritDoc} */
115     @Override
116     public String toString() {
117         return "Ziggurat normalized Gaussian deviate [" + rng.toString() + "]";
118     }
119 
120     /**
121      * Gets the value from the tail of the distribution.
122      *
123      * @param hz Start random integer.
124      * @param iz Index of cell corresponding to {@code hz}.
125      * @return the requested random value.
126      */
127     private double fix(long hz,
128                        int iz) {
129         double x;
130         double y;
131 
132         x = hz * W[iz];
133         if (iz == 0) {
134             // Base strip.
135             // This branch is called about 5.7624515E-4 times per sample.
136             do {
137                 y = -Math.log(rng.nextDouble());
138                 x = -Math.log(rng.nextDouble()) * ONE_OVER_R;
139             } while (y + y < x * x);
140 
141             final double out = R + x;
142             return hz > 0 ? out : -out;
143         } else {
144             // Wedge of other strips.
145             // This branch is called about 0.027323 times per sample.
146             if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < gauss(x)) {
147                 return x;
148             } else {
149                 // Try again.
150                 // This branch is called about 0.012362 times per sample.
151                 return sample();
152             }
153         }
154     }
155 
156     /**
157      * @param x Argument.
158      * @return \( e^{-\frac{x^2}{2}} \)
159      */
160     private static double gauss(double x) {
161         return Math.exp(-0.5 * x * x);
162     }
163 
164     /**
165      * {@inheritDoc}
166      *
167      * @since 1.3
168      */
169     @Override
170     public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) {
171         return new ZigguratNormalizedGaussianSampler(rng);
172     }
173 
174     /**
175      * Create a new normalised Gaussian sampler.
176      *
177      * @param <S> Sampler type.
178      * @param rng Generator of uniformly distributed random numbers.
179      * @return the sampler
180      * @since 1.3
181      */
182     @SuppressWarnings("unchecked")
183     public static <S extends NormalizedGaussianSampler & SharedStateContinuousSampler> S
184             of(UniformRandomProvider rng) {
185         return (S) new ZigguratNormalizedGaussianSampler(rng);
186     }
187 }