package hivemall.fm;

import hivemall.fm.FMHyperParameters;
import hivemall.utils.collections.Fastutil;
import hivemall.utils.collections.arrays.DoubleArray3D;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.math.MathUtils;
import it.unimi.dsi.fastutil.ints.Int2LongMap;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;

@Description(name = "train_ffm", value = "_FUNC_(array<string> x, double y [, const string options]) - Returns a prediction model")
/* loaded from: input_file:hivemall/fm/FieldAwareFactorizationMachineUDTF.class */
public final class FieldAwareFactorizationMachineUDTF extends FactorizationMachineUDTF {
    private static final Log LOG = LogFactory.getLog(FieldAwareFactorizationMachineUDTF.class);
    private boolean _globalBias;
    private boolean _linearCoeff;
    private int _numFeatures;
    private int _numFields;
    protected transient FFMStringFeatureMapModel _ffmModel;
    private transient IntArrayList _fieldList;

    @Nullable
    private transient DoubleArray3D _sumVfX;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.fm.FactorizationMachineUDTF, hivemall.UDTFWithOptions
    public Options getOptions() {
        Options options = super.getOptions();
        options.addOption("w0", "global_bias", false, "Whether to include global bias term w0 [default: OFF]");
        options.addOption("enable_wi", "linear_term", false, "Include linear term [default: OFF]");
        options.addOption("no_norm", "disable_norm", false, "Disable instance-wise L2 normalization");
        options.addOption("feature_hashing", true, "The number of bits for feature hashing in range [18,31] [default: -1]. No feature hashing for -1.");
        options.addOption("num_fields", true, "The number of fields [default: 256]");
        options.addOption("opt", "optimizer", true, "Gradient Descent optimizer [default: ftrl, adagrad, sgd]");
        options.addOption("eps", true, "A constant used in the denominator of AdaGrad [default: 1.0]");
        options.addOption("alpha", "alphaFTRL", true, "Alpha value (learning rate) of Follow-The-Regularized-Reader [default: 0.5]");
        options.addOption("beta", "betaFTRL", true, "Beta value (a learning smoothing parameter) of Follow-The-Regularized-Reader [default: 1.0]");
        options.addOption("l1", "lambda1", true, "L1 regularization value of Follow-The-Regularized-Reader that controls model Sparseness [default: 0.0002]");
        options.addOption("l2", "lambda2", true, "L2 regularization value of Follow-The-Regularized-Reader [default: 0.0001]");
        return options;
    }

    @Override // hivemall.fm.FactorizationMachineUDTF
    protected boolean isAdaptiveRegularizationSupported() {
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.fm.FactorizationMachineUDTF
    public FMHyperParameters.FFMHyperParameters newHyperParameters() {
        return new FMHyperParameters.FFMHyperParameters();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.fm.FactorizationMachineUDTF, hivemall.UDTFWithOptions
    public CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        CommandLine processOptions = super.processOptions(objectInspectorArr);
        FMHyperParameters.FFMHyperParameters fFMHyperParameters = (FMHyperParameters.FFMHyperParameters) this._params;
        this._globalBias = fFMHyperParameters.globalBias;
        this._linearCoeff = fFMHyperParameters.linearCoeff;
        this._numFeatures = fFMHyperParameters.numFeatures;
        this._numFields = fFMHyperParameters.numFields;
        return processOptions;
    }

    @Override // hivemall.fm.FactorizationMachineUDTF
    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        StructObjectInspector initialize = super.initialize(objectInspectorArr);
        this._fieldList = new IntArrayList();
        return initialize;
    }

    @Override // hivemall.fm.FactorizationMachineUDTF
    protected StructObjectInspector getOutputOI(@Nonnull FMHyperParameters fMHyperParameters) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add("model_id");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        arrayList.add("i");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        arrayList.add("Wi");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        arrayList.add("Vi");
        arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.fm.FactorizationMachineUDTF
    public FFMStringFeatureMapModel initModel(@Nonnull FMHyperParameters fMHyperParameters) throws UDFArgumentException {
        FFMStringFeatureMapModel fFMStringFeatureMapModel = new FFMStringFeatureMapModel((FMHyperParameters.FFMHyperParameters) fMHyperParameters);
        this._ffmModel = fFMStringFeatureMapModel;
        return fFMStringFeatureMapModel;
    }

    @Override // hivemall.fm.FactorizationMachineUDTF
    protected Feature[] parseFeatures(@Nonnull Object obj) throws HiveException {
        Feature[] parseFFMFeatures = Feature.parseFFMFeatures(obj, this._xOI, this._probes, this._numFeatures, this._numFields);
        if (this._params.l2norm) {
            Feature.l2normalize(parseFFMFeatures);
        }
        return parseFFMFeatures;
    }

