package hivemall.model;

import hivemall.model.WeightValue;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Copyable;
import hivemall.utils.math.MathUtils;
import java.util.Arrays;
import javax.annotation.Nonnull;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/* loaded from: input_file:hivemall/model/DenseModel.class */
public final class DenseModel extends AbstractPredictionModel {
    private static final Log logger;
    private int size;
    private float[] weights;
    private float[] covars;
    private float[] sum_of_squared_gradients;
    private float[] sum_of_squared_delta_x;
    private float[] sum_of_gradients;
    private short[] clocks;
    private byte[] deltaUpdates;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hivemall/model/DenseModel$Itr.class */
    private final class Itr implements IMapIterator<Number, IWeightValue> {
        private int cursor;
        private final WeightValue.WeightValueWithCovar tmpWeight;

        private Itr() {
            this.cursor = -1;
            this.tmpWeight = new WeightValue.WeightValueWithCovar();
        }

        @Override // hivemall.utils.collections.IMapIterator
        public boolean hasNext() {
            return this.cursor < DenseModel.this.size;
        }

        @Override // hivemall.utils.collections.IMapIterator
        public int next() {
            this.cursor++;
            if (hasNext()) {
                return this.cursor;
            }
            return -1;
        }

        @Override // hivemall.utils.collections.IMapIterator
        /* renamed from: getKey */
        public Number getKey2() {
            return Integer.valueOf(this.cursor);
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // hivemall.utils.collections.IMapIterator
        public IWeightValue getValue() {
            if (DenseModel.this.covars == null) {
                float f = DenseModel.this.weights[this.cursor];
                WeightValue weightValue = new WeightValue(f);
                weightValue.setTouched(f != 0.0f);
                return weightValue;
            }
            float f2 = DenseModel.this.weights[this.cursor];
            float f3 = DenseModel.this.covars[this.cursor];
            WeightValue.WeightValueWithCovar weightValueWithCovar = new WeightValue.WeightValueWithCovar(f2, f3);
            weightValueWithCovar.setTouched((f2 == 0.0f && f3 == 1.0f) ? false : true);
            return weightValueWithCovar;
        }

        @Override // hivemall.utils.collections.IMapIterator
        public <T extends Copyable<IWeightValue>> void getValue(@Nonnull T t) {
            float f = DenseModel.this.weights[this.cursor];
            this.tmpWeight.value = f;
            float f2 = 1.0f;
            if (DenseModel.this.covars != null) {
                f2 = DenseModel.this.covars[this.cursor];
                this.tmpWeight.setCovariance(f2);
            }
            this.tmpWeight.setTouched((f == 0.0f && f2 == 1.0f) ? false : true);
            t.copyFrom(this.tmpWeight);
        }
    }

    public DenseModel(int i) {
        this(i, false);
    }

    public DenseModel(int i, boolean z) {
        int i2 = i + 1;
        this.size = i2;
        this.weights = new float[i2];
        if (z) {
            float[] fArr = new float[i2];
            Arrays.fill(fArr, 1.0f);
            this.covars = fArr;
        } else {
            this.covars = null;
        }
        this.sum_of_squared_gradients = null;
        this.sum_of_squared_delta_x = null;
        this.sum_of_gradients = null;
        this.clocks = null;
        this.deltaUpdates = null;
    }

    @Override // hivemall.model.AbstractPredictionModel
    protected boolean isDenseModel() {
        return true;
    }

    @Override // hivemall.model.PredictionModel
    public boolean hasCovariance() {
        return this.covars != null;
    }

    @Override // hivemall.model.PredictionModel
    public void configureParams(boolean z, boolean z2, boolean z3) {
        if (z) {
            this.sum_of_squared_gradients = new float[this.size];
        }
        if (z2) {
            this.sum_of_squared_delta_x = new float[this.size];
        }
        if (z3) {
            this.sum_of_gradients = new float[this.size];
        }
    }

    @Override // hivemall.model.PredictionModel
    public void configureClock() {
        if (this.clocks == null) {
            this.clocks = new short[this.size];
            this.deltaUpdates = new byte[this.size];
        }
    }

    @Override // hivemall.model.PredictionModel
    public boolean hasClock() {
        return this.clocks != null;
    }

    @Override // hivemall.model.AbstractPredictionModel, hivemall.model.PredictionModel
    public void resetDeltaUpdates(int i) {
        this.deltaUpdates[i] = 0;
    }

    private void ensureCapacity(int i) {
        if (i >= this.size) {
            int bitsRequired = MathUtils.bitsRequired(i);
            int i2 = (1 << bitsRequired) + 1;
            int i3 = this.size;
            logger.info("Expands internal array size from " + i3 + " to " + i2 + " (" + bitsRequired + " bits)");
            this.size = i2;
            this.weights = Arrays.copyOf(this.weights, i2);
            if (this.covars != null) {
                this.covars = Arrays.copyOf(this.covars, i2);
                Arrays.fill(this.covars, i3, i2, 1.0f);
            }
            if (this.sum_of_squared_gradients != null) {
                this.sum_of_squared_gradients = Arrays.copyOf(this.sum_of_squared_gradients, i2);
            }
            if (this.sum_of_squared_delta_x != null) {
                this.sum_of_squared_delta_x = Arrays.copyOf(this.sum_of_squared_delta_x, i2);
            }
            if (this.sum_of_gradients != null) {
                this.sum_of_gradients = Arrays.copyOf(this.sum_of_gradients, i2);
            }
            if (this.clocks != null) {
                this.clocks = Arrays.copyOf(this.clocks, i2);
                this.deltaUpdates = Arrays.copyOf(this.deltaUpdates, i2);
            }
        }
    }

