package hivemall.ftvec.ranking;

import hivemall.UDTFWithOptions;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.BitUtils;
import hivemall.utils.lang.Primitives;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Random;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.IntWritable;

@Description(name = "bpr_sampling", value = "_FUNC_(int userId, List<int> posItems [, const string options])- Returns a relation consists of <int userId, int itemId>")
/* loaded from: input_file:hivemall/ftvec/ranking/BprSamplingUDTF.class */
public final class BprSamplingUDTF extends UDTFWithOptions {
    private PrimitiveObjectInspector userOI;
    private ListObjectInspector itemListOI;
    private PrimitiveObjectInspector itemElemOI;

    @Nullable
    private transient PositiveOnlyFeedback feedback;
    private int maxItemId;
    private float samplingRate;
    private boolean withoutReplacement;
    private boolean pairSampling;
    private Object[] forwardObjs;
    private IntWritable userId;
    private IntWritable posItemId;
    private IntWritable negItemId;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // hivemall.UDTFWithOptions
    protected Options getOptions() {
        Options options = new Options();
        options.addOption("sampling", "sampling_rate", true, "Sampling rates of positive items [default: 1.0]");
        options.addOption("without_replacement", false, "Do sampling without-replacement sampling [default: false]");
        options.addOption("uniform_pair_sampling", "pair_sampling", false, "Sampling pairs uniform from feedbacks [default: false]");
        options.addOption("maxcol", "max_itemid", true, "Max item id index [default: -1]");
        return options;
    }

    @Override // hivemall.UDTFWithOptions
    protected CommandLine processOptions(@Nonnull ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        CommandLine commandLine = null;
        int i = -1;
        float f = 1.0f;
        boolean z = false;
        boolean z2 = false;
        if (objectInspectorArr.length >= 3) {
            commandLine = parseOptions(HiveUtils.getConstString(objectInspectorArr[2]));
            i = Primitives.parseInt(commandLine.getOptionValue("max_itemid"), -1);
            z = commandLine.hasOption("without_replacement");
            z2 = commandLine.hasOption("uniform_pair_sampling");
            f = Primitives.parseFloat(commandLine.getOptionValue("sampling_rate"), 1.0f);
            if (z && f > 1.0f) {
                throw new UDFArgumentException("sampling_rate MUST be in less than or equals to 1 where without-replacement is true: " + f);
            }
        }
        this.maxItemId = i;
        this.samplingRate = f;
        this.withoutReplacement = z;
        this.pairSampling = z2;
        return commandLine;
    }

    public StructObjectInspector initialize(@Nonnull ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length != 2 && objectInspectorArr.length != 3) {
            throw new UDFArgumentException("_FUNC_(int userid, array<int> itemid, [, const string options]) takes at least two arguments");
        }
        this.userOI = HiveUtils.asIntegerOI(objectInspectorArr[0]);
        this.itemListOI = HiveUtils.asListOI(objectInspectorArr[1]);
        this.itemElemOI = HiveUtils.asIntegerOI(this.itemListOI.getListElementObjectInspector());
        processOptions(objectInspectorArr);
        this.userId = new IntWritable();
        this.posItemId = new IntWritable();
        this.negItemId = new IntWritable();
        this.forwardObjs = new Object[]{this.userId, this.posItemId, this.negItemId};
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add("user");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        arrayList.add("pos_item");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        arrayList.add("neg_item");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    public void process(@Nonnull Object[] objArr) throws HiveException {
        if (this.feedback == null) {
            this.feedback = this.pairSampling ? new PerEventPositiveOnlyFeedback(this.maxItemId) : new PositiveOnlyFeedback(this.maxItemId);
        }
        int i = PrimitiveObjectInspectorUtils.getInt(objArr[0], this.userOI);
        validateIndex(i);
        addFeedback(i, objArr[1]);
    }

