package hivemall.tools.vector;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.StringUtils;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;

@UDFType(deterministic = true, stateful = false)
@Description(name = "vector_dot", value = "_FUNC_(array<NUMBER> x, array<NUMBER> y) - Performs vector dot product.", extended = "SELECT vector_dot(array(1.0,2.0,3.0),array(2.0,3.0,4.0));\n20\n\nSELECT vector_dot(array(1.0,2.0,3.0),2);\n[2.0,4.0,6.0]")
/* loaded from: input_file:hivemall/tools/vector/VectorDotUDF.class */
public final class VectorDotUDF extends GenericUDF {
    private Evaluator evaluator;

    /* loaded from: input_file:hivemall/tools/vector/VectorDotUDF$Dot2DVectors.class */
    static final class Dot2DVectors implements Evaluator {
        private static final long serialVersionUID = -8783159823009951347L;
        private final ListObjectInspector xListOI;
        private final ListObjectInspector yListOI;
        private final PrimitiveObjectInspector xElemOI;
        private final PrimitiveObjectInspector yElemOI;

        Dot2DVectors(@Nonnull ListObjectInspector listObjectInspector, @Nonnull ListObjectInspector listObjectInspector2) throws UDFArgumentTypeException {
            this.xListOI = listObjectInspector;
            this.yListOI = listObjectInspector2;
            this.xElemOI = HiveUtils.asNumberOI(listObjectInspector.getListElementObjectInspector());
            this.yElemOI = HiveUtils.asNumberOI(listObjectInspector2.getListElementObjectInspector());
        }

        @Override // hivemall.tools.vector.VectorDotUDF.Evaluator
        public Double dot(@Nonnull Object obj, @Nonnull Object obj2) throws HiveException {
            int listLength = this.xListOI.getListLength(obj);
            if (listLength != this.yListOI.getListLength(obj2)) {
                throw new HiveException("vector lengths do not match. x=" + this.xListOI.getList(obj) + ", y=" + this.yListOI.getList(obj2));
            }
            double d = 0.0d;
            for (int i = 0; i < listLength; i++) {
                Object listElement = this.xListOI.getListElement(obj, i);
                Object listElement2 = this.yListOI.getListElement(obj2, i);
                if (listElement != null && listElement2 != null) {
                    d += PrimitiveObjectInspectorUtils.getDouble(listElement, this.xElemOI) * PrimitiveObjectInspectorUtils.getDouble(listElement2, this.yElemOI);
                }
            }
            return Double.valueOf(d);
        }
    }

    /* loaded from: input_file:hivemall/tools/vector/VectorDotUDF$Evaluator.class */
    interface Evaluator extends Serializable {
        @Nonnull
        Object dot(@Nonnull Object obj, @Nonnull Object obj2) throws HiveException;
    }

    /* loaded from: input_file:hivemall/tools/vector/VectorDotUDF$Multiply2D1D.class */
    static final class Multiply2D1D implements Evaluator {
        private static final long serialVersionUID = -9090211833041797311L;
        private final ListObjectInspector xListOI;
        private final PrimitiveObjectInspector xElemOI;
        private final PrimitiveObjectInspector yOI;

        Multiply2D1D(@Nonnull ListObjectInspector listObjectInspector, @Nonnull ObjectInspector objectInspector) throws UDFArgumentTypeException {
            this.xListOI = listObjectInspector;
            this.xElemOI = HiveUtils.asNumberOI(listObjectInspector.getListElementObjectInspector());
            this.yOI = HiveUtils.asNumberOI(objectInspector);
        }

        @Override // hivemall.tools.vector.VectorDotUDF.Evaluator
        public List<Double> dot(@Nonnull Object obj, @Nonnull Object obj2) throws HiveException {
            double d = PrimitiveObjectInspectorUtils.getDouble(obj2, this.yOI);
            int listLength = this.xListOI.getListLength(obj);
            Double[] dArr = new Double[listLength];
            for (int i = 0; i < listLength; i++) {
                Object listElement = this.xListOI.getListElement(obj, i);
                if (listElement != null) {
                    dArr[i] = Double.valueOf(PrimitiveObjectInspectorUtils.getDouble(listElement, this.xElemOI) * d);
                }
            }
            return Arrays.asList(dArr);
        }
    }

    public ObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length != 2) {
            throw new UDFArgumentLengthException("Expected 2 arguments, but got " + objectInspectorArr.length);
        }
        ObjectInspector objectInspector = objectInspectorArr[0];
        if (!HiveUtils.isNumberListOI(objectInspector)) {
            throw new UDFArgumentException("Expected array<number> for the first argument: " + objectInspector.getTypeName());
        }
        ListObjectInspector asListOI = HiveUtils.asListOI(objectInspector);
        ObjectInspector objectInspector2 = objectInspectorArr[1];
        if (HiveUtils.isNumberListOI(objectInspector2)) {
            this.evaluator = new Dot2DVectors(asListOI, HiveUtils.asListOI(objectInspector2));
            return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
        }
        if (!HiveUtils.isNumberOI(objectInspector2)) {
            throw new UDFArgumentException("Expected array<number> or number for the send argument: " + objectInspector2.getTypeName());
        }
        this.evaluator = new Multiply2D1D(asListOI, objectInspector2);
        return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector);
    }

    public Object evaluate(GenericUDF.DeferredObject[] deferredObjectArr) throws HiveException {
        Object obj = deferredObjectArr[0].get();
        Object obj2 = deferredObjectArr[1].get();
        if (obj == null || obj2 == null) {
            return null;
        }
        return this.evaluator.dot(obj, obj2);
    }

    public String getDisplayString(String[] strArr) {
        return "vector_dot(" + StringUtils.join((Object[]) strArr, ',') + ")";
    }
}
