package org.apache.sysml.api;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.apache.hadoop.yarn.util.Apps;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.spark.functions.GetMLBlock;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.util.UtilFunctions;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysml/api/MLOutput.class */
public class MLOutput {
    HashMap<String, JavaPairRDD<MatrixIndexes, MatrixBlock>> _outputs;
    private HashMap<String, MatrixCharacteristics> _outMetadata;

    /* loaded from: input_file:org/apache/sysml/api/MLOutput$ConvertDoubleArrayToRangeRows.class */
    public static class ConvertDoubleArrayToRangeRows implements Function<Tuple2<Long, Iterable<Tuple2<Long, Double[]>>>, Row> {
        private static final long serialVersionUID = 4441184411670316972L;
        int bclen;
        long clen;
        ArrayList<Tuple2<String, Tuple2<Long, Long>>> range;

        public ConvertDoubleArrayToRangeRows(long j, int i, ArrayList<Tuple2<String, Tuple2<Long, Long>>> arrayList) {
            this.bclen = i;
            this.clen = j;
            this.range = arrayList;
        }

        public Row call(Tuple2<Long, Iterable<Tuple2<Long, Double[]>>> tuple2) throws Exception {
            HashMap hashMap = new HashMap();
            int i = 0;
            for (Tuple2 tuple22 : (Iterable) tuple2._2) {
                hashMap.put(tuple22._1, tuple22._2);
                i += ((Double[]) tuple22._2).length;
            }
            Object[] objArr = new Object[this.range.size() + 1];
            double[] dArr = new double[i];
            long j = 1;
            while (true) {
                long j2 = j;
                if (j2 > hashMap.size()) {
                    objArr[0] = new Double(((Long) tuple2._1).longValue());
                    int i2 = 1;
                    for (int i3 = 0; i3 < this.range.size(); i3++) {
                        long longValue = ((Long) ((Tuple2) this.range.get(i3)._2)._1).longValue();
                        long longValue2 = ((Long) ((Tuple2) this.range.get(i3)._2)._2).longValue();
                        if (longValue2 < longValue) {
                            throw new Exception("Incorrect range:" + longValue2 + "<" + longValue);
                        }
                        if (longValue == longValue2) {
                            objArr[i2] = new Double(dArr[(int) (longValue - 1)]);
                        } else {
                            int i4 = (int) ((longValue2 - longValue) + 1);
                            double[] dArr2 = new double[i4];
                            for (int i5 = 0; i5 < i4; i5++) {
                                dArr2[i5] = dArr[(int) ((longValue + i5) - 1)];
                            }
                            objArr[i2] = new DenseVector(dArr2);
                        }
                        i2++;
                    }
                    return RowFactory.create(objArr);
                }
                if (!hashMap.containsKey(Long.valueOf(j2))) {
                    throw new Exception("The block for column index " + j2 + " is missing. Make sure the last instruction is not returning empty blocks");
                }
                Double[] dArr3 = (Double[]) hashMap.get(Long.valueOf(j2));
                int computeBlockSize = UtilFunctions.computeBlockSize(this.clen, j2, this.bclen);
                if (dArr3.length != computeBlockSize) {
                    throw new Exception("Incorrect double array provided by ProjectRows");
                }
                for (int i6 = 0; i6 < computeBlockSize; i6++) {
                    dArr[(int) (((j2 - 1) * this.bclen) + i6)] = dArr3[i6].doubleValue();
                }
                j = j2 + 1;
            }
        }
    }

    /* loaded from: input_file:org/apache/sysml/api/MLOutput$ConvertDoubleArrayToRows.class */
    public static class ConvertDoubleArrayToRows implements Function<Tuple2<Long, Iterable<Tuple2<Long, Double[]>>>, Row> {
        private static final long serialVersionUID = 4441184411670316972L;
        int bclen;
        long clen;
        boolean outputVector;

