package org.apache.sysml.hops;

import java.util.HashMap;
import java.util.Map;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.lops.DataGen;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.parser.DataExpression;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.parser.Statement;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysml/hops/DataGenOp.class */
public class DataGenOp extends Hop implements Hop.MultiThreadedHop {
    public static final long UNSPECIFIED_SEED = -1;
    private Hop.DataGenMethod _op;
    private int _maxNumThreads;
    private HashMap<String, Integer> _paramIndexMap;
    private DataIdentifier _id;
    private double _sparsity;
    private String _baseDir;
    private double _incr;

    private DataGenOp() {
        this._maxNumThreads = -1;
        this._paramIndexMap = new HashMap<>();
        this._sparsity = -1.0d;
        this._incr = Double.MAX_VALUE;
    }

    public DataGenOp(Hop.DataGenMethod dataGenMethod, DataIdentifier dataIdentifier, HashMap<String, Hop> hashMap) {
        super(dataIdentifier.getName(), Expression.DataType.MATRIX, Expression.ValueType.DOUBLE);
        this._maxNumThreads = -1;
        this._paramIndexMap = new HashMap<>();
        this._sparsity = -1.0d;
        this._incr = Double.MAX_VALUE;
        this._id = dataIdentifier;
        this._op = dataGenMethod;
        int i = 0;
        for (Map.Entry<String, Hop> entry : hashMap.entrySet()) {
            String key = entry.getKey();
            Hop value = entry.getValue();
            getInput().add(value);
            value.getParent().add(this);
            this._paramIndexMap.put(key, Integer.valueOf(i));
            i++;
        }
        Hop hop = hashMap.get(DataExpression.RAND_SPARSITY);
        if (dataGenMethod == Hop.DataGenMethod.RAND && (hop instanceof LiteralOp)) {
            this._sparsity = Double.valueOf(((LiteralOp) hop).getName()).doubleValue();
        }
        this._baseDir = ConfigurationManager.getScratchSpace() + "/" + Lop.PROCESS_PREFIX + DMLScript.getUUID() + "//" + ProgramConverter.CP_ROOT_THREAD_ID + "/";
        refreshSizeInformation();
    }

    @Override // org.apache.sysml.hops.Hop
    public void checkArity() throws HopsException {
        int size = this._input.size();
        int size2 = this._paramIndexMap.size();
        HopsException.check(size == size2, this, "has %d inputs but %d parameters", Integer.valueOf(size), Integer.valueOf(size2));
    }

    @Override // org.apache.sysml.hops.Hop
    public String getOpString() {
        return "dg(" + this._op.toString().toLowerCase() + ")";
    }

    public Hop.DataGenMethod getOp() {
        return this._op;
    }

    @Override // org.apache.sysml.hops.Hop.MultiThreadedHop
    public void setMaxNumThreads(int i) {
        this._maxNumThreads = i;
    }

    @Override // org.apache.sysml.hops.Hop.MultiThreadedHop
    public int getMaxNumThreads() {
        return this._maxNumThreads;
    }

    @Override // org.apache.sysml.hops.Hop
    public boolean isGPUEnabled() {
        return false;
    }

