package org.apache.sysml.runtime.compress.estim;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.solvers.UnivariateSolverUtils;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.compress.BitmapEncoder;
import org.apache.sysml.runtime.compress.ReaderColumnSelection;
import org.apache.sysml.runtime.compress.UncompressedBitmap;
import org.apache.sysml.runtime.compress.estim.CompressedSizeEstimator;
import org.apache.sysml.runtime.compress.utils.DblArray;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;

/* loaded from: input_file:org/apache/sysml/runtime/compress/estim/CompressedSizeEstimatorSample.class */
public class CompressedSizeEstimatorSample extends CompressedSizeEstimator {
    private static final double SHLOSSER_JACKKNIFE_ALPHA = 0.975d;
    public static final double HAAS_AND_STOKES_ALPHA1 = 0.9d;
    public static final double HAAS_AND_STOKES_ALPHA2 = 30.0d;
    public static final int HAAS_AND_STOKES_UJ2A_C = 50;
    public static final boolean HAAS_AND_STOKES_UJ2A_CUT2 = true;
    public static final boolean HAAS_AND_STOKES_UJ2A_SOLVE = true;
    public static final int MAX_SOLVE_CACHE_SIZE = 65536;
    private static final Log LOG = LogFactory.getLog(CompressedSizeEstimatorSample.class.getName());
    private int[] _sampleRows;
    private HashMap<Integer, Double> _solveCache;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/compress/estim/CompressedSizeEstimatorSample$CriticalValue.class */
    public static class CriticalValue {
        public final double uniformityCriticalValue;
        public final int usedSampleSize;

