package org.apache.sysml.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.commons.io.FileUtils;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.lops.MapMult;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.class */
public class MapmmSPInstruction extends BinarySPInstruction {
    private MapMult.CacheType _type;
    private boolean _outputEmpty;
    private AggBinaryOp.SparkAggType _aggtype;

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction$RDDFlatMapMMFunction.class */
    private static class RDDFlatMapMMFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -6076256569118957281L;
        private final MapMult.CacheType _type;
        private final AggregateBinaryOperator _op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0.0d, Plus.getPlusFnObject()));
        private final PartitionedBroadcast<MatrixBlock> _pbc;

        public RDDFlatMapMMFunction(MapMult.CacheType cacheType, PartitionedBroadcast<MatrixBlock> partitionedBroadcast) {
            this._type = cacheType;
            this._pbc = partitionedBroadcast;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            ArrayList arrayList = new ArrayList();
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            if (this._type == MapMult.CacheType.LEFT) {
                int numRowBlocks = this._pbc.getNumRowBlocks();
                for (int i = 1; i <= numRowBlocks; i++) {
                    MatrixBlock block = this._pbc.getBlock(i, (int) matrixIndexes.getRowIndex());
                    MatrixIndexes matrixIndexes2 = new MatrixIndexes();
                    MatrixBlock matrixBlock2 = new MatrixBlock();
                    OperationsOnMatrixValues.performAggregateBinary(new MatrixIndexes(i, matrixIndexes.getRowIndex()), block, matrixIndexes, matrixBlock, matrixIndexes2, matrixBlock2, this._op);
                    arrayList.add(new Tuple2(matrixIndexes2, matrixBlock2));
                }
            } else {
                int numColumnBlocks = this._pbc.getNumColumnBlocks();
                for (int i2 = 1; i2 <= numColumnBlocks; i2++) {
                    MatrixBlock block2 = this._pbc.getBlock((int) matrixIndexes.getColumnIndex(), i2);
                    MatrixIndexes matrixIndexes3 = new MatrixIndexes();
                    MatrixBlock matrixBlock3 = new MatrixBlock();
                    OperationsOnMatrixValues.performAggregateBinary(matrixIndexes, matrixBlock, new MatrixIndexes(matrixIndexes.getColumnIndex(), i2), block2, matrixIndexes3, matrixBlock3, this._op);
                    arrayList.add(new Tuple2(matrixIndexes3, matrixBlock3));
                }
            }
            return arrayList.iterator();
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction$RDDMapMMFunction.class */
    private static class RDDMapMMFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 8197406787010296291L;
        private final MapMult.CacheType _type;
        private final AggregateBinaryOperator _op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0.0d, Plus.getPlusFnObject()));
        private final PartitionedBroadcast<MatrixBlock> _pbc;

        public RDDMapMMFunction(MapMult.CacheType cacheType, PartitionedBroadcast<MatrixBlock> partitionedBroadcast) {
            this._type = cacheType;
            this._pbc = partitionedBroadcast;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            MatrixIndexes matrixIndexes2 = new MatrixIndexes();
            MatrixBlock matrixBlock2 = new MatrixBlock();
            if (this._type == MapMult.CacheType.LEFT) {
                OperationsOnMatrixValues.performAggregateBinary(new MatrixIndexes(1L, matrixIndexes.getRowIndex()), this._pbc.getBlock(1, (int) matrixIndexes.getRowIndex()), matrixIndexes, matrixBlock, matrixIndexes2, matrixBlock2, this._op);
            } else {
                OperationsOnMatrixValues.performAggregateBinary(matrixIndexes, matrixBlock, new MatrixIndexes(matrixIndexes.getColumnIndex(), 1L), this._pbc.getBlock((int) matrixIndexes.getColumnIndex(), 1), matrixIndexes2, matrixBlock2, this._op);
            }
            return new Tuple2<>(matrixIndexes2, matrixBlock2);
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction$RDDMapMMFunction2.class */
    private static class RDDMapMMFunction2 implements Function<Tuple2<MatrixIndexes, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = -2753453898072910182L;
        private final MapMult.CacheType _type;
        private final AggregateBinaryOperator _op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0.0d, Plus.getPlusFnObject()));
        private final PartitionedBroadcast<MatrixBlock> _pbc;

        public RDDMapMMFunction2(MapMult.CacheType cacheType, PartitionedBroadcast<MatrixBlock> partitionedBroadcast) {
            this._type = cacheType;
            this._pbc = partitionedBroadcast;
        }

        public MatrixBlock call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            return this._type == MapMult.CacheType.LEFT ? (MatrixBlock) OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes(this._pbc.getBlock(1, (int) matrixIndexes.getRowIndex()), matrixBlock, new MatrixBlock(), this._op) : (MatrixBlock) OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes(matrixBlock, this._pbc.getBlock((int) matrixIndexes.getColumnIndex(), 1), new MatrixBlock(), this._op);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction$RDDMapMMPartitionFunction.class */
    public static class RDDMapMMPartitionFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 1886318890063064287L;
        private final MapMult.CacheType _type;
        private final AggregateBinaryOperator _op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0.0d, Plus.getPlusFnObject()));
        private final PartitionedBroadcast<MatrixBlock> _pbc;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction$RDDMapMMPartitionFunction$MapMMPartitionIterator.class */
        public class MapMMPartitionIterator extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> {
            public MapMMPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> it) {
                super(it);
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator
            public Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
                MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
                MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
                MatrixBlock matrixBlock2 = new MatrixBlock();
                if (RDDMapMMPartitionFunction.this._type == MapMult.CacheType.LEFT) {
                    MatrixBlock matrixBlock3 = (MatrixBlock) RDDMapMMPartitionFunction.this._pbc.getBlock(1, (int) matrixIndexes.getRowIndex());
                    matrixBlock3.aggregateBinaryOperations(matrixBlock3, matrixBlock, matrixBlock2, RDDMapMMPartitionFunction.this._op);
                } else {
                    matrixBlock.aggregateBinaryOperations(matrixBlock, (MatrixBlock) RDDMapMMPartitionFunction.this._pbc.getBlock((int) matrixIndexes.getColumnIndex(), 1), matrixBlock2, RDDMapMMPartitionFunction.this._op);
                }
                return new Tuple2<>(matrixIndexes, matrixBlock2);
            }
        }

        public RDDMapMMPartitionFunction(MapMult.CacheType cacheType, PartitionedBroadcast<MatrixBlock> partitionedBroadcast) {
            this._type = cacheType;
            this._pbc = partitionedBroadcast;
        }

        public LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> it) throws Exception {
            return new MapMMPartitionIterator(it);
        }
    }

    private MapmmSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, MapMult.CacheType cacheType, boolean z, AggBinaryOp.SparkAggType sparkAggType, String str, String str2) {
        super(SPInstruction.SPType.MAPMM, operator, cPOperand, cPOperand2, cPOperand3, str, str2);
        this._type = null;
        this._outputEmpty = true;
        this._type = cacheType;
        this._outputEmpty = z;
        this._aggtype = sparkAggType;
    }

    public static MapmmSPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (!str2.equalsIgnoreCase(MapMult.OPCODE)) {
            throw new DMLRuntimeException("MapmmSPInstruction.parseInstruction():: Unknown opcode " + str2);
        }
        return new MapmmSPInstruction(new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0.0d, Plus.getPlusFnObject())), new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), MapMult.CacheType.valueOf(instructionPartsWithValueType[4]), Boolean.parseBoolean(instructionPartsWithValueType[5]), AggBinaryOp.SparkAggType.valueOf(instructionPartsWithValueType[6]), str2, str);
    }

    @Override // org.apache.sysml.runtime.instructions.spark.SPInstruction, org.apache.sysml.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) throws DMLRuntimeException {
        JavaPairRDD<MatrixIndexes, MatrixBlock> mapPartitionsToPair;
        int numRepartitioning;
        int numRepartitioning2;
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        MapMult.CacheType cacheType = this._type;
        String name = cacheType.isRight() ? this.input1.getName() : this.input2.getName();
        String name2 = cacheType.isRight() ? this.input2.getName() : this.input1.getName();
        MatrixCharacteristics matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(name);
        MatrixCharacteristics matrixCharacteristics2 = sparkExecutionContext.getMatrixCharacteristics(name2);
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockRDDHandleForVariable = sparkExecutionContext.getBinaryBlockRDDHandleForVariable(name);
        if (requiresFlatMapFunction(cacheType, matrixCharacteristics2) && requiresRepartitioning(cacheType, matrixCharacteristics, matrixCharacteristics2, binaryBlockRDDHandleForVariable.getNumPartitions()) && (numRepartitioning2 = getNumRepartitioning(cacheType.getFlipped(), matrixCharacteristics2, matrixCharacteristics)) > (numRepartitioning = getNumRepartitioning(cacheType, matrixCharacteristics, matrixCharacteristics2))) {
            cacheType = cacheType.getFlipped();
            name = cacheType.isRight() ? this.input1.getName() : this.input2.getName();
            name2 = cacheType.isRight() ? this.input2.getName() : this.input1.getName();
            matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(name);
            matrixCharacteristics2 = sparkExecutionContext.getMatrixCharacteristics(name2);
            binaryBlockRDDHandleForVariable = sparkExecutionContext.getBinaryBlockRDDHandleForVariable(name);
            LOG.warn("Mapmm: Switching rdd ('" + name2 + "') and broadcast ('" + name + "') inputs for repartitioning because this allows better control of output partition sizes (" + numRepartitioning + " < " + numRepartitioning2 + ").");
        }
        PartitionedBroadcast<MatrixBlock> broadcastForVariable = sparkExecutionContext.getBroadcastForVariable(name2);
        if (!this._outputEmpty) {
            binaryBlockRDDHandleForVariable = binaryBlockRDDHandleForVariable.filter(new FilterNonEmptyBlocksFunction());
        }
        if (this._aggtype == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
            sparkExecutionContext.setMatrixOutput(this.output.getName(), RDDAggregateUtils.sumStable((JavaRDD<MatrixBlock>) binaryBlockRDDHandleForVariable.map(new RDDMapMMFunction2(cacheType, broadcastForVariable))), getExtendedOpcode());
            return;
        }
        if (requiresFlatMapFunction(cacheType, matrixCharacteristics2)) {
            if (requiresRepartitioning(cacheType, matrixCharacteristics, matrixCharacteristics2, binaryBlockRDDHandleForVariable.getNumPartitions())) {
                int numRepartitioning3 = getNumRepartitioning(cacheType, matrixCharacteristics, matrixCharacteristics2);
                LOG.warn("Mapmm: Repartition input rdd '" + name + "' from " + binaryBlockRDDHandleForVariable.getNumPartitions() + " to " + numRepartitioning3 + " partitions to satisfy size restrictions of output partitions.");
                binaryBlockRDDHandleForVariable = binaryBlockRDDHandleForVariable.repartition(numRepartitioning3);
            }
            mapPartitionsToPair = binaryBlockRDDHandleForVariable.flatMapToPair(new RDDFlatMapMMFunction(cacheType, broadcastForVariable));
        } else {
            mapPartitionsToPair = preservesPartitioning(matrixCharacteristics, cacheType) ? binaryBlockRDDHandleForVariable.mapPartitionsToPair(new RDDMapMMPartitionFunction(cacheType, broadcastForVariable), true) : binaryBlockRDDHandleForVariable.mapToPair(new RDDMapMMFunction(cacheType, broadcastForVariable));
        }
        if (!this._outputEmpty) {
            mapPartitionsToPair = mapPartitionsToPair.filter(new FilterNonEmptyBlocksFunction());
        }
        if (this._aggtype == AggBinaryOp.SparkAggType.MULTI_BLOCK) {
            mapPartitionsToPair = RDDAggregateUtils.sumByKeyStable(mapPartitionsToPair, false);
        }
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), mapPartitionsToPair);
        sparkExecutionContext.addLineageRDD(this.output.getName(), name);
        sparkExecutionContext.addLineageBroadcast(this.output.getName(), name2);
        updateBinaryMMOutputMatrixCharacteristics(sparkExecutionContext, true);
    }

    private static boolean preservesPartitioning(MatrixCharacteristics matrixCharacteristics, MapMult.CacheType cacheType) {
        return cacheType == MapMult.CacheType.LEFT ? matrixCharacteristics.dimsKnown() && matrixCharacteristics.getRows() <= ((long) matrixCharacteristics.getRowsPerBlock()) : matrixCharacteristics.dimsKnown() && matrixCharacteristics.getCols() <= ((long) matrixCharacteristics.getColsPerBlock());
    }

    private static boolean requiresFlatMapFunction(MapMult.CacheType cacheType, MatrixCharacteristics matrixCharacteristics) {
        return (cacheType == MapMult.CacheType.LEFT && matrixCharacteristics.getRows() > ((long) matrixCharacteristics.getRowsPerBlock())) || (cacheType == MapMult.CacheType.RIGHT && matrixCharacteristics.getCols() > ((long) matrixCharacteristics.getColsPerBlock()));
    }

    private static boolean requiresRepartitioning(MapMult.CacheType cacheType, MatrixCharacteristics matrixCharacteristics, MatrixCharacteristics matrixCharacteristics2, int i) {
        boolean z = cacheType == MapMult.CacheType.LEFT;
        return (z ? (matrixCharacteristics.getRows() > ((long) matrixCharacteristics.getRowsPerBlock()) ? 1 : (matrixCharacteristics.getRows() == ((long) matrixCharacteristics.getRowsPerBlock()) ? 0 : -1)) <= 0 : (matrixCharacteristics.getCols() > ((long) matrixCharacteristics.getColsPerBlock()) ? 1 : (matrixCharacteristics.getCols() == ((long) matrixCharacteristics.getColsPerBlock()) ? 0 : -1)) <= 0) && (((OptimizerUtils.estimatePartitionedSizeExactSparsity(z ? matrixCharacteristics2.getRows() : matrixCharacteristics.getRows(), z ? matrixCharacteristics.getCols() : matrixCharacteristics2.getCols(), z ? (long) matrixCharacteristics2.getRowsPerBlock() : (long) matrixCharacteristics.getRowsPerBlock(), z ? (long) matrixCharacteristics.getColsPerBlock() : (long) matrixCharacteristics2.getColsPerBlock(), 1.0d) / ((long) i)) > FileUtils.ONE_GB ? 1 : ((OptimizerUtils.estimatePartitionedSizeExactSparsity(z ? matrixCharacteristics2.getRows() : matrixCharacteristics.getRows(), z ? matrixCharacteristics.getCols() : matrixCharacteristics2.getCols(), z ? (long) matrixCharacteristics2.getRowsPerBlock() : (long) matrixCharacteristics.getRowsPerBlock(), z ? (long) matrixCharacteristics.getColsPerBlock() : (long) matrixCharacteristics2.getColsPerBlock(), 1.0d) / ((long) i)) == FileUtils.ONE_GB ? 0 : -1)) > 0) && matrixCharacteristics.dimsKnown() && matrixCharacteristics2.dimsKnown();
    }

    private static int getNumRepartitioning(MapMult.CacheType cacheType, MatrixCharacteristics matrixCharacteristics, MatrixCharacteristics matrixCharacteristics2) {
        boolean z = cacheType == MapMult.CacheType.LEFT;
        return (int) Math.min(OptimizerUtils.estimatePartitionedSizeExactSparsity(z ? matrixCharacteristics2.getRows() : matrixCharacteristics.getRows(), z ? matrixCharacteristics.getCols() : matrixCharacteristics2.getCols(), z ? matrixCharacteristics2.getRowsPerBlock() : matrixCharacteristics.getRowsPerBlock(), z ? matrixCharacteristics.getColsPerBlock() : matrixCharacteristics2.getColsPerBlock(), 1.0d) / InfrastructureAnalyzer.getHDFSBlockSize(), z ? matrixCharacteristics.getNumColBlocks() : matrixCharacteristics.getNumRowBlocks());
    }
}
