package org.apache.sysml.runtime.matrix;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.parser.DataExpression;
import org.apache.sysml.runtime.instructions.MRJobInstruction;
import org.apache.sysml.runtime.instructions.mr.CombineBinaryInstruction;
import org.apache.sysml.runtime.instructions.mr.CombineTernaryInstruction;
import org.apache.sysml.runtime.instructions.mr.MRInstruction;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.runtime.matrix.data.Pair;
import org.apache.sysml.runtime.matrix.data.TaggedMatrixBlock;
import org.apache.sysml.runtime.matrix.data.TaggedMatrixCell;
import org.apache.sysml.runtime.matrix.data.TaggedMatrixValue;
import org.apache.sysml.runtime.matrix.data.WeightedPair;
import org.apache.sysml.runtime.matrix.mapred.GMRMapper;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.mapred.MRConfigurationNames;
import org.apache.sysml.runtime.matrix.mapred.MRJobConfiguration;
import org.apache.sysml.runtime.matrix.mapred.ReduceBase;
import org.apache.sysml.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysml/runtime/matrix/CombineMR.class */
public class CombineMR {
    private static final Log LOG = LogFactory.getLog(CombineMR.class.getName());

    /* loaded from: input_file:org/apache/sysml/runtime/matrix/CombineMR$InnerReducer.class */
    public static class InnerReducer extends ReduceBase implements Reducer<MatrixIndexes, TaggedMatrixValue, MatrixIndexes, WeightedPair> {
        protected MRInstruction[] comb_instructions = null;
        private MatrixIndexes keyBuff = new MatrixIndexes();
        private WeightedPair valueBuff = new WeightedPair();
        private HashMap<Byte, Pair<Integer, Integer>> outputBlockSizes = new HashMap<>();
        private HashMap<Byte, ArrayList<Integer>> outputIndexesMapping = new HashMap<>();

        public void reduce(MatrixIndexes matrixIndexes, Iterator<TaggedMatrixValue> it, OutputCollector<MatrixIndexes, WeightedPair> outputCollector, Reporter reporter) throws IOException {
            long currentTimeMillis = System.currentTimeMillis();
            if (this.firsttime) {
                this.cachedReporter = reporter;
                this.firsttime = false;
            }
            this.cachedValues.reset();
            while (it.hasNext()) {
                TaggedMatrixValue next = it.next();
                this.cachedValues.set(next.getTag(), matrixIndexes, next.getBaseObject(), true);
            }
            processCombineInstructionsAndOutput(reporter);
            reporter.incrCounter(ReduceBase.Counters.COMBINE_OR_REDUCE_TIME, System.currentTimeMillis() - currentTimeMillis);
        }

