package hivemall.utils.math;

import hivemall.utils.lang.Preconditions;
import java.util.AbstractMap;
import java.util.Map;
import javax.annotation.Nonnull;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.NotPositiveException;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.commons.math3.util.MathArrays;

/* loaded from: input_file:hivemall/utils/math/StatsUtils.class */
public final class StatsUtils {
    private StatsUtils() {
    }

    public static double probit(double d) {
        if (d < CMAESOptimizer.DEFAULT_STOPFITNESS || d > 1.0d) {
            throw new IllegalArgumentException("p must be in [0,1]");
        }
        return Math.sqrt(2.0d) * MathUtils.inverseErf((2.0d * d) - 1.0d);
    }

    public static double probit(double d, double d2) {
        if (d2 <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
            throw new IllegalArgumentException("range must be > 0: " + d2);
        }
        if (d == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            return -d2;
        }
        if (d == 1.0d) {
            return d2;
        }
        double probit = probit(d);
        return probit < CMAESOptimizer.DEFAULT_STOPFITNESS ? Math.max(probit, -d2) : Math.min(probit, d2);
    }

    public static double pdf(double d, double d2, double d3) {
        if (d3 == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            return CMAESOptimizer.DEFAULT_STOPFITNESS;
        }
        double d4 = d - d2;
        return Math.exp((((-0.5d) * d4) * d4) / d3) / (Math.sqrt(6.283185307179586d) * Math.sqrt(d3));
    }

    public static double pdf(@Nonnull RealVector realVector, @Nonnull RealVector realVector2, @Nonnull RealMatrix realMatrix) {
        int dimension = realVector.getDimension();
        Preconditions.checkArgument(realVector2.getDimension() == dimension, "|x| != |x_hat|, |x|=" + dimension + ", |x_hat|=" + realVector2.getDimension());
        Preconditions.checkArgument(realMatrix.getRowDimension() == dimension, "|x| != |sigma|, |x|=" + dimension + ", |sigma|=" + realMatrix.getRowDimension());
        Preconditions.checkArgument(realMatrix.isSquare(), "Sigma is not square matrix");
        LUDecomposition lUDecomposition = new LUDecomposition(realMatrix);
        double pow = Math.pow(6.283185307179586d, 0.5d * dimension) * Math.pow(lUDecomposition.getDeterminant(), 0.5d);
        if (pow == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            return CMAESOptimizer.DEFAULT_STOPFITNESS;
        }
        DecompositionSolver solver = lUDecomposition.getSolver();
        RealMatrix inverse = !solver.isNonSingular() ? new SingularValueDecomposition(realMatrix).getSolver().getInverse() : solver.getInverse();
        RealVector subtract = realVector.subtract(realVector2);
        return Math.exp((-0.5d) * inverse.preMultiply(subtract).dotProduct(subtract)) / pow;
    }

    public static double logLoss(double d, double d2, double d3) {
        double pdf = pdf(d, d2, d3);
        return pdf == CMAESOptimizer.DEFAULT_STOPFITNESS ? CMAESOptimizer.DEFAULT_STOPFITNESS : -Math.log(pdf);
    }

    public static double logLoss(@Nonnull RealVector realVector, @Nonnull RealVector realVector2, @Nonnull RealMatrix realMatrix) {
        double pdf = pdf(realVector, realVector2, realMatrix);
        return pdf == CMAESOptimizer.DEFAULT_STOPFITNESS ? CMAESOptimizer.DEFAULT_STOPFITNESS : -Math.log(pdf);
    }

    public static double hellingerDistance(@Nonnull double d, @Nonnull double d2, @Nonnull double d3, @Nonnull double d4) {
        double d5 = d2 + d4;
        if (d5 == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            return CMAESOptimizer.DEFAULT_STOPFITNESS;
        }
        double pow = Math.pow(d2, 0.25d) * Math.pow(d4, 0.25d) * Math.exp(((-0.25d) * Math.pow(d - d3, 2.0d)) / d5);
        double sqrt = Math.sqrt(d5 / 2.0d);
        if (sqrt == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            return 1.0d;
        }
        return 1.0d - (pow / sqrt);
    }

    public static double hellingerDistance(@Nonnull RealVector realVector, @Nonnull RealMatrix realMatrix, @Nonnull RealVector realVector2, @Nonnull RealMatrix realMatrix2) {
        RealVector subtract = realVector.subtract(realVector2);
        LUDecomposition lUDecomposition = new LUDecomposition(realMatrix.add(realMatrix2).scalarMultiply(0.5d));
        double sqrt = Math.sqrt(lUDecomposition.getDeterminant());
        if (sqrt == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            return 1.0d;
        }
        RealMatrix inverse = lUDecomposition.getSolver().getInverse();
        return 1.0d - (((Math.pow(MatrixUtils.det(realMatrix), 0.25d) * Math.pow(MatrixUtils.det(realMatrix2), 0.25d)) * Math.exp((-0.125d) * inverse.preMultiply(subtract).dotProduct(subtract))) / sqrt);
    }

    public static double chiSquare(@Nonnull double[] dArr, @Nonnull double[] dArr2) {
        double d;
        double d2;
        double d3;
        if (dArr.length < 2) {
            throw new DimensionMismatchException(dArr.length, 2);
        }
        if (dArr2.length != dArr.length) {
            throw new DimensionMismatchException(dArr.length, dArr2.length);
        }
        MathArrays.checkPositive(dArr2);
        for (double d4 : dArr) {
            if (d4 < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                throw new NotPositiveException(Double.valueOf(d4));
            }
        }
        double d5 = 0.0d;
        double d6 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d5 += dArr[i];
            d6 += dArr2[i];
        }
        double d7 = 1.0d;
        boolean z = false;
        if (org.apache.commons.math3.util.FastMath.abs(d5 - d6) > 1.0E-5d) {
            d7 = d5 / d6;
            z = true;
        }
        double d8 = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (z) {
                double d9 = dArr[i2] - (d7 * dArr2[i2]);
                d = d8;
                d2 = d9 * d9;
                d3 = d7 * dArr2[i2];
            } else {
                double d10 = dArr[i2] - dArr2[i2];
                d = d8;
                d2 = d10 * d10;
                d3 = dArr2[i2];
            }
            d8 = d + (d2 / d3);
        }
        return d8;
    }

    public static double chiSquareTest(@Nonnull double[] dArr, @Nonnull double[] dArr2) {
        return 1.0d - new ChiSquaredDistribution(dArr2.length - 1.0d).cumulativeProbability(chiSquare(dArr, dArr2));
    }

    public static Map.Entry<double[], double[]> chiSquare(@Nonnull double[][] dArr, @Nonnull double[][] dArr2) {
        Preconditions.checkArgument(dArr.length == dArr2.length);
        int length = dArr2.length;
        ChiSquaredDistribution chiSquaredDistribution = new ChiSquaredDistribution(dArr2[0].length - 1.0d);
        double[] dArr3 = new double[length];
        double[] dArr4 = new double[length];
        for (int i = 0; i < length; i++) {
            dArr3[i] = chiSquare(dArr[i], dArr2[i]);
            dArr4[i] = 1.0d - chiSquaredDistribution.cumulativeProbability(dArr3[i]);
        }
        return new AbstractMap.SimpleEntry(dArr3, dArr4);
    }
}
