package org.apache.sysml.hops;

import java.util.Iterator;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.lops.Aggregate;
import org.apache.sysml.lops.Data;
import org.apache.sysml.lops.Group;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.RangeBasedReIndex;
import org.apache.sysml.parser.DataExpression;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;

/* loaded from: input_file:org/apache/sysml/hops/IndexingOp.class */
public class IndexingOp extends Hop {
    public static String OPSTRING = "rix";
    private boolean _rowLowerEqualsUpper;
    private boolean _colLowerEqualsUpper;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/hops/IndexingOp$IndexingMethod.class */
    public enum IndexingMethod {
        CP_RIX,
        MR_RIX,
        MR_VRIX
    }

    private IndexingOp() {
        this._rowLowerEqualsUpper = false;
        this._colLowerEqualsUpper = false;
    }

    public IndexingOp(String str, Expression.DataType dataType, Expression.ValueType valueType, Hop hop, Hop hop2, Hop hop3, Hop hop4, Hop hop5, boolean z, boolean z2) {
        super(str, dataType, valueType);
        this._rowLowerEqualsUpper = false;
        this._colLowerEqualsUpper = false;
        getInput().add(0, hop);
        getInput().add(1, hop2);
        getInput().add(2, hop3);
        getInput().add(3, hop4);
        getInput().add(4, hop5);
        hop.getParent().add(this);
        hop2.getParent().add(this);
        hop3.getParent().add(this);
        hop4.getParent().add(this);
        hop5.getParent().add(this);
        setRowLowerEqualsUpper(z);
        setColLowerEqualsUpper(z2);
    }

    public boolean getRowLowerEqualsUpper() {
        return this._rowLowerEqualsUpper;
    }

    public boolean getColLowerEqualsUpper() {
        return this._colLowerEqualsUpper;
    }

    public void setRowLowerEqualsUpper(boolean z) {
        this._rowLowerEqualsUpper = z;
    }

    public void setColLowerEqualsUpper(boolean z) {
        this._colLowerEqualsUpper = z;
    }

    @Override // org.apache.sysml.hops.Hop
    public Lop constructLops() throws HopsException, LopsException {
        if (getLops() != null) {
            return getLops();
        }
        Hop hop = getInput().get(0);
        if (dimsKnown() && hop.dimsKnown() && getDim1() == hop.getDim1() && getDim2() == hop.getDim2()) {
            setLops(hop.constructLops());
        } else {
            try {
                LopProperties.ExecType optFindExecType = optFindExecType();
                if (optFindExecType == LopProperties.ExecType.MR) {
                    IndexingMethod optFindIndexingMethod = optFindIndexingMethod(this._rowLowerEqualsUpper, this._colLowerEqualsUpper, hop._dim1, hop._dim2, this._dim1, this._dim2);
                    Data createLiteralLop = Data.createLiteralLop(Expression.ValueType.INT, Integer.toString(-1));
                    RangeBasedReIndex rangeBasedReIndex = new RangeBasedReIndex(hop.constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(), getInput().get(3).constructLops(), getInput().get(4).constructLops(), createLiteralLop, createLiteralLop, getDataType(), getValueType(), optFindExecType);
                    setOutputDimensions(rangeBasedReIndex);
                    setLineNumbers(rangeBasedReIndex);
                    if (optFindIndexingMethod == IndexingMethod.MR_RIX) {
                        Group group = new Group(rangeBasedReIndex, Group.OperationTypes.Sort, Expression.DataType.MATRIX, getValueType());
                        setOutputDimensions(group);
                        setLineNumbers(group);
                        Aggregate aggregate = new Aggregate(group, Aggregate.OperationTypes.Sum, Expression.DataType.MATRIX, getValueType(), optFindExecType);
                        setOutputDimensions(aggregate);
                        setLineNumbers(aggregate);
                        setLops(aggregate);
                    } else {
                        setLops(rangeBasedReIndex);
                    }
                } else if (optFindExecType == LopProperties.ExecType.SPARK) {
                    AggBinaryOp.SparkAggType sparkAggType = optFindIndexingMethod(this._rowLowerEqualsUpper, this._colLowerEqualsUpper, hop._dim1, hop._dim2, this._dim1, this._dim2) == IndexingMethod.MR_VRIX ? AggBinaryOp.SparkAggType.NONE : AggBinaryOp.SparkAggType.MULTI_BLOCK;
                    Data createLiteralLop2 = Data.createLiteralLop(Expression.ValueType.INT, Integer.toString(-1));
                    RangeBasedReIndex rangeBasedReIndex2 = new RangeBasedReIndex(hop.constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(), getInput().get(3).constructLops(), getInput().get(4).constructLops(), createLiteralLop2, createLiteralLop2, getDataType(), getValueType(), sparkAggType, optFindExecType);
                    setOutputDimensions(rangeBasedReIndex2);
                    setLineNumbers(rangeBasedReIndex2);
                    setLops(rangeBasedReIndex2);
                } else {
                    Data createLiteralLop3 = Data.createLiteralLop(Expression.ValueType.INT, Integer.toString(-1));
                    RangeBasedReIndex rangeBasedReIndex3 = new RangeBasedReIndex(hop.constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(), getInput().get(3).constructLops(), getInput().get(4).constructLops(), createLiteralLop3, createLiteralLop3, getDataType(), getValueType(), optFindExecType);
                    setOutputDimensions(rangeBasedReIndex3);
                    setLineNumbers(rangeBasedReIndex3);
                    setLops(rangeBasedReIndex3);
                }
            } catch (Exception e) {
                throw new HopsException(printErrorLocation() + "In IndexingOp Hop, error constructing Lops ", e);
            }
        }
        constructAndSetLopsDataFlowProperties();
        return getLops();
    }

