package hivemall.optimizer;

import hivemall.model.IWeightValue;
import hivemall.model.WeightValue;
import hivemall.utils.lang.Primitives;
import java.util.Map;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.concurrent.NotThreadSafe;

/* loaded from: input_file:hivemall/optimizer/Optimizer.class */
public interface Optimizer {

    /* loaded from: input_file:hivemall/optimizer/Optimizer$AdaDelta.class */
    public static abstract class AdaDelta extends OptimizerBase {
        private final float decay;
        private final float eps;
        private final float scale;

        public AdaDelta(@Nonnull Map<String, String> map) {
            super(map);
            this.decay = Primitives.parseFloat(map.get("decay"), 0.95f);
            this.eps = Primitives.parseFloat(map.get("eps"), 1.0E-6f);
            this.scale = Primitives.parseFloat(map.get("scale"), 100.0f);
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float computeDelta(@Nonnull IWeightValue iWeightValue, float f) {
            float sumOfSquaredGradients = iWeightValue.getSumOfSquaredGradients();
            float sumOfSquaredDeltaX = iWeightValue.getSumOfSquaredDeltaX();
            float f2 = (this.decay * sumOfSquaredGradients) + ((1.0f - this.decay) * f * (f / this.scale));
            float sqrt = ((float) Math.sqrt((sumOfSquaredDeltaX + this.eps) / ((f2 * this.scale) + this.eps))) * f;
            float f3 = (this.decay * sumOfSquaredDeltaX) + ((1.0f - this.decay) * sqrt * sqrt);
            iWeightValue.setSumOfSquaredGradients(f2);
            iWeightValue.setSumOfSquaredDeltaX(f3);
            return sqrt;
        }

        @Override // hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return "adadelta";
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Optimizer$AdaGrad.class */
    public static abstract class AdaGrad extends OptimizerBase {
        private final float eps;
        private final float scale;

        public AdaGrad(@Nonnull Map<String, String> map) {
            super(map);
            this.eps = Primitives.parseFloat(map.get("eps"), 1.0f);
            this.scale = Primitives.parseFloat(map.get("scale"), 100.0f);
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float computeDelta(@Nonnull IWeightValue iWeightValue, float f) {
            iWeightValue.setSumOfSquaredGradients(iWeightValue.getSumOfSquaredGradients() + (f * (f / this.scale)));
            return f / (((float) Math.sqrt(r0 * this.scale)) + this.eps);
        }

        @Override // hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return "adagrad";
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Optimizer$AdagradRDA.class */
    public static abstract class AdagradRDA extends OptimizerBase {

        @Nonnull
        private final AdaGrad optimizerImpl;
        private final float lambda;

        public AdagradRDA(@Nonnull AdaGrad adaGrad, @Nonnull Map<String, String> map) {
            super(map);
            this.optimizerImpl = adaGrad;
            this.lambda = Primitives.parseFloat(map.get("lambda"), 1.0E-6f);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        public float update(@Nonnull IWeightValue iWeightValue, float f) {
            float sumOfGradients = iWeightValue.getSumOfGradients() + f;
            float f2 = sumOfGradients > 0.0f ? 1.0f : -1.0f;
            float f3 = ((f2 * sumOfGradients) / ((float) this._numStep)) - this.lambda;
            if (f3 < 0.0f) {
                iWeightValue.set(0.0f);
                iWeightValue.setSumOfSquaredGradients(0.0f);
                iWeightValue.setSumOfGradients(0.0f);
                return 0.0f;
            }
            float eta = (-1.0f) * f2 * this._eta.eta(this._numStep) * ((float) this._numStep) * this.optimizerImpl.computeDelta(iWeightValue, f3);
            iWeightValue.set(eta);
            iWeightValue.setSumOfGradients(sumOfGradients);
            return eta;
        }

        @Override // hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return "adagrad_rda";
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Optimizer$Adam.class */
    public static abstract class Adam extends OptimizerBase {
        private final float beta;
        private final float gamma;
        private final float eps_hat;

        public Adam(@Nonnull Map<String, String> map) {
            super(map);
            this.beta = Primitives.parseFloat(map.get("beta"), 0.9f);
            this.gamma = Primitives.parseFloat(map.get("gamma"), 0.999f);
            this.eps_hat = Primitives.parseFloat(map.get("eps_hat"), 1.0E-8f);
        }

        @Override // hivemall.optimizer.Optimizer.OptimizerBase
        protected float computeDelta(@Nonnull IWeightValue iWeightValue, float f) {
            float m = (this.beta * iWeightValue.getM()) + ((1.0f - this.beta) * f);
            float v = (this.gamma * iWeightValue.getV()) + ((float) ((1.0f - this.gamma) * Math.pow(f, 2.0d)));
            float pow = (m / ((float) (1.0d - Math.pow(this.beta, this._numStep)))) / ((float) (Math.sqrt(v / ((float) (1.0d - Math.pow(this.gamma, this._numStep)))) + this.eps_hat));
            iWeightValue.setM(m);
            iWeightValue.setV(v);
            return pow;
        }

        @Override // hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return "adam";
        }
    }

    @NotThreadSafe
    /* loaded from: input_file:hivemall/optimizer/Optimizer$OptimizerBase.class */
    public static abstract class OptimizerBase implements Optimizer {

        @Nonnull
        protected final EtaEstimator _eta;

        @Nonnull
        protected final Regularization _reg;

        @Nonnegative
        protected long _numStep = 1;

        public OptimizerBase(@Nonnull Map<String, String> map) {
            this._eta = EtaEstimator.get(map);
            this._reg = Regularization.get(map);
        }

        @Override // hivemall.optimizer.Optimizer
        public void proceedStep() {
            this._numStep++;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public float update(@Nonnull IWeightValue iWeightValue, float f) {
            float f2 = iWeightValue.get();
            float eta = f2 - (this._eta.eta(this._numStep) * computeDelta(iWeightValue, this._reg.regularize(f2, f)));
            iWeightValue.set(eta);
            return eta;
        }

        protected float computeDelta(@Nonnull IWeightValue iWeightValue, float f) {
            return f;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/Optimizer$SGD.class */
    public static final class SGD extends OptimizerBase {
        private final IWeightValue weightValueReused;

        public SGD(@Nonnull Map<String, String> map) {
            super(map);
            this.weightValueReused = new WeightValue(0.0f);
        }

        @Override // hivemall.optimizer.Optimizer
        public float update(@Nonnull Object obj, float f, float f2) {
            this.weightValueReused.set(f);
            update(this.weightValueReused, f2);
            return this.weightValueReused.get();
        }

        @Override // hivemall.optimizer.Optimizer
        public String getOptimizerName() {
            return "sgd";
        }
    }

    float update(@Nonnull Object obj, float f, float f2);

    void proceedStep();

    @Nonnull
    String getOptimizerName();
}
