package org.apache.sysml.api.mlcontext;

import java.util.Date;
import java.util.Set;
import org.apache.log4j.Logger;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
import org.apache.sysml.api.ConfigurableAPI;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.api.mlcontext.convenience.Scripts;
import org.apache.sysml.api.mlcontext.convenience.scripts.Algorithms;
import org.apache.sysml.api.mlcontext.convenience.scripts.Datagen;
import org.apache.sysml.api.mlcontext.convenience.scripts.Nn;
import org.apache.sysml.api.mlcontext.convenience.scripts.Utils;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.parser.DataExpression;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.parser.IntIdentifier;
import org.apache.sysml.parser.Statement;
import org.apache.sysml.parser.StringIdentifier;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.matrix.MetaDataFormat;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.utils.Explain;
import org.apache.sysml.utils.MLContextProxy;

/* loaded from: input_file:org/apache/sysml/api/mlcontext/MLContext.class */
public class MLContext implements ConfigurableAPI {
    private SparkSession spark = null;
    private Script executionScript = null;
    private InternalProxy internalProxy = new InternalProxy();
    private boolean explain = false;
    private boolean statistics = false;
    private boolean gpu = false;
    private boolean forceGPU = false;
    private int statisticsMaxHeavyHitters = 10;
    private ExplainLevel explainLevel = null;
    private ExecutionType executionType = ExecutionType.DRIVER_AND_SPARK;
    private boolean maintainSymbolTable = false;
    private boolean initBeforeExecution = true;
    protected static Logger log = Logger.getLogger(MLContext.class);
    private static MLContext activeMLContext = null;

    /* loaded from: input_file:org/apache/sysml/api/mlcontext/MLContext$ExecutionType.class */
    public enum ExecutionType {
        DRIVER,
        SPARK,
        HADOOP,
        DRIVER_AND_SPARK,
        DRIVER_AND_HADOOP;

        public DMLScript.RUNTIME_PLATFORM getRuntimePlatform() {
            switch (this) {
                case DRIVER:
                    return DMLScript.RUNTIME_PLATFORM.SINGLE_NODE;
                case SPARK:
                    return DMLScript.RUNTIME_PLATFORM.SPARK;
                case HADOOP:
                    return DMLScript.RUNTIME_PLATFORM.HADOOP;
                case DRIVER_AND_SPARK:
                    return DMLScript.RUNTIME_PLATFORM.HYBRID_SPARK;
                case DRIVER_AND_HADOOP:
                    return DMLScript.RUNTIME_PLATFORM.HYBRID;
                default:
                    return DMLScript.RUNTIME_PLATFORM.HYBRID_SPARK;
            }
        }
    }

    /* loaded from: input_file:org/apache/sysml/api/mlcontext/MLContext$ExplainLevel.class */
    public enum ExplainLevel {
        NONE,
        HOPS,
        RUNTIME,
        RECOMPILE_HOPS,
        RECOMPILE_RUNTIME;

        public Explain.ExplainType getExplainType() {
            switch (this) {
                case NONE:
                    return Explain.ExplainType.NONE;
                case HOPS:
                    return Explain.ExplainType.HOPS;
                case RUNTIME:
                    return Explain.ExplainType.RUNTIME;
                case RECOMPILE_HOPS:
                    return Explain.ExplainType.RECOMPILE_HOPS;
                case RECOMPILE_RUNTIME:
                    return Explain.ExplainType.RECOMPILE_RUNTIME;
                default:
                    return Explain.ExplainType.HOPS;
            }
        }
    }

    /* loaded from: input_file:org/apache/sysml/api/mlcontext/MLContext$InternalProxy.class */
    public class InternalProxy {
        public InternalProxy() {
        }

