package hivemall.mf;

import hivemall.UDTFWithOptions;
import hivemall.common.ConversionState;
import hivemall.mf.FactorizedModel;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.NioFixedSegment;
import hivemall.utils.lang.Primitives;
import it.unimi.dsi.fastutil.io.InspectableFileCachedInputStream;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;

/* loaded from: input_file:hivemall/mf/OnlineMatrixFactorizationUDTF.class */
public abstract class OnlineMatrixFactorizationUDTF extends UDTFWithOptions implements RatingInitializer {
    private static final Log logger;
    private static final int RECORD_BYTES = 16;
    protected int factor = 10;
    protected float lambda = 0.03f;
    protected float meanRating = 0.0f;
    protected boolean updateMeanRating = false;
    protected int iterations = 1;
    protected boolean useBiasClause = true;
    protected FactorizedModel.RankInitScheme rankInit;
    protected FactorizedModel model;
    protected long count;
    protected ConversionState cvState;
    protected PrimitiveObjectInspector userOI;
    protected PrimitiveObjectInspector itemOI;
    protected PrimitiveObjectInspector ratingOI;
    protected NioFixedSegment fileIO;
    protected ByteBuffer inputBuf;
    private long lastWritePos;
    private float[] userProbe;
    private float[] itemProbe;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.UDTFWithOptions
    public Options getOptions() {
        Options options = new Options();
        options.addOption("k", "factor", true, "The number of latent factor [default: 10]  Note this is alias for `factors` option.");
        options.addOption("f", "factors", true, "The number of latent factor [default: 10]");
        options.addOption("r", "lambda", true, "The regularization factor [default: 0.03]");
        options.addOption("mu", "mean_rating", true, "The mean rating [default: 0.0]");
        options.addOption("update_mean", "update_mu", false, "Whether update (and return) the mean rating or not");
        options.addOption("rankinit", true, "Initialization strategy of rank matrix [random, gaussian] (default: random)");
        options.addOption("maxval", "max_init_value", true, "The maximum initial value in the rank matrix [default: 1.0]");
        options.addOption("min_init_stddev", true, "The minimum standard deviation of initial rank matrix [default: 0.1]");
        options.addOption("iters", "iterations", true, "The number of iterations [default: 1]");
        options.addOption("iter", true, "The number of iterations [default: 1] Alias for `-iterations`");
        options.addOption("disable_cv", "disable_cvtest", false, "Whether to disable convergence check [default: enabled]");
        options.addOption("cv_rate", "convergence_rate", true, "Threshold to determine convergence [default: 0.005]");
        options.addOption("disable_bias", "no_bias", false, "Turn off bias clause");
        return options;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.UDTFWithOptions
    public CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        CommandLine commandLine = null;
        String str = null;
        float f = 1.0f;
        double d = 0.1d;
        boolean z = true;
        double d2 = 0.005d;
        if (objectInspectorArr.length >= 4) {
            commandLine = parseOptions(HiveUtils.getConstString(objectInspectorArr[3]));
            if (commandLine.hasOption("factors")) {
                this.factor = Primitives.parseInt(commandLine.getOptionValue("factors"), 10);
            } else {
                this.factor = Primitives.parseInt(commandLine.getOptionValue("factor"), 10);
            }
            this.lambda = Primitives.parseFloat(commandLine.getOptionValue("lambda"), 0.03f);
            this.meanRating = Primitives.parseFloat(commandLine.getOptionValue("mu"), 0.0f);
            this.updateMeanRating = commandLine.hasOption("update_mean");
            str = commandLine.getOptionValue("rankinit");
            f = Primitives.parseFloat(commandLine.getOptionValue("max_init_value"), 1.0f);
            d = Primitives.parseDouble(commandLine.getOptionValue("min_init_stddev"), 0.1d);
            if (commandLine.hasOption("iter")) {
                this.iterations = Primitives.parseInt(commandLine.getOptionValue("iter"), 1);
            } else {
                this.iterations = Primitives.parseInt(commandLine.getOptionValue("iterations"), 1);
            }
            if (this.iterations < 1) {
                throw new UDFArgumentException("'-iterations' must be greater than or equal to 1: " + this.iterations);
            }
            z = !commandLine.hasOption("disable_cvtest");
            d2 = Primitives.parseDouble(commandLine.getOptionValue("cv_rate"), 0.005d);
            boolean hasOption = commandLine.hasOption("no_bias");
            this.useBiasClause = !hasOption;
            if (hasOption && this.updateMeanRating) {
                throw new UDFArgumentException("Cannot set both `update_mean` and `no_bias` option");
            }
        }
        this.rankInit = FactorizedModel.RankInitScheme.resolve(str);
        this.rankInit.setMaxInitValue(f);
        this.rankInit.setInitStdDev(Math.max(d, 1.0d / this.factor));
        this.cvState = new ConversionState(z, d2);
        return commandLine;
    }

    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length < 3) {
            throw new UDFArgumentException("_FUNC_ takes 3 arguments: INT user, INT item, FLOAT rating [, CONSTANT STRING options]");
        }
        this.userOI = HiveUtils.asIntCompatibleOI(objectInspectorArr[0]);
        this.itemOI = HiveUtils.asIntCompatibleOI(objectInspectorArr[1]);
        this.ratingOI = HiveUtils.asDoubleCompatibleOI(objectInspectorArr[2]);
        processOptions(objectInspectorArr);
        this.model = new FactorizedModel(this, this.factor, this.meanRating, this.rankInit);
        this.count = 0L;
        this.lastWritePos = 0L;
        this.userProbe = new float[this.factor];
        this.itemProbe = new float[this.factor];
        if (this.mapredContext != null && this.iterations > 1) {
            try {
                File createTempFile = File.createTempFile("hivemall_mf", ".sgmt");
                createTempFile.deleteOnExit();
                if (!createTempFile.canWrite()) {
                    throw new UDFArgumentException("Cannot write a temporary file: " + createTempFile.getAbsolutePath());
                }
                this.fileIO = new NioFixedSegment(createTempFile, 16, false);
                this.inputBuf = ByteBuffer.allocateDirect(InspectableFileCachedInputStream.DEFAULT_BUFFER_SIZE);
            } catch (IOException e) {
                throw new UDFArgumentException(e);
            } catch (Throwable th) {
                throw new UDFArgumentException(th);
            }
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add("idx");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        arrayList.add("Pu");
        arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
        arrayList.add("Qi");
        arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
        if (this.useBiasClause) {
            arrayList.add("Bu");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
            arrayList.add("Bi");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
            if (this.updateMeanRating) {
                arrayList.add("mu");
                arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
            }
        }
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    public Rating newRating(float f) {
        return new Rating(f);
    }

    public final void process(Object[] objArr) throws HiveException {
        if (!$assertionsDisabled && objArr.length < 3) {
            throw new AssertionError(objArr.length);
        }
        int i = PrimitiveObjectInspectorUtils.getInt(objArr[0], this.userOI);
        if (i < 0) {
            throw new HiveException("Illegal user index: " + i);
        }
        int i2 = PrimitiveObjectInspectorUtils.getInt(objArr[1], this.itemOI);
        if (i2 < 0) {
            throw new HiveException("Illegal item index: " + i2);
        }
        double d = PrimitiveObjectInspectorUtils.getDouble(objArr[2], this.ratingOI);
        beforeTrain(this.count, i, i2, d);
        this.count++;
        train(i, i2, d);
    }

    @Nonnull
    protected final float[] copyToUserProbe(@Nonnull Rating[] ratingArr) {
        int i = this.factor;
        for (int i2 = 0; i2 < i; i2++) {
            this.userProbe[i2] = ratingArr[i2].getWeight();
        }
        return this.userProbe;
    }

    @Nonnull
    protected final float[] copyToItemProbe(@Nonnull Rating[] ratingArr) {
        int i = this.factor;
        for (int i2 = 0; i2 < i; i2++) {
            this.itemProbe[i2] = ratingArr[i2].getWeight();
        }
        return this.itemProbe;
    }

    protected void train(int i, int i2, double d) throws HiveException {
        Rating[] userVector = this.model.getUserVector(i, true);
        if (!$assertionsDisabled && userVector == null) {
            throw new AssertionError();
        }
        Rating[] itemVector = this.model.getItemVector(i2, true);
        if (!$assertionsDisabled && itemVector == null) {
            throw new AssertionError();
        }
        float[] copyToUserProbe = copyToUserProbe(userVector);
        float[] copyToItemProbe = copyToItemProbe(itemVector);
        double predict = d - predict(i, i2, copyToUserProbe, copyToItemProbe);
        this.cvState.incrError(Math.abs(predict));
        this.cvState.incrLoss(predict * predict);
        float eta = eta();
        int i3 = this.factor;
        for (int i4 = 0; i4 < i3; i4++) {
            float f = copyToUserProbe[i4];
            float f2 = copyToItemProbe[i4];
            updateItemRating(itemVector[i4], f, f2, predict, eta);
            updateUserRating(userVector[i4], f, f2, predict, eta);
        }
        if (this.useBiasClause) {
            updateBias(i, i2, predict, eta);
            if (this.updateMeanRating) {
                updateMeanRating(predict, eta);
            }
        }
        onUpdate(i, i2, userVector, itemVector, predict);
    }

    protected void beforeTrain(long j, int i, int i2, double d) throws HiveException {
        if (this.inputBuf != null) {
            if (!$assertionsDisabled && this.fileIO == null) {
                throw new AssertionError();
            }
            ByteBuffer byteBuffer = this.inputBuf;
            if (byteBuffer.remaining() < 16) {
                writeBuffer(byteBuffer, this.fileIO, this.lastWritePos);
                this.lastWritePos = j;
            }
            byteBuffer.putInt(i);
            byteBuffer.putInt(i2);
            byteBuffer.putDouble(d);
        }
    }

    protected void onUpdate(int i, int i2, Rating[] ratingArr, Rating[] ratingArr2, double d) throws HiveException {
    }

    protected double predict(int i, int i2, float[] fArr, float[] fArr2) {
        double bias = bias(i, i2);
        for (int i3 = 0; i3 < this.factor; i3++) {
            bias += fArr[i3] * fArr2[i3];
        }
        return bias;
    }

    protected double predict(int i, int i2) throws HiveException {
        if (this.model.getUserVector(i) == null) {
            throw new HiveException("User rating is not found: " + i);
        }
        if (this.model.getItemVector(i2) == null) {
            throw new HiveException("Item rating is not found: " + i2);
        }
        double bias = bias(i, i2);
        for (int i3 = 0; i3 < this.factor; i3++) {
            bias += r0[i3].getWeight() * r0[i3].getWeight();
        }
        return bias;
    }

    protected double bias(int i, int i2) {
        return !this.useBiasClause ? this.model.getMeanRating() : this.model.getMeanRating() + this.model.getUserBias(i) + this.model.getItemBias(i2);
    }

    protected float eta() {
        return 1.0f;
    }

    protected void updateItemRating(Rating rating, float f, float f2, double d, float f3) {
        rating.setWeight(f2 + ((float) (f3 * ((d * f) - (this.lambda * f2)))));
        this.cvState.incrLoss(this.lambda * f2 * f2);
    }

    protected void updateUserRating(Rating rating, float f, float f2, double d, float f3) {
        rating.setWeight(f + ((float) (f3 * ((d * f2) - (this.lambda * f)))));
        this.cvState.incrLoss(this.lambda * f * f);
    }

    protected void updateMeanRating(double d, float f) {
        if (!$assertionsDisabled && !this.updateMeanRating) {
            throw new AssertionError();
        }
        this.model.setMeanRating((float) (this.model.getMeanRating() + (f * d)));
    }

    protected void updateBias(int i, int i2, double d, float f) {
        if (!$assertionsDisabled && !this.useBiasClause) {
            throw new AssertionError();
        }
        this.model.setUserBias(i, (float) (this.model.getUserBias(i) + (f * (d - (this.lambda * r0)))));
        this.cvState.incrLoss(this.lambda * r0 * r0);
        this.model.setItemBias(i2, (float) (this.model.getItemBias(i2) + (f * (d - (this.lambda * r0)))));
        this.cvState.incrLoss(this.lambda * r0 * r0);
    }

    public void close() throws HiveException {
        Object[] objArr;
        if (this.model != null) {
            if (this.count == 0) {
                this.model = null;
                return;
            }
            if (this.iterations > 1) {
                runIterativeTraining(this.iterations);
            }
            IntWritable intWritable = new IntWritable();
            FloatWritable[] newFloatArray = HiveUtils.newFloatArray(this.factor, 0.0f);
            FloatWritable[] newFloatArray2 = HiveUtils.newFloatArray(this.factor, 0.0f);
            FloatWritable floatWritable = new FloatWritable();
            FloatWritable floatWritable2 = new FloatWritable();
            if (!this.updateMeanRating) {
                objArr = this.useBiasClause ? new Object[]{intWritable, newFloatArray, newFloatArray2, floatWritable, floatWritable2} : new Object[]{intWritable, newFloatArray, newFloatArray2};
            } else {
                if (!$assertionsDisabled && !this.useBiasClause) {
                    throw new AssertionError();
                }
                objArr = new Object[]{intWritable, newFloatArray, newFloatArray2, floatWritable, floatWritable2, new FloatWritable(this.model.getMeanRating())};
            }
            int i = 0;
            int maxIndex = this.model.getMaxIndex();
            for (int minIndex = this.model.getMinIndex(); minIndex <= maxIndex; minIndex++) {
                intWritable.set(minIndex);
                Rating[] userVector = this.model.getUserVector(minIndex);
                if (userVector == null) {
                    objArr[1] = null;
                } else {
                    objArr[1] = newFloatArray;
                    copyTo(userVector, newFloatArray);
                }
                Rating[] itemVector = this.model.getItemVector(minIndex);
                if (itemVector == null) {
                    objArr[2] = null;
                } else {
                    objArr[2] = newFloatArray2;
                    copyTo(itemVector, newFloatArray2);
                }
                if (this.useBiasClause) {
                    floatWritable.set(this.model.getUserBias(minIndex));
                    floatWritable2.set(this.model.getItemBias(minIndex));
                }
                forward(objArr);
                i++;
            }
            this.model = null;
            logger.info("Forwarded the prediction model of " + i + " rows. [totalErrors=" + this.cvState.getTotalErrors() + ", lastLosses=" + this.cvState.getCumulativeLoss() + ", #trainingExamples=" + this.count + "]");
        }
    }

    protected static void writeBuffer(@Nonnull ByteBuffer byteBuffer, @Nonnull NioFixedSegment nioFixedSegment, long j) throws HiveException {
        byteBuffer.flip();
        try {
            nioFixedSegment.writeRecords(j, byteBuffer);
            byteBuffer.clear();
        } catch (IOException e) {
            throw new HiveException("Exception causes while writing records to : " + j, e);
        }
    }

    /* JADX WARN: Code restructure failed: missing block: B:95:0x02b1, code lost:
    
        r6.cvState.multiplyLoss(0.5d);
     */
    /* JADX WARN: Code restructure failed: missing block: B:96:0x02c4, code lost:
    
        if (r6.cvState.isConverged(r0) == false) goto L83;
     */
    /* JADX WARN: Code restructure failed: missing block: B:97:0x02ca, code lost:
    
        r14 = r14 + 1;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    protected final void runIterativeTraining(@javax.annotation.Nonnegative int r7) throws org.apache.hadoop.hive.ql.metadata.HiveException {
        /*
            Method dump skipped, instructions count: 905
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: hivemall.mf.OnlineMatrixFactorizationUDTF.runIterativeTraining(int):void");
    }

    private static void copyTo(@Nonnull Rating[] ratingArr, @Nonnull FloatWritable[] floatWritableArr) {
        int length = ratingArr.length;
        for (int i = 0; i < length; i++) {
            floatWritableArr[i].set(ratingArr[i].getWeight());
        }
    }

    static {
        $assertionsDisabled = !OnlineMatrixFactorizationUDTF.class.desiredAssertionStatus();
        logger = LogFactory.getLog(OnlineMatrixFactorizationUDTF.class);
    }
}
