package org.apache.sysml.hops;

import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.lops.Aggregate;
import org.apache.sysml.lops.Binary;
import org.apache.sysml.lops.DataPartition;
import org.apache.sysml.lops.Group;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.MMCJ;
import org.apache.sysml.lops.MMRJ;
import org.apache.sysml.lops.MMTSJ;
import org.apache.sysml.lops.MMZip;
import org.apache.sysml.lops.MapMult;
import org.apache.sysml.lops.MapMultChain;
import org.apache.sysml.lops.PMMJ;
import org.apache.sysml.lops.PMapMult;
import org.apache.sysml.lops.PartialAggregate;
import org.apache.sysml.lops.Transform;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.mapred.DistributedCacheInput;
import org.apache.sysml.runtime.matrix.mapred.MMCJMRReducerWithAggregator;
import org.apache.sysml.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysml/hops/AggBinaryOp.class */
public class AggBinaryOp extends Hop implements Hop.MultiThreadedHop {
    public static final double MAPMULT_MEM_MULTIPLIER = 1.0d;
    public static MMultMethod FORCED_MMULT_METHOD = null;
    private Hop.OpOp2 innerOp;
    private Hop.AggOp outerOp;
    private MMultMethod _method;
    private boolean _hasLeftPMInput;
    private int _maxNumThreads;

    /* loaded from: input_file:org/apache/sysml/hops/AggBinaryOp$MMultMethod.class */
    public enum MMultMethod {
        CPMM,
        RMM,
        MAPMM_L,
        MAPMM_R,
        MAPMM_CHAIN,
        PMAPMM,
        PMM,
        TSMM,
        TSMM2,
        ZIPMM,
        MM
    }

    /* loaded from: input_file:org/apache/sysml/hops/AggBinaryOp$SparkAggType.class */
    public enum SparkAggType {
        NONE,
        SINGLE_BLOCK,
        MULTI_BLOCK
    }

    private AggBinaryOp() {
        this._method = null;
        this._hasLeftPMInput = false;
        this._maxNumThreads = -1;
    }

    public AggBinaryOp(String str, Expression.DataType dataType, Expression.ValueType valueType, Hop.OpOp2 opOp2, Hop.AggOp aggOp, Hop hop, Hop hop2) {
        super(str, dataType, valueType);
        this._method = null;
        this._hasLeftPMInput = false;
        this._maxNumThreads = -1;
        this.innerOp = opOp2;
        this.outerOp = aggOp;
        getInput().add(0, hop);
        getInput().add(1, hop2);
        hop.getParent().add(this);
        hop2.getParent().add(this);
        refreshSizeInformation();
    }

    @Override // org.apache.sysml.hops.Hop
    public void checkArity() throws HopsException {
        HopsException.check(this._input.size() == 2, this, "should have arity 2 but has arity %d", Integer.valueOf(this._input.size()));
    }

    public void setHasLeftPMInput(boolean z) {
        this._hasLeftPMInput = z;
    }

    public boolean hasLeftPMInput() {
        return this._hasLeftPMInput;
    }

    @Override // org.apache.sysml.hops.Hop.MultiThreadedHop
    public void setMaxNumThreads(int i) {
        this._maxNumThreads = i;
    }

    @Override // org.apache.sysml.hops.Hop.MultiThreadedHop
    public int getMaxNumThreads() {
        return this._maxNumThreads;
    }

    public MMultMethod getMMultMethod() {
        return this._method;
    }

    @Override // org.apache.sysml.hops.Hop
    public boolean isGPUEnabled() {
        if (!DMLScript.USE_ACCELERATOR) {
            return false;
        }
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        this._method = optFindMMultMethodCP(hop.getDim1(), hop.getDim2(), hop2.getDim1(), hop2.getDim2(), checkTransposeSelf(), checkMapMultChain(), this._hasLeftPMInput);
        switch (this._method) {
            case TSMM:
                return false;
            case MAPMM_CHAIN:
                return false;
            case PMM:
                return false;
            case MM:
                return true;
            default:
                throw new RuntimeException("Unsupported method:" + this._method);
        }
    }

    @Override // org.apache.sysml.hops.Hop
    public Lop constructLops() throws HopsException, LopsException {
        if (getLops() != null) {
            return getLops();
        }
        if (!isMatrixMultiply()) {
            throw new HopsException(printErrorLocation() + "Invalid operation in AggBinary Hop, aggBin(" + this.innerOp + "," + this.outerOp + ") while constructing lops.");
        }
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        LopProperties.ExecType optFindExecType = optFindExecType();
        MMTSJ.MMTSJType checkTransposeSelf = checkTransposeSelf();
        MapMultChain.ChainType checkMapMultChain = checkMapMultChain();
        if (optFindExecType != LopProperties.ExecType.CP && optFindExecType != LopProperties.ExecType.GPU) {
            if (optFindExecType != LopProperties.ExecType.SPARK) {
                if (optFindExecType == LopProperties.ExecType.MR) {
                    this._method = optFindMMultMethodMR(hop.getDim1(), hop.getDim2(), hop.getRowsInBlock(), hop.getColsInBlock(), hop.getNnz(), hop2.getDim1(), hop2.getDim2(), hop2.getRowsInBlock(), hop2.getColsInBlock(), hop2.getNnz(), checkTransposeSelf, checkMapMultChain, this._hasLeftPMInput);
                    switch (this._method) {
                        case TSMM:
                            constructMRLopsTSMM(checkTransposeSelf);
                            break;
                        case MAPMM_CHAIN:
                            constructMRLopsMapMMChain(checkMapMultChain);
                            break;
                        case PMM:
                            constructMRLopsPMM();
                            break;
                        case MM:
                        case TSMM2:
                        case PMAPMM:
                        default:
                            throw new HopsException(printErrorLocation() + "Invalid Matrix Mult Method (" + this._method + ") while constructing MR lops.");
                        case MAPMM_L:
                        case MAPMM_R:
                            constructMRLopsMapMM(this._method);
                            break;
                        case CPMM:
                            constructMRLopsCPMM();
                            break;
                        case RMM:
                            constructMRLopsRMM();
                            break;
                    }
                }
            } else {
                this._method = optFindMMultMethodSpark(hop.getDim1(), hop.getDim2(), hop.getRowsInBlock(), hop.getColsInBlock(), hop.getNnz(), hop2.getDim1(), hop2.getDim2(), hop2.getRowsInBlock(), hop2.getColsInBlock(), hop2.getNnz(), checkTransposeSelf, checkMapMultChain, this._hasLeftPMInput, HopRewriteUtils.isTransposeOperation(hop));
                switch (this._method) {
                    case TSMM:
                    case TSMM2:
                        constructSparkLopsTSMM(checkTransposeSelf, this._method == MMultMethod.TSMM2);
                        break;
                    case MAPMM_CHAIN:
                        constructSparkLopsMapMMChain(checkMapMultChain);
                        break;
                    case PMM:
                        constructSparkLopsPMM();
                        break;
                    case MM:
                    default:
                        throw new HopsException(printErrorLocation() + "Invalid Matrix Mult Method (" + this._method + ") while constructing SPARK lops.");
                    case MAPMM_L:
                    case MAPMM_R:
                        constructSparkLopsMapMM(this._method);
                        break;
                    case PMAPMM:
                        constructSparkLopsPMapMM();
                        break;
                    case CPMM:
                        constructSparkLopsCPMM();
                        break;
                    case RMM:
                        constructSparkLopsRMM();
                        break;
                    case ZIPMM:
                        constructSparkLopsZIPMM();
                        break;
                }
            }
        } else {
            this._method = optFindMMultMethodCP(hop.getDim1(), hop.getDim2(), hop2.getDim1(), hop2.getDim2(), checkTransposeSelf, checkMapMultChain, this._hasLeftPMInput);
            switch (this._method) {
                case TSMM:
                    constructCPLopsTSMM(checkTransposeSelf, optFindExecType);
                    break;
                case MAPMM_CHAIN:
                    constructCPLopsMMChain(checkMapMultChain);
                    break;
                case PMM:
                    constructCPLopsPMM();
                    break;
                case MM:
                    constructCPLopsMM(optFindExecType);
                    break;
                default:
                    throw new HopsException(printErrorLocation() + "Invalid Matrix Mult Method (" + this._method + ") while constructing CP lops.");
            }
        }
        constructAndSetLopsDataFlowProperties();
        return getLops();
    }

