package org.apache.sysml.api.ml;

import org.apache.spark.SparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.sysml.api.mlcontext.MLContext;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Matrix;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.api.mlcontext.ScriptFactory;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
import scala.Predef$;
import scala.Tuple2;

/* compiled from: PredictionUtils.scala */
/* loaded from: input_file:org/apache/sysml/api/ml/PredictionUtils$.class */
public final class PredictionUtils$ {
    public static final PredictionUtils$ MODULE$ = null;

    static {
        new PredictionUtils$();
    }

    public Tuple2<Script, String> getGLMPredictionScript(Matrix matrix, boolean z, Integer num) {
        Script out = ScriptFactory.dml(ScriptsUtils$.MODULE$.getDMLScript(LogisticRegressionModel$.MODULE$.scriptPath())).in("$X", " ").in("$B", " ").in("$dfam", num).out("means");
        return new Tuple2<>(z ? out.in("B_full", matrix.toMatrixBlock(), matrix.getMatrixMetadata()) : out.in("B_full", matrix), "X");
    }

    public Integer getGLMPredictionScript$default$3() {
        return Predef$.MODULE$.int2Integer(1);
    }

    public Dataset<Row> joinUsingID(Dataset<Row> dataset, Dataset<Row> dataset2) {
        return dataset.join(dataset2, RDDConverterUtils.DF_ID_COLUMN);
    }

    public MLResults computePredictedClassLabelsFromProbability(MLResults mLResults, boolean z, SparkContext sparkContext, String str) {
        MLContext mLContext = new MLContext(sparkContext);
        Script out = ScriptFactory.dml("\n        Prob = read(\"temp1\");\n        Prediction = rowIndexMax(Prob); # assuming one-based label mapping\n        write(Prediction, \"tempOut\", \"csv\");\n        ").out("Prediction");
        Matrix matrix = mLResults.getMatrix(str);
        return z ? mLContext.execute(out.in("Prob", matrix.toMatrixBlock(), matrix.getMatrixMetadata())) : mLContext.execute(out.in("Prob", matrix));
    }

    private PredictionUtils$() {
        MODULE$ = this;
    }
}
