package org.apache.sysml.hops;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.lops.ConvolutionTransform;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.parser.DataExpression;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;

/* loaded from: input_file:org/apache/sysml/hops/ConvolutionOp.class */
public class ConvolutionOp extends Hop implements Hop.MultiThreadedHop {
    private Hop.ConvOp op;
    private int _maxNumThreads;

    private ConvolutionOp() {
        this._maxNumThreads = -1;
    }

    public ConvolutionOp(String str, Expression.DataType dataType, Expression.ValueType valueType, Hop.ConvOp convOp, Hop hop) {
        super(str, dataType, valueType);
        this._maxNumThreads = -1;
        this.op = convOp;
        getInput().add(0, hop);
        hop.getParent().add(this);
        refreshSizeInformation();
    }

    public ConvolutionOp(String str, Expression.DataType dataType, Expression.ValueType valueType, Hop.ConvOp convOp, ArrayList<Hop> arrayList) {
        super(str, dataType, valueType);
        this._maxNumThreads = -1;
        this.op = convOp;
        for (int i = 0; i < arrayList.size(); i++) {
            Hop hop = arrayList.get(i);
            getInput().add(i, hop);
            hop.getParent().add(this);
        }
        refreshSizeInformation();
    }

    public Hop.ConvOp getOp() {
        return this.op;
    }

    @Override // org.apache.sysml.hops.Hop
    public String getOpString() {
        return "" + HopsConv2Lops.get(this.op);
    }

    @Override // org.apache.sysml.hops.Hop
    public Lop constructLops() throws HopsException, LopsException {
        if (getLops() != null) {
            return getLops();
        }
        LopProperties.ExecType optFindExecType = optFindExecType();
        ArrayList<Hop> input = getInput();
        switch (this.op) {
            case IM2COL:
            case RESHAPE_COL:
            case ROTATE180:
            case COL2IM:
            case MAX_POOLING:
            case MAX_POOLING_BACKWARD:
            case DIRECT_CONV2D:
            case DIRECT_CONV2D_BACKWARD_DATA:
            case DIRECT_CONV2D_BACKWARD_FILTER:
                if (optFindExecType != LopProperties.ExecType.CP) {
                    throw new HopsException("Unimplemented ConvolutionOp for execution type: " + optFindExecType.name());
                }
                setLops(constructConvolutionLops(optFindExecType, input));
                constructAndSetLopsDataFlowProperties();
                return getLops();
            default:
                throw new HopsException("Unsupported lops construction for operation type '" + this.op + "'.");
        }
    }

    public void setOp(Hop.ConvOp convOp) {
        this.op = convOp;
    }

