View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.examples.sampling;
19  
20  import java.io.PrintWriter;
21  import java.util.EnumSet;
22  import java.util.concurrent.Callable;
23  import java.io.IOException;
24  import org.apache.commons.rng.UniformRandomProvider;
25  import org.apache.commons.rng.simple.RandomSource;
26  
27  import picocli.CommandLine.Command;
28  import picocli.CommandLine.Mixin;
29  import picocli.CommandLine.Option;
30  
31  import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
32  import org.apache.commons.rng.sampling.distribution.MarsagliaNormalizedGaussianSampler;
33  import org.apache.commons.rng.sampling.distribution.BoxMullerNormalizedGaussianSampler;
34  import org.apache.commons.rng.sampling.distribution.ChengBetaSampler;
35  import org.apache.commons.rng.sampling.distribution.AhrensDieterExponentialSampler;
36  import org.apache.commons.rng.sampling.distribution.AhrensDieterMarsagliaTsangGammaSampler;
37  import org.apache.commons.rng.sampling.distribution.InverseTransformParetoSampler;
38  import org.apache.commons.rng.sampling.distribution.LogNormalSampler;
39  import org.apache.commons.rng.sampling.distribution.ContinuousUniformSampler;
40  import org.apache.commons.rng.sampling.distribution.GaussianSampler;
41  import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
42  
43  /**
44   * Approximation of the probability density by the histogram of the sampler output.
45   */
46  @Command(name = "density",
47           description = {"Approximate the probability density of samplers."})
48  class ProbabilityDensityApproximationCommand  implements Callable<Void> {
49      /** The standard options. */
50      @Mixin
51      private StandardOptions reusableOptions;
52  
53      /** Number of (equal-width) bins in the histogram. */
54      @Option(names = {"-b", "--bins"},
55              description = "The number of bins in the histogram (default: ${DEFAULT-VALUE}).")
56      private int numBins = 25_000;
57  
58      /** Number of samples to be generated. */
59      @Option(names = {"-n", "--samples"},
60              description = "The number of samples in the histogram (default: ${DEFAULT-VALUE}).")
61      private long numSamples = 1_000_000_000;
62  
63      /** The samplers. */
64      @Option(names = {"-s", "--samplers"},
65              split = ",",
66              description = {"The samplers (comma-delimited for multiple options).",
67                            "Valid values: ${COMPLETION-CANDIDATES}."})
68      private EnumSet<Sampler> samplers = EnumSet.noneOf(Sampler.class);
69  
70      /** Flag to output all samplers. */
71      @Option(names = {"-a", "--all"},
72              description = "Output all samplers")
73      private boolean allSamplers;
74  
75      /**
76       * The sampler. This enum uses lower case for clarity when matching the distribution name.
77       */
78      enum Sampler {
79          /** The Ziggurat gaussian sampler. */
80          ZigguratGaussianSampler,
81          /** The Marsaglia gaussian sampler. */
82          MarsagliaGaussianSampler,
83          /** The Box muller gaussian sampler. */
84          BoxMullerGaussianSampler,
85          /** The Cheng beta sampler case 1. */
86          ChengBetaSamplerCase1,
87          /** The Cheng beta sampler case 2. */
88          ChengBetaSamplerCase2,
89          /** The Ahrens dieter exponential sampler. */
90          AhrensDieterExponentialSampler,
91          /** The Ahrens dieter marsaglia tsang gamma sampler small gamma. */
92          AhrensDieterMarsagliaTsangGammaSamplerCase1,
93          /** The Ahrens dieter marsaglia tsang gamma sampler large gamma. */
94          AhrensDieterMarsagliaTsangGammaSamplerCase2,
95          /** The Inverse transform pareto sampler. */
96          InverseTransformParetoSampler,
97          /** The Continuous uniform sampler. */
98          ContinuousUniformSampler,
99          /** The Log normal ziggurat gaussian sampler. */
100         LogNormalZigguratGaussianSampler,
101         /** The Log normal marsaglia gaussian sampler. */
102         LogNormalMarsagliaGaussianSampler,
103         /** The Log normal box muller gaussian sampler. */
104         LogNormalBoxMullerGaussianSampler,
105     }
106 
107     /**
108      * @param sampler Sampler.
109      * @param min Right abscissa of the first bin: every sample smaller
110      * than that value will increment an additional bin (of infinite width)
111      * placed before the first "equal-width" bin.
112      * @param max abscissa of the last bin: every sample larger than or
113      * equal to that value will increment an additional bin (of infinite
114      * width) placed after the last "equal-width" bin.
115      * @param outputFile Filename.
116      * @throws IOException Signals that an I/O exception has occurred.
117      */
118     private void createDensity(ContinuousSampler sampler,
119                                double min,
120                                double max,
121                                String outputFile)
122         throws IOException {
123         final double binSize = (max - min) / numBins;
124         final long[] histogram = new long[numBins];
125 
126         long belowMin = 0;
127         long aboveMax = 0;
128         for (long n = 0; n < numSamples; n++) {
129             final double r = sampler.sample();
130 
131             if (r < min) {
132                 ++belowMin;
133                 continue;
134             }
135 
136             if (r >= max) {
137                 ++aboveMax;
138                 continue;
139             }
140 
141             final int binIndex = (int) ((r - min) / binSize);
142             ++histogram[binIndex];
143         }
144 
145         final double binHalfSize = 0.5 * binSize;
146         final double norm = 1 / (binSize * numSamples);
147 
148         try (PrintWriter out = new PrintWriter(outputFile)) {
149             // CHECKSTYLE: stop MultipleStringLiteralsCheck
150             out.println("# Sampler: " + sampler);
151             out.println("# Number of bins: " + numBins);
152             out.println("# Min: " + min + " (fraction of samples below: " + (belowMin / (double) numSamples) + ")");
153             out.println("# Max: " + max + " (fraction of samples above: " + (aboveMax / (double) numSamples) + ")");
154             out.println("# Bin width: " + binSize);
155             out.println("# Histogram normalization factor: " + norm);
156             out.println("#");
157             out.println("# " + (min - binHalfSize) + " " + (belowMin * norm));
158             for (int i = 0; i < numBins; i++) {
159                 out.println((min + (i + 1) * binSize - binHalfSize) + " " + (histogram[i] * norm));
160             }
161             out.println("# " + (max + binHalfSize) + " " + (aboveMax * norm));
162             // CHECKSTYLE: resume MultipleStringLiteralsCheck
163         }
164     }
165 
166     /**
167      * Program entry point.
168      *
169      * @throws IOException if failure occurred while writing to files.
170      */
171     @Override
172     public Void call() throws IOException {
173         if (allSamplers) {
174             samplers = EnumSet.allOf(Sampler.class);
175         } else if (samplers.isEmpty()) {
176             // CHECKSTYLE: stop regexp
177             System.err.println("ERROR: No samplers specified");
178             // CHECKSTYLE: resume regexp
179             System.exit(1);
180         }
181 
182         final UniformRandomProvider rng = RandomSource.create(RandomSource.XOR_SHIFT_1024_S_PHI);
183 
184         final double gaussMean = 1;
185         final double gaussSigma = 2;
186         final double gaussMin = -9;
187         final double gaussMax = 11;
188         if (samplers.contains(Sampler.ZigguratGaussianSampler)) {
189             createDensity(GaussianSampler.of(ZigguratNormalizedGaussianSampler.of(rng),
190                                              gaussMean, gaussSigma),
191                           gaussMin, gaussMax, "gauss.ziggurat.txt");
192         }
193         if (samplers.contains(Sampler.MarsagliaGaussianSampler)) {
194             createDensity(GaussianSampler.of(MarsagliaNormalizedGaussianSampler.of(rng),
195                                              gaussMean, gaussSigma),
196                           gaussMin, gaussMax, "gauss.marsaglia.txt");
197         }
198         if (samplers.contains(Sampler.BoxMullerGaussianSampler)) {
199             createDensity(GaussianSampler.of(BoxMullerNormalizedGaussianSampler.of(rng),
200                                              gaussMean, gaussSigma),
201                           gaussMin, gaussMax, "gauss.boxmuller.txt");
202         }
203 
204         final double betaMin = 0;
205         final double betaMax = 1;
206         if (samplers.contains(Sampler.ChengBetaSamplerCase1)) {
207             final double alphaBeta = 4.3;
208             final double betaBeta = 2.1;
209             createDensity(ChengBetaSampler.of(rng, alphaBeta, betaBeta),
210                           betaMin, betaMax, "beta.case1.txt");
211         }
212         if (samplers.contains(Sampler.ChengBetaSamplerCase2)) {
213             final double alphaBetaAlt = 0.5678;
214             final double betaBetaAlt = 0.1234;
215             createDensity(ChengBetaSampler.of(rng, alphaBetaAlt, betaBetaAlt),
216                           betaMin, betaMax, "beta.case2.txt");
217         }
218 
219         if (samplers.contains(Sampler.AhrensDieterExponentialSampler)) {
220             final double meanExp = 3.45;
221             final double expMin = 0;
222             final double expMax = 60;
223             createDensity(AhrensDieterExponentialSampler.of(rng, meanExp),
224                           expMin, expMax, "exp.txt");
225         }
226 
227         final double gammaMin = 0;
228         final double gammaMax1 = 40;
229         final double thetaGamma = 3.456;
230         if (samplers.contains(Sampler.AhrensDieterMarsagliaTsangGammaSamplerCase1)) {
231             final double alphaGammaSmallerThanOne = 0.1234;
232             createDensity(AhrensDieterMarsagliaTsangGammaSampler.of(rng, alphaGammaSmallerThanOne, thetaGamma),
233                           gammaMin, gammaMax1, "gamma.case1.txt");
234         }
235         if (samplers.contains(Sampler.AhrensDieterMarsagliaTsangGammaSamplerCase2)) {
236             final double alphaGammaLargerThanOne = 2.345;
237             final double gammaMax2 = 70;
238             createDensity(AhrensDieterMarsagliaTsangGammaSampler.of(rng, alphaGammaLargerThanOne, thetaGamma),
239                           gammaMin, gammaMax2, "gamma.case2.txt");
240         }
241 
242         final double scalePareto = 23.45;
243         final double shapePareto = 0.789;
244         final double paretoMin = 23;
245         final double paretoMax = 400;
246         if (samplers.contains(Sampler.InverseTransformParetoSampler)) {
247             createDensity(InverseTransformParetoSampler.of(rng, scalePareto, shapePareto),
248                           paretoMin, paretoMax, "pareto.txt");
249         }
250 
251         final double loUniform = -9.876;
252         final double hiUniform = 5.432;
253         if (samplers.contains(Sampler.ContinuousUniformSampler)) {
254             createDensity(ContinuousUniformSampler.of(rng, loUniform, hiUniform),
255                           loUniform, hiUniform, "uniform.txt");
256         }
257 
258         final double scaleLogNormal = 2.345;
259         final double shapeLogNormal = 0.1234;
260         final double logNormalMin = 5;
261         final double logNormalMax = 25;
262         if (samplers.contains(Sampler.LogNormalZigguratGaussianSampler)) {
263             createDensity(LogNormalSampler.of(ZigguratNormalizedGaussianSampler.of(rng),
264                                               scaleLogNormal, shapeLogNormal),
265                           logNormalMin, logNormalMax, "lognormal.ziggurat.txt");
266         }
267         if (samplers.contains(Sampler.LogNormalMarsagliaGaussianSampler)) {
268             createDensity(LogNormalSampler.of(MarsagliaNormalizedGaussianSampler.of(rng),
269                                               scaleLogNormal, shapeLogNormal),
270                           logNormalMin, logNormalMax, "lognormal.marsaglia.txt");
271         }
272         if (samplers.contains(Sampler.LogNormalBoxMullerGaussianSampler)) {
273             createDensity(LogNormalSampler.of(BoxMullerNormalizedGaussianSampler.of(rng),
274                                               scaleLogNormal, shapeLogNormal),
275                           logNormalMin, logNormalMax, "lognormal.boxmuller.txt");
276         }
277 
278         return null;
279     }
280 }