package org.apache.sysml.runtime.instructions.cp;

import java.util.ArrayList;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPInstruction;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
import org.apache.sysml.runtime.util.ConvolutionUtils;
import org.apache.sysml.utils.Statistics;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.class */
public class ConvolutionCPInstruction extends UnaryCPInstruction {
    private CPOperand _in2;
    private ArrayList<CPOperand> _input_shape;
    private ArrayList<CPOperand> _filter_shape;
    private ArrayList<CPOperand> _stride;
    private ArrayList<CPOperand> _padding;
    private boolean _reuseNonZeroedOutput;
    private int _numThreads;

    public ConvolutionCPInstruction(CPOperand cPOperand, CPOperand cPOperand2, String str, String str2, ArrayList<CPOperand> arrayList, ArrayList<CPOperand> arrayList2, ArrayList<CPOperand> arrayList3, ArrayList<CPOperand> arrayList4, int i) {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), cPOperand, cPOperand2, str, str2);
        this._stride = new ArrayList<>();
        this._padding = new ArrayList<>();
        this._reuseNonZeroedOutput = false;
        this._numThreads = -1;
        this._cptype = CPInstruction.CPINSTRUCTION_TYPE.Convolution;
        this._stride = arrayList;
        this._padding = arrayList2;
        this._input_shape = arrayList3;
        this._filter_shape = arrayList4;
        this._numThreads = i;
    }

    public ConvolutionCPInstruction(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2, ArrayList<CPOperand> arrayList, ArrayList<CPOperand> arrayList2, ArrayList<CPOperand> arrayList3, ArrayList<CPOperand> arrayList4, int i) {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), cPOperand, cPOperand3, str, str2);
        this._stride = new ArrayList<>();
        this._padding = new ArrayList<>();
        this._reuseNonZeroedOutput = false;
        this._numThreads = -1;
        this._in2 = cPOperand2;
        this._cptype = CPInstruction.CPINSTRUCTION_TYPE.Convolution;
        this._stride = arrayList;
        this._padding = arrayList2;
        this._input_shape = arrayList3;
        this._filter_shape = arrayList4;
        this._numThreads = i;
    }

    public static ConvolutionCPInstruction parseInstruction(String str) throws DMLRuntimeException {
        CPOperand cPOperand = new CPOperand("", Expression.ValueType.UNKNOWN, Expression.DataType.UNKNOWN);
        CPOperand cPOperand2 = new CPOperand("", Expression.ValueType.UNKNOWN, Expression.DataType.UNKNOWN);
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (str2.equalsIgnoreCase("reshape_col") || str2.equalsIgnoreCase("rotate180") || str2.equalsIgnoreCase("im2col") || str2.equalsIgnoreCase("col2im") || str2.equalsIgnoreCase("pooling_pre_reshape") || str2.equalsIgnoreCase("pooling_post_reshape") || str2.equalsIgnoreCase("maxpooling")) {
            InstructionUtils.checkNumFields(instructionPartsWithValueType, 15);
            cPOperand.split(instructionPartsWithValueType[1]);
            cPOperand2.split(instructionPartsWithValueType[14]);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            arrayList.add(new CPOperand(instructionPartsWithValueType[2]));
            arrayList.add(new CPOperand(instructionPartsWithValueType[3]));
            arrayList2.add(new CPOperand(instructionPartsWithValueType[4]));
            arrayList2.add(new CPOperand(instructionPartsWithValueType[5]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[6]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[7]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[8]));
            arrayList3.add(new CPOperand(instructionPartsWithValueType[9]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[10]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[11]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[12]));
            arrayList4.add(new CPOperand(instructionPartsWithValueType[13]));
            return new ConvolutionCPInstruction(cPOperand, cPOperand2, str2, str, arrayList, arrayList2, arrayList3, arrayList4, Integer.parseInt(instructionPartsWithValueType[15]));
        }
        if (!str2.equalsIgnoreCase("pooling_backward_reshape") && !str2.equalsIgnoreCase("maxpooling_backward")) {
            throw new DMLRuntimeException("Unknown opcode while parsing a ConvolutionCPInstruction: " + str);
        }
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 16);
        cPOperand.split(instructionPartsWithValueType[1]);
        CPOperand cPOperand3 = new CPOperand("", Expression.ValueType.UNKNOWN, Expression.DataType.UNKNOWN);
        cPOperand3.split(instructionPartsWithValueType[2]);
        cPOperand2.split(instructionPartsWithValueType[15]);
        ArrayList arrayList5 = new ArrayList();
        ArrayList arrayList6 = new ArrayList();
        ArrayList arrayList7 = new ArrayList();
        ArrayList arrayList8 = new ArrayList();
        arrayList5.add(new CPOperand(instructionPartsWithValueType[3]));
        arrayList5.add(new CPOperand(instructionPartsWithValueType[4]));
        arrayList6.add(new CPOperand(instructionPartsWithValueType[5]));
        arrayList6.add(new CPOperand(instructionPartsWithValueType[6]));
        arrayList7.add(new CPOperand(instructionPartsWithValueType[7]));
        arrayList7.add(new CPOperand(instructionPartsWithValueType[8]));
        arrayList7.add(new CPOperand(instructionPartsWithValueType[9]));
        arrayList7.add(new CPOperand(instructionPartsWithValueType[10]));
        arrayList8.add(new CPOperand(instructionPartsWithValueType[11]));
        arrayList8.add(new CPOperand(instructionPartsWithValueType[12]));
        arrayList8.add(new CPOperand(instructionPartsWithValueType[13]));
        arrayList8.add(new CPOperand(instructionPartsWithValueType[14]));
        return new ConvolutionCPInstruction(cPOperand, cPOperand3, cPOperand2, str2, str, arrayList5, arrayList6, arrayList7, arrayList8, Integer.parseInt(instructionPartsWithValueType[16]));
    }

    private int getScalarInput(ExecutionContext executionContext, ArrayList<CPOperand> arrayList, int i) throws DMLRuntimeException {
        return (int) executionContext.getScalarInput(arrayList.get(i).getName(), arrayList.get(i).getValueType(), arrayList.get(i).isLiteral()).getLongValue();
    }

    @Override // org.apache.sysml.runtime.instructions.cp.CPInstruction, org.apache.sysml.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) throws DMLRuntimeException {
        MatrixBlock denseOutputBlock;
        MatrixBlock matrixInput = executionContext.getMatrixInput(this.input1.getName());
        int scalarInput = getScalarInput(executionContext, this._padding, 0);
        int scalarInput2 = getScalarInput(executionContext, this._padding, 1);
        int scalarInput3 = getScalarInput(executionContext, this._stride, 0);
        int scalarInput4 = getScalarInput(executionContext, this._stride, 1);
        int scalarInput5 = getScalarInput(executionContext, this._input_shape, 0);
        int scalarInput6 = getScalarInput(executionContext, this._input_shape, 1);
        int scalarInput7 = getScalarInput(executionContext, this._input_shape, 2);
        int scalarInput8 = getScalarInput(executionContext, this._input_shape, 3);
        int scalarInput9 = getScalarInput(executionContext, this._filter_shape, 0);
        int scalarInput10 = getScalarInput(executionContext, this._filter_shape, 2);
        int scalarInput11 = getScalarInput(executionContext, this._filter_shape, 3);
        int p = (int) ConvolutionUtils.getP(scalarInput7, scalarInput10, scalarInput3, scalarInput);
        int q = (int) ConvolutionUtils.getQ(scalarInput8, scalarInput11, scalarInput4, scalarInput2);
        LibMatrixDNN.ConvolutionParameters convolutionParameters = new LibMatrixDNN.ConvolutionParameters(scalarInput5, scalarInput6, scalarInput7, scalarInput8, scalarInput9, scalarInput10, scalarInput11, scalarInput3, scalarInput4, scalarInput, scalarInput2, this._numThreads);
        if (this.instOpcode.equalsIgnoreCase("im2col")) {
            checkHeightWidth(executionContext, convolutionParameters);
            checkInputDimensionForIm2col(matrixInput, convolutionParameters);
            denseOutputBlock = getDenseOutputBlock(executionContext, scalarInput6 * scalarInput10 * scalarInput11, scalarInput5 * p * q, true);
            convolutionParameters.setReuseNonZeroedOutput(this._reuseNonZeroedOutput);
            LibMatrixDNN.im2col(matrixInput, denseOutputBlock, convolutionParameters);
        } else if (this.instOpcode.equalsIgnoreCase("reshape_col")) {
            checkHeightWidth(executionContext, convolutionParameters);
            denseOutputBlock = getDenseOutputBlock(executionContext, scalarInput5, scalarInput9 * p * q, true);
            convolutionParameters.setReuseNonZeroedOutput(this._reuseNonZeroedOutput);
            LibMatrixDNN.reshape_col(matrixInput, denseOutputBlock, convolutionParameters);
        } else if (this.instOpcode.equalsIgnoreCase("rotate180")) {
            checkHeightWidth(executionContext, convolutionParameters);
            denseOutputBlock = getDenseOutputBlock(executionContext, scalarInput5 * p * q, scalarInput9, true);
            convolutionParameters.setReuseNonZeroedOutput(this._reuseNonZeroedOutput);
            LibMatrixDNN.rotate180(matrixInput, denseOutputBlock, convolutionParameters);
        } else if (this.instOpcode.equalsIgnoreCase("col2im")) {
            checkHeightWidth(executionContext, convolutionParameters);
            checkInputDimensionForCol2im(matrixInput, convolutionParameters);
            denseOutputBlock = getDenseOutputBlock(executionContext, scalarInput5, scalarInput6 * scalarInput7 * scalarInput8, false);
            convolutionParameters.setReuseNonZeroedOutput(this._reuseNonZeroedOutput);
            LibMatrixDNN.col2im(matrixInput, denseOutputBlock, convolutionParameters);
        } else if (this.instOpcode.equalsIgnoreCase("maxpooling")) {
            denseOutputBlock = getDenseOutputBlock(executionContext, scalarInput5, scalarInput6 * p * q, true);
            convolutionParameters.setReuseNonZeroedOutput(this._reuseNonZeroedOutput);
            LibMatrixDNN.maxpooling(matrixInput, denseOutputBlock, convolutionParameters);
        } else {
            if (!this.instOpcode.equalsIgnoreCase("maxpooling_backward")) {
                throw new DMLRuntimeException("Unsupported op code " + this.instOpcode);
            }
            MatrixBlock matrixInput2 = executionContext.getMatrixInput(this._in2.getName());
            denseOutputBlock = getDenseOutputBlock(executionContext, scalarInput5, scalarInput6 * scalarInput7 * scalarInput8, false);
            convolutionParameters.setReuseNonZeroedOutput(this._reuseNonZeroedOutput);
            LibMatrixDNN.maxpooling_backward(matrixInput, matrixInput2, denseOutputBlock, convolutionParameters);
            executionContext.releaseMatrixInput(this._in2.getName());
        }
        executionContext.releaseMatrixInput(this.input1.getName());
        executionContext.setMatrixOutput(getOutputVariableName(), denseOutputBlock);
    }

    private MatrixBlock getDenseOutputBlock(ExecutionContext executionContext, int i, int i2, boolean z) throws DMLRuntimeException {
        long j = -1;
        if (DMLScript.STATISTICS) {
            j = System.nanoTime();
        }
        MatrixBlock matrixBlock = new MatrixBlock(i, i2, i * i2);
        this._reuseNonZeroedOutput = false;
        if (z) {
        }
        matrixBlock.allocateDenseBlock();
        matrixBlock.setNonZeros(-1L);
        if (DMLScript.STATISTICS) {
            Statistics.incrementAllocationTime(System.nanoTime() - j, false);
        }
        return matrixBlock;
    }

    private void checkHeightWidth(ExecutionContext executionContext, LibMatrixDNN.ConvolutionParameters convolutionParameters) throws DMLRuntimeException {
        if (getScalarInput(executionContext, this._filter_shape, 1) != convolutionParameters.C) {
            throw new DMLRuntimeException("The number of channels of input and filter should match");
        }
        if (((convolutionParameters.W + (2 * convolutionParameters.pad_w)) - convolutionParameters.S) % convolutionParameters.stride_w != 0) {
            throw new DMLRuntimeException("The width does not work (Hint: (W + 2 * pad_w - S) % stride_w should be 0 [ ==> (" + convolutionParameters.W + "+ 2*" + convolutionParameters.pad_w + "-" + convolutionParameters.S + ") % " + convolutionParameters.stride_w + "!= 0] ");
        }
        if (((convolutionParameters.H + (2 * convolutionParameters.pad_h)) - convolutionParameters.R) % convolutionParameters.stride_h != 0) {
            throw new DMLRuntimeException("The height does not work (Hint: (H + 2 * pad_h - R) % stride_h should be 0 [ ==> (" + convolutionParameters.H + "+ 2*" + convolutionParameters.pad_h + "-" + convolutionParameters.R + ") % " + convolutionParameters.stride_h + "!= 0] ");
        }
        if (convolutionParameters.H <= 0) {
            throw new DMLRuntimeException("Height of output patch should be zero");
        }
        if (convolutionParameters.Q <= 0) {
            throw new DMLRuntimeException("Width of output patch should be zero");
        }
    }

    private void checkInputDimensionForIm2col(MatrixBlock matrixBlock, LibMatrixDNN.ConvolutionParameters convolutionParameters) throws DMLRuntimeException {
        if (convolutionParameters.N != matrixBlock.getNumRows() || convolutionParameters.C * convolutionParameters.H * convolutionParameters.W != matrixBlock.getNumColumns()) {
            throw new DMLRuntimeException("Incorrect input shape in conv2d");
        }
    }

    private void checkInputDimensionForCol2im(MatrixBlock matrixBlock, LibMatrixDNN.ConvolutionParameters convolutionParameters) throws DMLRuntimeException {
        if (convolutionParameters.C * convolutionParameters.R * convolutionParameters.S != matrixBlock.getNumRows() || convolutionParameters.N * convolutionParameters.P * convolutionParameters.Q != matrixBlock.getNumColumns()) {
            throw new DMLRuntimeException("Incorrect input shape in conv2d_backward_data");
        }
    }
}