        @Override // org.apache.sysml.runtime.matrix.mapred.ReduceBase, org.apache.sysml.runtime.matrix.mapred.MRBaseForCommonInstructions
        public void configure(JobConf jobConf) {
            super.configure(jobConf);
            try {
                this.comb_instructions = MRJobConfiguration.getCombineInstruction(jobConf);
                for (int i = 0; i < this.resultIndexes.length; i++) {
                    MatrixCharacteristics matrixCharacteristicsForOutput = MRJobConfiguration.getMatrixCharacteristicsForOutput(jobConf, this.resultIndexes[i]);
                    this.outputBlockSizes.put(Byte.valueOf(this.resultIndexes[i]), new Pair<>(Integer.valueOf(matrixCharacteristicsForOutput.getRowsPerBlock()), Integer.valueOf(matrixCharacteristicsForOutput.getColsPerBlock())));
                }
                for (MRInstruction mRInstruction : this.comb_instructions) {
                    this.outputIndexesMapping.put(Byte.valueOf(mRInstruction.output), getOutputIndexes(mRInstruction.output));
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        void processCombineInstructionsAndOutput(Reporter reporter) throws IOException {
            for (MRInstruction mRInstruction : this.comb_instructions) {
                if (mRInstruction instanceof CombineBinaryInstruction) {
                    processBinaryCombineInstruction((CombineBinaryInstruction) mRInstruction, reporter);
                } else {
                    if (!(mRInstruction instanceof CombineTernaryInstruction)) {
                        throw new IOException("unsupported instruction: " + mRInstruction);
                    }
                    processTernaryCombineInstruction((CombineTernaryInstruction) mRInstruction, reporter);
                }
            }
        }

        private void processTernaryCombineInstruction(CombineTernaryInstruction combineTernaryInstruction, Reporter reporter) throws IOException {
            int numRows;
            int numColumns;
            IndexedMatrixValue first = this.cachedValues.getFirst(combineTernaryInstruction.input1);
            IndexedMatrixValue first2 = this.cachedValues.getFirst(combineTernaryInstruction.input2);
            IndexedMatrixValue first3 = this.cachedValues.getFirst(combineTernaryInstruction.input3);
            if (first == null && first2 == null && first3 == null) {
                return;
            }
            if (first != null) {
                numRows = first.getValue().getNumRows();
                numColumns = first.getValue().getNumColumns();
            } else if (first2 != null) {
                numRows = first2.getValue().getNumRows();
                numColumns = first2.getValue().getNumColumns();
            } else {
                numRows = first3.getValue().getNumRows();
                numColumns = first3.getValue().getNumColumns();
            }
            if (first == null) {
                first = this.zeroInput;
                first.getValue().reset(numRows, numColumns);
            }
            if (first2 == null) {
                first2 = this.zeroInput;
                first2.getValue().reset(numRows, numColumns);
            }
            if (first3 == null) {
                first3 = this.zeroInput;
                first3.getValue().reset(numRows, numColumns);
            }
            try {
                ArrayList<Integer> arrayList = this.outputIndexesMapping.get(Byte.valueOf(combineTernaryInstruction.output));
                for (int i = 0; i < numRows; i++) {
                    for (int i2 = 0; i2 < numColumns; i2++) {
                        Pair<Integer, Integer> pair = this.outputBlockSizes.get(Byte.valueOf(combineTernaryInstruction.output));
                        this.keyBuff.setIndexes(UtilFunctions.computeCellIndex(first.getIndexes().getRowIndex(), pair.getKey().intValue(), i), UtilFunctions.computeCellIndex(first.getIndexes().getColumnIndex(), pair.getValue().intValue(), i2));
                        this.valueBuff.setValue(first.getValue().getValue(i, i2));
                        this.valueBuff.setOtherValue(first2.getValue().getValue(i, i2));
                        this.valueBuff.setWeight(first3.getValue().getValue(i, i2));
                        Iterator<Integer> it = arrayList.iterator();
                        while (it.hasNext()) {
                            this.collectFinalMultipleOutputs.collectOutput(this.keyBuff, this.valueBuff, it.next().intValue(), reporter);
                        }
                    }
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        private void processBinaryCombineInstruction(CombineBinaryInstruction combineBinaryInstruction, Reporter reporter) throws IOException {
            IndexedMatrixValue first = this.cachedValues.getFirst(combineBinaryInstruction.input1);
            IndexedMatrixValue first2 = this.cachedValues.getFirst(combineBinaryInstruction.input2);
            if (first == null && first2 == null) {
                return;
            }
            MatrixIndexes indexes = first != null ? first.getIndexes() : first2.getIndexes();
            if (first == null) {
                first = this.zeroInput;
                first.getValue().reset(first2.getValue().getNumRows(), first2.getValue().getNumColumns());
            }
            if (first2 == null) {
                first2 = this.zeroInput;
                first2.getValue().reset(first.getValue().getNumRows(), first.getValue().getNumColumns());
            }
            try {
                ArrayList<Integer> arrayList = this.outputIndexesMapping.get(Byte.valueOf(combineBinaryInstruction.output));
                for (int i = 0; i < first.getValue().getNumRows(); i++) {
                    for (int i2 = 0; i2 < first.getValue().getNumColumns(); i2++) {
                        Pair<Integer, Integer> pair = this.outputBlockSizes.get(Byte.valueOf(combineBinaryInstruction.output));
                        this.keyBuff.setIndexes(UtilFunctions.computeCellIndex(indexes.getRowIndex(), pair.getKey().intValue(), i), UtilFunctions.computeCellIndex(indexes.getColumnIndex(), pair.getValue().intValue(), i2));
                        this.valueBuff.setValue(first.getValue().getValue(i, i2));
                        double value = first2.getValue().getValue(i, i2);
                        if (combineBinaryInstruction.isSecondInputWeight()) {
                            this.valueBuff.setWeight(value);
                            this.valueBuff.setOtherValue(DataExpression.DEFAULT_DELIM_FILL_VALUE);
                        } else {
                            this.valueBuff.setWeight(1.0d);
                            this.valueBuff.setOtherValue(value);
                        }
                        Iterator<Integer> it = arrayList.iterator();
                        while (it.hasNext()) {
                            this.collectFinalMultipleOutputs.collectOutput(this.keyBuff, this.valueBuff, it.next().intValue(), reporter);
                        }
                    }
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        public /* bridge */ /* synthetic */ void reduce(Object obj, Iterator it, OutputCollector outputCollector, Reporter reporter) throws IOException {
            reduce((MatrixIndexes) obj, (Iterator<TaggedMatrixValue>) it, (OutputCollector<MatrixIndexes, WeightedPair>) outputCollector, reporter);
        }
    }

    private CombineMR() {
    }

    public static JobReturn runJob(MRJobInstruction mRJobInstruction, String[] strArr, InputInfo[] inputInfoArr, long[] jArr, long[] jArr2, int[] iArr, int[] iArr2, String str, int i, int i2, byte[] bArr, String[] strArr2, OutputInfo[] outputInfoArr) throws Exception {
        JobConf jobConf = new JobConf(CombineMR.class);
        jobConf.setJobName("Standalone-MR");
        boolean deriveRepresentation = MRJobConfiguration.deriveRepresentation(inputInfoArr);
        MRJobConfiguration.setMatrixValueClass(jobConf, deriveRepresentation);
        byte[] bArr2 = new byte[strArr.length];
        byte b = 0;
        while (true) {
            byte b2 = b;
            if (b2 >= strArr.length) {
                break;
            }
            bArr2[b2] = b2;
            b = (byte) (b2 + 1);
        }
        MRJobConfiguration.setUpMultipleInputs(jobConf, bArr2, strArr, inputInfoArr, iArr, iArr2, true, deriveRepresentation ? MRJobConfiguration.ConvertTarget.BLOCK : MRJobConfiguration.ConvertTarget.CELL);
        MRJobConfiguration.setMatricesDimensions(jobConf, bArr2, jArr, jArr2);
        MRJobConfiguration.setBlocksSizes(jobConf, bArr2, iArr, iArr2);
        MRJobConfiguration.setInstructionsInMapper(jobConf, "");
        MRJobConfiguration.setAggregateInstructions(jobConf, "");
        MRJobConfiguration.setInstructionsInReducer(jobConf, "");
        MRJobConfiguration.setCombineInstructions(jobConf, str);
        jobConf.setInt(MRConfigurationNames.DFS_REPLICATION, i2);
        MRJobConfiguration.setupCustomMRConfigurations(jobConf, ConfigurationManager.getDMLConfig());
        HashSet<Byte> upOutputIndexesForMapper = MRJobConfiguration.setUpOutputIndexesForMapper(jobConf, bArr2, null, null, str, bArr);
        MRJobConfiguration.setUpMultipleOutputs(jobConf, bArr, null, strArr2, outputInfoArr, deriveRepresentation);
        jobConf.setMapperClass(GMRMapper.class);
        jobConf.setMapOutputKeyClass(MatrixIndexes.class);
        if (deriveRepresentation) {
            jobConf.setMapOutputValueClass(TaggedMatrixBlock.class);
        } else {
            jobConf.setMapOutputValueClass(TaggedMatrixCell.class);
        }
        jobConf.setReducerClass(InnerReducer.class);
        MRJobConfiguration.MatrixChar_N_ReducerGroups computeMatrixCharacteristics = MRJobConfiguration.computeMatrixCharacteristics(jobConf, bArr2, null, null, null, str, bArr, upOutputIndexesForMapper, false);
        MatrixCharacteristics[] matrixCharacteristicsArr = computeMatrixCharacteristics.stats;
        MRJobConfiguration.setNumReducers(jobConf, computeMatrixCharacteristics.numReducerGroups, i);
        if (LOG.isTraceEnabled()) {
            mRJobInstruction.printCompleteMRJobInstruction(matrixCharacteristicsArr);
        }
        MatrixCharacteristics[] matrixCharacteristicsArr2 = new MatrixCharacteristics[strArr.length];
        for (int i3 = 0; i3 < strArr.length; i3++) {
            matrixCharacteristicsArr2[i3] = new MatrixCharacteristics(jArr[i3], jArr2[i3], iArr[i3], iArr2[i3]);
        }
        MRJobConfiguration.setUniqueWorkingDir(jobConf);
        return new JobReturn(matrixCharacteristicsArr, JobClient.runJob(jobConf).isSuccessful());
    }
}
