package org.apache.sysml.runtime.matrix.mapred;

import java.io.IOException;
import java.util.Iterator;
import org.apache.hadoop.hdfs.web.resources.OffsetParam;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.lops.MMCJ;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.data.Pair;
import org.apache.sysml.runtime.matrix.data.TaggedFirstSecondIndexes;
import org.apache.sysml.runtime.matrix.mapred.ReduceBase;
import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysml.runtime.util.MapReduceTool;

/* loaded from: input_file:org/apache/sysml/runtime/matrix/mapred/MMCJMRReducerWithAggregator.class */
public class MMCJMRReducerWithAggregator extends MMCJMRCombinerReducerBase implements Reducer<TaggedFirstSecondIndexes, MatrixValue, Writable, Writable> {
    public static long MIN_CACHE_SIZE = 67108864;
    private MMCJMRInputCache cache = null;
    private PartialAggregator aggregator = null;
    private double prevFirstIndex = -1.0d;
    private int prevTag = -1;
    private MatrixIndexes indexesbuffer = new MatrixIndexes();
    private MatrixValue valueBuffer = null;
    private boolean outputDummyRecords = false;

    @Override // org.apache.hadoop.mapred.Reducer
    public void reduce(TaggedFirstSecondIndexes taggedFirstSecondIndexes, Iterator<MatrixValue> it, OutputCollector<Writable, Writable> outputCollector, Reporter reporter) throws IOException {
        MatrixValue performAggregateInstructions;
        long currentTimeMillis = System.currentTimeMillis();
        commonSetup(reporter);
        if (this.valueClass == MatrixBlock.class) {
            performAggregateInstructions = it.next();
        } else {
            performAggregateInstructions = performAggregateInstructions(taggedFirstSecondIndexes, it);
            if (performAggregateInstructions == null) {
                return;
            }
        }
        byte tag = taggedFirstSecondIndexes.getTag();
        long firstIndex = taggedFirstSecondIndexes.getFirstIndex();
        long secondIndex = taggedFirstSecondIndexes.getSecondIndex();
        if (this.prevFirstIndex != firstIndex) {
            this.cache.resetCache(true);
            this.prevFirstIndex = firstIndex;
        } else if (this.prevTag > tag) {
            throw new RuntimeException("tag is not ordered correctly: " + this.prevTag + " > " + ((int) tag));
        }
        this.prevTag = tag;
        processJoin(tag, secondIndex, performAggregateInstructions);
        reporter.incrCounter(ReduceBase.Counters.COMBINE_OR_REDUCE_TIME, System.currentTimeMillis() - currentTimeMillis);
    }

    private void processJoin(int i, long j, MatrixValue matrixValue) throws IOException {
        try {
            if (i == 0) {
                this.cache.put(j, matrixValue);
            } else {
                for (int i2 = 0; i2 < this.cache.getCacheSize(); i2++) {
                    Pair<MatrixIndexes, MatrixValue> pair = this.cache.get(i2);
                    if (this.tagForLeft == 0) {
                        this.indexesbuffer.setIndexes(pair.getKey().getRowIndex(), j);
                        OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes(pair.getValue(), matrixValue, this.valueBuffer, (AggregateBinaryOperator) this.aggBinInstruction.getOperator());
                    } else {
                        this.indexesbuffer.setIndexes(j, pair.getKey().getColumnIndex());
                        OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes(matrixValue, pair.getValue(), this.valueBuffer, (AggregateBinaryOperator) this.aggBinInstruction.getOperator());
                    }
                    if (this.aggBinInstruction.getMMCJType() == MMCJ.MMCJType.AGG) {
                        this.aggregator.aggregateToBuffer(this.indexesbuffer, this.valueBuffer, this.tagForLeft == 0);
                    } else {
                        this.collectFinalMultipleOutputs.collectOutput(this.indexesbuffer, this.valueBuffer, 0, this.cachedReporter);
                        long[] jArr = this.resultsNonZeros;
                        jArr[0] = jArr[0] + this.valueBuffer.getNonZeros();
                    }
                }
            }
        } catch (Exception e) {
            throw new IOException(e);
        }
    }

