package hivemall.dataset;

import hivemall.UDTFWithOptions;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Comparator;
import java.util.Random;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

@Description(name = "lr_datagen", value = "_FUNC_(options string) - Generates a logistic regression dataset", extended = "WITH dual AS (SELECT 1) SELECT lr_datagen('-n_examples 1k -n_features 10') FROM dual;")
/* loaded from: input_file:hivemall/dataset/LogisticRegressionDataGeneratorUDTF.class */
public final class LogisticRegressionDataGeneratorUDTF extends UDTFWithOptions {
    private static final int N_BUFFERS = 1000;
    private int position;
    private float[] labels;
    private String[][] featuresArray;
    private Float[][] featuresFloatArray;
    private int n_examples;
    private int n_features;
    private int n_dimensions;
    private float eps;
    private float prob_one;
    private long r_seed;
    private boolean dense;
    private boolean sort;
    private boolean classification;
    private Random rnd1 = null;
    private Random rnd2 = null;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // hivemall.UDTFWithOptions
    protected Options getOptions() {
        Options options = new Options();
        options.addOption("ne", "n_examples", true, "Number of training examples created for each task [DEFAULT: 1000]");
        options.addOption("nf", "n_features", true, "Number of features contained for each example [DEFAULT: 10]");
        options.addOption("nd", "n_dims", true, "The size of feature dimensions [DEFAULT: 200]");
        options.addOption("eps", true, "eps Epsilon factor by which positive examples are scaled [DEFAULT: 3.0]");
        options.addOption("p1", "prob_one", true, " Probability in [0, 1.0) that a label is 1 [DEFAULT: 0.6]");
        options.addOption("seed", true, "The seed value for random number generator [DEFAULT: 43L]");
        options.addOption("dense", false, "Make a dense dataset or not. If not specified, a sparse dataset is generated.\nFor sparse, n_dims should be much larger than n_features. When disabled, n_features must be equals to n_dims ");
        options.addOption("sort", false, "Sort features if specified (used only for sparse dataset)");
        options.addOption("cl", "classification", false, "Toggle this option on to generate a classification dataset");
        return options;
    }

