package org.apache.sysml.hops.globalopt;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.cost.CostEstimationWrapper;
import org.apache.sysml.hops.globalopt.gdfgraph.GDFGraph;
import org.apache.sysml.hops.globalopt.gdfgraph.GDFLoopNode;
import org.apache.sysml.hops.globalopt.gdfgraph.GDFNode;
import org.apache.sysml.hops.globalopt.gdfresolve.GDFMismatchHeuristic;
import org.apache.sysml.hops.globalopt.gdfresolve.MismatchHeuristicFactory;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.matrix.mapred.GMRCtableBuffer;
import org.apache.sysml.yarn.DMLYarnClient;

/* loaded from: input_file:org/apache/sysml/hops/globalopt/GDFEnumOptimizer.class */
public class GDFEnumOptimizer extends GlobalOptimizer {
    private static final boolean BRANCH_AND_BOUND_PRUNING = true;
    private static final boolean PREFERRED_PLAN_SELECTION = true;
    private static final boolean COST_FULL_PROGRAMS = false;
    private static final boolean ENUM_CP_BLOCKSIZES = false;
    private MemoStructure _memo;
    private static final Log LOG = LogFactory.getLog(GDFEnumOptimizer.class);
    private static final GDFMismatchHeuristic.MismatchHeuristicType DEFAULT_MISMATCH_HEURISTIC = GDFMismatchHeuristic.MismatchHeuristicType.FIRST;
    private static final int[] BLOCK_SIZES = {1024, DMLYarnClient.MAX_MEM_OVERHEAD, GMRCtableBuffer.MAX_BUFFER_SIZE};
    private static final double BRANCH_AND_BOUND_REL_THRES = Math.pow(10.0d, -5.0d);
    private static GDFMismatchHeuristic _resolve = null;
    private static long _enumeratedPlans = 0;
    private static long _prunedInvalidPlans = 0;
    private static long _prunedSuboptimalPlans = 0;
    private static long _compiledPlans = 0;
    private static long _costedPlans = 0;
    private static long _planMismatches = 0;

    public GDFEnumOptimizer() throws DMLRuntimeException {
        this._memo = null;
        this._memo = new MemoStructure();
        _resolve = MismatchHeuristicFactory.createMismatchHeuristic(DEFAULT_MISMATCH_HEURISTIC);
    }

