package org.apache.sysml.runtime.matrix.mapred;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reporter;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.mr.AggregateInstruction;
import org.apache.sysml.runtime.instructions.mr.MRInstruction;
import org.apache.sysml.runtime.instructions.mr.TernaryInstruction;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.data.TaggedMatrixValue;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.util.MapReduceTool;

/* loaded from: input_file:org/apache/sysml/runtime/matrix/mapred/ReduceBase.class */
public class ReduceBase extends MRBaseForCommonInstructions {
    protected static final AggregateOperator DEFAULT_AGG_OP = new AggregateOperator(0.0d, Plus.getPlusFnObject());
    protected CollectMultipleConvertedOutputs collectFinalMultipleOutputs;
    protected String dimsUnknownFilePrefix;
    protected String reducerID;
    protected HashMap<Byte, ArrayList<AggregateInstruction>> agg_instructions = new HashMap<>();
    protected AggregateInstruction defaultAggIns = new AggregateInstruction(DEFAULT_AGG_OP, (byte) 0, (byte) 0, "DEFAULT_AGG_OP");
    protected ArrayList<MRInstruction> mixed_instructions = null;
    protected byte[] resultIndexes = null;
    protected byte[] resultDimsUnknown = null;
    protected long[] resultsNonZeros = null;
    protected long[] resultsMaxRowDims = null;
    protected long[] resultsMaxColDims = null;
    protected Reporter cachedReporter = null;
    protected boolean firsttime = true;
    protected CachedValueMap correctionCache = new CachedValueMap();

    /* loaded from: input_file:org/apache/sysml/runtime/matrix/mapred/ReduceBase$Counters.class */
    public enum Counters {
        COMBINE_OR_REDUCE_TIME
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void commonSetup(Reporter reporter) {
        if (this.firsttime) {
            this.cachedReporter = reporter;
            this.firsttime = false;
        }
    }

