package org.apache.sysml.runtime.instructions.spark.utils;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import org.apache.spark.HashPartitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.storage.StorageLevel;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.lops.Checkpoint;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.instructions.spark.functions.CopyBinaryCellFunction;
import org.apache.sysml.runtime.instructions.spark.functions.CopyBlockFunction;
import org.apache.sysml.runtime.instructions.spark.functions.CopyBlockPairFunction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.Pair;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.util.UtilFunctions;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.class */
public class SparkUtils {
    public static final StorageLevel DEFAULT_TMP = Checkpoint.DEFAULT_STORAGE_LEVEL;

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/SparkUtils$AggregateMatrixCharacteristics.class */
    private static class AggregateMatrixCharacteristics implements Function2<MatrixCharacteristics, MatrixCharacteristics, MatrixCharacteristics> {
        private static final long serialVersionUID = 4263886749699779994L;

        private AggregateMatrixCharacteristics() {
        }

        public MatrixCharacteristics call(MatrixCharacteristics matrixCharacteristics, MatrixCharacteristics matrixCharacteristics2) throws Exception {
            return new MatrixCharacteristics(Math.max(matrixCharacteristics.getRows(), matrixCharacteristics2.getRows()), Math.max(matrixCharacteristics.getCols(), matrixCharacteristics2.getCols()), matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock(), matrixCharacteristics.getNonZeros() + matrixCharacteristics2.getNonZeros());
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/SparkUtils$AnalyzeCellMatrixCharacteristics.class */
    private static class AnalyzeCellMatrixCharacteristics implements Function<Tuple2<MatrixIndexes, MatrixCell>, MatrixCharacteristics> {
        private static final long serialVersionUID = 8899395272683723008L;

        private AnalyzeCellMatrixCharacteristics() {
        }

        public MatrixCharacteristics call(Tuple2<MatrixIndexes, MatrixCell> tuple2) throws Exception {
            return new MatrixCharacteristics(((MatrixIndexes) tuple2._1()).getRowIndex(), ((MatrixIndexes) tuple2._1()).getColumnIndex(), 0, 0, ((MatrixCell) tuple2._2()).getValue() != 0.0d ? 1L : 0L);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/SparkUtils$GenerateEmptyBlocks.class */
    public static class GenerateEmptyBlocks implements PairFlatMapFunction<Long, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 630129586089106855L;
        private final MatrixCharacteristics _mc;
        private final long _pNumBlocks;

        public GenerateEmptyBlocks(MatrixCharacteristics matrixCharacteristics, long j) {
            this._mc = matrixCharacteristics;
            this._pNumBlocks = j;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Long l) throws Exception {
            ArrayList arrayList = new ArrayList();
            long numColBlocks = this._mc.getNumColBlocks();
            long min = Math.min(l.longValue() + this._pNumBlocks, this._mc.getNumBlocks());
            long longValue = l.longValue();
            while (true) {
                long j = longValue;
                if (j >= min) {
                    return arrayList.iterator();
                }
                long j2 = 1 + (j / numColBlocks);
                long j3 = 1 + (j % numColBlocks);
                arrayList.add(new Tuple2(new MatrixIndexes(j2, j3), new MatrixBlock(UtilFunctions.computeBlockSize(this._mc.getRows(), j2, this._mc.getRowsPerBlock()), UtilFunctions.computeBlockSize(this._mc.getCols(), j3, this._mc.getColsPerBlock()), true)));
                longValue = j + 1;
            }
        }
    }

    public static IndexedMatrixValue toIndexedMatrixBlock(Tuple2<MatrixIndexes, MatrixBlock> tuple2) {
        return new IndexedMatrixValue((MatrixIndexes) tuple2._1(), (MatrixValue) tuple2._2());
    }

    public static IndexedMatrixValue toIndexedMatrixBlock(MatrixIndexes matrixIndexes, MatrixBlock matrixBlock) {
        return new IndexedMatrixValue(matrixIndexes, matrixBlock);
    }

    public static Tuple2<MatrixIndexes, MatrixBlock> fromIndexedMatrixBlock(IndexedMatrixValue indexedMatrixValue) {
        return new Tuple2<>(indexedMatrixValue.getIndexes(), (MatrixBlock) indexedMatrixValue.getValue());
    }

    public static ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> fromIndexedMatrixBlock(ArrayList<IndexedMatrixValue> arrayList) {
        ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> arrayList2 = new ArrayList<>();
        Iterator<IndexedMatrixValue> it = arrayList.iterator();
        while (it.hasNext()) {
            arrayList2.add(fromIndexedMatrixBlock(it.next()));
        }
        return arrayList2;
    }

    public static Pair<MatrixIndexes, MatrixBlock> fromIndexedMatrixBlockToPair(IndexedMatrixValue indexedMatrixValue) {
        return new Pair<>(indexedMatrixValue.getIndexes(), (MatrixBlock) indexedMatrixValue.getValue());
    }

    public static ArrayList<Pair<MatrixIndexes, MatrixBlock>> fromIndexedMatrixBlockToPair(ArrayList<IndexedMatrixValue> arrayList) {
        ArrayList<Pair<MatrixIndexes, MatrixBlock>> arrayList2 = new ArrayList<>();
        Iterator<IndexedMatrixValue> it = arrayList.iterator();
        while (it.hasNext()) {
            arrayList2.add(fromIndexedMatrixBlockToPair(it.next()));
        }
        return arrayList2;
    }

    public static Tuple2<Long, FrameBlock> fromIndexedFrameBlock(Pair<Long, FrameBlock> pair) {
        return new Tuple2<>(pair.getKey(), pair.getValue());
    }

    public static ArrayList<Tuple2<Long, FrameBlock>> fromIndexedFrameBlock(ArrayList<Pair<Long, FrameBlock>> arrayList) {
        ArrayList<Tuple2<Long, FrameBlock>> arrayList2 = new ArrayList<>();
        Iterator<Pair<Long, FrameBlock>> it = arrayList.iterator();
        while (it.hasNext()) {
            arrayList2.add(fromIndexedFrameBlock(it.next()));
        }
        return arrayList2;
    }

    public static ArrayList<Pair<Long, Long>> toIndexedLong(List<Tuple2<Long, Long>> list) {
        ArrayList<Pair<Long, Long>> arrayList = new ArrayList<>();
        for (Tuple2<Long, Long> tuple2 : list) {
            arrayList.add(new Pair<>(tuple2._1(), tuple2._2()));
        }
        return arrayList;
    }

    public static Pair<Long, FrameBlock> toIndexedFrameBlock(Tuple2<Long, FrameBlock> tuple2) {
        return new Pair<>(tuple2._1(), tuple2._2());
    }

    public static boolean isHashPartitioned(JavaPairRDD<?, ?> javaPairRDD) {
        return !javaPairRDD.rdd().partitioner().isEmpty() && (javaPairRDD.rdd().partitioner().get() instanceof HashPartitioner);
    }

    public static int getNumPreferredPartitions(MatrixCharacteristics matrixCharacteristics, JavaPairRDD<?, ?> javaPairRDD) {
        return (matrixCharacteristics.dimsKnown(true) || javaPairRDD == null) ? getNumPreferredPartitions(matrixCharacteristics) : javaPairRDD.getNumPartitions();
    }

    public static int getNumPreferredPartitions(MatrixCharacteristics matrixCharacteristics) {
        if (!matrixCharacteristics.dimsKnown()) {
            return SparkExecutionContext.getDefaultParallelism(true);
        }
        return (int) Math.max(Math.ceil(OptimizerUtils.estimatePartitionedSizeExactSparsity(matrixCharacteristics) / InfrastructureAnalyzer.getHDFSBlockSize()), 1.0d);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> copyBinaryBlockMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD) {
        return copyBinaryBlockMatrix(javaPairRDD, true);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> copyBinaryBlockMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, boolean z) {
        return !z ? javaPairRDD.mapValues(new CopyBlockFunction(false)) : javaPairRDD.mapPartitionsToPair(new CopyBlockPairFunction(z), true);
    }

    public static String getStartLineFromSparkDebugInfo(String str) throws DMLRuntimeException {
        return str.substring(4, str.length()).split(":")[0];
    }

    public static String getPrefixFromSparkDebugInfo(String str) {
        String[] split = str.split("\\||\\+-");
        String str2 = split[0];
        for (int i = 1; i < split.length - 1; i++) {
            str2 = str2 + ProgramConverter.DATA_FIELD_DELIM + split[i];
        }
        return str.contains("+-") ? str2 + "+- " : str2 + ProgramConverter.DATA_FIELD_DELIM + "  ";
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> getEmptyBlockRDD(JavaSparkContext javaSparkContext, MatrixCharacteristics matrixCharacteristics) {
        int min = (int) Math.min(Math.max(SparkExecutionContext.getDefaultParallelism(true), Math.ceil((matrixCharacteristics.getNumBlocks() * OptimizerUtils.estimateSizeEmptyBlock(Math.min(Math.max(matrixCharacteristics.getRows(), 1L), matrixCharacteristics.getRowsPerBlock()), Math.min(Math.max(matrixCharacteristics.getCols(), 1L), matrixCharacteristics.getColsPerBlock()))) / InfrastructureAnalyzer.getHDFSBlockSize())), matrixCharacteristics.getNumBlocks());
        long ceil = (long) Math.ceil(matrixCharacteristics.getNumBlocks() / min);
        return javaSparkContext.parallelize((List) LongStream.iterate(0L, j -> {
            return j + ceil;
        }).limit(min).boxed().collect(Collectors.toList()), min).flatMapToPair(new GenerateEmptyBlocks(matrixCharacteristics, ceil));
    }

    public static JavaPairRDD<MatrixIndexes, MatrixCell> cacheBinaryCellRDD(JavaPairRDD<MatrixIndexes, MatrixCell> javaPairRDD) {
        return !javaPairRDD.getStorageLevel().equals(DEFAULT_TMP) ? javaPairRDD.mapToPair(new CopyBinaryCellFunction()).persist(DEFAULT_TMP) : javaPairRDD;
    }

    public static MatrixCharacteristics computeMatrixCharacteristics(JavaPairRDD<MatrixIndexes, MatrixCell> javaPairRDD) {
        return (MatrixCharacteristics) javaPairRDD.map(new AnalyzeCellMatrixCharacteristics()).reduce(new AggregateMatrixCharacteristics());
    }

    public static long getNonZeros(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD) {
        return ((Long) javaPairRDD.values().map(matrixBlock -> {
            return Long.valueOf(matrixBlock.getNonZeros());
        }).reduce((l, l2) -> {
            return Long.valueOf(l.longValue() + l2.longValue());
        })).longValue();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1563835834:
                if (implMethodName.equals("lambda$getNonZeros$92d5a20a$1")) {
                    z = true;
                    break;
                }
                break;
            case -1378700527:
                if (implMethodName.equals("lambda$getNonZeros$785978c4$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/sysml/runtime/instructions/spark/utils/SparkUtils") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/sysml/runtime/matrix/data/MatrixBlock;)Ljava/lang/Long;")) {
                    return matrixBlock -> {
                        return Long.valueOf(matrixBlock.getNonZeros());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/Function2") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/sysml/runtime/instructions/spark/utils/SparkUtils") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Long;Ljava/lang/Long;)Ljava/lang/Long;")) {
                    return (l, l2) -> {
                        return Long.valueOf(l.longValue() + l2.longValue());
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
