package org.apache.sysml.runtime.instructions.mr;

import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.mr.MRInstruction;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.mapred.CachedValueMap;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.TernaryOperator;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/mr/TernaryInstruction.class */
public class TernaryInstruction extends MRInstruction {
    private final CPOperand input1;
    private final CPOperand input2;
    private final CPOperand input3;
    private final CPOperand output;
    private final byte ixinput1;
    private final byte ixinput2;
    private final byte ixinput3;
    private final byte ixoutput;
    private final MatrixBlock m1;
    private final MatrixBlock m2;
    private final MatrixBlock m3;

    private TernaryInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str) {
        super(MRInstruction.MRType.Ternary, operator, Byte.parseByte(cPOperand4.getName()));
        this.instString = str;
        this.input1 = cPOperand;
        this.input2 = cPOperand2;
        this.input3 = cPOperand3;
        this.output = cPOperand4;
        this.ixinput1 = this.input1.isMatrix() ? Byte.parseByte(this.input1.getName()) : (byte) -1;
        this.ixinput2 = this.input2.isMatrix() ? Byte.parseByte(this.input2.getName()) : (byte) -1;
        this.ixinput3 = this.input3.isMatrix() ? Byte.parseByte(this.input3.getName()) : (byte) -1;
        this.ixoutput = this.output.isMatrix() ? Byte.parseByte(this.output.getName()) : (byte) -1;
        this.m1 = this.input1.isMatrix() ? null : new MatrixBlock(Double.parseDouble(this.input1.getName()));
        this.m2 = this.input2.isMatrix() ? null : new MatrixBlock(Double.parseDouble(this.input2.getName()));
        this.m3 = this.input3.isMatrix() ? null : new MatrixBlock(Double.parseDouble(this.input3.getName()));
    }

    public static TernaryInstruction parseInstruction(String str) throws DMLRuntimeException {
        InstructionUtils.checkNumFields(str, 4);
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        return new TernaryInstruction(InstructionUtils.parseTernaryOperator(str2), new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), new CPOperand(instructionPartsWithValueType[4]), str);
    }

    @Override // org.apache.sysml.runtime.instructions.mr.MRInstruction
    public void processInstruction(Class<? extends MatrixValue> cls, CachedValueMap cachedValueMap, IndexedMatrixValue indexedMatrixValue, IndexedMatrixValue indexedMatrixValue2, int i, int i2) throws DMLRuntimeException {
        MatrixBlock matrixBlock = this.input1.isMatrix() ? (MatrixBlock) cachedValueMap.getFirst(this.ixinput1).getValue() : this.m1;
        MatrixBlock matrixBlock2 = this.input2.isMatrix() ? (MatrixBlock) cachedValueMap.getFirst(this.ixinput2).getValue() : this.m2;
        MatrixBlock matrixBlock3 = this.input3.isMatrix() ? (MatrixBlock) cachedValueMap.getFirst(this.ixinput3).getValue() : this.m3;
        MatrixIndexes indexes = this.input1.isMatrix() ? cachedValueMap.getFirst(this.ixinput1).getIndexes() : this.input2.isMatrix() ? cachedValueMap.getFirst(this.ixinput2).getIndexes() : cachedValueMap.getFirst(this.ixinput3).getIndexes();
        IndexedMatrixValue indexedMatrixValue3 = new IndexedMatrixValue(new MatrixIndexes(), new MatrixBlock());
        indexedMatrixValue3.getIndexes().setIndexes(indexes);
        matrixBlock.ternaryOperations((TernaryOperator) this.optr, matrixBlock2, matrixBlock3, (MatrixBlock) indexedMatrixValue3.getValue());
        cachedValueMap.add(this.ixoutput, indexedMatrixValue3);
    }

    @Override // org.apache.sysml.runtime.instructions.mr.MRInstruction
    public byte[] getInputIndexes() {
        byte[] allIndexes = getAllIndexes();
        return Arrays.copyOfRange(allIndexes, 0, allIndexes.length - 1);
    }

    @Override // org.apache.sysml.runtime.instructions.mr.MRInstruction
    public byte[] getAllIndexes() {
        return ArrayUtils.toPrimitive((Byte[]) Arrays.stream(new CPOperand[]{this.input1, this.input2, this.input3, this.output}).filter(cPOperand -> {
            return cPOperand.isMatrix();
        }).map(cPOperand2 -> {
            return Byte.valueOf(Byte.parseByte(cPOperand2.getName()));
        }).toArray(i -> {
            return new Byte[i];
        }));
    }
}
