package hivemall.fm;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import java.util.ArrayList;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;

@Description(name = "fm_predict", value = "_FUNC_(Float Wj, array<float> Vjf, float Xj) - Returns a prediction value in Double")
/* loaded from: input_file:hivemall/fm/FMPredictGenericUDAF.class */
public final class FMPredictGenericUDAF extends AbstractGenericUDAFResolver {

    /* loaded from: input_file:hivemall/fm/FMPredictGenericUDAF$Evaluator.class */
    public static class Evaluator extends GenericUDAFEvaluator {
        private PrimitiveObjectInspector wOI;
        private ListObjectInspector vOI;
        private PrimitiveObjectInspector vElemOI;
        private PrimitiveObjectInspector xOI;
        private StructObjectInspector internalMergeOI;
        private StructField retField;
        private StructField sumVjXjField;
        private StructField sumV2X2Field;
        private WritableDoubleObjectInspector retOI;
        private StandardListObjectInspector sumVjXjOI;
        private StandardListObjectInspector sumV2X2OI;
        static final /* synthetic */ boolean $assertionsDisabled;

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] objectInspectorArr) throws HiveException {
            if (!$assertionsDisabled && objectInspectorArr.length != 3) {
                throw new AssertionError();
            }
            super.init(mode, objectInspectorArr);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.wOI = HiveUtils.asDoubleCompatibleOI(objectInspectorArr[0]);
                this.vOI = HiveUtils.asListOI(objectInspectorArr[1]);
                this.vElemOI = HiveUtils.asDoubleCompatibleOI(this.vOI.getListElementObjectInspector());
                this.xOI = HiveUtils.asDoubleCompatibleOI(objectInspectorArr[2]);
            } else {
                StructObjectInspector structObjectInspector = (StructObjectInspector) objectInspectorArr[0];
                this.internalMergeOI = structObjectInspector;
                this.retField = structObjectInspector.getStructFieldRef("ret");
                this.sumVjXjField = structObjectInspector.getStructFieldRef("sumVjXj");
                this.sumV2X2Field = structObjectInspector.getStructFieldRef("sumV2X2");
                this.retOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
                this.sumVjXjOI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
                this.sumV2X2OI = ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            }
            return (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2) ? internalMergeOI() : PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
        }

        private static StructObjectInspector internalMergeOI() {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            arrayList.add("ret");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            arrayList.add("sumVjXj");
            arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
            arrayList.add("sumV2X2");
            arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
            return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
        }

        /* renamed from: getNewAggregationBuffer, reason: merged with bridge method [inline-methods] */
        public FMPredictAggregationBuffer m84getNewAggregationBuffer() throws HiveException {
            FMPredictAggregationBuffer fMPredictAggregationBuffer = new FMPredictAggregationBuffer();
            fMPredictAggregationBuffer.reset();
            return fMPredictAggregationBuffer;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            ((FMPredictAggregationBuffer) aggregationBuffer).reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object[] objArr) throws HiveException {
            if (objArr[0] == null) {
                return;
            }
            FMPredictAggregationBuffer fMPredictAggregationBuffer = (FMPredictAggregationBuffer) aggregationBuffer;
            double d = PrimitiveObjectInspectorUtils.getDouble(objArr[0], this.wOI);
            if (objArr[1] == null || this.vOI.getListLength(objArr[1]) == 0) {
                fMPredictAggregationBuffer.iterate(d);
            } else {
                if (objArr[2] == null) {
                    throw new UDFArgumentException("The third argument Xj must not be null");
                }
                fMPredictAggregationBuffer.iterate(d, PrimitiveObjectInspectorUtils.getDouble(objArr[2], this.xOI), objArr[1], this.vOI, this.vElemOI);
            }
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            FMPredictAggregationBuffer fMPredictAggregationBuffer = (FMPredictAggregationBuffer) aggregationBuffer;
            Object[] objArr = new Object[3];
            objArr[0] = new DoubleWritable(fMPredictAggregationBuffer.ret);
            if (fMPredictAggregationBuffer.sumVjXj != null) {
                objArr[1] = WritableUtils.toWritableList(fMPredictAggregationBuffer.sumVjXj);
                objArr[2] = WritableUtils.toWritableList(fMPredictAggregationBuffer.sumV2X2);
            }
            return objArr;
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object obj) throws HiveException {
            if (obj == null) {
                return;
            }
            FMPredictAggregationBuffer fMPredictAggregationBuffer = (FMPredictAggregationBuffer) aggregationBuffer;
            double d = this.retOI.get(this.internalMergeOI.getStructFieldData(obj, this.retField));
            Object structFieldData = this.internalMergeOI.getStructFieldData(obj, this.sumVjXjField);
            Object structFieldData2 = this.internalMergeOI.getStructFieldData(obj, this.sumV2X2Field);
            if (structFieldData instanceof LazyBinaryArray) {
                structFieldData = ((LazyBinaryArray) structFieldData).getList();
            }
            if (structFieldData2 instanceof LazyBinaryArray) {
                structFieldData2 = ((LazyBinaryArray) structFieldData2).getList();
            }
            fMPredictAggregationBuffer.merge(d, structFieldData, structFieldData2, this.sumVjXjOI, this.sumV2X2OI);
        }

        /* renamed from: terminate, reason: merged with bridge method [inline-methods] */
        public DoubleWritable m83terminate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            return new DoubleWritable(((FMPredictAggregationBuffer) aggregationBuffer).getPrediction());
        }

        static {
            $assertionsDisabled = !FMPredictGenericUDAF.class.desiredAssertionStatus();
        }
    }

    @GenericUDAFEvaluator.AggregationType(estimable = true)
    /* loaded from: input_file:hivemall/fm/FMPredictGenericUDAF$FMPredictAggregationBuffer.class */
    public static class FMPredictAggregationBuffer extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        private double ret;
        private double[] sumVjXj;
        private double[] sumV2X2;

        FMPredictAggregationBuffer() {
        }

        void reset() {
            this.ret = CMAESOptimizer.DEFAULT_STOPFITNESS;
            this.sumVjXj = null;
            this.sumV2X2 = null;
        }

        void iterate(double d) {
            this.ret += d;
        }

        void iterate(double d, double d2, @Nonnull Object obj, @Nonnull ListObjectInspector listObjectInspector, @Nonnull PrimitiveObjectInspector primitiveObjectInspector) throws HiveException {
            this.ret += d * d2;
            int listLength = listObjectInspector.getListLength(obj);
            if (listLength < 1) {
                throw new HiveException("# of Factor should be more than 0: " + listLength);
            }
            if (this.sumVjXj == null) {
                this.sumVjXj = new double[listLength];
                this.sumV2X2 = new double[listLength];
            } else if (this.sumVjXj.length != listLength) {
                throw new HiveException("Mismatch in the number of factors");
            }
            for (int i = 0; i < listLength; i++) {
                Object listElement = listObjectInspector.getListElement(obj, i);
                if (listElement == null) {
                    throw new HiveException("Vj" + i + " should not be null");
                }
                double d3 = PrimitiveObjectInspectorUtils.getDouble(listElement, primitiveObjectInspector) * d2;
                double[] dArr = this.sumVjXj;
                int i2 = i;
                dArr[i2] = dArr[i2] + d3;
                double[] dArr2 = this.sumV2X2;
                int i3 = i;
                dArr2[i3] = dArr2[i3] + (d3 * d3);
            }
        }

        void merge(double d, @Nullable Object obj, @Nullable Object obj2, @Nonnull StandardListObjectInspector standardListObjectInspector, @Nonnull StandardListObjectInspector standardListObjectInspector2) throws HiveException {
            this.ret += d;
            if (obj == null) {
                return;
            }
            if (obj2 == null) {
                throw new HiveException("o_sumV2X2 should not be null");
            }
            int listLength = standardListObjectInspector.getListLength(obj);
            if (this.sumVjXj == null) {
                this.sumVjXj = new double[listLength];
                this.sumV2X2 = new double[listLength];
            } else if (this.sumVjXj.length != listLength) {
                throw new HiveException("Mismatch in the number of factors");
            }
            WritableDoubleObjectInspector writableDoubleObjectInspector = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
            for (int i = 0; i < listLength; i++) {
                Object listElement = standardListObjectInspector.getListElement(obj, i);
                Object listElement2 = standardListObjectInspector2.getListElement(obj2, i);
                double d2 = writableDoubleObjectInspector.get(listElement);
                double d3 = writableDoubleObjectInspector.get(listElement2);
                double[] dArr = this.sumVjXj;
                int i2 = i;
                dArr[i2] = dArr[i2] + d2;
                double[] dArr2 = this.sumV2X2;
                int i3 = i;
                dArr2[i3] = dArr2[i3] + d3;
            }
        }

        double getPrediction() {
            double d = this.ret;
            if (this.sumVjXj != null) {
                int length = this.sumVjXj.length;
                for (int i = 0; i < length; i++) {
                    double d2 = this.sumVjXj[i];
                    d += 0.5d * ((d2 * d2) - this.sumV2X2[i]);
                }
            }
            return d;
        }

        public int estimate() {
            if (this.sumVjXj == null) {
                return 24;
            }
            return 8 + (2 * (32 + (8 * this.sumVjXj.length)));
        }
    }

    private FMPredictGenericUDAF() {
    }

    /* renamed from: getEvaluator, reason: merged with bridge method [inline-methods] */
    public Evaluator m81getEvaluator(TypeInfo[] typeInfoArr) throws SemanticException {
        if (typeInfoArr.length != 3) {
            throw new UDFArgumentLengthException("Expected argument length is 3 but given argument length was " + typeInfoArr.length);
        }
        if (!HiveUtils.isNumberTypeInfo(typeInfoArr[0])) {
            throw new UDFArgumentTypeException(0, "Number type is expected for the first argument Wj: " + typeInfoArr[0].getTypeName());
        }
        if (typeInfoArr[1].getCategory() != ObjectInspector.Category.LIST) {
            throw new UDFArgumentTypeException(1, "List type is expected for the second argument Vjf: " + typeInfoArr[1].getTypeName());
        }
        ListTypeInfo listTypeInfo = (ListTypeInfo) typeInfoArr[1];
        if (!HiveUtils.isNumberTypeInfo(listTypeInfo.getListElementTypeInfo())) {
            throw new UDFArgumentTypeException(1, "Number type is expected for the element type of list Vjf: " + listTypeInfo.getTypeName());
        }
        if (HiveUtils.isNumberTypeInfo(typeInfoArr[2])) {
            return new Evaluator();
        }
        throw new UDFArgumentTypeException(2, "Number type is expected for the third argument Xj: " + typeInfoArr[2].getTypeName());
    }
}
