package org.apache.sysml.runtime.instructions.spark.utils;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.hadoop.io.Text;
import org.apache.spark.Accumulator;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.sysml.api.MLOutput;
import org.apache.sysml.parser.DataExpression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.spark.functions.ConvertMatrixBlockToIJVLines;
import org.apache.sysml.runtime.io.IOUtilFunctions;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.mapred.DistributedCacheInput;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.mapred.ReblockBuffer;
import org.apache.sysml.runtime.util.FastStringTokenizer;
import org.apache.sysml.runtime.util.UtilFunctions;
import scala.Tuple2;
import scala.collection.JavaConversions;

/* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.class */
public class RDDConverterUtilsExt {

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt$AddRowID.class */
    public static class AddRowID implements Function<Tuple2<Row, Long>, Row> {
        private static final long serialVersionUID = -3733816995375745659L;

        public Row call(Tuple2<Row, Long> tuple2) throws Exception {
            int length = ((Row) tuple2._1).length();
            Object[] objArr = new Object[length + 1];
            for (int i = 0; i < length; i++) {
                objArr[i] = ((Row) tuple2._1).get(i);
            }
            objArr[length] = new Double(((Long) tuple2._2).longValue() + 1);
            return RowFactory.create(objArr);
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt$CSVToBinaryBlockFunction.class */
    private static class CSVToBinaryBlockFunction implements PairFlatMapFunction<Iterator<Tuple2<Text, Long>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 1501589201971233542L;
        private RowToBinaryBlockFunctionHelper helper;

        public CSVToBinaryBlockFunction(MatrixCharacteristics matrixCharacteristics, String str, boolean z, double d) {
            this.helper = null;
            this.helper = new RowToBinaryBlockFunctionHelper(matrixCharacteristics, str, z, d);
        }

        public Iterable<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<Text, Long>> it) throws Exception {
            return this.helper.convertToBinaryBlock(it, RDDConverterTypes.TEXT_TO_DOUBLEARR);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt$DataFrameAnalysisFunction.class */
    public static class DataFrameAnalysisFunction implements Function<Row, Row> {
        private static final long serialVersionUID = 5705371332119770215L;
        private RowAnalysisFunctionHelper helper;
        boolean isVectorBasedRDD;

        public DataFrameAnalysisFunction(Accumulator<Double> accumulator, boolean z) {
            this.helper = null;
            this.helper = new RowAnalysisFunctionHelper(accumulator);
            this.isVectorBasedRDD = z;
        }

        public Row call(Row row) throws Exception {
            return this.isVectorBasedRDD ? this.helper.analyzeVector(row) : this.helper.analyzeRow(row);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt$DataFrameToBinaryBlockFunction.class */
    public static class DataFrameToBinaryBlockFunction implements PairFlatMapFunction<Iterator<Tuple2<Row, Long>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 653447740362447236L;
        private RowToBinaryBlockFunctionHelper helper;
        boolean isVectorBasedDF;

        public DataFrameToBinaryBlockFunction(MatrixCharacteristics matrixCharacteristics, boolean z) {
            this.helper = null;
            this.helper = new RowToBinaryBlockFunctionHelper(matrixCharacteristics);
            this.isVectorBasedDF = z;
        }

        public Iterable<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<Row, Long>> it) throws Exception {
            return this.isVectorBasedDF ? this.helper.convertToBinaryBlock(it, RDDConverterTypes.VECTOR_TO_DOUBLEARR) : this.helper.convertToBinaryBlock(it, RDDConverterTypes.ROW_TO_DOUBLEARR);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt$IJVToBinaryBlockFunctionHelper.class */
    public static class IJVToBinaryBlockFunctionHelper implements Serializable {
        private static final long serialVersionUID = -7952801318564745821L;
        private static final int BUFFER_SIZE = 4000000;
        private int _bufflen;
        private long _rlen;
        private long _clen;
        private int _brlen;
        private int _bclen;

        public IJVToBinaryBlockFunctionHelper(MatrixCharacteristics matrixCharacteristics) throws DMLRuntimeException {
            this._bufflen = -1;
            this._rlen = -1L;
            this._clen = -1L;
            this._brlen = -1;
            this._bclen = -1;
            if (!matrixCharacteristics.dimsKnown()) {
                throw new DMLRuntimeException("The dimensions need to be known in given MatrixCharacteristics for given input RDD");
            }
            this._rlen = matrixCharacteristics.getRows();
            this._clen = matrixCharacteristics.getCols();
            this._brlen = matrixCharacteristics.getRowsPerBlock();
            this._bclen = matrixCharacteristics.getColsPerBlock();
            this._bufflen = (int) Math.min(this._rlen * this._clen, DistributedCacheInput.PARTITION_SIZE);
        }

        public Tuple2<MatrixIndexes, MatrixCell> textToMatrixCell(Text text) {
            FastStringTokenizer fastStringTokenizer = new FastStringTokenizer(' ');
            String text2 = text.toString();
            if (text2.startsWith("%")) {
                return null;
            }
            fastStringTokenizer.reset(text2);
            return new Tuple2<>(new MatrixIndexes(fastStringTokenizer.nextLong(), fastStringTokenizer.nextLong()), new MatrixCell(fastStringTokenizer.nextDouble()));
        }

        public Tuple2<MatrixIndexes, MatrixCell> matrixEntryToMatrixCell(MatrixEntry matrixEntry) {
            return new Tuple2<>(new MatrixIndexes(matrixEntry.i(), matrixEntry.j()), new MatrixCell(matrixEntry.value()));
        }

        Iterable<Tuple2<MatrixIndexes, MatrixBlock>> convertToBinaryBlock(Object obj, RDDConverterTypes rDDConverterTypes) throws Exception {
            Tuple2<MatrixIndexes, MatrixCell> textToMatrixCell;
            ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> arrayList = new ArrayList<>();
            ReblockBuffer reblockBuffer = new ReblockBuffer(this._bufflen, this._rlen, this._clen, this._brlen, this._bclen);
            Iterator it = (Iterator) obj;
            while (it.hasNext()) {
                switch (rDDConverterTypes) {
                    case MATRIXENTRY_TO_MATRIXCELL:
                        textToMatrixCell = matrixEntryToMatrixCell((MatrixEntry) it.next());
                        break;
                    case TEXT_TO_MATRIX_CELL:
                        textToMatrixCell = textToMatrixCell((Text) it.next());
                        break;
                    default:
                        throw new Exception("Invalid converter for IJV data:" + rDDConverterTypes.toString());
                }
                if (textToMatrixCell != null) {
                    if (reblockBuffer.getSize() >= reblockBuffer.getCapacity()) {
                        flushBufferToList(reblockBuffer, arrayList);
                    }
                    reblockBuffer.appendCell(((MatrixIndexes) textToMatrixCell._1).getRowIndex(), ((MatrixIndexes) textToMatrixCell._1).getColumnIndex(), ((MatrixCell) textToMatrixCell._2).getValue());
                }
            }
            flushBufferToList(reblockBuffer, arrayList);
            return arrayList;
        }

        private void flushBufferToList(ReblockBuffer reblockBuffer, ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> arrayList) throws IOException, DMLRuntimeException {
            ArrayList<IndexedMatrixValue> arrayList2 = new ArrayList<>();
            reblockBuffer.flushBufferToBinaryBlocks(arrayList2);
            arrayList.addAll(SparkUtils.fromIndexedMatrixBlock(arrayList2));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt$MatrixEntryToBinaryBlockFunction.class */
    public static class MatrixEntryToBinaryBlockFunction implements PairFlatMapFunction<Iterator<MatrixEntry>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 4907483236186747224L;
        private IJVToBinaryBlockFunctionHelper helper;

        public MatrixEntryToBinaryBlockFunction(MatrixCharacteristics matrixCharacteristics) throws DMLRuntimeException {
            this.helper = null;
            this.helper = new IJVToBinaryBlockFunctionHelper(matrixCharacteristics);
        }

        public Iterable<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<MatrixEntry> it) throws Exception {
            return this.helper.convertToBinaryBlock(it, RDDConverterTypes.MATRIXENTRY_TO_MATRIXCELL);
        }
    }

    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt$RDDConverterTypes.class */
    public enum RDDConverterTypes {
        TEXT_TO_MATRIX_CELL,
        MATRIXENTRY_TO_MATRIXCELL,
        TEXT_TO_DOUBLEARR,
        ROW_TO_DOUBLEARR,
        VECTOR_TO_DOUBLEARR
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt$RowAnalysisFunctionHelper.class */
    public static class RowAnalysisFunctionHelper implements Serializable {
        private static final long serialVersionUID = 2310303223289674477L;
        private Accumulator<Double> _aNnz;
        private String _delim;

        public RowAnalysisFunctionHelper(Accumulator<Double> accumulator) {
            this._aNnz = null;
            this._delim = null;
            this._aNnz = accumulator;
        }

        public RowAnalysisFunctionHelper(Accumulator<Double> accumulator, String str) {
            this._aNnz = null;
            this._delim = null;
            this._aNnz = accumulator;
            this._delim = str;
        }

        public String analyzeText(Text text) throws Exception {
            String text2 = text.toString();
            long j = 0;
            for (String str : IOUtilFunctions.split(text2, this._delim)) {
                if (!str.isEmpty() && !str.equals("0") && !str.equals("0.0")) {
                    j++;
                }
            }
            this._aNnz.add(Double.valueOf(j));
            return text2;
        }

        public Row analyzeRow(Row row) throws Exception {
            long j = 0;
            if (row == null) {
                throw new Exception("Error while analyzing row");
            }
            for (int i = 0; i < row.length(); i++) {
                if (RowToBinaryBlockFunctionHelper.getDoubleValue(row, i) != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                    j++;
                }
            }
            this._aNnz.add(Double.valueOf(j));
            return row;
        }

        public Row analyzeVector(Row row) {
            Vector vector = (Vector) row.get(0);
            long j = 0;
            for (int i = 0; i < vector.size(); i++) {
                if (vector.apply(i) != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                    j++;
                }
            }
            this._aNnz.add(Double.valueOf(j));
            return row;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt$RowToBinaryBlockFunctionHelper.class */
    public static class RowToBinaryBlockFunctionHelper implements Serializable {
        private static final long serialVersionUID = -4948430402942717043L;
        private long _rlen;
        private long _clen;
        private int _brlen;
        private int _bclen;
        private String _delim;
        private boolean _fill;
        private double _fillValue;
        boolean emptyFound;

        public RowToBinaryBlockFunctionHelper(MatrixCharacteristics matrixCharacteristics) {
            this._rlen = -1L;
            this._clen = -1L;
            this._brlen = -1;
            this._bclen = -1;
            this._delim = null;
            this._fill = false;
            this._fillValue = DataExpression.DEFAULT_DELIM_FILL_VALUE;
            this.emptyFound = false;
            this._rlen = matrixCharacteristics.getRows();
            this._clen = matrixCharacteristics.getCols();
            this._brlen = matrixCharacteristics.getRowsPerBlock();
            this._bclen = matrixCharacteristics.getColsPerBlock();
        }

        public RowToBinaryBlockFunctionHelper(MatrixCharacteristics matrixCharacteristics, String str, boolean z, double d) {
            this._rlen = -1L;
            this._clen = -1L;
            this._brlen = -1;
            this._bclen = -1;
            this._delim = null;
            this._fill = false;
            this._fillValue = DataExpression.DEFAULT_DELIM_FILL_VALUE;
            this.emptyFound = false;
            this._rlen = matrixCharacteristics.getRows();
            this._clen = matrixCharacteristics.getCols();
            this._brlen = matrixCharacteristics.getRowsPerBlock();
            this._bclen = matrixCharacteristics.getColsPerBlock();
            this._delim = str;
            this._fill = z;
            this._fillValue = d;
        }

        public double[] textToDoubleArray(Text text) {
            String[] split = IOUtilFunctions.split(text.toString(), this._delim);
            double[] dArr = new double[split.length];
            int i = 0;
            for (String str : split) {
                this.emptyFound |= str.isEmpty() && !this._fill;
                int i2 = i;
                i++;
                dArr[i2] = (str.isEmpty() && this._fill) ? this._fillValue : Double.parseDouble(str);
            }
            return dArr;
        }

        public double[] rowToDoubleArray(Row row) throws Exception {
            double[] dArr = new double[row.length()];
            for (int i = 0; i < row.length(); i++) {
                dArr[i] = getDoubleValue(row, i);
            }
            return dArr;
        }

        public double[] vectorToDoubleArray(Vector vector) throws Exception {
            return vector.toDense().values();
        }

        public Iterable<Tuple2<MatrixIndexes, MatrixBlock>> convertToBinaryBlock(Object obj, RDDConverterTypes rDDConverterTypes) throws Exception {
            double[] vectorToDoubleArray;
            ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> arrayList = new ArrayList<>();
            int ceil = (int) Math.ceil(this._clen / this._bclen);
            MatrixIndexes[] matrixIndexesArr = new MatrixIndexes[ceil];
            MatrixBlock[] matrixBlockArr = new MatrixBlock[ceil];
            Iterator it = (Iterator) obj;
            while (it.hasNext()) {
                Tuple2 tuple2 = (Tuple2) it.next();
                long longValue = ((Long) tuple2._2()).longValue() + 1;
                long computeBlockIndex = UtilFunctions.computeBlockIndex(longValue, this._brlen);
                int computeCellInBlock = UtilFunctions.computeCellInBlock(longValue, this._brlen);
                if (matrixIndexesArr[0] == null || matrixIndexesArr[0].getRowIndex() != computeBlockIndex) {
                    if (matrixIndexesArr[0] != null) {
                        flushBlocksToList(matrixIndexesArr, matrixBlockArr, arrayList);
                    }
                    createBlocks(longValue, UtilFunctions.computeBlockSize(this._rlen, computeBlockIndex, this._brlen), matrixIndexesArr, matrixBlockArr);
                }
                this.emptyFound = false;
                switch (rDDConverterTypes) {
                    case TEXT_TO_DOUBLEARR:
                        vectorToDoubleArray = textToDoubleArray((Text) tuple2._1());
                        break;
                    case ROW_TO_DOUBLEARR:
                        vectorToDoubleArray = rowToDoubleArray((Row) tuple2._1());
                        break;
                    case VECTOR_TO_DOUBLEARR:
                        vectorToDoubleArray = vectorToDoubleArray((Vector) ((Row) tuple2._1()).get(0));
                        break;
                    default:
                        throw new Exception("Invalid converter for row-based data:" + rDDConverterTypes.toString());
                }
                int i = 0;
                for (int i2 = 1; i2 <= ceil; i2++) {
                    int computeBlockSize = UtilFunctions.computeBlockSize(this._clen, i2, this._bclen);
                    for (int i3 = 0; i3 < computeBlockSize; i3++) {
                        int i4 = i;
                        i++;
                        matrixBlockArr[i2 - 1].appendValue(computeCellInBlock, i3, vectorToDoubleArray[i4]);
                    }
                }
                if (rDDConverterTypes == RDDConverterTypes.TEXT_TO_DOUBLEARR) {
                    IOUtilFunctions.checkAndRaiseErrorCSVEmptyField(((Text) tuple2._1()).toString(), this._fill, this.emptyFound);
                }
            }
            flushBlocksToList(matrixIndexesArr, matrixBlockArr, arrayList);
            return arrayList;
        }

        private void createBlocks(long j, int i, MatrixIndexes[] matrixIndexesArr, MatrixBlock[] matrixBlockArr) {
            long computeBlockIndex = UtilFunctions.computeBlockIndex(j, this._brlen);
            int ceil = (int) Math.ceil(this._clen / this._bclen);
            for (int i2 = 1; i2 <= ceil; i2++) {
                int computeBlockSize = UtilFunctions.computeBlockSize(this._clen, i2, this._bclen);
                matrixIndexesArr[i2 - 1] = new MatrixIndexes(computeBlockIndex, i2);
                matrixBlockArr[i2 - 1] = new MatrixBlock(i, computeBlockSize, false);
            }
        }

        private void flushBlocksToList(MatrixIndexes[] matrixIndexesArr, MatrixBlock[] matrixBlockArr, ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> arrayList) throws DMLRuntimeException {
            int length = matrixIndexesArr.length;
            for (int i = 0; i < length; i++) {
                if (matrixBlockArr[i] != null) {
                    arrayList.add(new Tuple2<>(matrixIndexesArr[i], matrixBlockArr[i]));
                    matrixBlockArr[i].examSparsity();
                }
            }
        }

        public static double getDoubleValue(Row row, int i) throws Exception {
            try {
                return row.getDouble(i);
            } catch (Exception e) {
                try {
                    return Double.parseDouble(row.get(i).toString());
                } catch (Exception e2) {
                    throw new Exception("Only double types are supported as input to SystemML. The input argument is '" + row.get(i) + "'");
                }
            }
        }
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> coordinateMatrixToBinaryBlock(JavaSparkContext javaSparkContext, CoordinateMatrix coordinateMatrix, MatrixCharacteristics matrixCharacteristics, boolean z) throws DMLRuntimeException {
        JavaPairRDD mapPartitionsToPair = coordinateMatrix.entries().toJavaRDD().mapPartitionsToPair(new MatrixEntryToBinaryBlockFunction(matrixCharacteristics));
        if (z && matrixCharacteristics.mightHaveEmptyBlocks()) {
            mapPartitionsToPair = mapPartitionsToPair.union(SparkUtils.getEmptyBlockRDD(javaSparkContext, matrixCharacteristics));
        }
        return RDDAggregateUtils.mergeByKey(mapPartitionsToPair);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> coordinateMatrixToBinaryBlock(SparkContext sparkContext, CoordinateMatrix coordinateMatrix, MatrixCharacteristics matrixCharacteristics, boolean z) throws DMLRuntimeException {
        return coordinateMatrixToBinaryBlock(new JavaSparkContext(sparkContext), coordinateMatrix, matrixCharacteristics, true);
    }

    public static JavaRDD<String> binaryBlockToStringRDD(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, MatrixCharacteristics matrixCharacteristics, String str) throws DMLRuntimeException {
        if (str.compareTo(DataExpression.FORMAT_TYPE_VALUE_TEXT) == 0) {
            return javaPairRDD.flatMap(new ConvertMatrixBlockToIJVLines(matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock()));
        }
        throw new DMLRuntimeException("The output format:" + str + " is not implemented yet.");
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> vectorDataFrameToBinaryBlock(SparkContext sparkContext, DataFrame dataFrame, MatrixCharacteristics matrixCharacteristics, boolean z, String str) throws DMLRuntimeException {
        return vectorDataFrameToBinaryBlock(new JavaSparkContext(sparkContext), dataFrame, matrixCharacteristics, z, str);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> vectorDataFrameToBinaryBlock(JavaSparkContext javaSparkContext, DataFrame dataFrame, MatrixCharacteristics matrixCharacteristics, boolean z, String str) throws DMLRuntimeException {
        if (z) {
            dataFrame = dropColumn(dataFrame.sort("ID", new String[0]), "ID");
        }
        DataFrame select = dataFrame.select(str, new String[0]);
        if (!matrixCharacteristics.dimsKnown(true)) {
            Accumulator accumulator = javaSparkContext.accumulator(DataExpression.DEFAULT_DELIM_FILL_VALUE);
            matrixCharacteristics.set(select.javaRDD().map(new DataFrameAnalysisFunction(accumulator, true)).count(), ((Vector) ((Row) r0.first()).get(0)).size(), matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock(), UtilFunctions.toLong(((Double) accumulator.value()).doubleValue()));
        }
        return RDDAggregateUtils.mergeByKey(select.javaRDD().zipWithIndex().mapPartitionsToPair(new DataFrameToBinaryBlockFunction(matrixCharacteristics, true)));
    }

    public static DataFrame dropColumn(DataFrame dataFrame, String str) throws DMLRuntimeException {
        ArrayList arrayList = new ArrayList();
        String str2 = null;
        boolean z = false;
        for (String str3 : dataFrame.columns()) {
            if (str3.compareTo(str) == 0) {
                z = true;
            } else if (str2 == null) {
                str2 = str3;
            } else {
                arrayList.add(str3);
            }
        }
        if (!z) {
            throw new DMLRuntimeException("The column \"" + str + "\" is not present in the dataframe.");
        }
        if (str2 == null) {
            throw new DMLRuntimeException("No column other than \"" + str + "\" present in the dataframe.");
        }
        return dataFrame.select(str2, JavaConversions.asScalaBuffer(arrayList).toList());
    }

    public static DataFrame projectColumns(DataFrame dataFrame, ArrayList<String> arrayList) throws DMLRuntimeException {
        ArrayList arrayList2 = new ArrayList();
        for (int i = 1; i < arrayList.size(); i++) {
            arrayList2.add(arrayList.get(i));
        }
        return dataFrame.select(arrayList.get(0), JavaConversions.asScalaBuffer(arrayList2).toList());
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlock(SparkContext sparkContext, DataFrame dataFrame, MatrixCharacteristics matrixCharacteristics, boolean z) throws DMLRuntimeException {
        return dataFrameToBinaryBlock(new JavaSparkContext(sparkContext), dataFrame, matrixCharacteristics, z, (ArrayList<String>) null);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlock(SparkContext sparkContext, DataFrame dataFrame, MatrixCharacteristics matrixCharacteristics, String[] strArr) throws DMLRuntimeException {
        return dataFrameToBinaryBlock(new JavaSparkContext(sparkContext), dataFrame, matrixCharacteristics, false, (ArrayList<String>) new ArrayList(Arrays.asList(strArr)));
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlock(SparkContext sparkContext, DataFrame dataFrame, MatrixCharacteristics matrixCharacteristics, ArrayList<String> arrayList) throws DMLRuntimeException {
        return dataFrameToBinaryBlock(new JavaSparkContext(sparkContext), dataFrame, matrixCharacteristics, false, arrayList);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlock(SparkContext sparkContext, DataFrame dataFrame, MatrixCharacteristics matrixCharacteristics, boolean z, String[] strArr) throws DMLRuntimeException {
        return dataFrameToBinaryBlock(new JavaSparkContext(sparkContext), dataFrame, matrixCharacteristics, z, (ArrayList<String>) new ArrayList(Arrays.asList(strArr)));
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlock(SparkContext sparkContext, DataFrame dataFrame, MatrixCharacteristics matrixCharacteristics, boolean z, ArrayList<String> arrayList) throws DMLRuntimeException {
        return dataFrameToBinaryBlock(new JavaSparkContext(sparkContext), dataFrame, matrixCharacteristics, z, arrayList);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlock(JavaSparkContext javaSparkContext, DataFrame dataFrame, MatrixCharacteristics matrixCharacteristics, boolean z) throws DMLRuntimeException {
        return dataFrameToBinaryBlock(javaSparkContext, dataFrame, matrixCharacteristics, z, (ArrayList<String>) null);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlock(JavaSparkContext javaSparkContext, DataFrame dataFrame, MatrixCharacteristics matrixCharacteristics, ArrayList<String> arrayList) throws DMLRuntimeException {
        return dataFrameToBinaryBlock(javaSparkContext, dataFrame, matrixCharacteristics, false, arrayList);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlock(JavaSparkContext javaSparkContext, DataFrame dataFrame, MatrixCharacteristics matrixCharacteristics, boolean z, ArrayList<String> arrayList) throws DMLRuntimeException {
        if (arrayList != null) {
            dataFrame = projectColumns(dataFrame, arrayList);
        }
        if (z) {
            dataFrame = dropColumn(dataFrame.sort("ID", new String[0]), "ID");
        }
        if (!matrixCharacteristics.dimsKnown(true)) {
            Accumulator accumulator = javaSparkContext.accumulator(DataExpression.DEFAULT_DELIM_FILL_VALUE);
            matrixCharacteristics.set(dataFrame.javaRDD().map(new DataFrameAnalysisFunction(accumulator, false)).count(), z ? dataFrame.columns().length - 1 : dataFrame.columns().length, matrixCharacteristics.getRowsPerBlock(), matrixCharacteristics.getColsPerBlock(), UtilFunctions.toLong(((Double) accumulator.value()).doubleValue()));
        }
        return RDDAggregateUtils.mergeByKey(dataFrame.javaRDD().zipWithIndex().mapPartitionsToPair(new DataFrameToBinaryBlockFunction(matrixCharacteristics, false)));
    }

    public static DataFrame binaryBlockToVectorDataFrame(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, MatrixCharacteristics matrixCharacteristics, SQLContext sQLContext) throws DMLRuntimeException {
        long rows = matrixCharacteristics.getRows();
        long cols = matrixCharacteristics.getCols();
        int rowsPerBlock = matrixCharacteristics.getRowsPerBlock();
        int colsPerBlock = matrixCharacteristics.getColsPerBlock();
        JavaRDD map = javaPairRDD.flatMapToPair(new MLOutput.ProjectRows(rows, cols, rowsPerBlock, colsPerBlock)).groupByKey().map(new MLOutput.ConvertDoubleArrayToRows(cols, colsPerBlock, true));
        if (((int) cols) <= 0) {
            throw new DMLRuntimeException("Output dimensions unknown after executing the script and hence cannot create the dataframe");
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(DataTypes.createStructField("ID", DataTypes.DoubleType, false));
        arrayList.add(DataTypes.createStructField("C1", new VectorUDT(), false));
        return sQLContext.createDataFrame(map.rdd(), DataTypes.createStructType(arrayList));
    }

    public static DataFrame addIDToDataFrame(DataFrame dataFrame, SQLContext sQLContext, String str) {
        StructField[] fields = dataFrame.schema().fields();
        StructField[] structFieldArr = new StructField[fields.length + 1];
        for (int i = 0; i < fields.length; i++) {
            structFieldArr[i] = fields[i];
        }
        structFieldArr[fields.length] = DataTypes.createStructField(str, DataTypes.DoubleType, false);
        return sQLContext.createDataFrame(dataFrame.rdd().toJavaRDD().zipWithIndex().map(new AddRowID()), new StructType(structFieldArr));
    }

    public static DataFrame binaryBlockToDataFrame(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, MatrixCharacteristics matrixCharacteristics, SQLContext sQLContext) throws DMLRuntimeException {
        long rows = matrixCharacteristics.getRows();
        long cols = matrixCharacteristics.getCols();
        int rowsPerBlock = matrixCharacteristics.getRowsPerBlock();
        int colsPerBlock = matrixCharacteristics.getColsPerBlock();
        JavaRDD map = javaPairRDD.flatMapToPair(new MLOutput.ProjectRows(rows, cols, rowsPerBlock, colsPerBlock)).groupByKey().map(new MLOutput.ConvertDoubleArrayToRows(cols, colsPerBlock, false));
        int i = (int) cols;
        if (i <= 0) {
            throw new DMLRuntimeException("Output dimensions unknown after executing the script and hence cannot create the dataframe");
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(DataTypes.createStructField("ID", DataTypes.DoubleType, false));
        for (int i2 = 1; i2 <= i; i2++) {
            arrayList.add(DataTypes.createStructField("C" + i2, DataTypes.DoubleType, false));
        }
        return sQLContext.createDataFrame(map.rdd(), DataTypes.createStructType(arrayList));
    }
}
