package org.apache.sysml.runtime.instructions.spark;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysml.lops.PickByCount;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.DoubleObject;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.util.UtilFunctions;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.class */
public class QuantilePickSPInstruction extends BinarySPInstruction {
    private PickByCount.OperationTypes _type;

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction$ExtractAndSumFunction.class */
    private static class ExtractAndSumFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -584044441055250489L;
        private long _minRowIndex;
        private long _maxRowIndex;
        private int _minPos;
        private int _maxPos;

        public ExtractAndSumFunction(long j, long j2, int i) {
            this._minRowIndex = UtilFunctions.computeBlockIndex(j, i);
            this._maxRowIndex = UtilFunctions.computeBlockIndex(j2, i);
            this._minPos = UtilFunctions.computeCellInBlock(j, i);
            this._maxPos = UtilFunctions.computeCellInBlock(j2, i);
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            if (this._minRowIndex == this._maxRowIndex) {
                matrixBlock = matrixBlock.sliceOperations(this._minPos - 1, this._maxPos - 1, 0, 0, new MatrixBlock());
            } else if (matrixIndexes.getRowIndex() == this._minRowIndex) {
                matrixBlock = matrixBlock.sliceOperations(this._minPos, matrixBlock.getNumRows() - 1, 0, 0, new MatrixBlock());
            } else if (matrixIndexes.getRowIndex() == this._maxRowIndex) {
                matrixBlock = matrixBlock.sliceOperations(0, this._maxPos, 0, 0, new MatrixBlock());
            }
            MatrixBlock matrixBlock2 = new MatrixBlock(1, 2, false);
            matrixBlock2.setValue(0, 0, matrixBlock.sum());
            return new Tuple2<>(new MatrixIndexes(1L, 1L), matrixBlock2);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction$ExtractAndSumWeightsFunction.class */
    public static class ExtractAndSumWeightsFunction implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 7169831202450745373L;

        private ExtractAndSumWeightsFunction() {
        }

        public MatrixBlock call(MatrixBlock matrixBlock) throws Exception {
            MatrixBlock sliceOperations = matrixBlock.sliceOperations(0, matrixBlock.getNumRows() - 1, 1, 1, new MatrixBlock());
            MatrixBlock matrixBlock2 = new MatrixBlock(1, 2, false);
            matrixBlock2.setValue(0, 0, sliceOperations.sum());
            return matrixBlock2;
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction$FilterFunction.class */
    private static class FilterFunction implements Function<Tuple2<MatrixIndexes, MatrixBlock>, Boolean> {
        private static final long serialVersionUID = -8249102381116157388L;
        private long _minRowIndex;
        private long _maxRowIndex;

        public FilterFunction(long j, long j2, int i) {
            this._minRowIndex = UtilFunctions.computeBlockIndex(j, i);
            this._maxRowIndex = UtilFunctions.computeBlockIndex(j2, i);
        }

        public Boolean call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            long rowIndex = ((MatrixIndexes) tuple2._1()).getRowIndex();
            return Boolean.valueOf(rowIndex >= this._minRowIndex && rowIndex <= this._maxRowIndex);
        }
    }

    public QuantilePickSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, PickByCount.OperationTypes operationTypes, boolean z, String str, String str2) {
        this(operator, cPOperand, null, cPOperand2, operationTypes, z, str, str2);
    }

    public QuantilePickSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, PickByCount.OperationTypes operationTypes, boolean z, String str, String str2) {
        super(operator, cPOperand, cPOperand2, cPOperand3, str, str2);
        this._type = null;
        this._sptype = SPInstruction.SPINSTRUCTION_TYPE.QPick;
        this._type = operationTypes;
    }

    public static QuantilePickSPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (!str2.equalsIgnoreCase(PickByCount.OPCODE)) {
            throw new DMLRuntimeException("Unknown opcode while parsing a QuantilePickCPInstruction: " + str);
        }
        if (instructionPartsWithValueType.length == 4) {
            return new QuantilePickSPInstruction(null, new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), PickByCount.OperationTypes.IQM, false, str2, str);
        }
        if (instructionPartsWithValueType.length == 5) {
            return new QuantilePickSPInstruction(null, new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), PickByCount.OperationTypes.valueOf(instructionPartsWithValueType[3]), Boolean.parseBoolean(instructionPartsWithValueType[4]), str2, str);
        }
        if (instructionPartsWithValueType.length == 6) {
            return new QuantilePickSPInstruction(null, new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), PickByCount.OperationTypes.valueOf(instructionPartsWithValueType[4]), Boolean.parseBoolean(instructionPartsWithValueType[5]), str2, str);
        }
        return null;
    }

    @Override // org.apache.sysml.runtime.instructions.spark.SPInstruction, org.apache.sysml.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) throws DMLRuntimeException {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        MatrixCharacteristics matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(this.input1.getName());
        boolean z = matrixCharacteristics.getCols() == 2;
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockRDDHandleForVariable = sparkExecutionContext.getBinaryBlockRDDHandleForVariable(this.input1.getName());
        switch (this._type) {
            case VALUEPICK:
                executionContext.setScalarOutput(this.output.getName(), new DoubleObject(lookupKey(binaryBlockRDDHandleForVariable, (long) Math.ceil(executionContext.getScalarInput(this.input2.getName(), this.input2.getValueType(), this.input2.isLiteral()).getDoubleValue() * (z ? sumWeights(binaryBlockRDDHandleForVariable) : matrixCharacteristics.getRows())), matrixCharacteristics.getRowsPerBlock())));
                return;
            case MEDIAN:
                executionContext.setScalarOutput(this.output.getName(), new DoubleObject(lookupKey(binaryBlockRDDHandleForVariable, (long) Math.ceil(0.5d * (z ? sumWeights(binaryBlockRDDHandleForVariable) : matrixCharacteristics.getRows())), matrixCharacteristics.getRowsPerBlock())));
                return;
            case IQM:
                double sumWeights = z ? sumWeights(binaryBlockRDDHandleForVariable) : matrixCharacteristics.getRows();
                long ceil = (long) Math.ceil(0.25d * sumWeights);
                long ceil2 = (long) Math.ceil(0.75d * sumWeights);
                executionContext.setScalarOutput(this.output.getName(), new DoubleObject(((RDDAggregateUtils.sumStable(binaryBlockRDDHandleForVariable.filter(new FilterFunction(ceil + 1, ceil2, matrixCharacteristics.getRowsPerBlock())).mapToPair(new ExtractAndSumFunction(ceil + 1, ceil2, matrixCharacteristics.getRowsPerBlock()))).getValue(0, 0) + ((ceil - (0.25d * sumWeights)) * lookupKey(binaryBlockRDDHandleForVariable, ceil, matrixCharacteristics.getRowsPerBlock()))) - ((ceil2 - (0.75d * sumWeights)) * lookupKey(binaryBlockRDDHandleForVariable, ceil2, matrixCharacteristics.getRowsPerBlock()))) / (0.5d * sumWeights)));
                return;
            default:
                throw new DMLRuntimeException("Unsupported qpick operation type: " + this._type);
        }
    }

    private double lookupKey(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, long j, int i) {
        return ((MatrixBlock) javaPairRDD.lookup(new MatrixIndexes(UtilFunctions.computeBlockIndex(j, i), 1L)).get(0)).quickGetValue(UtilFunctions.computeCellInBlock(j, i), 0);
    }

    private double sumWeights(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD) {
        return RDDAggregateUtils.sumStable(javaPairRDD.mapValues(new ExtractAndSumWeightsFunction())).quickGetValue(0, 0);
    }
}
