001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math3.stat.descriptive;
018    
019    import java.io.Serializable;
020    import java.util.Arrays;
021    
022    import org.apache.commons.math3.exception.util.LocalizedFormats;
023    import org.apache.commons.math3.exception.DimensionMismatchException;
024    import org.apache.commons.math3.exception.MathIllegalStateException;
025    import org.apache.commons.math3.linear.RealMatrix;
026    import org.apache.commons.math3.stat.descriptive.moment.GeometricMean;
027    import org.apache.commons.math3.stat.descriptive.moment.Mean;
028    import org.apache.commons.math3.stat.descriptive.moment.VectorialCovariance;
029    import org.apache.commons.math3.stat.descriptive.rank.Max;
030    import org.apache.commons.math3.stat.descriptive.rank.Min;
031    import org.apache.commons.math3.stat.descriptive.summary.Sum;
032    import org.apache.commons.math3.stat.descriptive.summary.SumOfLogs;
033    import org.apache.commons.math3.stat.descriptive.summary.SumOfSquares;
034    import org.apache.commons.math3.util.MathUtils;
035    import org.apache.commons.math3.util.MathArrays;
036    import org.apache.commons.math3.util.Precision;
037    import org.apache.commons.math3.util.FastMath;
038    
039    /**
040     * <p>Computes summary statistics for a stream of n-tuples added using the
041     * {@link #addValue(double[]) addValue} method. The data values are not stored
042     * in memory, so this class can be used to compute statistics for very large
043     * n-tuple streams.</p>
044     *
045     * <p>The {@link StorelessUnivariateStatistic} instances used to maintain
046     * summary state and compute statistics are configurable via setters.
047     * For example, the default implementation for the mean can be overridden by
048     * calling {@link #setMeanImpl(StorelessUnivariateStatistic[])}. Actual
049     * parameters to these methods must implement the
050     * {@link StorelessUnivariateStatistic} interface and configuration must be
051     * completed before <code>addValue</code> is called. No configuration is
052     * necessary to use the default, commons-math provided implementations.</p>
053     *
054     * <p>To compute statistics for a stream of n-tuples, construct a
055     * MultivariateStatistics instance with dimension n and then use
056     * {@link #addValue(double[])} to add n-tuples. The <code>getXxx</code>
057     * methods where Xxx is a statistic return an array of <code>double</code>
058     * values, where for <code>i = 0,...,n-1</code> the i<sup>th</sup> array element is the
059     * value of the given statistic for data range consisting of the i<sup>th</sup> element of
060     * each of the input n-tuples.  For example, if <code>addValue</code> is called
061     * with actual parameters {0, 1, 2}, then {3, 4, 5} and finally {6, 7, 8},
062     * <code>getSum</code> will return a three-element array with values
063     * {0+3+6, 1+4+7, 2+5+8}</p>
064     *
065     * <p>Note: This class is not thread-safe. Use
066     * {@link SynchronizedMultivariateSummaryStatistics} if concurrent access from multiple
067     * threads is required.</p>
068     *
069     * @since 1.2
070     * @version $Id: MultivariateSummaryStatistics.java 1416643 2012-12-03 19:37:14Z tn $
071     */
072    public class MultivariateSummaryStatistics
073        implements StatisticalMultivariateSummary, Serializable {
074    
075        /** Serialization UID */
076        private static final long serialVersionUID = 2271900808994826718L;
077    
078        /** Dimension of the data. */
079        private int k;
080    
081        /** Count of values that have been added */
082        private long n = 0;
083    
084        /** Sum statistic implementation - can be reset by setter. */
085        private StorelessUnivariateStatistic[] sumImpl;
086    
087        /** Sum of squares statistic implementation - can be reset by setter. */
088        private StorelessUnivariateStatistic[] sumSqImpl;
089    
090        /** Minimum statistic implementation - can be reset by setter. */
091        private StorelessUnivariateStatistic[] minImpl;
092    
093        /** Maximum statistic implementation - can be reset by setter. */
094        private StorelessUnivariateStatistic[] maxImpl;
095    
096        /** Sum of log statistic implementation - can be reset by setter. */
097        private StorelessUnivariateStatistic[] sumLogImpl;
098    
099        /** Geometric mean statistic implementation - can be reset by setter. */
100        private StorelessUnivariateStatistic[] geoMeanImpl;
101    
102        /** Mean statistic implementation - can be reset by setter. */
103        private StorelessUnivariateStatistic[] meanImpl;
104    
105        /** Covariance statistic implementation - cannot be reset. */
106        private VectorialCovariance covarianceImpl;
107    
108        /**
109         * Construct a MultivariateSummaryStatistics instance
110         * @param k dimension of the data
111         * @param isCovarianceBiasCorrected if true, the unbiased sample
112         * covariance is computed, otherwise the biased population covariance
113         * is computed
114         */
115        public MultivariateSummaryStatistics(int k, boolean isCovarianceBiasCorrected) {
116            this.k = k;
117    
118            sumImpl     = new StorelessUnivariateStatistic[k];
119            sumSqImpl   = new StorelessUnivariateStatistic[k];
120            minImpl     = new StorelessUnivariateStatistic[k];
121            maxImpl     = new StorelessUnivariateStatistic[k];
122            sumLogImpl  = new StorelessUnivariateStatistic[k];
123            geoMeanImpl = new StorelessUnivariateStatistic[k];
124            meanImpl    = new StorelessUnivariateStatistic[k];
125    
126            for (int i = 0; i < k; ++i) {
127                sumImpl[i]     = new Sum();
128                sumSqImpl[i]   = new SumOfSquares();
129                minImpl[i]     = new Min();
130                maxImpl[i]     = new Max();
131                sumLogImpl[i]  = new SumOfLogs();
132                geoMeanImpl[i] = new GeometricMean();
133                meanImpl[i]    = new Mean();
134            }
135    
136            covarianceImpl =
137                new VectorialCovariance(k, isCovarianceBiasCorrected);
138    
139        }
140    
141        /**
142         * Add an n-tuple to the data
143         *
144         * @param value  the n-tuple to add
145         * @throws DimensionMismatchException if the length of the array
146         * does not match the one used at construction
147         */
148        public void addValue(double[] value) throws DimensionMismatchException {
149            checkDimension(value.length);
150            for (int i = 0; i < k; ++i) {
151                double v = value[i];
152                sumImpl[i].increment(v);
153                sumSqImpl[i].increment(v);
154                minImpl[i].increment(v);
155                maxImpl[i].increment(v);
156                sumLogImpl[i].increment(v);
157                geoMeanImpl[i].increment(v);
158                meanImpl[i].increment(v);
159            }
160            covarianceImpl.increment(value);
161            n++;
162        }
163    
164        /**
165         * Returns the dimension of the data
166         * @return The dimension of the data
167         */
168        public int getDimension() {
169            return k;
170        }
171    
172        /**
173         * Returns the number of available values
174         * @return The number of available values
175         */
176        public long getN() {
177            return n;
178        }
179    
180        /**
181         * Returns an array of the results of a statistic.
182         * @param stats univariate statistic array
183         * @return results array
184         */
185        private double[] getResults(StorelessUnivariateStatistic[] stats) {
186            double[] results = new double[stats.length];
187            for (int i = 0; i < results.length; ++i) {
188                results[i] = stats[i].getResult();
189            }
190            return results;
191        }
192    
193        /**
194         * Returns an array whose i<sup>th</sup> entry is the sum of the
195         * i<sup>th</sup> entries of the arrays that have been added using
196         * {@link #addValue(double[])}
197         *
198         * @return the array of component sums
199         */
200        public double[] getSum() {
201            return getResults(sumImpl);
202        }
203    
204        /**
205         * Returns an array whose i<sup>th</sup> entry is the sum of squares of the
206         * i<sup>th</sup> entries of the arrays that have been added using
207         * {@link #addValue(double[])}
208         *
209         * @return the array of component sums of squares
210         */
211        public double[] getSumSq() {
212            return getResults(sumSqImpl);
213        }
214    
215        /**
216         * Returns an array whose i<sup>th</sup> entry is the sum of logs of the
217         * i<sup>th</sup> entries of the arrays that have been added using
218         * {@link #addValue(double[])}
219         *
220         * @return the array of component log sums
221         */
222        public double[] getSumLog() {
223            return getResults(sumLogImpl);
224        }
225    
226        /**
227         * Returns an array whose i<sup>th</sup> entry is the mean of the
228         * i<sup>th</sup> entries of the arrays that have been added using
229         * {@link #addValue(double[])}
230         *
231         * @return the array of component means
232         */
233        public double[] getMean() {
234            return getResults(meanImpl);
235        }
236    
237        /**
238         * Returns an array whose i<sup>th</sup> entry is the standard deviation of the
239         * i<sup>th</sup> entries of the arrays that have been added using
240         * {@link #addValue(double[])}
241         *
242         * @return the array of component standard deviations
243         */
244        public double[] getStandardDeviation() {
245            double[] stdDev = new double[k];
246            if (getN() < 1) {
247                Arrays.fill(stdDev, Double.NaN);
248            } else if (getN() < 2) {
249                Arrays.fill(stdDev, 0.0);
250            } else {
251                RealMatrix matrix = covarianceImpl.getResult();
252                for (int i = 0; i < k; ++i) {
253                    stdDev[i] = FastMath.sqrt(matrix.getEntry(i, i));
254                }
255            }
256            return stdDev;
257        }
258    
259        /**
260         * Returns the covariance matrix of the values that have been added.
261         *
262         * @return the covariance matrix
263         */
264        public RealMatrix getCovariance() {
265            return covarianceImpl.getResult();
266        }
267    
268        /**
269         * Returns an array whose i<sup>th</sup> entry is the maximum of the
270         * i<sup>th</sup> entries of the arrays that have been added using
271         * {@link #addValue(double[])}
272         *
273         * @return the array of component maxima
274         */
275        public double[] getMax() {
276            return getResults(maxImpl);
277        }
278    
279        /**
280         * Returns an array whose i<sup>th</sup> entry is the minimum of the
281         * i<sup>th</sup> entries of the arrays that have been added using
282         * {@link #addValue(double[])}
283         *
284         * @return the array of component minima
285         */
286        public double[] getMin() {
287            return getResults(minImpl);
288        }
289    
290        /**
291         * Returns an array whose i<sup>th</sup> entry is the geometric mean of the
292         * i<sup>th</sup> entries of the arrays that have been added using
293         * {@link #addValue(double[])}
294         *
295         * @return the array of component geometric means
296         */
297        public double[] getGeometricMean() {
298            return getResults(geoMeanImpl);
299        }
300    
301        /**
302         * Generates a text report displaying
303         * summary statistics from values that
304         * have been added.
305         * @return String with line feeds displaying statistics
306         */
307        @Override
308        public String toString() {
309            final String separator = ", ";
310            final String suffix = System.getProperty("line.separator");
311            StringBuilder outBuffer = new StringBuilder();
312            outBuffer.append("MultivariateSummaryStatistics:" + suffix);
313            outBuffer.append("n: " + getN() + suffix);
314            append(outBuffer, getMin(), "min: ", separator, suffix);
315            append(outBuffer, getMax(), "max: ", separator, suffix);
316            append(outBuffer, getMean(), "mean: ", separator, suffix);
317            append(outBuffer, getGeometricMean(), "geometric mean: ", separator, suffix);
318            append(outBuffer, getSumSq(), "sum of squares: ", separator, suffix);
319            append(outBuffer, getSumLog(), "sum of logarithms: ", separator, suffix);
320            append(outBuffer, getStandardDeviation(), "standard deviation: ", separator, suffix);
321            outBuffer.append("covariance: " + getCovariance().toString() + suffix);
322            return outBuffer.toString();
323        }
324    
325        /**
326         * Append a text representation of an array to a buffer.
327         * @param buffer buffer to fill
328         * @param data data array
329         * @param prefix text prefix
330         * @param separator elements separator
331         * @param suffix text suffix
332         */
333        private void append(StringBuilder buffer, double[] data,
334                            String prefix, String separator, String suffix) {
335            buffer.append(prefix);
336            for (int i = 0; i < data.length; ++i) {
337                if (i > 0) {
338                    buffer.append(separator);
339                }
340                buffer.append(data[i]);
341            }
342            buffer.append(suffix);
343        }
344    
345        /**
346         * Resets all statistics and storage
347         */
348        public void clear() {
349            this.n = 0;
350            for (int i = 0; i < k; ++i) {
351                minImpl[i].clear();
352                maxImpl[i].clear();
353                sumImpl[i].clear();
354                sumLogImpl[i].clear();
355                sumSqImpl[i].clear();
356                geoMeanImpl[i].clear();
357                meanImpl[i].clear();
358            }
359            covarianceImpl.clear();
360        }
361    
362        /**
363         * Returns true iff <code>object</code> is a <code>MultivariateSummaryStatistics</code>
364         * instance and all statistics have the same values as this.
365         * @param object the object to test equality against.
366         * @return true if object equals this
367         */
368        @Override
369        public boolean equals(Object object) {
370            if (object == this ) {
371                return true;
372            }
373            if (object instanceof MultivariateSummaryStatistics == false) {
374                return false;
375            }
376            MultivariateSummaryStatistics stat = (MultivariateSummaryStatistics) object;
377            return MathArrays.equalsIncludingNaN(stat.getGeometricMean(), getGeometricMean()) &&
378                   MathArrays.equalsIncludingNaN(stat.getMax(),           getMax())           &&
379                   MathArrays.equalsIncludingNaN(stat.getMean(),          getMean())          &&
380                   MathArrays.equalsIncludingNaN(stat.getMin(),           getMin())           &&
381                   Precision.equalsIncludingNaN(stat.getN(),             getN())             &&
382                   MathArrays.equalsIncludingNaN(stat.getSum(),           getSum())           &&
383                   MathArrays.equalsIncludingNaN(stat.getSumSq(),         getSumSq())         &&
384                   MathArrays.equalsIncludingNaN(stat.getSumLog(),        getSumLog())        &&
385                   stat.getCovariance().equals( getCovariance());
386        }
387    
388        /**
389         * Returns hash code based on values of statistics
390         *
391         * @return hash code
392         */
393        @Override
394        public int hashCode() {
395            int result = 31 + MathUtils.hash(getGeometricMean());
396            result = result * 31 + MathUtils.hash(getGeometricMean());
397            result = result * 31 + MathUtils.hash(getMax());
398            result = result * 31 + MathUtils.hash(getMean());
399            result = result * 31 + MathUtils.hash(getMin());
400            result = result * 31 + MathUtils.hash(getN());
401            result = result * 31 + MathUtils.hash(getSum());
402            result = result * 31 + MathUtils.hash(getSumSq());
403            result = result * 31 + MathUtils.hash(getSumLog());
404            result = result * 31 + getCovariance().hashCode();
405            return result;
406        }
407    
408        // Getters and setters for statistics implementations
409        /**
410         * Sets statistics implementations.
411         * @param newImpl new implementations for statistics
412         * @param oldImpl old implementations for statistics
413         * @throws DimensionMismatchException if the array dimension
414         * does not match the one used at construction
415         * @throws MathIllegalStateException if data has already been added
416         * (i.e. if n > 0)
417         */
418        private void setImpl(StorelessUnivariateStatistic[] newImpl,
419                             StorelessUnivariateStatistic[] oldImpl) throws MathIllegalStateException,
420                             DimensionMismatchException {
421            checkEmpty();
422            checkDimension(newImpl.length);
423            System.arraycopy(newImpl, 0, oldImpl, 0, newImpl.length);
424        }
425    
426        /**
427         * Returns the currently configured Sum implementation
428         *
429         * @return the StorelessUnivariateStatistic implementing the sum
430         */
431        public StorelessUnivariateStatistic[] getSumImpl() {
432            return sumImpl.clone();
433        }
434    
435        /**
436         * <p>Sets the implementation for the Sum.</p>
437         * <p>This method must be activated before any data has been added - i.e.,
438         * before {@link #addValue(double[]) addValue} has been used to add data;
439         * otherwise an IllegalStateException will be thrown.</p>
440         *
441         * @param sumImpl the StorelessUnivariateStatistic instance to use
442         * for computing the Sum
443         * @throws DimensionMismatchException if the array dimension
444         * does not match the one used at construction
445         * @throws MathIllegalStateException if data has already been added
446         *  (i.e if n > 0)
447         */
448        public void setSumImpl(StorelessUnivariateStatistic[] sumImpl)
449        throws MathIllegalStateException, DimensionMismatchException {
450            setImpl(sumImpl, this.sumImpl);
451        }
452    
453        /**
454         * Returns the currently configured sum of squares implementation
455         *
456         * @return the StorelessUnivariateStatistic implementing the sum of squares
457         */
458        public StorelessUnivariateStatistic[] getSumsqImpl() {
459            return sumSqImpl.clone();
460        }
461    
462        /**
463         * <p>Sets the implementation for the sum of squares.</p>
464         * <p>This method must be activated before any data has been added - i.e.,
465         * before {@link #addValue(double[]) addValue} has been used to add data;
466         * otherwise an IllegalStateException will be thrown.</p>
467         *
468         * @param sumsqImpl the StorelessUnivariateStatistic instance to use
469         * for computing the sum of squares
470         * @throws DimensionMismatchException if the array dimension
471         * does not match the one used at construction
472         * @throws MathIllegalStateException if data has already been added
473         *  (i.e if n > 0)
474         */
475        public void setSumsqImpl(StorelessUnivariateStatistic[] sumsqImpl)
476        throws MathIllegalStateException, DimensionMismatchException {
477            setImpl(sumsqImpl, this.sumSqImpl);
478        }
479    
480        /**
481         * Returns the currently configured minimum implementation
482         *
483         * @return the StorelessUnivariateStatistic implementing the minimum
484         */
485        public StorelessUnivariateStatistic[] getMinImpl() {
486            return minImpl.clone();
487        }
488    
489        /**
490         * <p>Sets the implementation for the minimum.</p>
491         * <p>This method must be activated before any data has been added - i.e.,
492         * before {@link #addValue(double[]) addValue} has been used to add data;
493         * otherwise an IllegalStateException will be thrown.</p>
494         *
495         * @param minImpl the StorelessUnivariateStatistic instance to use
496         * for computing the minimum
497         * @throws DimensionMismatchException if the array dimension
498         * does not match the one used at construction
499         * @throws MathIllegalStateException if data has already been added
500         *  (i.e if n > 0)
501         */
502        public void setMinImpl(StorelessUnivariateStatistic[] minImpl)
503        throws MathIllegalStateException, DimensionMismatchException {
504            setImpl(minImpl, this.minImpl);
505        }
506    
507        /**
508         * Returns the currently configured maximum implementation
509         *
510         * @return the StorelessUnivariateStatistic implementing the maximum
511         */
512        public StorelessUnivariateStatistic[] getMaxImpl() {
513            return maxImpl.clone();
514        }
515    
516        /**
517         * <p>Sets the implementation for the maximum.</p>
518         * <p>This method must be activated before any data has been added - i.e.,
519         * before {@link #addValue(double[]) addValue} has been used to add data;
520         * otherwise an IllegalStateException will be thrown.</p>
521         *
522         * @param maxImpl the StorelessUnivariateStatistic instance to use
523         * for computing the maximum
524         * @throws DimensionMismatchException if the array dimension
525         * does not match the one used at construction
526         * @throws MathIllegalStateException if data has already been added
527         *  (i.e if n > 0)
528         */
529        public void setMaxImpl(StorelessUnivariateStatistic[] maxImpl)
530        throws MathIllegalStateException, DimensionMismatchException{
531            setImpl(maxImpl, this.maxImpl);
532        }
533    
534        /**
535         * Returns the currently configured sum of logs implementation
536         *
537         * @return the StorelessUnivariateStatistic implementing the log sum
538         */
539        public StorelessUnivariateStatistic[] getSumLogImpl() {
540            return sumLogImpl.clone();
541        }
542    
543        /**
544         * <p>Sets the implementation for the sum of logs.</p>
545         * <p>This method must be activated before any data has been added - i.e.,
546         * before {@link #addValue(double[]) addValue} has been used to add data;
547         * otherwise an IllegalStateException will be thrown.</p>
548         *
549         * @param sumLogImpl the StorelessUnivariateStatistic instance to use
550         * for computing the log sum
551         * @throws DimensionMismatchException if the array dimension
552         * does not match the one used at construction
553         * @throws MathIllegalStateException if data has already been added
554         *  (i.e if n > 0)
555         */
556        public void setSumLogImpl(StorelessUnivariateStatistic[] sumLogImpl)
557        throws MathIllegalStateException, DimensionMismatchException{
558            setImpl(sumLogImpl, this.sumLogImpl);
559        }
560    
561        /**
562         * Returns the currently configured geometric mean implementation
563         *
564         * @return the StorelessUnivariateStatistic implementing the geometric mean
565         */
566        public StorelessUnivariateStatistic[] getGeoMeanImpl() {
567            return geoMeanImpl.clone();
568        }
569    
570        /**
571         * <p>Sets the implementation for the geometric mean.</p>
572         * <p>This method must be activated before any data has been added - i.e.,
573         * before {@link #addValue(double[]) addValue} has been used to add data;
574         * otherwise an IllegalStateException will be thrown.</p>
575         *
576         * @param geoMeanImpl the StorelessUnivariateStatistic instance to use
577         * for computing the geometric mean
578         * @throws DimensionMismatchException if the array dimension
579         * does not match the one used at construction
580         * @throws MathIllegalStateException if data has already been added
581         *  (i.e if n > 0)
582         */
583        public void setGeoMeanImpl(StorelessUnivariateStatistic[] geoMeanImpl)
584        throws MathIllegalStateException, DimensionMismatchException {
585            setImpl(geoMeanImpl, this.geoMeanImpl);
586        }
587    
588        /**
589         * Returns the currently configured mean implementation
590         *
591         * @return the StorelessUnivariateStatistic implementing the mean
592         */
593        public StorelessUnivariateStatistic[] getMeanImpl() {
594            return meanImpl.clone();
595        }
596    
597        /**
598         * <p>Sets the implementation for the mean.</p>
599         * <p>This method must be activated before any data has been added - i.e.,
600         * before {@link #addValue(double[]) addValue} has been used to add data;
601         * otherwise an IllegalStateException will be thrown.</p>
602         *
603         * @param meanImpl the StorelessUnivariateStatistic instance to use
604         * for computing the mean
605         * @throws DimensionMismatchException if the array dimension
606         * does not match the one used at construction
607         * @throws MathIllegalStateException if data has already been added
608         *  (i.e if n > 0)
609         */
610        public void setMeanImpl(StorelessUnivariateStatistic[] meanImpl)
611        throws MathIllegalStateException, DimensionMismatchException{
612            setImpl(meanImpl, this.meanImpl);
613        }
614    
615        /**
616         * Throws MathIllegalStateException if the statistic is not empty.
617         * @throws MathIllegalStateException if n > 0.
618         */
619        private void checkEmpty() throws MathIllegalStateException {
620            if (n > 0) {
621                throw new MathIllegalStateException(
622                        LocalizedFormats.VALUES_ADDED_BEFORE_CONFIGURING_STATISTIC, n);
623            }
624        }
625    
626        /**
627         * Throws DimensionMismatchException if dimension != k.
628         * @param dimension dimension to check
629         * @throws DimensionMismatchException if dimension != k
630         */
631        private void checkDimension(int dimension) throws DimensionMismatchException {
632            if (dimension != k) {
633                throw new DimensionMismatchException(dimension, k);
634            }
635        }
636    }