package hivemall.mf;

import hivemall.UDTFWithOptions;
import hivemall.common.ConversionState;
import hivemall.mf.FactorizedModel;
import hivemall.optimizer.EtaEstimator;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.NioFixedSegment;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.math.MathUtils;
import it.unimi.dsi.fastutil.io.InspectableFileCachedInputStream;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;

@Description(name = "train_bprmf", value = "_FUNC_(INT user, INT posItem, INT negItem [, String options]) - Returns a relation <INT i, FLOAT Pi, FLOAT Qi [, FLOAT Bi]>")
/* loaded from: input_file:hivemall/mf/BPRMatrixFactorizationUDTF.class */
public final class BPRMatrixFactorizationUDTF extends UDTFWithOptions implements RatingInitializer {
    private static final Log LOG;
    private static final int RECORD_BYTES = 12;
    protected int factor = 10;
    protected float regU = 0.0025f;
    protected float regI = 0.0025f;
    protected float regJ = 0.00125f;
    protected float regBias = 0.01f;
    protected boolean useBiasClause = true;
    protected int iterations = 30;
    protected LossFunction lossFunction;
    protected FactorizedModel.RankInitScheme rankInit;
    protected EtaEstimator etaEstimator;
    protected long count;
    protected ConversionState cvState;
    protected FactorizedModel model;
    protected PrimitiveObjectInspector userOI;
    protected PrimitiveObjectInspector posItemOI;
    protected PrimitiveObjectInspector negItemOI;
    protected NioFixedSegment fileIO;
    protected ByteBuffer inputBuf;
    private long lastWritePos;
    private float[] uProbe;
    private float[] iProbe;
    private float[] jProbe;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hivemall/mf/BPRMatrixFactorizationUDTF$LossFunction.class */
    public enum LossFunction {
        sigmoid,
        logistic,
        lnLogistic;

        @Nonnull
        public static LossFunction resolve(@Nullable String str) {
            if (str != null && !str.equalsIgnoreCase("lnLogistic")) {
                if (str.equalsIgnoreCase("logistic")) {
                    return logistic;
                }
                if (str.equalsIgnoreCase("sigmoid")) {
                    return sigmoid;
                }
                throw new IllegalArgumentException("Unexpected loss function: " + str);
            }
            return lnLogistic;
        }
    }

    @Override // hivemall.UDTFWithOptions
    protected Options getOptions() {
        Options options = new Options();
        options.addOption("k", "factor", true, "The number of latent factor [default: 10] Alias for `-factors`");
        options.addOption("f", "factors", true, "The number of latent factor [default: 10]");
        options.addOption("iters", "iterations", true, "The number of iterations [default: 30]");
        options.addOption("iter", true, "The number of iterations [default: 30] Alias for `-iterations");
        options.addOption("loss", "loss_function", true, "Loss function [default: lnLogistic, logistic, sigmoid]");
        options.addOption("rankinit", true, "Initialization strategy of rank matrix [random, gaussian] (default: random)");
        options.addOption("maxval", "max_init_value", true, "The maximum initial value in the rank matrix [default: 1.0]");
        options.addOption("min_init_stddev", true, "The minimum standard deviation of initial rank matrix [default: 0.1]");
        options.addOption("reg", "lambda", true, "The regularization factor [default: 0.0025]");
        options.addOption("reg_u", "reg_user", true, "The regularization factor for user [default: 0.0025 (reg)]");
        options.addOption("reg_i", "reg_item", true, "The regularization factor for positive item [default: 0.0025 (reg)]");
        options.addOption("reg_j", true, "The regularization factor for negative item [default: 0.00125 (reg_i/2) ]");
        options.addOption("reg_bias", true, "The regularization factor for bias clause [default: 0.01]");
        options.addOption("disable_bias", "no_bias", false, "Turn off bias clause");
        options.addOption("eta", true, "The initial learning rate [default: 0.001]");
        options.addOption("eta0", true, "The initial learning rate [default 0.3]");
        options.addOption("t", "total_steps", true, "The total number of training examples");
        options.addOption("power_t", true, "The exponent for inverse scaling learning rate [default 0.1]");
        options.addOption("boldDriver", "bold_driver", false, "Whether to use Bold Driver for learning rate [default: false]");
        options.addOption("disable_cv", "disable_cvtest", false, "Whether to disable convergence check [default: enabled]");
        options.addOption("cv_rate", "convergence_rate", true, "Threshold to determine convergence [default: 0.005]");
        return options;
    }

