package hivemall.mf;

import hivemall.optimizer.EtaEstimator;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;

@Description(name = "train_mf_sgd", value = "_FUNC_(INT user, INT item, FLOAT rating [, CONSTANT STRING options]) - Returns a relation consists of <int idx, array<float> Pu, array<float> Qi [, float Bu, float Bi [, float mu]]>")
/* loaded from: input_file:hivemall/mf/MatrixFactorizationSGDUDTF.class */
public final class MatrixFactorizationSGDUDTF extends OnlineMatrixFactorizationUDTF {
    private EtaEstimator etaEstimator;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.mf.OnlineMatrixFactorizationUDTF, hivemall.UDTFWithOptions
    public Options getOptions() {
        Options options = super.getOptions();
        options.addOption("eta", true, "The initial learning rate [default: 0.001]");
        options.addOption("eta0", true, "The initial learning rate [default 0.2]");
        options.addOption("t", "total_steps", true, "The total number of training examples");
        options.addOption("power_t", true, "The exponent for inverse scaling learning rate [default 0.1]");
        return options;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.mf.OnlineMatrixFactorizationUDTF, hivemall.UDTFWithOptions
    public CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        CommandLine processOptions = super.processOptions(objectInspectorArr);
        this.etaEstimator = EtaEstimator.get(processOptions);
        return processOptions;
    }

    @Override // hivemall.mf.OnlineMatrixFactorizationUDTF
    protected float eta() {
        return this.etaEstimator.eta(this.count);
    }
}
