package org.apache.sysml.runtime.compress;

import java.util.Arrays;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.KahanFunction;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.Pair;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysml.runtime.matrix.operators.ScalarOperator;

/* loaded from: input_file:org/apache/sysml/runtime/compress/ColGroupValue.class */
public abstract class ColGroupValue extends ColGroup {
    private static final long serialVersionUID = 3786247536054353658L;
    public static final boolean SORT_VALUES_BY_LENGTH = true;
    protected double[] _values;
    public static boolean LOW_LEVEL_OPT = true;
    private static ThreadLocal<Pair<int[], double[]>> memPool = new ThreadLocal<Pair<int[], double[]>>() { // from class: org.apache.sysml.runtime.compress.ColGroupValue.1
        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.lang.ThreadLocal
        public Pair<int[], double[]> initialValue() {
            return new Pair<>();
        }
    };

    public ColGroupValue() {
        super((int[]) null, -1);
    }

    public ColGroupValue(int[] iArr, int i, UncompressedBitmap uncompressedBitmap) {
        super(iArr, i);
        if (LOW_LEVEL_OPT && i > 65536) {
            uncompressedBitmap.sortValuesByFrequency();
        }
        this._values = uncompressedBitmap.getValues();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ColGroupValue(int[] iArr, int i, double[] dArr) {
        super(iArr, i);
        this._values = dArr;
    }

    @Override // org.apache.sysml.runtime.compress.ColGroup
    public long estimateInMemorySize() {
        return super.estimateInMemorySize() + 8 + getValuesSize();
    }

    public long getValuesSize() {
        if (this._values != null) {
            return 32 + (this._values.length * 8);
        }
        return 0L;
    }

    public int getNumValues() {
        return this._values.length / this._colIndexes.length;
    }

    public double[] getValues() {
        return this._values;
    }

    public void setValues(double[] dArr) {
        this._values = dArr;
    }

    public double getValue(int i, int i2) {
        return this._values[(i * getNumCols()) + i2];
    }

    public MatrixBlock getValuesAsBlock() {
        MatrixBlock matrixBlock = new MatrixBlock(this instanceof ColGroupOffset ? ((ColGroupOffset) this)._zeros : false ? this._values.length + 1 : this._values.length, 1, false);
        for (int i = 0; i < this._values.length; i++) {
            matrixBlock.quickSetValue(i, 0, this._values[i]);
        }
        return matrixBlock;
    }

    public final int[] getCounts() {
        return getCounts(new int[getNumValues()]);
    }

    public abstract int[] getCounts(int[] iArr);

    public final int[] getCounts(int i, int i2) {
        return getCounts(i, i2, new int[getNumValues()]);
    }

    public abstract int[] getCounts(int i, int i2, int[] iArr);

    public int[] getCounts(boolean z) {
        int[] counts = getCounts();
        if (z && (this instanceof ColGroupOffset)) {
            counts = Arrays.copyOf(counts, counts.length + 1);
            int i = 0;
            for (int i2 : counts) {
                i += i2;
            }
            counts[counts.length - 1] = getNumRows() - i;
        }
        return counts;
    }

    public MatrixBlock getCountsAsBlock() {
        return getCountsAsBlock(getCounts());
    }

    public static MatrixBlock getCountsAsBlock(int[] iArr) {
        MatrixBlock matrixBlock = new MatrixBlock(iArr.length, 1, false);
        for (int i = 0; i < iArr.length; i++) {
            matrixBlock.quickSetValue(i, 0, iArr[i]);
        }
        return matrixBlock;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int containsAllZeroValue() {
        int numValues = getNumValues();
        int numCols = getNumCols();
        int i = 0;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i >= numValues) {
                return -1;
            }
            boolean z = true;
            for (int i4 = 0; i4 < numCols; i4++) {
                z &= this._values[i3 + i4] == 0.0d;
            }
            if (z) {
                return i;
            }
            i++;
            i2 = i3 + numCols;
        }
    }

    public final double sumValues(int i) {
        int numCols = getNumCols();
        int i2 = i * numCols;
        double d = 0.0d;
        for (int i3 = 0; i3 < numCols; i3++) {
            d += this._values[i2 + i3];
        }
        return d;
    }

    public final double sumValues(int i, KahanFunction kahanFunction) {
        return sumValues(i, kahanFunction, new KahanObject(0.0d, 0.0d));
    }

