package org.apache.sysml.runtime.controlprogram.parfor.opt;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.ipa.InterProceduralAnalysis;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.hops.rewrite.HopRewriteRule;
import org.apache.sysml.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysml.hops.rewrite.ProgramRewriter;
import org.apache.sysml.hops.rewrite.RewriteConstantFolding;
import org.apache.sysml.hops.rewrite.RewriteRemoveUnnecessaryBranches;
import org.apache.sysml.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.LanguageException;
import org.apache.sysml.parser.ParForStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.ForProgramBlock;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.IfProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.WhileProgramBlock;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.controlprogram.parfor.opt.Optimizer;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Stat;
import org.apache.sysml.runtime.controlprogram.parfor.stat.StatisticMonitor;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.sysml.utils.Statistics;

/* loaded from: input_file:org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.class */
public class OptimizationWrapper {
    private static final boolean LDEBUG = false;
    private static final Log LOG = LogFactory.getLog(OptimizationWrapper.class.getName());
    public static final double PAR_FACTOR_INFRASTRUCTURE = 1.0d;
    private static final boolean ALLOW_RUNTIME_COSTMODEL = false;
    private static final boolean CHECK_PLAN_CORRECTNESS = false;

    public static void optimize(DMLProgram dMLProgram, Program program, boolean z) throws DMLRuntimeException, LanguageException, DMLUnsupportedOperationException {
        LOG.debug("ParFOR Opt: Running optimize all on DML program " + DMLScript.getUUID());
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        findParForProgramBlocks(dMLProgram, program, hashMap, hashMap2);
        ExecutionContext createContext = ExecutionContextFactory.createContext();
        for (Map.Entry entry : hashMap2.entrySet()) {
            ParForStatementBlock parForStatementBlock = (ParForStatementBlock) hashMap.get(Long.valueOf(((Long) entry.getKey()).longValue()));
            ParForProgramBlock parForProgramBlock = (ParForProgramBlock) entry.getValue();
            optimize(parForProgramBlock.getOptimizationMode(), parForStatementBlock, parForProgramBlock, createContext, z);
        }
        LOG.debug("ParFOR Opt: Finished optimization for DML program " + DMLScript.getUUID());
    }

    public static void optimize(ParForProgramBlock.POptMode pOptMode, ParForStatementBlock parForStatementBlock, ParForProgramBlock parForProgramBlock, ExecutionContext executionContext, boolean z) throws DMLRuntimeException, DMLUnsupportedOperationException {
        Timing timing = new Timing(true);
        LOG.debug("ParFOR Opt: Running optimization for ParFOR(" + parForProgramBlock.getID() + ")");
        optimize(pOptMode, UtilFunctions.toInt(Math.max(InfrastructureAnalyzer.getCkMaxCP(), InfrastructureAnalyzer.getCkMaxMR()) * 1.0d), InfrastructureAnalyzer.getCmMax() * OptimizerUtils.MEM_UTIL_FACTOR, parForStatementBlock, parForProgramBlock, executionContext, z);
        double stop = timing.stop();
        LOG.debug("ParFOR Opt: Finished optimization for PARFOR(" + parForProgramBlock.getID() + ") in " + stop + "ms.");
        if (z) {
            StatisticMonitor.putPFStat(parForProgramBlock.getID(), Stat.OPT_T, stop);
        }
    }

    public static void setLogLevel(Level level) {
        Logger.getLogger("org.apache.sysml.runtime.controlprogram.parfor.opt").setLevel(level);
    }

