package org.apache.sysml.runtime.controlprogram.context;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.Path;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.package$;
import org.apache.spark.storage.RDDInfo;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.util.LongAccumulator;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.api.mlcontext.MLContext;
import org.apache.sysml.api.mlcontext.MLContextUtil;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.lops.Checkpoint;
import org.apache.sysml.parser.DataExpression;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.parser.ExternalFunctionStatement;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.compress.CompressedMatrixBlock;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.spark.data.BroadcastObject;
import org.apache.sysml.runtime.instructions.spark.data.LineageObject;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBlock;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysml.runtime.instructions.spark.data.RDDObject;
import org.apache.sysml.runtime.instructions.spark.functions.ComputeBinaryBlockNnzFunction;
import org.apache.sysml.runtime.instructions.spark.functions.CopyBinaryCellFunction;
import org.apache.sysml.runtime.instructions.spark.functions.CopyFrameBlockPairFunction;
import org.apache.sysml.runtime.instructions.spark.functions.CopyTextInputFunction;
import org.apache.sysml.runtime.instructions.spark.functions.CreateSparseBlockFunction;
import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.runtime.matrix.data.SparseBlock;
import org.apache.sysml.runtime.matrix.mapred.MRJobConfiguration;
import org.apache.sysml.runtime.util.MapReduceTool;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.sysml.utils.MLContextProxy;
import org.apache.sysml.utils.Statistics;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.class */
public class SparkExecutionContext extends ExecutionContext {
    private static final boolean LDEBUG = false;
    private static final Log LOG = LogFactory.getLog(SparkExecutionContext.class.getName());
    private static boolean LAZY_SPARKCTX_CREATION = true;
    private static boolean ASYNCHRONOUS_VAR_DESTROY = true;
    public static boolean FAIR_SCHEDULER_MODE = true;
    private static SparkClusterConfig _sconf = null;
    private static JavaSparkContext _spctx = null;
    private static MemoryManagerParRDDs _parRDDs = new MemoryManagerParRDDs(0.1d);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext$MemoryManagerParRDDs.class */
    public static class MemoryManagerParRDDs {
        private final long _limit;
        private long _size = 0;
        private HashMap<Integer, Long> _rdds = new HashMap<>();

        public MemoryManagerParRDDs(double d) {
            this._limit = (long) (d * InfrastructureAnalyzer.getLocalMaxMemory());
        }

        public synchronized boolean reserve(long j) {
            boolean z = j + this._size < this._limit;
            this._size += z ? j : 0L;
            return z;
        }

        public synchronized void registerRDD(int i, long j, boolean z) {
            if (!z) {
                throw new RuntimeException("Unsupported rdd registration without size reservation for " + j + " bytes.");
            }
            this._rdds.put(Integer.valueOf(i), Long.valueOf(j));
        }

        public synchronized void deregisterRDD(int i) {
            this._size -= this._rdds.remove(Integer.valueOf(i)).longValue();
        }

