package org.apache.sysml.runtime.instructions.spark;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.AppendGSPInstruction;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.util.UtilFunctions;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/BuiltinNarySPInstruction.class */
public class BuiltinNarySPInstruction extends SPInstruction {
    private CPOperand[] inputs;
    private CPOperand output;

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/BuiltinNarySPInstruction$PadBlocksFunction.class */
    public static class PadBlocksFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 1291358959908299855L;
        private final MatrixCharacteristics _mcOut;

        public PadBlocksFunction(MatrixCharacteristics matrixCharacteristics) {
            this._mcOut = matrixCharacteristics;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            int computeBlockSize = UtilFunctions.computeBlockSize(this._mcOut.getRows(), matrixIndexes.getRowIndex(), this._mcOut.getRowsPerBlock());
            int computeBlockSize2 = UtilFunctions.computeBlockSize(this._mcOut.getCols(), matrixIndexes.getColumnIndex(), this._mcOut.getColsPerBlock());
            if (computeBlockSize == matrixBlock.getNumRows() && computeBlockSize2 == matrixBlock.getNumColumns()) {
                return tuple2;
            }
            if (computeBlockSize > matrixBlock.getNumRows()) {
                matrixBlock = matrixBlock.append(new MatrixBlock(computeBlockSize - matrixBlock.getNumRows(), computeBlockSize2, true), new MatrixBlock(), false);
            } else if (computeBlockSize2 > matrixBlock.getNumColumns()) {
                matrixBlock = matrixBlock.append(new MatrixBlock(computeBlockSize, computeBlockSize2 - matrixBlock.getNumColumns(), true), new MatrixBlock(), true);
            }
            return new Tuple2<>(matrixIndexes, matrixBlock);
        }
    }

    protected BuiltinNarySPInstruction(CPOperand[] cPOperandArr, CPOperand cPOperand, String str, String str2) {
        super(SPInstruction.SPType.BuiltinNary, str, str2);
        this.inputs = cPOperandArr;
        this.output = cPOperand;
    }

    public static BuiltinNarySPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[instructionPartsWithValueType.length - 1]);
        CPOperand[] cPOperandArr = new CPOperand[instructionPartsWithValueType.length - 2];
        for (int i = 1; i < instructionPartsWithValueType.length - 1; i++) {
            cPOperandArr[i - 1] = new CPOperand(instructionPartsWithValueType[i]);
        }
        return new BuiltinNarySPInstruction(cPOperandArr, cPOperand, str2, str);
    }

    @Override // org.apache.sysml.runtime.instructions.spark.SPInstruction, org.apache.sysml.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) throws DMLRuntimeException {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        boolean equals = getOpcode().equals("cbind");
        MatrixCharacteristics computeOutputMatrixCharacteristics = computeOutputMatrixCharacteristics(sparkExecutionContext, this.inputs, equals);
        MatrixCharacteristics matrixCharacteristics = new MatrixCharacteristics(0L, 0L, computeOutputMatrixCharacteristics.getRowsPerBlock(), computeOutputMatrixCharacteristics.getColsPerBlock(), 0L);
        JavaPairRDD javaPairRDD = null;
        for (CPOperand cPOperand : this.inputs) {
            MatrixCharacteristics matrixCharacteristics2 = sparkExecutionContext.getMatrixCharacteristics(cPOperand.getName());
            JavaPairRDD mapToPair = sparkExecutionContext.getBinaryBlockRDDHandleForVariable(cPOperand.getName()).flatMapToPair(new AppendGSPInstruction.ShiftMatrix(matrixCharacteristics, matrixCharacteristics2, equals)).mapToPair(new PadBlocksFunction(computeOutputMatrixCharacteristics));
            javaPairRDD = javaPairRDD != null ? javaPairRDD.union(mapToPair) : mapToPair;
            updateMatrixCharacteristics(matrixCharacteristics2, matrixCharacteristics, equals);
        }
        JavaPairRDD<MatrixIndexes, MatrixBlock> mergeByKey = RDDAggregateUtils.mergeByKey(javaPairRDD, SparkUtils.getNumPreferredPartitions(computeOutputMatrixCharacteristics), false);
        sparkExecutionContext.getMatrixCharacteristics(this.output.getName()).set(computeOutputMatrixCharacteristics);
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), mergeByKey);
        for (CPOperand cPOperand2 : this.inputs) {
            sparkExecutionContext.addLineageRDD(this.output.getName(), cPOperand2.getName());
        }
    }

    private static MatrixCharacteristics computeOutputMatrixCharacteristics(SparkExecutionContext sparkExecutionContext, CPOperand[] cPOperandArr, boolean z) throws DMLRuntimeException {
        MatrixCharacteristics matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(cPOperandArr[0].getName());
        MatrixCharacteristics matrixCharacteristics2 = new MatrixCharacteristics(0L, 0L, matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock(), 0L);
        for (CPOperand cPOperand : cPOperandArr) {
            updateMatrixCharacteristics(sparkExecutionContext.getMatrixCharacteristics(cPOperand.getName()), matrixCharacteristics2, z);
        }
        return matrixCharacteristics2;
    }

    private static void updateMatrixCharacteristics(MatrixCharacteristics matrixCharacteristics, MatrixCharacteristics matrixCharacteristics2, boolean z) {
        matrixCharacteristics2.setDimension(z ? Math.max(matrixCharacteristics2.getRows(), matrixCharacteristics.getRows()) : matrixCharacteristics2.getRows() + matrixCharacteristics.getRows(), z ? matrixCharacteristics2.getCols() + matrixCharacteristics.getCols() : Math.max(matrixCharacteristics2.getCols(), matrixCharacteristics.getCols()));
        matrixCharacteristics2.setNonZeros((matrixCharacteristics2.getNonZeros() == -1 || !matrixCharacteristics.dimsKnown(true)) ? -1L : matrixCharacteristics2.getNonZeros() + matrixCharacteristics.getNonZeros());
    }
}
