package org.apache.sysml.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/CumulativeOffsetSPInstruction.class */
public class CumulativeOffsetSPInstruction extends BinarySPInstruction {
    private BinaryOperator _bop;
    private UnaryOperator _uop;
    private double _initValue;

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/CumulativeOffsetSPInstruction$RDDCumOffsetFunction.class */
    private static class RDDCumOffsetFunction implements Function<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = -5804080263258064743L;
        private UnaryOperator _uop;
        private BinaryOperator _bop;

        public RDDCumOffsetFunction(UnaryOperator unaryOperator, BinaryOperator binaryOperator) {
            this._uop = null;
            this._bop = null;
            this._uop = unaryOperator;
            this._bop = binaryOperator;
        }

        public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> tuple2) throws Exception {
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._1();
            MatrixBlock matrixBlock2 = (MatrixBlock) tuple2._2();
            MatrixBlock matrixBlock3 = new MatrixBlock(matrixBlock.getNumRows(), matrixBlock.getNumColumns(), matrixBlock.isInSparseFormat());
            MatrixBlock matrixBlock4 = new MatrixBlock(matrixBlock);
            MatrixBlock slice = matrixBlock4.slice(0, 0, 0, matrixBlock4.getNumColumns() - 1, (CacheBlock) new MatrixBlock());
            slice.binaryOperationsInPlace(this._bop, matrixBlock2);
            matrixBlock4.copy(0, 0, 0, matrixBlock4.getNumColumns() - 1, slice, true);
            matrixBlock4.unaryOperations(this._uop, matrixBlock3);
            return matrixBlock3;
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/CumulativeOffsetSPInstruction$RDDCumSplitFunction.class */
    private static class RDDCumSplitFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -8407407527406576965L;
        private double _initValue;
        private int _brlen;
        private long _lastRowBlockIndex;

        public RDDCumSplitFunction(double d, long j, int i) {
            this._initValue = 0.0d;
            this._brlen = -1;
            this._initValue = d;
            this._brlen = i;
            this._lastRowBlockIndex = (long) Math.ceil(j / i);
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            ArrayList arrayList = new ArrayList();
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2();
            long rowIndex = (matrixIndexes.getRowIndex() - 1) * this._brlen;
            boolean z = matrixIndexes.getRowIndex() == 1;
            boolean z2 = matrixIndexes.getRowIndex() == this._lastRowBlockIndex;
            if (z) {
                MatrixIndexes matrixIndexes2 = new MatrixIndexes(1L, matrixIndexes.getColumnIndex());
                MatrixBlock matrixBlock2 = new MatrixBlock(1, matrixBlock.getNumColumns(), matrixBlock.isInSparseFormat());
                if (this._initValue != 0.0d) {
                    for (int i = 0; i < matrixBlock.getNumColumns(); i++) {
                        matrixBlock2.appendValue(0, i, this._initValue);
                    }
                }
                arrayList.add(new Tuple2(matrixIndexes2, matrixBlock2));
            }
            for (int i2 = 0; i2 < matrixBlock.getNumRows(); i2++) {
                if (!z2 || i2 != matrixBlock.getNumRows() - 1) {
                    MatrixIndexes matrixIndexes3 = new MatrixIndexes(rowIndex + i2 + 2, matrixIndexes.getColumnIndex());
                    MatrixBlock matrixBlock3 = new MatrixBlock(1, matrixBlock.getNumColumns(), matrixBlock.isInSparseFormat());
                    matrixBlock.slice(i2, i2, 0, matrixBlock.getNumColumns() - 1, (CacheBlock) matrixBlock3);
                    arrayList.add(new Tuple2(matrixIndexes3, matrixBlock3));
                }
            }
            return arrayList.iterator();
        }
    }

    private CumulativeOffsetSPInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, double d, String str, String str2) {
        super(SPInstruction.SPType.CumsumOffset, operator, cPOperand, cPOperand2, cPOperand3, str, str2);
        this._bop = null;
        this._uop = null;
        this._initValue = 0.0d;
        if ("bcumoffk+".equals(str)) {
            this._bop = new BinaryOperator(Plus.getPlusFnObject());
            this._uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+"));
        } else if ("bcumoff*".equals(str)) {
            this._bop = new BinaryOperator(Multiply.getMultiplyFnObject());
            this._uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucum*"));
        } else if ("bcumoffmin".equals(str)) {
            this._bop = new BinaryOperator(Builtin.getBuiltinFnObject("min"));
            this._uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucummin"));
        } else if ("bcumoffmax".equals(str)) {
            this._bop = new BinaryOperator(Builtin.getBuiltinFnObject("max"));
            this._uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucummax"));
        }
        this._initValue = d;
    }

    public static CumulativeOffsetSPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 4);
        return new CumulativeOffsetSPInstruction(null, new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), Double.parseDouble(instructionPartsWithValueType[4]), instructionPartsWithValueType[0], str);
    }

    @Override // org.apache.sysml.runtime.instructions.spark.SPInstruction, org.apache.sysml.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) throws DMLRuntimeException {
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) executionContext;
        MatrixCharacteristics matrixCharacteristics = sparkExecutionContext.getMatrixCharacteristics(this.input2.getName());
        JavaPairRDD<?, ?> mapValues = sparkExecutionContext.getBinaryBlockRDDHandleForVariable(this.input1.getName()).join(sparkExecutionContext.getBinaryBlockRDDHandleForVariable(this.input2.getName()).flatMapToPair(new RDDCumSplitFunction(this._initValue, matrixCharacteristics.getRows(), matrixCharacteristics.getRowsPerBlock()))).mapValues(new RDDCumOffsetFunction(this._uop, this._bop));
        updateUnaryOutputMatrixCharacteristics(sparkExecutionContext);
        sparkExecutionContext.setRDDHandleForVariable(this.output.getName(), mapValues);
        sparkExecutionContext.addLineageRDD(this.output.getName(), this.input1.getName());
        sparkExecutionContext.addLineageRDD(this.output.getName(), this.input2.getName());
    }
}
