package org.apache.sysml.api.javaml;

import java.io.File;
import java.util.HashMap;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegressionParams;
import org.apache.spark.ml.classification.ProbabilisticClassifier;
import org.apache.spark.ml.param.BooleanParam;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
import org.apache.sysml.api.MLContext;
import org.apache.sysml.api.MLOutput;
import org.apache.sysml.api.ml.functions.ConvertSingleColumnToString;
import org.apache.sysml.parser.DataExpression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;

/* loaded from: input_file:org/apache/sysml/api/javaml/LogisticRegression.class */
public class LogisticRegression extends ProbabilisticClassifier<Vector, LogisticRegression, LogisticRegressionModel> implements LogisticRegressionParams {
    private static final long serialVersionUID = 7763813395635870734L;
    private SparkContext sc;
    private SQLContext sqlContext;
    private HashMap<String, String> cmdLineParams;
    private IntParam icpt;
    private DoubleParam reg;
    private DoubleParam tol;
    private IntParam moi;
    private IntParam mii;
    private IntParam labelIndex;
    private StringArrayParam inputCol;
    private StringArrayParam outputCol;
    private int intMin;
    private int li;
    private String[] icname;
    private String[] ocname;

    public LogisticRegression() {
        this.sc = null;
        this.sqlContext = null;
        this.cmdLineParams = new HashMap<>();
        this.icpt = new IntParam(this, "icpt", "Value of intercept");
        this.reg = new DoubleParam(this, "reg", "Value of regularization parameter");
        this.tol = new DoubleParam(this, "tol", "Value of tolerance");
        this.moi = new IntParam(this, "moi", "Max outer iterations");
        this.mii = new IntParam(this, "mii", "Max inner iterations");
        this.labelIndex = new IntParam(this, "li", "Index of the label column");
        this.inputCol = new StringArrayParam(this, "icname", "Feature column name");
        this.outputCol = new StringArrayParam(this, "ocname", "Label column name");
        this.intMin = Integer.MIN_VALUE;
        this.li = 0;
        this.icname = new String[1];
        this.ocname = new String[1];
    }