        public synchronized void clear() {
            this._size = 0L;
            this._rdds.clear();
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext$SparkClusterConfig.class */
    public static class SparkClusterConfig {
        private static final double BROADCAST_DATA_FRACTION = 0.35d;
        private static final long RESERVED_SYSTEM_MEMORY_BYTES = 314572800;
        private boolean _legacyVersion;
        private boolean _confOnly;
        private long _memExecutor = -1;
        private double _memDataMinFrac = -1.0d;
        private double _memDataMaxFrac = -1.0d;
        private double _memBroadcastFrac = -1.0d;
        private int _numExecutors = -1;
        private int _defaultPar = -1;

        public SparkClusterConfig() {
            this._legacyVersion = false;
            this._confOnly = false;
            SparkConf createSystemMLSparkConf = SparkExecutionContext.createSystemMLSparkConf();
            this._confOnly = true;
            this._legacyVersion = UtilFunctions.compareVersion(getSparkVersionString(), "1.6.0") < 0 || createSystemMLSparkConf.getBoolean("spark.memory.useLegacyMode", false);
            if (this._legacyVersion) {
                analyzeSparkConfiguationLegacy(createSystemMLSparkConf);
            } else {
                analyzeSparkConfiguation(createSystemMLSparkConf);
            }
            if (SparkExecutionContext.LOG.isDebugEnabled()) {
                SparkExecutionContext.LOG.debug(toString());
            }
        }

        public long getBroadcastMemoryBudget() {
            return (long) (this._memExecutor * this._memBroadcastFrac);
        }

        public long getDataMemoryBudget(boolean z, boolean z2) {
            int i = this._numExecutors;
            if ((z2 && !this._confOnly) || SparkExecutionContext.isSparkContextCreated()) {
                i = Math.max(SparkExecutionContext.getSparkContextStatic().sc().getExecutorMemoryStatus().size() - 1, 1);
            }
            return (long) (i * this._memExecutor * (z ? this._memDataMinFrac : this._memDataMaxFrac));
        }

        public int getNumExecutors() {
            if (this._numExecutors < 0) {
                analyzeSparkParallelismConfiguation(null);
            }
            return this._numExecutors;
        }

        public int getDefaultParallelism(boolean z) {
            if (this._defaultPar < 0 && !z) {
                analyzeSparkParallelismConfiguation(null);
            }
            return Math.max(((!z || this._confOnly) && !SparkExecutionContext.isSparkContextCreated()) ? this._defaultPar : SparkExecutionContext.getSparkContextStatic().defaultParallelism().intValue(), 1);
        }

        public void analyzeSparkConfiguationLegacy(SparkConf sparkConf) {
            SparkConf createSystemMLSparkConf = sparkConf == null ? SparkExecutionContext.createSystemMLSparkConf() : sparkConf;
            this._memExecutor = UtilFunctions.parseMemorySize(createSystemMLSparkConf.get("spark.executor.memory", "1g"));
            double d = createSystemMLSparkConf.getDouble("spark.storage.memoryFraction", 0.6d);
            this._memDataMinFrac = d;
            this._memDataMaxFrac = d;
            this._memBroadcastFrac = d * BROADCAST_DATA_FRACTION;
            analyzeSparkParallelismConfiguation(createSystemMLSparkConf);
        }

        public void analyzeSparkConfiguation(SparkConf sparkConf) {
            SparkConf createSystemMLSparkConf = sparkConf == null ? SparkExecutionContext.createSystemMLSparkConf() : sparkConf;
            this._memExecutor = UtilFunctions.parseMemorySize(createSystemMLSparkConf.get("spark.executor.memory", "1g")) - RESERVED_SYSTEM_MEMORY_BYTES;
            this._memDataMinFrac = createSystemMLSparkConf.getDouble("spark.memory.storageFraction", 0.5d);
            this._memDataMaxFrac = createSystemMLSparkConf.getDouble("spark.memory.fraction", 0.6d);
            this._memBroadcastFrac = this._memDataMaxFrac * BROADCAST_DATA_FRACTION;
            analyzeSparkParallelismConfiguation(createSystemMLSparkConf);
        }

        private void analyzeSparkParallelismConfiguation(SparkConf sparkConf) {
            SparkConf createSystemMLSparkConf = sparkConf == null ? SparkExecutionContext.createSystemMLSparkConf() : sparkConf;
            int i = createSystemMLSparkConf.getInt("spark.executor.instances", -1);
            int i2 = createSystemMLSparkConf.getInt("spark.executor.cores", -1);
            int i3 = createSystemMLSparkConf.getInt("spark.default.parallelism", -1);
            if (i > 1 && (i3 > 1 || i2 > 1)) {
                this._numExecutors = i;
                this._defaultPar = i3 > 1 ? i3 : i * i2;
                this._confOnly &= true;
            } else {
                JavaSparkContext sparkContextStatic = SparkExecutionContext.getSparkContextStatic();
                this._numExecutors = Math.max(sparkContextStatic.sc().getExecutorMemoryStatus().size() - 1, 1);
                this._defaultPar = sparkContextStatic.defaultParallelism().intValue();
                this._confOnly &= false;
            }
        }

        private static String getSparkVersionString() {
            return SparkExecutionContext.isSparkContextCreated() ? SparkExecutionContext.getSparkContextStatic().version() : package$.MODULE$.SPARK_VERSION();
        }

        public String toString() {
            StringBuilder sb = new StringBuilder("SparkClusterConfig: \n");
            sb.append("-- legacyVersion    = " + this._legacyVersion + " (" + getSparkVersionString() + ")\n");
            sb.append("-- confOnly         = " + this._confOnly + "\n");
            sb.append("-- numExecutors     = " + this._numExecutors + "\n");
            sb.append("-- defaultPar       = " + this._defaultPar + "\n");
            sb.append("-- memExecutor      = " + this._memExecutor + "\n");
            sb.append("-- memDataMinFrac   = " + this._memDataMinFrac + "\n");
            sb.append("-- memDataMaxFrac   = " + this._memDataMaxFrac + "\n");
            sb.append("-- memBroadcastFrac = " + this._memBroadcastFrac + "\n");
            return sb.toString();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SparkExecutionContext(boolean z, Program program) {
        super(z, program);
        if (!LAZY_SPARKCTX_CREATION || DMLScript.rtplatform == DMLScript.RUNTIME_PLATFORM.SPARK) {
            initSparkContext();
        }
    }

    public JavaSparkContext getSparkContext() {
        if (LAZY_SPARKCTX_CREATION) {
            initSparkContext();
        }
        return _spctx;
    }

    public static JavaSparkContext getSparkContextStatic() {
        initSparkContext();
        return _spctx;
    }

    public static synchronized boolean isSparkContextCreated() {
        return _spctx != null;
    }

    public static void resetSparkContextStatic() {
        _spctx = null;
    }

    public void close() {
        synchronized (SparkExecutionContext.class) {
            if (_spctx != null) {
                _spctx.stop();
                _spctx = null;
            }
        }
    }

    public static boolean isLazySparkContextCreation() {
        return LAZY_SPARKCTX_CREATION;
    }

    private static synchronized void initSparkContext() {
        if (_spctx != null) {
            return;
        }
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        MLContext activeMLContext = MLContextProxy.getActiveMLContext();
        if (activeMLContext != null) {
            _spctx = MLContextUtil.getJavaSparkContext(activeMLContext);
        } else {
            if (DMLScript.USE_LOCAL_SPARK_CONFIG) {
                SparkConf appName = createSystemMLSparkConf().setMaster("local[*]").setAppName("My local integration test app");
                appName.set("spark.ui.enabled", ExternalFunctionStatement.DEFAULT_SIDE_EFFECTS);
                _spctx = new JavaSparkContext(appName);
            } else {
                _spctx = new JavaSparkContext(createSystemMLSparkConf());
            }
            _parRDDs.clear();
        }
        long parseMemorySize = UtilFunctions.parseMemorySize(_spctx.getConf().get("spark.driver.maxResultSize", "1g"));
        if (parseMemorySize != 0 && parseMemorySize < OptimizerUtils.getLocalMemBudget() && !DMLScript.USE_LOCAL_SPARK_CONFIG) {
            LOG.warn("Configuration parameter spark.driver.maxResultSize set to " + UtilFunctions.formatMemorySize(parseMemorySize) + ". You can set it through Spark default configuration setting either to 0 (unlimited) or to available memory budget of size " + UtilFunctions.formatMemorySize((long) OptimizerUtils.getLocalMemBudget()) + Path.CUR_DIR);
        }
        MRJobConfiguration.addBinaryBlockSerializationFramework(_spctx.hadoopConfiguration());
        if (DMLScript.STATISTICS) {
            Statistics.setSparkCtxCreateTime(System.nanoTime() - nanoTime);
        }
    }

    public static SparkConf createSystemMLSparkConf() {
        SparkConf sparkConf = new SparkConf();
        sparkConf.set("spark.driver.maxResultSize", "0");
        if (FAIR_SCHEDULER_MODE) {
            sparkConf.set("spark.scheduler.mode", "FAIR");
        }
        if (!sparkConf.contains("spark.locality.wait")) {
            sparkConf.set("spark.locality.wait", "5s");
        }
        String str = UtilFunctions.compareVersion(package$.MODULE$.SPARK_VERSION(), "2.0.0") < 0 ? "spark.akka.frameSize" : "spark.rpc.message.maxSize";
        if (!sparkConf.contains(str)) {
            sparkConf.set(str, "512");
        }
        return sparkConf;
    }

    public static boolean isLocalMaster() {
        return getSparkContextStatic().isLocal().booleanValue();
    }

    public JavaPairRDD<MatrixIndexes, MatrixBlock> getBinaryBlockRDDHandleForVariable(String str) throws DMLRuntimeException {
        return getRDDHandleForVariable(str, InputInfo.BinaryBlockInputInfo);
    }

    public JavaPairRDD<Long, FrameBlock> getFrameBinaryBlockRDDHandleForVariable(String str) throws DMLRuntimeException {
        return getRDDHandleForVariable(str, InputInfo.BinaryBlockInputInfo);
    }

    public JavaPairRDD<?, ?> getRDDHandleForVariable(String str, InputInfo inputInfo) throws DMLRuntimeException {
        Data variable = getVariable(str);
        if (variable instanceof MatrixObject) {
            return getRDDHandleForMatrixObject(getMatrixObject(str), inputInfo);
        }
        if (variable instanceof FrameObject) {
            return getRDDHandleForFrameObject(getFrameObject(str), inputInfo);
        }
        throw new DMLRuntimeException("Failed to obtain RDD for data type other than matrix or frame.");
    }

    public JavaPairRDD<?, ?> getRDDHandleForMatrixObject(MatrixObject matrixObject, InputInfo inputInfo) throws DMLRuntimeException {
        JavaPairRDD<?, ?> copyBinaryBlockMatrix;
        JavaSparkContext sparkContext = getSparkContext();
        if (matrixObject.getRDDHandle() != null && (matrixObject.getRDDHandle().isCheckpointRDD() || !matrixObject.isCached(false))) {
            copyBinaryBlockMatrix = matrixObject.getRDDHandle().getRDD();
        } else if (matrixObject.isDirty() || matrixObject.isCached(false)) {
            MatrixCharacteristics matrixCharacteristics = matrixObject.getMatrixCharacteristics();
            boolean z = false;
            if (OptimizerUtils.checkSparkCollectMemoryBudget(matrixCharacteristics, 0L) && _parRDDs.reserve(OptimizerUtils.estimatePartitionedSizeExactSparsity(matrixCharacteristics))) {
                copyBinaryBlockMatrix = toMatrixJavaPairRDD(sparkContext, matrixObject.acquireRead(), (int) matrixObject.getNumRowsPerBlock(), (int) matrixObject.getNumColumnsPerBlock());
                matrixObject.release();
                _parRDDs.registerRDD(copyBinaryBlockMatrix.id(), OptimizerUtils.estimatePartitionedSizeExactSparsity(matrixCharacteristics), true);
            } else {
                if (matrixObject.isDirty() || !matrixObject.isHDFSFileExists()) {
                    matrixObject.exportData();
                }
                copyBinaryBlockMatrix = SparkUtils.copyBinaryBlockMatrix(sparkContext.hadoopFile(matrixObject.getFileName(), inputInfo.inputFormatClass, inputInfo.inputKeyClass, inputInfo.inputValueClass));
                z = true;
            }
            RDDObject rDDObject = new RDDObject(copyBinaryBlockMatrix);
            rDDObject.setHDFSFile(z);
            rDDObject.setParallelizedRDD(!z);
            matrixObject.setRDDHandle(rDDObject);
        } else {
            if (inputInfo == InputInfo.BinaryBlockInputInfo) {
                copyBinaryBlockMatrix = SparkUtils.copyBinaryBlockMatrix(sparkContext.hadoopFile(matrixObject.getFileName(), inputInfo.inputFormatClass, inputInfo.inputKeyClass, inputInfo.inputValueClass));
            } else if (inputInfo == InputInfo.TextCellInputInfo || inputInfo == InputInfo.CSVInputInfo || inputInfo == InputInfo.MatrixMarketInputInfo) {
                copyBinaryBlockMatrix = sparkContext.hadoopFile(matrixObject.getFileName(), inputInfo.inputFormatClass, inputInfo.inputKeyClass, inputInfo.inputValueClass).mapToPair(new CopyTextInputFunction());
            } else {
                if (inputInfo != InputInfo.BinaryCellInputInfo) {
                    throw new DMLRuntimeException("Incorrect input format in getRDDHandleForVariable");
                }
                copyBinaryBlockMatrix = sparkContext.hadoopFile(matrixObject.getFileName(), inputInfo.inputFormatClass, inputInfo.inputKeyClass, inputInfo.inputValueClass).mapToPair(new CopyBinaryCellFunction());
            }
            RDDObject rDDObject2 = new RDDObject(copyBinaryBlockMatrix);
            rDDObject2.setHDFSFile(true);
            matrixObject.setRDDHandle(rDDObject2);
        }
        return copyBinaryBlockMatrix;
    }

    public JavaPairRDD<?, ?> getRDDHandleForFrameObject(FrameObject frameObject, InputInfo inputInfo) throws DMLRuntimeException {
        JavaPairRDD<?, ?> mapToPair;
        InputInfo inputInfo2 = inputInfo == InputInfo.BinaryBlockInputInfo ? InputInfo.BinaryBlockFrameInputInfo : inputInfo;
        JavaSparkContext sparkContext = getSparkContext();
        if (frameObject.getRDDHandle() != null && (frameObject.getRDDHandle().isCheckpointRDD() || !frameObject.isCached(false))) {
            mapToPair = frameObject.getRDDHandle().getRDD();
        } else if (frameObject.isDirty() || frameObject.isCached(false)) {
            MatrixCharacteristics matrixCharacteristics = frameObject.getMatrixCharacteristics();
            boolean z = false;
            if (OptimizerUtils.checkSparkCollectMemoryBudget(matrixCharacteristics, 0L) && _parRDDs.reserve(OptimizerUtils.estimatePartitionedSizeExactSparsity(matrixCharacteristics))) {
                mapToPair = toFrameJavaPairRDD(sparkContext, frameObject.acquireRead());
                frameObject.release();
                _parRDDs.registerRDD(mapToPair.id(), OptimizerUtils.estimatePartitionedSizeExactSparsity(matrixCharacteristics), true);
            } else {
                if (frameObject.isDirty()) {
                    frameObject.exportData();
                }
                mapToPair = sparkContext.hadoopFile(frameObject.getFileName(), inputInfo2.inputFormatClass, inputInfo2.inputKeyClass, inputInfo2.inputValueClass).mapToPair(new CopyFrameBlockPairFunction());
                z = true;
            }
            RDDObject rDDObject = new RDDObject(mapToPair);
            rDDObject.setHDFSFile(z);
            frameObject.setRDDHandle(rDDObject);
        } else {
            if (inputInfo2 == InputInfo.BinaryBlockFrameInputInfo) {
                mapToPair = sparkContext.hadoopFile(frameObject.getFileName(), inputInfo2.inputFormatClass, inputInfo2.inputKeyClass, inputInfo2.inputValueClass).mapToPair(new CopyFrameBlockPairFunction());
            } else {
                if (inputInfo2 != InputInfo.TextCellInputInfo && inputInfo2 != InputInfo.CSVInputInfo && inputInfo2 != InputInfo.MatrixMarketInputInfo) {
                    if (inputInfo2 == InputInfo.BinaryCellInputInfo) {
                        throw new DMLRuntimeException("Binarycell not supported for frames.");
                    }
                    throw new DMLRuntimeException("Incorrect input format in getRDDHandleForVariable");
                }
                mapToPair = sparkContext.hadoopFile(frameObject.getFileName(), inputInfo2.inputFormatClass, inputInfo2.inputKeyClass, inputInfo2.inputValueClass).mapToPair(new CopyTextInputFunction());
            }
            RDDObject rDDObject2 = new RDDObject(mapToPair);
            rDDObject2.setHDFSFile(true);
            frameObject.setRDDHandle(rDDObject2);
        }
        return mapToPair;
    }

    public PartitionedBroadcast<MatrixBlock> getBroadcastForVariable(String str) throws DMLRuntimeException {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        MatrixObject matrixObject = getMatrixObject(str);
        PartitionedBroadcast<MatrixBlock> partitionedBroadcast = null;
        if (matrixObject.getBroadcastHandle() != null && matrixObject.getBroadcastHandle().isValid()) {
            partitionedBroadcast = matrixObject.getBroadcastHandle().getBroadcast();
        }
        if (partitionedBroadcast == null) {
            if (matrixObject.getBroadcastHandle() != null) {
                CacheableData.addBroadcastSize(-matrixObject.getBroadcastHandle().getSize());
            }
            int numRowsPerBlock = (int) matrixObject.getNumRowsPerBlock();
            int numColumnsPerBlock = (int) matrixObject.getNumColumnsPerBlock();
            PartitionedBlock partitionedBlock = new PartitionedBlock(matrixObject.acquireRead(), numRowsPerBlock, numColumnsPerBlock);
            matrixObject.release();
            int computeBlocksPerPartition = PartitionedBroadcast.computeBlocksPerPartition(matrixObject.getNumRows(), matrixObject.getNumColumns(), numRowsPerBlock, numColumnsPerBlock);
            int ceil = (int) Math.ceil((partitionedBlock.getNumRowBlocks() * partitionedBlock.getNumColumnBlocks()) / computeBlocksPerPartition);
            Broadcast[] broadcastArr = new Broadcast[ceil];
            if (ceil > 1) {
                for (int i = 0; i < ceil; i++) {
                    int i2 = i * computeBlocksPerPartition;
                    PartitionedBlock createPartition = partitionedBlock.createPartition(i2, Math.min(computeBlocksPerPartition, (partitionedBlock.getNumRowBlocks() * partitionedBlock.getNumColumnBlocks()) - i2), new MatrixBlock());
                    broadcastArr[i] = getSparkContext().broadcast(createPartition);
                    if (!isLocalMaster()) {
                        createPartition.clearBlocks();
                    }
                }
            } else {
                broadcastArr[0] = getSparkContext().broadcast(partitionedBlock);
                if (!isLocalMaster()) {
                    partitionedBlock.clearBlocks();
                }
            }
            partitionedBroadcast = new PartitionedBroadcast<>(broadcastArr, matrixObject.getMatrixCharacteristics());
            BroadcastObject broadcastObject = new BroadcastObject(partitionedBroadcast, OptimizerUtils.estimatePartitionedSizeExactSparsity(matrixObject.getMatrixCharacteristics()));
            matrixObject.setBroadcastHandle(broadcastObject);
            CacheableData.addBroadcastSize(broadcastObject.getSize());
        }
        if (DMLScript.STATISTICS) {
            Statistics.accSparkBroadCastTime(System.nanoTime() - nanoTime);
            Statistics.incSparkBroadcastCount(1L);
        }
        return partitionedBroadcast;
    }

    public PartitionedBroadcast<FrameBlock> getBroadcastForFrameVariable(String str) throws DMLRuntimeException {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        FrameObject frameObject = getFrameObject(str);
        PartitionedBroadcast<FrameBlock> partitionedBroadcast = null;
        if (frameObject.getBroadcastHandle() != null && frameObject.getBroadcastHandle().isValid()) {
            partitionedBroadcast = frameObject.getBroadcastHandle().getBroadcast();
        }
        if (partitionedBroadcast == null) {
            if (frameObject.getBroadcastHandle() != null) {
                CacheableData.addBroadcastSize(-frameObject.getBroadcastHandle().getSize());
            }
            int numColumns = (int) frameObject.getNumColumns();
            int defaultFrameSize = OptimizerUtils.getDefaultFrameSize();
            PartitionedBlock partitionedBlock = new PartitionedBlock(frameObject.acquireRead(), defaultFrameSize, numColumns);
            frameObject.release();
            int computeBlocksPerPartition = PartitionedBroadcast.computeBlocksPerPartition(frameObject.getNumRows(), frameObject.getNumColumns(), defaultFrameSize, numColumns);
            int ceil = (int) Math.ceil((partitionedBlock.getNumRowBlocks() * partitionedBlock.getNumColumnBlocks()) / computeBlocksPerPartition);
            Broadcast[] broadcastArr = new Broadcast[ceil];
            if (ceil > 1) {
                for (int i = 0; i < ceil; i++) {
                    int i2 = i * computeBlocksPerPartition;
                    PartitionedBlock createPartition = partitionedBlock.createPartition(i2, Math.min(computeBlocksPerPartition, (partitionedBlock.getNumRowBlocks() * partitionedBlock.getNumColumnBlocks()) - i2), new FrameBlock());
                    broadcastArr[i] = getSparkContext().broadcast(createPartition);
                    if (!isLocalMaster()) {
                        createPartition.clearBlocks();
                    }
                }
            } else {
                broadcastArr[0] = getSparkContext().broadcast(partitionedBlock);
                if (!isLocalMaster()) {
                    partitionedBlock.clearBlocks();
                }
            }
            partitionedBroadcast = new PartitionedBroadcast<>(broadcastArr, frameObject.getMatrixCharacteristics());
            BroadcastObject broadcastObject = new BroadcastObject(partitionedBroadcast, OptimizerUtils.estimatePartitionedSizeExactSparsity(frameObject.getMatrixCharacteristics()));
            frameObject.setBroadcastHandle(broadcastObject);
            CacheableData.addBroadcastSize(broadcastObject.getSize());
        }
        if (DMLScript.STATISTICS) {
            Statistics.accSparkBroadCastTime(System.nanoTime() - nanoTime);
            Statistics.incSparkBroadcastCount(1L);
        }
        return partitionedBroadcast;
    }

    public void setRDDHandleForVariable(String str, JavaPairRDD<?, ?> javaPairRDD) throws DMLRuntimeException {
        getCacheableData(str).setRDDHandle(new RDDObject(javaPairRDD));
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> toMatrixJavaPairRDD(JavaSparkContext javaSparkContext, MatrixBlock matrixBlock, int i, int i2) throws DMLRuntimeException {
        List list;
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        if (matrixBlock.getNumRows() > i || matrixBlock.getNumColumns() > i2) {
            MatrixCharacteristics matrixCharacteristics = new MatrixCharacteristics(matrixBlock.getNumRows(), matrixBlock.getNumColumns(), i, i2, matrixBlock.getNonZeros());
            list = (List) LongStream.range(0L, matrixCharacteristics.getNumBlocks()).parallel().mapToObj(j -> {
                return createIndexedBlock(matrixBlock, matrixCharacteristics, j);
            }).collect(Collectors.toList());
        } else {
            list = Arrays.asList(new Tuple2(new MatrixIndexes(1L, 1L), matrixBlock));
        }
        JavaPairRDD<MatrixIndexes, MatrixBlock> parallelizePairs = javaSparkContext.parallelizePairs(list);
        if (DMLScript.STATISTICS) {
            Statistics.accSparkParallelizeTime(System.nanoTime() - nanoTime);
            Statistics.incSparkParallelizeCount(1L);
        }
        return parallelizePairs;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Tuple2<MatrixIndexes, MatrixBlock> createIndexedBlock(MatrixBlock matrixBlock, MatrixCharacteristics matrixCharacteristics, long j) {
        try {
            long numColBlocks = j / matrixCharacteristics.getNumColBlocks();
            long numColBlocks2 = j % matrixCharacteristics.getNumColBlocks();
            int computeBlockSize = UtilFunctions.computeBlockSize(matrixCharacteristics.getRows(), numColBlocks + 1, matrixCharacteristics.getRowsPerBlock());
            int computeBlockSize2 = UtilFunctions.computeBlockSize(matrixCharacteristics.getCols(), numColBlocks2 + 1, matrixCharacteristics.getColsPerBlock());
            MatrixBlock matrixBlock2 = new MatrixBlock(computeBlockSize, computeBlockSize2, matrixBlock.isInSparseFormat());
            int rowsPerBlock = ((int) numColBlocks) * matrixCharacteristics.getRowsPerBlock();
            int colsPerBlock = ((int) numColBlocks2) * matrixCharacteristics.getColsPerBlock();
            return new Tuple2<>(new MatrixIndexes(numColBlocks + 1, numColBlocks2 + 1), matrixBlock.slice(rowsPerBlock, (rowsPerBlock + computeBlockSize) - 1, colsPerBlock, (colsPerBlock + computeBlockSize2) - 1, (CacheBlock) matrixBlock2));
        } catch (DMLRuntimeException e) {
            throw new RuntimeException(e);
        }
    }

    public static JavaPairRDD<Long, FrameBlock> toFrameJavaPairRDD(JavaSparkContext javaSparkContext, FrameBlock frameBlock) throws DMLRuntimeException {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        LinkedList linkedList = new LinkedList();
        int blocksize = ConfigurationManager.getBlocksize();
        for (int i = 0; i < ((int) Math.ceil(frameBlock.getNumRows() / blocksize)); i++) {
            int numRows = (i * blocksize) + blocksize < frameBlock.getNumRows() ? blocksize : frameBlock.getNumRows() - (i * blocksize);
            int i2 = i * blocksize;
            FrameBlock frameBlock2 = new FrameBlock(frameBlock.getSchema());
            frameBlock.slice(i2, (i2 + numRows) - 1, 0, frameBlock.getNumColumns() - 1, (CacheBlock) frameBlock2);
            if (i2 == 0) {
                frameBlock2.setColumnMetadata(frameBlock.getColumnMetadata());
            }
            linkedList.addLast(new Tuple2(Long.valueOf(i2 + 1), frameBlock2));
        }
        JavaPairRDD<Long, FrameBlock> parallelizePairs = javaSparkContext.parallelizePairs(linkedList);
        if (DMLScript.STATISTICS) {
            Statistics.accSparkParallelizeTime(System.nanoTime() - nanoTime);
            Statistics.incSparkParallelizeCount(1L);
        }
        return parallelizePairs;
    }

    public static MatrixBlock toMatrixBlock(RDDObject rDDObject, int i, int i2, int i3, int i4, long j) throws DMLRuntimeException {
        return toMatrixBlock((JavaPairRDD<MatrixIndexes, MatrixBlock>) rDDObject.getRDD(), i, i2, i3, i4, j);
    }

    public static MatrixBlock toMatrixBlock(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, int i, int i2, int i3, int i4, long j) throws DMLRuntimeException {
        MatrixBlock matrixBlock;
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        if (i > i3 || i2 > i4) {
            long j2 = j >= 0 ? j : i * i2;
            boolean evalSparseFormatInMemory = MatrixBlock.evalSparseFormatInMemory(i, i2, j2);
            matrixBlock = new MatrixBlock(i, i2, evalSparseFormatInMemory, j2);
            long j3 = 0;
            for (Tuple2 tuple2 : javaPairRDD.collect()) {
                MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
                MatrixBlock matrixBlock2 = (MatrixBlock) tuple2._2();
                int rowIndex = ((int) (matrixIndexes.getRowIndex() - 1)) * i3;
                int columnIndex = ((int) (matrixIndexes.getColumnIndex() - 1)) * i4;
                int numRows = matrixBlock2.getNumRows();
                int numColumns = matrixBlock2.getNumColumns();
                if (matrixBlock2 instanceof CompressedMatrixBlock) {
                    matrixBlock2 = ((CompressedMatrixBlock) matrixBlock2).decompress();
                }
                if (evalSparseFormatInMemory) {
                    matrixBlock.appendToSparse(matrixBlock2, rowIndex, columnIndex, i2 > i4);
                } else {
                    matrixBlock.copy(rowIndex, (rowIndex + numRows) - 1, columnIndex, (columnIndex + numColumns) - 1, matrixBlock2, false);
                }
                j3 += matrixBlock2.getNonZeros();
            }
            if (evalSparseFormatInMemory && i2 > i4) {
                matrixBlock.sortSparseRows();
            }
            matrixBlock.setNonZeros(j3);
            matrixBlock.examSparsity();
        } else {
            List collect = javaPairRDD.collect();
            if (collect.size() > 1) {
                throw new DMLRuntimeException("Expecting no more than one result block.");
            }
            matrixBlock = collect.size() == 1 ? (MatrixBlock) ((Tuple2) collect.get(0))._2() : new MatrixBlock(i, i2, true);
            matrixBlock.examSparsity();
        }
        if (DMLScript.STATISTICS) {
            Statistics.accSparkCollectTime(System.nanoTime() - nanoTime);
            Statistics.incSparkCollectCount(1L);
        }
        return matrixBlock;
    }

    public static MatrixBlock toMatrixBlock(RDDObject rDDObject, int i, int i2, long j) throws DMLRuntimeException {
        return toMatrixBlock((JavaPairRDD<MatrixIndexes, MatrixCell>) rDDObject.getRDD(), i, i2, j);
    }

    public static MatrixBlock toMatrixBlock(JavaPairRDD<MatrixIndexes, MatrixCell> javaPairRDD, int i, int i2, long j) throws DMLRuntimeException {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        boolean evalSparseFormatInMemory = MatrixBlock.evalSparseFormatInMemory(i, i2, j >= 0 ? j : i * i2);
        MatrixBlock matrixBlock = new MatrixBlock(i, i2, evalSparseFormatInMemory);
        for (Tuple2 tuple2 : javaPairRDD.collect()) {
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            matrixBlock.appendValue(((int) matrixIndexes.getRowIndex()) - 1, ((int) matrixIndexes.getColumnIndex()) - 1, ((MatrixCell) tuple2._2()).getValue());
        }
        if (evalSparseFormatInMemory) {
            matrixBlock.sortSparseRows();
        }
        matrixBlock.recomputeNonZeros();
        matrixBlock.examSparsity();
        if (DMLScript.STATISTICS) {
            Statistics.accSparkCollectTime(System.nanoTime() - nanoTime);
            Statistics.incSparkCollectCount(1L);
        }
        return matrixBlock;
    }

    public static PartitionedBlock<MatrixBlock> toPartitionedMatrixBlock(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, int i, int i2, int i3, int i4, long j) throws DMLRuntimeException {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        PartitionedBlock<MatrixBlock> partitionedBlock = new PartitionedBlock<>(i, i2, i3, i4);
        for (Tuple2 tuple2 : javaPairRDD.collect()) {
            MatrixIndexes matrixIndexes = (MatrixIndexes) tuple2._1();
            partitionedBlock.setBlock((int) matrixIndexes.getRowIndex(), (int) matrixIndexes.getColumnIndex(), (MatrixBlock) tuple2._2());
        }
        if (DMLScript.STATISTICS) {
            Statistics.accSparkCollectTime(System.nanoTime() - nanoTime);
            Statistics.incSparkCollectCount(1L);
        }
        return partitionedBlock;
    }

    public static FrameBlock toFrameBlock(RDDObject rDDObject, Expression.ValueType[] valueTypeArr, int i, int i2) throws DMLRuntimeException {
        return toFrameBlock((JavaPairRDD<Long, FrameBlock>) rDDObject.getRDD(), valueTypeArr, i, i2);
    }

    public static FrameBlock toFrameBlock(JavaPairRDD<Long, FrameBlock> javaPairRDD, Expression.ValueType[] valueTypeArr, int i, int i2) throws DMLRuntimeException {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        if (valueTypeArr == null) {
            valueTypeArr = UtilFunctions.nCopies(i2, Expression.ValueType.STRING);
        }
        FrameBlock frameBlock = new FrameBlock(valueTypeArr);
        frameBlock.ensureAllocatedColumns(i);
        for (Tuple2 tuple2 : javaPairRDD.collect()) {
            int longValue = (int) (((Long) tuple2._1()).longValue() - 1);
            FrameBlock frameBlock2 = (FrameBlock) tuple2._2();
            frameBlock.copy(longValue, (longValue + frameBlock2.getNumRows()) - 1, 0, frameBlock2.getNumColumns() - 1, frameBlock2);
            if (longValue == 0) {
                frameBlock.setColumnNames(frameBlock2.getColumnNames());
                frameBlock.setColumnMetadata(frameBlock2.getColumnMetadata());
            }
        }
        if (DMLScript.STATISTICS) {
            Statistics.accSparkCollectTime(System.nanoTime() - nanoTime);
            Statistics.incSparkCollectCount(1L);
        }
        return frameBlock;
    }

    public static long writeRDDtoHDFS(RDDObject rDDObject, String str, OutputInfo outputInfo) {
        JavaPairRDD<?, ?> rdd = rDDObject.getRDD();
        LongAccumulator longAccumulator = getSparkContextStatic().sc().longAccumulator(DataExpression.READNUMNONZEROPARAM);
        rdd.mapValues(new ComputeBinaryBlockNnzFunction(longAccumulator)).saveAsHadoopFile(str, outputInfo.outputKeyClass, outputInfo.outputValueClass, outputInfo.outputFormatClass);
        return longAccumulator.value().longValue();
    }

    public static void writeFrameRDDtoHDFS(RDDObject rDDObject, String str, OutputInfo outputInfo) {
        JavaPairRDD<?, ?> rdd = rDDObject.getRDD();
        if (outputInfo == OutputInfo.BinaryBlockOutputInfo) {
            rdd = rdd.mapToPair(new FrameRDDConverterUtils.LongFrameToLongWritableFrameFunction());
            outputInfo = OutputInfo.BinaryBlockFrameOutputInfo;
        }
        rdd.saveAsHadoopFile(str, outputInfo.outputKeyClass, outputInfo.outputValueClass, outputInfo.outputFormatClass);
    }

    public void addLineageRDD(String str, String str2) throws DMLRuntimeException {
        getCacheableData(str).getRDDHandle().addLineageChild(getCacheableData(str2).getRDDHandle());
    }

    public void addLineageBroadcast(String str, String str2) throws DMLRuntimeException {
        getCacheableData(str).getRDDHandle().addLineageChild(getCacheableData(str2).getBroadcastHandle());
    }

    public void addLineage(String str, String str2, boolean z) throws DMLRuntimeException {
        if (z) {
            addLineageBroadcast(str, str2);
        } else {
            addLineageRDD(str, str2);
        }
    }

    @Override // org.apache.sysml.runtime.controlprogram.context.ExecutionContext
    public void cleanupCacheableData(CacheableData<?> cacheableData) throws DMLRuntimeException {
        try {
            if (cacheableData.isCleanupEnabled() && !getVariables().hasReferences(cacheableData)) {
                cacheableData.clearData();
                if (cacheableData.isHDFSFileExists() && cacheableData.getFileName() != null) {
                    if (cacheableData.getRDDHandle() == null) {
                        MapReduceTool.deleteFileWithMTDIfExistOnHDFS(cacheableData.getFileName());
                    } else {
                        cacheableData.getRDDHandle().setHDFSFilename(cacheableData.getFileName());
                    }
                }
                if (cacheableData.getRDDHandle() != null) {
                    rCleanupLineageObject(cacheableData.getRDDHandle());
                }
                if (cacheableData.getBroadcastHandle() != null) {
                    rCleanupLineageObject(cacheableData.getBroadcastHandle());
                }
            }
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    private void rCleanupLineageObject(LineageObject lineageObject) throws IOException {
        if (lineageObject.getNumReferences() <= 0 && !lineageObject.hasBackReference()) {
            if (lineageObject instanceof RDDObject) {
                RDDObject rDDObject = (RDDObject) lineageObject;
                int id = rDDObject.getRDD().id();
                cleanupRDDVariable(rDDObject.getRDD());
                if (rDDObject.getHDFSFilename() != null) {
                    MapReduceTool.deleteFileWithMTDIfExistOnHDFS(rDDObject.getHDFSFilename());
                }
                if (rDDObject.isParallelizedRDD()) {
                    _parRDDs.deregisterRDD(id);
                }
            } else if (lineageObject instanceof BroadcastObject) {
                PartitionedBroadcast broadcast = ((BroadcastObject) lineageObject).getBroadcast();
                if (broadcast != null) {
                    for (Broadcast broadcast2 : broadcast.getBroadcasts()) {
                        cleanupBroadcastVariable(broadcast2);
                    }
                }
                CacheableData.addBroadcastSize(-((BroadcastObject) lineageObject).getSize());
            }
            for (LineageObject lineageObject2 : lineageObject.getLineageChilds()) {
                lineageObject2.decrementNumReferences();
                rCleanupLineageObject(lineageObject2);
            }
        }
    }

    public static void cleanupBroadcastVariable(Broadcast<?> broadcast) {
        if (broadcast.isValid()) {
            broadcast.destroy(!ASYNCHRONOUS_VAR_DESTROY);
        }
    }

    public static void cleanupRDDVariable(JavaPairRDD<?, ?> javaPairRDD) {
        if (javaPairRDD.getStorageLevel() != StorageLevel.NONE()) {
            javaPairRDD.unpersist(!ASYNCHRONOUS_VAR_DESTROY);
        }
    }

    public void repartitionAndCacheMatrixObject(String str) throws DMLRuntimeException {
        MatrixObject matrixObject = getMatrixObject(str);
        MatrixCharacteristics matrixCharacteristics = matrixObject.getMatrixCharacteristics();
        if (OptimizerUtils.exceedsCachingThreshold(matrixObject.getNumColumns(), OptimizerUtils.estimateSizeExactSparsity(matrixCharacteristics))) {
            JavaPairRDD<?, ?> rDDHandleForMatrixObject = getRDDHandleForMatrixObject(matrixObject, InputInfo.BinaryBlockInputInfo);
            if (matrixObject.getRDDHandle().allowsShortCircuitRead() && isRDDMarkedForCaching(rDDHandleForMatrixObject.id()) && !isRDDCached(rDDHandleForMatrixObject.id())) {
                rDDHandleForMatrixObject = ((RDDObject) matrixObject.getRDDHandle().getLineageChilds().get(0)).getRDD();
                int numPreferredPartitions = SparkUtils.getNumPreferredPartitions(matrixCharacteristics, rDDHandleForMatrixObject);
                if (numPreferredPartitions < rDDHandleForMatrixObject.getNumPartitions()) {
                    rDDHandleForMatrixObject = rDDHandleForMatrixObject.coalesce(numPreferredPartitions);
                }
            }
            JavaPairRDD<MatrixIndexes, MatrixBlock> mergeByKey = RDDAggregateUtils.mergeByKey(rDDHandleForMatrixObject, false);
            if (OptimizerUtils.checkSparseBlockCSRConversion(matrixCharacteristics)) {
                mergeByKey = mergeByKey.mapValues(new CreateSparseBlockFunction(SparseBlock.Type.CSR));
            }
            mergeByKey.persist(Checkpoint.DEFAULT_STORAGE_LEVEL).count();
            RDDObject rDDHandle = matrixObject.getRDDHandle();
            RDDObject rDDObject = new RDDObject(mergeByKey);
            rDDObject.setCheckpointRDD(true);
            rDDObject.addLineageChild(rDDHandle);
            matrixObject.setRDDHandle(rDDObject);
        }
    }

    public void cacheMatrixObject(String str) throws DMLRuntimeException {
        MatrixObject matrixObject = getMatrixObject(str);
        if (OptimizerUtils.exceedsCachingThreshold(matrixObject.getNumColumns(), OptimizerUtils.estimateSizeExactSparsity(matrixObject.getMatrixCharacteristics()))) {
            JavaPairRDD<?, ?> rDDHandleForMatrixObject = getRDDHandleForMatrixObject(matrixObject, InputInfo.BinaryBlockInputInfo);
            if (isRDDCached(rDDHandleForMatrixObject.id())) {
                return;
            }
            rDDHandleForMatrixObject.count();
        }
    }

    public void setThreadLocalSchedulerPool(String str) {
        if (FAIR_SCHEDULER_MODE) {
            getSparkContext().sc().setLocalProperty("spark.scheduler.pool", str);
        }
    }

    public void cleanupThreadLocalSchedulerPool() {
        if (FAIR_SCHEDULER_MODE) {
            getSparkContext().sc().setLocalProperty("spark.scheduler.pool", (String) null);
        }
    }

    private boolean isRDDMarkedForCaching(int i) {
        return getSparkContext().sc().getPersistentRDDs().contains(Integer.valueOf(i));
    }

    public boolean isRDDCached(int i) {
        JavaSparkContext sparkContext = getSparkContext();
        if (!sparkContext.sc().getPersistentRDDs().contains(Integer.valueOf(i))) {
            return false;
        }
        for (RDDInfo rDDInfo : sparkContext.sc().getRDDStorageInfo()) {
            if (rDDInfo.id() == i) {
                return rDDInfo.isCached();
            }
        }
        return false;
    }

    public static SparkClusterConfig getSparkClusterConfig() {
        if (_sconf == null) {
            _sconf = new SparkClusterConfig();
        }
        return _sconf;
    }

    public static double getBroadcastMemoryBudget() {
        return getSparkClusterConfig().getBroadcastMemoryBudget();
    }

    public static double getDataMemoryBudget(boolean z, boolean z2) {
        return getSparkClusterConfig().getDataMemoryBudget(z, z2);
    }

    public static int getNumExecutors() {
        return getSparkClusterConfig().getNumExecutors();
    }

    public static int getDefaultParallelism(boolean z) {
        return getSparkClusterConfig().getDefaultParallelism(z);
    }
}
