package org.apache.sysml.runtime.controlprogram.parfor;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.runtime.util.MapReduceTool;
import org.apache.sysml.utils.Statistics;

/* loaded from: input_file:org/apache/sysml/runtime/controlprogram/parfor/DataPartitionerRemoteSpark.class */
public class DataPartitionerRemoteSpark extends DataPartitioner {
    private final ExecutionContext _ec;
    private final long _numRed;
    private final int _replication;

    public DataPartitionerRemoteSpark(ParForProgramBlock.PartitionFormat partitionFormat, ExecutionContext executionContext, long j, int i, boolean z) {
        super(partitionFormat._dpf, partitionFormat._N);
        this._ec = executionContext;
        this._numRed = j;
        this._replication = i;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.sysml.runtime.controlprogram.parfor.DataPartitioner
    protected void partitionMatrix(MatrixObject matrixObject, String str, InputInfo inputInfo, OutputInfo outputInfo, long j, long j2, int i, int i2) throws DMLRuntimeException {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        SparkExecutionContext sparkExecutionContext = (SparkExecutionContext) this._ec;
        try {
            MapReduceTool.deleteFileIfExistOnHDFS(str);
            JavaPairRDD<?, ?> rDDHandleForMatrixObject = sparkExecutionContext.getRDDHandleForMatrixObject(matrixObject, InputInfo.BinaryBlockInputInfo);
            MatrixCharacteristics matrixCharacteristics = matrixObject.getMatrixCharacteristics();
            rDDHandleForMatrixObject.flatMapToPair(new DataPartitionerRemoteSparkMapper(matrixCharacteristics, inputInfo, outputInfo, this._format, this._n)).groupByKey((int) determineNumReducers(rDDHandleForMatrixObject, matrixCharacteristics, this._numRed)).foreach(new DataPartitionerRemoteSparkReducer(str, outputInfo, this._replication));
            Statistics.incrementNoOfCompiledSPInst();
            Statistics.incrementNoOfExecutedSPInst();
            if (DMLScript.STATISTICS) {
                Statistics.maintainCPHeavyHitters("ParFor-DPSP", System.nanoTime() - nanoTime);
            }
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    private long determineNumReducers(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, MatrixCharacteristics matrixCharacteristics, long j) {
        long rows = matrixCharacteristics.getRows();
        long cols = matrixCharacteristics.getCols();
        int rowsPerBlock = matrixCharacteristics.getRowsPerBlock();
        int colsPerBlock = matrixCharacteristics.getColsPerBlock();
        long j2 = -1;
        switch (this._format) {
            case ROW_WISE:
                j2 = rows;
                break;
            case COLUMN_WISE:
                j2 = cols;
                break;
            case ROW_BLOCK_WISE:
                j2 = (rows / rowsPerBlock) + (rows % ((long) rowsPerBlock) == 0 ? 0 : 1);
                break;
            case COLUMN_BLOCK_WISE:
                j2 = (cols / colsPerBlock) + (cols % ((long) colsPerBlock) == 0 ? 0 : 1);
                break;
            case ROW_BLOCK_WISE_N:
                j2 = (rows / this._n) + (rows % ((long) this._n) == 0 ? 0 : 1);
                break;
            case COLUMN_BLOCK_WISE_N:
                j2 = (cols / this._n) + (cols % ((long) this._n) == 0 ? 0 : 1);
                break;
        }
        return Math.max(j, Math.min(SparkUtils.getNumPreferredPartitions(matrixCharacteristics, javaPairRDD), j2));
    }
}