    @Override // org.apache.sysml.hops.Hop
    public String getOpString() {
        return new String("") + OPSTRING;
    }

    @Override // org.apache.sysml.hops.Hop
    public void printMe() throws HopsException {
        if (getVisited() != Hop.VisitStatus.DONE) {
            super.printMe();
            Iterator<Hop> it = getInput().iterator();
            while (it.hasNext()) {
                it.next().printMe();
            }
        }
        setVisited(Hop.VisitStatus.DONE);
    }

    @Override // org.apache.sysml.hops.Hop
    public boolean allowsAllExecTypes() {
        return true;
    }

    @Override // org.apache.sysml.hops.Hop
    public void computeMemEstimate(MemoTable memoTable) {
        super.computeMemEstimate(memoTable);
        MatrixCharacteristics allInputStats = memoTable.getAllInputStats(getInput().get(0));
        if (!dimsKnown() || allInputStats.getNonZeros() < 0) {
            return;
        }
        double computeOutputMemEstimate = computeOutputMemEstimate(this._dim1, this._dim2, allInputStats.getNonZeros());
        if (computeOutputMemEstimate < this._outputMemEstimate) {
            this._outputMemEstimate = computeOutputMemEstimate;
            this._memEstimate = getInputOutputSize();
        }
    }

    @Override // org.apache.sysml.hops.Hop
    protected double computeOutputMemEstimate(long j, long j2, long j3) {
        return OptimizerUtils.estimateSizeExactSparsity(j, j2, OptimizerUtils.getSparsity(j, j2, j3));
    }

    @Override // org.apache.sysml.hops.Hop
    protected double computeIntermediateMemEstimate(long j, long j2, long j3) {
        return DataExpression.DEFAULT_DELIM_FILL_VALUE;
    }

    @Override // org.apache.sysml.hops.Hop
    protected long[] inferOutputCharacteristics(MemoTable memoTable) {
        long[] jArr = null;
        MatrixCharacteristics allInputStats = memoTable.getAllInputStats(getInput().get(0));
        if (allInputStats != null) {
            jArr = new long[]{allInputStats.getRows(), allInputStats.getCols(), allInputStats.dimsKnown() ? Math.min(allInputStats.getRows() * allInputStats.getCols(), allInputStats.getNonZeros()) : -1L};
            if (this._rowLowerEqualsUpper) {
                jArr[0] = 1;
            }
            if (this._colLowerEqualsUpper) {
                jArr[1] = 1;
            }
            Hop hop = getInput().get(1);
            Hop hop2 = getInput().get(2);
            Hop hop3 = getInput().get(3);
            Hop hop4 = getInput().get(4);
            if (isBlockIndexingExpression(hop, hop2)) {
                jArr[0] = getBlockIndexingExpressionSize(hop, hop2);
            }
            if (isBlockIndexingExpression(hop3, hop4)) {
                jArr[1] = getBlockIndexingExpressionSize(hop3, hop4);
            }
        }
        return jArr;
    }

