package org.apache.sysml.runtime.instructions.spark;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.sysml.lops.BinaryM;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcastMatrix;
import org.apache.sysml.runtime.instructions.spark.functions.MatrixMatrixBinaryOpFunction;
import org.apache.sysml.runtime.instructions.spark.functions.MatrixScalarUnaryFunction;
import org.apache.sysml.runtime.instructions.spark.functions.MatrixVectorBinaryOpPartitionFunction;
import org.apache.sysml.runtime.instructions.spark.functions.OuterVectorBinaryOpFunction;
import org.apache.sysml.runtime.instructions.spark.functions.ReplicateVectorFunction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.ScalarOperator;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/BinarySPInstruction.class */
public abstract class BinarySPInstruction extends ComputationSPInstruction {
    public BinarySPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2) {
        super(operator, cPOperand, cPOperand2, cPOperand3, str, str2);
    }

    public BinarySPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2) {
        super(operator, cPOperand, cPOperand2, cPOperand3, cPOperand4, str, str2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static String parseBinaryInstruction(String str, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3) throws DMLRuntimeException {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 3);
        String str2 = instructionPartsWithValueType[0];
        cPOperand.split(instructionPartsWithValueType[1]);
        cPOperand2.split(instructionPartsWithValueType[2]);
        cPOperand3.split(instructionPartsWithValueType[3]);
        return str2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static String parseBinaryInstruction(String str, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4) throws DMLRuntimeException {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 4);
        String str2 = instructionPartsWithValueType[0];
        cPOperand.split(instructionPartsWithValueType[1]);
        cPOperand2.split(instructionPartsWithValueType[2]);
        cPOperand3.split(instructionPartsWithValueType[3]);
        cPOperand4.split(instructionPartsWithValueType[4]);
        return str2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void processMatrixMatrixBinaryInstruction(ExecutionContext executionContext) throws DMLRuntimeException {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        checkMatrixMatrixBinaryCharacteristics(sparkExecutionContext);
        String name = this.input1.getName();
        String name2 = this.input2.getName();
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockRDDHandleForVariable = sparkExecutionContext.getBinaryBlockRDDHandleForVariable(name);
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockRDDHandleForVariable2 = sparkExecutionContext.getBinaryBlockRDDHandleForVariable(name2);
        MatrixCharacteristics matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(name);
        MatrixCharacteristics matrixCharacteristics2 = sparkExecutionContext.getMatrixCharacteristics(name2);
        BinaryOperator binaryOperator = (BinaryOperator) this._optr;
        boolean z = matrixCharacteristics2.getRows() == 1 && matrixCharacteristics.getRows() > 1;
        long numReplicas = getNumReplicas(matrixCharacteristics, matrixCharacteristics2, true);
        long numReplicas2 = getNumReplicas(matrixCharacteristics, matrixCharacteristics2, false);
        if (numReplicas > 1) {
            binaryBlockRDDHandleForVariable = binaryBlockRDDHandleForVariable.flatMapToPair(new ReplicateVectorFunction(false, numReplicas));
        }
        if (numReplicas2 > 1) {
            binaryBlockRDDHandleForVariable2 = binaryBlockRDDHandleForVariable2.flatMapToPair(new ReplicateVectorFunction(z, numReplicas2));
        }
        JavaPairRDD<MatrixIndexes, ?> mapValues = binaryBlockRDDHandleForVariable.join(binaryBlockRDDHandleForVariable2).mapValues(new MatrixMatrixBinaryOpFunction(binaryOperator));
        updateBinaryOutputMatrixCharacteristics(sparkExecutionContext);
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), mapValues);
        sparkExecutionContext.addLineageRDD(this.output.getName(), name);
        sparkExecutionContext.addLineageRDD(this.output.getName(), name2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void processMatrixBVectorBinaryInstruction(ExecutionContext executionContext, BinaryM.VectorType vectorType) throws DMLRuntimeException {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        checkMatrixMatrixBinaryCharacteristics(sparkExecutionContext);
        String name = this.input1.getName();
        String name2 = this.input2.getName();
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockRDDHandleForVariable = sparkExecutionContext.getBinaryBlockRDDHandleForVariable(name);
        PartitionedBroadcastMatrix broadcastForVariable = sparkExecutionContext.getBroadcastForVariable(name2);
        MatrixCharacteristics matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(name);
        MatrixCharacteristics matrixCharacteristics2 = sparkExecutionContext.getMatrixCharacteristics(name2);
        BinaryOperator binaryOperator = (BinaryOperator) this._optr;
        JavaPairRDD<MatrixIndexes, ?> flatMapToPair = (matrixCharacteristics.getRows() > 1L ? 1 : (matrixCharacteristics.getRows() == 1L ? 0 : -1)) > 0 && (matrixCharacteristics.getCols() > 1L ? 1 : (matrixCharacteristics.getCols() == 1L ? 0 : -1)) == 0 && (matrixCharacteristics2.getRows() > 1L ? 1 : (matrixCharacteristics2.getRows() == 1L ? 0 : -1)) == 0 && (matrixCharacteristics2.getCols() > 1L ? 1 : (matrixCharacteristics2.getCols() == 1L ? 0 : -1)) > 0 ? binaryBlockRDDHandleForVariable.flatMapToPair(new OuterVectorBinaryOpFunction(binaryOperator, broadcastForVariable)) : binaryBlockRDDHandleForVariable.mapPartitionsToPair(new MatrixVectorBinaryOpPartitionFunction(binaryOperator, broadcastForVariable, vectorType), true);
        updateBinaryOutputMatrixCharacteristics(sparkExecutionContext);
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), flatMapToPair);
        sparkExecutionContext.addLineageRDD(this.output.getName(), name);
        sparkExecutionContext.addLineageBroadcast(this.output.getName(), name2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void processMatrixScalarBinaryInstruction(ExecutionContext executionContext) throws DMLRuntimeException {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        String name = this.input1.getDataType() == Expression.DataType.MATRIX ? this.input1.getName() : this.input2.getName();
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockRDDHandleForVariable = sparkExecutionContext.getBinaryBlockRDDHandleForVariable(name);
        CPOperand cPOperand = this.input1.getDataType() == Expression.DataType.MATRIX ? this.input2 : this.input1;
        ScalarObject scalarInput = executionContext.getScalarInput(cPOperand.getName(), cPOperand.getValueType(), cPOperand.isLiteral());
        ScalarOperator scalarOperator = (ScalarOperator) this._optr;
        scalarOperator.setConstant(scalarInput.getDoubleValue());
        JavaPairRDD<MatrixIndexes, ?> mapValues = binaryBlockRDDHandleForVariable.mapValues(new MatrixScalarUnaryFunction(scalarOperator));
        updateUnaryOutputMatrixCharacteristics(sparkExecutionContext, name, this.output.getName());
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), mapValues);
        sparkExecutionContext.addLineageRDD(this.output.getName(), name);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateBinaryMMOutputMatrixCharacteristics(SparkExecutionContext sparkExecutionContext, boolean z) throws DMLRuntimeException {
        MatrixCharacteristics matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(this.input1.getName());
        MatrixCharacteristics matrixCharacteristics2 = sparkExecutionContext.getMatrixCharacteristics(this.input2.getName());
        MatrixCharacteristics matrixCharacteristics3 = sparkExecutionContext.getMatrixCharacteristics(this.output.getName());
        if (matrixCharacteristics3.dimsKnown()) {
            return;
        }
        if (!matrixCharacteristics.dimsKnown() || !matrixCharacteristics2.dimsKnown()) {
            throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from inputs.");
        }
        if (matrixCharacteristics.getRowsPerBlock() != matrixCharacteristics2.getRowsPerBlock() || matrixCharacteristics.getColsPerBlock() != matrixCharacteristics2.getColsPerBlock()) {
            throw new DMLRuntimeException("Incompatible block sizes for BinarySPInstruction.");
        }
        if (z && matrixCharacteristics.getCols() != matrixCharacteristics2.getRows()) {
            throw new DMLRuntimeException("Incompatible dimensions for BinarySPInstruction");
        }
        matrixCharacteristics3.set(matrixCharacteristics.getRows(), matrixCharacteristics2.getCols(), matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateBinaryAppendOutputMatrixCharacteristics(SparkExecutionContext sparkExecutionContext, boolean z) throws DMLRuntimeException {
        MatrixCharacteristics matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(this.input1.getName());
        MatrixCharacteristics matrixCharacteristics2 = sparkExecutionContext.getMatrixCharacteristics(this.input2.getName());
        MatrixCharacteristics matrixCharacteristics3 = sparkExecutionContext.getMatrixCharacteristics(this.output.getName());
        if (!matrixCharacteristics3.dimsKnown()) {
            if (!matrixCharacteristics.dimsKnown() || !matrixCharacteristics2.dimsKnown()) {
                throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from inputs.");
            }
            if (z) {
                matrixCharacteristics3.set(matrixCharacteristics.getRows(), matrixCharacteristics.getCols() + matrixCharacteristics2.getCols(), matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock());
            } else {
                matrixCharacteristics3.set(matrixCharacteristics.getRows() + matrixCharacteristics2.getRows(), matrixCharacteristics.getCols(), matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock());
            }
        }
        if (!matrixCharacteristics3.nnzKnown() && matrixCharacteristics.nnzKnown() && matrixCharacteristics2.nnzKnown()) {
            matrixCharacteristics3.setNonZeros(matrixCharacteristics.getNonZeros() + matrixCharacteristics2.getNonZeros());
        }
    }

    protected long getNumReplicas(MatrixCharacteristics matrixCharacteristics, MatrixCharacteristics matrixCharacteristics2, boolean z) {
        if (z) {
            if (matrixCharacteristics.getCols() == 1) {
                return (long) Math.ceil(matrixCharacteristics2.getCols() / matrixCharacteristics2.getColsPerBlock());
            }
            return 1L;
        }
        if (matrixCharacteristics2.getRows() == 1 && matrixCharacteristics.getRows() > 1) {
            return (long) Math.ceil(matrixCharacteristics.getRows() / matrixCharacteristics.getRowsPerBlock());
        }
        if (matrixCharacteristics2.getCols() != 1 || matrixCharacteristics.getCols() <= 1) {
            return 1L;
        }
        return (long) Math.ceil(matrixCharacteristics.getCols() / matrixCharacteristics.getColsPerBlock());
    }

    protected void checkMatrixMatrixBinaryCharacteristics(SparkExecutionContext sparkExecutionContext) throws DMLRuntimeException {
        MatrixCharacteristics matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(this.input1.getName());
        MatrixCharacteristics matrixCharacteristics2 = sparkExecutionContext.getMatrixCharacteristics(this.input2.getName());
        if (!matrixCharacteristics.dimsKnown() || !matrixCharacteristics2.dimsKnown()) {
            throw new DMLRuntimeException("Unknown dimensions matrix-matrix binary operations: [" + matrixCharacteristics.getRows() + "x" + matrixCharacteristics.getCols() + " vs " + matrixCharacteristics2.getRows() + "x" + matrixCharacteristics2.getCols() + "]");
        }
        if ((matrixCharacteristics.getRows() != matrixCharacteristics2.getRows() || matrixCharacteristics.getCols() != matrixCharacteristics2.getCols()) && ((matrixCharacteristics.getRows() != matrixCharacteristics2.getRows() || matrixCharacteristics2.getCols() != 1) && ((matrixCharacteristics.getCols() != matrixCharacteristics2.getCols() || matrixCharacteristics2.getRows() != 1) && (matrixCharacteristics.getCols() != 1 || matrixCharacteristics2.getRows() != 1)))) {
            throw new DMLRuntimeException("Dimensions mismatch matrix-matrix binary operations: [" + matrixCharacteristics.getRows() + "x" + matrixCharacteristics.getCols() + " vs " + matrixCharacteristics2.getRows() + "x" + matrixCharacteristics2.getCols() + "]");
        }
        if (matrixCharacteristics.getRowsPerBlock() != matrixCharacteristics2.getRowsPerBlock() || matrixCharacteristics.getColsPerBlock() != matrixCharacteristics2.getColsPerBlock()) {
            throw new DMLRuntimeException("Blocksize mismatch matrix-matrix binary operations: [" + matrixCharacteristics.getRowsPerBlock() + "x" + matrixCharacteristics.getColsPerBlock() + " vs " + matrixCharacteristics2.getRowsPerBlock() + "x" + matrixCharacteristics2.getColsPerBlock() + "]");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void checkBinaryAppendInputCharacteristics(SparkExecutionContext sparkExecutionContext, boolean z, boolean z2, boolean z3) throws DMLRuntimeException {
        MatrixCharacteristics matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(this.input1.getName());
        MatrixCharacteristics matrixCharacteristics2 = sparkExecutionContext.getMatrixCharacteristics(this.input2.getName());
        if (!matrixCharacteristics.dimsKnown() || !matrixCharacteristics2.dimsKnown()) {
            throw new DMLRuntimeException("The dimensions unknown for inputs");
        }
        if (z && matrixCharacteristics.getRows() != matrixCharacteristics2.getRows()) {
            throw new DMLRuntimeException("The number of rows of inputs should match for append-cbind instruction");
        }
        if (!z && matrixCharacteristics.getCols() != matrixCharacteristics2.getCols()) {
            throw new DMLRuntimeException("The number of columns of inputs should match for append-rbind instruction");
        }
        if (matrixCharacteristics.getRowsPerBlock() != matrixCharacteristics2.getRowsPerBlock() || matrixCharacteristics.getColsPerBlock() != matrixCharacteristics2.getColsPerBlock()) {
            throw new DMLRuntimeException("The block sizes donot match for input matrices");
        }
        if (z2 && matrixCharacteristics.getCols() + matrixCharacteristics2.getCols() > matrixCharacteristics.getColsPerBlock()) {
            throw new DMLRuntimeException("Output must have at most one column block");
        }
        if (z3 && matrixCharacteristics.getCols() % matrixCharacteristics.getColsPerBlock() != 0) {
            throw new DMLRuntimeException("Input matrices are not aligned to blocksize boundaries. Wrong append selected");
        }
    }
}