        public ConvertDoubleArrayToRows(long j, int i, boolean z) {
            this.bclen = i;
            this.clen = j;
            this.outputVector = z;
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v44, types: [java.lang.Object[]] */
        public Row call(Tuple2<Long, Iterable<Tuple2<Long, Double[]>>> tuple2) throws Exception {
            Double[] dArr;
            HashMap hashMap = new HashMap();
            int i = 0;
            for (Tuple2 tuple22 : (Iterable) tuple2._2) {
                hashMap.put(tuple22._1, tuple22._2);
                i += ((Double[]) tuple22._2).length;
            }
            if (!this.outputVector) {
                dArr = new Double[i + 1];
                dArr[0] = new Double(((Long) tuple2._1).longValue());
                long j = 1;
                while (true) {
                    long j2 = j;
                    if (j2 > hashMap.size()) {
                        break;
                    }
                    if (!hashMap.containsKey(Long.valueOf(j2))) {
                        throw new Exception("The block for column index " + j2 + " is missing. Make sure the last instruction is not returning empty blocks");
                    }
                    Double[] dArr2 = (Double[]) hashMap.get(Long.valueOf(j2));
                    int computeBlockSize = UtilFunctions.computeBlockSize(this.clen, j2, this.bclen);
                    if (dArr2.length != computeBlockSize) {
                        throw new Exception("Incorrect double array provided by ProjectRows");
                    }
                    for (int i2 = 0; i2 < computeBlockSize; i2++) {
                        dArr[((int) (((j2 - 1) * this.bclen) + i2)) + 1] = dArr2[i2];
                    }
                    j = j2 + 1;
                }
            } else {
                dArr = new Object[2];
                double[] dArr3 = new double[i];
                long j3 = 1;
                while (true) {
                    long j4 = j3;
                    if (j4 > hashMap.size()) {
                        dArr[0] = new Double(((Long) tuple2._1).longValue());
                        dArr[1] = new DenseVector(dArr3);
                        break;
                    }
                    if (!hashMap.containsKey(Long.valueOf(j4))) {
                        throw new Exception("The block for column index " + j4 + " is missing. Make sure the last instruction is not returning empty blocks");
                    }
                    Double[] dArr4 = (Double[]) hashMap.get(Long.valueOf(j4));
                    int computeBlockSize2 = UtilFunctions.computeBlockSize(this.clen, j4, this.bclen);
                    if (dArr4.length != computeBlockSize2) {
                        throw new Exception("Incorrect double array provided by ProjectRows");
                    }
                    for (int i3 = 0; i3 < computeBlockSize2; i3++) {
                        dArr3[(int) (((j4 - 1) * this.bclen) + i3)] = dArr4[i3].doubleValue();
                    }
                    j3 = j4 + 1;
                }
            }
            return RowFactory.create(dArr);
        }
    }

    /* loaded from: input_file:org/apache/sysml/api/MLOutput$ProjectRows.class */
    public static class ProjectRows implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, Long, Tuple2<Long, Double[]>> {
        private static final long serialVersionUID = -4792573268900472749L;
        long rlen;
        long clen;
        int brlen;
        int bclen;

        public ProjectRows(long j, long j2, int i, int i2) {
            this.rlen = j;
            this.clen = j2;
            this.brlen = i;
            this.bclen = i2;
        }

