package hivemall.xgboost.tools;

import hivemall.utils.lang.Preconditions;
import hivemall.xgboost.XGBoostPredictUDTF;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

@Description(name = "xgboost_multiclass_predict", value = "_FUNC_(string rowid, string[] features, string model_id, array<byte> pred_model [, string options]) - Returns a prediction result as (string rowid, string label, float probability)")
/* loaded from: input_file:hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.class */
public final class XGBoostMulticlassPredictUDTF extends hivemall.xgboost.XGBoostPredictUDTF {
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // hivemall.xgboost.XGBoostPredictUDTF
    protected StructObjectInspector getReturnOI() {
        ArrayList arrayList = new ArrayList(3);
        ArrayList arrayList2 = new ArrayList(3);
        arrayList.add("rowid");
        arrayList2.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        arrayList.add("label");
        arrayList2.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        arrayList.add("probability");
        arrayList2.add(PrimitiveObjectInspectorFactory.javaFloatObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    @Override // hivemall.xgboost.XGBoostPredictUDTF
    protected void forwardPredicted(@Nonnull List<XGBoostPredictUDTF.LabeledPointWithRowId> list, @Nonnull float[][] fArr) throws HiveException {
        Preconditions.checkArgument(fArr.length == list.size(), HiveException.class);
        Object[] objArr = new Object[3];
        int size = list.size();
        for (int i = 0; i < size; i++) {
            float[] fArr2 = fArr[i];
            objArr[0] = list.get(i).getRowId();
            if (!$assertionsDisabled && fArr2.length <= 1) {
                throw new AssertionError();
            }
            for (int i2 = 0; i2 < fArr2.length; i2++) {
                objArr[1] = String.valueOf(i2);
                objArr[2] = Float.valueOf(fArr2[i2]);
                forward(objArr);
            }
        }
    }

    static {
        $assertionsDisabled = !XGBoostMulticlassPredictUDTF.class.desiredAssertionStatus();
    }
}
