package org.apache.sysml.scripts.nn.layers;

import java.io.IOException;
import java.io.InputStreamReader;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Matrix;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.scripts.nn.layers.fm.Backward_output;
import org.apache.sysml.scripts.nn.layers.fm.Init_output;

/* loaded from: input_file:org/apache/sysml/scripts/nn/layers/Fm.class */
public class Fm extends Script {
    public Fm() {
        InputStreamReader inputStreamReader = new InputStreamReader(Script.class.getResourceAsStream(new StringBuffer().append("/").append("scripts/nn/layers/fm.dml").toString()));
        char[] cArr = new char[1024];
        StringBuilder sb = new StringBuilder();
        while (true) {
            try {
                int read = inputStreamReader.read(cArr);
                if (read <= 0) {
                    break;
                } else {
                    sb.append(cArr, 0, read);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        setScriptString(sb.toString());
    }

    public Init_output init(Object obj, Object obj2, Object obj3) {
        Script script = new Script("source('scripts/nn/layers/fm.dml') as mlcontextns;[w0, W, V] = mlcontextns::init(n, d, k);");
        script.in("n", obj).in("d", obj2).in("k", obj3).out("w0").out("W").out("V");
        MLResults execute = script.execute();
        return new Init_output(execute.getMatrix("w0"), execute.getMatrix("W"), execute.getMatrix("V"));
    }

    public String init__docs() {
        return "init = function(int n, int d, int k)\n    return (matrix[double] w0, matrix[double] W, matrix[double] V) {\n  /*\n   * This function initializes the parameters.\n   *\n   * Inputs:\n   *  - d: the number of features, is an integer.\n   *  - k: the factorization dimensionality, is an integer.\n   *\n   * Outputs:\n   *  - w0: the global bias, of shape (1,).\n   *  - W : the strength of each feature, of shape (d, 1).\n   *  - V : factorized interaction terms, of shape (d, k).\n   */\n";
    }

    public String init__source() {
        return "init = function(int n, int d, int k)\n    return (matrix[double] w0, matrix[double] W, matrix[double] V) {\n  /*\n   * This function initializes the parameters.\n   *\n   * Inputs:\n   *  - d: the number of features, is an integer.\n   *  - k: the factorization dimensionality, is an integer.\n   *\n   * Outputs:\n   *  - w0: the global bias, of shape (1,).\n   *  - W : the strength of each feature, of shape (d, 1).\n   *  - V : factorized interaction terms, of shape (d, k).\n   */\n  w0 = matrix(0, rows=1, cols=1)\n  W  = matrix(0, rows=d, cols=1)\n  V  = rand(rows=d, cols=k, min=0.0, max=1.0, pdf=\"uniform\", sparsity=.08)\n}\n";
    }

    public Matrix forward(Object obj, Object obj2, Object obj3, Object obj4) {
        Script script = new Script("source('scripts/nn/layers/fm.dml') as mlcontextns;out = mlcontextns::forward(X, w0, W, V);");
        script.in("X", obj).in("w0", obj2).in("W", obj3).in("V", obj4).out("out");
        return script.execute().getMatrix("out");
    }

    public String forward__docs() {
        return "forward = function(matrix[double] X, matrix[double] w0, matrix[double] W, matrix[double] V)\n    return (matrix[double] out) {\n  /*\n   * Computes the model.\n   *\n   * Reference:\n   *  - Factorization Machines, Steffen Rendle.\n   *\n   * Inputs:\n   *  - X : n examples with d features, of shape (n, d).\n   *  - w0: the global bias, of shape (1,).\n   *  - W : the strength of each feature, of shape (d, 1).\n   *  - V : factorized interaction terms, of shape (d, k).\n   *\n   * Outputs:\n   *  - out : target vector, of shape (n, 1).\n   */\n";
    }

    public String forward__source() {
        return "forward = function(matrix[double] X, matrix[double] w0, matrix[double] W, matrix[double] V)\n    return (matrix[double] out) {\n  /*\n   * Computes the model.\n   *\n   * Reference:\n   *  - Factorization Machines, Steffen Rendle.\n   *\n   * Inputs:\n   *  - X : n examples with d features, of shape (n, d).\n   *  - w0: the global bias, of shape (1,).\n   *  - W : the strength of each feature, of shape (d, 1).\n   *  - V : factorized interaction terms, of shape (d, k).\n   *\n   * Outputs:\n   *  - out : target vector, of shape (n, 1).\n   */\n  out = (X %*% W) + (0.5 * rowSums((X %*% V)^2 - (X^2 %*% V^2)) ) + w0  # shape (n, 1)\n}\n";
    }

    public Backward_output backward(Object obj, Object obj2, Object obj3, Object obj4, Object obj5) {
        Script script = new Script("source('scripts/nn/layers/fm.dml') as mlcontextns;[dw0, dW, dV] = mlcontextns::backward(dout, X, w0, W, V);");
        script.in("dout", obj).in("X", obj2).in("w0", obj3).in("W", obj4).in("V", obj5).out("dw0").out("dW").out("dV");
        MLResults execute = script.execute();
        return new Backward_output(execute.getMatrix("dw0"), execute.getMatrix("dW"), execute.getMatrix("dV"));
    }

    public String backward__docs() {
        return "backward = function(matrix[double] dout, matrix[double] X, matrix[double] w0, matrix[double] W,\n                    matrix[double] V)\n    return (matrix[double] dw0, matrix[double] dW, matrix[double] dV) {\n  /*\n   * This function accepts the upstream gradients w.r.t. output target\n   * vector, and returns the gradients of the loss w.r.t. the\n   * parameters.\n   *\n   * Inputs:\n   *  - dout : the gradient of the loss function w.r.t y, of\n   *     shape (n, 1).\n   *  - X, w0, W, V are as mentioned in the above forward function.\n   *\n   * Outputs:\n   *  - dX : the gradient of loss function w.r.t  X, of shape (n, d).\n   *  - dw0: the gradient of loss function w.r.t w0, of shape (1,).\n   *  - dW : the gradient of loss function w.r.t  W, of shape (d, 1).\n   *  - dV : the gradient of loss function w.r.t  V, of shape (d, k).\n   */\n";
    }

    public String backward__source() {
        return "backward = function(matrix[double] dout, matrix[double] X, matrix[double] w0, matrix[double] W,\n                    matrix[double] V)\n    return (matrix[double] dw0, matrix[double] dW, matrix[double] dV) {\n  /*\n   * This function accepts the upstream gradients w.r.t. output target\n   * vector, and returns the gradients of the loss w.r.t. the\n   * parameters.\n   *\n   * Inputs:\n   *  - dout : the gradient of the loss function w.r.t y, of\n   *     shape (n, 1).\n   *  - X, w0, W, V are as mentioned in the above forward function.\n   *\n   * Outputs:\n   *  - dX : the gradient of loss function w.r.t  X, of shape (n, d).\n   *  - dw0: the gradient of loss function w.r.t w0, of shape (1,).\n   *  - dW : the gradient of loss function w.r.t  W, of shape (d, 1).\n   *  - dV : the gradient of loss function w.r.t  V, of shape (d, k).\n   */\n  n = nrow(X)\n  d = ncol(X)\n  k = ncol(V)\n\n  # 1. gradient of target vector w.r.t. w0\n  g_w0 = as.matrix(1)  # shape (1, 1)\n\n  ## gradient of loss function w.r.t. w0\n  dw0  = colSums(dout)  # shape (1, 1)\n\n  # 2. gradient target vector w.r.t. W\n  g_W = X  # shape (n, d)\n\n  ## gradient of loss function w.r.t. W\n  dW  =  t(g_W) %*% dout  # shape (d, 1)\n\n  # TODO: VECTORIZE THE FOLLOWING CODE (https://issues.apache.org/jira/browse/SYSTEMML-2102)\n  # 3. gradient of target vector w.r.t. V\n  # First term -> g_V1 = t(X) %*% (X %*% V)  # shape (d, k)\n\n  ## gradient of loss function w.r.t. V\n  # First term -> t(X) %*% X %*% V\n\n\n  # Second term -> V(i,f) * (X(i))^2\n  Xt = t( X^2 ) %*% dout  # shape (d,1)\n\n  g_V2 = Xt[1,] %*% V[1,]\n\n  for (i in 2:d) {\n    tmp = Xt[i,] %*% V[i,]\n    g_V2 = rbind(g_V2, tmp)\n  }\n\n  xv = X %*% V\n\n  g_V1 = dout[,1] * xv[,1]\n\n  for (j in 2:k) {\n    tmp1 = dout[,1] * xv[,k]\n    g_V1 = cbind(g_V1, tmp1)\n  }\n\n  dV = (t(X) %*% g_V1) - g_V2\n  # dV = mean(dout) * (t(X) %*% X %*%V) - g_V2\n}\n";
    }
}
