package org.apache.sysml.runtime.matrix.mapred;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.Reporter;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.io.MatrixWriter;
import org.apache.sysml.runtime.matrix.data.CTableMap;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.util.LongLongDoubleHashMap;

/* loaded from: input_file:org/apache/sysml/runtime/matrix/mapred/GMRCtableBuffer.class */
public class GMRCtableBuffer {
    public static final int MAX_BUFFER_SIZE = 4096;
    private HashMap<Byte, CTableMap> _mapBuffer;
    private HashMap<Byte, MatrixBlock> _blockBuffer;
    private CollectMultipleConvertedOutputs _collector;
    private byte[] _resultIndexes = null;
    private long[] _resultNonZeros = null;
    private byte[] _resultDimsUnknown = null;
    private long[] _resultMaxRowDims = null;
    private long[] _resultMaxColDims = null;

    public GMRCtableBuffer(CollectMultipleConvertedOutputs collectMultipleConvertedOutputs, boolean z) {
        this._mapBuffer = null;
        this._blockBuffer = null;
        this._collector = null;
        if (z) {
            this._blockBuffer = new HashMap<>();
        } else {
            this._mapBuffer = new HashMap<>();
        }
        this._collector = collectMultipleConvertedOutputs;
    }

    public void setMetadataReferences(byte[] bArr, long[] jArr, byte[] bArr2, long[] jArr2, long[] jArr3) {
        this._resultIndexes = bArr;
        this._resultNonZeros = jArr;
        this._resultDimsUnknown = bArr2;
        this._resultMaxRowDims = jArr2;
        this._resultMaxColDims = jArr3;
    }

    public int getBufferSize() {
        if (this._mapBuffer != null) {
            int i = 0;
            Iterator<Map.Entry<Byte, CTableMap>> it = this._mapBuffer.entrySet().iterator();
            while (it.hasNext()) {
                i += it.next().getValue().size();
            }
            return i;
        }
        if (this._blockBuffer == null) {
            return 0;
        }
        int i2 = 0;
        Iterator<Map.Entry<Byte, MatrixBlock>> it2 = this._blockBuffer.entrySet().iterator();
        while (it2.hasNext()) {
            it2.next().getValue().recomputeNonZeros();
            i2 = (int) (i2 + MatrixBlock.estimateSizeInMemory(r0.getValue().getNumRows(), r0.getValue().getNumColumns(), (r0.getValue().getNonZeros() / r0.getValue().getNumRows()) * r0.getValue().getNumColumns()));
        }
        return i2;
    }

    public HashMap<Byte, CTableMap> getMapBuffer() {
        return this._mapBuffer;
    }

    public HashMap<Byte, MatrixBlock> getBlockBuffer() {
        return this._blockBuffer;
    }

    public void flushBuffer(Reporter reporter) throws RuntimeException {
        try {
            if (this._mapBuffer != null) {
                Writable matrixCell = new MatrixCell();
                for (Map.Entry<Byte, CTableMap> entry : this._mapBuffer.entrySet()) {
                    ArrayList<Integer> outputIndexes = ReduceBase.getOutputIndexes(entry.getKey().byteValue(), this._resultIndexes);
                    CTableMap value = entry.getValue();
                    Iterator<Integer> it = outputIndexes.iterator();
                    while (it.hasNext()) {
                        Integer next = it.next();
                        long[] jArr = this._resultNonZeros;
                        int intValue = next.intValue();
                        jArr[intValue] = jArr[intValue] + value.size();
                        if (this._resultDimsUnknown[next.intValue()] == 1) {
                            this._resultMaxRowDims[next.intValue()] = Math.max(value.getMaxRow(), this._resultMaxRowDims[next.intValue()]);
                            this._resultMaxColDims[next.intValue()] = Math.max(value.getMaxColumn(), this._resultMaxColDims[next.intValue()]);
                        }
                    }
                    Iterator<LongLongDoubleHashMap.LLDoubleEntry> it2 = value.entrySet().iterator();
                    while (it2.hasNext()) {
                        LongLongDoubleHashMap.LLDoubleEntry next2 = it2.next();
                        Writable matrixIndexes = new MatrixIndexes(next2.key1, next2.key2);
                        matrixCell.setValue(next2.value);
                        Iterator<Integer> it3 = outputIndexes.iterator();
                        while (it3.hasNext()) {
                            this._collector.collectOutput(matrixIndexes, matrixCell, it3.next().intValue(), reporter);
                        }
                    }
                }
            } else {
                if (this._blockBuffer == null) {
                    throw new DMLRuntimeException("Unexpected.. both ctable buffers are empty.");
                }
                Writable matrixIndexes2 = new MatrixIndexes(1L, 1L);
                for (Map.Entry<Byte, MatrixBlock> entry2 : this._blockBuffer.entrySet()) {
                    ArrayList<Integer> outputIndexes2 = ReduceBase.getOutputIndexes(entry2.getKey().byteValue(), this._resultIndexes);
                    MatrixBlock value2 = entry2.getValue();
                    value2.recomputeNonZeros();
                    int numRows = value2.getNumRows();
                    int numColumns = value2.getNumColumns();
                    if (numRows > 1000 || numColumns > 1000) {
                        MatrixBlock[] createMatrixBlocksForReuse = MatrixWriter.createMatrixBlocksForReuse(numRows, numColumns, 1000, 1000, true, value2.getNonZeros());
                        for (int i = 0; i < ((int) Math.ceil(numRows / 1000)); i++) {
                            for (int i2 = 0; i2 < ((int) Math.ceil(numColumns / 1000)); i2++) {
                                int i3 = (i * 1000) + 1000 < numRows ? 1000 : numRows - (i * 1000);
                                int i4 = (i2 * 1000) + 1000 < numColumns ? 1000 : numColumns - (i2 * 1000);
                                int i5 = i * 1000;
                                int i6 = i2 * 1000;
                                MatrixBlock matrixBlockForReuse = MatrixWriter.getMatrixBlockForReuse(createMatrixBlocksForReuse, i3, i4, 1000, 1000);
                                value2.sliceOperations(i5, (i5 + i3) - 1, i6, (i6 + i4) - 1, matrixBlockForReuse);
                                matrixIndexes2.setIndexes(i + 1, i2 + 1);
                                Iterator<Integer> it4 = outputIndexes2.iterator();
                                while (it4.hasNext()) {
                                    Integer next3 = it4.next();
                                    this._collector.collectOutput(matrixIndexes2, matrixBlockForReuse, next3.intValue(), reporter);
                                    long[] jArr2 = this._resultNonZeros;
                                    int intValue2 = next3.intValue();
                                    jArr2[intValue2] = jArr2[intValue2] + matrixBlockForReuse.getNonZeros();
                                }
                                matrixBlockForReuse.reset();
                            }
                        }
                    } else {
                        matrixIndexes2 = new MatrixIndexes(1L, 1L);
                        Iterator<Integer> it5 = outputIndexes2.iterator();
                        while (it5.hasNext()) {
                            Integer next4 = it5.next();
                            this._collector.collectOutput(matrixIndexes2, value2, next4.intValue(), reporter);
                            long[] jArr3 = this._resultNonZeros;
                            int intValue3 = next4.intValue();
                            jArr3[intValue3] = jArr3[intValue3] + value2.getNonZeros();
                        }
                    }
                }
            }
            if (this._mapBuffer != null) {
                this._mapBuffer.clear();
            } else {
                this._blockBuffer.clear();
            }
        } catch (Exception e) {
            throw new RuntimeException("Failed to flush ctable buffer.", e);
        }
    }
}
