View Javadoc
1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.rng.sampling;
19  
20  import java.util.List;
21  import java.util.Map;
22  import java.util.HashMap;
23  import java.util.ArrayList;
24  import java.util.Arrays;
25  
26  import org.apache.commons.rng.UniformRandomProvider;
27  
28  /**
29   * Sampling from a collection of items with user-defined
30   * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
31   * probabilities</a>.
32   * Note that if all unique items are assigned the same probability,
33   * it is much more efficient to use {@link CollectionSampler}.
34   *
35   * @param <T> Type of items in the collection.
36   *
37   * @since 1.1
38   */
39  public class DiscreteProbabilityCollectionSampler<T> {
40      /** Collection to be sampled from. */
41      private final List<T> items;
42      /** RNG. */
43      private final UniformRandomProvider rng;
44      /** Cumulative probabilities. */
45      private final double[] cumulativeProbabilities;
46  
47      /**
48       * Creates a sampler.
49       *
50       * @param rng Generator of uniformly distributed random numbers.
51       * @param collection Collection to be sampled, with the probabilities
52       * associated to each of its items.
53       * A (shallow) copy of the items will be stored in the created instance.
54       * The probabilities must be non-negative, but zero values are allowed
55       * and their sum does not have to equal one (input will be normalized
56       * to make the probabilities sum to one).
57       * @throws IllegalArgumentException if {@code collection} is empty, a
58       * probability is negative, infinite or {@code NaN}, or the sum of all
59       * probabilities is not strictly positive.
60       */
61      public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
62                                                  Map<T, Double> collection) {
63          if (collection.isEmpty()) {
64              throw new IllegalArgumentException("Empty collection");
65          }
66  
67          this.rng = rng;
68          final int size = collection.size();
69          items = new ArrayList<T>(size);
70          cumulativeProbabilities = new double[size];
71  
72          double sumProb = 0;
73          int count = 0;
74          for (Map.Entry<T, Double> e : collection.entrySet()) {
75              items.add(e.getKey());
76  
77              final double prob = e.getValue();
78              if (prob < 0 ||
79                  Double.isInfinite(prob) ||
80                  Double.isNaN(prob)) {
81                  throw new IllegalArgumentException("Invalid probability: " +
82                                                     prob);
83              }
84  
85              // Temporarily store probability.
86              cumulativeProbabilities[count++] = prob;
87              sumProb += prob;
88          }
89  
90          if (!(sumProb > 0)) {
91              throw new IllegalArgumentException("Invalid sum of probabilities");
92          }
93  
94          // Compute and store cumulative probability.
95          for (int i = 0; i < size; i++) {
96              cumulativeProbabilities[i] /= sumProb;
97              if (i > 0) {
98                  cumulativeProbabilities[i] += cumulativeProbabilities[i - 1];
99              }
100         }
101     }
102 
103     /**
104      * Creates a sampler.
105      *
106      * @param rng Generator of uniformly distributed random numbers.
107      * @param collection Collection to be sampled.
108      * A (shallow) copy of the items will be stored in the created instance.
109      * @param probabilities Probability associated to each item of the
110      * {@code collection}.
111      * The probabilities must be non-negative, but zero values are allowed
112      * and their sum does not have to equal one (input will be normalized
113      * to make the probabilities sum to one).
114      * @throws IllegalArgumentException if {@code collection} is empty or
115      * a probability is negative, infinite or {@code NaN}, or if the number
116      * of items in the {@code collection} is not equal to the number of
117      * provided {@code probabilities}.
118      */
119     public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
120                                                 List<T> collection,
121                                                 double[] probabilities) {
122         this(rng, consolidate(collection, probabilities));
123     }
124 
125     /**
126      * Picks one of the items from the collection passed to the constructor.
127      *
128      * @return a random sample.
129      */
130     public T sample() {
131         final double rand = rng.nextDouble();
132 
133         int index = Arrays.binarySearch(cumulativeProbabilities, rand);
134         if (index < 0) {
135             index = -index - 1;
136         }
137 
138         if (index >= 0 &&
139             index < cumulativeProbabilities.length &&
140             rand < cumulativeProbabilities[index]) {
141             return items.get(index);
142         }
143 
144         // This should never happen, but it ensures we will return a correct
145         // object in case there is some floating point inequality problem
146         // wrt the cumulative probabilities.
147         return items.get(items.size() - 1);
148     }
149 
150     /**
151      * @param collection Collection to be sampled.
152      * @param probabilities Probability associated to each item of the
153      * {@code collection}.
154      * @return a consolidated map (where probabilities of equal items
155      * have been summed).
156      * @throws IllegalArgumentException if the number of items in the
157      * {@code collection} is not equal to the number of provided
158      * {@code probabilities}.
159      * @param <T> Type of items in the collection.
160      */
161     private static <T> Map<T, Double> consolidate(List<T> collection,
162                                                   double[] probabilities) {
163         final int len = probabilities.length;
164         if (len != collection.size()) {
165             throw new IllegalArgumentException("Size mismatch: " +
166                                                len + " != " +
167                                                collection.size());
168         }
169 
170         final Map<T, Double> map = new HashMap<T, Double>();
171         for (int i = 0; i < len; i++) {
172             final T item = collection.get(i);
173             final Double prob = probabilities[i];
174 
175             Double currentProb = map.get(item);
176             if (currentProb == null) {
177                 currentProb = 0d;
178             }
179 
180             map.put(item, currentProb + prob);
181         }
182 
183         return map;
184     }
185 }