package org.apache.sysml.scripts.utils;

import java.io.IOException;
import java.io.InputStreamReader;
import org.apache.sysml.api.mlcontext.Script;

/* loaded from: input_file:org/apache/sysml/scripts/utils/Metrics.class */
public class Metrics extends Script {
    public Metrics() {
        InputStreamReader inputStreamReader = new InputStreamReader(Script.class.getResourceAsStream(new StringBuffer().append("/").append("scripts/utils/metrics.dml").toString()));
        char[] cArr = new char[1024];
        StringBuilder sb = new StringBuilder();
        while (true) {
            try {
                int read = inputStreamReader.read(cArr);
                if (read <= 0) {
                    break;
                } else {
                    sb.append(cArr, 0, read);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        setScriptString(sb.toString());
    }

    public String classification_report(Object obj, Object obj2, Object obj3) {
        Script script = new Script("source('scripts/utils/metrics.dml') as mlcontextns;out = mlcontextns::classification_report(y_true, y_pred, labels);");
        script.in("y_true", obj).in("y_pred", obj2).in("labels", obj3).out("out");
        return script.execute().getString("out");
    }

    public String classification_report__docs() {
        return "classification_report = function(matrix[double] y_true, matrix[double] y_pred, matrix[double] labels) return (string out) {\n\tnum_rows_error_measures = nrow(labels)\n\terror_measures = matrix(0, rows=num_rows_error_measures, cols=5)\n\tfor(i in 1:num_rows_error_measures) {\n\t\tclass_i = labels[i,1]\n        tp = sum( (y_true == y_pred) * (y_true == class_i) )\n        tp_plus_fp = sum( (y_pred == class_i) )\n        tp_plus_fn = sum( (y_true == class_i) )\n        precision = tp / tp_plus_fp\n        recall = tp / tp_plus_fn\n        f1Score = 2*precision*recall / (precision+recall)\n        error_measures[i,1] = class_i\n        error_measures[i,2] = precision\n        error_measures[i,3] = recall\n        error_measures[i,4] = f1Score\n        error_measures[i,5] = tp_plus_fn\n\t}\n\t# Added num_true_labels to debug whether the input data was randomized or now, which is common requirement of SGD-style algorithms.\n\t# Also, helps debug class-skew related problems.\n\tout = \"class    \\tprecision\\trecall  \\tf1-score\\tnum_true_labels\\n\" + toString(error_measures, decimal=7, sep=\"\\t\")\n}\n";
    }

    public String classification_report__source() {
        return "classification_report = function(matrix[double] y_true, matrix[double] y_pred, matrix[double] labels) return (string out) {\n\tnum_rows_error_measures = nrow(labels)\n\terror_measures = matrix(0, rows=num_rows_error_measures, cols=5)\n\tfor(i in 1:num_rows_error_measures) {\n\t\tclass_i = labels[i,1]\n        tp = sum( (y_true == y_pred) * (y_true == class_i) )\n        tp_plus_fp = sum( (y_pred == class_i) )\n        tp_plus_fn = sum( (y_true == class_i) )\n        precision = tp / tp_plus_fp\n        recall = tp / tp_plus_fn\n        f1Score = 2*precision*recall / (precision+recall)\n        error_measures[i,1] = class_i\n        error_measures[i,2] = precision\n        error_measures[i,3] = recall\n        error_measures[i,4] = f1Score\n        error_measures[i,5] = tp_plus_fn\n\t}\n\t# Added num_true_labels to debug whether the input data was randomized or now, which is common requirement of SGD-style algorithms.\n\t# Also, helps debug class-skew related problems.\n\tout = \"class    \\tprecision\\trecall  \\tf1-score\\tnum_true_labels\\n\" + toString(error_measures, decimal=7, sep=\"\\t\")\n}\n";
    }
}
