1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.rng.examples.sampling;
19
20 import java.io.PrintWriter;
21 import java.util.EnumSet;
22 import java.util.concurrent.Callable;
23 import java.io.IOException;
24 import org.apache.commons.rng.UniformRandomProvider;
25 import org.apache.commons.rng.simple.RandomSource;
26
27 import picocli.CommandLine.Command;
28 import picocli.CommandLine.Mixin;
29 import picocli.CommandLine.Option;
30
31 import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
32 import org.apache.commons.rng.sampling.distribution.MarsagliaNormalizedGaussianSampler;
33 import org.apache.commons.rng.sampling.distribution.BoxMullerNormalizedGaussianSampler;
34 import org.apache.commons.rng.sampling.distribution.ChengBetaSampler;
35 import org.apache.commons.rng.sampling.distribution.AhrensDieterExponentialSampler;
36 import org.apache.commons.rng.sampling.distribution.AhrensDieterMarsagliaTsangGammaSampler;
37 import org.apache.commons.rng.sampling.distribution.InverseTransformParetoSampler;
38 import org.apache.commons.rng.sampling.distribution.LogNormalSampler;
39 import org.apache.commons.rng.sampling.distribution.ContinuousUniformSampler;
40 import org.apache.commons.rng.sampling.distribution.GaussianSampler;
41 import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
42
43
44
45
46 @Command(name = "density",
47 description = {"Approximate the probability density of samplers."})
48 class ProbabilityDensityApproximationCommand implements Callable<Void> {
49
50 @Mixin
51 private StandardOptions reusableOptions;
52
53
54 @Option(names = {"-b", "--bins"},
55 description = "The number of bins in the histogram (default: ${DEFAULT-VALUE}).")
56 private int numBins = 25_000;
57
58
59 @Option(names = {"-n", "--samples"},
60 description = "The number of samples in the histogram (default: ${DEFAULT-VALUE}).")
61 private long numSamples = 1_000_000_000;
62
63
64 @Option(names = {"-s", "--samplers"},
65 split = ",",
66 description = {"The samplers (comma-delimited for multiple options).",
67 "Valid values: ${COMPLETION-CANDIDATES}."})
68 private EnumSet<Sampler> samplers = EnumSet.noneOf(Sampler.class);
69
70
71 @Option(names = {"-a", "--all"},
72 description = "Output all samplers")
73 private boolean allSamplers;
74
75
76
77
78 enum Sampler {
79
80 ZigguratGaussianSampler,
81
82 MarsagliaGaussianSampler,
83
84 BoxMullerGaussianSampler,
85
86 ChengBetaSamplerCase1,
87
88 ChengBetaSamplerCase2,
89
90 AhrensDieterExponentialSampler,
91
92 AhrensDieterMarsagliaTsangGammaSamplerCase1,
93
94 AhrensDieterMarsagliaTsangGammaSamplerCase2,
95
96 InverseTransformParetoSampler,
97
98 ContinuousUniformSampler,
99
100 LogNormalZigguratGaussianSampler,
101
102 LogNormalMarsagliaGaussianSampler,
103
104 LogNormalBoxMullerGaussianSampler,
105 }
106
107
108
109
110
111
112
113
114
115
116
117
118 private void createDensity(ContinuousSampler sampler,
119 double min,
120 double max,
121 String outputFile)
122 throws IOException {
123 final double binSize = (max - min) / numBins;
124 final long[] histogram = new long[numBins];
125
126 long belowMin = 0;
127 long aboveMax = 0;
128 for (long n = 0; n < numSamples; n++) {
129 final double r = sampler.sample();
130
131 if (r < min) {
132 ++belowMin;
133 continue;
134 }
135
136 if (r >= max) {
137 ++aboveMax;
138 continue;
139 }
140
141 final int binIndex = (int) ((r - min) / binSize);
142 ++histogram[binIndex];
143 }
144
145 final double binHalfSize = 0.5 * binSize;
146 final double norm = 1 / (binSize * numSamples);
147
148 try (PrintWriter out = new PrintWriter(outputFile)) {
149
150 out.println("# Sampler: " + sampler);
151 out.println("# Number of bins: " + numBins);
152 out.println("# Min: " + min + " (fraction of samples below: " + (belowMin / (double) numSamples) + ")");
153 out.println("# Max: " + max + " (fraction of samples above: " + (aboveMax / (double) numSamples) + ")");
154 out.println("# Bin width: " + binSize);
155 out.println("# Histogram normalization factor: " + norm);
156 out.println("#");
157 out.println("# " + (min - binHalfSize) + " " + (belowMin * norm));
158 for (int i = 0; i < numBins; i++) {
159 out.println((min + (i + 1) * binSize - binHalfSize) + " " + (histogram[i] * norm));
160 }
161 out.println("# " + (max + binHalfSize) + " " + (aboveMax * norm));
162
163 }
164 }
165
166
167
168
169
170
171 @Override
172 public Void call() throws IOException {
173 if (allSamplers) {
174 samplers = EnumSet.allOf(Sampler.class);
175 } else if (samplers.isEmpty()) {
176
177 System.err.println("ERROR: No samplers specified");
178
179 System.exit(1);
180 }
181
182 final UniformRandomProvider rng = RandomSource.create(RandomSource.XOR_SHIFT_1024_S_PHI);
183
184 final double gaussMean = 1;
185 final double gaussSigma = 2;
186 final double gaussMin = -9;
187 final double gaussMax = 11;
188 if (samplers.contains(Sampler.ZigguratGaussianSampler)) {
189 createDensity(GaussianSampler.of(ZigguratNormalizedGaussianSampler.of(rng),
190 gaussMean, gaussSigma),
191 gaussMin, gaussMax, "gauss.ziggurat.txt");
192 }
193 if (samplers.contains(Sampler.MarsagliaGaussianSampler)) {
194 createDensity(GaussianSampler.of(MarsagliaNormalizedGaussianSampler.of(rng),
195 gaussMean, gaussSigma),
196 gaussMin, gaussMax, "gauss.marsaglia.txt");
197 }
198 if (samplers.contains(Sampler.BoxMullerGaussianSampler)) {
199 createDensity(GaussianSampler.of(BoxMullerNormalizedGaussianSampler.of(rng),
200 gaussMean, gaussSigma),
201 gaussMin, gaussMax, "gauss.boxmuller.txt");
202 }
203
204 final double betaMin = 0;
205 final double betaMax = 1;
206 if (samplers.contains(Sampler.ChengBetaSamplerCase1)) {
207 final double alphaBeta = 4.3;
208 final double betaBeta = 2.1;
209 createDensity(ChengBetaSampler.of(rng, alphaBeta, betaBeta),
210 betaMin, betaMax, "beta.case1.txt");
211 }
212 if (samplers.contains(Sampler.ChengBetaSamplerCase2)) {
213 final double alphaBetaAlt = 0.5678;
214 final double betaBetaAlt = 0.1234;
215 createDensity(ChengBetaSampler.of(rng, alphaBetaAlt, betaBetaAlt),
216 betaMin, betaMax, "beta.case2.txt");
217 }
218
219 if (samplers.contains(Sampler.AhrensDieterExponentialSampler)) {
220 final double meanExp = 3.45;
221 final double expMin = 0;
222 final double expMax = 60;
223 createDensity(AhrensDieterExponentialSampler.of(rng, meanExp),
224 expMin, expMax, "exp.txt");
225 }
226
227 final double gammaMin = 0;
228 final double gammaMax1 = 40;
229 final double thetaGamma = 3.456;
230 if (samplers.contains(Sampler.AhrensDieterMarsagliaTsangGammaSamplerCase1)) {
231 final double alphaGammaSmallerThanOne = 0.1234;
232 createDensity(AhrensDieterMarsagliaTsangGammaSampler.of(rng, alphaGammaSmallerThanOne, thetaGamma),
233 gammaMin, gammaMax1, "gamma.case1.txt");
234 }
235 if (samplers.contains(Sampler.AhrensDieterMarsagliaTsangGammaSamplerCase2)) {
236 final double alphaGammaLargerThanOne = 2.345;
237 final double gammaMax2 = 70;
238 createDensity(AhrensDieterMarsagliaTsangGammaSampler.of(rng, alphaGammaLargerThanOne, thetaGamma),
239 gammaMin, gammaMax2, "gamma.case2.txt");
240 }
241
242 final double scalePareto = 23.45;
243 final double shapePareto = 0.789;
244 final double paretoMin = 23;
245 final double paretoMax = 400;
246 if (samplers.contains(Sampler.InverseTransformParetoSampler)) {
247 createDensity(InverseTransformParetoSampler.of(rng, scalePareto, shapePareto),
248 paretoMin, paretoMax, "pareto.txt");
249 }
250
251 final double loUniform = -9.876;
252 final double hiUniform = 5.432;
253 if (samplers.contains(Sampler.ContinuousUniformSampler)) {
254 createDensity(ContinuousUniformSampler.of(rng, loUniform, hiUniform),
255 loUniform, hiUniform, "uniform.txt");
256 }
257
258 final double scaleLogNormal = 2.345;
259 final double shapeLogNormal = 0.1234;
260 final double logNormalMin = 5;
261 final double logNormalMax = 25;
262 if (samplers.contains(Sampler.LogNormalZigguratGaussianSampler)) {
263 createDensity(LogNormalSampler.of(ZigguratNormalizedGaussianSampler.of(rng),
264 scaleLogNormal, shapeLogNormal),
265 logNormalMin, logNormalMax, "lognormal.ziggurat.txt");
266 }
267 if (samplers.contains(Sampler.LogNormalMarsagliaGaussianSampler)) {
268 createDensity(LogNormalSampler.of(MarsagliaNormalizedGaussianSampler.of(rng),
269 scaleLogNormal, shapeLogNormal),
270 logNormalMin, logNormalMax, "lognormal.marsaglia.txt");
271 }
272 if (samplers.contains(Sampler.LogNormalBoxMullerGaussianSampler)) {
273 createDensity(LogNormalSampler.of(BoxMullerNormalizedGaussianSampler.of(rng),
274 scaleLogNormal, shapeLogNormal),
275 logNormalMin, logNormalMax, "lognormal.boxmuller.txt");
276 }
277
278 return null;
279 }
280 }