    @Override // hivemall.UDTFWithOptions
    protected CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        CommandLine commandLine = null;
        String str = null;
        String str2 = null;
        float f = 1.0f;
        double d = 0.1d;
        boolean z = true;
        double d2 = 0.005d;
        if (objectInspectorArr.length >= 4) {
            commandLine = parseOptions(HiveUtils.getConstString(objectInspectorArr[3]));
            if (commandLine.hasOption("factor")) {
                this.factor = Primitives.parseInt(commandLine.getOptionValue("factor"), this.factor);
            } else {
                this.factor = Primitives.parseInt(commandLine.getOptionValue("factors"), this.factor);
            }
            if (commandLine.hasOption("iter")) {
                this.iterations = Primitives.parseInt(commandLine.getOptionValue("iter"), this.iterations);
            } else {
                this.iterations = Primitives.parseInt(commandLine.getOptionValue("iterations"), this.iterations);
            }
            if (this.iterations < 1) {
                throw new UDFArgumentException("'-iterations' must be greater than or equals to 1: " + this.iterations);
            }
            str = commandLine.getOptionValue("loss_function");
            float parseFloat = Primitives.parseFloat(commandLine.getOptionValue("reg"), 0.0025f);
            this.regU = Primitives.parseFloat(commandLine.getOptionValue("reg_u"), parseFloat);
            this.regI = Primitives.parseFloat(commandLine.getOptionValue("reg_i"), parseFloat);
            this.regJ = Primitives.parseFloat(commandLine.getOptionValue("reg_j"), this.regI / 2.0f);
            this.regBias = Primitives.parseFloat(commandLine.getOptionValue("reg_bias"), this.regBias);
            str2 = commandLine.getOptionValue("rankinit");
            f = Primitives.parseFloat(commandLine.getOptionValue("max_init_value"), 1.0f);
            d = Primitives.parseDouble(commandLine.getOptionValue("min_init_stddev"), 0.1d);
            z = !commandLine.hasOption("disable_cvtest");
            d2 = Primitives.parseDouble(commandLine.getOptionValue("cv_rate"), 0.005d);
            this.useBiasClause = !commandLine.hasOption("no_bias");
        }
        this.lossFunction = LossFunction.resolve(str);
        this.rankInit = FactorizedModel.RankInitScheme.resolve(str2);
        this.rankInit.setMaxInitValue(f);
        this.rankInit.setInitStdDev(Math.max(d, 1.0d / this.factor));
        this.etaEstimator = EtaEstimator.get(commandLine);
        this.cvState = new ConversionState(z, d2);
        return commandLine;
    }

    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length != 3 && objectInspectorArr.length != 4) {
            throw new UDFArgumentException(getClass().getSimpleName() + " takes 3 or 4 arguments: INT user, INT posItem, INT negItem [, CONSTANT STRING options]");
        }
        this.userOI = HiveUtils.asIntCompatibleOI(objectInspectorArr[0]);
        this.posItemOI = HiveUtils.asIntCompatibleOI(objectInspectorArr[1]);
        this.negItemOI = HiveUtils.asIntCompatibleOI(objectInspectorArr[2]);
        processOptions(objectInspectorArr);
        this.model = new FactorizedModel(this, this.factor, this.rankInit);
        this.count = 0L;
        this.lastWritePos = 0L;
        this.uProbe = new float[this.factor];
        this.iProbe = new float[this.factor];
        this.jProbe = new float[this.factor];
        if (this.mapredContext != null && this.iterations > 1) {
            try {
                File createTempFile = File.createTempFile("hivemall_bprmf", ".sgmt");
                createTempFile.deleteOnExit();
                if (!createTempFile.canWrite()) {
                    throw new UDFArgumentException("Cannot write a temporary file: " + createTempFile.getAbsolutePath());
                }
                this.fileIO = new NioFixedSegment(createTempFile, 12, false);
                this.inputBuf = ByteBuffer.allocateDirect(InspectableFileCachedInputStream.DEFAULT_BUFFER_SIZE);
            } catch (IOException e) {
                throw new UDFArgumentException(e);
            } catch (Throwable th) {
                throw new UDFArgumentException(th);
            }
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add("idx");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        arrayList.add("Pu");
        arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
        arrayList.add("Qi");
        arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
        if (this.useBiasClause) {
            arrayList.add("Bi");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        }
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    public void process(Object[] objArr) throws HiveException {
        if (!$assertionsDisabled && objArr.length < 3) {
            throw new AssertionError(objArr.length);
        }
        int i = PrimitiveObjectInspectorUtils.getInt(objArr[0], this.userOI);
        int i2 = PrimitiveObjectInspectorUtils.getInt(objArr[1], this.posItemOI);
        int i3 = PrimitiveObjectInspectorUtils.getInt(objArr[2], this.negItemOI);
        validateInput(i, i2, i3);
        beforeTrain(this.count, i, i2, i3);
        this.count++;
        train(i, i2, i3);
    }

    protected void beforeTrain(long j, int i, int i2, int i3) throws HiveException {
        if (this.inputBuf != null) {
            if (!$assertionsDisabled && this.fileIO == null) {
                throw new AssertionError();
            }
            ByteBuffer byteBuffer = this.inputBuf;
            if (byteBuffer.remaining() < 12) {
                writeBuffer(byteBuffer, this.fileIO, this.lastWritePos);
                this.lastWritePos = j;
            }
            byteBuffer.putInt(i);
            byteBuffer.putInt(i2);
            byteBuffer.putInt(i3);
        }
    }

    protected void train(int i, int i2, int i3) {
        Rating[] userVector = this.model.getUserVector(i, true);
        Rating[] itemVector = this.model.getItemVector(i2, true);
        Rating[] itemVector2 = this.model.getItemVector(i3, true);
        copyToProbe(userVector, this.uProbe);
        copyToProbe(itemVector, this.iProbe);
        copyToProbe(itemVector2, this.jProbe);
        double dloss = dloss(predict(i, i2, this.uProbe, this.iProbe) - predict(i, i3, this.uProbe, this.jProbe), this.lossFunction);
        float eta = eta();
        int i4 = this.factor;
        for (int i5 = 0; i5 < i4; i5++) {
            float f = this.uProbe[i5];
            float f2 = this.iProbe[i5];
            float f3 = this.jProbe[i5];
            updateUserRating(userVector[i5], f, f2, f3, dloss, eta);
            updateItemRating(itemVector[i5], f, f2, dloss, eta, this.regI);
            updateItemRating(itemVector2[i5], f, f3, -dloss, eta, this.regJ);
        }
        if (this.useBiasClause) {
            updateBias(i2, i3, dloss, eta);
        }
    }

    protected double predict(int i, int i2, @Nonnull float[] fArr, @Nonnull float[] fArr2) {
        double itemBias = this.model.getItemBias(i2);
        for (int i3 = 0; i3 < this.factor; i3++) {
            itemBias += fArr[i3] * fArr2[i3];
        }
        if (NumberUtils.isFinite(itemBias)) {
            return itemBias;
        }
        throw new IllegalStateException("Detected " + itemBias + " in predict where user=" + i + " and item=" + i2);
    }

    protected double dloss(double d, @Nonnull LossFunction lossFunction) {
        switch (lossFunction) {
            case sigmoid:
                return 1.0d / (1.0d + Math.exp(d));
            case logistic:
                double sigmoid = MathUtils.sigmoid(d);
                return sigmoid * (1.0d - sigmoid);
            case lnLogistic:
                double exp = Math.exp(-d);
                return exp / (1.0d + exp);
            default:
                throw new IllegalStateException("Unexpected loss function: " + lossFunction);
        }
    }

    protected float eta() {
        return this.etaEstimator.eta(this.count);
    }

    protected void updateUserRating(Rating rating, float f, float f2, float f3, double d, float f4) {
        float f5 = f + ((float) (f4 * ((d * (f2 - f3)) - (this.regU * f))));
        if (!NumberUtils.isFinite(f5)) {
            throw new IllegalStateException("Detected " + f5 + " for w_uf");
        }
        rating.setWeight(f5);
        this.cvState.incrLoss(this.regU * f * f);
    }

    protected void updateItemRating(Rating rating, float f, float f2, double d, float f3, float f4) {
        float f5 = f2 + ((float) (f3 * ((d * f) - (f4 * f2))));
        if (!NumberUtils.isFinite(f5)) {
            throw new IllegalStateException("Detected " + f5 + " for h_f");
        }
        rating.setWeight(f5);
        this.cvState.incrLoss(f4 * f2 * f2);
    }

    protected void updateBias(int i, int i2, double d, float f) {
        float itemBias = (float) (this.model.getItemBias(i) + (f * (d - (this.regBias * r0))));
        if (!NumberUtils.isFinite(itemBias)) {
            throw new IllegalStateException("Detected " + itemBias + " for Bi");
        }
        this.model.setItemBias(i, itemBias);
        this.cvState.incrLoss(this.regBias * itemBias * itemBias);
        float itemBias2 = (float) (this.model.getItemBias(i2) + (f * ((-d) - (this.regBias * r0))));
        if (!NumberUtils.isFinite(itemBias2)) {
            throw new IllegalStateException("Detected " + itemBias2 + " for Bj");
        }
        this.model.setItemBias(i2, itemBias2);
        this.cvState.incrLoss(this.regBias * itemBias2 * itemBias2);
    }

    public void close() throws HiveException {
        if (this.model != null) {
            if (this.count == 0) {
                this.model = null;
                return;
            }
            if (this.iterations > 1) {
                runIterativeTraining(this.iterations);
            }
            IntWritable intWritable = new IntWritable();
            FloatWritable[] newFloatArray = HiveUtils.newFloatArray(this.factor, 0.0f);
            FloatWritable[] newFloatArray2 = HiveUtils.newFloatArray(this.factor, 0.0f);
            FloatWritable floatWritable = this.useBiasClause ? new FloatWritable() : null;
            Object[] objArr = {intWritable, newFloatArray, newFloatArray2, floatWritable};
            int i = 0;
            int maxIndex = this.model.getMaxIndex();
            for (int minIndex = this.model.getMinIndex(); minIndex <= maxIndex; minIndex++) {
                intWritable.set(minIndex);
                Rating[] userVector = this.model.getUserVector(minIndex);
                if (userVector == null) {
                    objArr[1] = null;
                } else {
                    objArr[1] = newFloatArray;
                    copyTo(userVector, newFloatArray);
                }
                Rating[] itemVector = this.model.getItemVector(minIndex);
                if (itemVector == null) {
                    objArr[2] = null;
                } else {
                    objArr[2] = newFloatArray2;
                    copyTo(itemVector, newFloatArray2);
                }
                if (this.useBiasClause) {
                    floatWritable.set(this.model.getItemBias(minIndex));
                }
                forward(objArr);
                i++;
            }
            this.model = null;
            LOG.info("Forwarded the prediction model of " + i + " rows. [lastLosses=" + this.cvState.getCumulativeLoss() + ", #trainingExamples=" + this.count + "]");
        }
    }

    /* JADX WARN: Code restructure failed: missing block: B:101:0x02d0, code lost:
    
        r6.cvState.multiplyLoss(0.5d);
     */
    /* JADX WARN: Code restructure failed: missing block: B:102:0x02e3, code lost:
    
        if (r6.cvState.isConverged(r0) == false) goto L87;
     */
    /* JADX WARN: Code restructure failed: missing block: B:104:0x02f0, code lost:
    
        if (r6.cvState.isLossIncreased() == false) goto L90;
     */
    /* JADX WARN: Code restructure failed: missing block: B:105:0x02f3, code lost:
    
        r6.etaEstimator.update(1.1f);
     */
    /* JADX WARN: Code restructure failed: missing block: B:107:0x0308, code lost:
    
        r14 = r14 + 1;
     */
    /* JADX WARN: Code restructure failed: missing block: B:108:0x02ff, code lost:
    
        r6.etaEstimator.update(0.5f);
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private final void runIterativeTraining(@javax.annotation.Nonnegative int r7) throws org.apache.hadoop.hive.ql.metadata.HiveException {
        /*
            Method dump skipped, instructions count: 968
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: hivemall.mf.BPRMatrixFactorizationUDTF.runIterativeTraining(int):void");
    }

    @Override // hivemall.mf.RatingInitializer
    public Rating newRating(float f) {
        return new Rating(f);
    }

    private static void validateInput(int i, int i2, int i3) throws HiveException {
        if (i < 0) {
            throw new HiveException("Illegal u index: " + i);
        }
        if (i2 < 0) {
            throw new HiveException("Illegal i index: " + i2);
        }
        if (i3 < 0) {
            throw new HiveException("Illegal j index: " + i3);
        }
    }

    private static void writeBuffer(@Nonnull ByteBuffer byteBuffer, @Nonnull NioFixedSegment nioFixedSegment, long j) throws HiveException {
        byteBuffer.flip();
        try {
            nioFixedSegment.writeRecords(j, byteBuffer);
            byteBuffer.clear();
        } catch (IOException e) {
            throw new HiveException("Exception causes while writing records to : " + j, e);
        }
    }

    @Nonnull
    private final void copyToProbe(@Nonnull Rating[] ratingArr, @Nonnull float[] fArr) {
        int i = this.factor;
        for (int i2 = 0; i2 < i; i2++) {
            fArr[i2] = ratingArr[i2].getWeight();
        }
    }

    private static void copyTo(@Nonnull Rating[] ratingArr, @Nonnull FloatWritable[] floatWritableArr) {
        int length = ratingArr.length;
        for (int i = 0; i < length; i++) {
            floatWritableArr[i].set(ratingArr[i].getWeight());
        }
    }

    static {
        $assertionsDisabled = !BPRMatrixFactorizationUDTF.class.desiredAssertionStatus();
        LOG = LogFactory.getLog(BPRMatrixFactorizationUDTF.class);
    }
}
