package org.apache.sysml.api.mlcontext;

import java.io.IOException;
import java.util.HashSet;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.api.ScriptExecutorUtils;
import org.apache.sysml.api.jmlc.JMLCUtils;
import org.apache.sysml.api.mlcontext.MLContext;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.globalopt.GlobalOptimizerWrapper;
import org.apache.sysml.hops.rewrite.ProgramRewriter;
import org.apache.sysml.hops.rewrite.RewriteRemovePersistentReadWrite;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.DMLTranslator;
import org.apache.sysml.parser.LanguageException;
import org.apache.sysml.parser.ParseException;
import org.apache.sysml.parser.ParserFactory;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysml.utils.Explain;
import org.apache.sysml.utils.Statistics;

/* loaded from: input_file:org/apache/sysml/api/mlcontext/ScriptExecutor.class */
public class ScriptExecutor {
    protected DMLConfig config;
    protected DMLProgram dmlProgram;
    protected DMLTranslator dmlTranslator;
    protected Program runtimeProgram;
    protected ExecutionContext executionContext;
    protected Script script;
    protected boolean init;
    protected boolean explain;
    protected boolean gpu;
    protected boolean oldGPU;
    protected boolean forceGPU;
    protected boolean oldForceGPU;
    protected boolean statistics;
    protected boolean oldStatistics;
    protected MLContext.ExplainLevel explainLevel;
    protected MLContext.ExecutionType executionType;
    protected int statisticsMaxHeavyHitters;
    protected boolean maintainSymbolTable;

    public ScriptExecutor() {
        this.init = false;
        this.explain = false;
        this.gpu = false;
        this.oldGPU = false;
        this.forceGPU = false;
        this.oldForceGPU = false;
        this.statistics = false;
        this.oldStatistics = false;
        this.statisticsMaxHeavyHitters = 10;
        this.maintainSymbolTable = false;
        this.config = ConfigurationManager.getDMLConfig();
    }

    public ScriptExecutor(DMLConfig dMLConfig) {
        this.init = false;
        this.explain = false;
        this.gpu = false;
        this.oldGPU = false;
        this.forceGPU = false;
        this.oldForceGPU = false;
        this.statistics = false;
        this.oldStatistics = false;
        this.statisticsMaxHeavyHitters = 10;
        this.maintainSymbolTable = false;
        this.config = dMLConfig;
        ConfigurationManager.setGlobalConfig(dMLConfig);
    }

    protected void constructHops() {
        try {
            this.dmlTranslator.constructHops(this.dmlProgram);
        } catch (LanguageException | ParseException e) {
            throw new MLContextException("Exception occurred while constructing HOPS (high-level operators)", e);
        }
    }

    protected void rewriteHops() {
        try {
            this.dmlTranslator.rewriteHopsDAG(this.dmlProgram);
        } catch (HopsException | LanguageException | ParseException | DMLRuntimeException e) {
            throw new MLContextException("Exception occurred while rewriting HOPS (high-level operators)", e);
        }
    }

    protected void showExplanation() {
        if (this.explain) {
            try {
                System.out.println(Explain.display(this.dmlProgram, this.runtimeProgram, this.explainLevel != null ? this.explainLevel.getExplainType() : Explain.ExplainType.RUNTIME, null));
            } catch (Exception e) {
                throw new MLContextException("Exception occurred while explaining dml program", e);
            }
        }
    }

