package org.apache.sysml.runtime.codegen;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.compress.CompressedMatrixBlock;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.matrix.data.DenseBlock;
import org.apache.sysml.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.SparseBlock;
import org.apache.sysml.runtime.util.DataConverter;
import org.apache.sysml.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysml/runtime/codegen/SpoofOperator.class */
public abstract class SpoofOperator implements Serializable {
    private static final long serialVersionUID = 3834006998853573319L;
    private static final Log LOG = LogFactory.getLog(SpoofOperator.class.getName());
    protected static final long PAR_NUMCELL_THRESHOLD = 1048576;
    protected static final long PAR_MINFLOP_THRESHOLD = 2097152;

    /* loaded from: input_file:org/apache/sysml/runtime/codegen/SpoofOperator$SideInput.class */
    public static class SideInput {
        public final DenseBlock ddat;
        public final MatrixBlock mdat;
        public final int clen;

        public SideInput(DenseBlock denseBlock, MatrixBlock matrixBlock, int i) {
            this.ddat = denseBlock;
            this.mdat = matrixBlock;
            this.clen = i;
        }

        public int pos(int i) {
            return this.ddat != null ? this.ddat.pos(i) : i * this.clen;
        }

        public double[] values(int i) {
            if (this.ddat != null) {
                return this.ddat.values(i);
            }
            return null;
        }

        public double getValue(int i, int i2) {
            return SpoofOperator.getValue(this, this.clen, i, i2);
        }