    @Nullable
    private void addFeedback(int i, @Nonnull Object obj) throws UDFArgumentException {
        int listLength = this.itemListOI.getListLength(obj);
        if (listLength == 0) {
            return;
        }
        int maxItemId = this.feedback.getMaxItemId();
        IntArrayList intArrayList = new IntArrayList(listLength);
        for (int i2 = 0; i2 < listLength; i2++) {
            Object listElement = this.itemListOI.getListElement(obj, i2);
            if (listElement != null) {
                int i3 = PrimitiveObjectInspectorUtils.getInt(listElement, this.itemElemOI);
                validateIndex(i3);
                maxItemId = Math.max(i3, maxItemId);
                intArrayList.add(i3);
            }
        }
        this.feedback.addFeedback(i, intArrayList);
        this.feedback.setMaxItemId(maxItemId);
    }

    public void close() throws HiveException {
        int totalFeedbacks = this.feedback.getTotalFeedbacks();
        if (totalFeedbacks == 0) {
            return;
        }
        int i = (int) (totalFeedbacks * this.samplingRate);
        if (!this.pairSampling) {
            if (this.withoutReplacement) {
                uniformUserSamplingWithoutReplacement(this.feedback, i);
                return;
            } else {
                uniformUserSamplingWithReplacement(this.feedback, i);
                return;
            }
        }
        PerEventPositiveOnlyFeedback perEventPositiveOnlyFeedback = (PerEventPositiveOnlyFeedback) this.feedback;
        if (this.withoutReplacement) {
            uniformPairSamplingWithoutReplacement(perEventPositiveOnlyFeedback, i);
        } else {
            uniformPairSamplingWithReplacement(perEventPositiveOnlyFeedback, i);
        }
    }

    private void forward(int i, int i2, int i3) throws HiveException {
        if (!$assertionsDisabled && i < 0) {
            throw new AssertionError(i);
        }
        if (!$assertionsDisabled && i2 < 0) {
            throw new AssertionError(i2);
        }
        if (!$assertionsDisabled && i3 < 0) {
            throw new AssertionError(i3);
        }
        this.userId.set(i);
        this.posItemId.set(i2);
        this.negItemId.set(i3);
        forward(this.forwardObjs);
    }

    private void uniformUserSamplingWithReplacement(@Nonnull PositiveOnlyFeedback positiveOnlyFeedback, int i) throws HiveException {
        int nextInt;
        int numUsers = positiveOnlyFeedback.getNumUsers();
        if (numUsers == 0) {
            return;
        }
        int maxItemId = positiveOnlyFeedback.getMaxItemId();
        if (maxItemId <= 0) {
            throw new HiveException("Invalid maxItemId: " + maxItemId);
        }
        int i2 = maxItemId + 1;
        int[] users = positiveOnlyFeedback.getUsers();
        if (!$assertionsDisabled && users.length != numUsers) {
            throw new AssertionError();
        }
        Random random = new Random(31L);
        int i3 = 0;
        while (i3 < i) {
            int i4 = users[random.nextInt(numUsers)];
            IntArrayList items = positiveOnlyFeedback.getItems(i4, true);
            if (!$assertionsDisabled && items == null) {
                throw new AssertionError(i4);
            }
            int size = items.size();
            if (!$assertionsDisabled && size <= 0) {
                throw new AssertionError(size);
            }
            if (size == i2) {
                i3--;
            } else {
                int fastGet = items.fastGet(random.nextInt(size));
                do {
                    nextInt = random.nextInt(maxItemId);
                } while (items.contains(nextInt));
                forward(i4, fastGet, nextInt);
            }
            i3++;
        }
    }