    @Override // hivemall.UDTFWithOptions
    protected CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length != 1) {
            throw new UDFArgumentException("Expected number of arguments is 1: " + objectInspectorArr.length);
        }
        CommandLine parseOptions = parseOptions(HiveUtils.getConstString(objectInspectorArr[0]));
        this.n_examples = NumberUtils.parseInt(parseOptions.getOptionValue("n_examples"), 1000);
        this.n_features = NumberUtils.parseInt(parseOptions.getOptionValue("n_features"), 10);
        this.n_dimensions = NumberUtils.parseInt(parseOptions.getOptionValue("n_dims"), 200);
        this.eps = Primitives.parseFloat(parseOptions.getOptionValue("eps"), 3.0f);
        this.prob_one = Primitives.parseFloat(parseOptions.getOptionValue("prob_one"), 0.6f);
        this.r_seed = Primitives.parseLong(parseOptions.getOptionValue("seed"), 43L);
        this.dense = parseOptions.hasOption("dense");
        this.sort = parseOptions.hasOption("sort");
        this.classification = parseOptions.hasOption("classification");
        if (this.n_features > this.n_dimensions) {
            throw new UDFArgumentException("n_features '" + this.n_features + "' should be greater than or equals to n_dimensions '" + this.n_dimensions + "'");
        }
        if (!this.dense || this.n_features == this.n_dimensions) {
            return parseOptions;
        }
        throw new UDFArgumentException("n_features '" + this.n_features + "' must be equals to n_dimensions '" + this.n_dimensions + "' when making a dense dataset");
    }

    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        processOptions(objectInspectorArr);
        init();
        ArrayList arrayList = new ArrayList(2);
        ArrayList arrayList2 = new ArrayList(2);
        arrayList.add("label");
        arrayList2.add(PrimitiveObjectInspectorFactory.javaFloatObjectInspector);
        arrayList.add("features");
        if (this.dense) {
            arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector));
        } else {
            arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector));
        }
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    private void init() {
        this.labels = new float[1000];
        if (this.dense) {
            this.featuresFloatArray = new Float[1000][this.n_features];
        } else {
            this.featuresArray = new String[1000][this.n_features];
        }
        this.position = 0;
    }

    public void process(Object[] objArr) throws HiveException {
        if (this.rnd1 == null) {
            if (!$assertionsDisabled && this.rnd2 != null) {
                throw new AssertionError();
            }
            int taskId = HadoopUtils.getTaskId(-1);
            long j = taskId == -1 ? this.r_seed : this.r_seed + taskId;
            this.rnd1 = new Random(j);
            this.rnd2 = new Random(j + 1);
        }
        for (int i = 0; i < this.n_examples; i++) {
            if (this.dense) {
                generateDenseData();
            } else {
                generateSparseData();
            }
            this.position++;
            if (this.position == 1000) {
                flushBuffered(this.position);
                this.position = 0;
            }
        }
    }

    private void generateSparseData() throws HiveException {
        float nextFloat = this.rnd1.nextFloat();
        float f = nextFloat <= this.prob_one ? 1.0f : 0.0f;
        this.labels[this.position] = this.classification ? f : nextFloat;
        String[] strArr = this.featuresArray[this.position];
        if (!$assertionsDisabled && strArr == null) {
            throw new AssertionError();
        }
        BitSet bitSet = new BitSet(this.n_dimensions);
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        while (i2 < this.n_features) {
            int nextInt = this.rnd2.nextInt(this.n_dimensions);
            if (bitSet.get(nextInt)) {
                if (i3 < 3) {
                    i2--;
                    i3++;
                    i2++;
                } else {
                    i = bitSet.nextClearBit(i);
                    nextInt = i;
                }
            }
            bitSet.set(nextInt);
            strArr[i2] = nextInt + ":" + (((float) this.rnd2.nextGaussian()) + (f * this.eps));
            i3 = 0;
            i2++;
        }
        if (this.sort) {
            Arrays.sort(strArr, new Comparator<String>() { // from class: hivemall.dataset.LogisticRegressionDataGeneratorUDTF.1
                @Override // java.util.Comparator
                public int compare(String str, String str2) {
                    return Primitives.compare(Integer.parseInt(str.split(":")[0]), Integer.parseInt(str2.split(":")[0]));
                }
            });
        }
    }

    private void generateDenseData() {
        float nextFloat = this.rnd1.nextFloat();
        float f = nextFloat >= this.prob_one ? 1.0f : 0.0f;
        this.labels[this.position] = this.classification ? f : nextFloat;
        Float[] fArr = this.featuresFloatArray[this.position];
        if (!$assertionsDisabled && fArr == null) {
            throw new AssertionError();
        }
        for (int i = 0; i < this.n_features; i++) {
            fArr[i] = Float.valueOf(((float) this.rnd2.nextGaussian()) + (f * this.eps));
        }
    }

    private void flushBuffered(int i) throws HiveException {
        Object[] objArr = new Object[2];
        if (this.dense) {
            for (int i2 = 0; i2 < i; i2++) {
                objArr[0] = Float.valueOf(this.labels[i2]);
                objArr[1] = Arrays.asList(this.featuresFloatArray[i2]);
                forward(objArr);
            }
            return;
        }
        for (int i3 = 0; i3 < i; i3++) {
            objArr[0] = Float.valueOf(this.labels[i3]);
            objArr[1] = Arrays.asList(this.featuresArray[i3]);
            forward(objArr);
        }
    }

    public void close() throws HiveException {
        if (this.position > 0) {
            flushBuffered(this.position);
        }
        this.labels = null;
        this.featuresArray = (String[][]) null;
        this.featuresFloatArray = (Float[][]) null;
    }

    static {
        $assertionsDisabled = !LogisticRegressionDataGeneratorUDTF.class.desiredAssertionStatus();
    }
}