    @Override // org.apache.sysml.hops.Hop
    public String getOpString() {
        return "ba(" + HopsAgg2String.get(this.outerOp) + HopsOpOp2String.get(this.innerOp) + ")";
    }

    @Override // org.apache.sysml.hops.Hop
    public void computeMemEstimate(MemoTable memoTable) {
        super.computeMemEstimate(memoTable);
        if (checkTransposeSelf().isLeft() && getInput().get(1).dimsKnown() && getInput().get(1).getDim2() > 1) {
            this._memEstimate -= getInput().get(0)._outputMemEstimate;
        }
    }

    @Override // org.apache.sysml.hops.Hop
    protected double computeOutputMemEstimate(long j, long j2, long j3) {
        return OptimizerUtils.estimateSizeExactSparsity(j, j2, 1.0d);
    }

    @Override // org.apache.sysml.hops.Hop
    protected double computeIntermediateMemEstimate(long j, long j2, long j3) {
        double d = 0.0d;
        if (isGPUEnabled()) {
            Hop hop = this._input.get(0);
            Hop hop2 = this._input.get(1);
            double sparsity = OptimizerUtils.getSparsity(hop.getDim1(), hop.getDim2(), hop.getNnz());
            double sparsity2 = OptimizerUtils.getSparsity(hop2.getDim1(), hop2.getDim2(), hop2.getNnz());
            boolean z = sparsity < 0.4d;
            boolean z2 = sparsity2 < 0.4d;
            if (z && !z2) {
                d = 0.0d + OptimizerUtils.estimateSizeExactSparsity(j, j2, 1.0d);
            }
        }
        if (j2 >= 2) {
            d += MatrixBlock.estimateSizeSparseInMemory(j, j2, 0.4d - UtilFunctions.DOUBLE_EPS);
        }
        return d;
    }

    @Override // org.apache.sysml.hops.Hop
    protected long[] inferOutputCharacteristics(MemoTable memoTable) {
        long[] jArr = null;
        MatrixCharacteristics[] allInputStats = memoTable.getAllInputStats(getInput());
        if (allInputStats[0].rowsKnown() && allInputStats[1].colsKnown()) {
            jArr = new long[3];
            jArr[0] = allInputStats[0].getRows();
            jArr[1] = allInputStats[1].getCols();
            jArr[2] = (long) (jArr[0] * jArr[1] * OptimizerUtils.getMatMultSparsity(allInputStats[0].getNonZeros() > 0 ? OptimizerUtils.getSparsity(allInputStats[0].getRows(), allInputStats[0].getCols(), allInputStats[0].getNonZeros()) : 1.0d, allInputStats[1].getNonZeros() > 0 ? OptimizerUtils.getSparsity(allInputStats[1].getRows(), allInputStats[1].getCols(), allInputStats[1].getNonZeros()) : 1.0d, jArr[0], allInputStats[0].getCols(), jArr[1], true));
        }
        return jArr;
    }

    public boolean isMatrixMultiply() {
        return this.innerOp == Hop.OpOp2.MULT && this.outerOp == Hop.AggOp.SUM;
    }

    private boolean isOuterProduct() {
        return getInput().get(0).isVector() && getInput().get(1).isVector() && getInput().get(0).getDim1() == 1 && getInput().get(0).getDim1() > 1 && getInput().get(1).getDim1() > 1 && getInput().get(1).getDim2() == 1;
    }

