package hivemall.xgboost.tools;

import hivemall.utils.lang.Preconditions;
import hivemall.xgboost.XGBoostPredictUDTF;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

@Description(name = "xgboost_predict", value = "_FUNC_(string rowid, string[] features, string model_id, array<byte> pred_model [, string options]) - Returns a prediction result as (string rowid, float predicted)")
/* loaded from: input_file:hivemall/xgboost/tools/XGBoostPredictUDTF.class */
public final class XGBoostPredictUDTF extends hivemall.xgboost.XGBoostPredictUDTF {
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // hivemall.xgboost.XGBoostPredictUDTF
    protected StructObjectInspector getReturnOI() {
        ArrayList arrayList = new ArrayList(2);
        ArrayList arrayList2 = new ArrayList(2);
        arrayList.add("rowid");
        arrayList2.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        arrayList.add("predicted");
        arrayList2.add(PrimitiveObjectInspectorFactory.javaFloatObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    @Override // hivemall.xgboost.XGBoostPredictUDTF
    protected void forwardPredicted(@Nonnull List<XGBoostPredictUDTF.LabeledPointWithRowId> list, @Nonnull float[][] fArr) throws HiveException {
        Preconditions.checkArgument(fArr.length == list.size(), HiveException.class);
        Object[] objArr = new Object[2];
        int size = list.size();
        for (int i = 0; i < size; i++) {
            if (!$assertionsDisabled && fArr[i].length != 1) {
                throw new AssertionError();
            }
            String rowId = list.get(i).getRowId();
            float f = fArr[i][0];
            objArr[0] = rowId;
            objArr[1] = Float.valueOf(f);
            forward(objArr);
        }
    }

    static {
        $assertionsDisabled = !XGBoostPredictUDTF.class.desiredAssertionStatus();
    }
}
