package org.apache.sysml.runtime.controlprogram.parfor;

import java.util.Arrays;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.JobConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.spark.data.RDDObject;
import org.apache.sysml.runtime.instructions.spark.functions.CopyBlockPairFunction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.MetaDataFormat;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.utils.Statistics;

/* loaded from: input_file:org/apache/sysml/runtime/controlprogram/parfor/ResultMergeRemoteSpark.class */
public class ResultMergeRemoteSpark extends ResultMerge {
    private static final long serialVersionUID = -6924566953903424820L;
    private ExecutionContext _ec;
    private int _numMappers;
    private int _numReducers;

    public ResultMergeRemoteSpark(MatrixObject matrixObject, MatrixObject[] matrixObjectArr, String str, boolean z, ExecutionContext executionContext, int i, int i2) {
        super(matrixObject, matrixObjectArr, str, z);
        this._ec = null;
        this._numMappers = -1;
        this._numReducers = -1;
        this._ec = executionContext;
        this._numMappers = i;
        this._numReducers = i2;
    }

    @Override // org.apache.sysml.runtime.controlprogram.parfor.ResultMerge
    public MatrixObject executeSerialMerge() throws DMLRuntimeException {
        return executeParallelMerge(this._numMappers);
    }

    @Override // org.apache.sysml.runtime.controlprogram.parfor.ResultMerge
    public MatrixObject executeParallelMerge(int i) throws DMLRuntimeException {
        MatrixObject matrixObject;
        if (LOG.isTraceEnabled()) {
            LOG.trace("ResultMerge (remote, spark): Execute serial merge for output " + this._output.hashCode() + " (fname=" + this._output.getFileName() + ")");
        }
        try {
            if (this._inputs == null || this._inputs.length <= 0) {
                matrixObject = this._output;
            } else {
                MetaDataFormat metaDataFormat = (MetaDataFormat) this._output.getMetaData();
                MatrixCharacteristics matrixCharacteristics = metaDataFormat.getMatrixCharacteristics();
                RDDObject executeMerge = executeMerge(matrixCharacteristics.getNonZeros() == 0 ? null : this._output, this._inputs, matrixCharacteristics.getRows(), matrixCharacteristics.getCols(), matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock());
                matrixObject = new MatrixObject(this._output.getValueType(), this._outputFName);
                OutputInfo outputInfo = metaDataFormat.getOutputInfo();
                InputInfo inputInfo = metaDataFormat.getInputInfo();
                MatrixCharacteristics matrixCharacteristics2 = new MatrixCharacteristics(matrixCharacteristics);
                matrixCharacteristics2.setNonZeros(this._isAccum ? -1L : computeNonZeros(this._output, Arrays.asList(this._inputs)));
                matrixObject.setMetaData(new MetaDataFormat(matrixCharacteristics2, outputInfo, inputInfo));
                matrixObject.setRDDHandle(executeMerge);
            }
            return matrixObject;
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    protected RDDObject executeMerge(MatrixObject matrixObject, MatrixObject[] matrixObjectArr, long j, long j2, int i, int i2) throws DMLRuntimeException {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) this._ec;
        boolean z = matrixObject != null;
        int determineNumReducers = determineNumReducers(j, j2, i, i2, this._numReducers);
        if (matrixObjectArr == null || matrixObjectArr.length == 0) {
            throw new DMLRuntimeException("Execute merge should never be called with no inputs.");
        }
        try {
            InputInfo inputInfo = InputInfo.BinaryBlockInputInfo;
            JobConf jobConf = new JobConf(ResultMergeRemoteMR.class);
            jobConf.setJobName("ParFor-RMSP");
            jobConf.setInputFormat(inputInfo.inputFormatClass);
            Path[] pathArr = new Path[matrixObjectArr.length];
            for (int i3 = 0; i3 < pathArr.length; i3++) {
                matrixObjectArr[i3].exportData();
                pathArr[i3] = new Path(matrixObjectArr[i3].getFileName());
                setRDDHandleForMerge(matrixObjectArr[i3], sparkExecutionContext);
            }
            FileInputFormat.setInputPaths(jobConf, pathArr);
            JavaPairRDD mapPartitionsToPair = sparkExecutionContext.getSparkContext().hadoopRDD(jobConf, inputInfo.inputFormatClass, inputInfo.inputKeyClass, inputInfo.inputValueClass).mapPartitionsToPair(new CopyBlockPairFunction(true), true);
            RDDObject rDDObject = new RDDObject(z ? mapPartitionsToPair.groupByKey(determineNumReducers).join(sparkExecutionContext.getRDDHandleForMatrixObject(matrixObject, InputInfo.BinaryBlockInputInfo)).mapToPair(new ResultMergeRemoteSparkWCompare(this._isAccum)) : this._isAccum ? RDDAggregateUtils.sumByKeyStable(mapPartitionsToPair, false) : RDDAggregateUtils.mergeByKey(mapPartitionsToPair, false));
            for (int i4 = 0; i4 < pathArr.length; i4++) {
                rDDObject.addLineageChild(matrixObjectArr[i4].getRDDHandle());
            }
            if (z) {
                rDDObject.addLineageChild(matrixObject.getRDDHandle());
            }
            Statistics.incrementNoOfCompiledSPInst();
            Statistics.incrementNoOfExecutedSPInst();
            if (DMLScript.STATISTICS) {
                Statistics.maintainCPHeavyHitters("ParFor-RMSP", System.nanoTime() - nanoTime);
            }
            return rDDObject;
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static int determineNumReducers(long j, long j2, int i, int i2, long j3) {
        return (int) Math.min(j3, Math.max(j / i, 1L) * Math.max(j2 / i2, 1L));
    }

    private static void setRDDHandleForMerge(MatrixObject matrixObject, SparkExecutionContext sparkExecutionContext) {
        InputInfo inputInfo = InputInfo.BinaryBlockInputInfo;
        RDDObject rDDObject = new RDDObject(sparkExecutionContext.getSparkContext().hadoopFile(matrixObject.getFileName(), inputInfo.inputFormatClass, inputInfo.inputKeyClass, inputInfo.inputValueClass));
        rDDObject.setHDFSFile(true);
        matrixObject.setRDDHandle(rDDObject);
    }
}
