package org.apache.sysml.runtime.matrix;

import java.util.HashSet;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.RunningJob;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.instructions.MRInstructionParser;
import org.apache.sysml.runtime.instructions.MRJobInstruction;
import org.apache.sysml.runtime.instructions.mr.AggregateBinaryInstruction;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.runtime.matrix.data.TaggedFirstSecondIndexes;
import org.apache.sysml.runtime.matrix.mapred.MMCJMRMapper;
import org.apache.sysml.runtime.matrix.mapred.MMCJMRReducerWithAggregator;
import org.apache.sysml.runtime.matrix.mapred.MRConfigurationNames;
import org.apache.sysml.runtime.matrix.mapred.MRJobConfiguration;
import org.apache.sysml.yarn.DMLAppMasterUtils;
import org.apache.sysml.yarn.ropt.YarnClusterAnalyzer;

/* loaded from: input_file:org/apache/sysml/runtime/matrix/MMCJMR.class */
public class MMCJMR {
    private static final boolean AUTOMATIC_CONFIG_NUM_REDUCERS = true;
    private static final Log LOG = LogFactory.getLog(MMCJMR.class);

    private MMCJMR() {
    }

    public static JobReturn runJob(MRJobInstruction mRJobInstruction, String[] strArr, InputInfo[] inputInfoArr, long[] jArr, long[] jArr2, int[] iArr, int[] iArr2, String str, String str2, String str3, int i, int i2, String str4, OutputInfo outputInfo) throws Exception {
        JobConf jobConf = new JobConf(MMCJMR.class);
        boolean deriveRepresentation = MRJobConfiguration.deriveRepresentation(inputInfoArr);
        MatrixCharacteristics[] commonSetup = commonSetup(jobConf, deriveRepresentation, strArr, inputInfoArr, jArr, jArr2, iArr, iArr2, str, str2, str3, i, i2, (byte) 0, str4, outputInfo);
        if (LOG.isTraceEnabled()) {
            mRJobInstruction.printCompleteMRJobInstruction(commonSetup);
        }
        if (commonSetup[0].getRows() == -1 || commonSetup[0].getCols() == -1) {
            MRJobConfiguration.setUpMultipleOutputs(jobConf, new byte[]{MRInstructionParser.parseSingleInstruction(str3).output}, new byte[]{1}, new String[]{str4}, new OutputInfo[]{outputInfo}, deriveRepresentation);
        }
        AggregateBinaryInstruction aggregateBinaryInstruction = (AggregateBinaryInstruction) MRInstructionParser.parseSingleInstruction(str3);
        MatrixCharacteristics matrixCharactristicsForBinAgg = MRJobConfiguration.getMatrixCharactristicsForBinAgg(jobConf, aggregateBinaryInstruction.input1);
        MatrixCharacteristics matrixCharactristicsForBinAgg2 = MRJobConfiguration.getMatrixCharactristicsForBinAgg(jobConf, aggregateBinaryInstruction.input2);
        if (matrixCharactristicsForBinAgg.getRowsPerBlock() > matrixCharactristicsForBinAgg.getRows()) {
            matrixCharactristicsForBinAgg.setRowsPerBlock((int) matrixCharactristicsForBinAgg.getRows());
        }
        if (matrixCharactristicsForBinAgg.getColsPerBlock() > matrixCharactristicsForBinAgg.getCols()) {
            matrixCharactristicsForBinAgg.setColsPerBlock((int) matrixCharactristicsForBinAgg.getCols());
        }
        if (matrixCharactristicsForBinAgg2.getRowsPerBlock() > matrixCharactristicsForBinAgg2.getRows()) {
            matrixCharactristicsForBinAgg2.setRowsPerBlock((int) matrixCharactristicsForBinAgg2.getRows());
        }
        if (matrixCharactristicsForBinAgg2.getColsPerBlock() > matrixCharactristicsForBinAgg2.getCols()) {
            matrixCharactristicsForBinAgg2.setColsPerBlock((int) matrixCharactristicsForBinAgg2.getCols());
        }
        long rowsPerBlock = 77 + (8 * matrixCharactristicsForBinAgg.getRowsPerBlock() * matrixCharactristicsForBinAgg.getColsPerBlock());
        long rowsPerBlock2 = 77 + (8 * matrixCharactristicsForBinAgg2.getRowsPerBlock() * matrixCharactristicsForBinAgg2.getColsPerBlock());
        MRJobConfiguration.setMMCJCacheSize(jobConf, (int) ((matrixCharactristicsForBinAgg.getRows() < matrixCharactristicsForBinAgg2.getCols() ? (((long) Math.ceil(matrixCharactristicsForBinAgg.getRows() / matrixCharactristicsForBinAgg.getRowsPerBlock())) * (20 + rowsPerBlock)) + 32 : (((long) Math.ceil(matrixCharactristicsForBinAgg2.getCols() / matrixCharactristicsForBinAgg2.getColsPerBlock())) * (20 + rowsPerBlock2)) + 32) + (2 * Math.max(rowsPerBlock, rowsPerBlock2)) + 77 + (8 * matrixCharactristicsForBinAgg.getRowsPerBlock() * matrixCharactristicsForBinAgg2.getColsPerBlock()) + MRJobConfiguration.getMiscMemRequired(jobConf)));
        MRJobConfiguration.setUniqueWorkingDir(jobConf);
        RunningJob runJob = JobClient.runJob(jobConf);
        commonSetup[0].setNonZeros(runJob.getCounters().getGroup(MRJobConfiguration.NUM_NONZERO_CELLS).getCounter(Byte.toString(Byte.valueOf(MRInstructionParser.parseSingleInstruction(str3).output).byteValue())));
        return new JobReturn(commonSetup[0], outputInfo, runJob.isSuccessful());
    }

