package hivemall.classifier;

import hivemall.model.FeatureValue;
import hivemall.model.PredictionResult;
import hivemall.optimizer.LossFunctions;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;

@Description(name = "train_pa", value = "_FUNC_(list<string|int|bigint> features, int label [, const string options]) - Returns a relation consists of <string|int|bigint feature, float weight>", extended = "Build a prediction model by Passive-Aggressive (PA) binary classifier")
/* loaded from: input_file:hivemall/classifier/PassiveAggressiveUDTF.class */
public class PassiveAggressiveUDTF extends BinaryOnlineClassifierUDTF {

    @Description(name = "train_pa1", value = "_FUNC_(list<string|int|bigint> features, int label [, const string options]) - Returns a relation consists of <string|int|bigint feature, float weight>", extended = "Build a prediction model by Passive-Aggressive 1 (PA-1) binary classifier")
    /* loaded from: input_file:hivemall/classifier/PassiveAggressiveUDTF$PA1.class */
    public static class PA1 extends PassiveAggressiveUDTF {
        protected float c;

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hivemall.LearnerBaseUDTF, hivemall.UDTFWithOptions
        public Options getOptions() {
            Options options = super.getOptions();
            options.addOption("c", "aggressiveness", true, "Aggressiveness parameter C [default 1.0]");
            return options;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hivemall.LearnerBaseUDTF, hivemall.UDTFWithOptions
        public CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
            String optionValue;
            CommandLine processOptions = super.processOptions(objectInspectorArr);
            float f = 1.0f;
            if (processOptions != null && (optionValue = processOptions.getOptionValue("c")) != null) {
                f = Float.parseFloat(optionValue);
                if (f <= 0.0f) {
                    throw new UDFArgumentException("Aggressiveness parameter C must be C > 0: " + f);
                }
            }
            this.c = f;
            return processOptions;
        }

        @Override // hivemall.classifier.PassiveAggressiveUDTF
        protected float eta(float f, PredictionResult predictionResult) {
            return Math.min(this.c, f / predictionResult.getSquaredNorm());
        }
    }

    @Description(name = "train_pa2", value = "_FUNC_(list<string|int|bigint> features, int label [, const string options]) - Returns a relation consists of <string|int|bigint feature, float weight>", extended = "Build a prediction model by Passive-Aggressive 2 (PA-2) binary classifier")
    /* loaded from: input_file:hivemall/classifier/PassiveAggressiveUDTF$PA2.class */
    public static class PA2 extends PA1 {
        @Override // hivemall.classifier.PassiveAggressiveUDTF.PA1, hivemall.classifier.PassiveAggressiveUDTF
        protected float eta(float f, PredictionResult predictionResult) {
            return f / (predictionResult.getSquaredNorm() + (0.5f / this.c));
        }
    }

    @Override // hivemall.classifier.BinaryOnlineClassifierUDTF
    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        int length = objectInspectorArr.length;
        if (length == 2 || length == 3) {
            return super.initialize(objectInspectorArr);
        }
        throw new UDFArgumentException("_FUNC_ takes 2 or 3 arguments: List<Text|Int|BitInt> features, int label [, constant string options]");
    }

    @Override // hivemall.classifier.BinaryOnlineClassifierUDTF
    protected void train(@Nonnull FeatureValue[] featureValueArr, int i) {
        float f = i > 0 ? 1.0f : -1.0f;
        PredictionResult calcScoreAndNorm = calcScoreAndNorm(featureValueArr);
        float hingeLoss = LossFunctions.hingeLoss(calcScoreAndNorm.getScore(), f);
        if (hingeLoss > 0.0f) {
            update(featureValueArr, eta(hingeLoss, calcScoreAndNorm) * f);
        }
    }

    protected float eta(float f, PredictionResult predictionResult) {
        return f / predictionResult.getSquaredNorm();
    }
}
