1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.rng.sampling.distribution;
19
20 import org.apache.commons.rng.UniformRandomProvider;
21
22
23
24
25
26
27
28
29
30
31
32
33 public class ZigguratNormalizedGaussianSampler
34 implements NormalizedGaussianSampler {
35
36 private static final double R = 3.442619855899;
37
38 private static final double ONE_OVER_R = 1 / R;
39
40 private static final double V = 9.91256303526217e-3;
41
42 private static final double MAX = Math.pow(2, 63);
43
44 private static final double ONE_OVER_MAX = 1d / MAX;
45
46 private static final int LEN = 128;
47
48 private static final int LAST = LEN - 1;
49
50 private static final long[] K = new long[LEN];
51
52 private static final double[] W = new double[LEN];
53
54 private static final double[] F = new double[LEN];
55
56 private final UniformRandomProvider rng;
57
58 static {
59
60
61 double d = R;
62 double t = d;
63 double fd = gauss(d);
64 final double q = V / fd;
65
66 K[0] = (long) ((d / q) * MAX);
67 K[1] = 0;
68
69 W[0] = q * ONE_OVER_MAX;
70 W[LAST] = d * ONE_OVER_MAX;
71
72 F[0] = 1;
73 F[LAST] = fd;
74
75 for (int i = LAST - 1; i >= 1; i--) {
76 d = Math.sqrt(-2 * Math.log(V / d + fd));
77 fd = gauss(d);
78
79 K[i + 1] = (long) ((d / t) * MAX);
80 t = d;
81
82 F[i] = fd;
83
84 W[i] = d * ONE_OVER_MAX;
85 }
86 }
87
88
89
90
91 public ZigguratNormalizedGaussianSampler(UniformRandomProvider rng) {
92 this.rng = rng;
93 }
94
95
96 @Override
97 public double sample() {
98 final long j = rng.nextLong();
99 final int i = (int) (j & LAST);
100 if (Math.abs(j) < K[i]) {
101 return j * W[i];
102 } else {
103 return fix(j, i);
104 }
105 }
106
107
108 @Override
109 public String toString() {
110 return "Ziggurat normalized Gaussian deviate [" + rng.toString() + "]";
111 }
112
113
114
115
116
117
118
119
120 private double fix(long hz,
121 int iz) {
122 double x;
123 double y;
124
125 while (true) {
126 x = hz * W[iz];
127 if (iz == 0) {
128
129 do {
130 y = -Math.log(rng.nextDouble());
131 x = -Math.log(rng.nextDouble()) * ONE_OVER_R;
132 } while (y + y < x * x);
133
134 final double out = R + x;
135 return hz > 0 ? out : -out;
136 } else {
137
138 if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < gauss(x)) {
139 return x;
140 } else {
141 final long hzNew = rng.nextLong();
142 final int izNew = (int) (hzNew & LAST);
143 if (Math.abs(hzNew) < K[izNew]) {
144 return hzNew * W[izNew];
145 }
146 }
147 }
148 }
149 }
150
151
152
153
154
155 private static double gauss(double x) {
156 return Math.exp(-0.5 * x * x);
157 }
158 }