        public void reset() {
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/codegen/SpoofOperator$SideInputSparseCell.class */
    public static class SideInputSparseCell extends SideInput {
        private int currRowIndex;
        private int currColPos;
        private int currLen;
        private int[] indexes;
        private double[] values;

        public SideInputSparseCell(SideInput sideInput) {
            super(sideInput.ddat, sideInput.mdat, sideInput.clen);
            this.currRowIndex = -1;
            this.currColPos = 0;
            this.currLen = 0;
        }

        public double next(int i, int i2) {
            SparseBlock sparseBlock = this.mdat.getSparseBlock();
            if (sparseBlock == null || sparseBlock.isEmpty(i)) {
                return 0.0d;
            }
            if (i > this.currRowIndex) {
                this.currRowIndex = i;
                this.currColPos = this.mdat.getSparseBlock().pos(this.currRowIndex);
                this.currLen = this.mdat.getSparseBlock().size(this.currRowIndex) + this.currColPos;
                this.indexes = this.mdat.getSparseBlock().indexes(this.currRowIndex);
                this.values = this.mdat.getSparseBlock().values(this.currRowIndex);
            }
            while (this.currColPos < this.currLen && this.indexes[this.currColPos] < i2) {
                this.currColPos++;
            }
            if (this.currColPos >= this.currLen || this.indexes[this.currColPos] != i2) {
                return 0.0d;
            }
            return this.values[this.currColPos];
        }

        @Override // org.apache.sysml.runtime.codegen.SpoofOperator.SideInput
        public void reset() {
            this.currColPos = 0;
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/codegen/SpoofOperator$SideInputSparseRow.class */
    public static class SideInputSparseRow extends SideInput {
        private final double[] values;
        private int currRowIndex;

        public SideInputSparseRow(SideInput sideInput) {
            super(sideInput.ddat, sideInput.mdat, sideInput.clen);
            this.currRowIndex = -1;
            this.values = new double[sideInput.clen];
        }

        @Override // org.apache.sysml.runtime.codegen.SpoofOperator.SideInput
        public int pos(int i) {
            return 0;
        }

        @Override // org.apache.sysml.runtime.codegen.SpoofOperator.SideInput
        public double[] values(int i) {
            if (i > this.currRowIndex) {
                nextRow(i);
            }
            return this.values;
        }

        private void nextRow(int i) {
            this.currRowIndex = i;
            SparseBlock sparseBlock = this.mdat.getSparseBlock();
            if (sparseBlock == null) {
                return;
            }
            Arrays.fill(this.values, 0.0d);
            if (sparseBlock.isEmpty(i)) {
                return;
            }
            int pos = sparseBlock.pos(i);
            int size = sparseBlock.size(i);
            int[] indexes = sparseBlock.indexes(i);
            double[] values = sparseBlock.values(i);
            for (int i2 = pos; i2 < pos + size; i2++) {
                this.values[indexes[i2]] = values[i2];
            }
        }
    }

    public abstract MatrixBlock execute(ArrayList<MatrixBlock> arrayList, ArrayList<ScalarObject> arrayList2, MatrixBlock matrixBlock) throws DMLRuntimeException;

    public MatrixBlock execute(ArrayList<MatrixBlock> arrayList, ArrayList<ScalarObject> arrayList2, MatrixBlock matrixBlock, int i) throws DMLRuntimeException {
        return execute(arrayList, arrayList2, matrixBlock);
    }

    public abstract String getSpoofType();

    public ScalarObject execute(ArrayList<MatrixBlock> arrayList, ArrayList<ScalarObject> arrayList2) throws DMLRuntimeException {
        throw new RuntimeException("Invalid invocation in base class.");
    }

    public ScalarObject execute(ArrayList<MatrixBlock> arrayList, ArrayList<ScalarObject> arrayList2, int i) throws DMLRuntimeException {
        return execute(arrayList, arrayList2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SideInput[] prepInputMatrices(ArrayList<MatrixBlock> arrayList) throws DMLRuntimeException {
        return prepInputMatrices(arrayList, 1, arrayList.size() - 1, false, false);
    }

    protected SideInput[] prepInputMatrices(ArrayList<MatrixBlock> arrayList, boolean z) throws DMLRuntimeException {
        return prepInputMatrices(arrayList, 1, arrayList.size() - 1, z, false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SideInput[] prepInputMatrices(ArrayList<MatrixBlock> arrayList, int i, boolean z) throws DMLRuntimeException {
        return prepInputMatrices(arrayList, i, arrayList.size() - i, z, false);
    }

    protected SideInput[] prepInputMatrices(ArrayList<MatrixBlock> arrayList, boolean z, boolean z2) throws DMLRuntimeException {
        return prepInputMatrices(arrayList, 1, arrayList.size() - 1, z, z2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SideInput[] prepInputMatrices(ArrayList<MatrixBlock> arrayList, int i, int i2, boolean z, boolean z2) throws DMLRuntimeException {
        SideInput[] sideInputArr = new SideInput[i2];
        int i3 = i;
        while (i3 < i + i2) {
            if (arrayList.get(i3) instanceof CompressedMatrixBlock) {
                arrayList.set(i3, ((CompressedMatrixBlock) arrayList.get(i3)).decompress());
            }
            int numColumns = arrayList.get(i3).getNumColumns();
            MatrixBlock transpose = (z2 && i3 == 1) ? LibMatrixReorg.transpose(arrayList.get(i3), new MatrixBlock(numColumns, arrayList.get(i3).getNumRows(), false)) : arrayList.get(i3);
            if (!z || (!transpose.isInSparseFormat() && transpose.isAllocated())) {
                if (transpose.isInSparseFormat() || !transpose.isAllocated()) {
                    sideInputArr[i3 - i] = new SideInput(null, transpose, numColumns);
                } else {
                    sideInputArr[i3 - i] = new SideInput(transpose.getDenseBlock(), null, numColumns);
                }
            } else if (transpose.getNumColumns() == 1 && transpose.isEmptyBlock(false)) {
                sideInputArr[i3 - i] = new SideInput(null, null, numColumns);
            } else {
                sideInputArr[i3 - i] = new SideInput(DataConverter.convertToDenseBlock(transpose, false), null, numColumns);
                LOG.warn(getClass().getName() + ": Converted " + transpose.getNumRows() + "x" + transpose.getNumColumns() + ", nnz=" + transpose.getNonZeros() + " sideways input matrix from sparse to dense.");
            }
            i3++;
        }
        return sideInputArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static SideInput[] createSparseSideInputs(SideInput[] sideInputArr) {
        return createSparseSideInputs(sideInputArr, false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static SideInput[] createSparseSideInputs(SideInput[] sideInputArr, boolean z) {
        boolean z2 = false;
        for (SideInput sideInput : sideInputArr) {
            z2 |= sideInput.mdat != null && sideInput.clen > 1;
        }
        if (!z2) {
            return sideInputArr;
        }
        SideInput[] sideInputArr2 = new SideInput[sideInputArr.length];
        for (int i = 0; i < sideInputArr.length; i++) {
            SideInput sideInput2 = sideInputArr[i];
            sideInputArr2[i] = (sideInput2.mdat == null || sideInput2.clen <= 1) ? sideInput2 : z ? new SideInputSparseRow(sideInput2) : new SideInputSparseCell(sideInput2);
        }
        return sideInputArr2;
    }

    public static DenseBlock[] getDenseMatrices(SideInput[] sideInputArr) {
        DenseBlock[] denseBlockArr = new DenseBlock[sideInputArr.length];
        for (int i = 0; i < sideInputArr.length; i++) {
            denseBlockArr[i] = sideInputArr[i].ddat;
        }
        return denseBlockArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double[] prepInputScalars(ArrayList<ScalarObject> arrayList) {
        double[] dArr = new double[arrayList.size()];
        for (int i = 0; i < arrayList.size(); i++) {
            dArr[i] = arrayList.get(i).getDoubleValue();
        }
        return dArr;
    }

    public static long getTotalInputNnz(ArrayList<MatrixBlock> arrayList) {
        return arrayList.stream().mapToLong(matrixBlock -> {
            return matrixBlock.getNonZeros();
        }).sum();
    }

    public static long getTotalInputSize(ArrayList<MatrixBlock> arrayList) {
        return arrayList.stream().mapToLong(matrixBlock -> {
            return matrixBlock.getNumRows() * matrixBlock.getNumColumns();
        }).sum();
    }

    protected static double getValue(double[] dArr, double d) {
        return getValue(dArr, UtilFunctions.toInt(d));
    }

    protected static double getValue(double[] dArr, int i) {
        if (dArr != null) {
            return dArr[i];
        }
        return 0.0d;
    }

    protected static double getValue(double[] dArr, int i, double d, double d2) {
        return getValue(dArr, i, UtilFunctions.toInt(d), UtilFunctions.toInt(d2));
    }

    protected static double getValue(double[] dArr, int i, int i2, int i3) {
        if (dArr != null) {
            return dArr[(i2 * i) + i3];
        }
        return 0.0d;
    }

    protected static double getValue(double[] dArr, int[] iArr, int i, int i2, double d) {
        return getValue(dArr, iArr, i, i2, UtilFunctions.toInt(d));
    }

    protected static double getValue(double[] dArr, int[] iArr, int i, int i2, int i3) {
        int binarySearch = Arrays.binarySearch(iArr, i, i + i2, i3);
        if (binarySearch >= 0) {
            return dArr[binarySearch];
        }
        return 0.0d;
    }

    protected static double getValue(SideInput sideInput, double d) {
        return getValue(sideInput, UtilFunctions.toInt(d));
    }

    protected static double getValue(SideInput sideInput, int i) {
        if (sideInput.ddat != null) {
            return sideInput.ddat.valuesAt(0)[i];
        }
        if (sideInput.mdat != null) {
            return sideInput.mdat.quickGetValue(i, 0);
        }
        return 0.0d;
    }

    protected static double getValue(SideInput sideInput, int i, double d, double d2) {
        return getValue(sideInput, i, UtilFunctions.toInt(d), UtilFunctions.toInt(d2));
    }

    protected static double getValue(SideInput sideInput, int i, int i2, int i3) {
        if (sideInput.ddat != null) {
            return sideInput.ddat.get(i2, i3);
        }
        if (sideInput instanceof SideInputSparseCell) {
            return ((SideInputSparseCell) sideInput).next(i2, i3);
        }
        if (sideInput.mdat != null) {
            return sideInput.mdat.quickGetValue(i2, i3);
        }
        return 0.0d;
    }

    protected static double[] getVector(SideInput sideInput, int i, double d, double d2) {
        return getVector(sideInput, i, UtilFunctions.toInt(d), UtilFunctions.toInt(d2));
    }

    protected static double[] getVector(SideInput sideInput, int i, int i2, int i3) {
        double[] allocVector = LibSpoofPrimitives.allocVector(i3 + 1, false);
        System.arraycopy(sideInput.values(i2), sideInput.pos(i2), allocVector, 0, i3 + 1);
        return allocVector;
    }
}
