package org.apache.sysml.runtime.matrix;

import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import org.apache.sysml.lops.MMTSJ;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.mr.AggregateBinaryInstruction;
import org.apache.sysml.runtime.instructions.mr.AggregateInstruction;
import org.apache.sysml.runtime.instructions.mr.AggregateUnaryInstruction;
import org.apache.sysml.runtime.instructions.mr.AppendInstruction;
import org.apache.sysml.runtime.instructions.mr.BinaryInstruction;
import org.apache.sysml.runtime.instructions.mr.BinaryMInstruction;
import org.apache.sysml.runtime.instructions.mr.BinaryMRInstructionBase;
import org.apache.sysml.runtime.instructions.mr.CM_N_COVInstruction;
import org.apache.sysml.runtime.instructions.mr.CombineBinaryInstruction;
import org.apache.sysml.runtime.instructions.mr.CombineTernaryInstruction;
import org.apache.sysml.runtime.instructions.mr.CombineUnaryInstruction;
import org.apache.sysml.runtime.instructions.mr.CtableInstruction;
import org.apache.sysml.runtime.instructions.mr.CumulativeAggregateInstruction;
import org.apache.sysml.runtime.instructions.mr.DataGenMRInstruction;
import org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction;
import org.apache.sysml.runtime.instructions.mr.GroupedAggregateMInstruction;
import org.apache.sysml.runtime.instructions.mr.MMTSJMRInstruction;
import org.apache.sysml.runtime.instructions.mr.MRInstruction;
import org.apache.sysml.runtime.instructions.mr.MapMultChainInstruction;
import org.apache.sysml.runtime.instructions.mr.MatrixReshapeMRInstruction;
import org.apache.sysml.runtime.instructions.mr.PMMJMRInstruction;
import org.apache.sysml.runtime.instructions.mr.ParameterizedBuiltinMRInstruction;
import org.apache.sysml.runtime.instructions.mr.QuaternaryInstruction;
import org.apache.sysml.runtime.instructions.mr.RandInstruction;
import org.apache.sysml.runtime.instructions.mr.RangeBasedReIndexInstruction;
import org.apache.sysml.runtime.instructions.mr.ReblockInstruction;
import org.apache.sysml.runtime.instructions.mr.RemoveEmptyMRInstruction;
import org.apache.sysml.runtime.instructions.mr.ReorgInstruction;
import org.apache.sysml.runtime.instructions.mr.ReplicateInstruction;
import org.apache.sysml.runtime.instructions.mr.ScalarInstruction;
import org.apache.sysml.runtime.instructions.mr.SeqInstruction;
import org.apache.sysml.runtime.instructions.mr.TernaryInstruction;
import org.apache.sysml.runtime.instructions.mr.UaggOuterChainInstruction;
import org.apache.sysml.runtime.instructions.mr.UnaryInstruction;
import org.apache.sysml.runtime.instructions.mr.UnaryMRInstructionBase;
import org.apache.sysml.runtime.instructions.mr.ZeroOutInstruction;
import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysml.runtime.matrix.operators.ReorgOperator;

/* loaded from: input_file:org/apache/sysml/runtime/matrix/MatrixCharacteristics.class */
public class MatrixCharacteristics implements Serializable {
    private static final long serialVersionUID = 8300479822915546000L;
    private long numRows = -1;
    private long numColumns = -1;
    private int numRowsPerBlock = 1;
    private int numColumnsPerBlock = 1;
    private long nonZero = -1;
    private boolean ubNnz = false;

    public MatrixCharacteristics() {
    }

    public MatrixCharacteristics(long j, long j2, int i, int i2) {
        set(j, j2, i, i2);
    }

    public MatrixCharacteristics(long j, long j2, int i, int i2, long j3) {
        set(j, j2, i, i2, j3);
    }

