package hivemall.regression;

import hivemall.common.OnlineVariance;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionResult;
import hivemall.optimizer.LossFunctions;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;

@Description(name = "train_pa1_regr", value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options]) - Returns a relation consists of <{int|bigint|string} feature, float weight>")
/* loaded from: input_file:hivemall/regression/PassiveAggressiveRegressionUDTF.class */
public class PassiveAggressiveRegressionUDTF extends RegressionBaseUDTF {
    protected float c;
    protected float epsilon;

    @Description(name = "train_pa1a_regr", value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options]) - Returns a relation consists of <{int|bigint|string} feature, float weight>")
    /* loaded from: input_file:hivemall/regression/PassiveAggressiveRegressionUDTF$PA1a.class */
    public static final class PA1a extends PassiveAggressiveRegressionUDTF {
        private OnlineVariance targetStdDev;

        @Override // hivemall.regression.PassiveAggressiveRegressionUDTF, hivemall.regression.RegressionBaseUDTF
        public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
            this.targetStdDev = new OnlineVariance();
            return super.initialize(objectInspectorArr);
        }

        @Override // hivemall.regression.PassiveAggressiveRegressionUDTF
        protected void preTrain(float f) {
            this.targetStdDev.handle(f);
        }

        @Override // hivemall.regression.PassiveAggressiveRegressionUDTF
        protected float loss(float f, float f2) {
            return LossFunctions.epsilonInsensitiveLoss(f2, f, this.epsilon * ((float) this.targetStdDev.stddev()));
        }
    }

    @Description(name = "train_pa2_regr", value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options]) - Returns a relation consists of <{int|bigint|string} feature, float weight>")
    /* loaded from: input_file:hivemall/regression/PassiveAggressiveRegressionUDTF$PA2.class */
    public static class PA2 extends PassiveAggressiveRegressionUDTF {
        @Override // hivemall.regression.PassiveAggressiveRegressionUDTF
        protected float aggressiveness() {
            return 1.0f;
        }

        @Override // hivemall.regression.PassiveAggressiveRegressionUDTF
        protected float eta(float f, PredictionResult predictionResult) {
            return f / (predictionResult.getSquaredNorm() + (0.5f / this.c));
        }
    }

    @Description(name = "train_pa2a_regr", value = "_FUNC_(array<int|bigint|string> features, float target [, constant string options]) - Returns a relation consists of <{int|bigint|string} feature, float weight>")
    /* loaded from: input_file:hivemall/regression/PassiveAggressiveRegressionUDTF$PA2a.class */
    public static final class PA2a extends PA2 {
        private OnlineVariance targetStdDev;

        @Override // hivemall.regression.PassiveAggressiveRegressionUDTF, hivemall.regression.RegressionBaseUDTF
        public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
            this.targetStdDev = new OnlineVariance();
            return super.initialize(objectInspectorArr);
        }

        @Override // hivemall.regression.PassiveAggressiveRegressionUDTF
        protected void preTrain(float f) {
            this.targetStdDev.handle(f);
        }

        @Override // hivemall.regression.PassiveAggressiveRegressionUDTF
        protected float loss(float f, float f2) {
            return LossFunctions.epsilonInsensitiveLoss(f2, f, this.epsilon * ((float) this.targetStdDev.stddev()));
        }
    }

    @Override // hivemall.regression.RegressionBaseUDTF
    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        int length = objectInspectorArr.length;
        if (length == 2 || length == 3) {
            return super.initialize(objectInspectorArr);
        }
        throw new UDFArgumentException("_FUNC_ takes arguments: List<Int|BigInt|Text> features, float target [, constant string options]");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.LearnerBaseUDTF, hivemall.UDTFWithOptions
    public Options getOptions() {
        Options options = super.getOptions();
        options.addOption("c", "aggressiveness", true, "Aggressiveness paramete [default Float.MAX_VALUE]");
        options.addOption("e", "epsilon", true, "Sensitivity to prediction mistakes [default 0.1]");
        return options;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.LearnerBaseUDTF, hivemall.UDTFWithOptions
    public CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        CommandLine processOptions = super.processOptions(objectInspectorArr);
        float aggressiveness = aggressiveness();
        float f = 0.1f;
        if (processOptions != null) {
            String optionValue = processOptions.getOptionValue("c");
            if (optionValue != null) {
                aggressiveness = Float.parseFloat(optionValue);
                if (aggressiveness <= 0.0f) {
                    throw new UDFArgumentException("Aggressiveness parameter C must be C > 0: " + aggressiveness);
                }
            }
            String optionValue2 = processOptions.getOptionValue("epsilon");
            if (optionValue2 != null) {
                f = Float.parseFloat(optionValue2);
            }
        }
        this.c = aggressiveness;
        this.epsilon = f;
        return processOptions;
    }

    protected float aggressiveness() {
        return Float.MAX_VALUE;
    }

    @Override // hivemall.regression.RegressionBaseUDTF
    protected void train(@Nonnull FeatureValue[] featureValueArr, float f) {
        preTrain(f);
        PredictionResult calcScoreAndNorm = calcScoreAndNorm(featureValueArr);
        float score = calcScoreAndNorm.getScore();
        float loss = loss(f, score);
        if (loss > 0.0f) {
            float eta = (f - score > 0.0f ? 1 : -1) * eta(loss, calcScoreAndNorm);
            if (Float.isInfinite(eta)) {
                return;
            }
            onlineUpdate(featureValueArr, eta);
        }
    }

    protected void preTrain(float f) {
    }

    protected float loss(float f, float f2) {
        return LossFunctions.epsilonInsensitiveLoss(f2, f, this.epsilon);
    }

    protected float eta(float f, PredictionResult predictionResult) {
        return Math.min(this.c, f / predictionResult.getSquaredNorm());
    }
}
