1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.commons.rng.sampling; 19 20 import java.util.List; 21 import java.util.Map; 22 import java.util.HashMap; 23 import java.util.ArrayList; 24 import java.util.Arrays; 25 26 import org.apache.commons.rng.UniformRandomProvider; 27 28 /** 29 * Sampling from a collection of items with user-defined 30 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution"> 31 * probabilities</a>. 32 * Note that if all unique items are assigned the same probability, 33 * it is much more efficient to use {@link CollectionSampler}. 34 * 35 * @param <T> Type of items in the collection. 36 * 37 * @since 1.1 38 */ 39 public class DiscreteProbabilityCollectionSampler<T> { 40 /** Collection to be sampled from. */ 41 private final List<T> items; 42 /** RNG. */ 43 private final UniformRandomProvider rng; 44 /** Cumulative probabilities. */ 45 private final double[] cumulativeProbabilities; 46 47 /** 48 * Creates a sampler. 49 * 50 * @param rng Generator of uniformly distributed random numbers. 51 * @param collection Collection to be sampled, with the probabilities 52 * associated to each of its items. 53 * A (shallow) copy of the items will be stored in the created instance. 54 * The probabilities must be non-negative, but zero values are allowed 55 * and their sum does not have to equal one (input will be normalized 56 * to make the probabilities sum to one). 57 * @throws IllegalArgumentException if {@code collection} is empty, a 58 * probability is negative, infinite or {@code NaN}, or the sum of all 59 * probabilities is not strictly positive. 60 */ 61 public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, 62 Map<T, Double> collection) { 63 if (collection.isEmpty()) { 64 throw new IllegalArgumentException("Empty collection"); 65 } 66 67 this.rng = rng; 68 final int size = collection.size(); 69 items = new ArrayList<T>(size); 70 cumulativeProbabilities = new double[size]; 71 72 double sumProb = 0; 73 int count = 0; 74 for (Map.Entry<T, Double> e : collection.entrySet()) { 75 items.add(e.getKey()); 76 77 final double prob = e.getValue(); 78 if (prob < 0 || 79 Double.isInfinite(prob) || 80 Double.isNaN(prob)) { 81 throw new IllegalArgumentException("Invalid probability: " + 82 prob); 83 } 84 85 // Temporarily store probability. 86 cumulativeProbabilities[count++] = prob; 87 sumProb += prob; 88 } 89 90 if (!(sumProb > 0)) { 91 throw new IllegalArgumentException("Invalid sum of probabilities"); 92 } 93 94 // Compute and store cumulative probability. 95 for (int i = 0; i < size; i++) { 96 cumulativeProbabilities[i] /= sumProb; 97 if (i > 0) { 98 cumulativeProbabilities[i] += cumulativeProbabilities[i - 1]; 99 } 100 } 101 } 102 103 /** 104 * Creates a sampler. 105 * 106 * @param rng Generator of uniformly distributed random numbers. 107 * @param collection Collection to be sampled. 108 * A (shallow) copy of the items will be stored in the created instance. 109 * @param probabilities Probability associated to each item of the 110 * {@code collection}. 111 * The probabilities must be non-negative, but zero values are allowed 112 * and their sum does not have to equal one (input will be normalized 113 * to make the probabilities sum to one). 114 * @throws IllegalArgumentException if {@code collection} is empty or 115 * a probability is negative, infinite or {@code NaN}, or if the number 116 * of items in the {@code collection} is not equal to the number of 117 * provided {@code probabilities}. 118 */ 119 public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, 120 List<T> collection, 121 double[] probabilities) { 122 this(rng, consolidate(collection, probabilities)); 123 } 124 125 /** 126 * Picks one of the items from the collection passed to the constructor. 127 * 128 * @return a random sample. 129 */ 130 public T sample() { 131 final double rand = rng.nextDouble(); 132 133 int index = Arrays.binarySearch(cumulativeProbabilities, rand); 134 if (index < 0) { 135 index = -index - 1; 136 } 137 138 if (index >= 0 && 139 index < cumulativeProbabilities.length && 140 rand < cumulativeProbabilities[index]) { 141 return items.get(index); 142 } 143 144 // This should never happen, but it ensures we will return a correct 145 // object in case there is some floating point inequality problem 146 // wrt the cumulative probabilities. 147 return items.get(items.size() - 1); 148 } 149 150 /** 151 * @param collection Collection to be sampled. 152 * @param probabilities Probability associated to each item of the 153 * {@code collection}. 154 * @return a consolidated map (where probabilities of equal items 155 * have been summed). 156 * @throws IllegalArgumentException if the number of items in the 157 * {@code collection} is not equal to the number of provided 158 * {@code probabilities}. 159 * @param <T> Type of items in the collection. 160 */ 161 private static <T> Map<T, Double> consolidate(List<T> collection, 162 double[] probabilities) { 163 final int len = probabilities.length; 164 if (len != collection.size()) { 165 throw new IllegalArgumentException("Size mismatch: " + 166 len + " != " + 167 collection.size()); 168 } 169 170 final Map<T, Double> map = new HashMap<T, Double>(); 171 for (int i = 0; i < len; i++) { 172 final T item = collection.get(i); 173 final Double prob = probabilities[i]; 174 175 Double currentProb = map.get(item); 176 if (currentProb == null) { 177 currentProb = 0d; 178 } 179 180 map.put(item, currentProb + prob); 181 } 182 183 return map; 184 } 185 }