    public MatrixCharacteristics(MatrixCharacteristics matrixCharacteristics) {
        set(matrixCharacteristics.numRows, matrixCharacteristics.numColumns, matrixCharacteristics.numRowsPerBlock, matrixCharacteristics.numColumnsPerBlock, matrixCharacteristics.nonZero);
    }

    public void set(long j, long j2, int i, int i2) {
        this.numRows = j;
        this.numColumns = j2;
        this.numRowsPerBlock = i;
        this.numColumnsPerBlock = i2;
    }

    public void set(long j, long j2, int i, int i2, long j3) {
        this.numRows = j;
        this.numColumns = j2;
        this.numRowsPerBlock = i;
        this.numColumnsPerBlock = i2;
        this.nonZero = j3;
        this.ubNnz = false;
    }

    public void set(MatrixCharacteristics matrixCharacteristics) {
        this.numRows = matrixCharacteristics.numRows;
        this.numColumns = matrixCharacteristics.numColumns;
        this.numRowsPerBlock = matrixCharacteristics.numRowsPerBlock;
        this.numColumnsPerBlock = matrixCharacteristics.numColumnsPerBlock;
        this.nonZero = matrixCharacteristics.nonZero;
        this.ubNnz = matrixCharacteristics.ubNnz;
    }

    public long getRows() {
        return this.numRows;
    }

    public long getCols() {
        return this.numColumns;
    }

    public long getLength() {
        return this.numRows * this.numColumns;
    }

    public int getRowsPerBlock() {
        return this.numRowsPerBlock;
    }

    public void setRowsPerBlock(int i) {
        this.numRowsPerBlock = i;
    }

    public int getColsPerBlock() {
        return this.numColumnsPerBlock;
    }

    public void setColsPerBlock(int i) {
        this.numColumnsPerBlock = i;
    }

    public long getNumBlocks() {
        return getNumRowBlocks() * getNumColBlocks();
    }

    public long getNumRowBlocks() {
        return Math.max((long) Math.ceil(getRows() / getRowsPerBlock()), 1L);
    }

    public long getNumColBlocks() {
        return Math.max((long) Math.ceil(getCols() / getColsPerBlock()), 1L);
    }

    public String toString() {
        return "[" + this.numRows + " x " + this.numColumns + ", nnz=" + this.nonZero + " (" + this.ubNnz + "), blocks (" + this.numRowsPerBlock + " x " + this.numColumnsPerBlock + ")]";
    }

    public void setDimension(long j, long j2) {
        this.numRows = j;
        this.numColumns = j2;
    }

    public void setBlockSize(int i) {
        setBlockSize(i, i);
    }

    public void setBlockSize(int i, int i2) {
        this.numRowsPerBlock = i;
        this.numColumnsPerBlock = i2;
    }

    public void setNonZeros(long j) {
        this.ubNnz = false;
        this.nonZero = j;
    }

    public long getNonZeros() {
        if (this.ubNnz) {
            return -1L;
        }
        return this.nonZero;
    }

    public void setNonZerosBound(long j) {
        this.ubNnz = true;
        this.nonZero = j;
    }

    public long getNonZerosBound() {
        return this.nonZero;
    }

    public boolean dimsKnown() {
        return this.numRows >= 0 && this.numColumns >= 0;
    }

    public boolean dimsKnown(boolean z) {
        return this.numRows >= 0 && this.numColumns >= 0 && (!z || nnzKnown());
    }

    public boolean rowsKnown() {
        return this.numRows >= 0;
    }

    public boolean colsKnown() {
        return this.numColumns >= 0;
    }

    public boolean nnzKnown() {
        return !this.ubNnz && this.nonZero >= 0;
    }

    public boolean mightHaveEmptyBlocks() {
        return !nnzKnown() || this.numRows == 0 || this.numColumns == 0 || this.nonZero < (this.numRows * this.numColumns) - (Math.max(Math.min(this.numRows, (long) this.numRowsPerBlock), 1L) * Math.max(Math.min(this.numColumns, (long) this.numColumnsPerBlock), 1L));
    }