    @Override // org.apache.sysml.hops.Hop
    public boolean allowsAllExecTypes() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysml.hops.Hop
    public LopProperties.ExecType optFindExecType() throws HopsException {
        checkAndSetForcedPlatform();
        LopProperties.ExecType execType = OptimizerUtils.isSparkExecutionMode() ? LopProperties.ExecType.SPARK : LopProperties.ExecType.MR;
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            if (OptimizerUtils.isMemoryBasedOptLevel()) {
                this._etype = findExecTypeByMemEstimate();
            } else if ((getInput().get(0).areDimsBelowThreshold() && getInput().get(1).areDimsBelowThreshold()) || (getInput().get(0).isVector() && getInput().get(1).isVector() && !isOuterProduct())) {
                this._etype = LopProperties.ExecType.CP;
            } else {
                this._etype = execType;
            }
            if (this._etype == LopProperties.ExecType.CP && checkMapMultChain() != MapMultChain.ChainType.NONE && OptimizerUtils.getLocalMemBudget() < getInput().get(0).getInput().get(0).getOutputMemEstimate()) {
                this._etype = execType;
            }
            checkAndSetInvalidCPDimsAndSize();
        }
        if (this._etype == LopProperties.ExecType.CP && this._etypeForced != LopProperties.ExecType.CP && (isApplicableForTransitiveSparkExecType(true) || isApplicableForTransitiveSparkExecType(false))) {
            this._etype = LopProperties.ExecType.SPARK;
        }
        setRequiresRecompileIfNecessary();
        return this._etype;
    }

    private boolean isApplicableForTransitiveSparkExecType(boolean z) throws HopsException {
        int i = z ? 0 : 1;
        return !((getInput().get(i) instanceof DataOp) && ((DataOp) getInput().get(i)).requiresCheckpoint()) && (!HopRewriteUtils.isTransposeOperation(getInput().get(i)) || (z && !isLeftTransposeRewriteApplicable(true, false))) && getInput().get(i).getParent().size() == 1 && !getInput().get(i).areDimsBelowThreshold() && getInput().get(i).optFindExecType() == LopProperties.ExecType.SPARK && getInput().get(i).getOutputMemEstimate() > getOutputMemEstimate();
    }

    public MMTSJ.MMTSJType checkTransposeSelf() {
        MMTSJ.MMTSJType mMTSJType = MMTSJ.MMTSJType.NONE;
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        if (HopRewriteUtils.isTransposeOperation(hop) && hop.getInput().get(0) == hop2) {
            mMTSJType = MMTSJ.MMTSJType.LEFT;
        }
        if (HopRewriteUtils.isTransposeOperation(hop2) && hop2.getInput().get(0) == hop) {
            mMTSJType = MMTSJ.MMTSJType.RIGHT;
        }
        return mMTSJType;
    }

    public MapMultChain.ChainType checkMapMultChain() {
        MapMultChain.ChainType chainType = MapMultChain.ChainType.NONE;
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        if (HopRewriteUtils.isTransposeOperation(hop)) {
            Hop hop3 = hop.getInput().get(0);
            if ((hop2 instanceof BinaryOp) && ((BinaryOp) hop2).getOp() == Hop.OpOp2.MULT) {
                Hop hop4 = hop2.getInput().get(1);
                if ((hop4 instanceof AggBinaryOp) && hop3 == hop4.getInput().get(0)) {
                    chainType = MapMultChain.ChainType.XtwXv;
                }
            } else if ((hop2 instanceof BinaryOp) && ((BinaryOp) hop2).getOp() == Hop.OpOp2.MINUS) {
                Hop hop5 = hop2.getInput().get(0);
                Hop hop6 = hop2.getInput().get(1);
                if ((hop5 instanceof AggBinaryOp) && hop6.getDataType() == Expression.DataType.MATRIX && hop3 == hop5.getInput().get(0)) {
                    chainType = MapMultChain.ChainType.XtXvy;
                }
            } else if ((hop2 instanceof AggBinaryOp) && hop3 == hop2.getInput().get(0)) {
                chainType = MapMultChain.ChainType.XtXv;
            }
        }
        return chainType;
    }

    private void constructCPLopsTSMM(MMTSJ.MMTSJType mMTSJType, LopProperties.ExecType execType) throws HopsException, LopsException {
        MMTSJ mmtsj = new MMTSJ(getInput().get(mMTSJType.isLeft() ? 1 : 0).constructLops(), getDataType(), getValueType(), execType, mMTSJType, false, OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
        mmtsj.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        setLineNumbers(mmtsj);
        setLops(mmtsj);
    }

    private void constructCPLopsMMChain(MapMultChain.ChainType chainType) throws LopsException, HopsException {
        MapMultChain mapMultChain;
        if (chainType == MapMultChain.ChainType.XtXv) {
            mapMultChain = new MapMultChain(getInput().get(0).getInput().get(0).constructLops(), getInput().get(1).getInput().get(1).constructLops(), getDataType(), getValueType(), LopProperties.ExecType.CP);
        } else {
            int i = chainType == MapMultChain.ChainType.XtwXv ? 0 : 1;
            int i2 = chainType == MapMultChain.ChainType.XtwXv ? 1 : 0;
            mapMultChain = new MapMultChain(getInput().get(0).getInput().get(0).constructLops(), getInput().get(1).getInput().get(i2).getInput().get(1).constructLops(), getInput().get(1).getInput().get(i).constructLops(), chainType, getDataType(), getValueType(), LopProperties.ExecType.CP);
        }
        mapMultChain.setNumThreads(OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
        setOutputDimensions(mapMultChain);
        setLineNumbers(mapMultChain);
        setLops(mapMultChain);
    }

    private void constructCPLopsPMM() throws HopsException, LopsException {
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        Hop createValueHop = HopRewriteUtils.createValueHop(hop, true);
        createValueHop.setOutputBlocksizes(0L, 0L);
        createValueHop.setForcedExecType(LopProperties.ExecType.CP);
        HopRewriteUtils.copyLineNumbers(this, createValueHop);
        PMMJ pmmj = new PMMJ(hop.constructLops(), hop2.constructLops(), createValueHop.constructLops(), getDataType(), getValueType(), false, false, LopProperties.ExecType.CP);
        pmmj.setNumThreads(OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
        pmmj.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        setLineNumbers(pmmj);
        setLops(pmmj);
        HopRewriteUtils.removeChildReference(hop, createValueHop);
    }

    private void constructCPLopsMM(LopProperties.ExecType execType) throws HopsException, LopsException {
        Lop constructCPLopsMMWithLeftTransposeRewrite;
        if (execType == LopProperties.ExecType.GPU) {
            Hop hop = getInput().get(0);
            Hop hop2 = getInput().get(1);
            constructCPLopsMMWithLeftTransposeRewrite = new Binary(0 == 0 ? hop.constructLops() : hop.getInput().get(0).constructLops(), 0 == 0 ? hop2.constructLops() : hop2.getInput().get(0).constructLops(), Binary.OperationTypes.MATMULT, getDataType(), getValueType(), execType, false, false);
            setOutputDimensions(constructCPLopsMMWithLeftTransposeRewrite);
        } else {
            constructCPLopsMMWithLeftTransposeRewrite = isLeftTransposeRewriteApplicable(true, false) ? constructCPLopsMMWithLeftTransposeRewrite() : new Binary(getInput().get(0).constructLops(), getInput().get(1).constructLops(), Binary.OperationTypes.MATMULT, getDataType(), getValueType(), execType, OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
            setOutputDimensions(constructCPLopsMMWithLeftTransposeRewrite);
        }
        setLineNumbers(constructCPLopsMMWithLeftTransposeRewrite);
        setLops(constructCPLopsMMWithLeftTransposeRewrite);
    }

    private Lop constructCPLopsMMWithLeftTransposeRewrite() throws HopsException, LopsException {
        Hop hop = getInput().get(0).getInput().get(0);
        Hop hop2 = getInput().get(1);
        int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        Lop constructLops = hop2.constructLops();
        Lop transform = ((constructLops instanceof Transform) && ((Transform) constructLops).getOperationType() == Transform.OperationTypes.Transpose) ? constructLops.getInputs().get(0) : new Transform(constructLops, Transform.OperationTypes.Transpose, getDataType(), getValueType(), LopProperties.ExecType.CP, constrainedNumThreads);
        transform.getOutputParameters().setDimensions(hop2.getDim2(), hop2.getDim1(), getRowsInBlock(), getColsInBlock(), hop2.getNnz());
        setLineNumbers(transform);
        Binary binary = new Binary(transform, hop.constructLops(), Binary.OperationTypes.MATMULT, getDataType(), getValueType(), LopProperties.ExecType.CP, constrainedNumThreads);
        binary.getOutputParameters().setDimensions(hop2.getDim2(), hop.getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        setLineNumbers(binary);
        return new Transform(binary, Transform.OperationTypes.Transpose, getDataType(), getValueType(), LopProperties.ExecType.CP, constrainedNumThreads);
    }

    private void constructSparkLopsTSMM(MMTSJ.MMTSJType mMTSJType, boolean z) throws HopsException, LopsException {
        MMTSJ mmtsj = new MMTSJ(getInput().get(mMTSJType.isLeft() ? 1 : 0).constructLops(), getDataType(), getValueType(), LopProperties.ExecType.SPARK, mMTSJType, z);
        setOutputDimensions(mmtsj);
        setLineNumbers(mmtsj);
        setLops(mmtsj);
    }

    private void constructSparkLopsMapMM(MMultMethod mMultMethod) throws LopsException, HopsException {
        Lop mapMult;
        if (isLeftTransposeRewriteApplicable(false, false)) {
            mapMult = constructSparkLopsMapMMWithLeftTransposeRewrite();
        } else {
            SparkAggType sparkMMAggregationType = getSparkMMAggregationType(requiresAggregation(mMultMethod));
            this._outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
            mapMult = new MapMult(getInput().get(0).constructLops(), getInput().get(1).constructLops(), getDataType(), getValueType(), mMultMethod == MMultMethod.MAPMM_R, false, this._outputEmptyBlocks, sparkMMAggregationType);
        }
        setOutputDimensions(mapMult);
        setLineNumbers(mapMult);
        setLops(mapMult);
    }

    private Lop constructSparkLopsMapMMWithLeftTransposeRewrite() throws HopsException, LopsException {
        Hop hop = getInput().get(0).getInput().get(0);
        Hop hop2 = getInput().get(1);
        Transform transform = new Transform(hop2.constructLops(), Transform.OperationTypes.Transpose, getDataType(), getValueType(), LopProperties.ExecType.CP);
        transform.getOutputParameters().setDimensions(hop2.getDim2(), hop2.getDim1(), getRowsInBlock(), getColsInBlock(), hop2.getNnz());
        setLineNumbers(transform);
        SparkAggType sparkMMAggregationType = getSparkMMAggregationType(requiresAggregation(MMultMethod.MAPMM_R));
        this._outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
        MapMult mapMult = new MapMult(transform, hop.constructLops(), getDataType(), getValueType(), false, false, this._outputEmptyBlocks, sparkMMAggregationType);
        mapMult.getOutputParameters().setDimensions(hop2.getDim2(), hop.getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        setLineNumbers(mapMult);
        return new Transform(mapMult, Transform.OperationTypes.Transpose, getDataType(), getValueType(), LopProperties.ExecType.CP);
    }

    private void constructSparkLopsMapMMChain(MapMultChain.ChainType chainType) throws LopsException, HopsException {
        MapMultChain mapMultChain;
        if (chainType == MapMultChain.ChainType.XtXv) {
            mapMultChain = new MapMultChain(getInput().get(0).getInput().get(0).constructLops(), getInput().get(1).getInput().get(1).constructLops(), getDataType(), getValueType(), LopProperties.ExecType.SPARK);
        } else {
            int i = chainType == MapMultChain.ChainType.XtwXv ? 0 : 1;
            int i2 = chainType == MapMultChain.ChainType.XtwXv ? 1 : 0;
            mapMultChain = new MapMultChain(getInput().get(0).getInput().get(0).constructLops(), getInput().get(1).getInput().get(i2).getInput().get(1).constructLops(), getInput().get(1).getInput().get(i).constructLops(), chainType, getDataType(), getValueType(), LopProperties.ExecType.SPARK);
        }
        setOutputDimensions(mapMultChain);
        setLineNumbers(mapMultChain);
        setLops(mapMultChain);
    }

    private void constructSparkLopsPMapMM() throws LopsException, HopsException {
        PMapMult pMapMult = new PMapMult(getInput().get(0).constructLops(), getInput().get(1).constructLops(), getDataType(), getValueType());
        setOutputDimensions(pMapMult);
        setLineNumbers(pMapMult);
        setLops(pMapMult);
    }

    private void constructSparkLopsCPMM() throws HopsException, LopsException {
        if (isLeftTransposeRewriteApplicable(false, false)) {
            setLops(constructSparkLopsCPMMWithLeftTransposeRewrite());
            return;
        }
        MMCJ mmcj = new MMCJ(getInput().get(0).constructLops(), getInput().get(1).constructLops(), getDataType(), getValueType(), getSparkMMAggregationType(true), LopProperties.ExecType.SPARK);
        setOutputDimensions(mmcj);
        setLineNumbers(mmcj);
        setLops(mmcj);
    }

    private Lop constructSparkLopsCPMMWithLeftTransposeRewrite() throws HopsException, LopsException {
        SparkAggType sparkMMAggregationType = getSparkMMAggregationType(true);
        Hop hop = getInput().get(0).getInput().get(0);
        Hop hop2 = getInput().get(1);
        Transform transform = new Transform(hop2.constructLops(), Transform.OperationTypes.Transpose, getDataType(), getValueType(), LopProperties.ExecType.CP);
        transform.getOutputParameters().setDimensions(hop2.getDim2(), hop2.getDim1(), getRowsInBlock(), getColsInBlock(), hop2.getNnz());
        setLineNumbers(transform);
        MMCJ mmcj = new MMCJ(transform, hop.constructLops(), getDataType(), getValueType(), sparkMMAggregationType, LopProperties.ExecType.SPARK);
        mmcj.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        setLineNumbers(mmcj);
        Transform transform2 = new Transform(mmcj, Transform.OperationTypes.Transpose, getDataType(), getValueType(), LopProperties.ExecType.CP);
        transform2.getOutputParameters().setDimensions(hop.getDim2(), hop2.getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        return transform2;
    }

    private void constructSparkLopsRMM() throws LopsException, HopsException {
        MMRJ mmrj = new MMRJ(getInput().get(0).constructLops(), getInput().get(1).constructLops(), getDataType(), getValueType(), LopProperties.ExecType.SPARK);
        setOutputDimensions(mmrj);
        setLineNumbers(mmrj);
        setLops(mmrj);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v46, types: [org.apache.sysml.hops.Hop] */
    private void constructSparkLopsPMM() throws HopsException, LopsException {
        AggUnaryOp createAggUnaryOp;
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        Lop constructLops = hop.constructLops();
        LopProperties.ExecType execType = ((double) OptimizerUtils.estimateSize(hop.getDim1(), 1L)) > OptimizerUtils.getLocalMemBudget() ? LopProperties.ExecType.MR : LopProperties.ExecType.CP;
        if (hop.getDim2() != 1) {
            ReorgOp createTranspose = HopRewriteUtils.createTranspose(hop);
            createTranspose.setForcedExecType(LopProperties.ExecType.SPARK);
            AggUnaryOp createAggUnaryOp2 = HopRewriteUtils.createAggUnaryOp(createTranspose, Hop.AggOp.MAXINDEX, Hop.Direction.Row);
            createAggUnaryOp2.setForcedExecType(LopProperties.ExecType.SPARK);
            AggUnaryOp createAggUnaryOp3 = HopRewriteUtils.createAggUnaryOp(createTranspose, Hop.AggOp.MAX, Hop.Direction.Row);
            createAggUnaryOp3.setForcedExecType(LopProperties.ExecType.SPARK);
            BinaryOp createBinary = HopRewriteUtils.createBinary(createAggUnaryOp2, createAggUnaryOp3, Hop.OpOp2.MULT);
            createBinary.setForcedExecType(LopProperties.ExecType.SPARK);
            createAggUnaryOp = HopRewriteUtils.createValueHop(hop, true);
            createAggUnaryOp.setOutputBlocksizes(0L, 0L);
            createAggUnaryOp.setForcedExecType(LopProperties.ExecType.CP);
            HopRewriteUtils.copyLineNumbers(this, createAggUnaryOp);
            constructLops = createBinary.constructLops();
            HopRewriteUtils.removeChildReference(hop, createTranspose);
        } else {
            createAggUnaryOp = HopRewriteUtils.createAggUnaryOp(hop, Hop.AggOp.MAX, Hop.Direction.RowCol);
            createAggUnaryOp.setOutputBlocksizes(0L, 0L);
            createAggUnaryOp.setForcedExecType(execType);
            HopRewriteUtils.copyLineNumbers(this, createAggUnaryOp);
        }
        this._outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
        PMMJ pmmj = new PMMJ(constructLops, hop2.constructLops(), createAggUnaryOp.constructLops(), getDataType(), getValueType(), false, this._outputEmptyBlocks, LopProperties.ExecType.SPARK);
        setOutputDimensions(pmmj);
        setLineNumbers(pmmj);
        setLops(pmmj);
        HopRewriteUtils.removeChildReference(hop, createAggUnaryOp);
    }

    private void constructSparkLopsZIPMM() throws HopsException, LopsException {
        Hop hop = getInput().get(0).getInput().get(0);
        Hop hop2 = getInput().get(1);
        MMZip mMZip = new MMZip(hop.constructLops(), hop2.constructLops(), getDataType(), getValueType(), hop.getDim1() * hop.getDim2() >= hop2.getDim1() * hop2.getDim2(), LopProperties.ExecType.SPARK);
        setOutputDimensions(mMZip);
        setLineNumbers(mMZip);
        setLops(mMZip);
    }

    private void constructMRLopsMapMM(MMultMethod mMultMethod) throws HopsException, LopsException {
        if (mMultMethod == MMultMethod.MAPMM_R && isLeftTransposeRewriteApplicable(false, true)) {
            setLops(constructMRLopsMapMMWithLeftTransposeRewrite());
            return;
        }
        boolean requiresAggregation = requiresAggregation(mMultMethod);
        boolean requiresPartitioning = requiresPartitioning(mMultMethod, false);
        this._outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
        Lop constructLops = getInput().get(0).constructLops();
        Lop constructLops2 = getInput().get(1).constructLops();
        if (requiresPartitioning) {
            if (mMultMethod == MMultMethod.MAPMM_L) {
                Hop hop = getInput().get(0);
                constructLops = new DataPartition(hop.constructLops(), Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, ((double) OptimizerUtils.estimateSizeExactSparsity(hop.getDim1(), hop.getDim2(), OptimizerUtils.getSparsity(hop.getDim1(), hop.getDim2(), hop.getNnz()))) < OptimizerUtils.getLocalMemBudget() ? LopProperties.ExecType.CP : LopProperties.ExecType.MR, ParForProgramBlock.PDataPartitionFormat.COLUMN_BLOCK_WISE_N);
                constructLops.getOutputParameters().setDimensions(hop.getDim1(), hop.getDim2(), getRowsInBlock(), getColsInBlock(), hop.getNnz());
                setLineNumbers(constructLops);
            } else {
                Hop hop2 = getInput().get(1);
                constructLops2 = new DataPartition(hop2.constructLops(), Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, ((double) OptimizerUtils.estimateSizeExactSparsity(hop2.getDim1(), hop2.getDim2(), OptimizerUtils.getSparsity(hop2.getDim1(), hop2.getDim2(), hop2.getNnz()))) < OptimizerUtils.getLocalMemBudget() ? LopProperties.ExecType.CP : LopProperties.ExecType.MR, ParForProgramBlock.PDataPartitionFormat.ROW_BLOCK_WISE_N);
                constructLops2.getOutputParameters().setDimensions(hop2.getDim1(), hop2.getDim2(), getRowsInBlock(), getColsInBlock(), hop2.getNnz());
                setLineNumbers(constructLops2);
            }
        }
        Lop mapMult = new MapMult(constructLops, constructLops2, getDataType(), getValueType(), mMultMethod == MMultMethod.MAPMM_R, requiresPartitioning, this._outputEmptyBlocks);
        mapMult.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        setLineNumbers(mapMult);
        if (!requiresAggregation) {
            setLops(mapMult);
            return;
        }
        Group group = new Group(mapMult, Group.OperationTypes.Sort, getDataType(), getValueType());
        Aggregate aggregate = new Aggregate(group, HopsAgg2Lops.get(this.outerOp), getDataType(), getValueType(), LopProperties.ExecType.MR);
        group.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        aggregate.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        setLineNumbers(aggregate);
        aggregate.setupCorrectionLocation(PartialAggregate.CorrectionLocationType.NONE);
        setLops(aggregate);
    }

    private Lop constructMRLopsMapMMWithLeftTransposeRewrite() throws HopsException, LopsException {
        Lop lop;
        Lop lop2;
        Hop hop = getInput().get(0).getInput().get(0);
        Hop hop2 = getInput().get(1);
        Lop transform = new Transform(hop2.constructLops(), Transform.OperationTypes.Transpose, getDataType(), getValueType(), LopProperties.ExecType.CP);
        transform.getOutputParameters().setDimensions(hop2.getDim2(), hop2.getDim1(), getRowsInBlock(), getColsInBlock(), hop2.getNnz());
        setLineNumbers(transform);
        boolean z = hop.getDim1() <= 0 || hop.getDim1() > hop.getRowsInBlock();
        boolean requiresPartitioning = requiresPartitioning(MMultMethod.MAPMM_R, true);
        if (requiresPartitioning) {
            lop = new DataPartition(transform, Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, ((double) OptimizerUtils.estimateSizeExactSparsity(hop2.getDim2(), hop2.getDim1(), OptimizerUtils.getSparsity(hop2.getDim2(), hop2.getDim1(), hop2.getNnz()))) < OptimizerUtils.getLocalMemBudget() ? LopProperties.ExecType.CP : LopProperties.ExecType.MR, ParForProgramBlock.PDataPartitionFormat.COLUMN_BLOCK_WISE_N);
            lop.getOutputParameters().setDimensions(hop2.getDim2(), hop2.getDim1(), getRowsInBlock(), getColsInBlock(), hop2.getNnz());
            setLineNumbers(lop);
        } else {
            lop = transform;
        }
        Lop mapMult = new MapMult(lop, hop.constructLops(), getDataType(), getValueType(), false, requiresPartitioning, false);
        mapMult.getOutputParameters().setDimensions(hop2.getDim2(), hop.getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        setLineNumbers(mapMult);
        if (z) {
            Lop group = new Group(mapMult, Group.OperationTypes.Sort, getDataType(), getValueType());
            group.getOutputParameters().setDimensions(hop2.getDim2(), hop.getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
            setLineNumbers(group);
            Aggregate aggregate = new Aggregate(group, HopsAgg2Lops.get(this.outerOp), getDataType(), getValueType(), LopProperties.ExecType.MR);
            aggregate.getOutputParameters().setDimensions(hop2.getDim2(), hop.getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
            setLineNumbers(aggregate);
            aggregate.setupCorrectionLocation(PartialAggregate.CorrectionLocationType.NONE);
            lop2 = aggregate;
        } else {
            lop2 = mapMult;
        }
        Transform transform2 = new Transform(lop2, Transform.OperationTypes.Transpose, getDataType(), getValueType(), LopProperties.ExecType.CP);
        transform2.getOutputParameters().setDimensions(hop.getDim2(), hop2.getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        return transform2;
    }

    private void constructMRLopsMapMMChain(MapMultChain.ChainType chainType) throws HopsException, LopsException {
        Lop constructLops;
        Lop mapMultChain;
        if (chainType == MapMultChain.ChainType.XtXv) {
            mapMultChain = new MapMultChain(getInput().get(0).getInput().get(0).constructLops(), getInput().get(1).getInput().get(1).constructLops(), getDataType(), getValueType(), LopProperties.ExecType.MR);
            mapMultChain.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
            setLineNumbers(mapMultChain);
        } else {
            int i = chainType == MapMultChain.ChainType.XtwXv ? 0 : 1;
            int i2 = chainType == MapMultChain.ChainType.XtwXv ? 1 : 0;
            Hop hop = getInput().get(0).getInput().get(0);
            Hop hop2 = getInput().get(1).getInput().get(i);
            Hop hop3 = getInput().get(1).getInput().get(i2).getInput().get(1);
            double estimateSize = OptimizerUtils.estimateSize(hop2.getDim1(), hop2.getDim2());
            boolean z = !hop2.dimsKnown() || hop2.getDim1() * hop2.getDim2() > DistributedCacheInput.PARTITION_SIZE;
            Lop constructLops2 = hop.constructLops();
            Lop constructLops3 = hop3.constructLops();
            if (z) {
                constructLops = new DataPartition(hop2.constructLops(), Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, estimateSize > OptimizerUtils.getLocalMemBudget() ? LopProperties.ExecType.MR : LopProperties.ExecType.CP, ParForProgramBlock.PDataPartitionFormat.ROW_BLOCK_WISE_N);
                constructLops.getOutputParameters().setDimensions(hop2.getDim1(), hop2.getDim2(), getRowsInBlock(), getColsInBlock(), hop2.getNnz());
                setLineNumbers(constructLops);
            } else {
                constructLops = hop2.constructLops();
            }
            mapMultChain = new MapMultChain(constructLops2, constructLops3, constructLops, chainType, getDataType(), getValueType(), LopProperties.ExecType.MR);
            mapMultChain.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
            setLineNumbers(mapMultChain);
        }
        Group group = new Group(mapMultChain, Group.OperationTypes.Sort, getDataType(), getValueType());
        group.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        Aggregate aggregate = new Aggregate(group, HopsAgg2Lops.get(this.outerOp), getDataType(), getValueType(), LopProperties.ExecType.MR);
        aggregate.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        aggregate.setupCorrectionLocation(PartialAggregate.CorrectionLocationType.NONE);
        setLineNumbers(aggregate);
        setLops(aggregate);
    }

    private void constructMRLopsCPMM() throws HopsException, LopsException {
        if (isLeftTransposeRewriteApplicable(false, false)) {
            setLops(constructMRLopsCPMMWithLeftTransposeRewrite());
            return;
        }
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        Lop mmcj = new MMCJ(hop.constructLops(), hop2.constructLops(), getDataType(), getValueType(), getMMCJAggregationType(hop, hop2), LopProperties.ExecType.MR);
        setOutputDimensions(mmcj);
        setLineNumbers(mmcj);
        Lop group = new Group(mmcj, Group.OperationTypes.Sort, getDataType(), getValueType());
        setOutputDimensions(group);
        setLineNumbers(group);
        Aggregate aggregate = new Aggregate(group, HopsAgg2Lops.get(this.outerOp), getDataType(), getValueType(), LopProperties.ExecType.MR);
        setOutputDimensions(aggregate);
        setLineNumbers(aggregate);
        aggregate.setupCorrectionLocation(PartialAggregate.CorrectionLocationType.NONE);
        setLops(aggregate);
    }

    private Lop constructMRLopsCPMMWithLeftTransposeRewrite() throws HopsException, LopsException {
        Hop hop = getInput().get(0).getInput().get(0);
        Hop hop2 = getInput().get(1);
        Lop transform = new Transform(hop2.constructLops(), Transform.OperationTypes.Transpose, getDataType(), getValueType(), LopProperties.ExecType.CP);
        transform.getOutputParameters().setDimensions(hop2.getDim2(), hop2.getDim1(), getRowsInBlock(), getColsInBlock(), hop2.getNnz());
        setLineNumbers(transform);
        Lop mmcj = new MMCJ(transform, hop.constructLops(), getDataType(), getValueType(), getMMCJAggregationType(hop, hop2), LopProperties.ExecType.MR);
        setOutputDimensions(mmcj);
        setLineNumbers(mmcj);
        Lop group = new Group(mmcj, Group.OperationTypes.Sort, getDataType(), getValueType());
        setOutputDimensions(group);
        setLineNumbers(group);
        Aggregate aggregate = new Aggregate(group, HopsAgg2Lops.get(this.outerOp), getDataType(), getValueType(), LopProperties.ExecType.MR);
        setOutputDimensions(aggregate);
        setLineNumbers(aggregate);
        aggregate.setupCorrectionLocation(PartialAggregate.CorrectionLocationType.NONE);
        Transform transform2 = new Transform(aggregate, Transform.OperationTypes.Transpose, getDataType(), getValueType(), LopProperties.ExecType.CP);
        transform2.getOutputParameters().setDimensions(hop.getDim2(), hop2.getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        return transform2;
    }

    private void constructMRLopsRMM() throws HopsException, LopsException {
        MMRJ mmrj = new MMRJ(getInput().get(0).constructLops(), getInput().get(1).constructLops(), getDataType(), getValueType(), LopProperties.ExecType.MR);
        setOutputDimensions(mmrj);
        setLineNumbers(mmrj);
        setLops(mmrj);
    }

    private void constructMRLopsTSMM(MMTSJ.MMTSJType mMTSJType) throws HopsException, LopsException {
        Lop mmtsj = new MMTSJ(getInput().get(mMTSJType.isLeft() ? 1 : 0).constructLops(), getDataType(), getValueType(), LopProperties.ExecType.MR, mMTSJType);
        mmtsj.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        setLineNumbers(mmtsj);
        Aggregate aggregate = new Aggregate(mmtsj, HopsAgg2Lops.get(this.outerOp), getDataType(), getValueType(), LopProperties.ExecType.MR);
        aggregate.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        aggregate.setupCorrectionLocation(PartialAggregate.CorrectionLocationType.NONE);
        setLineNumbers(aggregate);
        setLops(aggregate);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v65, types: [org.apache.sysml.hops.Hop] */
    /* JADX WARN: Type inference failed for: r12v0, types: [org.apache.sysml.hops.AggBinaryOp, org.apache.sysml.hops.Hop] */
    private void constructMRLopsPMM() throws HopsException, LopsException {
        AggUnaryOp createAggUnaryOp;
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        Lop constructLops = hop.constructLops();
        LopProperties.ExecType execType = ((double) OptimizerUtils.estimateSize(hop.getDim1(), 1L)) > OptimizerUtils.getLocalMemBudget() ? LopProperties.ExecType.MR : LopProperties.ExecType.CP;
        if (hop.getDim2() != 1) {
            ReorgOp createTranspose = HopRewriteUtils.createTranspose(hop);
            createTranspose.setForcedExecType(LopProperties.ExecType.MR);
            AggUnaryOp createAggUnaryOp2 = HopRewriteUtils.createAggUnaryOp(createTranspose, Hop.AggOp.MAXINDEX, Hop.Direction.Row);
            createAggUnaryOp2.setForcedExecType(LopProperties.ExecType.MR);
            AggUnaryOp createAggUnaryOp3 = HopRewriteUtils.createAggUnaryOp(createTranspose, Hop.AggOp.MAX, Hop.Direction.Row);
            createAggUnaryOp3.setForcedExecType(LopProperties.ExecType.MR);
            BinaryOp createBinary = HopRewriteUtils.createBinary(createAggUnaryOp2, createAggUnaryOp3, Hop.OpOp2.MULT);
            createBinary.setForcedExecType(LopProperties.ExecType.MR);
            createAggUnaryOp = HopRewriteUtils.createValueHop(hop, true);
            createAggUnaryOp.setOutputBlocksizes(0L, 0L);
            createAggUnaryOp.setForcedExecType(LopProperties.ExecType.CP);
            HopRewriteUtils.copyLineNumbers(this, createAggUnaryOp);
            constructLops = createBinary.constructLops();
            HopRewriteUtils.removeChildReference(hop, createTranspose);
        } else {
            createAggUnaryOp = HopRewriteUtils.createAggUnaryOp(hop, Hop.AggOp.MAX, Hop.Direction.RowCol);
            createAggUnaryOp.setOutputBlocksizes(0L, 0L);
            createAggUnaryOp.setForcedExecType(execType);
            HopRewriteUtils.copyLineNumbers(this, createAggUnaryOp);
        }
        boolean z = !hop.dimsKnown() || hop.getDim1() > DistributedCacheInput.PARTITION_SIZE;
        if (z) {
            constructLops = new DataPartition(constructLops, Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, execType, ParForProgramBlock.PDataPartitionFormat.ROW_BLOCK_WISE_N);
            constructLops.getOutputParameters().setDimensions(hop.getDim1(), 1L, getRowsInBlock(), getColsInBlock(), hop.getDim1());
            setLineNumbers(constructLops);
        }
        this._outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
        PMMJ pmmj = new PMMJ(constructLops, hop2.constructLops(), createAggUnaryOp.constructLops(), getDataType(), getValueType(), z, this._outputEmptyBlocks, LopProperties.ExecType.MR);
        pmmj.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        setLineNumbers(pmmj);
        Aggregate aggregate = new Aggregate(pmmj, HopsAgg2Lops.get(this.outerOp), getDataType(), getValueType(), LopProperties.ExecType.MR);
        aggregate.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
        aggregate.setupCorrectionLocation(PartialAggregate.CorrectionLocationType.NONE);
        setLineNumbers(aggregate);
        setLops(aggregate);
        HopRewriteUtils.removeChildReference(hop, createAggUnaryOp);
    }

    private boolean isLeftTransposeRewriteApplicable(boolean z, boolean z2) {
        if (DMLScript.rtplatform == DMLScript.RUNTIME_PLATFORM.HADOOP || DMLScript.rtplatform == DMLScript.RUNTIME_PLATFORM.SPARK) {
            return false;
        }
        boolean z3 = false;
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        if (z) {
            if (HopRewriteUtils.isTransposeOperation(hop)) {
                long dim1 = hop.getDim1();
                long dim2 = hop.getDim2();
                long dim22 = hop2.getDim2();
                boolean z4 = dim1 > 0 && dim2 > 0 && dim22 > 0;
                double estimateSizeExactSparsity = OptimizerUtils.estimateSizeExactSparsity(dim22, dim2, 1.0d) + hop.getInput().get(0).getOutputMemEstimate() + OptimizerUtils.estimateSizeExactSparsity(dim22, dim1, 1.0d);
                z3 = z4 & (estimateSizeExactSparsity < OptimizerUtils.getLocalMemBudget()) & (dim1 * dim2 > (dim2 * dim22) + (dim1 * dim22) && ((double) (2 * OptimizerUtils.estimateSizeExactSparsity(dim2, dim22, 1.0d))) < OptimizerUtils.getLocalMemBudget() && ((double) (2 * OptimizerUtils.estimateSizeExactSparsity(dim1, dim22, 1.0d))) < OptimizerUtils.getLocalMemBudget());
                if (z3) {
                    this._memEstimate = estimateSizeExactSparsity;
                }
            }
        } else if ((hop instanceof ReorgOp) && ((ReorgOp) hop).getOp() == Hop.ReOrgOp.TRANSPOSE) {
            long dim12 = hop.getDim1();
            long dim23 = hop.getDim2();
            long dim24 = hop2.getDim2();
            if (dim12 > 0 && dim23 > 0 && dim24 > 0 && dim12 * dim23 > (dim23 * dim24) + (dim12 * dim24) && 2 * OptimizerUtils.estimateSizeExactSparsity(dim23, dim24, 1.0d) < OptimizerUtils.getLocalMemBudget() && 2 * OptimizerUtils.estimateSizeExactSparsity(dim12, dim24, 1.0d) < OptimizerUtils.getLocalMemBudget() && (!z2 || OptimizerUtils.estimateSizeExactSparsity(dim23, dim24, 1.0d) < OptimizerUtils.getRemoteMemBudgetMap(true))) {
                z3 = true;
            }
        }
        return z3;
    }

    private MMCJ.MMCJType getMMCJAggregationType(Hop hop, Hop hop2) {
        return (!dimsKnown() || ((double) (2 * OptimizerUtils.estimateSize(getDim1(), getDim2()))) <= OptimizerUtils.getRemoteMemBudgetReduce() || (((double) OptimizerUtils.estimateSize(hop.getDim1(), Math.min(hop.getDim2(), hop.getColsInBlock()))) >= ((double) MMCJMRReducerWithAggregator.MIN_CACHE_SIZE) && ((double) OptimizerUtils.estimateSize(Math.min(hop2.getDim1(), hop2.getRowsInBlock()), hop2.getDim2())) >= ((double) MMCJMRReducerWithAggregator.MIN_CACHE_SIZE))) ? MMCJ.MMCJType.AGG : MMCJ.MMCJType.NO_AGG;
    }

    private SparkAggType getSparkMMAggregationType(boolean z) {
        return !z ? SparkAggType.NONE : (!dimsKnown() || getDim1() > getRowsInBlock() || getDim2() > getColsInBlock()) ? SparkAggType.MULTI_BLOCK : SparkAggType.SINGLE_BLOCK;
    }

    private boolean requiresAggregation(MMultMethod mMultMethod) {
        boolean z = true;
        if (mMultMethod == MMultMethod.MAPMM_R && getInput().get(0).getDim2() >= 0 && getInput().get(0).getDim2() <= getInput().get(0).getColsInBlock()) {
            z = false;
        }
        if (mMultMethod == MMultMethod.MAPMM_L && getInput().get(1).getDim1() >= 0 && getInput().get(1).getDim1() <= getInput().get(1).getRowsInBlock()) {
            z = false;
        }
        return z;
    }

    private boolean requiresPartitioning(MMultMethod mMultMethod, boolean z) {
        boolean z2 = true;
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        if (mMultMethod == MMultMethod.MAPMM_R && hop2.dimsKnown()) {
            z2 = hop2.getDim1() * hop2.getDim2() > DistributedCacheInput.PARTITION_SIZE;
        }
        if (mMultMethod == MMultMethod.MAPMM_L && hop.dimsKnown()) {
            z2 = hop.getDim1() * hop.getDim2() > DistributedCacheInput.PARTITION_SIZE;
        }
        return z2;
    }

    public static double getMapmmMemEstimate(long j, long j2, long j3, long j4, long j5, long j6, long j7, long j8, long j9, long j10, int i, boolean z) {
        double estimatePartitionedSizeExactSparsity = OptimizerUtils.estimatePartitionedSizeExactSparsity(j, j2, j3, j4, j5);
        double estimatePartitionedSizeExactSparsity2 = OptimizerUtils.estimatePartitionedSizeExactSparsity(j6, j7, j8, j9, j10);
        double estimateSize = OptimizerUtils.estimateSize(Math.min(j, j3), Math.min(j2, j4));
        double estimateSize2 = OptimizerUtils.estimateSize(Math.min(j6, j8), Math.min(j7, j9));
        return z ? estimatePartitionedSizeExactSparsity + (3.0d * estimateSize2) : i == 1 ? estimatePartitionedSizeExactSparsity + estimateSize2 + OptimizerUtils.estimateSize(j, Math.min(j7, j9)) : estimateSize + estimatePartitionedSizeExactSparsity2 + OptimizerUtils.estimateSize(Math.min(j, j3), j7);
    }

    private static MMultMethod optFindMMultMethodMR(long j, long j2, long j3, long j4, long j5, long j6, long j7, long j8, long j9, long j10, MMTSJ.MMTSJType mMTSJType, MapMultChain.ChainType chainType, boolean z) {
        double remoteMemBudgetMap = 1.0d * OptimizerUtils.getRemoteMemBudgetMap(true);
        if (FORCED_MMULT_METHOD != null) {
            return FORCED_MMULT_METHOD;
        }
        if ((mMTSJType == MMTSJ.MMTSJType.LEFT && j7 >= 0 && j7 <= j9) || (mMTSJType == MMTSJ.MMTSJType.RIGHT && j >= 0 && j <= j3)) {
            return MMultMethod.TSMM;
        }
        if (OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES && chainType != MapMultChain.ChainType.NONE && j >= 0 && j <= j3 && j7 == 1) {
            if (chainType == MapMultChain.ChainType.XtXv && j >= 0 && j7 >= 0 && OptimizerUtils.estimateSize(j, j7) < remoteMemBudgetMap) {
                return MMultMethod.MAPMM_CHAIN;
            }
            if ((chainType == MapMultChain.ChainType.XtwXv || chainType == MapMultChain.ChainType.XtXvy) && j >= 0 && j7 >= 0 && j2 >= 0 && OptimizerUtils.estimateSize(j, j7) + OptimizerUtils.estimateSize(j2, j7) < remoteMemBudgetMap) {
                return MMultMethod.MAPMM_CHAIN;
            }
        }
        double mapmmMemEstimate = getMapmmMemEstimate(j, 1L, j3, j4, j5, j6, j7, j8, j9, j10, 1, true);
        double mapmmMemEstimate2 = getMapmmMemEstimate(j6, 1L, j3, j4, j5, j6, j7, j8, j9, j10, 1, true);
        if (((mapmmMemEstimate < remoteMemBudgetMap && j >= 0) || (mapmmMemEstimate2 < remoteMemBudgetMap && j6 >= 0)) && z) {
            return MMultMethod.PMM;
        }
        double estimatePartitionedSizeExactSparsity = OptimizerUtils.estimatePartitionedSizeExactSparsity(j, j2, j3, j4, j5);
        double estimatePartitionedSizeExactSparsity2 = OptimizerUtils.estimatePartitionedSizeExactSparsity(j6, j7, j8, j9, j10);
        double mapmmMemEstimate3 = getMapmmMemEstimate(j, j2, j3, j4, j5, j6, j7, j8, j9, j10, 1, false);
        double mapmmMemEstimate4 = getMapmmMemEstimate(j, j2, j3, j4, j5, j6, j7, j8, j9, j10, 2, false);
        if ((mapmmMemEstimate3 < remoteMemBudgetMap && j >= 0 && j2 >= 0) || (mapmmMemEstimate4 < remoteMemBudgetMap && j6 >= 0 && j7 >= 0)) {
            return (estimatePartitionedSizeExactSparsity >= estimatePartitionedSizeExactSparsity2 || j < 0 || j2 < 0) ? MMultMethod.MAPMM_R : MMultMethod.MAPMM_L;
        }
        if (j == -1 || j2 == -1 || j6 == -1 || j7 == -1) {
            return MMultMethod.CPMM;
        }
        return getCPMMCostEstimate(j, j2, j3, j4, j6, j7, j8, j9) < getRMMCostEstimate(j, j2, j3, j4, j6, j7, j8, j9) ? MMultMethod.CPMM : MMultMethod.RMM;
    }

    private static MMultMethod optFindMMultMethodCP(long j, long j2, long j3, long j4, MMTSJ.MMTSJType mMTSJType, MapMultChain.ChainType chainType, boolean z) {
        return mMTSJType != MMTSJ.MMTSJType.NONE ? MMultMethod.TSMM : (chainType != MapMultChain.ChainType.NONE && OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES && j4 == 1) ? MMultMethod.MAPMM_CHAIN : (z && j2 == 1 && j3 != 1) ? MMultMethod.PMM : MMultMethod.MM;
    }

    private MMultMethod optFindMMultMethodSpark(long j, long j2, long j3, long j4, long j5, long j6, long j7, long j8, long j9, long j10, MMTSJ.MMTSJType mMTSJType, MapMultChain.ChainType chainType, boolean z, boolean z2) {
        double broadcastMemoryBudget = 1.0d * SparkExecutionContext.getBroadcastMemoryBudget();
        double localMemBudget = OptimizerUtils.getLocalMemBudget();
        this._spBroadcastMemEstimate = 0.0d;
        if (FORCED_MMULT_METHOD != null) {
            return FORCED_MMULT_METHOD;
        }
        if ((mMTSJType == MMTSJ.MMTSJType.LEFT && j7 >= 0 && j7 <= j9) || (mMTSJType == MMTSJ.MMTSJType.RIGHT && j >= 0 && j <= j3)) {
            return MMultMethod.TSMM;
        }
        if (OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES && chainType != MapMultChain.ChainType.NONE && j >= 0 && j <= j3 && j7 == 1) {
            if (chainType == MapMultChain.ChainType.XtXv && j >= 0 && j7 >= 0 && OptimizerUtils.estimateSize(j, j7) < broadcastMemoryBudget) {
                return MMultMethod.MAPMM_CHAIN;
            }
            if ((chainType == MapMultChain.ChainType.XtwXv || chainType == MapMultChain.ChainType.XtXvy) && j >= 0 && j7 >= 0 && j2 >= 0 && OptimizerUtils.estimateSize(j, j7) + OptimizerUtils.estimateSize(j2, j7) < broadcastMemoryBudget && 2 * (OptimizerUtils.estimateSize(j, j7) + OptimizerUtils.estimateSize(j2, j7)) < localMemBudget) {
                this._spBroadcastMemEstimate = 2 * (OptimizerUtils.estimateSize(j, j7) + OptimizerUtils.estimateSize(j2, j7));
                return MMultMethod.MAPMM_CHAIN;
            }
        }
        double mapmmMemEstimate = getMapmmMemEstimate(j, 1L, j3, j4, j5, j6, j7, j8, j9, j10, 1, true);
        double mapmmMemEstimate2 = getMapmmMemEstimate(j6, 1L, j3, j4, j5, j6, j7, j8, j9, j10, 1, true);
        if (((mapmmMemEstimate < broadcastMemoryBudget && j >= 0) || (mapmmMemEstimate2 < broadcastMemoryBudget && j6 >= 0)) && 2 * OptimizerUtils.estimateSize(j, 1L) < localMemBudget && z) {
            this._spBroadcastMemEstimate = 2 * OptimizerUtils.estimateSize(j, 1L);
            return MMultMethod.PMM;
        }
        double estimateSizeExactSparsity = OptimizerUtils.estimateSizeExactSparsity(j, j2, j5);
        double estimateSizeExactSparsity2 = OptimizerUtils.estimateSizeExactSparsity(j6, j7, j10);
        double estimatePartitionedSizeExactSparsity = OptimizerUtils.estimatePartitionedSizeExactSparsity(j, j2, j3, j4, j5);
        double estimatePartitionedSizeExactSparsity2 = OptimizerUtils.estimatePartitionedSizeExactSparsity(j6, j7, j8, j9, j10);
        double mapmmMemEstimate3 = getMapmmMemEstimate(j, j2, j3, j4, j5, j6, j7, j8, j9, j10, 1, false);
        double mapmmMemEstimate4 = getMapmmMemEstimate(j, j2, j3, j4, j5, j6, j7, j8, j9, j10, 2, false);
        if ((mapmmMemEstimate3 < broadcastMemoryBudget && estimateSizeExactSparsity + estimatePartitionedSizeExactSparsity < localMemBudget && j >= 0 && j2 >= 0) || (mapmmMemEstimate4 < broadcastMemoryBudget && estimateSizeExactSparsity2 + estimatePartitionedSizeExactSparsity2 < localMemBudget && j6 >= 0 && j7 >= 0)) {
            if (estimatePartitionedSizeExactSparsity < estimatePartitionedSizeExactSparsity2 && j >= 0 && j2 >= 0 && OptimizerUtils.isValidCPDimensions(j, j2)) {
                this._spBroadcastMemEstimate = estimateSizeExactSparsity + estimatePartitionedSizeExactSparsity;
                return MMultMethod.MAPMM_L;
            }
            if (OptimizerUtils.isValidCPDimensions(j6, j7)) {
                this._spBroadcastMemEstimate = estimateSizeExactSparsity2 + estimatePartitionedSizeExactSparsity2;
                return MMultMethod.MAPMM_R;
            }
        }
        if (mMTSJType != MMTSJ.MMTSJType.NONE && j >= 0 && j2 >= 0 && j6 >= 0 && j7 >= 0) {
            double estimateSizeExactSparsity3 = mMTSJType == MMTSJ.MMTSJType.LEFT ? OptimizerUtils.estimateSizeExactSparsity(j6, j7 - j9, 1.0d) : OptimizerUtils.estimateSizeExactSparsity(j - j3, j2, 1.0d);
            double estimatePartitionedSizeExactSparsity3 = mMTSJType == MMTSJ.MMTSJType.LEFT ? OptimizerUtils.estimatePartitionedSizeExactSparsity(j6, j7 - j9, j8, j9, 1.0d) : OptimizerUtils.estimatePartitionedSizeExactSparsity(j - j3, j2, j3, j4, 1.0d);
            if (estimatePartitionedSizeExactSparsity3 < broadcastMemoryBudget && estimateSizeExactSparsity3 + estimatePartitionedSizeExactSparsity3 < localMemBudget && (mMTSJType != MMTSJ.MMTSJType.LEFT ? j <= 2 * j3 : j7 <= 2 * j9) && estimatePartitionedSizeExactSparsity3 < 2.147483648E9d) {
                return MMultMethod.TSMM2;
            }
        }
        if (j == -1 || j2 == -1 || j6 == -1 || j7 == -1) {
            return MMultMethod.CPMM;
        }
        if (!z2 || j < 0 || j > j3 || j7 < 0 || j7 > j9) {
            return getCPMMCostEstimate(j, j2, j3, j4, j6, j7, j8, j9) < getRMMCostEstimate(j, j2, j3, j4, j6, j7, j8, j9) ? MMultMethod.CPMM : MMultMethod.RMM;
        }
        return MMultMethod.ZIPMM;
    }

    private static double getRMMCostEstimate(long j, long j2, long j3, long j4, long j5, long j6, long j7, long j8) {
        long ceil = (long) Math.ceil(j / j3);
        double d = j * j2;
        double d2 = j5 * j6;
        return (((((long) Math.ceil(j6 / j8)) * d) + (ceil * d2)) + ((d + d2) + (j * j6))) / Math.min(ceil * r0, OptimizerUtils.getNumReducers(true));
    }

    private static double getCPMMCostEstimate(long j, long j2, long j3, long j4, long j5, long j6, long j7, long j8) {
        long ceil = (long) Math.ceil(j / j3);
        long ceil2 = (long) Math.ceil(j2 / j4);
        long ceil3 = (long) Math.ceil(j6 / j8);
        double d = j * j2;
        double d2 = j5 * j6;
        double d3 = j * j6;
        double d4 = d + d2;
        double min = Math.min(ceil2, OptimizerUtils.getNumReducers(false));
        return ((d4 + ((d + d2) + (min * d3))) / min) + (((min * d3) + ((min * d3) + d3)) / Math.min(ceil * ceil3, r0));
    }

    @Override // org.apache.sysml.hops.Hop
    public void refreshSizeInformation() {
        Hop hop = getInput().get(0);
        Hop hop2 = getInput().get(1);
        if (isMatrixMultiply()) {
            setDim1(hop.getDim1());
            setDim2(hop2.getDim2());
        }
    }

    @Override // org.apache.sysml.hops.Hop
    public Object clone() throws CloneNotSupportedException {
        AggBinaryOp aggBinaryOp = new AggBinaryOp();
        aggBinaryOp.clone(this, false);
        aggBinaryOp.innerOp = this.innerOp;
        aggBinaryOp.outerOp = this.outerOp;
        aggBinaryOp._hasLeftPMInput = this._hasLeftPMInput;
        aggBinaryOp._maxNumThreads = this._maxNumThreads;
        return aggBinaryOp;
    }

    @Override // org.apache.sysml.hops.Hop
    public boolean compare(Hop hop) {
        if (!(hop instanceof AggBinaryOp)) {
            return false;
        }
        AggBinaryOp aggBinaryOp = (AggBinaryOp) hop;
        return this.innerOp == aggBinaryOp.innerOp && this.outerOp == aggBinaryOp.outerOp && getInput().get(0) == aggBinaryOp.getInput().get(0) && getInput().get(1) == aggBinaryOp.getInput().get(1) && this._hasLeftPMInput == aggBinaryOp._hasLeftPMInput && this._maxNumThreads == aggBinaryOp._maxNumThreads;
    }
}