    private static MatrixCharacteristics[] commonSetup(JobConf jobConf, boolean z, String[] strArr, InputInfo[] inputInfoArr, long[] jArr, long[] jArr2, int[] iArr, int[] iArr2, String str, String str2, String str3, int i, int i2, byte b, String str4, OutputInfo outputInfo) throws Exception {
        jobConf.setJobName("MMCJ-MR");
        if (i <= 0) {
            throw new Exception("MMCJ-MR has to have at least one reduce task!");
        }
        MRJobConfiguration.setMatrixValueClass(jobConf, z);
        byte[] bArr = new byte[strArr.length];
        byte b2 = 0;
        while (true) {
            byte b3 = b2;
            if (b3 >= bArr.length) {
                break;
            }
            bArr[b3] = b3;
            b2 = (byte) (b3 + 1);
        }
        MRJobConfiguration.setUpMultipleInputs(jobConf, bArr, strArr, inputInfoArr, iArr, iArr2, true, z ? MRJobConfiguration.ConvertTarget.BLOCK : MRJobConfiguration.ConvertTarget.CELL);
        MRJobConfiguration.setMatricesDimensions(jobConf, bArr, jArr, jArr2);
        MRJobConfiguration.setBlocksSizes(jobConf, bArr, iArr, iArr2);
        MRJobConfiguration.setInstructionsInMapper(jobConf, str);
        MRJobConfiguration.setAggregateInstructions(jobConf, str2);
        MRJobConfiguration.setAggregateBinaryInstructions(jobConf, str3);
        jobConf.setInt(MRConfigurationNames.DFS_REPLICATION, i2);
        MRJobConfiguration.addBinaryBlockSerializationFramework(jobConf);
        DMLConfig dMLConfig = ConfigurationManager.getDMLConfig();
        DMLAppMasterUtils.setupMRJobRemoteMaxMemory(jobConf, dMLConfig);
        MRJobConfiguration.setupCustomMRConfigurations(jobConf, dMLConfig);
        byte[] bArr2 = {MRInstructionParser.parseSingleInstruction(str3).output};
        HashSet<Byte> upOutputIndexesForMapper = MRJobConfiguration.setUpOutputIndexesForMapper(jobConf, bArr, str, str2, str3, bArr2);
        MRJobConfiguration.setUpMultipleOutputs(jobConf, bArr2, new byte[]{b}, new String[]{str4}, new OutputInfo[]{outputInfo}, z);
        jobConf.setMapperClass(MMCJMRMapper.class);
        jobConf.setMapOutputKeyClass(TaggedFirstSecondIndexes.class);
        if (z) {
            jobConf.setMapOutputValueClass(MatrixBlock.class);
        } else {
            jobConf.setMapOutputValueClass(MatrixCell.class);
        }
        jobConf.setOutputKeyComparatorClass(TaggedFirstSecondIndexes.Comparator.class);
        jobConf.setPartitionerClass(TaggedFirstSecondIndexes.FirstIndexPartitioner.class);
        MRJobConfiguration.MatrixChar_N_ReducerGroups computeMatrixCharacteristics = MRJobConfiguration.computeMatrixCharacteristics(jobConf, bArr, str, str2, str3, null, bArr2, upOutputIndexesForMapper, true);
        jobConf.setNumReduceTasks(determineNumReducers(jArr, jArr2, i, computeMatrixCharacteristics.numReducerGroups));
        jobConf.setReducerClass(MMCJMRReducerWithAggregator.class);
        return computeMatrixCharacteristics.stats;
    }

    protected static int determineNumReducers(long[] jArr, long[] jArr2, int i, long j) {
        long remoteParallelReduceTasks = InfrastructureAnalyzer.getRemoteParallelReduceTasks();
        long hDFSBlockSize = InfrastructureAnalyzer.getHDFSBlockSize() / 1048576;
        long j2 = -1;
        for (int i2 = 0; i2 < jArr.length; i2++) {
            j2 = Math.max(j2, MatrixBlock.estimateSizeOnDisk(jArr[i2], jArr2[i2], jArr[i2] * jArr2[i2]) / 1048576);
        }
        if (InfrastructureAnalyzer.isYarnEnabled()) {
            remoteParallelReduceTasks = Math.max(remoteParallelReduceTasks, YarnClusterAnalyzer.getNumCores() / 2);
        }
        return Math.max((int) Math.min((int) Math.max(i, Math.min(j2 / hDFSBlockSize, remoteParallelReduceTasks)), j), 1);
    }
}