    @Override // org.apache.sysml.runtime.matrix.mapred.MMCJMRCombinerReducerBase, org.apache.sysml.runtime.matrix.mapred.ReduceBase, org.apache.sysml.runtime.matrix.mapred.MRBaseForCommonInstructions, org.apache.hadoop.mapred.MapReduceBase, org.apache.hadoop.mapred.JobConfigurable
    public void configure(JobConf jobConf) {
        long j;
        long j2;
        super.configure(jobConf);
        if (this.resultIndexes.length > 1) {
            throw new RuntimeException("MMCJMR only outputs one result");
        }
        this.outputDummyRecords = MapReduceTool.getUniqueKeyPerTask(jobConf, false).equals(OffsetParam.DEFAULT);
        try {
            this.valueBuffer = this.buffer;
            long mMCJCacheSize = MRJobConfiguration.getMMCJCacheSize(jobConf);
            long localMemBudget = (long) OptimizerUtils.getLocalMemBudget();
            if (localMemBudget - mMCJCacheSize > MIN_CACHE_SIZE) {
                j = mMCJCacheSize;
                j2 = localMemBudget - mMCJCacheSize;
            } else {
                j = localMemBudget - (2 * MIN_CACHE_SIZE);
                j2 = MIN_CACHE_SIZE;
            }
            try {
                if (this.tagForLeft == 0) {
                    this.cache = new MMCJMRInputCache(jobConf, j, this.dim1.getRows(), this.dim1.getCols(), this.dim1.getRowsPerBlock(), this.dim1.getColsPerBlock(), true, this.valueClass);
                } else {
                    this.cache = new MMCJMRInputCache(jobConf, j, this.dim2.getRows(), this.dim2.getCols(), this.dim2.getRowsPerBlock(), this.dim2.getColsPerBlock(), false, this.valueClass);
                }
                if (this.aggBinInstruction.getMMCJType() == MMCJ.MMCJType.AGG) {
                    this.aggregator = new PartialAggregator(jobConf, j2, this.dim1.getRows(), this.dim2.getCols(), this.dim1.getRowsPerBlock(), this.dim2.getColsPerBlock(), this.tagForLeft != 0, (AggregateBinaryOperator) this.aggBinInstruction.getOperator(), this.valueClass);
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } catch (Exception e2) {
            throw new RuntimeException(e2);
        }
    }

    @Override // org.apache.sysml.runtime.matrix.mapred.ReduceBase, org.apache.hadoop.mapred.MapReduceBase, java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        if (this.cachedReporter != null) {
            long currentTimeMillis = System.currentTimeMillis();
            if (this.aggBinInstruction.getMMCJType() == MMCJ.MMCJType.AGG) {
                long[] jArr = this.resultsNonZeros;
                jArr[0] = jArr[0] + this.aggregator.outputToHadoop(this.collectFinalMultipleOutputs, 0, this.cachedReporter);
            }
            this.cachedReporter.incrCounter(ReduceBase.Counters.COMBINE_OR_REDUCE_TIME, System.currentTimeMillis() - currentTimeMillis);
        }
        if (this.outputDummyRecords) {
            long rows = this.dim1.getRows();
            long cols = this.dim2.getCols();
            int rowsPerBlock = this.dim1.getRowsPerBlock();
            int colsPerBlock = this.dim2.getColsPerBlock();
            MatrixIndexes matrixIndexes = new MatrixIndexes();
            MatrixBlock matrixBlock = new MatrixBlock();
            long j = 0;
            long j2 = 1;
            while (true) {
                long j3 = j2;
                if (j >= rows) {
                    break;
                }
                long j4 = 0;
                long j5 = 1;
                while (true) {
                    long j6 = j5;
                    if (j4 < cols) {
                        int min = (int) Math.min(rowsPerBlock, rows - ((j3 - 1) * rowsPerBlock));
                        int min2 = (int) Math.min(colsPerBlock, cols - ((j6 - 1) * colsPerBlock));
                        matrixIndexes.setIndexes(j3, j6);
                        if (this.aggBinInstruction.getMMCJType() == MMCJ.MMCJType.NO_AGG || !this.aggregator.getBufferMap().containsKey(matrixIndexes)) {
                            matrixBlock.reset(min, min2);
                            this.collectFinalMultipleOutputs.collectOutput(matrixIndexes, matrixBlock, 0, this.cachedReporter);
                        }
                        j4 += colsPerBlock;
                        j5 = j6 + 1;
                    }
                }
                j += rowsPerBlock;
                j2 = j3 + 1;
            }
        }
        this.cache.close();
        super.close();
    }
}