        public CriticalValue(double d, int i) {
            this.uniformityCriticalValue = d;
            this.usedSampleSize = i;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/compress/estim/CompressedSizeEstimatorSample$MethodOfMomentsFunction.class */
    public static class MethodOfMomentsFunction implements UnivariateFunction {
        private final int _nj;
        private final double _q;

        public MethodOfMomentsFunction(int i, double d) {
            this._nj = i;
            this._q = d;
        }

        @Override // org.apache.commons.math3.analysis.UnivariateFunction
        public double value(double d) {
            return ((this._q * d) / (1.0d - Math.pow(1.0d - this._q, d))) - this._nj;
        }
    }

    public CompressedSizeEstimatorSample(MatrixBlock matrixBlock, int i) throws DMLRuntimeException {
        super(matrixBlock);
        this._sampleRows = null;
        this._solveCache = null;
        this._sampleRows = getSortedUniformSample(this._numRows, i);
        MatrixBlock matrixBlock2 = new MatrixBlock(this._numRows, 1, false);
        for (int i2 = 0; i2 < i; i2++) {
            matrixBlock2.quickSetValue(this._sampleRows[i2], 0, 1.0d);
        }
        this._data = this._data.removeEmptyOperations(new MatrixBlock(), false, true, matrixBlock2);
        this._solveCache = new HashMap<>();
    }

    @Override // org.apache.sysml.runtime.compress.estim.CompressedSizeEstimator
    public CompressedSizeInfo estimateCompressedColGroupSize(int[] iArr) {
        int length = this._sampleRows.length;
        int length2 = iArr.length;
        int[] iArr2 = this._sampleRows;
        UncompressedBitmap extractBitmap = BitmapEncoder.extractBitmap(iArr, this._data);
        CompressedSizeEstimator.SizeEstimationFactors computeSizeEstimationFactors = computeSizeEstimationFactors(extractBitmap, false);
        int min = Math.min(Math.max(getNumDistinctValues(extractBitmap, this._numRows, iArr2, this._solveCache), computeSizeEstimationFactors.numVals), this._numRows);
        int i = min - computeSizeEstimationFactors.numVals;
        double max = Math.max(1.0d - (computeSizeEstimationFactors.numSingle / length), length / this._numRows);
        int i2 = length - computeSizeEstimationFactors.numOffs;
        int max2 = Math.max((int) Math.ceil(this._numRows - (((this._numRows / length) * max) * i2)), min);
        if (min <= 0 || i < 0 || i2 < 0 || max2 <= 0) {
            LOG.warn("Invalid estimates detected for " + Arrays.toString(iArr) + ": " + min + " " + i + " " + i2 + " " + max2);
        }
        int ceil = (int) (i * Math.ceil((this._numRows / 65536.0d) / 2.0d));
        return new CompressedSizeInfo(min, max2, getRLESize(min, getNumRuns(extractBitmap, length, this._numRows, iArr2) + ceil, length2), getOLESize(min, max2, computeSizeEstimationFactors.numSegs + ceil, length2), getDDCSize(min, this._numRows, length2));
    }

    @Override // org.apache.sysml.runtime.compress.estim.CompressedSizeEstimator
    public CompressedSizeInfo estimateCompressedColGroupSize(UncompressedBitmap uncompressedBitmap) {
        CompressedSizeEstimator.SizeEstimationFactors computeSizeEstimationFactors = computeSizeEstimationFactors(uncompressedBitmap, true);
        return new CompressedSizeInfo(computeSizeEstimationFactors.numVals, computeSizeEstimationFactors.numOffs, getRLESize(computeSizeEstimationFactors.numVals, computeSizeEstimationFactors.numRuns, uncompressedBitmap.getNumColumns()), getOLESize(computeSizeEstimationFactors.numVals, computeSizeEstimationFactors.numOffs, computeSizeEstimationFactors.numSegs, uncompressedBitmap.getNumColumns()), getDDCSize(computeSizeEstimationFactors.numVals, this._numRows, uncompressedBitmap.getNumColumns()));
    }

    private static int getNumDistinctValues(UncompressedBitmap uncompressedBitmap, int i, int[] iArr, HashMap<Integer, Double> hashMap) {
        return haasAndStokes(uncompressedBitmap, i, iArr.length, hashMap);
    }

    private static int getNumRuns(UncompressedBitmap uncompressedBitmap, int i, int i2, int[] iArr) {
        int i3;
        double d;
        boolean z;
        int numValues = uncompressedBitmap.getNumValues();
        if (numValues == 0) {
            return 0;
        }
        double d2 = 0.0d;
        for (int i4 = 0; i4 < numValues; i4++) {
            int[] extractValues = uncompressedBitmap.getOffsetsList(i4).extractValues();
            int numOffsets = uncompressedBitmap.getNumOffsets(i4);
            double d3 = numOffsets / i;
            if ((d3 * i2) / i < 1.0d) {
                d2 += (numOffsets * i2) / i;
            } else {
                double d4 = 1.0d;
                boolean z2 = false;
                if (iArr[0] == 0) {
                    i3 = 0;
                } else {
                    int i5 = iArr[0];
                    int i6 = (i5 - (-1)) - 1;
                    double d5 = d3 * i6;
                    d2 += ((i6 - d5) * d5) / i6;
                    i3 = i5;
                    d4 = (i6 - d5) / i6;
                }
                int i7 = 0;
                boolean z3 = false;
                boolean z4 = false;
                int i8 = 0;
                int i9 = 1;
                while (i9 < i) {
                    if (i8 >= numOffsets || extractValues[i8] != i3) {
                        z3 = true;
                        z = false;
                    } else {
                        z4 = true;
                        i8++;
                        z = true;
                    }
                    while (true) {
                        if (i3 + 1 != iArr[i9]) {
                            break;
                        }
                        i3 = iArr[i9];
                        if (z3) {
                            if (i8 >= numOffsets || extractValues[i8] != i3) {
                                d2 += i7;
                                i7 = 0;
                                z = false;
                            } else {
                                i7 = 1;
                                i8++;
                                z = true;
                            }
                        } else if (i8 >= numOffsets || extractValues[i8] != i3) {
                            z3 = true;
                            z = false;
                        } else {
                            i8++;
                            z = true;
                        }
                        i9++;
                        if (i9 == i) {
                            z2 = true;
                            break;
                        }
                    }
                    if (z2) {
                        break;
                    }
                    int i10 = iArr[i9];
                    int i11 = (i10 - i3) - 1;
                    double d6 = d3 * i11;
                    d2 += ((i11 - d6) * d6) / i11;
                    double d7 = (i11 - d6) / i11;
                    if (z3) {
                        if (z4) {
                            d2 += d4;
                        }
                        if (z) {
                            d2 += d7;
                        }
                    } else {
                        d2 += d4 * d7;
                    }
                    d4 = d7;
                    i3 = i10;
                    z4 = false;
                    z3 = false;
                    i7 = 0;
                    i9++;
                }
                if (i3 != i2 - 1) {
                    int i12 = (i2 - i3) - 1;
                    double d8 = d3 * i12;
                    d2 += ((i12 - d8) * d8) / i12;
                    d = (i12 - d8) / i12;
                } else {
                    d = 1.0d;
                }
                boolean z5 = i3 == extractValues[numOffsets - 1];
                if (z3) {
                    if (z4) {
                        d2 += d4;
                    }
                    if (z5) {
                        d2 += d;
                    }
                } else if (z5) {
                    d2 += d4 * d;
                }
            }
        }
        return (int) Math.min(Math.round(d2), OptimizerUtils.MAX_NUMCELLS_CP_DENSE);
    }

    private static int[] getSortedUniformSample(int i, int i2) {
        if (i2 == 0) {
            return new int[0];
        }
        int[] nextPermutation = new RandomDataGenerator().nextPermutation(i, i2);
        Arrays.sort(nextPermutation);
        return nextPermutation;
    }

    private static int guaranteedErrorEstimator(int i, int i2, ReaderColumnSelection readerColumnSelection) {
        int i3 = 0;
        int i4 = 0;
        Iterator<Integer> it = getValCounts(readerColumnSelection).values().iterator();
        while (it.hasNext()) {
            if (it.next().intValue() == 1) {
                i3++;
            } else {
                i4++;
            }
        }
        return (int) Math.round(i4 + (i3 * Math.sqrt(i / i2)));
    }

    private static int shlosserEstimator(UncompressedBitmap uncompressedBitmap, int i, int i2) {
        double d = i2 / i;
        double d2 = 1.0d - d;
        int numValues = uncompressedBitmap.getNumValues();
        int[] freqCounts = getFreqCounts(uncompressedBitmap);
        double d3 = 0.0d;
        double d4 = 0.0d;
        int i3 = 1;
        int i4 = 0;
        while (i4 < freqCounts.length) {
            d3 += Math.pow(d2, i3) * freqCounts[i4];
            d4 += i3 * d * Math.pow(d2, i4) * freqCounts[i4];
            i4++;
            i3++;
        }
        int round = (int) Math.round(numValues + ((freqCounts[0] * d3) / d4));
        if (round < 1) {
            return 1;
        }
        return round;
    }

    private static int smoothedJackknifeEstimator(UncompressedBitmap uncompressedBitmap, int i, int i2) {
        int numValues = uncompressedBitmap.getNumValues();
        int[] freqCounts = getFreqCounts(uncompressedBitmap);
        if (freqCounts.length == 0) {
            return 0;
        }
        double d = freqCounts[0];
        int i3 = i * i2;
        double d2 = (numValues - (d / i2)) / (1.0d - ((((i - i2) + 1) * d) / i3));
        double d3 = i / d2;
        double d4 = i - d3;
        double d5 = (d4 - i2) + 1.0d;
        double d6 = i;
        double d7 = (i - i2) + 1;
        double min = Math.min(d4, d7 - 1.0d);
        double max = Math.max(min + 1.0d, d7);
        double d8 = 1.0d;
        while (true) {
            if (min < d5 && d6 < max) {
                break;
            }
            if (min >= d5) {
                d8 *= min;
            }
            if (d6 >= max) {
                d8 /= d6;
            }
            min -= 1.0d;
            d6 -= 1.0d;
        }
        double d9 = 0.0d;
        double d10 = 0.0d;
        for (int i4 = 2; i4 <= i2 + 1; i4++) {
            d9 += 1.0d / (((i - d3) - i2) + i4);
        }
        for (int i5 = 1; i5 <= freqCounts.length; i5++) {
            d10 += i5 * (i5 - 1) * freqCounts[i5 - 1];
        }
        double d11 = (numValues + (((i * d8) * d9) * ((d10 * ((((i - 1) * d2) / i3) / (i2 - 1))) + ((d2 / i) - 1.0d)))) / (1.0d - (((((i - d3) - i2) + 1.0d) * d) / i3));
        if (d11 < 1.0d) {
            return 1;
        }
        return (int) Math.round(d11);
    }

    private static int shlosserJackknifeEstimator(UncompressedBitmap uncompressedBitmap, int i, int i2) {
        int numValues = uncompressedBitmap.getNumValues();
        CriticalValue computeCriticalValue = computeCriticalValue(i2);
        double d = i2 / numValues;
        double d2 = 0.0d;
        for (int i3 = 0; i3 < numValues; i3++) {
            d2 += Math.pow(uncompressedBitmap.getNumOffsets(i3) - d, 2.0d);
        }
        double d3 = d2 / d;
        if (i2 != computeCriticalValue.usedSampleSize) {
            computeCriticalValue(i2);
        }
        return d3 < computeCriticalValue.uniformityCriticalValue ? smoothedJackknifeEstimator(uncompressedBitmap, i, i2) : shlosserEstimator(uncompressedBitmap, i, i2);
    }

    private static CriticalValue computeCriticalValue(int i) {
        return new CriticalValue(new ChiSquaredDistribution(i - 1).inverseCumulativeProbability(SHLOSSER_JACKKNIFE_ALPHA), i);
    }

    private static int haasAndStokes(UncompressedBitmap uncompressedBitmap, int i, int i2, HashMap<Integer, Double> hashMap) {
        int numValues = uncompressedBitmap.getNumValues();
        int[] freqCounts = getFreqCounts(uncompressedBitmap);
        if (numValues == 0) {
            return 1;
        }
        double d = i2 / i;
        double d2 = freqCounts[0];
        double gammaSquared = getGammaSquared(getDuj1Estimate(d, d2, i2, numValues), freqCounts, i2, i);
        return Math.max(1, (int) Math.round(gammaSquared < 0.9d ? getDuj2Estimate(d, d2, i2, numValues, gammaSquared) : gammaSquared < 30.0d ? getDuj2aEstimate(d, freqCounts, i2, numValues, gammaSquared, i, hashMap) : getSh3Estimate(d, freqCounts, numValues)));
    }

    private static HashMap<DblArray, Integer> getValCounts(ReaderColumnSelection readerColumnSelection) {
        HashMap<DblArray, Integer> hashMap = new HashMap<>();
        while (true) {
            DblArray nextRow = readerColumnSelection.nextRow();
            if (null == nextRow) {
                return hashMap;
            }
            Integer num = hashMap.get(nextRow);
            if (num == null) {
                num = 0;
            }
            hashMap.put(new DblArray(nextRow), Integer.valueOf(num.intValue() + 1));
        }
    }

    private static int[] getFreqCounts(UncompressedBitmap uncompressedBitmap) {
        int numValues = uncompressedBitmap.getNumValues();
        int i = 0;
        for (int i2 = 0; i2 < numValues; i2++) {
            i = Math.max(i, uncompressedBitmap.getNumOffsets(i2));
        }
        int[] iArr = new int[i];
        for (int i3 = 0; i3 < numValues; i3++) {
            int numOffsets = uncompressedBitmap.getNumOffsets(i3) - 1;
            iArr[numOffsets] = iArr[numOffsets] + 1;
        }
        return iArr;
    }

    private static double getDuj1Estimate(double d, double d2, int i, int i2) {
        return i2 / (1.0d - (((1.0d - d) * d2) / i));
    }

    private static double getDuj2Estimate(double d, double d2, int i, int i2, double d3) {
        return (i2 - (((((1.0d - d) * d2) * Math.log(1.0d - d)) * d3) / d)) / (1.0d - (((1.0d - d) * d2) / i));
    }

    private static double getDuj2aEstimate(double d, int[] iArr, int i, int i2, double d2, int i3, HashMap<Integer, Double> hashMap) {
        int length = (iArr.length / 2) + 1;
        int i4 = 0;
        int i5 = 0;
        for (int i6 = length; i6 <= iArr.length; i6++) {
            if (iArr[i6 - 1] != 0) {
                i4 += iArr[i6 - 1] * i6;
                i5 += iArr[i6 - 1];
            }
        }
        if (i - i4 == 0) {
            return getDuj2Estimate(d, iArr[0], i, i2, d2);
        }
        int i7 = i3;
        for (int i8 = length; i8 <= iArr.length; i8++) {
            if (iArr[i8 - 1] != 0) {
                i7 = (int) (i7 - (iArr[i8 - 1] * getMethodOfMomentsEstimate(i8, d, 1.0d, i3, hashMap)));
            }
        }
        for (int i9 = length; i9 <= iArr.length; i9++) {
            iArr[i9 - 1] = 0;
        }
        return getDuj2Estimate(d, iArr[0], i - i4, i2 - i5, getGammaSquared(getDuj1Estimate(d, iArr[0], i - i4, i2 - i5), iArr, i - i4, i7)) + i5;
    }

    private static double getSh3Estimate(double d, int[] iArr, double d2) {
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        double d6 = 0.0d;
        for (int i = 1; i <= iArr.length; i++) {
            if (iArr[i - 1] != 0) {
                d3 += i * d * d * Math.pow(1.0d - (d * d), i - 1) * iArr[i - 1];
                d4 += (Math.pow(1.0d - (d * d), i) - Math.pow(1.0d - d, i)) * iArr[i - 1];
                d5 += Math.pow(1.0d - d, i) * iArr[i - 1];
                d6 += i * d * Math.pow(1.0d - d, i - 1) * iArr[i - 1];
            }
        }
        return d2 + (((iArr[0] * d3) / d4) * Math.pow(d5 / d6, 2.0d));
    }

    private static double getGammaSquared(double d, int[] iArr, int i, int i2) {
        double d2 = 0.0d;
        for (int i3 = 1; i3 <= iArr.length; i3++) {
            if (iArr[i3 - 1] != 0) {
                d2 += i3 * (i3 - 1) * iArr[i3 - 1];
            }
        }
        return Math.max(0.0d, (d2 * ((d / i) / i)) + ((d / i2) - 1.0d));
    }

    private static double getMethodOfMomentsEstimate(int i, double d, double d2, double d3, HashMap<Integer, Double> hashMap) {
        if (hashMap.containsKey(Integer.valueOf(i))) {
            return hashMap.get(Integer.valueOf(i)).doubleValue();
        }
        double solve = UnivariateSolverUtils.solve(new MethodOfMomentsFunction(i, d), d2, d3, 1.0E-9d);
        if (hashMap.size() < 65536) {
            hashMap.put(Integer.valueOf(i), Double.valueOf(solve));
        }
        return solve;
    }
}
