View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.examples.jmh.sampling.distribution;
19  
20  import org.apache.commons.rng.UniformRandomProvider;
21  import org.apache.commons.rng.examples.jmh.RandomSources;
22  import org.apache.commons.rng.sampling.distribution.AhrensDieterExponentialSampler;
23  import org.apache.commons.rng.sampling.distribution.AhrensDieterMarsagliaTsangGammaSampler;
24  import org.apache.commons.rng.sampling.distribution.BoxMullerNormalizedGaussianSampler;
25  import org.apache.commons.rng.sampling.distribution.ChengBetaSampler;
26  import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
27  import org.apache.commons.rng.sampling.distribution.ContinuousUniformSampler;
28  import org.apache.commons.rng.sampling.distribution.InverseTransformParetoSampler;
29  import org.apache.commons.rng.sampling.distribution.LogNormalSampler;
30  import org.apache.commons.rng.sampling.distribution.MarsagliaNormalizedGaussianSampler;
31  import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
32  
33  import org.openjdk.jmh.annotations.Benchmark;
34  import org.openjdk.jmh.annotations.BenchmarkMode;
35  import org.openjdk.jmh.annotations.Fork;
36  import org.openjdk.jmh.annotations.Measurement;
37  import org.openjdk.jmh.annotations.Mode;
38  import org.openjdk.jmh.annotations.OutputTimeUnit;
39  import org.openjdk.jmh.annotations.Param;
40  import org.openjdk.jmh.annotations.Scope;
41  import org.openjdk.jmh.annotations.Setup;
42  import org.openjdk.jmh.annotations.State;
43  import org.openjdk.jmh.annotations.Warmup;
44  
45  import java.util.concurrent.TimeUnit;
46  
47  /**
48   * Executes benchmark to compare the speed of generation of random numbers
49   * from the various source providers for different types of {@link ContinuousSampler}.
50   */
51  @BenchmarkMode(Mode.AverageTime)
52  @OutputTimeUnit(TimeUnit.NANOSECONDS)
53  @Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
54  @Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
55  @State(Scope.Benchmark)
56  @Fork(value = 1, jvmArgs = {"-server", "-Xms128M", "-Xmx128M"})
57  public class ContinuousSamplersPerformance {
58      /**
59       * The value.
60       *
61       * <p>This must NOT be final!</p>
62       */
63      private double value;
64  
65      /**
66       * The {@link ContinuousSampler} samplers to use for testing. Creates the sampler for each
67       * {@link org.apache.commons.rng.simple.RandomSource RandomSource} in the default
68       * {@link RandomSources}.
69       */
70      @State(Scope.Benchmark)
71      public static class Sources extends RandomSources {
72          /**
73           * The sampler type.
74           */
75          @Param({"BoxMullerNormalizedGaussianSampler",
76                  "MarsagliaNormalizedGaussianSampler",
77                  "ZigguratNormalizedGaussianSampler",
78                  "AhrensDieterExponentialSampler",
79                  "AhrensDieterGammaSampler",
80                  "MarsagliaTsangGammaSampler",
81                  "LogNormalBoxMullerNormalizedGaussianSampler",
82                  "LogNormalMarsagliaNormalizedGaussianSampler",
83                  "LogNormalZigguratNormalizedGaussianSampler",
84                  "ChengBetaSampler",
85                  "ContinuousUniformSampler",
86                  "InverseTransformParetoSampler",
87                  })
88          private String samplerType;
89  
90          /** The sampler. */
91          private ContinuousSampler sampler;
92  
93          /**
94           * @return the sampler.
95           */
96          public ContinuousSampler getSampler() {
97              return sampler;
98          }
99  
100         /** Instantiates sampler. */
101         @Override
102         @Setup
103         public void setup() {
104             super.setup();
105             final UniformRandomProvider rng = getGenerator();
106             if ("BoxMullerNormalizedGaussianSampler".equals(samplerType)) {
107                 sampler = BoxMullerNormalizedGaussianSampler.of(rng);
108             } else if ("MarsagliaNormalizedGaussianSampler".equals(samplerType)) {
109                 sampler = MarsagliaNormalizedGaussianSampler.of(rng);
110             } else if ("ZigguratNormalizedGaussianSampler".equals(samplerType)) {
111                 sampler = ZigguratNormalizedGaussianSampler.of(rng);
112             } else if ("AhrensDieterExponentialSampler".equals(samplerType)) {
113                 sampler = AhrensDieterExponentialSampler.of(rng, 4.56);
114             } else if ("AhrensDieterGammaSampler".equals(samplerType)) {
115                 // This tests the Ahrens-Dieter algorithm since alpha < 1
116                 sampler = AhrensDieterMarsagliaTsangGammaSampler.of(rng, 0.76, 9.8);
117             } else if ("MarsagliaTsangGammaSampler".equals(samplerType)) {
118                 // This tests the Marsaglia-Tsang algorithm since alpha > 1
119                 sampler = AhrensDieterMarsagliaTsangGammaSampler.of(rng, 12.34, 9.8);
120             } else if ("LogNormalBoxMullerNormalizedGaussianSampler".equals(samplerType)) {
121                 sampler = LogNormalSampler.of(BoxMullerNormalizedGaussianSampler.of(rng), 12.3, 4.6);
122             } else if ("LogNormalMarsagliaNormalizedGaussianSampler".equals(samplerType)) {
123                 sampler = LogNormalSampler.of(MarsagliaNormalizedGaussianSampler.of(rng), 12.3, 4.6);
124             } else if ("LogNormalZigguratNormalizedGaussianSampler".equals(samplerType)) {
125                 sampler = LogNormalSampler.of(ZigguratNormalizedGaussianSampler.of(rng), 12.3, 4.6);
126             } else if ("ChengBetaSampler".equals(samplerType)) {
127                 sampler = ChengBetaSampler.of(rng, 0.45, 6.7);
128             } else if ("ContinuousUniformSampler".equals(samplerType)) {
129                 sampler = ContinuousUniformSampler.of(rng, 123.4, 5678.9);
130             } else if ("InverseTransformParetoSampler".equals(samplerType)) {
131                 sampler = InverseTransformParetoSampler.of(rng, 23.45, 0.1234);
132             }
133         }
134     }
135 
136     // Benchmarks methods below.
137 
138     /**
139      * Baseline for the JMH timing overhead for production of an {@code double} value.
140      *
141      * @return the {@code double} value
142      */
143     @Benchmark
144     public double baseline() {
145         return value;
146     }
147 
148     /**
149      * Run the sampler.
150      *
151      * @param sources Source of randomness.
152      * @return the sample value
153      */
154     @Benchmark
155     public double sample(Sources sources) {
156         return sources.getSampler().sample();
157     }
158 }