1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.rng.UniformRandomProvider;
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44 public class AhrensDieterMarsagliaTsangGammaSampler
45 extends SamplerBase
46 implements ContinuousSampler {
47
48 private static final double ONE_THIRD = 1d / 3;
49
50 private final double theta;
51
52 private final double alpha;
53
54 private final double oneOverTheta;
55
56 private final double bGSOptim;
57
58 private final double dOptim;
59
60 private final double cOptim;
61
62 private final NormalizedGaussianSampler gaussian;
63
64 private final UniformRandomProvider rng;
65
66
67
68
69
70
71 public AhrensDieterMarsagliaTsangGammaSampler(UniformRandomProvider rng,
72 double alpha,
73 double theta) {
74 super(null);
75 this.rng = rng;
76 this.alpha = alpha;
77 this.theta = theta;
78 gaussian = new ZigguratNormalizedGaussianSampler(rng);
79 oneOverTheta = 1 / theta;
80 bGSOptim = 1 + theta / Math.E;
81 dOptim = theta - ONE_THIRD;
82 cOptim = ONE_THIRD / Math.sqrt(dOptim);
83 }
84
85
86 @Override
87 public double sample() {
88 if (theta < 1) {
89
90
91 while (true) {
92
93 final double u = rng.nextDouble();
94 final double p = bGSOptim * u;
95
96 if (p <= 1) {
97
98
99 final double x = Math.pow(p, oneOverTheta);
100 final double u2 = rng.nextDouble();
101
102 if (u2 > Math.exp(-x)) {
103
104 continue;
105 } else {
106 return alpha * x;
107 }
108 } else {
109
110
111 final double x = -Math.log((bGSOptim - p) * oneOverTheta);
112 final double u2 = rng.nextDouble();
113
114 if (u2 > Math.pow(x, theta - 1)) {
115
116 continue;
117 } else {
118 return alpha * x;
119 }
120 }
121 }
122 } else {
123 while (true) {
124 final double x = gaussian.sample();
125 final double oPcTx = 1 + cOptim * x;
126 final double v = oPcTx * oPcTx * oPcTx;
127
128 if (v <= 0) {
129 continue;
130 }
131
132 final double x2 = x * x;
133 final double u = rng.nextDouble();
134
135
136 if (u < 1 - 0.0331 * x2 * x2) {
137 return alpha * dOptim * v;
138 }
139
140 if (Math.log(u) < 0.5 * x2 + dOptim * (1 - v + Math.log(v))) {
141 return alpha * dOptim * v;
142 }
143 }
144 }
145 }
146
147
148 @Override
149 public String toString() {
150 return "Ahrens-Dieter-Marsaglia-Tsang Gamma deviate [" + rng.toString() + "]";
151 }
152 }