001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math3.distribution;
018    
019    import java.util.List;
020    import java.util.ArrayList;
021    import org.apache.commons.math3.exception.DimensionMismatchException;
022    import org.apache.commons.math3.exception.NotPositiveException;
023    import org.apache.commons.math3.exception.MathArithmeticException;
024    import org.apache.commons.math3.exception.util.LocalizedFormats;
025    import org.apache.commons.math3.random.RandomGenerator;
026    import org.apache.commons.math3.random.Well19937c;
027    import org.apache.commons.math3.util.Pair;
028    
029    /**
030     * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model">
031     * mixture model</a> distributions.
032     *
033     * @param <T> Type of the mixture components.
034     *
035     * @version $Id: MixtureMultivariateRealDistribution.java 1416643 2012-12-03 19:37:14Z tn $
036     * @since 3.1
037     */
038    public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
039        extends AbstractMultivariateRealDistribution {
040        /** Normalized weight of each mixture component. */
041        private final double[] weight;
042        /** Mixture components. */
043        private final List<T> distribution;
044    
045        /**
046         * Creates a mixture model from a list of distributions and their
047         * associated weights.
048         *
049         * @param components List of (weight, distribution) pairs from which to sample.
050         */
051        public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
052            this(new Well19937c(), components);
053        }
054    
055        /**
056         * Creates a mixture model from a list of distributions and their
057         * associated weights.
058         *
059         * @param rng Random number generator.
060         * @param components Distributions from which to sample.
061         * @throws NotPositiveException if any of the weights is negative.
062         * @throws DimensionMismatchException if not all components have the same
063         * number of variables.
064         */
065        public MixtureMultivariateRealDistribution(RandomGenerator rng,
066                                                   List<Pair<Double, T>> components) {
067            super(rng, components.get(0).getSecond().getDimension());
068    
069            final int numComp = components.size();
070            final int dim = getDimension();
071            double weightSum = 0;
072            for (int i = 0; i < numComp; i++) {
073                final Pair<Double, T> comp = components.get(i);
074                if (comp.getSecond().getDimension() != dim) {
075                    throw new DimensionMismatchException(comp.getSecond().getDimension(), dim);
076                }
077                if (comp.getFirst() < 0) {
078                    throw new NotPositiveException(comp.getFirst());
079                }
080                weightSum += comp.getFirst();
081            }
082    
083            // Check for overflow.
084            if (Double.isInfinite(weightSum)) {
085                throw new MathArithmeticException(LocalizedFormats.OVERFLOW);
086            }
087    
088            // Store each distribution and its normalized weight.
089            distribution = new ArrayList<T>();
090            weight = new double[numComp];
091            for (int i = 0; i < numComp; i++) {
092                final Pair<Double, T> comp = components.get(i);
093                weight[i] = comp.getFirst() / weightSum;
094                distribution.add(comp.getSecond());
095            }
096        }
097    
098        /** {@inheritDoc} */
099        public double density(final double[] values) {
100            double p = 0;
101            for (int i = 0; i < weight.length; i++) {
102                p += weight[i] * distribution.get(i).density(values);
103            }
104            return p;
105        }
106    
107        /** {@inheritDoc} */
108        public double[] sample() {
109            // Sampled values.
110            double[] vals = null;
111    
112            // Determine which component to sample from.
113            final double randomValue = random.nextDouble();
114            double sum = 0;
115    
116            for (int i = 0; i < weight.length; i++) {
117                sum += weight[i];
118                if (randomValue <= sum) {
119                    // pick model i
120                    vals = distribution.get(i).sample();
121                    break;
122                }
123            }
124    
125            if (vals == null) {
126                // This should never happen, but it ensures we won't return a null in
127                // case the loop above has some floating point inequality problem on
128                // the final iteration.
129                vals = distribution.get(weight.length - 1).sample();
130            }
131    
132            return vals;
133        }
134    
135        /** {@inheritDoc} */
136        public void reseedRandomGenerator(long seed) {
137            // Seed needs to be propagated to underlying components
138            // in order to maintain consistency between runs.
139            super.reseedRandomGenerator(seed);
140    
141            for (int i = 0; i < distribution.size(); i++) {
142                // Make each component's seed different in order to avoid
143                // using the same sequence of random numbers.
144                distribution.get(i).reseedRandomGenerator(i + 1 + seed);
145            }
146        }
147    
148        /**
149         * Gets the distributions that make up the mixture model.
150         *
151         * @return the component distributions and associated weights.
152         */
153        public List<Pair<Double, T>> getComponents() {
154            final List<Pair<Double, T>> list = new ArrayList<Pair<Double, T>>();
155    
156            for (int i = 0; i < weight.length; i++) {
157                list.add(new Pair<Double, T>(weight[i], distribution.get(i)));
158            }
159    
160            return list;
161        }
162    }