package hivemall.classifier.multiclass;

import hivemall.model.FeatureValue;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;

@Description(name = "train_multiclass_perceptron", value = "_FUNC_(list<string|int|bigint> features, {int|string} label [, const string options]) - Returns a relation consists of <{int|string} label, {string|int|bigint} feature, float weight>", extended = "Build a prediction model by Perceptron multiclass classifier")
/* loaded from: input_file:hivemall/classifier/multiclass/MulticlassPerceptronUDTF.class */
public final class MulticlassPerceptronUDTF extends MulticlassOnlineClassifierUDTF {
    @Override // hivemall.classifier.multiclass.MulticlassOnlineClassifierUDTF
    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        int length = objectInspectorArr.length;
        if (length == 2 || length == 3) {
            return super.initialize(objectInspectorArr);
        }
        throw new UDFArgumentException("MulticlassPerceptronUDTF takes 2 or 3 arguments: List<Text|Int|BitInt> features, {Int|Text} label [, constant text options]");
    }

    @Override // hivemall.classifier.multiclass.MulticlassOnlineClassifierUDTF
    protected void train(@Nonnull FeatureValue[] featureValueArr, @Nonnull Object obj) {
        Object label = classify(featureValueArr).getLabel();
        if (obj.equals(label)) {
            return;
        }
        update(featureValueArr, 1.0f, obj, label);
    }
}
