package org.apache.sysml.scripts.nn.layers;

import java.io.IOException;
import java.io.InputStreamReader;
import org.apache.sysml.api.mlcontext.Matrix;
import org.apache.sysml.api.mlcontext.Script;

/* loaded from: input_file:org/apache/sysml/scripts/nn/layers/Log_loss.class */
public class Log_loss extends Script {
    public Log_loss() {
        InputStreamReader inputStreamReader = new InputStreamReader(Script.class.getResourceAsStream(new StringBuffer().append("/").append("scripts/nn/layers/log_loss.dml").toString()));
        char[] cArr = new char[1024];
        StringBuilder sb = new StringBuilder();
        while (true) {
            try {
                int read = inputStreamReader.read(cArr);
                if (read <= 0) {
                    break;
                } else {
                    sb.append(cArr, 0, read);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        setScriptString(sb.toString());
    }

    public double forward(Object obj, Object obj2) {
        Script script = new Script("source('scripts/nn/layers/log_loss.dml') as mlcontextns;loss = mlcontextns::forward(pred, y);");
        script.in("pred", obj).in("y", obj2).out("loss");
        return script.execute().getDouble("loss");
    }

    public String forward__docs() {
        return "forward = function(matrix[double] pred, matrix[double] y)\n    return (double loss) {\n  /*\n   * Computes the forward pass for a log loss function.\n   *\n   *   ```\n   *   L_i = -y_i*log(pred_i) - (1-y_i)*log(1-pred_i)\n   *   L = (1/N) sum(L_i) for i=1 to N\n   *   ```\n   *\n   * In these equations, `L` is the total loss, `L_i` is the loss for\n   * example `i`, `y_i` is the binary target, `pred_i` is probability\n   * of the true class (i.e. `y=1`), and `N` is the number of examples.\n   *\n   * This can be interpreted as the negative log-likelihood assuming\n   * a Bernoulli distribution.\n   *\n   * Inputs:\n   *  - pred: Predictions, of shape (N, 1).\n   *      Predictions should be probabilities of the true\n   *      class (i.e. probability of `y=1`).\n   *  - y: Targets, of shape (N, 1).\n   *      Targets should be binary in the set {0, 1}.\n   *\n   * Outputs:\n   *  - loss: Average loss.\n   */\n";
    }

    public String forward__source() {
        return "forward = function(matrix[double] pred, matrix[double] y)\n    return (double loss) {\n  /*\n   * Computes the forward pass for a log loss function.\n   *\n   *   ```\n   *   L_i = -y_i*log(pred_i) - (1-y_i)*log(1-pred_i)\n   *   L = (1/N) sum(L_i) for i=1 to N\n   *   ```\n   *\n   * In these equations, `L` is the total loss, `L_i` is the loss for\n   * example `i`, `y_i` is the binary target, `pred_i` is probability\n   * of the true class (i.e. `y=1`), and `N` is the number of examples.\n   *\n   * This can be interpreted as the negative log-likelihood assuming\n   * a Bernoulli distribution.\n   *\n   * Inputs:\n   *  - pred: Predictions, of shape (N, 1).\n   *      Predictions should be probabilities of the true\n   *      class (i.e. probability of `y=1`).\n   *  - y: Targets, of shape (N, 1).\n   *      Targets should be binary in the set {0, 1}.\n   *\n   * Outputs:\n   *  - loss: Average loss.\n   */\n  N = nrow(y)\n  losses = -y*log(pred) - (1-y)*log(1-pred)\n  loss = sum(losses) / N\n}\n";
    }

    public Matrix backward(Object obj, Object obj2) {
        Script script = new Script("source('scripts/nn/layers/log_loss.dml') as mlcontextns;dpred = mlcontextns::backward(pred, y);");
        script.in("pred", obj).in("y", obj2).out("dpred");
        return script.execute().getMatrix("dpred");
    }

    public String backward__docs() {
        return "backward = function(matrix[double] pred, matrix[double] y)\n    return (matrix[double] dpred) {\n  /*\n   * Computes the backward pass for a log loss function.\n   *\n   * Inputs:\n   *  - pred: Predictions, of shape (N, 1).\n   *      Predictions should be probabilities of the true\n   *      class (i.e. probability of `y=1`).\n   *  - y: Targets, of shape (N, 1).\n   *      Targets should be binary in the set {0, 1}.\n   *\n   * Outputs:\n   *  - dpred: Gradient wrt `pred`, of shape (N, 1).\n   */\n";
    }

    public String backward__source() {
        return "backward = function(matrix[double] pred, matrix[double] y)\n    return (matrix[double] dpred) {\n  /*\n   * Computes the backward pass for a log loss function.\n   *\n   * Inputs:\n   *  - pred: Predictions, of shape (N, 1).\n   *      Predictions should be probabilities of the true\n   *      class (i.e. probability of `y=1`).\n   *  - y: Targets, of shape (N, 1).\n   *      Targets should be binary in the set {0, 1}.\n   *\n   * Outputs:\n   *  - dpred: Gradient wrt `pred`, of shape (N, 1).\n   */\n  N = nrow(y)\n  dpred = (1/N) * (pred-y) / (pred*(1-pred))\n}\n";
    }
}