    protected void constructLops() {
        try {
            this.dmlTranslator.constructLops(this.dmlProgram);
        } catch (HopsException | LopsException | LanguageException | ParseException e) {
            throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e);
        }
    }

    protected void generateRuntimeProgram() {
        try {
            this.runtimeProgram = this.dmlTranslator.getRuntimeProgram(this.dmlProgram, this.config);
        } catch (IOException | HopsException | LopsException | LanguageException | DMLRuntimeException e) {
            throw new MLContextException("Exception occurred while generating runtime program", e);
        }
    }

    protected void countCompiledMRJobsAndSparkInstructions() {
        Statistics.resetNoOfCompiledJobs(Explain.countDistributedOperations(this.runtimeProgram).numJobs);
    }

    protected void createAndInitializeExecutionContext() {
        this.executionContext = ExecutionContextFactory.createContext(this.runtimeProgram);
        LocalVariableMap symbolTable = this.script.getSymbolTable();
        if (symbolTable != null) {
            this.executionContext.setVariables(symbolTable);
        }
        this.executionContext.getVariables().setRegisteredOutputs(new HashSet<>(this.script.getOutputVariables()));
    }

    protected void setGlobalFlags() {
        this.oldStatistics = DMLScript.STATISTICS;
        DMLScript.STATISTICS = this.statistics;
        this.oldForceGPU = DMLScript.FORCE_ACCELERATOR;
        DMLScript.FORCE_ACCELERATOR = this.forceGPU;
        this.oldGPU = DMLScript.USE_ACCELERATOR;
        DMLScript.USE_ACCELERATOR = this.gpu;
        DMLScript.STATISTICS_COUNT = this.statisticsMaxHeavyHitters;
        try {
            OptimizerUtils.resetStaticCompilerFlags();
            ConfigurationManager.setGlobalConfig(OptimizerUtils.constructCompilerConfig(ConfigurationManager.getCompilerConfig(), this.config));
            GPUContextPool.AVAILABLE_GPUS = this.config.getTextValue(DMLConfig.AVAILABLE_GPUS);
            String upperCase = this.config.getTextValue(DMLConfig.GPU_EVICTION_POLICY).toUpperCase();
            try {
                DMLScript.GPU_EVICTION_POLICY = DMLScript.EvictionPolicy.valueOf(upperCase);
            } catch (IllegalArgumentException e) {
                throw new RuntimeException("Unsupported eviction policy:" + upperCase);
            }
        } catch (DMLRuntimeException e2) {
            throw new RuntimeException(e2);
        }
    }

    protected void resetGlobalFlags() {
        DMLScript.STATISTICS = this.oldStatistics;
        DMLScript.FORCE_ACCELERATOR = this.oldForceGPU;
        DMLScript.USE_ACCELERATOR = this.oldGPU;
        DMLScript.STATISTICS_COUNT = DMLScript.DMLOptions.defaultOptions.statsCount;
    }

    public void compile(Script script) {
        compile(script, true);
    }

    public void compile(Script script, boolean z) {
        setup(script);
        if (this.statistics) {
            Statistics.startCompileTimer();
        }
        parseScript();
        liveVariableAnalysis();
        validateScript();
        constructHops();
        if (z) {
            rewriteHops();
        }
        rewritePersistentReadsAndWrites();
        constructLops();
        generateRuntimeProgram();
        showExplanation();
        globalDataFlowOptimization();
        countCompiledMRJobsAndSparkInstructions();
        initializeCachingAndScratchSpace();
        cleanupRuntimeProgram();
        if (this.statistics) {
            Statistics.stopCompileTimer();
        }
    }

    public MLResults execute(Script script) {
        compile(script);
        try {
            createAndInitializeExecutionContext();
            executeRuntimeProgram();
            MLResults mLResults = new MLResults(script);
            script.setResults(mLResults);
            return mLResults;
        } finally {
            cleanupAfterExecution();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setup(Script script) {
        this.script = script;
        checkScriptHasTypeAndString();
        script.setScriptExecutor(this);
        DMLScript.SCRIPT_TYPE = script.getScriptType();
        setGlobalFlags();
        Statistics.resetNoOfExecutedJobs();
        if (this.statistics) {
            Statistics.reset();
        }
    }

    protected void cleanupAfterExecution() {
        restoreInputsInSymbolTable();
        resetGlobalFlags();
    }

    protected void restoreInputsInSymbolTable() {
        Map<String, Object> inputs = this.script.getInputs();
        Map<String, Metadata> inputMetadata = this.script.getInputMetadata();
        LocalVariableMap symbolTable = this.script.getSymbolTable();
        for (String str : this.script.getInputVariables()) {
            if (symbolTable.get(str) == null) {
                this.script.in(str, inputs.get(str), inputMetadata.get(str));
            }
        }
    }

    protected void cleanupRuntimeProgram() {
        if (this.maintainSymbolTable) {
            MLContextUtil.deleteRemoveVariableInstructions(this.runtimeProgram);
        } else {
            JMLCUtils.cleanupRuntimeProgram(this.runtimeProgram, this.script.getOutputVariables() == null ? new String[0] : (String[]) this.script.getOutputVariables().toArray(new String[0]));
        }
    }

    protected void executeRuntimeProgram() {
        try {
            ScriptExecutorUtils.executeRuntimeProgram(this, this.statistics ? this.statisticsMaxHeavyHitters : 0);
        } catch (DMLRuntimeException e) {
            throw new MLContextException("Exception occurred while executing runtime program", e);
        }
    }

    protected void initializeCachingAndScratchSpace() {
        if (this.init) {
            try {
                DMLScript.initHadoopExecution(this.config);
            } catch (IOException e) {
                throw new MLContextException("Exception occurred initializing caching and scratch space", e);
            } catch (ParseException e2) {
                throw new MLContextException("Exception occurred initializing caching and scratch space", e2);
            } catch (DMLRuntimeException e3) {
                throw new MLContextException("Exception occurred initializing caching and scratch space", e3);
            }
        }
    }

    protected void globalDataFlowOptimization() {
        if (OptimizerUtils.isOptLevel(OptimizerUtils.OptimizationLevel.O4_GLOBAL_TIME_MEMORY)) {
            try {
                this.runtimeProgram = GlobalOptimizerWrapper.optimizeProgram(this.dmlProgram, this.runtimeProgram);
            } catch (HopsException e) {
                throw new MLContextException("Exception occurred during global data flow optimization", e);
            } catch (LopsException e2) {
                throw new MLContextException("Exception occurred during global data flow optimization", e2);
            } catch (DMLRuntimeException e3) {
                throw new MLContextException("Exception occurred during global data flow optimization", e3);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void parseScript() {
        try {
            this.dmlProgram = ParserFactory.createParser(this.script.getScriptType()).parse(null, this.script.getScriptExecutionString(), MLContextUtil.convertInputParametersForParser(this.script.getInputParameters(), this.script.getScriptType()));
        } catch (ParseException e) {
            throw new MLContextException("Exception occurred while parsing script", e);
        }
    }

    protected void rewritePersistentReadsAndWrites() {
        if (this.script.getSymbolTable() != null) {
            try {
                new ProgramRewriter(new RewriteRemovePersistentReadWrite(this.script.getInputVariables() == null ? new String[0] : (String[]) this.script.getInputVariables().toArray(new String[0]), this.script.getOutputVariables() == null ? new String[0] : (String[]) this.script.getOutputVariables().toArray(new String[0]), this.script.getSymbolTable())).rewriteProgramHopDAGs(this.dmlProgram);
            } catch (HopsException | LanguageException e) {
                throw new MLContextException("Exception occurred while rewriting persistent reads and writes", e);
            }
        }
    }

    public void setConfig(DMLConfig dMLConfig) {
        this.config = dMLConfig;
        ConfigurationManager.setGlobalConfig(dMLConfig);
    }

    protected void liveVariableAnalysis() {
        try {
            this.dmlTranslator = new DMLTranslator(this.dmlProgram);
            this.dmlTranslator.liveVariableAnalysis(this.dmlProgram);
        } catch (LanguageException e) {
            throw new MLContextException("Exception occurred during live variable analysis", e);
        } catch (DMLRuntimeException e2) {
            throw new MLContextException("Exception occurred during live variable analysis", e2);
        }
    }

    protected void validateScript() {
        try {
            this.dmlTranslator.validateParseTree(this.dmlProgram);
        } catch (IOException e) {
            throw new MLContextException("Exception occurred while validating script", e);
        } catch (LanguageException e2) {
            throw new MLContextException("Exception occurred while validating script", e2);
        } catch (ParseException e3) {
            throw new MLContextException("Exception occurred while validating script", e3);
        }
    }

    protected void checkScriptHasTypeAndString() {
        if (this.script == null) {
            throw new MLContextException("Script is null");
        }
        if (this.script.getScriptType() == null) {
            throw new MLContextException("ScriptType (DML or PYDML) needs to be specified");
        }
        if (this.script.getScriptString() == null) {
            throw new MLContextException("Script string is null");
        }
        if (StringUtils.isBlank(this.script.getScriptString())) {
            throw new MLContextException("Script string is blank");
        }
    }

    public DMLProgram getDmlProgram() {
        return this.dmlProgram;
    }

    public DMLTranslator getDmlTranslator() {
        return this.dmlTranslator;
    }

    public Program getRuntimeProgram() {
        return this.runtimeProgram;
    }

    public ExecutionContext getExecutionContext() {
        return this.executionContext;
    }

    public Script getScript() {
        return this.script;
    }

    public void setExplain(boolean z) {
        this.explain = z;
    }

    public void setStatistics(boolean z) {
        this.statistics = z;
    }

    public void setStatisticsMaxHeavyHitters(int i) {
        this.statisticsMaxHeavyHitters = i;
    }

    public boolean isMaintainSymbolTable() {
        return this.maintainSymbolTable;
    }

    public void setMaintainSymbolTable(boolean z) {
        this.maintainSymbolTable = z;
    }

    public void setInit(boolean z) {
        this.init = z;
    }

    public void setExplainLevel(MLContext.ExplainLevel explainLevel) {
        this.explainLevel = explainLevel;
        if (explainLevel == null) {
            DMLScript.EXPLAIN = Explain.ExplainType.NONE;
        } else {
            DMLScript.EXPLAIN = explainLevel.getExplainType();
        }
    }

    public void setGPU(boolean z) {
        this.gpu = z;
    }

    public void setForceGPU(boolean z) {
        this.forceGPU = z;
    }

    public DMLConfig getConfig() {
        return this.config;
    }

    public MLContext.ExecutionType getExecutionType() {
        return this.executionType;
    }

    public void setExecutionType(MLContext.ExecutionType executionType) {
        DMLScript.rtplatform = executionType.getRuntimePlatform();
        this.executionType = executionType;
    }
}