    @Override // hivemall.model.PredictionModel
    public <T extends IWeightValue> T get(@Nonnull Object obj) {
        int parseInt = HiveUtils.parseInt(obj);
        if (parseInt >= this.size) {
            return null;
        }
        return this.sum_of_squared_gradients != null ? this.sum_of_squared_delta_x != null ? new WeightValue.WeightValueParamsF2(this.weights[parseInt], this.sum_of_squared_gradients[parseInt], this.sum_of_squared_delta_x[parseInt]) : this.sum_of_gradients != null ? new WeightValue.WeightValueParamsF2(this.weights[parseInt], this.sum_of_squared_gradients[parseInt], this.sum_of_gradients[parseInt]) : new WeightValue.WeightValueParamsF1(this.weights[parseInt], this.sum_of_squared_gradients[parseInt]) : this.covars != null ? new WeightValue.WeightValueWithCovar(this.weights[parseInt], this.covars[parseInt]) : new WeightValue(this.weights[parseInt]);
    }

    @Override // hivemall.model.PredictionModel
    public <T extends IWeightValue> void set(@Nonnull Object obj, @Nonnull T t) {
        int parseInt = HiveUtils.parseInt(obj);
        ensureCapacity(parseInt);
        float f = t.get();
        this.weights[parseInt] = f;
        float f2 = 1.0f;
        boolean hasCovariance = t.hasCovariance();
        if (hasCovariance) {
            f2 = t.getCovariance();
            this.covars[parseInt] = f2;
        }
        if (this.sum_of_squared_gradients != null) {
            this.sum_of_squared_gradients[parseInt] = t.getSumOfSquaredGradients();
        }
        if (this.sum_of_squared_delta_x != null) {
            this.sum_of_squared_delta_x[parseInt] = t.getSumOfSquaredDeltaX();
        }
        if (this.sum_of_gradients != null) {
            this.sum_of_gradients[parseInt] = t.getSumOfGradients();
        }
        short s = 0;
        int i = 0;
        if (this.clocks != null && t.isTouched()) {
            s = (short) (this.clocks[parseInt] + 1);
            this.clocks[parseInt] = s;
            i = this.deltaUpdates[parseInt] + 1;
            if (!$assertionsDisabled && i <= 0) {
                throw new AssertionError(i);
            }
            this.deltaUpdates[parseInt] = (byte) i;
        }
        onUpdate(parseInt, f, f2, s, i, hasCovariance);
    }

    @Override // hivemall.model.PredictionModel
    public void delete(@Nonnull Object obj) {
        int parseInt = HiveUtils.parseInt(obj);
        if (parseInt >= this.size) {
            return;
        }
        this.weights[parseInt] = 0.0f;
        if (this.covars != null) {
            this.covars[parseInt] = 1.0f;
        }
        if (this.sum_of_squared_gradients != null) {
            this.sum_of_squared_gradients[parseInt] = 0.0f;
        }
        if (this.sum_of_squared_delta_x != null) {
            this.sum_of_squared_delta_x[parseInt] = 0.0f;
        }
        if (this.sum_of_gradients != null) {
            this.sum_of_gradients[parseInt] = 0.0f;
        }
    }

    @Override // hivemall.model.PredictionModel
    public float getWeight(@Nonnull Object obj) {
        int parseInt = HiveUtils.parseInt(obj);
        if (parseInt >= this.size) {
            return 0.0f;
        }
        return this.weights[parseInt];
    }

    @Override // hivemall.model.PredictionModel
    public void setWeight(@Nonnull Object obj, float f) {
        throw new UnsupportedOperationException();
    }

    @Override // hivemall.model.PredictionModel
    public float getCovariance(@Nonnull Object obj) {
        int parseInt = HiveUtils.parseInt(obj);
        if (parseInt >= this.size) {
            return 1.0f;
        }
        return this.covars[parseInt];
    }

    @Override // hivemall.model.AbstractPredictionModel
    protected void _set(@Nonnull Object obj, float f, short s) {
        int intValue = ((Integer) obj).intValue();
        ensureCapacity(intValue);
        this.weights[intValue] = f;
        this.clocks[intValue] = s;
        this.deltaUpdates[intValue] = 0;
    }

    @Override // hivemall.model.AbstractPredictionModel
    protected void _set(@Nonnull Object obj, float f, float f2, short s) {
        int intValue = ((Integer) obj).intValue();
        ensureCapacity(intValue);
        this.weights[intValue] = f;
        this.covars[intValue] = f2;
        this.clocks[intValue] = s;
        this.deltaUpdates[intValue] = 0;
    }

    @Override // hivemall.model.PredictionModel
    public int size() {
        return this.size;
    }

    @Override // hivemall.model.PredictionModel
    public boolean contains(@Nonnull Object obj) {
        int parseInt = HiveUtils.parseInt(obj);
        return parseInt < this.size && this.weights[parseInt] != 0.0f;
    }

    @Override // hivemall.model.PredictionModel
    public <K, V extends IWeightValue> IMapIterator<K, V> entries() {
        return new Itr();
    }

    static {
        $assertionsDisabled = !DenseModel.class.desiredAssertionStatus();
        logger = LogFactory.getLog(DenseModel.class);
    }
}
