package hivemall.classifier;

import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
import hivemall.model.WeightValue;
import hivemall.optimizer.LossFunctions;
import hivemall.utils.lang.Primitives;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;

@Description(name = "train_adagrad_rda", value = "_FUNC_(list<string|int|bigint> features, int label [, const string options]) - Returns a relation consists of <string|int|bigint feature, float weight>", extended = "Build a prediction model by Adagrad+RDA regularization binary classifier")
@Deprecated
/* loaded from: input_file:hivemall/classifier/AdaGradRDAUDTF.class */
public final class AdaGradRDAUDTF extends BinaryOnlineClassifierUDTF {
    private float eta;
    private float lambda;
    private float scaling;

    @Override // hivemall.classifier.BinaryOnlineClassifierUDTF
    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        int length = objectInspectorArr.length;
        if (length != 2 && length != 3) {
            throw new UDFArgumentException("_FUNC_ takes 2 or 3 arguments: List<Text|Int|BitInt> features, int label [, constant string options]");
        }
        StructObjectInspector initialize = super.initialize(objectInspectorArr);
        this.model.configureParams(true, false, true);
        return initialize;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.LearnerBaseUDTF, hivemall.UDTFWithOptions
    public Options getOptions() {
        Options options = super.getOptions();
        options.addOption("eta", "eta0", true, "The learning rate \\eta [default 0.1]");
        options.addOption("lambda", true, "lambda constant of RDA [default: 1E-6f]");
        options.addOption("scale", true, "Internal scaling/descaling factor for cumulative weights [default: 100]");
        return options;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.LearnerBaseUDTF, hivemall.UDTFWithOptions
    public CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        CommandLine processOptions = super.processOptions(objectInspectorArr);
        if (processOptions == null) {
            this.eta = 0.1f;
            this.lambda = 1.0E-6f;
            this.scaling = 100.0f;
        } else {
            this.eta = Primitives.parseFloat(processOptions.getOptionValue("eta"), 0.1f);
            this.lambda = Primitives.parseFloat(processOptions.getOptionValue("lambda"), 1.0E-6f);
            this.scaling = Primitives.parseFloat(processOptions.getOptionValue("scale"), 100.0f);
        }
        return processOptions;
    }

    @Override // hivemall.classifier.BinaryOnlineClassifierUDTF
    protected void train(@Nonnull FeatureValue[] featureValueArr, int i) {
        float f = i > 0 ? 1.0f : -1.0f;
        if (LossFunctions.hingeLoss(predict(featureValueArr), f) <= 0.0f) {
            return;
        }
        update(featureValueArr, f, this.count);
    }

    protected void update(@Nonnull FeatureValue[] featureValueArr, float f, int i) {
        for (FeatureValue featureValue : featureValueArr) {
            if (featureValue != null) {
                updateWeight(featureValue.getFeature(), featureValue.getValueAsFloat(), f, i);
            }
        }
    }

    protected void updateWeight(@Nonnull Object obj, float f, float f2, float f3) {
        float f4 = (-f2) * f * this.scaling;
        float f5 = 0.0f;
        float f6 = 0.0f;
        IWeightValue iWeightValue = this.model.get(obj);
        if (iWeightValue != null) {
            f5 = iWeightValue.getSumOfSquaredGradients();
            f6 = iWeightValue.getSumOfGradients();
        }
        float f7 = f6 + f4;
        float f8 = f5 + (f4 * f4);
        float f9 = f7 * this.scaling;
        double d = f8 * this.scaling;
        float f10 = f9 > 0.0f ? 1.0f : -1.0f;
        float f11 = ((f10 * f9) / f3) - this.lambda;
        if (f11 < 0.0f) {
            this.model.delete(obj);
        } else {
            this.model.set(obj, new WeightValue.WeightValueParamsF2((((((-1.0f) * f10) * this.eta) * f3) * f11) / ((float) Math.sqrt(d)), f8, f7));
        }
    }
}