    private void uniformUserSamplingWithoutReplacement(@Nonnull PositiveOnlyFeedback positiveOnlyFeedback, int i) throws HiveException {
        int nextInt;
        int numUsers = positiveOnlyFeedback.getNumUsers();
        if (numUsers == 0) {
            return;
        }
        int maxItemId = positiveOnlyFeedback.getMaxItemId();
        if (maxItemId <= 0) {
            throw new HiveException("Invalid maxItemId: " + maxItemId);
        }
        int i2 = maxItemId + 1;
        BitSet bitSet = new BitSet(numUsers);
        positiveOnlyFeedback.getUsers(bitSet);
        Random random = new Random(31L);
        int i3 = 0;
        while (i3 < i && numUsers > 0) {
            int nextInt2 = random.nextInt(numUsers);
            int indexOfSetBit = BitUtils.indexOfSetBit(bitSet, nextInt2);
            if (indexOfSetBit == -1) {
                throw new HiveException("Cannot find " + nextInt2 + "-th user among " + numUsers + " users");
            }
            IntArrayList items = positiveOnlyFeedback.getItems(indexOfSetBit, true);
            if (!$assertionsDisabled && items == null) {
                throw new AssertionError(indexOfSetBit);
            }
            int size = items.size();
            if (!$assertionsDisabled && size <= 0) {
                throw new AssertionError(size);
            }
            if (size == i2) {
                i3--;
            } else {
                int nextInt3 = random.nextInt(size);
                int fastGet = items.fastGet(nextInt3);
                do {
                    nextInt = random.nextInt(maxItemId);
                } while (items.contains(nextInt));
                items.remove(nextInt3);
                if (items.isEmpty()) {
                    positiveOnlyFeedback.removeFeedback(indexOfSetBit);
                    bitSet.clear(indexOfSetBit);
                    numUsers--;
                }
                forward(indexOfSetBit, fastGet, nextInt);
            }
            i3++;
        }
    }

    private void uniformPairSamplingWithReplacement(@Nonnull PerEventPositiveOnlyFeedback perEventPositiveOnlyFeedback, int i) throws HiveException {
        int nextInt;
        int totalFeedbacks = perEventPositiveOnlyFeedback.getTotalFeedbacks();
        if (totalFeedbacks == 0) {
            return;
        }
        int maxItemId = perEventPositiveOnlyFeedback.getMaxItemId();
        if (maxItemId <= 0) {
            throw new HiveException("Invalid maxItemId: " + maxItemId);
        }
        Random random = new Random(31L);
        for (int i2 = 0; i2 < i; i2++) {
            int nextInt2 = random.nextInt(totalFeedbacks);
            int user = perEventPositiveOnlyFeedback.getUser(nextInt2);
            int positiveItem = perEventPositiveOnlyFeedback.getPositiveItem(nextInt2);
            IntArrayList items = perEventPositiveOnlyFeedback.getItems(user, true);
            if (!$assertionsDisabled && items == null) {
                throw new AssertionError(user);
            }
            do {
                nextInt = random.nextInt(maxItemId);
            } while (items.contains(nextInt));
            forward(user, positiveItem, nextInt);
        }
    }

    private void uniformPairSamplingWithoutReplacement(@Nonnull PerEventPositiveOnlyFeedback perEventPositiveOnlyFeedback, int i) throws HiveException {
        int nextInt;
        if (perEventPositiveOnlyFeedback.getTotalFeedbacks() == 0) {
            return;
        }
        int maxItemId = perEventPositiveOnlyFeedback.getMaxItemId();
        if (maxItemId <= 0) {
            throw new HiveException("Invalid maxItemId: " + maxItemId);
        }
        Random random = new Random(31L);
        for (int i2 : perEventPositiveOnlyFeedback.getRandomIndex(random)) {
            int user = perEventPositiveOnlyFeedback.getUser(i2);
            int positiveItem = perEventPositiveOnlyFeedback.getPositiveItem(i2);
            IntArrayList items = perEventPositiveOnlyFeedback.getItems(user, true);
            if (!$assertionsDisabled && items == null) {
                throw new AssertionError(user);
            }
            do {
                nextInt = random.nextInt(maxItemId);
            } while (items.contains(nextInt));
            forward(user, positiveItem, nextInt);
        }
    }

    private static void validateIndex(int i) throws UDFArgumentException {
        if (i < 0) {
            throw new UDFArgumentException("Negative index is not allowed: " + i);
        }
    }

    static {
        $assertionsDisabled = !BprSamplingUDTF.class.desiredAssertionStatus();
    }
}
