package org.apache.sysml.runtime.instructions.mr;

import java.util.ArrayList;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.lops.PMMJ;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.mr.MRInstruction;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.mapred.CachedValueMap;
import org.apache.sysml.runtime.matrix.mapred.DistributedCacheInput;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.mapred.MRBaseForCommonInstructions;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/mr/PMMJMRInstruction.class */
public class PMMJMRInstruction extends BinaryMRInstructionBase implements IDistributedCacheConsumer {
    private long _rlen;
    private boolean _outputEmptyBlocks;

    private PMMJMRInstruction(Operator operator, byte b, byte b2, byte b3, long j, PMMJ.CacheType cacheType, boolean z, String str) {
        super(MRInstruction.MRType.PMMJ, operator, b, b2, b3);
        this._rlen = -1L;
        this._outputEmptyBlocks = true;
        this.instString = str;
        this._rlen = j;
        this._outputEmptyBlocks = z;
    }

    public long getNumRows() {
        return this._rlen;
    }

    public boolean getOutputEmptyBlocks() {
        return this._outputEmptyBlocks;
    }

    public static PMMJMRInstruction parseInstruction(String str) throws DMLRuntimeException {
        InstructionUtils.checkNumFields(str, 6);
        String[] instructionParts = InstructionUtils.getInstructionParts(str);
        String str2 = instructionParts[0];
        byte parseByte = Byte.parseByte(instructionParts[1]);
        byte parseByte2 = Byte.parseByte(instructionParts[2]);
        long j = UtilFunctions.toLong(Double.parseDouble(instructionParts[3]));
        byte parseByte3 = Byte.parseByte(instructionParts[4]);
        PMMJ.CacheType valueOf = PMMJ.CacheType.valueOf(instructionParts[5]);
        boolean parseBoolean = Boolean.parseBoolean(instructionParts[6]);
        if (str2.equalsIgnoreCase(PMMJ.OPCODE)) {
            return new PMMJMRInstruction(new Operator(true), parseByte, parseByte2, parseByte3, j, valueOf, parseBoolean, str);
        }
        throw new DMLRuntimeException("Unknown opcode while parsing an PmmMRInstruction: " + str);
    }

    @Override // org.apache.sysml.runtime.instructions.mr.MRInstruction
    public void processInstruction(Class<? extends MatrixValue> cls, CachedValueMap cachedValueMap, IndexedMatrixValue indexedMatrixValue, IndexedMatrixValue indexedMatrixValue2, int i, int i2) throws DMLRuntimeException {
        DistributedCacheInput distributedCacheInput = MRBaseForCommonInstructions.dcValues.get(Byte.valueOf(this.input1));
        IndexedMatrixValue first = cachedValueMap.getFirst(this.input2);
        MatrixBlock matrixBlock = (MatrixBlock) distributedCacheInput.getDataBlock((int) first.getIndexes().getRowIndex(), 1).getValue();
        MatrixBlock matrixBlock2 = (MatrixBlock) first.getValue();
        long j = UtilFunctions.toLong(matrixBlock.minNonZero());
        long j2 = UtilFunctions.toLong(matrixBlock.max());
        long j3 = ((j - 1) / i) + 1;
        long j4 = ((j2 - 1) / i) + 1;
        boolean z = j3 != j4;
        if (j >= 1) {
            boolean evalSparseFormatInMemory = MatrixBlock.evalSparseFormatInMemory(i, matrixBlock2.getNumColumns(), (long) (OptimizerUtils.getSparsity(matrixBlock.getNumRows(), 1L, matrixBlock.getNonZeros()) * matrixBlock2.getNonZeros()));
            IndexedMatrixValue holdPlace = cachedValueMap.holdPlace(this.output, cls);
            IndexedMatrixValue holdPlace2 = z ? cachedValueMap.holdPlace(this.output, cls) : null;
            holdPlace.getValue().reset(i, matrixBlock2.getNumColumns(), evalSparseFormatInMemory);
            if (holdPlace2 != null) {
                holdPlace2.getValue().reset(UtilFunctions.computeBlockSize(this._rlen, j4, i), matrixBlock2.getNumColumns(), evalSparseFormatInMemory);
            }
            matrixBlock.permutationMatrixMultOperations(matrixBlock2, holdPlace.getValue(), holdPlace2 != null ? holdPlace2.getValue() : null);
            ((MatrixBlock) holdPlace.getValue()).setNumRows(UtilFunctions.computeBlockSize(this._rlen, j3, i));
            holdPlace.getIndexes().setIndexes(j3, first.getIndexes().getColumnIndex());
            if (holdPlace2 != null) {
                holdPlace2.getIndexes().setIndexes(j4, first.getIndexes().getColumnIndex());
            }
            if (this._outputEmptyBlocks || !holdPlace.getValue().isEmpty()) {
                return;
            }
            if (holdPlace2 == null || holdPlace2.getValue().isEmpty()) {
                cachedValueMap.remove(this.output);
            }
        }
    }

    @Override // org.apache.sysml.runtime.instructions.mr.IDistributedCacheConsumer
    public boolean isDistCacheOnlyIndex(String str, byte b) {
        return b == this.input1 && b != this.input2;
    }

    @Override // org.apache.sysml.runtime.instructions.mr.IDistributedCacheConsumer
    public void addDistCacheIndex(String str, ArrayList<Byte> arrayList) {
        arrayList.add(Byte.valueOf(this.input1));
    }
}
