1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.rng.sampling.distribution;
19
20 import org.apache.commons.rng.UniformRandomProvider;
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40 public class ZigguratNormalizedGaussianSampler
41 implements NormalizedGaussianSampler, SharedStateContinuousSampler {
42
43 private static final double R = 3.442619855899;
44
45 private static final double ONE_OVER_R = 1 / R;
46
47 private static final double V = 9.91256303526217e-3;
48
49 private static final double MAX = Math.pow(2, 63);
50
51 private static final double ONE_OVER_MAX = 1d / MAX;
52
53 private static final int LEN = 128;
54
55 private static final int LAST = LEN - 1;
56
57 private static final long[] K = new long[LEN];
58
59 private static final double[] W = new double[LEN];
60
61 private static final double[] F = new double[LEN];
62
63 private final UniformRandomProvider rng;
64
65 static {
66
67
68 double d = R;
69 double t = d;
70 double fd = gauss(d);
71 final double q = V / fd;
72
73 K[0] = (long) ((d / q) * MAX);
74 K[1] = 0;
75
76 W[0] = q * ONE_OVER_MAX;
77 W[LAST] = d * ONE_OVER_MAX;
78
79 F[0] = 1;
80 F[LAST] = fd;
81
82 for (int i = LAST - 1; i >= 1; i--) {
83 d = Math.sqrt(-2 * Math.log(V / d + fd));
84 fd = gauss(d);
85
86 K[i + 1] = (long) ((d / t) * MAX);
87 t = d;
88
89 F[i] = fd;
90
91 W[i] = d * ONE_OVER_MAX;
92 }
93 }
94
95
96
97
98 public ZigguratNormalizedGaussianSampler(UniformRandomProvider rng) {
99 this.rng = rng;
100 }
101
102
103 @Override
104 public double sample() {
105 final long j = rng.nextLong();
106 final int i = (int) (j & LAST);
107 if (Math.abs(j) < K[i]) {
108 return j * W[i];
109 } else {
110 return fix(j, i);
111 }
112 }
113
114
115 @Override
116 public String toString() {
117 return "Ziggurat normalized Gaussian deviate [" + rng.toString() + "]";
118 }
119
120
121
122
123
124
125
126
127 private double fix(long hz,
128 int iz) {
129 double x;
130 double y;
131
132 x = hz * W[iz];
133 if (iz == 0) {
134
135
136 do {
137 y = -Math.log(rng.nextDouble());
138 x = -Math.log(rng.nextDouble()) * ONE_OVER_R;
139 } while (y + y < x * x);
140
141 final double out = R + x;
142 return hz > 0 ? out : -out;
143 } else {
144
145
146 if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < gauss(x)) {
147 return x;
148 } else {
149
150
151 return sample();
152 }
153 }
154 }
155
156
157
158
159
160 private static double gauss(double x) {
161 return Math.exp(-0.5 * x * x);
162 }
163
164
165
166
167
168
169 @Override
170 public SharedStateContinuousSampler withUniformRandomProvider(UniformRandomProvider rng) {
171 return new ZigguratNormalizedGaussianSampler(rng);
172 }
173
174
175
176
177
178
179
180
181
182 @SuppressWarnings("unchecked")
183 public static <S extends NormalizedGaussianSampler & SharedStateContinuousSampler> S
184 of(UniformRandomProvider rng) {
185 return (S) new ZigguratNormalizedGaussianSampler(rng);
186 }
187 }