    @Override // org.apache.sysml.hops.globalopt.GlobalOptimizer
    public GDFGraph optimize(GDFGraph gDFGraph, Summary summary) throws DMLRuntimeException, HopsException, LopsException {
        Timing timing = new Timing(true);
        Program runtimeProgram = gDFGraph.getRuntimeProgram();
        ExecutionContext createContext = ExecutionContextFactory.createContext(runtimeProgram);
        ArrayList<GDFNode> graphRootNodes = gDFGraph.getGraphRootNodes();
        double timeEstimate = CostEstimationWrapper.getTimeEstimate(runtimeProgram, createContext) * (1.0d + BRANCH_AND_BOUND_REL_THRES);
        ArrayList arrayList = new ArrayList();
        Iterator<GDFNode> it = graphRootNodes.iterator();
        while (it.hasNext()) {
            arrayList.add(enumOpt(it.next(), this._memo, timeEstimate).getPlanWithMinCosts());
        }
        long planMismatches = getPlanMismatches();
        HashMap hashMap = new HashMap();
        resetPlanMismatches();
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            rSetRuntimePlanConfig((Plan) it2.next(), hashMap);
        }
        long planMismatches2 = getPlanMismatches();
        Recompiler.recompileProgramBlockHierarchy(runtimeProgram.getProgramBlocks(), new LocalVariableMap(), 0L, false);
        double timeEstimate2 = CostEstimationWrapper.getTimeEstimate(runtimeProgram, ExecutionContextFactory.createContext(runtimeProgram));
        summary.setCostsInitial(timeEstimate);
        summary.setCostsOptimal(timeEstimate2);
        summary.setNumEnumPlans(_enumeratedPlans);
        summary.setNumPrunedInvalidPlans(_prunedInvalidPlans);
        summary.setNumPrunedSuboptPlans(_prunedSuboptimalPlans);
        summary.setNumCompiledPlans(_compiledPlans);
        summary.setNumCostedPlans(_costedPlans);
        summary.setNumEnumPlanMismatch(planMismatches);
        summary.setNumFinalPlanMismatch(planMismatches2);
        summary.setTimeOptim(timing.stop());
        return gDFGraph;
    }

    public static PlanSet enumOpt(GDFNode gDFNode, MemoStructure memoStructure, double d) throws DMLRuntimeException {
        if (memoStructure.constainsEntry(gDFNode)) {
            return memoStructure.getEntry(gDFNode);
        }
        PlanSet enumNodePlans = enumNodePlans(gDFNode, memoStructure, d);
        Iterator<GDFNode> it = gDFNode.getInputs().iterator();
        while (it.hasNext()) {
            GDFNode next = it.next();
            PlanSet enumOpt = enumOpt(next, memoStructure, d);
            if (next instanceof GDFLoopNode) {
                enumOpt = enumOpt.selectChild(gDFNode);
            }
            enumNodePlans = enumNodePlans.crossProductChild(enumOpt);
            _enumeratedPlans += enumNodePlans.size();
            pruneInvalidPlans(enumNodePlans);
        }
        pruneSuboptimalPlans(enumNodePlans, d);
        memoStructure.putEntry(gDFNode, enumNodePlans);
        return enumNodePlans;
    }

    private static PlanSet enumNodePlans(GDFNode gDFNode, MemoStructure memoStructure, double d) throws DMLRuntimeException {
        ArrayList arrayList = new ArrayList();
        LopProperties.ExecType execType = OptimizerUtils.isSparkExecutionMode() ? LopProperties.ExecType.SPARK : LopProperties.ExecType.MR;
        if (gDFNode.getNodeType() == GDFNode.NodeType.HOP_NODE && !(gDFNode.getHop() instanceof DataOp)) {
            enumHopNodePlans(gDFNode, arrayList);
        } else if (gDFNode.getHop() instanceof DataOp) {
            DataOp dataOp = (DataOp) gDFNode.getHop();
            if (dataOp.getDataOpType() == Hop.DataOpTypes.PERSISTENTREAD) {
                LopProperties.ExecType execType2 = (dataOp.getMemEstimate() > OptimizerUtils.getLocalMemBudget() || HopRewriteUtils.alwaysRequiresReblock(dataOp)) ? execType : LopProperties.ExecType.CP;
                for (int i : execType2 == execType ? BLOCK_SIZES : new int[]{BLOCK_SIZES[0]}) {
                    RewriteConfig rewriteConfig = new RewriteConfig(execType2, Integer.valueOf(i).intValue(), Hop.FileFormatTypes.BINARY);
                    arrayList.add(new Plan(gDFNode, rewriteConfig.deriveInterestingProperties(), rewriteConfig, null));
                }
            } else if (dataOp.getDataOpType() == Hop.DataOpTypes.PERSISTENTWRITE) {
                RewriteConfig rewriteConfig2 = new RewriteConfig(dataOp.getMemEstimate() > OptimizerUtils.getLocalMemBudget() ? execType : LopProperties.ExecType.CP, (int) dataOp.getRowsInBlock(), dataOp.getInputFormatType());
                arrayList.add(new Plan(gDFNode, rewriteConfig2.deriveInterestingProperties(), rewriteConfig2, null));
            } else if (dataOp.getDataOpType() == Hop.DataOpTypes.TRANSIENTREAD || dataOp.getDataOpType() == Hop.DataOpTypes.TRANSIENTWRITE) {
                enumHopNodePlans(gDFNode, arrayList);
            }
        } else if (gDFNode.getNodeType() == GDFNode.NodeType.LOOP_NODE) {
            GDFLoopNode gDFLoopNode = (GDFLoopNode) gDFNode;
            Iterator<GDFNode> it = gDFLoopNode.getLoopInputs().values().iterator();
            while (it.hasNext()) {
                enumOpt(it.next(), memoStructure, d);
            }
            RewriteConfig rewriteConfig3 = new RewriteConfig(LopProperties.ExecType.CP, -1, null);
            arrayList.add(new Plan(gDFNode, rewriteConfig3.deriveInterestingProperties(), rewriteConfig3, null));
            if (gDFLoopNode.getLoopPredicate() != null) {
                enumOpt(gDFLoopNode.getLoopPredicate(), memoStructure, d);
            }
            PlanSet planSet = new PlanSet();
            Iterator<GDFNode> it2 = gDFLoopNode.getLoopOutputs().values().iterator();
            while (it2.hasNext()) {
                planSet = planSet.union(enumOpt(it2.next(), memoStructure, d));
            }
            arrayList.addAll(planSet.getPlans());
        } else if (gDFNode.getNodeType() == GDFNode.NodeType.CROSS_BLOCK_NODE) {
        }
        return new PlanSet(arrayList);
    }

    private static void enumHopNodePlans(GDFNode gDFNode, ArrayList<Plan> arrayList) {
        LopProperties.ExecType execType = OptimizerUtils.isSparkExecutionMode() ? LopProperties.ExecType.SPARK : LopProperties.ExecType.MR;
        if (gDFNode.getHop().getMemEstimate() < OptimizerUtils.getLocalMemBudget()) {
            for (int i : new int[]{BLOCK_SIZES[0]}) {
                RewriteConfig rewriteConfig = new RewriteConfig(LopProperties.ExecType.CP, Integer.valueOf(i).intValue(), Hop.FileFormatTypes.BINARY);
                arrayList.add(new Plan(gDFNode, rewriteConfig.deriveInterestingProperties(), rewriteConfig, null));
            }
        }
        if (gDFNode.requiresMREnumeration()) {
            for (int i2 : BLOCK_SIZES) {
                RewriteConfig rewriteConfig2 = new RewriteConfig(execType, Integer.valueOf(i2).intValue(), Hop.FileFormatTypes.BINARY);
                arrayList.add(new Plan(gDFNode, rewriteConfig2.deriveInterestingProperties(), rewriteConfig2, null));
            }
        }
    }

    private static void pruneInvalidPlans(PlanSet planSet) {
        ArrayList<Plan> arrayList = new ArrayList<>();
        Iterator<Plan> it = planSet.getPlans().iterator();
        while (it.hasNext()) {
            Plan next = it.next();
            if (next.checkValidBlocksizesInMR() && next.checkValidBlocksizesTRead() && next.checkValidFormatInMR() && next.checkValidExecutionType()) {
                arrayList.add(next);
            }
        }
        int size = planSet.size();
        int size2 = arrayList.size();
        _prunedInvalidPlans += size - size2;
        LOG.debug("Pruned invalid plans: " + size + " --> " + size2);
        planSet.setPlans(arrayList);
    }

    private static void pruneSuboptimalPlans(PlanSet planSet, double d) throws DMLRuntimeException {
        Plan plan;
        Iterator<Plan> it = planSet.getPlans().iterator();
        while (it.hasNext()) {
            Plan next = it.next();
            next.setCosts(costRuntimePlan(next));
        }
        HashMap hashMap = new HashMap();
        Iterator<Plan> it2 = planSet.getPlans().iterator();
        while (it2.hasNext()) {
            Plan next2 = it2.next();
            if (next2.getCosts() <= d && ((plan = (Plan) hashMap.get(next2.getInterestingProperties())) == null || next2.getCosts() <= plan.getCosts())) {
                if (plan == null || next2.getCosts() != plan.getCosts() || next2.isPreferredPlan()) {
                    hashMap.put(next2.getInterestingProperties(), next2);
                }
            }
        }
        ArrayList<Plan> arrayList = new ArrayList<>((Collection<? extends Plan>) hashMap.values());
        int size = planSet.size();
        int size2 = arrayList.size();
        _prunedSuboptimalPlans += size - size2;
        LOG.debug("Pruned suboptimal plans: " + size + " --> " + size2);
        planSet.setPlans(arrayList);
    }

    private static double costRuntimePlan(Plan plan) throws DMLRuntimeException {
        double timeEstimate;
        Program program = plan.getNode().getProgram();
        if (program == null) {
            throw new DMLRuntimeException("Program not available for runtime plan costing.");
        }
        rSetRuntimePlanConfig(plan, new HashMap());
        if (plan.getNode().getHop() == null || plan.getNode().getProgramBlock() == null) {
            Recompiler.recompileProgramBlockHierarchy(program.getProgramBlocks(), new LocalVariableMap(), 0L, false);
            _compiledPlans++;
            timeEstimate = CostEstimationWrapper.getTimeEstimate(program, ExecutionContextFactory.createContext(program));
        } else {
            Hop hop = plan.getNode().getHop();
            ProgramBlock programBlock = plan.getNode().getProgramBlock();
            try {
                ArrayList<Hop> arrayList = programBlock.getStatementBlock().get_hops();
                DataOp dataOp = null;
                if (!(hop instanceof DataOp) || !((DataOp) hop).isWrite()) {
                    ArrayList<Hop> arrayList2 = new ArrayList<>();
                    dataOp = new DataOp("_tmp", hop.getDataType(), hop.getValueType(), hop, Hop.DataOpTypes.TRANSIENTWRITE, "tmp");
                    dataOp.setVisited(Hop.VisitStatus.DONE);
                    arrayList2.add(dataOp);
                    programBlock.getStatementBlock().set_hops(arrayList2);
                }
                Recompiler.recompileProgramBlockHierarchy(program.getProgramBlocks(), new LocalVariableMap(), 0L, false);
                _compiledPlans++;
                timeEstimate = CostEstimationWrapper.getTimeEstimate(program, ExecutionContextFactory.createContext(program));
                if (dataOp != null) {
                    HopRewriteUtils.removeChildReference(dataOp, hop);
                }
                programBlock.getStatementBlock().set_hops(arrayList);
            } catch (HopsException e) {
                throw new DMLRuntimeException(e);
            }
        }
        rResetRuntimePlanConfig(plan, new HashMap());
        _costedPlans++;
        return timeEstimate;
    }

    private static void rSetRuntimePlanConfig(Plan plan, HashMap<Long, Plan> hashMap) {
        LopProperties.ExecType execType = OptimizerUtils.isSparkExecutionMode() ? LopProperties.ExecType.SPARK : LopProperties.ExecType.MR;
        if (hashMap.containsKey(Long.valueOf(plan.getNode().getID()))) {
            Plan plan2 = hashMap.get(Long.valueOf(plan.getNode().getID()));
            if (!plan.getInterestingProperties().equals(plan2.getInterestingProperties())) {
                if (_resolve.resolveMismatch(plan2.getRewriteConfig(), plan.getRewriteConfig())) {
                    hashMap.put(Long.valueOf(plan.getNode().getID()), plan);
                }
                LOG.warn("Configuration mismatch on shared node (" + plan.getNode().getHop().getHopID() + "). Falling back to heuristic '" + _resolve.getName() + "'.");
                LOG.warn(plan.getInterestingProperties().toString());
                LOG.warn(hashMap.get(Long.valueOf(plan.getNode().getID())).getInterestingProperties());
                _planMismatches++;
                return;
            }
        }
        Hop hop = plan.getNode().getHop();
        if (hop != null) {
            RewriteConfig rewriteConfig = plan.getRewriteConfig();
            hop.setForcedExecType(rewriteConfig.getExecType());
            hop.setRowsInBlock(rewriteConfig.getBlockSize());
            hop.setColsInBlock(rewriteConfig.getBlockSize());
            if (rewriteConfig.getExecType() == execType) {
                hop.setRequiresReblock(HopRewriteUtils.alwaysRequiresReblock(hop) || (hop.hasMatrixInputWithDifferentBlocksizes() && !(hop instanceof DataOp)));
            } else {
                hop.setRequiresReblock(false);
            }
        }
        if (plan.getChilds() != null) {
            Iterator<Plan> it = plan.getChilds().iterator();
            while (it.hasNext()) {
                rSetRuntimePlanConfig(it.next(), hashMap);
            }
        }
        hashMap.put(Long.valueOf(plan.getNode().getID()), plan);
    }

    private static void rResetRuntimePlanConfig(Plan plan, HashMap<Long, Plan> hashMap) {
        if (hashMap.containsKey(Long.valueOf(plan.getNode().getID()))) {
            return;
        }
        Hop hop = plan.getNode().getHop();
        if (hop != null) {
            hop.setForcedExecType(null);
            hop.setRowsInBlock(ConfigurationManager.getBlocksize());
            hop.setColsInBlock(ConfigurationManager.getBlocksize());
            if (!HopRewriteUtils.alwaysRequiresReblock(hop)) {
                hop.setRequiresReblock(false);
            }
        }
        if (plan.getChilds() != null) {
            Iterator<Plan> it = plan.getChilds().iterator();
            while (it.hasNext()) {
                rResetRuntimePlanConfig(it.next(), hashMap);
            }
        }
        hashMap.put(Long.valueOf(plan.getNode().getID()), plan);
    }

    private static long getPlanMismatches() {
        return _planMismatches;
    }

    private static void resetPlanMismatches() {
        _planMismatches = 0L;
    }
}
