package hivemall.evaluation;

import hivemall.utils.hadoop.HiveUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;

@Description(name = "ndcg", value = "_FUNC_(array rankItems, array correctItems [, const int recommendSize = rankItems.size]) - Returns nDCG")
/* loaded from: input_file:hivemall/evaluation/NDCGUDAF.class */
public final class NDCGUDAF extends AbstractGenericUDAFResolver {

    /* loaded from: input_file:hivemall/evaluation/NDCGUDAF$Evaluator.class */
    public static class Evaluator extends GenericUDAFEvaluator {
        private ListObjectInspector recommendListOI;
        private ListObjectInspector truthListOI;
        private PrimitiveObjectInspector recommendSizeOI;
        private StructObjectInspector internalMergeOI;
        private StructField countField;
        private StructField sumField;
        static final /* synthetic */ boolean $assertionsDisabled;

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] objectInspectorArr) throws HiveException {
            if (!$assertionsDisabled && (objectInspectorArr.length < 1 || objectInspectorArr.length > 3)) {
                throw new AssertionError(objectInspectorArr.length);
            }
            super.init(mode, objectInspectorArr);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.recommendListOI = (ListObjectInspector) objectInspectorArr[0];
                this.truthListOI = (ListObjectInspector) objectInspectorArr[1];
                if (objectInspectorArr.length == 3) {
                    this.recommendSizeOI = HiveUtils.asIntegerOI(objectInspectorArr[2]);
                }
            } else {
                StructObjectInspector structObjectInspector = (StructObjectInspector) objectInspectorArr[0];
                this.internalMergeOI = structObjectInspector;
                this.countField = structObjectInspector.getStructFieldRef("count");
                this.sumField = structObjectInspector.getStructFieldRef("sum");
            }
            return (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2) ? internalMergeOI() : PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
        }

        private static StructObjectInspector internalMergeOI() {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            arrayList.add("sum");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            arrayList.add("count");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
            return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
        }

        /* renamed from: getNewAggregationBuffer, reason: merged with bridge method [inline-methods] */
        public NDCGAggregationBuffer m67getNewAggregationBuffer() throws HiveException {
            NDCGAggregationBuffer nDCGAggregationBuffer = new NDCGAggregationBuffer();
            reset(nDCGAggregationBuffer);
            return nDCGAggregationBuffer;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            ((NDCGAggregationBuffer) aggregationBuffer).reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object[] objArr) throws HiveException {
            double nDCG;
            NDCGAggregationBuffer nDCGAggregationBuffer = (NDCGAggregationBuffer) aggregationBuffer;
            List list = this.recommendListOI.getList(objArr[0]);
            if (list == null) {
                list = Collections.emptyList();
            }
            List list2 = this.truthListOI.getList(objArr[1]);
            if (list2 == null) {
                return;
            }
            int size = list.size();
            if (objArr.length == 3) {
                size = PrimitiveObjectInspectorUtils.getInt(objArr[2], this.recommendSizeOI);
                if (size < 0) {
                    throw new UDFArgumentException("The third argument `int recommendSize` must be in greater than or equals to 0: " + size);
                }
            }
            if (!HiveUtils.isStructOI(this.recommendListOI.getListElementObjectInspector())) {
                nDCG = BinaryResponsesMeasures.nDCG(list, list2, size);
            } else {
                ArrayList arrayList = new ArrayList();
                StructObjectInspector listElementObjectInspector = this.recommendListOI.getListElementObjectInspector();
                PrimitiveObjectInspector asDoubleCompatibleOI = HiveUtils.asDoubleCompatibleOI(((StructField) listElementObjectInspector.getAllStructFieldRefs().get(0)).getFieldObjectInspector());
                int size2 = list.size();
                for (int i = 0; i < size2; i++) {
                    List structFieldsDataAsList = listElementObjectInspector.getStructFieldsDataAsList(list.get(i));
                    Object obj = structFieldsDataAsList.get(0);
                    if (obj == null) {
                        throw new UDFArgumentException("Field 0 of a struct field is null: " + structFieldsDataAsList);
                    }
                    arrayList.add(Double.valueOf(PrimitiveObjectInspectorUtils.getDouble(obj, asDoubleCompatibleOI)));
                }
                ArrayList arrayList2 = new ArrayList();
                PrimitiveObjectInspector asDoubleCompatibleOI2 = HiveUtils.asDoubleCompatibleOI(this.truthListOI.getListElementObjectInspector());
                int size3 = list2.size();
                for (int i2 = 0; i2 < size3; i2++) {
                    Object obj2 = list2.get(i2);
                    if (obj2 == null) {
                        throw new UDFArgumentException("Found null in the ground truth: " + list2);
                    }
                    arrayList2.add(Double.valueOf(PrimitiveObjectInspectorUtils.getDouble(obj2, asDoubleCompatibleOI2)));
                }
                nDCG = GradedResponsesMeasures.nDCG(arrayList, arrayList2, size);
            }
            nDCGAggregationBuffer.iterate(nDCG);
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            NDCGAggregationBuffer nDCGAggregationBuffer = (NDCGAggregationBuffer) aggregationBuffer;
            return new Object[]{new DoubleWritable(nDCGAggregationBuffer.sum), new LongWritable(nDCGAggregationBuffer.count)};
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object obj) throws HiveException {
            if (obj == null) {
                return;
            }
            Object structFieldData = this.internalMergeOI.getStructFieldData(obj, this.sumField);
            Object structFieldData2 = this.internalMergeOI.getStructFieldData(obj, this.countField);
            ((NDCGAggregationBuffer) aggregationBuffer).merge(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(structFieldData), PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(structFieldData2));
        }

        /* renamed from: terminate, reason: merged with bridge method [inline-methods] */
        public DoubleWritable m66terminate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            return new DoubleWritable(((NDCGAggregationBuffer) aggregationBuffer).get());
        }

        static {
            $assertionsDisabled = !NDCGUDAF.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hivemall/evaluation/NDCGUDAF$NDCGAggregationBuffer.class */
    public static class NDCGAggregationBuffer extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        double sum;
        long count;

        void reset() {
            this.sum = CMAESOptimizer.DEFAULT_STOPFITNESS;
            this.count = 0L;
        }

        void merge(double d, long j) {
            this.sum += d;
            this.count += j;
        }

        double get() {
            return this.count == 0 ? CMAESOptimizer.DEFAULT_STOPFITNESS : this.sum / this.count;
        }

        void iterate(@Nonnull double d) {
            this.sum += d;
            this.count++;
        }
    }

    public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfoArr) throws SemanticException {
        if (typeInfoArr.length != 2 && typeInfoArr.length != 3) {
            throw new UDFArgumentTypeException(typeInfoArr.length - 1, "_FUNC_ takes two or three arguments");
        }
        ListTypeInfo asListTypeInfo = HiveUtils.asListTypeInfo(typeInfoArr[0]);
        if (!HiveUtils.isPrimitiveTypeInfo(asListTypeInfo.getListElementTypeInfo()) && !HiveUtils.isStructTypeInfo(asListTypeInfo.getListElementTypeInfo())) {
            throw new UDFArgumentTypeException(0, "The first argument `array rankItems` is invalid form: " + typeInfoArr[0]);
        }
        if (HiveUtils.isPrimitiveTypeInfo(HiveUtils.asListTypeInfo(typeInfoArr[1]).getListElementTypeInfo())) {
            return new Evaluator();
        }
        throw new UDFArgumentTypeException(1, "The second argument `array correctItems` is invalid form: " + typeInfoArr[1]);
    }
}
