package org.apache.sysml.api.ml;

import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.classification.ProbabilisticClassificationModel;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.sysml.api.DMLException;
import org.apache.sysml.api.MLContext;
import org.apache.sysml.api.MLOutput;
import org.apache.sysml.parser.ParseException;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;

/* loaded from: input_file:org/apache/sysml/api/ml/LogisticRegressionModel.class */
public class LogisticRegressionModel extends ProbabilisticClassificationModel<Vector, LogisticRegressionModel> {
    private static final long serialVersionUID = -6464693773946415027L;
    private JavaPairRDD<MatrixIndexes, MatrixBlock> b_out;
    private SparkContext sc;
    private MatrixCharacteristics b_outMC;

    /* loaded from: input_file:org/apache/sysml/api/ml/LogisticRegressionModel$ConvertIntToRow.class */
    public static class ConvertIntToRow implements Function<Integer, Row> {
        private static final long serialVersionUID = -3480953015655773622L;

        public Row call(Integer num) throws Exception {
            return RowFactory.create(new Object[]{new Double(num.intValue())});
        }
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public LogisticRegressionModel m230copy(ParamMap paramMap) {
        return this;
    }

    public LogisticRegressionModel(JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRDD, MatrixCharacteristics matrixCharacteristics, SparkContext sparkContext) {
        this.b_out = javaPairRDD;
        this.b_outMC = matrixCharacteristics;
        this.sc = sparkContext;
    }

    public LogisticRegressionModel() {
    }

    public LogisticRegressionModel(String str) {
    }

    public String uid() {
        return Long.toString(serialVersionUID);
    }

    public Vector raw2probabilityInPlace(Vector vector) {
        return vector;
    }

    public int numClasses() {
        return 2;
    }

    public Vector predictRaw(Vector vector) {
        return vector;
    }

    public double predict(Vector vector) {
        return super.predict(vector);
    }

    public double raw2prediction(Vector vector) {
        return super.raw2prediction(vector);
    }

    public double probability2prediction(Vector vector) {
        return super.probability2prediction(vector);
    }

    public DataFrame transform(DataFrame dataFrame) {
        try {
            MatrixCharacteristics matrixCharacteristics = new MatrixCharacteristics();
            try {
                JavaPairRDD<MatrixIndexes, MatrixBlock> vectorDataFrameToBinaryBlock = RDDConverterUtilsExt.vectorDataFrameToBinaryBlock(new JavaSparkContext(this.sc), dataFrame, matrixCharacteristics, false, "features");
                MLContext mLContext = new MLContext(this.sc);
                mLContext.registerInput("X", vectorDataFrameToBinaryBlock, matrixCharacteristics);
                mLContext.registerInput("B_full", this.b_out, this.b_outMC);
                mLContext.registerOutput("means");
                HashMap<String, String> hashMap = new HashMap<>();
                hashMap.put("dfam", "3");
                String str = System.getenv("SYSTEMML_HOME");
                if (str == null) {
                    System.err.println("ERROR: The environment variable SYSTEMML_HOME is not set.");
                    return null;
                }
                hashMap.put("X", " ");
                hashMap.put("B", " ");
                MLOutput execute = mLContext.execute(str + File.separator + "algorithms" + File.separator + "GLM-predict.dml", hashMap);
                SQLContext sQLContext = new SQLContext(this.sc);
                DataFrame withColumnRenamed = execute.getDF(sQLContext, "means", true).withColumnRenamed("C1", "probability");
                MLContext mLContext2 = new MLContext(this.sc);
                mLContext2.registerInput("X", vectorDataFrameToBinaryBlock, matrixCharacteristics);
                mLContext2.registerInput("B_full", this.b_out, this.b_outMC);
                mLContext2.registerInput("Prob", execute.getBinaryBlockedRDD("means"), execute.getMatrixCharacteristics("means"));
                mLContext2.registerOutput("Prediction");
                mLContext2.registerOutput("rawPred");
                MLOutput executeScript = mLContext2.executeScript("Prob = read(\"temp1\"); Prediction = rowIndexMax(Prob); write(Prediction, \"tempOut\", \"csv\")X = read(\"temp2\");B_full = read(\"temp3\");rawPred = 1 / (1 + exp(- X * t(B_full)) );write(rawPred, \"tempOut1\", \"csv\")");
                DataFrame withColumnRenamed2 = executeScript.getDF(sQLContext, "Prediction").withColumnRenamed("C1", "prediction").withColumnRenamed("ID", "ID1");
                DataFrame withColumnRenamed3 = executeScript.getDF(sQLContext, "rawPred", true).withColumnRenamed("C1", "rawPrediction").withColumnRenamed("ID", "ID2");
                DataFrame select = withColumnRenamed.join(withColumnRenamed2, withColumnRenamed.col("ID").equalTo(withColumnRenamed2.col("ID1"))).select("ID", new String[]{"probability", "prediction"});
                DataFrame select2 = select.join(withColumnRenamed3, select.col("ID").equalTo(withColumnRenamed3.col("ID2"))).select("ID", new String[]{"probability", "prediction", "rawPrediction"});
                DataFrame addIDToDataFrame = RDDConverterUtilsExt.addIDToDataFrame(dataFrame, sQLContext, "ID");
                return addIDToDataFrame.join(select2, addIDToDataFrame.col("ID").equalTo(select2.col("ID"))).orderBy("id", new String[0]);
            } catch (DMLRuntimeException e) {
                e.printStackTrace();
                return null;
            }
        } catch (IOException e2) {
            throw new RuntimeException(e2);
        } catch (DMLRuntimeException e3) {
            throw new RuntimeException(e3);
        } catch (DMLException e4) {
            throw new RuntimeException(e4);
        } catch (ParseException e5) {
            throw new RuntimeException(e5);
        }
    }
}
