1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.rng.sampling.distribution;
18
19 import org.apache.commons.rng.UniformRandomProvider;
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 public class ChengBetaSampler
35 extends SamplerBase
36 implements ContinuousSampler {
37
38 private final double alphaShape;
39
40 private final double betaShape;
41
42 private final UniformRandomProvider rng;
43
44
45
46
47
48
49
50
51 public ChengBetaSampler(UniformRandomProvider rng,
52 double alpha,
53 double beta) {
54 super(null);
55 this.rng = rng;
56 alphaShape = alpha;
57 betaShape = beta;
58 }
59
60
61 @Override
62 public double sample() {
63 final double a = Math.min(alphaShape, betaShape);
64 final double b = Math.max(alphaShape, betaShape);
65
66 if (a > 1) {
67 return algorithmBB(a, b);
68 } else {
69 return algorithmBC(b, a);
70 }
71 }
72
73
74 @Override
75 public String toString() {
76 return "Cheng Beta deviate [" + rng.toString() + "]";
77 }
78
79
80
81
82
83
84
85
86
87 private double algorithmBB(double a,
88 double b) {
89 final double alpha = a + b;
90 final double beta = Math.sqrt((alpha - 2) / (2 * a * b - alpha));
91 final double gamma = a + 1 / beta;
92
93 double r;
94 double w;
95 double t;
96 do {
97 final double u1 = rng.nextDouble();
98 final double u2 = rng.nextDouble();
99 final double v = beta * (Math.log(u1) - Math.log1p(-u1));
100 w = a * Math.exp(v);
101 final double z = u1 * u1 * u2;
102 r = gamma * v - 1.3862944;
103 final double s = a + r - w;
104 if (s + 2.609438 >= 5 * z) {
105 break;
106 }
107
108 t = Math.log(z);
109 if (s >= t) {
110 break;
111 }
112 } while (r + alpha * (Math.log(alpha) - Math.log(b + w)) < t);
113
114 w = Math.min(w, Double.MAX_VALUE);
115
116 return equals(a, alphaShape) ? w / (b + w) : b / (b + w);
117 }
118
119
120
121
122
123
124
125
126
127 private double algorithmBC(double a,
128 double b) {
129 final double alpha = a + b;
130 final double beta = 1 / b;
131 final double delta = 1 + a - b;
132 final double k1 = delta * (0.0138889 + 0.0416667 * b) / (a * beta - 0.777778);
133 final double k2 = 0.25 + (0.5 + 0.25 / delta) * b;
134
135 double w;
136 while (true) {
137 final double u1 = rng.nextDouble();
138 final double u2 = rng.nextDouble();
139 final double y = u1 * u2;
140 final double z = u1 * y;
141 if (u1 < 0.5) {
142 if (0.25 * u2 + z - y >= k1) {
143 continue;
144 }
145 } else {
146 if (z <= 0.25) {
147 final double v = beta * (Math.log(u1) - Math.log1p(-u1));
148 w = a * Math.exp(v);
149 break;
150 }
151
152 if (z >= k2) {
153 continue;
154 }
155 }
156
157 final double v = beta * (Math.log(u1) - Math.log1p(-u1));
158 w = a * Math.exp(v);
159 if (alpha * (Math.log(alpha) - Math.log(b + w) + v) - 1.3862944 >= Math.log(z)) {
160 break;
161 }
162 }
163
164 w = Math.min(w, Double.MAX_VALUE);
165
166 return equals(a, alphaShape) ? w / (b + w) : b / (b + w);
167 }
168
169
170
171
172
173
174 private boolean equals(double a,
175 double b) {
176 return Math.abs(a - b) <= Double.MIN_VALUE;
177 }
178 }