    public Lop constructConvolutionLops(LopProperties.ExecType execType, ArrayList<Hop> arrayList) throws HopsException, LopsException {
        int i = (this.op == Hop.ConvOp.MAX_POOLING_BACKWARD || this.op == Hop.ConvOp.DIRECT_CONV2D || this.op == Hop.ConvOp.DIRECT_CONV2D_BACKWARD_FILTER || this.op == Hop.ConvOp.DIRECT_CONV2D_BACKWARD_DATA) ? 14 : 13;
        if (arrayList.size() != i) {
            throw new HopsException("Incorrect number of inputs for " + this.op.name());
        }
        ConvolutionTransform convolutionTransform = new ConvolutionTransform(arrayList.get(0).constructLops(), HopsConv2Lops.get(this.op), getDataType(), getValueType(), execType, OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
        setOutputDimensions(convolutionTransform);
        setLineNumbers(convolutionTransform);
        for (int i2 = 1; i2 <= i - 1; i2++) {
            Lop constructLops = arrayList.get(i2).constructLops();
            convolutionTransform.addInput(constructLops);
            constructLops.addOutput(convolutionTransform);
        }
        convolutionTransform.setLevel();
        return convolutionTransform;
    }

    @Override // org.apache.sysml.hops.Hop
    protected double computeOutputMemEstimate(long j, long j2, long j3) {
        double d = 1.0d;
        switch (this.op) {
            case IM2COL:
            case COL2IM:
            case MAX_POOLING:
            case MAX_POOLING_BACKWARD:
            case DIRECT_CONV2D:
            case DIRECT_CONV2D_BACKWARD_DATA:
            case DIRECT_CONV2D_BACKWARD_FILTER:
                d = 1.0d;
                break;
            case RESHAPE_COL:
            case ROTATE180:
                d = OptimizerUtils.getSparsity(j, j2, j3);
                break;
        }
        return OptimizerUtils.estimateSizeExactSparsity(j, j2, d);
    }

    @Override // org.apache.sysml.hops.Hop
    protected double computeIntermediateMemEstimate(long j, long j2, long j3) {
        return DataExpression.DEFAULT_DELIM_FILL_VALUE;
    }

    @Override // org.apache.sysml.hops.Hop
    protected long[] inferOutputCharacteristics(MemoTable memoTable) {
        long[] jArr = null;
        MatrixCharacteristics allInputStats = memoTable.getAllInputStats(getInput().get(0));
        try {
            LibMatrixDNN.ConvolutionParameters parseInput = parseInput();
            switch (this.op) {
                case RESHAPE_COL:
                    jArr = new long[]{parseInput.N, getExtractedVal(parseInput.K, parseInput.P, parseInput.Q), allInputStats.getNonZeros()};
                    break;
                case ROTATE180:
                    jArr = new long[]{getExtractedVal(parseInput.N, parseInput.P, parseInput.Q), parseInput.K, allInputStats.getNonZeros()};
                    break;
            }
            return jArr;
        } catch (DMLRuntimeException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.apache.sysml.hops.Hop
    public boolean allowsAllExecTypes() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysml.hops.Hop
    public LopProperties.ExecType optFindExecType() throws HopsException {
        checkAndSetForcedPlatform();
        LopProperties.ExecType execType = OptimizerUtils.isSparkExecutionMode() ? LopProperties.ExecType.SPARK : LopProperties.ExecType.MR;
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            if (OptimizerUtils.isMemoryBasedOptLevel()) {
                this._etype = findExecTypeByMemEstimate();
            } else if (getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVector()) {
                this._etype = LopProperties.ExecType.CP;
            } else {
                this._etype = execType;
            }
            checkAndSetInvalidCPDimsAndSize();
        }
        if (ConfigurationManager.isDynamicRecompilation() && !dimsKnown(true) && this._etype == execType) {
            setRequiresRecompile();
        }
        this._etype = LopProperties.ExecType.CP;
        return this._etype;
    }

    LibMatrixDNN.ConvolutionParameters parseInput() throws DMLRuntimeException {
        return new LibMatrixDNN.ConvolutionParameters(extractValue(getInput().get(5)), extractValue(getInput().get(6)), extractValue(getInput().get(7)), extractValue(getInput().get(8)), extractValue(getInput().get(9)), extractValue(getInput().get(11)), extractValue(getInput().get(12)), extractValue(getInput().get(1)), extractValue(getInput().get(2)), extractValue(getInput().get(3)), extractValue(getInput().get(4)), this._maxNumThreads);
    }

    long getExtractedVal(long j, long j2) {
        if (j == -1 || j2 == -1) {
            return -1L;
        }
        return j * j2;
    }

    long getExtractedVal(long j, long j2, long j3) {
        if (j == -1 || j2 == -1 || j3 == -1) {
            return -1L;
        }
        return j * j2 * j3;
    }

    @Override // org.apache.sysml.hops.Hop
    public void refreshSizeInformation() {
        Hop hop = getInput().get(0);
        try {
            LibMatrixDNN.ConvolutionParameters parseInput = parseInput();
            switch (this.op) {
                case IM2COL:
                    this._dim1 = getExtractedVal(parseInput.C, parseInput.R, parseInput.S);
                    this._dim2 = getExtractedVal(parseInput.N, parseInput.P, parseInput.Q);
                    this._nnz = -1L;
                    return;
                case RESHAPE_COL:
                    this._dim1 = parseInput.N;
                    this._dim2 = getExtractedVal(parseInput.K, parseInput.P, parseInput.Q);
                    this._nnz = hop.getNnz();
                    return;
                case ROTATE180:
                    this._dim1 = getExtractedVal(parseInput.N, parseInput.P, parseInput.Q);
                    this._dim2 = parseInput.K;
                    this._nnz = hop.getNnz();
                    return;
                case COL2IM:
                    this._dim1 = parseInput.N;
                    this._dim2 = getExtractedVal(parseInput.C, parseInput.H, parseInput.W);
                    this._nnz = -1L;
                    return;
                case MAX_POOLING:
                    this._dim1 = parseInput.N;
                    this._dim2 = getExtractedVal(parseInput.C, parseInput.P, parseInput.Q);
                    this._nnz = -1L;
                    return;
                case MAX_POOLING_BACKWARD:
                    this._dim1 = parseInput.N;
                    this._dim2 = getExtractedVal(parseInput.C, parseInput.H, parseInput.W);
                    this._nnz = -1L;
                    return;
                case DIRECT_CONV2D:
                    this._dim1 = parseInput.N;
                    this._dim2 = getExtractedVal(parseInput.K, parseInput.P, parseInput.Q);
                    this._nnz = -1L;
                    return;
                case DIRECT_CONV2D_BACKWARD_DATA:
                    this._dim1 = parseInput.N;
                    this._dim2 = getExtractedVal(parseInput.C, parseInput.H, parseInput.W);
                    this._nnz = -1L;
                    return;
                case DIRECT_CONV2D_BACKWARD_FILTER:
                    this._dim1 = parseInput.K;
                    this._dim2 = getExtractedVal(parseInput.C, parseInput.R, parseInput.S);
                    this._nnz = -1L;
                    return;
                default:
                    throw new RuntimeException("The sizes are not refreshed for " + this.op.name());
            }
        } catch (DMLRuntimeException e) {
            throw new RuntimeException(e);
        }
    }

    private long extractValue(Hop hop) {
        if (hop instanceof LiteralOp) {
            return (long) HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop);
        }
        return -1L;
    }

    @Override // org.apache.sysml.hops.Hop
    public Object clone() throws CloneNotSupportedException {
        ConvolutionOp convolutionOp = new ConvolutionOp();
        convolutionOp.clone(this, false);
        convolutionOp.op = this.op;
        convolutionOp._maxNumThreads = this._maxNumThreads;
        return convolutionOp;
    }

    @Override // org.apache.sysml.hops.Hop
    public boolean compare(Hop hop) {
        if (!(hop instanceof ConvolutionOp)) {
            return false;
        }
        ConvolutionOp convolutionOp = (ConvolutionOp) hop;
        boolean z = this.op == convolutionOp.op && getInput().size() == hop.getInput().size() && this._maxNumThreads == convolutionOp._maxNumThreads;
        if (z) {
            for (int i = 0; i < this._input.size(); i++) {
                z &= getInput().get(i) == convolutionOp.getInput().get(i);
            }
        }
        return z;
    }

    @Override // org.apache.sysml.hops.Hop
    public void printMe() throws HopsException {
        if (LOG.isDebugEnabled()) {
            if (getVisited() != Hop.VisitStatus.DONE) {
                super.printMe();
                LOG.debug("  Operation: " + this.op);
                Iterator<Hop> it = getInput().iterator();
                while (it.hasNext()) {
                    it.next().printMe();
                }
            }
            setVisited(Hop.VisitStatus.DONE);
        }
    }

    @Override // org.apache.sysml.hops.Hop.MultiThreadedHop
    public void setMaxNumThreads(int i) {
        this._maxNumThreads = i;
    }

    @Override // org.apache.sysml.hops.Hop.MultiThreadedHop
    public int getMaxNumThreads() {
        return this._maxNumThreads;
    }
}