    public static void reorg(MatrixCharacteristics matrixCharacteristics, ReorgOperator reorgOperator, MatrixCharacteristics matrixCharacteristics2) throws DMLRuntimeException {
        reorgOperator.fn.computeDimension(matrixCharacteristics, matrixCharacteristics2);
    }

    public static void aggregateUnary(MatrixCharacteristics matrixCharacteristics, AggregateUnaryOperator aggregateUnaryOperator, MatrixCharacteristics matrixCharacteristics2) throws DMLRuntimeException {
        aggregateUnaryOperator.indexFn.computeDimension(matrixCharacteristics, matrixCharacteristics2);
    }

    public static void aggregateBinary(MatrixCharacteristics matrixCharacteristics, MatrixCharacteristics matrixCharacteristics2, AggregateBinaryOperator aggregateBinaryOperator, MatrixCharacteristics matrixCharacteristics3) {
        matrixCharacteristics3.set(matrixCharacteristics.numRows, matrixCharacteristics2.numColumns, matrixCharacteristics.numRowsPerBlock, matrixCharacteristics2.numColumnsPerBlock);
    }

    public static void computeDimension(HashMap<Byte, MatrixCharacteristics> hashMap, MRInstruction mRInstruction) throws DMLRuntimeException {
        MatrixCharacteristics matrixCharacteristics = hashMap.get(Byte.valueOf(mRInstruction.output));
        if (matrixCharacteristics == null) {
            matrixCharacteristics = new MatrixCharacteristics();
            hashMap.put(Byte.valueOf(mRInstruction.output), matrixCharacteristics);
        }
        if (mRInstruction instanceof ReorgInstruction) {
            ReorgInstruction reorgInstruction = (ReorgInstruction) mRInstruction;
            reorg(hashMap.get(Byte.valueOf(reorgInstruction.input)), (ReorgOperator) reorgInstruction.getOperator(), matrixCharacteristics);
            return;
        }
        if (mRInstruction instanceof AppendInstruction) {
            AppendInstruction appendInstruction = (AppendInstruction) mRInstruction;
            MatrixCharacteristics matrixCharacteristics2 = hashMap.get(Byte.valueOf(appendInstruction.input1));
            MatrixCharacteristics matrixCharacteristics3 = hashMap.get(Byte.valueOf(appendInstruction.input2));
            if (appendInstruction.isCBind()) {
                matrixCharacteristics.set(matrixCharacteristics2.numRows, matrixCharacteristics2.numColumns + matrixCharacteristics3.numColumns, matrixCharacteristics2.numRowsPerBlock, matrixCharacteristics3.numColumnsPerBlock);
                return;
            } else {
                matrixCharacteristics.set(matrixCharacteristics2.numRows + matrixCharacteristics3.numRows, matrixCharacteristics2.numColumns, matrixCharacteristics2.numRowsPerBlock, matrixCharacteristics3.numColumnsPerBlock);
                return;
            }
        }
        if (mRInstruction instanceof CumulativeAggregateInstruction) {
            MatrixCharacteristics matrixCharacteristics4 = hashMap.get(Byte.valueOf(((AggregateUnaryInstruction) mRInstruction).input));
            matrixCharacteristics.set((long) Math.ceil(matrixCharacteristics4.getRows() / matrixCharacteristics4.getRowsPerBlock()), matrixCharacteristics4.getCols(), matrixCharacteristics4.getRowsPerBlock(), matrixCharacteristics4.getColsPerBlock());
            return;
        }
        if (mRInstruction instanceof AggregateUnaryInstruction) {
            AggregateUnaryInstruction aggregateUnaryInstruction = (AggregateUnaryInstruction) mRInstruction;
            aggregateUnary(hashMap.get(Byte.valueOf(aggregateUnaryInstruction.input)), (AggregateUnaryOperator) aggregateUnaryInstruction.getOperator(), matrixCharacteristics);
            return;
        }
        if (mRInstruction instanceof AggregateBinaryInstruction) {
            AggregateBinaryInstruction aggregateBinaryInstruction = (AggregateBinaryInstruction) mRInstruction;
            aggregateBinary(hashMap.get(Byte.valueOf(aggregateBinaryInstruction.input1)), hashMap.get(Byte.valueOf(aggregateBinaryInstruction.input2)), (AggregateBinaryOperator) aggregateBinaryInstruction.getOperator(), matrixCharacteristics);
            return;
        }
        if (mRInstruction instanceof MapMultChainInstruction) {
            MapMultChainInstruction mapMultChainInstruction = (MapMultChainInstruction) mRInstruction;
            MatrixCharacteristics matrixCharacteristics5 = hashMap.get(Byte.valueOf(mapMultChainInstruction.getInput1()));
            matrixCharacteristics.set(matrixCharacteristics5.numColumns, hashMap.get(Byte.valueOf(mapMultChainInstruction.getInput2())).numColumns, matrixCharacteristics5.numRowsPerBlock, matrixCharacteristics5.numColumnsPerBlock);
            return;
        }
        if (mRInstruction instanceof QuaternaryInstruction) {
            QuaternaryInstruction quaternaryInstruction = (QuaternaryInstruction) mRInstruction;
            quaternaryInstruction.computeMatrixCharacteristics(hashMap.get(Byte.valueOf(quaternaryInstruction.getInput1())), hashMap.get(Byte.valueOf(quaternaryInstruction.getInput2())), hashMap.get(Byte.valueOf(quaternaryInstruction.getInput3())), matrixCharacteristics);
            return;
        }
        if (mRInstruction instanceof ReblockInstruction) {
            ReblockInstruction reblockInstruction = (ReblockInstruction) mRInstruction;
            MatrixCharacteristics matrixCharacteristics6 = hashMap.get(Byte.valueOf(reblockInstruction.input));
            matrixCharacteristics.set(matrixCharacteristics6.numRows, matrixCharacteristics6.numColumns, reblockInstruction.brlen, reblockInstruction.bclen, matrixCharacteristics6.nonZero);
            return;
        }
        if (mRInstruction instanceof MatrixReshapeMRInstruction) {
            MatrixReshapeMRInstruction matrixReshapeMRInstruction = (MatrixReshapeMRInstruction) mRInstruction;
            MatrixCharacteristics matrixCharacteristics7 = hashMap.get(Byte.valueOf(matrixReshapeMRInstruction.input));
            matrixCharacteristics.set(matrixReshapeMRInstruction.getNumRows(), matrixReshapeMRInstruction.getNumColunms(), matrixCharacteristics7.getRowsPerBlock(), matrixCharacteristics7.getColsPerBlock(), matrixCharacteristics7.getNonZeros());
            return;
        }
        if ((mRInstruction instanceof RandInstruction) || (mRInstruction instanceof SeqInstruction)) {
            matrixCharacteristics.set(hashMap.get(Byte.valueOf(((DataGenMRInstruction) mRInstruction).getInput())));
            return;
        }
        if (mRInstruction instanceof ReplicateInstruction) {
            ReplicateInstruction replicateInstruction = (ReplicateInstruction) mRInstruction;
            replicateInstruction.computeOutputDimension(hashMap.get(Byte.valueOf(replicateInstruction.input)), matrixCharacteristics);
            return;
        }
        if (mRInstruction instanceof ParameterizedBuiltinMRInstruction) {
            ParameterizedBuiltinMRInstruction parameterizedBuiltinMRInstruction = (ParameterizedBuiltinMRInstruction) mRInstruction;
            parameterizedBuiltinMRInstruction.computeOutputCharacteristics(hashMap.get(Byte.valueOf(parameterizedBuiltinMRInstruction.input)), matrixCharacteristics);
            return;
        }
        if ((mRInstruction instanceof ScalarInstruction) || (mRInstruction instanceof AggregateInstruction) || (((mRInstruction instanceof UnaryInstruction) && !(mRInstruction instanceof MMTSJMRInstruction)) || (mRInstruction instanceof ZeroOutInstruction))) {
            matrixCharacteristics.set(hashMap.get(Byte.valueOf(((UnaryMRInstructionBase) mRInstruction).input)));
            return;
        }
        if (mRInstruction instanceof MMTSJMRInstruction) {
            MMTSJMRInstruction mMTSJMRInstruction = (MMTSJMRInstruction) mRInstruction;
            MMTSJ.MMTSJType mMTSJType = mMTSJMRInstruction.getMMTSJType();
            MatrixCharacteristics matrixCharacteristics8 = hashMap.get(Byte.valueOf(mMTSJMRInstruction.input));
            matrixCharacteristics.set(mMTSJType.isLeft() ? matrixCharacteristics8.numColumns : matrixCharacteristics8.numRows, mMTSJType.isLeft() ? matrixCharacteristics8.numColumns : matrixCharacteristics8.numRows, matrixCharacteristics8.numRowsPerBlock, matrixCharacteristics8.numColumnsPerBlock);
            return;
        }
        if (mRInstruction instanceof PMMJMRInstruction) {
            PMMJMRInstruction pMMJMRInstruction = (PMMJMRInstruction) mRInstruction;
            MatrixCharacteristics matrixCharacteristics9 = hashMap.get(Byte.valueOf(pMMJMRInstruction.input2));
            matrixCharacteristics.set(pMMJMRInstruction.getNumRows(), matrixCharacteristics9.numColumns, matrixCharacteristics9.numRowsPerBlock, matrixCharacteristics9.numColumnsPerBlock);
            return;
        }
        if (mRInstruction instanceof RemoveEmptyMRInstruction) {
            RemoveEmptyMRInstruction removeEmptyMRInstruction = (RemoveEmptyMRInstruction) mRInstruction;
            MatrixCharacteristics matrixCharacteristics10 = hashMap.get(Byte.valueOf(removeEmptyMRInstruction.input1));
            long j = removeEmptyMRInstruction.isEmptyReturn() ? 1L : 0L;
            if (removeEmptyMRInstruction.isRemoveRows()) {
                matrixCharacteristics.set(Math.max(removeEmptyMRInstruction.getOutputLen(), j), matrixCharacteristics10.getCols(), matrixCharacteristics10.numRowsPerBlock, matrixCharacteristics10.numColumnsPerBlock);
                return;
            } else {
                matrixCharacteristics.set(matrixCharacteristics10.getRows(), Math.max(removeEmptyMRInstruction.getOutputLen(), j), matrixCharacteristics10.numRowsPerBlock, matrixCharacteristics10.numColumnsPerBlock);
                return;
            }
        }
        if (mRInstruction instanceof UaggOuterChainInstruction) {
            UaggOuterChainInstruction uaggOuterChainInstruction = (UaggOuterChainInstruction) mRInstruction;
            uaggOuterChainInstruction.computeOutputCharacteristics(hashMap.get(Byte.valueOf(uaggOuterChainInstruction.input1)), hashMap.get(Byte.valueOf(uaggOuterChainInstruction.input2)), matrixCharacteristics);
            return;
        }
        if (mRInstruction instanceof GroupedAggregateMInstruction) {
            GroupedAggregateMInstruction groupedAggregateMInstruction = (GroupedAggregateMInstruction) mRInstruction;
            groupedAggregateMInstruction.computeOutputCharacteristics(hashMap.get(Byte.valueOf(groupedAggregateMInstruction.input1)), matrixCharacteristics);
            return;
        }
        if ((mRInstruction instanceof BinaryInstruction) || (mRInstruction instanceof BinaryMInstruction) || (mRInstruction instanceof CombineBinaryInstruction)) {
            BinaryMRInstructionBase binaryMRInstructionBase = (BinaryMRInstructionBase) mRInstruction;
            MatrixCharacteristics matrixCharacteristics11 = hashMap.get(Byte.valueOf(binaryMRInstructionBase.input1));
            MatrixCharacteristics matrixCharacteristics12 = hashMap.get(Byte.valueOf(binaryMRInstructionBase.input2));
            if (matrixCharacteristics11.getRows() <= 1 || matrixCharacteristics11.getCols() != 1 || matrixCharacteristics12.getRows() != 1 || matrixCharacteristics12.getCols() <= 1) {
                matrixCharacteristics.set(matrixCharacteristics11);
                return;
            } else {
                matrixCharacteristics.set(matrixCharacteristics11.getRows(), matrixCharacteristics12.getCols(), matrixCharacteristics11.getRowsPerBlock(), matrixCharacteristics12.getColsPerBlock());
                return;
            }
        }
        if (mRInstruction instanceof TernaryInstruction) {
            matrixCharacteristics.set(hashMap.get(Byte.valueOf(mRInstruction.getInputIndexes()[0])));
            return;
        }
        if (mRInstruction instanceof CombineTernaryInstruction) {
            matrixCharacteristics.set(hashMap.get(Byte.valueOf(((CtableInstruction) mRInstruction).input1)));
            return;
        }
        if (mRInstruction instanceof CombineUnaryInstruction) {
            matrixCharacteristics.set(hashMap.get(Byte.valueOf(((CombineUnaryInstruction) mRInstruction).input)));
            return;
        }
        if ((mRInstruction instanceof CM_N_COVInstruction) || (mRInstruction instanceof GroupedAggregateInstruction)) {
            matrixCharacteristics.set(1L, 1L, 1, 1);
            return;
        }
        if (mRInstruction instanceof RangeBasedReIndexInstruction) {
            RangeBasedReIndexInstruction rangeBasedReIndexInstruction = (RangeBasedReIndexInstruction) mRInstruction;
            rangeBasedReIndexInstruction.computeOutputCharacteristics(hashMap.get(Byte.valueOf(rangeBasedReIndexInstruction.input)), matrixCharacteristics);
        } else if (mRInstruction instanceof CtableInstruction) {
            CtableInstruction ctableInstruction = (CtableInstruction) mRInstruction;
            MatrixCharacteristics matrixCharacteristics13 = hashMap.get(Byte.valueOf(ctableInstruction.input1));
            matrixCharacteristics.set(ctableInstruction.getOutputDim1(), ctableInstruction.getOutputDim2(), matrixCharacteristics13.numRowsPerBlock, matrixCharacteristics13.numColumnsPerBlock);
        } else {
            matrixCharacteristics.numRows = -1L;
            matrixCharacteristics.numColumns = -1L;
            matrixCharacteristics.numRowsPerBlock = 1;
            matrixCharacteristics.numColumnsPerBlock = 1;
        }
    }

    public boolean equals(Object obj) {
        if (!(obj instanceof MatrixCharacteristics)) {
            return false;
        }
        MatrixCharacteristics matrixCharacteristics = (MatrixCharacteristics) obj;
        return this.numRows == matrixCharacteristics.numRows && this.numColumns == matrixCharacteristics.numColumns && this.numRowsPerBlock == matrixCharacteristics.numRowsPerBlock && this.numColumnsPerBlock == matrixCharacteristics.numColumnsPerBlock && this.nonZero == matrixCharacteristics.nonZero;
    }

    public int hashCode() {
        return Arrays.hashCode(new long[]{this.numRows, this.numColumns, this.numRowsPerBlock, this.numColumnsPerBlock, this.nonZero});
    }
}
