package org.apache.sysml.runtime.controlprogram.parfor.opt;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.parser.DataExpression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.parfor.opt.OptNode;
import org.apache.sysml.runtime.controlprogram.parfor.opt.PerfTestTool;

/* loaded from: input_file:org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimator.class */
public abstract class CostEstimator {
    protected static final Log LOG = LogFactory.getLog(CostEstimator.class.getName());
    public static final double DEFAULT_EST_PARALLELISM = 1.0d;
    public static final long FACTOR_NUM_ITERATIONS = 10;
    public static final double DEFAULT_TIME_ESTIMATE = 5.0d;
    public static final double DEFAULT_MEM_ESTIMATE_CP = 1024.0d;
    public static final double DEFAULT_MEM_ESTIMATE_MR = 1.048576E7d;

    public abstract double getLeafNodeEstimate(PerfTestTool.TestMeasure testMeasure, OptNode optNode) throws DMLRuntimeException;

    public abstract double getLeafNodeEstimate(PerfTestTool.TestMeasure testMeasure, OptNode optNode, LopProperties.ExecType execType) throws DMLRuntimeException;

    public double getEstimate(PerfTestTool.TestMeasure testMeasure, OptNode optNode) throws DMLRuntimeException {
        return getEstimate(testMeasure, optNode, null);
    }

    public double getEstimate(PerfTestTool.TestMeasure testMeasure, OptNode optNode, LopProperties.ExecType execType) throws DMLRuntimeException {
        double d = -1.0d;
        if (!optNode.isLeaf()) {
            switch (testMeasure) {
                case EXEC_TIME:
                    switch (optNode.getNodeType()) {
                        case GENERIC:
                        case FUNCCALL:
                            d = getSumEstimate(testMeasure, optNode.getChilds(), execType);
                            break;
                        case IF:
                            if (optNode.getChilds().size() != 2) {
                                d = getMaxEstimate(testMeasure, optNode.getChilds(), execType);
                                break;
                            } else {
                                d = getWeightedEstimate(testMeasure, optNode.getChilds(), execType);
                                break;
                            }
                        case WHILE:
                            d = 10.0d * getSumEstimate(testMeasure, optNode.getChilds(), execType);
                            break;
                        case FOR:
                            d = (optNode.getParam(OptNode.ParamType.NUM_ITERATIONS) != null ? Long.parseLong(r0) : 10.0d) * getSumEstimate(testMeasure, optNode.getChilds(), execType);
                            break;
                        case PARFOR:
                            d = ((optNode.getParam(OptNode.ParamType.NUM_ITERATIONS) != null ? Long.parseLong(r0) : 10.0d) * getSumEstimate(testMeasure, optNode.getChilds(), execType)) / optNode.getK();
                            break;
                    }
                case MEMORY_USAGE:
                    switch (optNode.getNodeType()) {
                        case GENERIC:
                        case FUNCCALL:
                        case IF:
                        case WHILE:
                        case FOR:
                            d = getMaxEstimate(testMeasure, optNode.getChilds(), execType);
                            break;
                        case PARFOR:
                            if (optNode.getExecType() != OptNode.ExecType.MR) {
                                if (optNode.getExecType() == OptNode.ExecType.CP) {
                                    d = getMaxEstimate(testMeasure, optNode.getChilds(), execType) * optNode.getK();
                                    break;
                                }
                            } else {
                                d = getMaxEstimate(testMeasure, optNode.getChilds(), execType);
                                break;
                            }
                            break;
                    }
            }
        } else {
            d = execType != null ? getLeafNodeEstimate(testMeasure, optNode, execType) : getLeafNodeEstimate(testMeasure, optNode);
        }
        return d;
    }

    public double computeLocalParBound(OptTree optTree, OptNode optNode) {
        return Math.floor(rComputeLocalValueBound(optTree.getRoot(), optNode, optTree.getCK()));
    }

    public double computeLocalMemoryBound(OptTree optTree, OptNode optNode) {
        return rComputeLocalValueBound(optTree.getRoot(), optNode, optTree.getCM());
    }

    public double getMinMemoryUsage(OptNode optNode) {
        throw new RuntimeException("Not implemented yet.");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getDefaultEstimate(PerfTestTool.TestMeasure testMeasure) {
        double d = -1.0d;
        switch (testMeasure) {
            case EXEC_TIME:
                d = 5.0d;
                break;
            case MEMORY_USAGE:
                d = 1024.0d;
                break;
        }
        return d;
    }

    protected double getMaxEstimate(PerfTestTool.TestMeasure testMeasure, ArrayList<OptNode> arrayList, LopProperties.ExecType execType) throws DMLRuntimeException {
        double d = Double.MIN_VALUE;
        Iterator<OptNode> it = arrayList.iterator();
        while (it.hasNext()) {
            double estimate = getEstimate(testMeasure, it.next(), execType);
            if (estimate > d) {
                d = estimate;
            }
        }
        return d;
    }

    protected double getSumEstimate(PerfTestTool.TestMeasure testMeasure, ArrayList<OptNode> arrayList, LopProperties.ExecType execType) throws DMLRuntimeException {
        double d = 0.0d;
        Iterator<OptNode> it = arrayList.iterator();
        while (it.hasNext()) {
            d += getEstimate(testMeasure, it.next(), execType);
        }
        return d;
    }

    protected double getWeightedEstimate(PerfTestTool.TestMeasure testMeasure, ArrayList<OptNode> arrayList, LopProperties.ExecType execType) throws DMLRuntimeException {
        double d = 0.0d;
        int size = arrayList.size();
        Iterator<OptNode> it = arrayList.iterator();
        while (it.hasNext()) {
            d += getEstimate(testMeasure, it.next(), execType);
        }
        return d / size;
    }

    protected double rComputeLocalValueBound(OptNode optNode, OptNode optNode2, double d) {
        if (optNode == optNode2) {
            return d;
        }
        if (optNode.isLeaf()) {
            return -1.0d;
        }
        switch (optNode.getNodeType()) {
            case GENERIC:
            case FUNCCALL:
            case IF:
            case WHILE:
            case FOR:
                Iterator<OptNode> it = optNode.getChilds().iterator();
                while (it.hasNext()) {
                    double rComputeLocalValueBound = rComputeLocalValueBound(it.next(), optNode2, d);
                    if (rComputeLocalValueBound > DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                        return rComputeLocalValueBound;
                    }
                }
                return -1.0d;
            case PARFOR:
                Iterator<OptNode> it2 = optNode.getChilds().iterator();
                while (it2.hasNext()) {
                    double rComputeLocalValueBound2 = rComputeLocalValueBound(it2.next(), optNode2, d / optNode.getK());
                    if (rComputeLocalValueBound2 > DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                        return rComputeLocalValueBound2;
                    }
                }
                return -1.0d;
            default:
                return -1.0d;
        }
    }
}
