package org.apache.sysml.runtime.instructions.spark.utils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.storage.StorageLevel;
import org.apache.sysml.lops.Checkpoint;
import org.apache.sysml.parser.DataExpression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.instructions.spark.functions.CopyBinaryCellFunction;
import org.apache.sysml.runtime.instructions.spark.functions.CopyBlockFunction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.Pair;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.sysml.udf.ExternalFunctionInvocationInstruction;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.class */
public class SparkUtils {
    public static final StorageLevel DEFAULT_TMP = Checkpoint.DEFAULT_STORAGE_LEVEL;

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/SparkUtils$AggregateMatrixCharacteristics.class */
    private static class AggregateMatrixCharacteristics implements Function2<MatrixCharacteristics, MatrixCharacteristics, MatrixCharacteristics> {
        private static final long serialVersionUID = 4263886749699779994L;

        private AggregateMatrixCharacteristics() {
        }

        public MatrixCharacteristics call(MatrixCharacteristics matrixCharacteristics, MatrixCharacteristics matrixCharacteristics2) throws Exception {
            return new MatrixCharacteristics(Math.max(matrixCharacteristics.getRows(), matrixCharacteristics2.getRows()), Math.max(matrixCharacteristics.getCols(), matrixCharacteristics2.getCols()), matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock(), matrixCharacteristics.getNonZeros() + matrixCharacteristics2.getNonZeros());
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/SparkUtils$AnalyzeBlockMatrixCharacteristics.class */
    private static class AnalyzeBlockMatrixCharacteristics implements Function<Tuple2<MatrixIndexes, MatrixBlock>, MatrixCharacteristics> {
        private static final long serialVersionUID = -1857049501217936951L;
        private int _brlen;
        private int _bclen;

        public AnalyzeBlockMatrixCharacteristics(int i, int i2) {
            this._brlen = -1;
            this._bclen = -1;
            this._brlen = i;
            this._bclen = i2;
        }

        public MatrixCharacteristics call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            return new MatrixCharacteristics(((((MatrixIndexes) tuple2._1()).getRowIndex() - 1) * this._brlen) + r0.getNumRows(), ((((MatrixIndexes) tuple2._1()).getColumnIndex() - 1) * this._bclen) + r0.getNumColumns(), this._brlen, this._bclen, ((MatrixBlock) tuple2._2()).getNonZeros());
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/SparkUtils$AnalyzeCellMatrixCharacteristics.class */
    private static class AnalyzeCellMatrixCharacteristics implements Function<Tuple2<MatrixIndexes, MatrixCell>, MatrixCharacteristics> {
        private static final long serialVersionUID = 8899395272683723008L;

        private AnalyzeCellMatrixCharacteristics() {
        }

        public MatrixCharacteristics call(Tuple2<MatrixIndexes, MatrixCell> tuple2) throws Exception {
            return new MatrixCharacteristics(((MatrixIndexes) tuple2._1()).getRowIndex(), ((MatrixIndexes) tuple2._1()).getColumnIndex(), 0, 0, ((MatrixCell) tuple2._2()).getValue() != DataExpression.DEFAULT_DELIM_FILL_VALUE ? 1L : 0L);
        }
    }

    public static IndexedMatrixValue toIndexedMatrixBlock(Tuple2<MatrixIndexes, MatrixBlock> tuple2) {
        return new IndexedMatrixValue((MatrixIndexes) tuple2._1(), (MatrixValue) tuple2._2());
    }

    public static IndexedMatrixValue toIndexedMatrixBlock(MatrixIndexes matrixIndexes, MatrixBlock matrixBlock) {
        return new IndexedMatrixValue(matrixIndexes, matrixBlock);
    }

    public static Tuple2<MatrixIndexes, MatrixBlock> fromIndexedMatrixBlock(IndexedMatrixValue indexedMatrixValue) {
        return new Tuple2<>(indexedMatrixValue.getIndexes(), (MatrixBlock) indexedMatrixValue.getValue());
    }

    public static ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> fromIndexedMatrixBlock(ArrayList<IndexedMatrixValue> arrayList) {
        ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> arrayList2 = new ArrayList<>();
        Iterator<IndexedMatrixValue> it = arrayList.iterator();
        while (it.hasNext()) {
            arrayList2.add(fromIndexedMatrixBlock(it.next()));
        }
        return arrayList2;
    }

    public static Tuple2<Long, FrameBlock> fromIndexedFrameBlock(Pair<Long, FrameBlock> pair) {
        return new Tuple2<>(pair.getKey(), pair.getValue());
    }

    public static ArrayList<Tuple2<Long, FrameBlock>> fromIndexedFrameBlock(ArrayList<Pair<Long, FrameBlock>> arrayList) {
        ArrayList<Tuple2<Long, FrameBlock>> arrayList2 = new ArrayList<>();
        Iterator<Pair<Long, FrameBlock>> it = arrayList.iterator();
        while (it.hasNext()) {
            arrayList2.add(fromIndexedFrameBlock(it.next()));
        }
        return arrayList2;
    }

    public static MatrixBlock[] partitionIntoRowBlocks(MatrixBlock matrixBlock, int i) throws DMLRuntimeException {
        int numRows = matrixBlock.getNumRows();
        int ceil = (int) Math.ceil(numRows / i);
        MatrixBlock[] matrixBlockArr = new MatrixBlock[ceil];
        for (int i2 = 0; i2 < ceil; i2++) {
            MatrixBlock matrixBlock2 = new MatrixBlock();
            matrixBlock.sliceOperations(i2 * i, Math.min((i2 + 1) * i, numRows) - 1, 0, matrixBlock.getNumColumns() - 1, matrixBlock2);
            matrixBlockArr[i2] = matrixBlock2;
        }
        return matrixBlockArr;
    }

    public static MatrixBlock[] partitionIntoColumnBlocks(MatrixBlock matrixBlock, int i) throws DMLRuntimeException {
        int numColumns = matrixBlock.getNumColumns();
        int ceil = (int) Math.ceil(numColumns / i);
        MatrixBlock[] matrixBlockArr = new MatrixBlock[ceil];
        for (int i2 = 0; i2 < ceil; i2++) {
            MatrixBlock matrixBlock2 = new MatrixBlock();
            matrixBlock.sliceOperations(0, matrixBlock.getNumRows() - 1, i2 * i, Math.min((i2 + 1) * i, numColumns) - 1, matrixBlock2);
            matrixBlockArr[i2] = matrixBlock2;
        }
        return matrixBlockArr;
    }

    public static String getStartLineFromSparkDebugInfo(String str) throws DMLRuntimeException {
        return str.substring(4, str.length()).split(ExternalFunctionInvocationInstruction.ELEMENT_DELIM)[0];
    }

    public static String getPrefixFromSparkDebugInfo(String str) {
        String[] split = str.split("\\||\\+-");
        String str2 = split[0];
        for (int i = 1; i < split.length - 1; i++) {
            str2 = str2 + ProgramConverter.DATA_FIELD_DELIM + split[i];
        }
        return str.contains("+-") ? str2 + "+- " : str2 + ProgramConverter.DATA_FIELD_DELIM + "  ";
    }

    public static long getStartGlobalIndex(long j, int i, long j2) {
        return UtilFunctions.computeCellIndex(j, i, 0);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> getRDDWithEmptyBlocks(JavaSparkContext javaSparkContext, JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, long j, long j2, int i, int i2) throws DMLRuntimeException {
        ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> emptyBlocks = getEmptyBlocks(javaPairRDD.keys().collect(), j, j2, i, i2);
        return (emptyBlocks == null || emptyBlocks.size() <= 0) ? javaPairRDD : JavaPairRDD.fromJavaRDD(javaSparkContext.parallelize(emptyBlocks)).union(javaPairRDD);
    }

    private static ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> getEmptyBlocks(List<MatrixIndexes> list, long j, long j2, int i, int i2) throws DMLRuntimeException {
        long ceil = ((long) Math.ceil(j / i)) * ((long) Math.ceil(j2 / i2));
        if (ceil == list.size()) {
            return null;
        }
        if (ceil < list.size()) {
            throw new DMLRuntimeException("Error: Incorrect number of indexes in ReblockSPInstruction:" + list.size());
        }
        Collections.sort(list);
        ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> arrayList = new ArrayList<>();
        int i3 = 0;
        long j3 = 1;
        while (true) {
            long j4 = j3;
            if (j4 > Math.ceil(j / i)) {
                break;
            }
            long j5 = 1;
            while (true) {
                long j6 = j5;
                if (j6 <= Math.ceil(j2 / i2)) {
                    boolean z = false;
                    if (list.size() > i3) {
                        z = list.get(i3).getRowIndex() == j4 && list.get(i3).getColumnIndex() == j6;
                    }
                    if (z) {
                        i3++;
                    } else {
                        arrayList.add(new Tuple2<>(new MatrixIndexes(j4, j6), new MatrixBlock(UtilFunctions.computeBlockSize(j, j4, i), UtilFunctions.computeBlockSize(j2, j6, i2), true)));
                    }
                    j5 = j6 + 1;
                }
            }
            j3 = j4 + 1;
        }
        if (i3 != list.size()) {
            throw new DMLRuntimeException("Unexpected error while adding empty blocks");
        }
        return arrayList;
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> getEmptyBlockRDD(JavaSparkContext javaSparkContext, MatrixCharacteristics matrixCharacteristics) {
        ArrayList arrayList = new ArrayList();
        int ceil = (int) Math.ceil(matrixCharacteristics.getRows() / matrixCharacteristics.getRowsPerBlock());
        int ceil2 = (int) Math.ceil(matrixCharacteristics.getCols() / matrixCharacteristics.getColsPerBlock());
        long j = 1;
        while (true) {
            long j2 = j;
            if (j2 > ceil) {
                return javaSparkContext.parallelizePairs(arrayList);
            }
            long j3 = 1;
            while (true) {
                long j4 = j3;
                if (j4 <= ceil2) {
                    arrayList.add(new Tuple2(new MatrixIndexes(j2, j4), new MatrixBlock(UtilFunctions.computeBlockSize(matrixCharacteristics.getRows(), j2, matrixCharacteristics.getRowsPerBlock()), UtilFunctions.computeBlockSize(matrixCharacteristics.getCols(), j4, matrixCharacteristics.getColsPerBlock()), true)));
                    j3 = j4 + 1;
                }
            }
            j = j2 + 1;
        }
    }

    public static JavaPairRDD<MatrixIndexes, MatrixCell> cacheBinaryCellRDD(JavaPairRDD<MatrixIndexes, MatrixCell> javaPairRDD) {
        JavaPairRDD<MatrixIndexes, MatrixCell> javaPairRDD2 = null;
        if (!javaPairRDD.getStorageLevel().equals(DEFAULT_TMP)) {
            javaPairRDD2 = javaPairRDD.mapToPair(new CopyBinaryCellFunction()).persist(DEFAULT_TMP);
        }
        return javaPairRDD2;
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> cacheBinaryBlockRDD(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD) {
        JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD2 = null;
        if (!javaPairRDD.getStorageLevel().equals(DEFAULT_TMP)) {
            javaPairRDD2 = javaPairRDD.mapValues(new CopyBlockFunction(false)).persist(DEFAULT_TMP);
        }
        return javaPairRDD2;
    }

    public static MatrixCharacteristics computeMatrixCharacteristics(JavaPairRDD<MatrixIndexes, MatrixCell> javaPairRDD) {
        return (MatrixCharacteristics) javaPairRDD.map(new AnalyzeCellMatrixCharacteristics()).reduce(new AggregateMatrixCharacteristics());
    }

    public static MatrixCharacteristics computeMatrixCharacteristics(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, int i, int i2) {
        return (MatrixCharacteristics) javaPairRDD.map(new AnalyzeBlockMatrixCharacteristics(i, i2)).reduce(new AggregateMatrixCharacteristics());
    }

    public static long computeNNZFromCells(JavaPairRDD<MatrixIndexes, MatrixCell> javaPairRDD) {
        return javaPairRDD.values().filter(new Function<MatrixCell, Boolean>() { // from class: org.apache.sysml.runtime.instructions.spark.utils.SparkUtils.1
            private static final long serialVersionUID = -6550193680630537857L;

            public Boolean call(MatrixCell matrixCell) throws Exception {
                return Boolean.valueOf(matrixCell.getValue() != DataExpression.DEFAULT_DELIM_FILL_VALUE);
            }
        }).count();
    }

    public static long computeNNZFromBlocks(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD) {
        return ((Long) javaPairRDD.values().aggregate(0L, new Function2<Long, MatrixBlock, Long>() { // from class: org.apache.sysml.runtime.instructions.spark.utils.SparkUtils.2
            private static final long serialVersionUID = 4907645080949985267L;

            public Long call(Long l, MatrixBlock matrixBlock) throws Exception {
                return Long.valueOf(l.longValue() + matrixBlock.getNonZeros());
            }
        }, new Function2<Long, Long, Long>() { // from class: org.apache.sysml.runtime.instructions.spark.utils.SparkUtils.3
            private static final long serialVersionUID = 333028431986883739L;

            public Long call(Long l, Long l2) throws Exception {
                return Long.valueOf(l.longValue() + l2.longValue());
            }
        })).longValue();
    }
}