    @Override // org.apache.sysml.hops.Hop
    public Lop constructLops() throws HopsException, LopsException {
        if (getLops() != null) {
            return getLops();
        }
        LopProperties.ExecType optFindExecType = optFindExecType();
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, Integer> entry : this._paramIndexMap.entrySet()) {
            if (entry.getKey().equals("rows") && rowsKnown()) {
                hashMap.put(entry.getKey(), new LiteralOp(this._dim1).constructLops());
            } else if (entry.getKey().equals("cols") && colsKnown()) {
                hashMap.put(entry.getKey(), new LiteralOp(this._dim2).constructLops());
            } else {
                hashMap.put(entry.getKey(), getInput().get(entry.getValue().intValue()).constructLops());
            }
        }
        DataGen dataGen = new DataGen(this._op, this._id, hashMap, this._baseDir, getDataType(), getValueType(), optFindExecType);
        dataGen.setNumThreads(OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads));
        dataGen.getOutputParameters().setDimensions(getDim1(), getDim2(), getRowsInBlock() > 0 ? getRowsInBlock() : ConfigurationManager.getBlocksize(), getColsInBlock() > 0 ? getColsInBlock() : ConfigurationManager.getBlocksize(), (this._op == Hop.DataGenMethod.RAND && optFindExecType == LopProperties.ExecType.SPARK && getNnz() != 0) ? -1L : getNnz(), getUpdateType());
        setLineNumbers(dataGen);
        setLops(dataGen);
        constructAndSetLopsDataFlowProperties();
        return getLops();
    }

    @Override // org.apache.sysml.hops.Hop
    public boolean allowsAllExecTypes() {
        return true;
    }

    @Override // org.apache.sysml.hops.Hop
    protected double computeOutputMemEstimate(long j, long j2, long j3) {
        return (this._op != Hop.DataGenMethod.RAND || this._sparsity == -1.0d) ? OptimizerUtils.estimateSizeExactSparsity(j, j2, 1.0d) : hasConstantValue(0.0d) ? OptimizerUtils.estimateSizeEmptyBlock(j, j2) : OptimizerUtils.estimateSizeExactSparsity(j, j2, this._sparsity);
    }

    @Override // org.apache.sysml.hops.Hop
    protected double computeIntermediateMemEstimate(long j, long j2, long j3) {
        if (this._op == Hop.DataGenMethod.RAND && dimsKnown()) {
            return 32.0d + (((long) (Math.ceil(j / ConfigurationManager.getBlocksize()) * Math.ceil(j2 / ConfigurationManager.getBlocksize()))) * 8.0d);
        }
        return 0.0d;
    }

    @Override // org.apache.sysml.hops.Hop
    protected long[] inferOutputCharacteristics(MemoTable memoTable) {
        if ((this._op == Hop.DataGenMethod.RAND || this._op == Hop.DataGenMethod.SINIT) && OptimizerUtils.ALLOW_WORSTCASE_SIZE_EXPRESSION_EVALUATION) {
            long computeDimParameterInformation = computeDimParameterInformation(getInput().get(this._paramIndexMap.get("rows").intValue()), memoTable);
            long computeDimParameterInformation2 = computeDimParameterInformation(getInput().get(this._paramIndexMap.get("cols").intValue()), memoTable);
            long j = this._sparsity >= 0.0d ? (long) (this._sparsity * computeDimParameterInformation * computeDimParameterInformation2) : -1L;
            if (computeDimParameterInformation < 0 || computeDimParameterInformation2 < 0) {
                return null;
            }
            return new long[]{computeDimParameterInformation, computeDimParameterInformation2, j};
        }
        if (this._op != Hop.DataGenMethod.SEQ) {
            return null;
        }
        Hop hop = getInput().get(this._paramIndexMap.get(Statement.SEQ_FROM).intValue());
        Hop hop2 = getInput().get(this._paramIndexMap.get(Statement.SEQ_TO).intValue());
        Hop hop3 = getInput().get(this._paramIndexMap.get(Statement.SEQ_INCR).intValue());
        if ((hop instanceof LiteralOp) && HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop) == 1.0d && (hop3 instanceof LiteralOp) && HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop3) == 1.0d) {
            long computeDimParameterInformation3 = computeDimParameterInformation(hop2, memoTable);
            if (computeDimParameterInformation3 > 0) {
                return new long[]{computeDimParameterInformation3, 1, -1};
            }
        }
        if (!(hop2 instanceof LiteralOp) || HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop2) != 1.0d || !(hop3 instanceof LiteralOp) || HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop3) != -1.0d) {
            return null;
        }
        long computeDimParameterInformation4 = computeDimParameterInformation(hop, memoTable);
        if (computeDimParameterInformation4 > 0) {
            return new long[]{computeDimParameterInformation4, 1, -1};
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysml.hops.Hop
    public LopProperties.ExecType optFindExecType() throws HopsException {
        checkAndSetForcedPlatform();
        LopProperties.ExecType execType = OptimizerUtils.isSparkExecutionMode() ? LopProperties.ExecType.SPARK : LopProperties.ExecType.MR;
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            if (OptimizerUtils.isMemoryBasedOptLevel()) {
                this._etype = findExecTypeByMemEstimate();
            } else if (areDimsBelowThreshold() || isVector()) {
                this._etype = LopProperties.ExecType.CP;
            } else {
                this._etype = execType;
            }
            checkAndSetInvalidCPDimsAndSize();
        }
        setRequiresRecompileIfNecessary();
        if (this._op == Hop.DataGenMethod.SINIT) {
            this._etype = LopProperties.ExecType.CP;
        }
        if (this._op == Hop.DataGenMethod.SAMPLE && this._etype == LopProperties.ExecType.MR) {
            this._etype = LopProperties.ExecType.CP;
        }
        return this._etype;
    }

    @Override // org.apache.sysml.hops.Hop
    public void refreshSizeInformation() {
        if (this._op == Hop.DataGenMethod.RAND || this._op == Hop.DataGenMethod.SINIT) {
            Hop hop = getInput().get(this._paramIndexMap.get("rows").intValue());
            Hop hop2 = getInput().get(this._paramIndexMap.get("cols").intValue());
            refreshRowsParameterInformation(hop);
            refreshColsParameterInformation(hop2);
        } else if (this._op == Hop.DataGenMethod.SEQ) {
            Hop hop3 = getInput().get(this._paramIndexMap.get(Statement.SEQ_FROM).intValue());
            Hop hop4 = getInput().get(this._paramIndexMap.get(Statement.SEQ_TO).intValue());
            Hop hop5 = getInput().get(this._paramIndexMap.get(Statement.SEQ_INCR).intValue());
            double computeBoundsInformation = computeBoundsInformation(hop3);
            boolean z = computeBoundsInformation != Double.MAX_VALUE;
            double computeBoundsInformation2 = computeBoundsInformation(hop4);
            boolean z2 = computeBoundsInformation2 != Double.MAX_VALUE;
            double computeBoundsInformation3 = computeBoundsInformation(hop5);
            boolean z3 = computeBoundsInformation3 != Double.MAX_VALUE;
            if (z && z2 && computeBoundsInformation3 == 1.0d) {
                computeBoundsInformation3 = computeBoundsInformation >= computeBoundsInformation2 ? -1.0d : 1.0d;
            }
            if (z && z2 && z3) {
                setDim1(UtilFunctions.getSeqLength(computeBoundsInformation, computeBoundsInformation2, computeBoundsInformation3, false));
                setDim2(1L);
                this._incr = computeBoundsInformation3;
            }
        }
        if (this._op == Hop.DataGenMethod.RAND && hasConstantValue(0.0d)) {
            this._nnz = 0L;
        } else if (!dimsKnown() || this._sparsity < 0.0d) {
            this._nnz = -1L;
        } else {
            this._nnz = (long) (this._sparsity * this._dim1 * this._dim2);
        }
    }

    public HashMap<String, Integer> getParamIndexMap() {
        return this._paramIndexMap;
    }

    public int getParamIndex(String str) {
        return this._paramIndexMap.get(str).intValue();
    }

    public Hop getInput(String str) {
        return getInput().get(getParamIndex(str));
    }

    public void setInput(String str, Hop hop) {
        getInput().set(getParamIndex(str), hop);
    }

    public boolean hasConstantValue() {
        if (this._op != Hop.DataGenMethod.RAND) {
            return false;
        }
        Hop hop = getInput().get(this._paramIndexMap.get("min").intValue());
        Hop hop2 = getInput().get(this._paramIndexMap.get("max").intValue());
        Hop hop3 = getInput().get(this._paramIndexMap.get(DataExpression.RAND_SPARSITY).intValue());
        if (!(hop instanceof LiteralOp) || !(hop2 instanceof LiteralOp) || !(hop3 instanceof LiteralOp)) {
            return hop == hop2 && (hop3 instanceof LiteralOp) && HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop3) == 1.0d;
        }
        try {
            return HopRewriteUtils.getDoubleValue((LiteralOp) hop3) == 1.0d && HopRewriteUtils.getDoubleValue((LiteralOp) hop) == HopRewriteUtils.getDoubleValue((LiteralOp) hop2);
        } catch (Exception e) {
            return false;
        }
    }

    public boolean hasConstantValue(double d) {
        if (this._op != Hop.DataGenMethod.RAND) {
            return false;
        }
        boolean z = false;
        Hop hop = getInput().get(this._paramIndexMap.get("min").intValue());
        Hop hop2 = getInput().get(this._paramIndexMap.get("max").intValue());
        if ((hop instanceof LiteralOp) && (hop2 instanceof LiteralOp)) {
            z = HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop) == d && HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop2) == d;
        }
        if (z && d != 0.0d) {
            Hop hop3 = getInput().get(this._paramIndexMap.get(DataExpression.RAND_SPARSITY).intValue());
            z &= hop3 == null || ((hop3 instanceof LiteralOp) && HopRewriteUtils.getDoubleValueSafe((LiteralOp) hop3) == 1.0d);
        }
        return z;
    }

    public Hop getConstantValue() {
        return getInput().get(this._paramIndexMap.get("min").intValue());
    }

    public void setIncrementValue(double d) {
        this._incr = d;
    }

    public double getIncrementValue() {
        return this._incr;
    }

    public static long generateRandomSeed() {
        return System.nanoTime();
    }

    @Override // org.apache.sysml.hops.Hop
    public Object clone() throws CloneNotSupportedException {
        DataGenOp dataGenOp = new DataGenOp();
        dataGenOp.clone(this, false);
        dataGenOp._op = this._op;
        dataGenOp._id = this._id;
        dataGenOp._sparsity = this._sparsity;
        dataGenOp._baseDir = this._baseDir;
        dataGenOp._paramIndexMap = (HashMap) this._paramIndexMap.clone();
        dataGenOp._maxNumThreads = this._maxNumThreads;
        return dataGenOp;
    }

    @Override // org.apache.sysml.hops.Hop
    public boolean compare(Hop hop) {
        if (!(hop instanceof DataGenOp)) {
            return false;
        }
        DataGenOp dataGenOp = (DataGenOp) hop;
        boolean z = this._op == dataGenOp._op && this._sparsity == dataGenOp._sparsity && this._baseDir.equals(dataGenOp._baseDir) && this._paramIndexMap != null && dataGenOp._paramIndexMap != null && this._maxNumThreads == dataGenOp._maxNumThreads;
        if (z) {
            for (Map.Entry<String, Integer> entry : this._paramIndexMap.entrySet()) {
                String key = entry.getKey();
                int intValue = entry.getValue().intValue();
                int intValue2 = dataGenOp._paramIndexMap.get(key).intValue();
                z &= dataGenOp.getInput().get(intValue2) != null && getInput().get(intValue) == dataGenOp.getInput().get(intValue2);
            }
            if (this._op == Hop.DataGenMethod.RAND || this._op == Hop.DataGenMethod.SINIT) {
                Hop hop2 = getInput().get(this._paramIndexMap.get(DataExpression.RAND_SEED).intValue());
                Hop hop3 = getInput().get(this._paramIndexMap.get("min").intValue());
                Hop hop4 = getInput().get(this._paramIndexMap.get("max").intValue());
                if (hop2.getName().equals(String.valueOf(-1L)) && hop3 != hop4) {
                    z = false;
                }
            }
        }
        return z;
    }
}
