package org.apache.sysml.runtime.controlprogram.parfor;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.spark.data.RDDObject;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.MatrixFormatMetaData;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.utils.Statistics;

/* loaded from: input_file:org/apache/sysml/runtime/controlprogram/parfor/ResultMergeRemoteSpark.class */
public class ResultMergeRemoteSpark extends ResultMerge {
    private ExecutionContext _ec;
    private int _numMappers;
    private int _numReducers;

    public ResultMergeRemoteSpark(MatrixObject matrixObject, MatrixObject[] matrixObjectArr, String str, ExecutionContext executionContext, int i, int i2) {
        super(matrixObject, matrixObjectArr, str);
        this._ec = null;
        this._numMappers = -1;
        this._numReducers = -1;
        this._ec = executionContext;
        this._numMappers = i;
        this._numReducers = i2;
    }

    @Override // org.apache.sysml.runtime.controlprogram.parfor.ResultMerge
    public MatrixObject executeSerialMerge() throws DMLRuntimeException {
        return executeParallelMerge(this._numMappers);
    }

    @Override // org.apache.sysml.runtime.controlprogram.parfor.ResultMerge
    public MatrixObject executeParallelMerge(int i) throws DMLRuntimeException {
        MatrixObject matrixObject;
        LOG.trace("ResultMerge (remote, spark): Execute serial merge for output " + this._output.getVarName() + " (fname=" + this._output.getFileName() + ")");
        try {
            if (this._inputs == null || this._inputs.length <= 0) {
                matrixObject = this._output;
            } else {
                MatrixFormatMetaData matrixFormatMetaData = (MatrixFormatMetaData) this._output.getMetaData();
                MatrixCharacteristics matrixCharacteristics = matrixFormatMetaData.getMatrixCharacteristics();
                RDDObject executeMerge = executeMerge(matrixCharacteristics.getNonZeros() == 0 ? null : this._output, this._inputs, this._output.getVarName(), matrixCharacteristics.getRows(), matrixCharacteristics.getCols(), matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock());
                String varName = this._output.getVarName();
                matrixObject = new MatrixObject(this._output.getValueType(), this._outputFName);
                matrixObject.setVarName(varName.contains("_rm") ? varName : varName + "_rm");
                matrixObject.setDataType(Expression.DataType.MATRIX);
                OutputInfo outputInfo = matrixFormatMetaData.getOutputInfo();
                InputInfo inputInfo = matrixFormatMetaData.getInputInfo();
                MatrixCharacteristics matrixCharacteristics2 = new MatrixCharacteristics(matrixCharacteristics.getRows(), matrixCharacteristics.getCols(), matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock());
                matrixCharacteristics2.setNonZeros(computeNonZeros(this._output, convertToList(this._inputs)));
                matrixObject.setMetaData(new MatrixFormatMetaData(matrixCharacteristics2, outputInfo, inputInfo));
                matrixObject.setRDDHandle(executeMerge);
            }
            return matrixObject;
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    protected RDDObject executeMerge(MatrixObject matrixObject, MatrixObject[] matrixObjectArr, String str, long j, long j2, int i, int i2) throws DMLRuntimeException {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) this._ec;
        boolean z = matrixObject != null;
        int determineNumReducers = determineNumReducers(j, j2, i, i2, this._numReducers);
        if (matrixObjectArr == null || matrixObjectArr.length == 0) {
            throw new DMLRuntimeException("Execute merge should never be called with no inputs.");
        }
        try {
            JavaPairRDD<?, ?> rDDHandleForMatrixObject = sparkExecutionContext.getRDDHandleForMatrixObject(this._inputs[0], InputInfo.BinaryBlockInputInfo);
            for (int i3 = 1; i3 < this._inputs.length; i3++) {
                rDDHandleForMatrixObject = rDDHandleForMatrixObject.union(sparkExecutionContext.getRDDHandleForMatrixObject(this._inputs[i3], InputInfo.BinaryBlockInputInfo));
            }
            RDDObject rDDObject = new RDDObject(z ? rDDHandleForMatrixObject.groupByKey(determineNumReducers).join(sparkExecutionContext.getRDDHandleForMatrixObject(matrixObject, InputInfo.BinaryBlockInputInfo)).mapToPair(new ResultMergeRemoteSparkWCompare()) : RDDAggregateUtils.mergeByKey(rDDHandleForMatrixObject), str);
            for (int i4 = 0; i4 < this._inputs.length; i4++) {
                rDDObject.addLineageChild(this._inputs[i4].getRDDHandle());
            }
            Statistics.incrementNoOfCompiledSPInst();
            Statistics.incrementNoOfExecutedSPInst();
            if (DMLScript.STATISTICS) {
                Statistics.maintainCPHeavyHitters("ParFor-RMSP", System.nanoTime() - nanoTime);
            }
            return rDDObject;
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    private int determineNumReducers(long j, long j2, int i, int i2, long j3) {
        return (int) Math.min(j3, Math.max(j / i, 1L) * Math.max(j2 / i2, 1L));
    }
}