    @Override // org.apache.sysml.runtime.matrix.mapred.MRBaseForCommonInstructions, org.apache.hadoop.mapred.MapReduceBase, org.apache.hadoop.mapred.JobConfigurable
    public void configure(JobConf jobConf) {
        super.configure(jobConf);
        this.reducerID = jobConf.get("mapred.task.id");
        this.dimsUnknownFilePrefix = jobConf.get("dims.unknown.file.prefix");
        this.resultIndexes = MRJobConfiguration.getResultIndexes(jobConf);
        this.resultDimsUnknown = MRJobConfiguration.getResultDimsUnknown(jobConf);
        this.resultsNonZeros = new long[this.resultIndexes.length];
        this.resultsMaxRowDims = new long[this.resultIndexes.length];
        this.resultsMaxColDims = new long[this.resultIndexes.length];
        this.collectFinalMultipleOutputs = MRJobConfiguration.getMultipleConvertedOutputs(jobConf);
        try {
            AggregateInstruction[] aggregateInstructions = MRJobConfiguration.getAggregateInstructions(jobConf);
            MRInstruction[] instructionsInReducer = MRJobConfiguration.getInstructionsInReducer(jobConf);
            if (instructionsInReducer != null) {
                this.mixed_instructions = new ArrayList<>();
                Collections.addAll(this.mixed_instructions, instructionsInReducer);
            }
            try {
                setupDistCacheFiles(jobConf);
                if (aggregateInstructions != null) {
                    for (AggregateInstruction aggregateInstruction : aggregateInstructions) {
                        ArrayList<AggregateInstruction> arrayList = this.agg_instructions.get(Byte.valueOf(aggregateInstruction.input));
                        if (arrayList == null) {
                            arrayList = new ArrayList<>();
                            this.agg_instructions.put(Byte.valueOf(aggregateInstruction.input), arrayList);
                        }
                        arrayList.add(aggregateInstruction);
                        if (aggregateInstruction.input != aggregateInstruction.output) {
                            AggregateInstruction aggregateInstruction2 = new AggregateInstruction(aggregateInstruction.getOperator(), aggregateInstruction.output, aggregateInstruction.output, aggregateInstruction.toString());
                            ArrayList<AggregateInstruction> arrayList2 = this.agg_instructions.get(Byte.valueOf(aggregateInstruction2.input));
                            if (arrayList2 == null) {
                                arrayList2 = new ArrayList<>();
                                this.agg_instructions.put(Byte.valueOf(aggregateInstruction2.input), arrayList2);
                            }
                            arrayList2.add(aggregateInstruction2);
                        }
                    }
                }
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        } catch (DMLRuntimeException e2) {
            throw new RuntimeException(e2);
        } catch (DMLUnsupportedOperationException e3) {
            throw new RuntimeException(e3);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void collectOutput_N_Increase_Counter(MatrixIndexes matrixIndexes, MatrixValue matrixValue, int i, Reporter reporter) throws IOException {
        collectOutput_N_Increase_Counter(matrixIndexes, matrixValue, i, reporter, this.collectFinalMultipleOutputs, this.resultDimsUnknown, this.resultsNonZeros, this.resultsMaxRowDims, this.resultsMaxColDims);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ArrayList<Integer> getOutputIndexes(byte b) {
        ArrayList<Integer> arrayList = new ArrayList<>();
        for (int i = 0; i < this.resultIndexes.length; i++) {
            if (this.resultIndexes[i] == b) {
                arrayList.add(Integer.valueOf(i));
            }
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static ArrayList<Integer> getOutputIndexes(byte b, byte[] bArr) {
        ArrayList<Integer> arrayList = new ArrayList<>();
        for (int i = 0; i < bArr.length; i++) {
            if (bArr[i] == b) {
                arrayList.add(Integer.valueOf(i));
            }
        }
        return arrayList;
    }

    @Override // org.apache.hadoop.mapred.MapReduceBase, java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        int parseInt;
        if (this.cachedReporter != null) {
            String[] split = this.reducerID.split("_");
            String str = "job_" + split[1] + "_" + split[2];
            if (split[0].equalsIgnoreCase("task")) {
                parseInt = Integer.parseInt(split[split.length - 1]);
            } else {
                if (!split[0].equalsIgnoreCase("attempt")) {
                    throw new RuntimeException("Unrecognized format for reducerID: " + this.reducerID);
                }
                parseInt = Integer.parseInt(split[split.length - 2]);
            }
            boolean z = false;
            for (int i = 0; i < this.resultIndexes.length; i++) {
                this.cachedReporter.incrCounter(MRJobConfiguration.NUM_NONZERO_CELLS, Integer.toString(i), this.resultsNonZeros[i]);
                if (this.resultDimsUnknown != null && this.resultDimsUnknown[i] != 0) {
                    z = true;
                }
            }
            if (z) {
                MapReduceTool.writeDimsFile(this.dimsUnknownFilePrefix + "/" + str + "_dimsFile/r_" + parseInt, this.resultDimsUnknown, this.resultsMaxRowDims, this.resultsMaxColDims);
            }
        }
        this.collectFinalMultipleOutputs.close();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void processReducerInstructions() throws IOException {
        try {
            processMixedInstructions(this.mixed_instructions);
        } catch (Exception e) {
            throw new IOException(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void outputInCombinerFromCachedValues(MatrixIndexes matrixIndexes, TaggedMatrixValue taggedMatrixValue, OutputCollector<MatrixIndexes, TaggedMatrixValue> outputCollector) throws IOException {
        Iterator<Byte> it = this.cachedValues.getIndexesOfAll().iterator();
        while (it.hasNext()) {
            byte byteValue = it.next().byteValue();
            ArrayList<IndexedMatrixValue> arrayList = this.cachedValues.get(byteValue);
            if (arrayList != null) {
                Iterator<IndexedMatrixValue> it2 = arrayList.iterator();
                while (it2.hasNext()) {
                    taggedMatrixValue.setBaseObject(it2.next().getValue());
                    taggedMatrixValue.setTag(byteValue);
                    outputCollector.collect(matrixIndexes, taggedMatrixValue);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void outputResultsFromCachedValues(Reporter reporter) throws IOException {
        for (int i = 0; i < this.resultIndexes.length; i++) {
            ArrayList<IndexedMatrixValue> arrayList = this.cachedValues.get(this.resultIndexes[i]);
            if (arrayList != null) {
                Iterator<IndexedMatrixValue> it = arrayList.iterator();
                while (it.hasNext()) {
                    IndexedMatrixValue next = it.next();
                    collectOutput_N_Increase_Counter(next.getIndexes(), next.getValue(), i, reporter);
                }
            }
        }
    }

    private void processAggregateHelp(long j, long j2, MatrixValue matrixValue, AggregateInstruction aggregateInstruction, boolean z) throws DMLUnsupportedOperationException, DMLRuntimeException {
        AggregateOperator aggregateOperator = (AggregateOperator) aggregateInstruction.getOperator();
        IndexedMatrixValue first = this.cachedValues.getFirst(aggregateInstruction.output);
        IndexedMatrixValue indexedMatrixValue = null;
        if (aggregateOperator.correctionExists) {
            indexedMatrixValue = this.correctionCache.getFirst(aggregateInstruction.output);
        }
        if (first == null) {
            first = this.cachedValues.holdPlace(aggregateInstruction.output, this.valueClass);
            first.getIndexes().setIndexes(j, j2);
            if (aggregateOperator.correctionExists) {
                if (indexedMatrixValue == null) {
                    indexedMatrixValue = this.correctionCache.holdPlace(aggregateInstruction.output, this.valueClass);
                }
                OperationsOnMatrixValues.startAggregation(first.getValue(), indexedMatrixValue.getValue(), aggregateOperator, matrixValue.getNumRows(), matrixValue.getNumColumns(), matrixValue.isInSparseFormat(), z);
            } else {
                OperationsOnMatrixValues.startAggregation(first.getValue(), null, aggregateOperator, matrixValue.getNumRows(), matrixValue.getNumColumns(), matrixValue.isInSparseFormat(), z);
            }
        }
        if (aggregateOperator.correctionExists) {
            OperationsOnMatrixValues.incrementalAggregation(first.getValue(), indexedMatrixValue.getValue(), matrixValue, (AggregateOperator) aggregateInstruction.getOperator(), z);
        } else {
            OperationsOnMatrixValues.incrementalAggregation(first.getValue(), null, matrixValue, (AggregateOperator) aggregateInstruction.getOperator(), z);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void processAggregateInstructions(MatrixIndexes matrixIndexes, Iterator<TaggedMatrixValue> it) throws IOException {
        processAggregateInstructions(matrixIndexes, it, false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void processAggregateInstructions(MatrixIndexes matrixIndexes, Iterator<TaggedMatrixValue> it, boolean z) throws IOException {
        while (it.hasNext()) {
            try {
                TaggedMatrixValue next = it.next();
                byte tag = next.getTag();
                ArrayList<AggregateInstruction> arrayList = this.agg_instructions.get(Byte.valueOf(tag));
                if (arrayList == null) {
                    this.defaultAggIns.input = tag;
                    this.defaultAggIns.output = tag;
                    processAggregateHelp(matrixIndexes.getRowIndex(), matrixIndexes.getColumnIndex(), next.getBaseObject(), this.defaultAggIns, z);
                } else {
                    Iterator<AggregateInstruction> it2 = arrayList.iterator();
                    while (it2.hasNext()) {
                        processAggregateHelp(matrixIndexes.getRowIndex(), matrixIndexes.getColumnIndex(), next.getBaseObject(), it2.next(), z);
                    }
                }
            } catch (Exception e) {
                throw new IOException(e);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean containsTernaryInstruction() {
        if (this.mixed_instructions == null) {
            return false;
        }
        Iterator<MRInstruction> it = this.mixed_instructions.iterator();
        while (it.hasNext()) {
            if (it.next() instanceof TernaryInstruction) {
                return true;
            }
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean dimsKnownForTernaryInstructions() {
        if (this.mixed_instructions == null) {
            return true;
        }
        Iterator<MRInstruction> it = this.mixed_instructions.iterator();
        while (it.hasNext()) {
            MRInstruction next = it.next();
            if ((next instanceof TernaryInstruction) && !((TernaryInstruction) next).knownOutputDims()) {
                return false;
            }
        }
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void prepareMatrixCharacteristicsTernaryInstruction(JobConf jobConf) {
        if (this.mixed_instructions != null) {
            Iterator<MRInstruction> it = this.mixed_instructions.iterator();
            while (it.hasNext()) {
                MRInstruction next = it.next();
                if (next instanceof TernaryInstruction) {
                    TernaryInstruction ternaryInstruction = (TernaryInstruction) next;
                    if (ternaryInstruction.input1 != -1) {
                        this.dimensions.put(Byte.valueOf(ternaryInstruction.input1), MRJobConfiguration.getMatrixCharacteristicsForInput(jobConf, ternaryInstruction.input1));
                    }
                }
            }
        }
    }
}