    public final double sumValues(int i, KahanFunction kahanFunction, KahanObject kahanObject) {
        int numCols = getNumCols();
        int i2 = i * numCols;
        kahanObject.set(0.0d, 0.0d);
        for (int i3 = 0; i3 < numCols; i3++) {
            kahanFunction.execute2(kahanObject, this._values[i2 + i3]);
        }
        return kahanObject._sum;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final double[] sumAllValues(KahanFunction kahanFunction, KahanObject kahanObject) {
        return sumAllValues(kahanFunction, kahanObject, true);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final double[] sumAllValues(KahanFunction kahanFunction, KahanObject kahanObject, boolean z) {
        if (getNumCols() == 1 && (kahanFunction instanceof KahanPlus)) {
            return this._values;
        }
        int numValues = getNumValues();
        double[] allocDVector = z ? new double[numValues] : allocDVector(numValues, false);
        for (int i = 0; i < numValues; i++) {
            allocDVector[i] = sumValues(i, kahanFunction, kahanObject);
        }
        return allocDVector;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final double sumValues(int i, double[] dArr) {
        int numCols = getNumCols();
        int i2 = i * numCols;
        double d = 0.0d;
        for (int i3 = 0; i3 < numCols; i3++) {
            d += this._values[i2 + i3] * dArr[i3];
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final double[] preaggValues(int i, double[] dArr) {
        return preaggValues(i, dArr, false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final double[] preaggValues(int i, double[] dArr, boolean z) {
        double[] allocDVector = z ? new double[i] : allocDVector(i, false);
        for (int i2 = 0; i2 < i; i2++) {
            allocDVector[i2] = sumValues(i2, dArr);
        }
        return allocDVector;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeMxx(MatrixBlock matrixBlock, Builtin builtin, boolean z) {
        double d = builtin.getBuiltinCode() == Builtin.BuiltinCode.MAX ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY;
        if (z) {
            d = builtin.execute2(d, 0.0d);
        }
        int numValues = getNumValues();
        int numCols = getNumCols();
        for (int i = 0; i < numValues; i++) {
            int i2 = i * numCols;
            for (int i3 = 0; i3 < numCols; i3++) {
                d = builtin.execute2(d, this._values[i2 + i3]);
            }
        }
        matrixBlock.quickSetValue(0, 0, builtin.execute2(d, matrixBlock.quickGetValue(0, 0)));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void computeColMxx(MatrixBlock matrixBlock, Builtin builtin, boolean z) {
        int numValues = getNumValues();
        int numCols = getNumCols();
        double[] dArr = new double[numCols];
        Arrays.fill(dArr, builtin.getBuiltinCode() == Builtin.BuiltinCode.MAX ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY);
        if (z) {
            for (int i = 0; i < numCols; i++) {
                dArr[i] = builtin.execute2(dArr[i], 0.0d);
            }
        }
        for (int i2 = 0; i2 < numValues; i2++) {
            int i3 = i2 * numCols;
            for (int i4 = 0; i4 < numCols; i4++) {
                dArr[i4] = builtin.execute2(dArr[i4], this._values[i3 + i4]);
            }
        }
        for (int i5 = 0; i5 < numCols; i5++) {
            matrixBlock.quickSetValue(0, this._colIndexes[i5], dArr[i5]);
        }
    }

    public abstract void leftMultByRowVector(ColGroupDDC colGroupDDC, MatrixBlock matrixBlock) throws DMLRuntimeException;

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] applyScalarOp(ScalarOperator scalarOperator) throws DMLRuntimeException {
        double[] dArr = new double[this._values.length];
        for (int i = 0; i < this._values.length; i++) {
            dArr[i] = scalarOperator.executeScalar(this._values[i]);
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] applyScalarOp(ScalarOperator scalarOperator, double d, int i) throws DMLRuntimeException {
        double[] dArr = new double[this._values.length + i];
        for (int i2 = 0; i2 < this._values.length; i2++) {
            dArr[i2] = scalarOperator.executeScalar(this._values[i2]);
        }
        Arrays.fill(dArr, this._values.length, this._values.length + i, d);
        return dArr;
    }

    @Override // org.apache.sysml.runtime.compress.ColGroup
    public void unaryAggregateOperations(AggregateUnaryOperator aggregateUnaryOperator, MatrixBlock matrixBlock) throws DMLRuntimeException {
        unaryAggregateOperations(aggregateUnaryOperator, matrixBlock, 0, getNumRows());
    }

    public abstract void unaryAggregateOperations(AggregateUnaryOperator aggregateUnaryOperator, MatrixBlock matrixBlock, int i, int i2) throws DMLRuntimeException;

    public static void setupThreadLocalMemory(int i) {
        Pair<int[], double[]> pair = new Pair<>();
        pair.setKey(new int[i]);
        pair.setValue(new double[i]);
        memPool.set(pair);
    }

    public static void cleanupThreadLocalMemory() {
        memPool.remove();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double[] allocDVector(int i, boolean z) {
        Pair<int[], double[]> pair = memPool.get();
        if (pair.getValue() == null) {
            return new double[i];
        }
        double[] value = pair.getValue();
        if (z) {
            Arrays.fill(value, 0, i, 0.0d);
        }
        return value;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static int[] allocIVector(int i, boolean z) {
        Pair<int[], double[]> pair = memPool.get();
        if (pair.getKey() == null) {
            return new int[i];
        }
        int[] key = pair.getKey();
        if (z) {
            Arrays.fill(key, 0, i, 0);
        }
        return key;
    }
}