        public void setAppropriateVarsForRead(Expression expression, String str) {
            MatrixObject matrixObject;
            boolean isRegisteredAsInput = isRegisteredAsInput(str);
            boolean z = (expression instanceof DataExpression) && ((DataExpression) expression).isRead();
            if (isRegisteredAsInput && z) {
                DataExpression dataExpression = (DataExpression) expression;
                dataExpression.setCheckMetadata(false);
                Expression varParam = ((DataExpression) expression).getVarParam(DataExpression.DATATYPEPARAM);
                String str2 = Statement.MATRIX_DATA_TYPE;
                if (varParam != null) {
                    str2 = varParam.toString();
                }
                if (str2.compareToIgnoreCase(Statement.FRAME_DATA_TYPE) == 0 || (matrixObject = getMatrixObject(str)) == null) {
                    return;
                }
                dataExpression.addVarParam("rows", new IntIdentifier(matrixObject.getNumRows(), expression));
                dataExpression.addVarParam("cols", new IntIdentifier(matrixObject.getNumColumns(), expression));
                dataExpression.addVarParam(DataExpression.READNUMNONZEROPARAM, new IntIdentifier(matrixObject.getNnz(), expression));
                dataExpression.addVarParam(DataExpression.DATATYPEPARAM, new StringIdentifier(Statement.MATRIX_DATA_TYPE, expression));
                dataExpression.addVarParam(DataExpression.VALUETYPEPARAM, new StringIdentifier(Statement.DOUBLE_VALUE_TYPE, expression));
                if (matrixObject.getMetaData() instanceof MetaDataFormat) {
                    MetaDataFormat metaDataFormat = (MetaDataFormat) matrixObject.getMetaData();
                    if (metaDataFormat.getOutputInfo() == OutputInfo.CSVOutputInfo) {
                        dataExpression.addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_CSV, expression));
                        return;
                    }
                    if (metaDataFormat.getOutputInfo() == OutputInfo.TextCellOutputInfo) {
                        dataExpression.addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_TEXT, expression));
                    } else {
                        if (metaDataFormat.getOutputInfo() != OutputInfo.BinaryBlockOutputInfo) {
                            throw new MLContextException("Unsupported format through MLContext");
                        }
                        dataExpression.addVarParam(DataExpression.ROWBLOCKCOUNTPARAM, new IntIdentifier(matrixObject.getNumRowsPerBlock(), expression));
                        dataExpression.addVarParam(DataExpression.COLUMNBLOCKCOUNTPARAM, new IntIdentifier(matrixObject.getNumColumnsPerBlock(), expression));
                        dataExpression.addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_BINARY, expression));
                    }
                }
            }
        }

        private boolean isRegisteredAsInput(String str) {
            Set<String> inputVariables;
            if (MLContext.this.executionScript == null || (inputVariables = MLContext.this.executionScript.getInputVariables()) == null) {
                return false;
            }
            return inputVariables.contains(str);
        }

        private MatrixObject getMatrixObject(String str) {
            LocalVariableMap symbolTable;
            if (MLContext.this.executionScript != null && (symbolTable = MLContext.this.executionScript.getSymbolTable()) != null) {
                Data data = symbolTable.get(str);
                if (data instanceof MatrixObject) {
                    return (MatrixObject) data;
                }
                if (data instanceof ScalarObject) {
                    return null;
                }
            }
            throw new MLContextException("getMatrixObject not set for parameter: " + str);
        }
    }

    public static MLContext getActiveMLContext() {
        return activeMLContext;
    }

    public MLContext(SparkSession sparkSession) {
        initMLContext(sparkSession);
    }

    public MLContext(SparkContext sparkContext) {
        initMLContext(SparkSession.builder().sparkContext(sparkContext).getOrCreate());
    }

    public MLContext(JavaSparkContext javaSparkContext) {
        initMLContext(SparkSession.builder().sparkContext(javaSparkContext.sc()).getOrCreate());
    }

    private void initMLContext(SparkSession sparkSession) {
        try {
            MLContextUtil.verifySparkVersionSupported(sparkSession);
        } catch (MLContextException e) {
            if (info() != null) {
                log.warn("Apache Spark " + info().minimumRecommendedSparkVersion() + " or above is recommended for SystemML " + info().version());
            } else {
                try {
                    log.warn("Apache Spark " + MLContextUtil.getMinimumRecommendedSparkVersionFromPom() + " or above is recommended for this version of SystemML.");
                } catch (MLContextException e2) {
                    log.error("Minimum recommended Spark version could not be determined from SystemML jar file manifest or pom.xml");
                }
            }
        }
        if (activeMLContext == null) {
            System.out.println(MLContextUtil.welcomeMessage());
        }
        this.spark = sparkSession;
        DMLScript.rtplatform = this.executionType.getRuntimePlatform();
        activeMLContext = this;
        MLContextProxy.setActive(true);
        MLContextUtil.setDefaultConfig();
        MLContextUtil.setCompilerConfig();
    }

    @Override // org.apache.sysml.api.ConfigurableAPI
    public void resetConfig() {
        MLContextUtil.setDefaultConfig();
    }

    @Override // org.apache.sysml.api.ConfigurableAPI
    public void setConfigProperty(String str, String str2) {
        try {
            ConfigurationManager.getDMLConfig().setTextValue(str, str2);
        } catch (DMLRuntimeException e) {
            throw new MLContextException(e);
        }
    }

    public MLResults execute(Script script) {
        ScriptExecutor scriptExecutor = new ScriptExecutor();
        scriptExecutor.setExecutionType(this.executionType);
        scriptExecutor.setExplain(this.explain);
        scriptExecutor.setExplainLevel(this.explainLevel);
        scriptExecutor.setGPU(this.gpu);
        scriptExecutor.setForceGPU(this.forceGPU);
        scriptExecutor.setStatistics(this.statistics);
        scriptExecutor.setStatisticsMaxHeavyHitters(this.statisticsMaxHeavyHitters);
        scriptExecutor.setInit(this.initBeforeExecution);
        if (this.initBeforeExecution) {
            this.initBeforeExecution = false;
        }
        scriptExecutor.setMaintainSymbolTable(this.maintainSymbolTable);
        return execute(script, scriptExecutor);
    }

    public MLResults execute(Script script, ScriptExecutor scriptExecutor) {
        try {
            this.executionScript = script;
            Long l = new Long(new Date().getTime());
            if (script.getName() == null || script.getName().equals("")) {
                script.setName(l.toString());
            }
            return scriptExecutor.execute(script);
        } catch (RuntimeException e) {
            throw new MLContextException("Exception when executing script", e);
        }
    }

    public void setExecutionScript(Script script) {
        this.executionScript = script;
    }

    public void setConfig(String str) {
        MLContextUtil.setConfig(str);
    }

    public SparkSession getSparkSession() {
        return this.spark;
    }

    public boolean isExplain() {
        return this.explain;
    }

    public void setExplain(boolean z) {
        this.explain = z;
    }

    public boolean isMaintainSymbolTable() {
        return this.maintainSymbolTable;
    }

    public void setMaintainSymbolTable(boolean z) {
        this.maintainSymbolTable = z;
    }

    public void setExplainLevel(ExplainLevel explainLevel) {
        this.explainLevel = explainLevel;
    }

    public void setExplainLevel(String str) {
        if (str != null) {
            for (ExplainLevel explainLevel : ExplainLevel.values()) {
                if (explainLevel.toString().equalsIgnoreCase(str)) {
                    setExplainLevel(explainLevel);
                    return;
                }
            }
        }
        throw new MLContextException("Failed to parse explain level: " + str + " (valid types: hops, runtime, recompile_hops, recompile_runtime).");
    }

    public void setGPU(boolean z) {
        this.gpu = z;
    }

    public void setForceGPU(boolean z) {
        this.forceGPU = z;
    }

    public boolean isGPU() {
        return this.gpu;
    }

    public boolean isForceGPU() {
        return this.forceGPU;
    }

    public InternalProxy getInternalProxy() {
        return this.internalProxy;
    }

    public boolean isStatistics() {
        return this.statistics;
    }

    public void setStatistics(boolean z) {
        DMLScript.STATISTICS = z;
        this.statistics = z;
    }

    public void setStatisticsMaxHeavyHitters(int i) {
        DMLScript.STATISTICS_COUNT = i;
        this.statisticsMaxHeavyHitters = i;
    }

    public void close() {
        SparkExecutionContext.resetSparkContextStatic();
        MLContextProxy.setActive(false);
        activeMLContext = null;
        try {
            DMLScript.cleanupHadoopExecution(ConfigurationManager.getDMLConfig());
            if (this.executionScript != null) {
                this.executionScript.clearAll();
            }
            resetConfig();
            this.spark = null;
        } catch (Exception e) {
            throw new MLContextException("Failed to cleanup working directories.", e);
        }
    }

    public ProjectInfo info() {
        try {
            return ProjectInfo.getProjectInfo();
        } catch (Exception e) {
            log.warn("Project information not available");
            return null;
        }
    }

    public String version() {
        return info() == null ? MLContextUtil.VERSION_NOT_AVAILABLE : info().version();
    }

    public String buildTime() {
        return info() == null ? MLContextUtil.BUILD_TIME_NOT_AVAILABLE : info().buildTime();
    }

    public int getStatisticsMaxHeavyHitters() {
        return this.statisticsMaxHeavyHitters;
    }

    public boolean isInitBeforeExecution() {
        return this.initBeforeExecution;
    }

    public void setInitBeforeExecution(boolean z) {
        this.initBeforeExecution = z;
    }

    public ExecutionType getExecutionType() {
        return this.executionType;
    }

    public void setExecutionType(ExecutionType executionType) {
        DMLScript.rtplatform = executionType.getRuntimePlatform();
        this.executionType = executionType;
    }

    public Scripts scripts() {
        return new Scripts();
    }

    public Nn nn() {
        return new Nn();
    }

    public Algorithms algorithms() {
        return new Algorithms();
    }

    public Utils utils() {
        return new Utils();
    }

    public Datagen datagen() {
        return new Datagen();
    }
}