    private static void optimize(ParForProgramBlock.POptMode pOptMode, int i, double d, ParForStatementBlock parForStatementBlock, ParForProgramBlock parForProgramBlock, ExecutionContext executionContext, boolean z) throws DMLRuntimeException, DMLUnsupportedOperationException {
        Timing timing = new Timing(true);
        if (DMLScript.STATISTICS) {
            Statistics.incrementParForOptimCount();
        }
        Optimizer createOptimizer = createOptimizer(pOptMode);
        Optimizer.CostModelType costModelType = createOptimizer.getCostModelType();
        LOG.trace("ParFOR Opt: Created optimizer (" + pOptMode + "," + createOptimizer.getPlanInputType() + "," + createOptimizer.getCostModelType());
        if (costModelType == Optimizer.CostModelType.RUNTIME_METRICS) {
            throw new DMLRuntimeException("ParFOR Optimizer " + pOptMode + " requires cost model " + costModelType + " that is not suported yet.");
        }
        if (OptimizerUtils.ALLOW_DYN_RECOMPILATION) {
            ForStatement forStatement = (ForStatement) parForStatementBlock.getStatement(0);
            if (LOG.isDebugEnabled()) {
                try {
                    LOG.debug("ParFOR Opt: Input plan (before recompilation):\n" + OptTreeConverter.createOptTree(i, d, createOptimizer.getPlanInputType(), parForStatementBlock, parForProgramBlock, executionContext).explain(false));
                    OptTreeConverter.clear();
                } catch (Exception e) {
                    throw new DMLRuntimeException("Unable to create opt tree.", e);
                }
            }
            try {
                ProgramRecompiler.replaceConstantScalarVariables(parForStatementBlock, ProgramRecompiler.getReusableScalarVariables(parForStatementBlock.getDMLProg(), parForStatementBlock, executionContext.getVariables()));
                try {
                    ProgramRewriter createProgramRewriterWithRuleSets = createProgramRewriterWithRuleSets();
                    ProgramRewriteStatus programRewriteStatus = new ProgramRewriteStatus();
                    createProgramRewriterWithRuleSets.rewriteStatementBlockHopDAGs(parForStatementBlock, programRewriteStatus);
                    forStatement.setBody(createProgramRewriterWithRuleSets.rewriteStatementBlocks(forStatement.getBody(), programRewriteStatus));
                    if (programRewriteStatus.getRemovedBranches()) {
                        LOG.debug("ParFOR Opt: Removed branches during program rewrites, rebuilding runtime program");
                        parForProgramBlock.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(parForProgramBlock.getProgram(), forStatement.getBody()));
                    }
                    try {
                        Recompiler.recompileProgramBlockHierarchy(parForProgramBlock.getChildBlocks(), (LocalVariableMap) executionContext.getVariables().clone(), 0L, true);
                        if (parForProgramBlock.hasFunctions()) {
                            Set<String> analyzeSubProgram = new InterProceduralAnalysis().analyzeSubProgram(parForStatementBlock);
                            if (!analyzeSubProgram.isEmpty()) {
                                Iterator<String> it = analyzeSubProgram.iterator();
                                while (it.hasNext()) {
                                    String[] splitFunctionKey = DMLProgram.splitFunctionKey(it.next());
                                    FunctionProgramBlock functionProgramBlock = parForProgramBlock.getProgram().getFunctionProgramBlock(splitFunctionKey[0], splitFunctionKey[1]);
                                    Recompiler.recompileProgramBlockHierarchy(functionProgramBlock.getChildBlocks(), new LocalVariableMap(), 0L, functionProgramBlock.isRecompileOnce());
                                }
                            }
                        }
                    } catch (Exception e2) {
                        throw new DMLRuntimeException(e2);
                    }
                } catch (Exception e3) {
                    throw new DMLRuntimeException(e3);
                }
            } catch (Exception e4) {
                throw new DMLRuntimeException(e4);
            }
        }
        try {
            OptTree createOptTree = OptTreeConverter.createOptTree(i, d, createOptimizer.getPlanInputType(), parForStatementBlock, parForProgramBlock, executionContext);
            LOG.debug("ParFOR Opt: Input plan (before optimization):\n" + createOptTree.explain(false));
            CostEstimator createCostEstimator = createCostEstimator(costModelType);
            LOG.trace("ParFOR Opt: Created cost estimator (" + costModelType + ")");
            createOptimizer.optimize(parForStatementBlock, parForProgramBlock, createOptTree, createCostEstimator, executionContext);
            LOG.debug("ParFOR Opt: Optimized plan (after optimization): \n" + createOptTree.explain(false));
            long stop = (long) timing.stop();
            LOG.trace("ParFOR Opt: Optimized plan in " + stop + "ms.");
            if (DMLScript.STATISTICS) {
                Statistics.incrementParForOptimTime(stop);
            }
            OptTreeConverter.clear();
            if (z) {
                StatisticMonitor.putPFStat(parForProgramBlock.getID(), Stat.OPT_OPTIMIZER, pOptMode.ordinal());
                StatisticMonitor.putPFStat(parForProgramBlock.getID(), Stat.OPT_NUMTPLANS, createOptimizer.getNumTotalPlans());
                StatisticMonitor.putPFStat(parForProgramBlock.getID(), Stat.OPT_NUMEPLANS, createOptimizer.getNumEvaluatedPlans());
            }
        } catch (Exception e5) {
            throw new DMLRuntimeException("Unable to create opt tree.", e5);
        }
    }

    private static void findParForProgramBlocks(DMLProgram dMLProgram, Program program, HashMap<Long, ParForStatementBlock> hashMap, HashMap<Long, ParForProgramBlock> hashMap2) throws LanguageException {
        for (Map.Entry<String, FunctionProgramBlock> entry : program.getFunctionProgramBlocks().entrySet()) {
            String[] split = entry.getKey().split(Program.KEY_DELIM);
            rfindParForProgramBlocks(dMLProgram.getFunctionStatementBlock(split[0], split[1]), entry.getValue(), hashMap, hashMap2);
        }
        ArrayList<ProgramBlock> programBlocks = program.getProgramBlocks();
        for (int i = 0; i < programBlocks.size(); i++) {
            rfindParForProgramBlocks(dMLProgram.getStatementBlock(i), programBlocks.get(i), hashMap, hashMap2);
        }
    }

    private static void rfindParForProgramBlocks(StatementBlock statementBlock, ProgramBlock programBlock, HashMap<Long, ParForStatementBlock> hashMap, HashMap<Long, ParForProgramBlock> hashMap2) {
        if (programBlock instanceof ParForProgramBlock) {
            ParForProgramBlock parForProgramBlock = (ParForProgramBlock) programBlock;
            ParForStatementBlock parForStatementBlock = (ParForStatementBlock) statementBlock;
            LOG.trace("ParFOR: found ParForProgramBlock with POptMode=" + parForProgramBlock.getOptimizationMode().toString());
            if (parForProgramBlock.getOptimizationMode() != ParForProgramBlock.POptMode.NONE) {
                long id = parForProgramBlock.getID();
                hashMap2.put(Long.valueOf(id), parForProgramBlock);
                hashMap.put(Long.valueOf(id), parForStatementBlock);
                return;
            }
            return;
        }
        if (programBlock instanceof ForProgramBlock) {
            ArrayList<ProgramBlock> childBlocks = ((ForProgramBlock) programBlock).getChildBlocks();
            ArrayList<StatementBlock> body = ((ForStatement) ((ForStatementBlock) statementBlock).getStatement(0)).getBody();
            for (int i = 0; i < childBlocks.size(); i++) {
                rfindParForProgramBlocks(body.get(i), childBlocks.get(i), hashMap, hashMap2);
            }
            return;
        }
        if (programBlock instanceof WhileProgramBlock) {
            ArrayList<ProgramBlock> childBlocks2 = ((WhileProgramBlock) programBlock).getChildBlocks();
            ArrayList<StatementBlock> body2 = ((WhileStatement) ((WhileStatementBlock) statementBlock).getStatement(0)).getBody();
            for (int i2 = 0; i2 < childBlocks2.size(); i2++) {
                rfindParForProgramBlocks(body2.get(i2), childBlocks2.get(i2), hashMap, hashMap2);
            }
            return;
        }
        if (programBlock instanceof IfProgramBlock) {
            IfProgramBlock ifProgramBlock = (IfProgramBlock) programBlock;
            IfStatement ifStatement = (IfStatement) ((IfStatementBlock) statementBlock).getStatement(0);
            ArrayList<ProgramBlock> childBlocksIfBody = ifProgramBlock.getChildBlocksIfBody();
            ArrayList<ProgramBlock> childBlocksElseBody = ifProgramBlock.getChildBlocksElseBody();
            ArrayList<StatementBlock> ifBody = ifStatement.getIfBody();
            ArrayList<StatementBlock> elseBody = ifStatement.getElseBody();
            for (int i3 = 0; i3 < childBlocksIfBody.size(); i3++) {
                rfindParForProgramBlocks(ifBody.get(i3), childBlocksIfBody.get(i3), hashMap, hashMap2);
            }
            for (int i4 = 0; i4 < childBlocksElseBody.size(); i4++) {
                rfindParForProgramBlocks(elseBody.get(i4), childBlocksElseBody.get(i4), hashMap, hashMap2);
            }
        }
    }

    private static Optimizer createOptimizer(ParForProgramBlock.POptMode pOptMode) throws DMLRuntimeException {
        Optimizer optimizerConstrained;
        switch (pOptMode) {
            case HEURISTIC:
                optimizerConstrained = new OptimizerHeuristic();
                break;
            case RULEBASED:
                optimizerConstrained = new OptimizerRuleBased();
                break;
            case CONSTRAINED:
                optimizerConstrained = new OptimizerConstrained();
                break;
            default:
                throw new DMLRuntimeException("Undefined optimizer: '" + pOptMode + "'.");
        }
        return optimizerConstrained;
    }

    private static CostEstimator createCostEstimator(Optimizer.CostModelType costModelType) throws DMLRuntimeException {
        CostEstimator costEstimatorRuntime;
        switch (costModelType) {
            case STATIC_MEM_METRIC:
                costEstimatorRuntime = new CostEstimatorHops(OptTreeConverter.getAbstractPlanMapping());
                break;
            case RUNTIME_METRICS:
                costEstimatorRuntime = new CostEstimatorRuntime();
                break;
            default:
                throw new DMLRuntimeException("Undefined cost model type: '" + costModelType + "'.");
        }
        return costEstimatorRuntime;
    }

    private static ProgramRewriter createProgramRewriterWithRuleSets() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new RewriteConstantFolding());
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(new RewriteRemoveUnnecessaryBranches());
        return new ProgramRewriter((ArrayList<HopRewriteRule>) arrayList, (ArrayList<StatementBlockRewriteRule>) arrayList2);
    }
}