    @Override // hivemall.fm.FactorizationMachineUDTF
    protected void processValidationSample(@Nonnull Feature[] featureArr, double d) throws HiveException {
        if (this._earlyStopping) {
            this._validationState.incrLoss(this._lossFunction.loss(this._model.predict(featureArr), d));
        }
    }

    @Override // hivemall.fm.FactorizationMachineUDTF
    protected void trainTheta(@Nonnull Feature[] featureArr, double d) throws HiveException {
        double predict = this._ffmModel.predict(featureArr);
        double dloss = this._ffmModel.dloss(predict, d);
        this._cvState.incrLoss(this._lossFunction.loss(predict, d));
        if (MathUtils.closeToZero(dloss, 1.0E-9d)) {
            return;
        }
        if (this._globalBias) {
            this._ffmModel.updateW0(dloss, this._etaEstimator.eta(this._t));
        }
        IntArrayList fieldList = getFieldList(featureArr);
        DoubleArray3D sumVfX = this._ffmModel.sumVfX(featureArr, fieldList, this._sumVfX);
        for (int i = 0; i < featureArr.length; i++) {
            Feature feature = featureArr[i];
            if (feature.value != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                if (this._linearCoeff) {
                    this._ffmModel.updateWi(dloss, feature, this._t);
                }
                int size = fieldList.size();
                for (int i2 = 0; i2 < size; i2++) {
                    int i3 = fieldList.get(i2);
                    int i4 = this._factors;
                    for (int i5 = 0; i5 < i4; i5++) {
                        double d2 = sumVfX.get(i, i2, i5);
                        if (!MathUtils.closeToZero(d2)) {
                            this._ffmModel.updateV(dloss, feature, i3, i5, d2, this._t);
                        }
                    }
                }
            }
        }
        sumVfX.clear();
        this._sumVfX = sumVfX;
        fieldList.clear();
    }

    @Nonnull
    private IntArrayList getFieldList(@Nonnull Feature[] featureArr) {
        for (Feature feature : featureArr) {
            this._fieldList.add(feature.getField());
        }
        return this._fieldList;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.fm.FactorizationMachineUDTF
    public IntFeature instantiateFeature(@Nonnull ByteBuffer byteBuffer) {
        return new IntFeature(byteBuffer);
    }

    @Override // hivemall.fm.FactorizationMachineUDTF
    public void close() throws HiveException {
        if (LOG.isInfoEnabled()) {
            LOG.info(this._ffmModel.getStatistics());
        }
        this._ffmModel.disableInitV();
        super.close();
        if (LOG.isInfoEnabled()) {
            LOG.info(this._ffmModel.getStatistics());
        }
        this._ffmModel = null;
    }

    @Override // hivemall.fm.FactorizationMachineUDTF
    protected void forwardModel() throws HiveException {
        this._model = null;
        this._fieldList = null;
        this._sumVfX = null;
        int i = this._factors;
        IntWritable intWritable = new IntWritable();
        FloatWritable floatWritable = new FloatWritable(0.0f);
        FloatWritable[] newFloatArray = HiveUtils.newFloatArray(i, 0.0f);
        List asList = Arrays.asList(newFloatArray);
        Object[] objArr = {new Text(HadoopUtils.getUniqueTaskIdString()), intWritable, floatWritable, null};
        intWritable.set(0);
        floatWritable.set(this._ffmModel.getW0());
        forward(objArr);
        Entry entry = new Entry(this._ffmModel._buf, 1);
        Entry entry2 = new Entry(this._ffmModel._buf, i);
        float[] fArr = new float[i];
        for (Int2LongMap.Entry entry3 : Fastutil.fastIterable(this._ffmModel._map)) {
            int intKey = entry3.getIntKey();
            intWritable.set(intKey);
            long longValue = entry3.getLongValue();
            if (Entry.isEntryW(intKey)) {
                entry.setOffset(longValue);
                float w = entry.getW();
                if (w != 0.0f) {
                    floatWritable.set(w);
                    objArr[2] = floatWritable;
                    objArr[3] = null;
                }
            } else {
                entry2.setOffset(longValue);
                entry2.getV(fArr);
                for (int i2 = 0; i2 < i; i2++) {
                    newFloatArray[i2].set(fArr[i2]);
                }
                objArr[2] = null;
                objArr[3] = asList;
            }
            forward(objArr);
        }
    }
}