    public LogisticRegression(String str) {
        this.sc = null;
        this.sqlContext = null;
        this.cmdLineParams = new HashMap<>();
        this.icpt = new IntParam(this, "icpt", "Value of intercept");
        this.reg = new DoubleParam(this, "reg", "Value of regularization parameter");
        this.tol = new DoubleParam(this, "tol", "Value of tolerance");
        this.moi = new IntParam(this, "moi", "Max outer iterations");
        this.mii = new IntParam(this, "mii", "Max inner iterations");
        this.labelIndex = new IntParam(this, "li", "Index of the label column");
        this.inputCol = new StringArrayParam(this, "icname", "Feature column name");
        this.outputCol = new StringArrayParam(this, "ocname", "Label column name");
        this.intMin = Integer.MIN_VALUE;
        this.li = 0;
        this.icname = new String[1];
        this.ocname = new String[1];
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public LogisticRegression m226copy(ParamMap paramMap) {
        try {
            LogisticRegression logisticRegression = new LogisticRegression(this.sc, this.sqlContext);
            logisticRegression.cmdLineParams.put(this.icpt.name(), paramMap.getOrElse(this.icpt, 0).toString());
            logisticRegression.cmdLineParams.put(this.reg.name(), paramMap.getOrElse(this.reg, Float.valueOf(0.0f)).toString());
            logisticRegression.cmdLineParams.put(this.tol.name(), paramMap.getOrElse(this.tol, Float.valueOf(1.0E-6f)).toString());
            logisticRegression.cmdLineParams.put(this.moi.name(), paramMap.getOrElse(this.moi, 100).toString());
            logisticRegression.cmdLineParams.put(this.mii.name(), paramMap.getOrElse(this.mii, 0).toString());
            return logisticRegression;
        } catch (DMLRuntimeException e) {
            e.printStackTrace();
            return null;
        }
    }

    public LogisticRegression(SparkContext sparkContext, SQLContext sQLContext) throws DMLRuntimeException {
        this.sc = null;
        this.sqlContext = null;
        this.cmdLineParams = new HashMap<>();
        this.icpt = new IntParam(this, "icpt", "Value of intercept");
        this.reg = new DoubleParam(this, "reg", "Value of regularization parameter");
        this.tol = new DoubleParam(this, "tol", "Value of tolerance");
        this.moi = new IntParam(this, "moi", "Max outer iterations");
        this.mii = new IntParam(this, "mii", "Max inner iterations");
        this.labelIndex = new IntParam(this, "li", "Index of the label column");
        this.inputCol = new StringArrayParam(this, "icname", "Feature column name");
        this.outputCol = new StringArrayParam(this, "ocname", "Label column name");
        this.intMin = Integer.MIN_VALUE;
        this.li = 0;
        this.icname = new String[1];
        this.ocname = new String[1];
        this.sc = sparkContext;
        this.sqlContext = sQLContext;
        setDefault(intercept(), 0);
        this.cmdLineParams.put(this.icpt.name(), "0");
        setDefault(regParam(), Float.valueOf(0.0f));
        this.cmdLineParams.put(this.reg.name(), "0.0f");
        setDefault(tol(), Float.valueOf(1.0E-6f));
        this.cmdLineParams.put(this.tol.name(), "0.000001f");
        setDefault(maxOuterIter(), 100);
        this.cmdLineParams.put(this.moi.name(), "100");
        setDefault(maxInnerIter(), 0);
        this.cmdLineParams.put(this.mii.name(), "0");
        setDefault(labelIdx(), Integer.valueOf(this.intMin));
        this.li = this.intMin;
        setDefault(inputCol(), this.icname);
        this.icname[0] = "";
        setDefault(outputCol(), this.ocname);
        this.ocname[0] = "";
    }

    public LogisticRegression(SparkContext sparkContext, SQLContext sQLContext, int i, double d, double d2, int i2, int i3) throws DMLRuntimeException {
        this.sc = null;
        this.sqlContext = null;
        this.cmdLineParams = new HashMap<>();
        this.icpt = new IntParam(this, "icpt", "Value of intercept");
        this.reg = new DoubleParam(this, "reg", "Value of regularization parameter");
        this.tol = new DoubleParam(this, "tol", "Value of tolerance");
        this.moi = new IntParam(this, "moi", "Max outer iterations");
        this.mii = new IntParam(this, "mii", "Max inner iterations");
        this.labelIndex = new IntParam(this, "li", "Index of the label column");
        this.inputCol = new StringArrayParam(this, "icname", "Feature column name");
        this.outputCol = new StringArrayParam(this, "ocname", "Label column name");
        this.intMin = Integer.MIN_VALUE;
        this.li = 0;
        this.icname = new String[1];
        this.ocname = new String[1];
        this.sc = sparkContext;
        this.sqlContext = sQLContext;
        setDefault(intercept(), Integer.valueOf(i));
        this.cmdLineParams.put(this.icpt.name(), Integer.toString(i));
        setDefault(regParam(), Double.valueOf(d));
        this.cmdLineParams.put(this.reg.name(), Double.toString(d));
        setDefault(tol(), Double.valueOf(d2));
        this.cmdLineParams.put(this.tol.name(), Double.toString(d2));
        setDefault(maxOuterIter(), Integer.valueOf(i2));
        this.cmdLineParams.put(this.moi.name(), Integer.toString(i2));
        setDefault(maxInnerIter(), Integer.valueOf(i3));
        this.cmdLineParams.put(this.mii.name(), Integer.toString(i3));
        setDefault(labelIdx(), Integer.valueOf(this.intMin));
        this.li = this.intMin;
        setDefault(inputCol(), this.icname);
        this.icname[0] = "";
        setDefault(outputCol(), this.ocname);
        this.ocname[0] = "";
    }

    public String uid() {
        return Long.toString(serialVersionUID);
    }

    public LogisticRegression setRegParam(double d) {
        this.cmdLineParams.put(this.reg.name(), Double.toString(d));
        return setDefault(this.reg, Double.valueOf(d));
    }

    public StructType validateAndTransformSchema(StructType structType, boolean z, DataType dataType) {
        return null;
    }

    public double getRegParam() {
        return Double.parseDouble(this.cmdLineParams.get(this.reg.name()));
    }

    public void org$apache$spark$ml$param$shared$HasRegParam$_setter_$regParam_$eq(DoubleParam doubleParam) {
    }

    public DoubleParam regParam() {
        return this.reg;
    }

    public DoubleParam elasticNetParam() {
        return null;
    }

    public double getElasticNetParam() {
        return DataExpression.DEFAULT_DELIM_FILL_VALUE;
    }

    public void org$apache$spark$ml$param$shared$HasElasticNetParam$_setter_$elasticNetParam_$eq(DoubleParam doubleParam) {
    }

    public int getMaxIter() {
        return 0;
    }

    public IntParam maxIter() {
        return null;
    }

    public LogisticRegression setMaxOuterIter(int i) {
        this.cmdLineParams.put(this.moi.name(), Integer.toString(i));
        return setDefault(this.moi, Integer.valueOf(i));
    }

    public int getMaxOuterIter() {
        return Integer.parseInt(this.cmdLineParams.get(this.moi.name()));
    }

    public IntParam maxOuterIter() {
        return this.moi;
    }

    public LogisticRegression setMaxInnerIter(int i) {
        this.cmdLineParams.put(this.mii.name(), Integer.toString(i));
        return setDefault(this.mii, Integer.valueOf(i));
    }

    public int getMaxInnerIter() {
        return Integer.parseInt(this.cmdLineParams.get(this.mii.name()));
    }

    public IntParam maxInnerIter() {
        return this.mii;
    }

    public void org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(IntParam intParam) {
    }

    public LogisticRegression setIntercept(int i) {
        this.cmdLineParams.put(this.icpt.name(), Integer.toString(i));
        return setDefault(this.icpt, Integer.valueOf(i));
    }

    public int getIntercept() {
        return Integer.parseInt(this.cmdLineParams.get(this.icpt.name()));
    }

    public IntParam intercept() {
        return this.icpt;
    }

    public BooleanParam fitIntercept() {
        return null;
    }

    public boolean getFitIntercept() {
        return false;
    }

    public void org$apache$spark$ml$param$shared$HasFitIntercept$_setter_$fitIntercept_$eq(BooleanParam booleanParam) {
    }

    public LogisticRegression setTol(double d) {
        this.cmdLineParams.put(this.tol.name(), Double.toString(d));
        return setDefault(this.tol, Double.valueOf(d));
    }

    public double getTol() {
        return Double.parseDouble(this.cmdLineParams.get(this.tol.name()));
    }

    public void org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(DoubleParam doubleParam) {
    }

    public DoubleParam tol() {
        return this.tol;
    }

    public double getThreshold() {
        return DataExpression.DEFAULT_DELIM_FILL_VALUE;
    }

    public void org$apache$spark$ml$param$shared$HasThreshold$_setter_$threshold_$eq(DoubleParam doubleParam) {
    }

    public DoubleParam threshold() {
        return null;
    }

    public LogisticRegression setLabelIndex(int i) {
        this.li = i;
        return setDefault(this.labelIndex, Integer.valueOf(i));
    }

    public int getLabelIndex() {
        return Integer.parseInt(this.cmdLineParams.get(this.labelIndex.name()));
    }

    public IntParam labelIdx() {
        return this.labelIndex;
    }

    public LogisticRegression setInputCol(String[] strArr) {
        this.icname[0] = strArr[0];
        return setDefault(this.inputCol, strArr);
    }

    public String getInputCol() {
        return this.icname[0];
    }

    public StringArrayParam inputCol() {
        return this.inputCol;
    }

    public LogisticRegression setOutputCol(String[] strArr) {
        this.ocname[0] = strArr[0];
        return setDefault(this.outputCol, strArr);
    }

    public String getOutputCol() {
        return this.ocname[0];
    }

    public StringArrayParam outputCol() {
        return this.outputCol;
    }

    /* renamed from: train, reason: merged with bridge method [inline-methods] */
    public LogisticRegressionModel m223train(DataFrame dataFrame) {
        MLOutput execute;
        try {
            MLContext mLContext = new MLContext(this.sc);
            MatrixCharacteristics matrixCharacteristics = new MatrixCharacteristics();
            try {
                JavaPairRDD<MatrixIndexes, MatrixBlock> vectorDataFrameToBinaryBlock = RDDConverterUtilsExt.vectorDataFrameToBinaryBlock(new JavaSparkContext(this.sc), dataFrame, matrixCharacteristics, false, "features");
                JavaRDD<String> map = dataFrame.select("label", new String[0]).rdd().toJavaRDD().map(new ConvertSingleColumnToString());
                try {
                    mLContext.registerInput("X", vectorDataFrameToBinaryBlock, matrixCharacteristics);
                    mLContext.registerInput("Y_vec", map, DataExpression.FORMAT_TYPE_VALUE_CSV);
                    mLContext.registerOutput("B_out");
                    this.cmdLineParams.put("X", " ");
                    this.cmdLineParams.put("Y", " ");
                    this.cmdLineParams.put("B", " ");
                    String str = System.getenv("SYSTEMML_HOME");
                    if (str == null) {
                        System.err.println("ERROR: The environment variable SYSTEMML_HOME is not set.");
                        return null;
                    }
                    String str2 = str + File.separator + "algorithms" + File.separator + "MultiLogReg.dml";
                    synchronized (MLContext.class) {
                        execute = mLContext.execute(str2, this.cmdLineParams);
                    }
                    return new LogisticRegressionModel(execute.getBinaryBlockedRDD("B_out"), execute.getMatrixCharacteristics("B_out"), this.sc).setParent(this);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            } catch (DMLRuntimeException e2) {
                e2.printStackTrace();
                return null;
            }
        } catch (DMLRuntimeException e3) {
            e3.printStackTrace();
            return null;
        }
    }
}