    private boolean isBlockIndexingExpression(Hop hop, Hop hop2) {
        boolean z = false;
        LiteralOp literalOp = null;
        DataOp dataOp = null;
        if ((hop instanceof BinaryOp) && ((BinaryOp) hop).getOp() == Hop.OpOp2.PLUS && (hop.getInput().get(1) instanceof LiteralOp) && HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop.getInput().get(1)) == 1.0d && (hop.getInput().get(0) instanceof BinaryOp)) {
            BinaryOp binaryOp = (BinaryOp) hop.getInput().get(0);
            if (binaryOp.getOp() == Hop.OpOp2.MULT && (binaryOp.getInput().get(0) instanceof LiteralOp) && (binaryOp.getInput().get(1) instanceof BinaryOp)) {
                BinaryOp binaryOp2 = (BinaryOp) binaryOp.getInput().get(1);
                if (binaryOp2.getOp() == Hop.OpOp2.MINUS && (binaryOp2.getInput().get(1) instanceof LiteralOp) && HopRewriteUtils.getDoubleValueSafe((LiteralOp) binaryOp2.getInput().get(1)) == 1.0d && (binaryOp2.getInput().get(0) instanceof DataOp)) {
                    literalOp = (LiteralOp) binaryOp.getInput().get(0);
                    dataOp = (DataOp) binaryOp2.getInput().get(0);
                }
            }
        }
        if (dataOp != null && literalOp != null && (hop2 instanceof BinaryOp) && (hop2.getInput().get(0) instanceof LiteralOp) && (hop2.getInput().get(1) instanceof DataOp) && hop2.getInput().get(1).getName().equals(dataOp.getName())) {
            z = HopRewriteUtils.getDoubleValueSafe(literalOp) == HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop2.getInput().get(0));
        }
        return z;
    }

    private long getBlockIndexingExpressionSize(Hop hop, Hop hop2) {
        return HopRewriteUtils.getIntValueSafe((LiteralOp) hop2.getInput().get(0));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysml.hops.Hop
    public LopProperties.ExecType optFindExecType() throws HopsException {
        checkAndSetForcedPlatform();
        LopProperties.ExecType execType = OptimizerUtils.isSparkExecutionMode() ? LopProperties.ExecType.SPARK : LopProperties.ExecType.MR;
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            if (OptimizerUtils.isMemoryBasedOptLevel()) {
                this._etype = findExecTypeByMemEstimate();
            } else if (getInput().get(0).areDimsBelowThreshold()) {
                this._etype = LopProperties.ExecType.CP;
            } else {
                this._etype = execType;
            }
            checkAndSetInvalidCPDimsAndSize();
        }
        if (OptimizerUtils.ALLOW_DYN_RECOMPILATION && !dimsKnown(true) && this._etype == execType) {
            setRequiresRecompile();
        }
        return this._etype;
    }

    private static IndexingMethod optFindIndexingMethod(boolean z, boolean z2, long j, long j2, long j3, long j4) {
        return ((z && j2 == j4 && j4 != -1) || (z2 && j == j3 && j3 != -1)) ? IndexingMethod.MR_VRIX : IndexingMethod.MR_RIX;
    }

    @Override // org.apache.sysml.hops.Hop
    public void refreshSizeInformation() {
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        Hop hop3 = getInput().get(2);
        Hop hop4 = getInput().get(3);
        Hop hop5 = getInput().get(4);
        boolean z = (hop2 instanceof LiteralOp) && HopRewriteUtils.getIntValueSafe((LiteralOp) hop2) == 1 && (hop3 instanceof UnaryOp) && ((UnaryOp) hop3).getOp() == Hop.OpOp1.NROW;
        boolean z2 = (hop4 instanceof LiteralOp) && HopRewriteUtils.getIntValueSafe((LiteralOp) hop4) == 1 && (hop5 instanceof UnaryOp) && ((UnaryOp) hop5).getOp() == Hop.OpOp1.NCOL;
        boolean z3 = (hop2 instanceof LiteralOp) && (hop3 instanceof LiteralOp);
        boolean z4 = (hop4 instanceof LiteralOp) && (hop5 instanceof LiteralOp);
        if (this._rowLowerEqualsUpper) {
            setDim1(1L);
        } else if (z) {
            setDim1(hop.getDim1());
        } else if (z3) {
            setDim1((HopRewriteUtils.getIntValueSafe((LiteralOp) hop3) - HopRewriteUtils.getIntValueSafe((LiteralOp) hop2)) + 1);
        } else if (isBlockIndexingExpression(hop2, hop3)) {
            setDim1(getBlockIndexingExpressionSize(hop2, hop3));
        }
        if (this._colLowerEqualsUpper) {
            setDim2(1L);
            return;
        }
        if (z2) {
            setDim2(hop.getDim2());
        } else if (z4) {
            setDim2((HopRewriteUtils.getIntValueSafe((LiteralOp) hop5) - HopRewriteUtils.getIntValueSafe((LiteralOp) hop4)) + 1);
        } else if (isBlockIndexingExpression(hop4, hop5)) {
            setDim2(getBlockIndexingExpressionSize(hop4, hop5));
        }
    }

    @Override // org.apache.sysml.hops.Hop
    public Object clone() throws CloneNotSupportedException {
        IndexingOp indexingOp = new IndexingOp();
        indexingOp.clone(this, false);
        return indexingOp;
    }

    @Override // org.apache.sysml.hops.Hop
    public boolean compare(Hop hop) {
        return (hop instanceof IndexingOp) && getInput().size() == hop.getInput().size() && getInput().get(0) == hop.getInput().get(0) && getInput().get(1) == hop.getInput().get(1) && getInput().get(2) == hop.getInput().get(2) && getInput().get(3) == hop.getInput().get(3) && getInput().get(4) == hop.getInput().get(4);
    }
}