        public Iterable<Tuple2<Long, Tuple2<Long, Double[]>>> call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) throws Exception {
            long rowIndex = ((MatrixIndexes) tuple2._1).getRowIndex();
            long columnIndex = ((MatrixIndexes) tuple2._1).getColumnIndex();
            int computeBlockSize = UtilFunctions.computeBlockSize(this.rlen, rowIndex, this.brlen);
            int computeBlockSize2 = UtilFunctions.computeBlockSize(this.clen, columnIndex, this.bclen);
            long rowIndex2 = (((MatrixIndexes) tuple2._1).getRowIndex() - 1) * this.bclen;
            MatrixBlock matrixBlock = (MatrixBlock) tuple2._2;
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < computeBlockSize; i++) {
                Double[] dArr = new Double[computeBlockSize2];
                for (int i2 = 0; i2 < computeBlockSize2; i2++) {
                    dArr[i2] = Double.valueOf(matrixBlock.getValue(i, i2));
                }
                arrayList.add(new Tuple2(Long.valueOf(rowIndex2 + i), new Tuple2(Long.valueOf(((MatrixIndexes) tuple2._1).getColumnIndex()), dArr)));
            }
            return arrayList;
        }
    }

    public MLOutput(HashMap<String, JavaPairRDD<MatrixIndexes, MatrixBlock>> hashMap, HashMap<String, MatrixCharacteristics> hashMap2) {
        this._outMetadata = null;
        this._outputs = hashMap;
        this._outMetadata = hashMap2;
    }

    public JavaPairRDD<MatrixIndexes, MatrixBlock> getBinaryBlockedRDD(String str) throws DMLRuntimeException {
        if (this._outputs.containsKey(str)) {
            return this._outputs.get(str);
        }
        throw new DMLRuntimeException("Variable " + str + " not found in the output symbol table.");
    }

    public MatrixCharacteristics getMatrixCharacteristics(String str) throws DMLRuntimeException {
        if (this._outputs.containsKey(str)) {
            return this._outMetadata.get(str);
        }
        throw new DMLRuntimeException("Variable " + str + " not found in the output symbol table.");
    }

    public DataFrame getDF(SQLContext sQLContext, String str) throws DMLRuntimeException {
        if (sQLContext == null) {
            throw new DMLRuntimeException("SQLContext is not created.");
        }
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockedRDD = getBinaryBlockedRDD(str);
        if (binaryBlockedRDD != null) {
            return RDDConverterUtilsExt.binaryBlockToDataFrame(binaryBlockedRDD, this._outMetadata.get(str), sQLContext);
        }
        throw new DMLRuntimeException("Variable " + str + " not found in the output symbol table.");
    }

    public DataFrame getDF(SQLContext sQLContext, String str, boolean z) throws DMLRuntimeException {
        if (sQLContext == null) {
            throw new DMLRuntimeException("SQLContext is not created.");
        }
        if (!z) {
            return getDF(sQLContext, str);
        }
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockedRDD = getBinaryBlockedRDD(str);
        if (binaryBlockedRDD != null) {
            return RDDConverterUtilsExt.binaryBlockToVectorDataFrame(binaryBlockedRDD, this._outMetadata.get(str), sQLContext);
        }
        throw new DMLRuntimeException("Variable " + str + " not found in the output symbol table.");
    }

    public DataFrame getDF(SQLContext sQLContext, String str, HashMap<String, Tuple2<Long, Long>> hashMap) throws DMLRuntimeException {
        if (sQLContext == null) {
            throw new DMLRuntimeException("SQLContext is not created.");
        }
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockedRDD = getBinaryBlockedRDD(str);
        if (binaryBlockedRDD == null) {
            throw new DMLRuntimeException("Variable " + str + " not found in the output symbol table.");
        }
        MatrixCharacteristics matrixCharacteristics = this._outMetadata.get(str);
        long rows = matrixCharacteristics.getRows();
        long cols = matrixCharacteristics.getCols();
        int rowsPerBlock = matrixCharacteristics.getRowsPerBlock();
        int colsPerBlock = matrixCharacteristics.getColsPerBlock();
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<String, Tuple2<Long, Long>> entry : hashMap.entrySet()) {
            arrayList.add(new Tuple2(entry.getKey(), entry.getValue()));
        }
        JavaRDD map = binaryBlockedRDD.flatMapToPair(new ProjectRows(rows, cols, rowsPerBlock, colsPerBlock)).groupByKey().map(new ConvertDoubleArrayToRangeRows(cols, colsPerBlock, arrayList));
        if (((int) cols) <= 0) {
            throw new DMLRuntimeException("Output dimensions unknown after executing the script and hence cannot create the dataframe");
        }
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(DataTypes.createStructField(Apps.ID, DataTypes.DoubleType, false));
        for (int i = 0; i < arrayList.size(); i++) {
            String str2 = (String) ((Tuple2) arrayList.get(i))._1;
            if (((Long) ((Tuple2) ((Tuple2) arrayList.get(i))._2)._1).longValue() != ((Long) ((Tuple2) ((Tuple2) arrayList.get(i))._2)._2).longValue()) {
                arrayList2.add(DataTypes.createStructField(str2, new VectorUDT(), false));
            } else {
                arrayList2.add(DataTypes.createStructField(str2, DataTypes.DoubleType, false));
            }
        }
        return sQLContext.createDataFrame(map.rdd(), DataTypes.createStructType(arrayList2));
    }

    public JavaRDD<String> getStringRDD(String str, String str2) throws DMLRuntimeException {
        if (str2.compareTo("text") == 0) {
            return RDDConverterUtilsExt.binaryBlockToStringRDD(getBinaryBlockedRDD(str), getMatrixCharacteristics(str), str2);
        }
        throw new DMLRuntimeException("The output format:" + str2 + " is not implemented yet.");
    }

    public MLMatrix getMLMatrix(MLContext mLContext, SQLContext sQLContext, String str) throws DMLRuntimeException {
        if (sQLContext == null) {
            throw new DMLRuntimeException("SQLContext is not created.");
        }
        if (mLContext == null) {
            throw new DMLRuntimeException("MLContext is not created.");
        }
        JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockedRDD = getBinaryBlockedRDD(str);
        if (binaryBlockedRDD == null) {
            throw new DMLRuntimeException("Variable " + str + " not found in the output symbol table.");
        }
        return new MLMatrix(sQLContext.createDataFrame(binaryBlockedRDD.map(new GetMLBlock()).rdd(), MLBlock.getDefaultSchemaForBinaryBlock()), getMatrixCharacteristics(str), mLContext);
    }
}
