package org.apache.sysml.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.lops.MapMult;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcastMatrix;
import org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.class */
public class MapmmSPInstruction extends BinarySPInstruction {
    private MapMult.CacheType _type;
    private boolean _outputEmpty;
    private AggBinaryOp.SparkAggType _aggtype;

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction$RDDFlatMapMMFunction.class */
    private static class RDDFlatMapMMFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -6076256569118957281L;
        private MapMult.CacheType _type;
        private AggregateBinaryOperator _op;
        private PartitionedBroadcastMatrix _pbc;

        public RDDFlatMapMMFunction(MapMult.CacheType cacheType, PartitionedBroadcastMatrix partitionedBroadcastMatrix) {
            this._type = null;
            this._op = null;
            this._pbc = null;
            this._type = cacheType;
            this._pbc = partitionedBroadcastMatrix;
            this._op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0.0d, Plus.getPlusFnObject()));
        }

        public Iterable<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            ArrayList arrayList = new ArrayList();
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            if (this._type == MapMult.CacheType.LEFT) {
                int numRowBlocks = this._pbc.getNumRowBlocks();
                for (int i = 1; i <= numRowBlocks; i++) {
                    MatrixBlock matrixBlock2 = this._pbc.getMatrixBlock(i, (int) matrixIndexes.getRowIndex());
                    MatrixIndexes matrixIndexes2 = new MatrixIndexes();
                    MatrixBlock matrixBlock3 = new MatrixBlock();
                    OperationsOnMatrixValues.performAggregateBinary(new MatrixIndexes(i, matrixIndexes.getRowIndex()), matrixBlock2, matrixIndexes, matrixBlock, matrixIndexes2, matrixBlock3, this._op);
                    arrayList.add(new Tuple2(matrixIndexes2, matrixBlock3));
                }
            } else {
                int numColumnBlocks = this._pbc.getNumColumnBlocks();
                for (int i2 = 1; i2 <= numColumnBlocks; i2++) {
                    MatrixBlock matrixBlock4 = this._pbc.getMatrixBlock((int) matrixIndexes.getColumnIndex(), i2);
                    MatrixIndexes matrixIndexes3 = new MatrixIndexes();
                    MatrixBlock matrixBlock5 = new MatrixBlock();
                    OperationsOnMatrixValues.performAggregateBinary(matrixIndexes, matrixBlock, new MatrixIndexes(matrixIndexes.getColumnIndex(), i2), matrixBlock4, matrixIndexes3, matrixBlock5, this._op);
                    arrayList.add(new Tuple2(matrixIndexes3, matrixBlock5));
                }
            }
            return arrayList;
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction$RDDMapMMFunction.class */
    private static class RDDMapMMFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 8197406787010296291L;
        private MapMult.CacheType _type;
        private AggregateBinaryOperator _op;
        private PartitionedBroadcastMatrix _pbc;

        public RDDMapMMFunction(MapMult.CacheType cacheType, PartitionedBroadcastMatrix partitionedBroadcastMatrix) {
            this._type = null;
            this._op = null;
            this._pbc = null;
            this._type = cacheType;
            this._pbc = partitionedBroadcastMatrix;
            this._op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0.0d, Plus.getPlusFnObject()));
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            MatrixIndexes matrixIndexes2 = new MatrixIndexes();
            MatrixBlock matrixBlock2 = new MatrixBlock();
            if (this._type == MapMult.CacheType.LEFT) {
                OperationsOnMatrixValues.performAggregateBinary(new MatrixIndexes(1L, matrixIndexes.getRowIndex()), this._pbc.getMatrixBlock(1, (int) matrixIndexes.getRowIndex()), matrixIndexes, matrixBlock, matrixIndexes2, matrixBlock2, this._op);
            } else {
                OperationsOnMatrixValues.performAggregateBinary(matrixIndexes, matrixBlock, new MatrixIndexes(matrixIndexes.getColumnIndex(), 1L), this._pbc.getMatrixBlock((int) matrixIndexes.getColumnIndex(), 1), matrixIndexes2, matrixBlock2, this._op);
            }
            return new Tuple2<>(matrixIndexes2, matrixBlock2);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction$RDDMapMMPartitionFunction.class */
    public static class RDDMapMMPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 1886318890063064287L;
        private MapMult.CacheType _type;
        private AggregateBinaryOperator _op;
        private PartitionedBroadcastMatrix _pbc;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction$RDDMapMMPartitionFunction$MapMMPartitionIterator.class */
        public class MapMMPartitionIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> {
            public MapMMPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> it) {
                super(it);
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator
            public Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
                MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
                MatrixValue matrixValue = (MatrixBlock) tuple2._2();
                MatrixBlock matrixBlock = new MatrixBlock();
                if (RDDMapMMPartitionFunction.this._type == MapMult.CacheType.LEFT) {
                    MatrixBlock matrixBlock2 = RDDMapMMPartitionFunction.this._pbc.getMatrixBlock(1, (int) matrixIndexes.getRowIndex());
                    matrixBlock2.aggregateBinaryOperations(matrixBlock2, matrixValue, matrixBlock, RDDMapMMPartitionFunction.this._op);
                } else {
                    matrixValue.aggregateBinaryOperations(matrixValue, RDDMapMMPartitionFunction.this._pbc.getMatrixBlock((int) matrixIndexes.getColumnIndex(), 1), matrixBlock, RDDMapMMPartitionFunction.this._op);
                }
                return new Tuple2<>(matrixIndexes, matrixBlock);
            }
        }

        public RDDMapMMPartitionFunction(MapMult.CacheType cacheType, PartitionedBroadcastMatrix partitionedBroadcastMatrix) {
            this._type = null;
            this._op = null;
            this._pbc = null;
            this._type = cacheType;
            this._pbc = partitionedBroadcastMatrix;
            this._op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0.0d, Plus.getPlusFnObject()));
        }

        public Iterable<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> it) throws Exception {
            return new MapMMPartitionIterator(it);
        }
    }

    public MapmmSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, MapMult.CacheType cacheType, boolean z, AggBinaryOp.SparkAggType sparkAggType, String str, String str2) {
        super(operator, cPOperand, cPOperand2, cPOperand3, str, str2);
        this._type = null;
        this._outputEmpty = true;
        this._sptype = SPInstruction.SPINSTRUCTION_TYPE.MAPMM;
        this._type = cacheType;
        this._outputEmpty = z;
        this._aggtype = sparkAggType;
    }

    public static MapmmSPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (!str2.equalsIgnoreCase(MapMult.OPCODE)) {
            throw new DMLRuntimeException("MapmmSPInstruction.parseInstruction():: Unknown opcode " + str2);
        }
        return new MapmmSPInstruction(new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0.0d, Plus.getPlusFnObject())), new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), MapMult.CacheType.valueOf(instructionPartsWithValueType[4]), Boolean.parseBoolean(instructionPartsWithValueType[5]), AggBinaryOp.SparkAggType.valueOf(instructionPartsWithValueType[6]), str2, str);
    }

    @Override // org.apache.sysml.runtime.instructions.spark.SPInstruction, org.apache.sysml.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) throws DMLRuntimeException, DMLUnsupportedOperationException {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        String name = this._type == MapMult.CacheType.LEFT ? this.input2.getName() : this.input1.getName();
        String name2 = this._type == MapMult.CacheType.LEFT ? this.input1.getName() : this.input2.getName();
        MatrixCharacteristics matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(name);
        MatrixCharacteristics matrixCharacteristics2 = sparkExecutionContext.getMatrixCharacteristics(name2);
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockRDDHandleForVariable = sparkExecutionContext.getBinaryBlockRDDHandleForVariable(name);
        PartitionedBroadcastMatrix broadcastForVariable = sparkExecutionContext.getBroadcastForVariable(name2);
        if (!this._outputEmpty) {
            binaryBlockRDDHandleForVariable = binaryBlockRDDHandleForVariable.filter(new FilterNonEmptyBlocksFunction());
        }
        JavaPairRDD<MatrixIndexes, MatrixBlock> flatMapToPair = requiresFlatMapFunction(this._type, matrixCharacteristics2) ? binaryBlockRDDHandleForVariable.flatMapToPair(new RDDFlatMapMMFunction(this._type, broadcastForVariable)) : preservesPartitioning(matrixCharacteristics, this._type) ? binaryBlockRDDHandleForVariable.mapPartitionsToPair(new RDDMapMMPartitionFunction(this._type, broadcastForVariable), true) : binaryBlockRDDHandleForVariable.mapToPair(new RDDMapMMFunction(this._type, broadcastForVariable));
        if (!this._outputEmpty) {
            flatMapToPair = flatMapToPair.filter(new FilterNonEmptyBlocksFunction());
        }
        if (this._aggtype == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
            sparkExecutionContext.setMatrixOutput(this.output.getName(), RDDAggregateUtils.sumStable(flatMapToPair));
            return;
        }
        if (this._aggtype == AggBinaryOp.SparkAggType.MULTI_BLOCK) {
            flatMapToPair = RDDAggregateUtils.sumByKeyStable(flatMapToPair);
        }
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), flatMapToPair);
        sparkExecutionContext.addLineageRDD(this.output.getName(), name);
        sparkExecutionContext.addLineageBroadcast(this.output.getName(), name2);
        updateBinaryMMOutputMatrixCharacteristics(sparkExecutionContext, true);
    }

    private static boolean preservesPartitioning(MatrixCharacteristics matrixCharacteristics, MapMult.CacheType cacheType) {
        return cacheType == MapMult.CacheType.LEFT ? matrixCharacteristics.dimsKnown() && matrixCharacteristics.getRows() <= ((long) matrixCharacteristics.getRowsPerBlock()) : matrixCharacteristics.dimsKnown() && matrixCharacteristics.getCols() <= ((long) matrixCharacteristics.getColsPerBlock());
    }

    private static boolean requiresFlatMapFunction(MapMult.CacheType cacheType, MatrixCharacteristics matrixCharacteristics) {
        return (cacheType == MapMult.CacheType.LEFT && matrixCharacteristics.getRows() > ((long) matrixCharacteristics.getRowsPerBlock())) || (cacheType == MapMult.CacheType.RIGHT && matrixCharacteristics.getCols() > ((long) matrixCharacteristics.getColsPerBlock()));
    }
}
