package org.apache.sysml.runtime.instructions.spark;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.lops.PartialAggregate;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.ReduceAll;
import org.apache.sysml.runtime.functionobjects.ReduceCol;
import org.apache.sysml.runtime.functionobjects.ReduceRow;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.functions.AggregateDropCorrectionFunction;
import org.apache.sysml.runtime.instructions.spark.functions.FilterDiagBlocksFunction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.class */
public class AggregateUnarySPInstruction extends UnarySPInstruction {
    private AggBinaryOp.SparkAggType _aggtype;
    private AggregateOperator _aop;

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction$RDDUAggFunction.class */
    private static class RDDUAggFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 2672082409287856038L;
        private AggregateUnaryOperator _op;
        private int _brlen;
        private int _bclen;

        public RDDUAggFunction(AggregateUnaryOperator aggregateUnaryOperator, int i, int i2) {
            this._op = null;
            this._brlen = -1;
            this._bclen = -1;
            this._op = aggregateUnaryOperator;
            this._brlen = i;
            this._bclen = i2;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            MatrixIndexes matrixIndexes2 = new MatrixIndexes();
            MatrixBlock matrixBlock2 = new MatrixBlock();
            OperationsOnMatrixValues.performAggregateUnary(matrixIndexes, matrixBlock, matrixIndexes2, matrixBlock2, this._op, this._brlen, this._bclen);
            return new Tuple2<>(matrixIndexes2, matrixBlock2);
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction$RDDUAggValueFunction.class */
    private static class RDDUAggValueFunction implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 5352374590399929673L;
        private AggregateUnaryOperator _op;
        private int _brlen;
        private int _bclen;
        private MatrixIndexes _ix;

        public RDDUAggValueFunction(AggregateUnaryOperator aggregateUnaryOperator, int i, int i2) {
            this._op = null;
            this._brlen = -1;
            this._bclen = -1;
            this._ix = null;
            this._op = aggregateUnaryOperator;
            this._brlen = i;
            this._bclen = i2;
            this._ix = new MatrixIndexes(1L, 1L);
        }

        public MatrixBlock call(MatrixBlock matrixBlock) throws Exception {
            MatrixBlock matrixBlock2 = new MatrixBlock();
            matrixBlock.aggregateUnaryOperations(this._op, matrixBlock2, this._brlen, this._bclen, this._ix);
            matrixBlock2.dropLastRowsOrColums(this._op.aggOp.correctionLocation);
            return matrixBlock2;
        }
    }

    public AggregateUnarySPInstruction(AggregateUnaryOperator aggregateUnaryOperator, AggregateOperator aggregateOperator, CPOperand cPOperand, CPOperand cPOperand2, AggBinaryOp.SparkAggType sparkAggType, String str, String str2) {
        super(aggregateUnaryOperator, cPOperand, cPOperand2, str, str2);
        this._aggtype = null;
        this._aop = null;
        this._aggtype = sparkAggType;
        this._aop = aggregateOperator;
    }

    public static AggregateUnarySPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 3);
        String str2 = instructionPartsWithValueType[0];
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
        CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2]);
        AggBinaryOp.SparkAggType valueOf = AggBinaryOp.SparkAggType.valueOf(instructionPartsWithValueType[3]);
        String deriveAggregateOperatorOpcode = InstructionUtils.deriveAggregateOperatorOpcode(str2);
        PartialAggregate.CorrectionLocationType deriveAggregateOperatorCorrectionLocation = InstructionUtils.deriveAggregateOperatorCorrectionLocation(str2);
        return new AggregateUnarySPInstruction(InstructionUtils.parseBasicAggregateUnaryOperator(str2), InstructionUtils.parseAggregateOperator(deriveAggregateOperatorOpcode, deriveAggregateOperatorCorrectionLocation != PartialAggregate.CorrectionLocationType.NONE ? "true" : "false", deriveAggregateOperatorCorrectionLocation.toString()), cPOperand, cPOperand2, valueOf, str2, str);
    }

    @Override // org.apache.sysml.runtime.instructions.spark.SPInstruction, org.apache.sysml.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) throws DMLRuntimeException, DMLUnsupportedOperationException {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        MatrixCharacteristics matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(this.input1.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockRDDHandleForVariable = sparkExecutionContext.getBinaryBlockRDDHandleForVariable(this.input1.getName());
        if (getOpcode().equalsIgnoreCase("uaktrace")) {
            binaryBlockRDDHandleForVariable = binaryBlockRDDHandleForVariable.filter(new FilterDiagBlocksFunction());
        }
        AggregateUnaryOperator aggregateUnaryOperator = (AggregateUnaryOperator) this._optr;
        AggregateOperator aggregateOperator = this._aop;
        JavaPairRDD<MatrixIndexes, MatrixBlock> mapValues = this._aggtype == AggBinaryOp.SparkAggType.NONE ? binaryBlockRDDHandleForVariable.mapValues(new RDDUAggValueFunction(aggregateUnaryOperator, matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock())) : binaryBlockRDDHandleForVariable.mapToPair(new RDDUAggFunction(aggregateUnaryOperator, matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock()));
        if (this._aggtype == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
            MatrixBlock aggStable = RDDAggregateUtils.aggStable(mapValues, aggregateOperator);
            aggStable.dropLastRowsOrColums(aggregateOperator.correctionLocation);
            sparkExecutionContext.setMatrixOutput(this.output.getName(), aggStable);
            return;
        }
        if (this._aggtype == AggBinaryOp.SparkAggType.MULTI_BLOCK) {
            mapValues = RDDAggregateUtils.aggByKeyStable(mapValues, aggregateOperator);
            if (aggregateUnaryOperator.aggOp.correctionExists) {
                mapValues = mapValues.mapValues(new AggregateDropCorrectionFunction(aggregateOperator));
            }
        }
        updateUnaryAggOutputMatrixCharacteristics(sparkExecutionContext);
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), mapValues);
        sparkExecutionContext.addLineageRDD(this.output.getName(), this.input1.getName());
    }

    protected void updateUnaryAggOutputMatrixCharacteristics(SparkExecutionContext sparkExecutionContext) throws DMLRuntimeException {
        AggregateUnaryOperator aggregateUnaryOperator = (AggregateUnaryOperator) this._optr;
        MatrixCharacteristics matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(this.input1.getName());
        MatrixCharacteristics matrixCharacteristics2 = sparkExecutionContext.getMatrixCharacteristics(this.output.getName());
        if (matrixCharacteristics2.dimsKnown()) {
            return;
        }
        if (!matrixCharacteristics.dimsKnown()) {
            throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from input:" + matrixCharacteristics.toString() + " " + matrixCharacteristics2.toString());
        }
        if (aggregateUnaryOperator.indexFn instanceof ReduceAll) {
            matrixCharacteristics2.set(1L, 1L, matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock());
        } else if (aggregateUnaryOperator.indexFn instanceof ReduceCol) {
            matrixCharacteristics2.set(matrixCharacteristics.getRows(), 1L, matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock());
        } else if (aggregateUnaryOperator.indexFn instanceof ReduceRow) {
            matrixCharacteristics2.set(1L, matrixCharacteristics.getCols(), matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock());
        }
    }
}
