package hivemall.smile.classification;

import hivemall.UDTFWithOptions;
import hivemall.math.matrix.Matrix;
import hivemall.math.matrix.builders.CSRMatrixBuilder;
import hivemall.math.matrix.builders.MatrixBuilder;
import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.math.vector.Vector;
import hivemall.smile.data.Attribute;
import hivemall.smile.regression.RegressionTree;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.utils.codec.Base91;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.math.MathUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.Counters;
import org.apache.hadoop.mapred.Reporter;
import smile.math.Math;

@Description(name = "train_gradient_tree_boosting_classifier", value = "_FUNC_(array<double|string> features, int label [, string options]) - Returns a relation consists of <int iteration, int model_type, array<string> pred_models, double intercept, double shrinkage, array<double> var_importance, float oob_error_rate>")
/* loaded from: input_file:hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.class */
public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions {
    private static final Log logger = LogFactory.getLog(GradientTreeBoostingClassifierUDTF.class);
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector featureElemOI;
    private PrimitiveObjectInspector labelOI;
    private boolean denseInput;
    private MatrixBuilder matrixBuilder;
    private IntArrayList labels;
    private int _numTrees;
    private double _eta;
    private double _subsample = 0.7d;
    private float _numVars;
    private int _maxDepth;
    private int _maxLeafNodes;
    private int _minSamplesSplit;
    private int _minSamplesLeaf;
    private long _seed;
    private Attribute[] _attributes;

    @Nullable
    private Reporter _progressReporter;

    @Nullable
    private Counters.Counter _iterationCounter;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hivemall/smile/classification/GradientTreeBoostingClassifierUDTF$L2NodeOutput.class */
    public static final class L2NodeOutput implements RegressionTree.NodeOutput {
        final double[] y;

        public L2NodeOutput(double[] dArr) {
            this.y = dArr;
        }

        @Override // hivemall.smile.regression.RegressionTree.NodeOutput
        public double calculate(int[] iArr) {
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i = 0; i < iArr.length; i++) {
                if (iArr[i] > 0) {
                    double d3 = this.y[i];
                    double abs = Math.abs(d3);
                    d += d3;
                    d2 += abs * (2.0d - abs);
                }
            }
            return d / d2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hivemall/smile/classification/GradientTreeBoostingClassifierUDTF$LKNodeOutput.class */
    public static final class LKNodeOutput implements RegressionTree.NodeOutput {
        final double[] y;
        final double k;

        public LKNodeOutput(double[] dArr, int i) {
            this.y = dArr;
            this.k = i;
        }

        @Override // hivemall.smile.regression.RegressionTree.NodeOutput
        public double calculate(int[] iArr) {
            int i = 0;
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i2 = 0; i2 < iArr.length; i2++) {
                if (iArr[i2] > 0) {
                    i++;
                    double d3 = this.y[i2];
                    double abs = Math.abs(d3);
                    d += d3;
                    d2 += abs * (1.0d - abs);
                }
            }
            return d2 < 1.0E-10d ? d / i : ((this.k - 1.0d) / this.k) * (d / d2);
        }
    }

    @Override // hivemall.UDTFWithOptions
    protected Options getOptions() {
        Options options = new Options();
        options.addOption("trees", "num_trees", true, "The number of trees for each task [default: 500]");
        options.addOption("eta", "learning_rate", true, "The learning rate (0, 1]  of procedure [default: 0.05]");
        options.addOption("subsample", "sampling_frac", true, "The fraction of samples to be used for fitting the individual base learners [default: 0.7]");
        options.addOption("vars", "num_variables", true, "The number of random selected features [default: ceil(sqrt(x[0].length))]. int(num_variables * x[0].length) is considered if num_variable is (0,1]");
        options.addOption("depth", "max_depth", true, "The maximum number of the tree depth [default: 8]");
        options.addOption("leafs", "max_leaf_nodes", true, "The maximum number of leaf nodes [default: Integer.MAX_VALUE]");
        options.addOption("splits", "min_split", true, "A node that has greater than or equals to `min_split` examples will split [default: 5]");
        options.addOption("min_samples_leaf", true, "The minimum number of samples in a leaf node [default: 1]");
        options.addOption("seed", true, "seed value in long [default: -1 (random)]");
        options.addOption("attrs", "attribute_types", true, "Comma separated attribute types (Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])");
        return options;
    }

    @Override // hivemall.UDTFWithOptions
    protected CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        int i = 500;
        int i2 = 8;
        int i3 = Integer.MAX_VALUE;
        int i4 = 5;
        int i5 = 1;
        float f = -1.0f;
        double d = 0.05d;
        double d2 = 0.7d;
        Attribute[] attributeArr = null;
        long j = -1;
        CommandLine commandLine = null;
        if (objectInspectorArr.length >= 3) {
            commandLine = parseOptions(HiveUtils.getConstString(objectInspectorArr[2]));
            i = Primitives.parseInt(commandLine.getOptionValue("num_trees"), 500);
            if (i < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + i);
            }
            d = Primitives.parseDouble(commandLine.getOptionValue("learning_rate"), 0.05d);
            d2 = Primitives.parseDouble(commandLine.getOptionValue("subsample"), 0.7d);
            f = Primitives.parseFloat(commandLine.getOptionValue("num_variables"), -1.0f);
            i2 = Primitives.parseInt(commandLine.getOptionValue("max_depth"), 8);
            i3 = Primitives.parseInt(commandLine.getOptionValue("max_leaf_nodes"), Integer.MAX_VALUE);
            i4 = Primitives.parseInt(commandLine.getOptionValue("min_split"), 5);
            i5 = Primitives.parseInt(commandLine.getOptionValue("min_samples_leaf"), 1);
            j = Primitives.parseLong(commandLine.getOptionValue("seed"), -1L);
            attributeArr = SmileExtUtils.resolveAttributes(commandLine.getOptionValue("attribute_types"));
        }
        this._numTrees = i;
        this._eta = d;
        this._subsample = d2;
        this._numVars = f;
        this._maxDepth = i2;
        this._maxLeafNodes = i3;
        this._minSamplesSplit = i4;
        this._minSamplesLeaf = i5;
        this._seed = j;
        this._attributes = attributeArr;
        return commandLine;
    }

    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length != 2 && objectInspectorArr.length != 3) {
            throw new UDFArgumentException(getClass().getSimpleName() + " takes 2 or 3 arguments: array<double|string> features, int label [, const string options]: " + objectInspectorArr.length);
        }
        ListObjectInspector asListOI = HiveUtils.asListOI(objectInspectorArr[0]);
        ObjectInspector listElementObjectInspector = asListOI.getListElementObjectInspector();
        this.featureListOI = asListOI;
        if (HiveUtils.isNumberOI(listElementObjectInspector)) {
            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(listElementObjectInspector);
            this.denseInput = true;
            this.matrixBuilder = new RowMajorDenseMatrixBuilder(8192);
        } else {
            if (!HiveUtils.isStringOI(listElementObjectInspector)) {
                throw new UDFArgumentException("_FUNC_ takes double[] or string[] for the first argument: " + asListOI.getTypeName());
            }
            this.featureElemOI = HiveUtils.asStringOI(listElementObjectInspector);
            this.denseInput = false;
            this.matrixBuilder = new CSRMatrixBuilder(8192);
        }
        this.labelOI = HiveUtils.asIntCompatibleOI(objectInspectorArr[1]);
        processOptions(objectInspectorArr);
        this.labels = new IntArrayList(1024);
        ArrayList arrayList = new ArrayList(6);
        ArrayList arrayList2 = new ArrayList(6);
        arrayList.add("iteration");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        arrayList.add("pred_models");
        arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector));
        arrayList.add("intercept");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
        arrayList.add("shrinkage");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
        arrayList.add("var_importance");
        arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
        arrayList.add("oob_error_rate");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    public void process(Object[] objArr) throws HiveException {
        if (objArr[0] == null) {
            throw new HiveException("array<double> features was null");
        }
        parseFeatures(objArr[0], this.matrixBuilder);
        this.labels.add(PrimitiveObjectInspectorUtils.getInt(objArr[1], this.labelOI));
    }

    private void parseFeatures(@Nonnull Object obj, @Nonnull MatrixBuilder matrixBuilder) {
        if (this.denseInput) {
            int listLength = this.featureListOI.getListLength(obj);
            for (int i = 0; i < listLength; i++) {
                Object listElement = this.featureListOI.getListElement(obj, i);
                if (listElement != null) {
                    matrixBuilder.nextColumn(i, PrimitiveObjectInspectorUtils.getDouble(listElement, this.featureElemOI));
                }
            }
        } else {
            int listLength2 = this.featureListOI.getListLength(obj);
            for (int i2 = 0; i2 < listLength2; i2++) {
                Object listElement2 = this.featureListOI.getListElement(obj, i2);
                if (listElement2 != null) {
                    matrixBuilder.nextColumn(listElement2.toString());
                }
            }
        }
        matrixBuilder.nextRow();
    }

    public void close() throws HiveException {
        this._progressReporter = getReporter();
        this._iterationCounter = this._progressReporter == null ? null : this._progressReporter.getCounter("hivemall.smile.GradientTreeBoostingClassifier$Counter", "iteration");
        reportProgress(this._progressReporter);
        if (!this.labels.isEmpty()) {
            Matrix buildMatrix = this.matrixBuilder.buildMatrix();
            this.matrixBuilder = null;
            int[] array = this.labels.toArray();
            this.labels = null;
            train(buildMatrix, array);
        }
        this.featureListOI = null;
        this.featureElemOI = null;
        this.labelOI = null;
        this._attributes = null;
    }

    private void checkOptions() throws HiveException {
        if (this._eta <= CMAESOptimizer.DEFAULT_STOPFITNESS || this._eta > 1.0d) {
            throw new HiveException("Invalid shrinkage: " + this._eta);
        }
        if (this._subsample <= CMAESOptimizer.DEFAULT_STOPFITNESS || this._subsample > 1.0d) {
            throw new HiveException("Invalid sampling fraction: " + this._subsample);
        }
        if (this._minSamplesSplit <= 0) {
            throw new HiveException("Invalid minSamplesSplit: " + this._minSamplesSplit);
        }
        if (this._maxDepth < 1) {
            throw new HiveException("Invalid maxDepth: " + this._maxDepth);
        }
    }

    private void train(@Nonnull Matrix matrix, @Nonnull int[] iArr) throws HiveException {
        int numRows = matrix.numRows();
        if (numRows != iArr.length) {
            throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(numRows), Integer.valueOf(iArr.length)));
        }
        checkOptions();
        this._attributes = SmileExtUtils.attributeTypes(this._attributes, matrix);
        Matrix shuffle = SmileExtUtils.shuffle(matrix, iArr, this._seed);
        int max = Math.max(iArr) + 1;
        if (max < 2) {
            throw new UDFArgumentException("Only one class or negative class labels.");
        }
        if (max != 2) {
            traink(shuffle, iArr, max);
            return;
        }
        int[] iArr2 = new int[numRows];
        for (int i = 0; i < numRows; i++) {
            if (iArr[i] == 1) {
                iArr2[i] = 1;
            } else {
                iArr2[i] = -1;
            }
        }
        train2(shuffle, iArr2);
    }

    private void train2(@Nonnull Matrix matrix, @Nonnull int[] iArr) throws HiveException {
        int computeNumInputVars = SmileExtUtils.computeNumInputVars(this._numVars, matrix);
        if (logger.isInfoEnabled()) {
            logger.info("k: 2, numTrees: " + this._numTrees + ", shrinkage: " + this._eta + ", subsample: " + this._subsample + ", numVars: " + computeNumInputVars + ", maxDepth: " + this._maxDepth + ", minSamplesSplit: " + this._minSamplesSplit + ", maxLeafs: " + this._maxLeafNodes + ", seed: " + this._seed);
        }
        int numRows = matrix.numRows();
        int round = (int) Math.round(numRows * this._subsample);
        double[] dArr = new double[numRows];
        double[] dArr2 = new double[numRows];
        double mean = Math.mean(iArr);
        double log = 0.5d * Math.log((1.0d + mean) / (1.0d - mean));
        for (int i = 0; i < numRows; i++) {
            dArr[i] = log;
        }
        ColumnMajorIntMatrix sort = SmileExtUtils.sort(this._attributes, matrix);
        L2NodeOutput l2NodeOutput = new L2NodeOutput(dArr2);
        BitSet bitSet = new BitSet(numRows);
        int[] iArr2 = new int[round];
        int[] iArr3 = new int[round];
        for (int i2 = 0; i2 < round; i2++) {
            iArr3[i2] = i2;
        }
        PRNG createPRNG = RandomNumberGeneratorFactory.createPRNG(this._seed == -1 ? SmileExtUtils.generateSeed() : RandomNumberGeneratorFactory.createPRNG(this._seed).nextLong());
        PRNG createPRNG2 = RandomNumberGeneratorFactory.createPRNG(createPRNG.nextLong());
        Vector rowVector = matrix.rowVector();
        for (int i3 = 0; i3 < this._numTrees; i3++) {
            reportProgress(this._progressReporter);
            SmileExtUtils.shuffle(iArr3, createPRNG);
            for (int i4 = 0; i4 < round; i4++) {
                int i5 = iArr3[i4];
                iArr2[i4] = i5;
                bitSet.set(i5);
            }
            for (int i6 = 0; i6 < numRows; i6++) {
                dArr2[i6] = (2.0d * iArr[i6]) / (1.0d + Math.exp((2.0d * iArr[i6]) * dArr[i6]));
            }
            RegressionTree regressionTree = new RegressionTree(this._attributes, matrix, dArr2, computeNumInputVars, this._maxDepth, this._maxLeafNodes, this._minSamplesSplit, this._minSamplesLeaf, sort, iArr2, l2NodeOutput, createPRNG2);
            for (int i7 = 0; i7 < numRows; i7++) {
                matrix.getRow(i7, rowVector);
                int i8 = i7;
                dArr[i8] = dArr[i8] + (this._eta * regressionTree.predict(rowVector));
            }
            int i9 = 0;
            int i10 = 0;
            int nextClearBit = bitSet.nextClearBit(0);
            while (true) {
                int i11 = nextClearBit;
                if (i11 >= numRows) {
                    break;
                }
                i9++;
                if ((dArr[i11] > CMAESOptimizer.DEFAULT_STOPFITNESS ? 1 : 0) != iArr[i11]) {
                    i10++;
                }
                nextClearBit = bitSet.nextClearBit(i11 + 1);
            }
            float f = 0.0f;
            if (i9 > 0) {
                f = i10 / i9;
            }
            forward(i3 + 1, log, this._eta, f, regressionTree);
            bitSet.clear();
        }
    }

    private void traink(Matrix matrix, int[] iArr, int i) throws HiveException {
        int computeNumInputVars = SmileExtUtils.computeNumInputVars(this._numVars, matrix);
        if (logger.isInfoEnabled()) {
            logger.info("k: " + i + ", numTrees: " + this._numTrees + ", shrinkage: " + this._eta + ", subsample: " + this._subsample + ", numVars: " + computeNumInputVars + ", minSamplesSplit: " + this._minSamplesSplit + ", maxDepth: " + this._maxDepth + ", maxLeafs: " + this._maxLeafNodes + ", seed: " + this._seed);
        }
        int numRows = matrix.numRows();
        int round = (int) Math.round(numRows * this._subsample);
        double[][] dArr = new double[i][numRows];
        double[][] dArr2 = new double[i][numRows];
        double[][] dArr3 = new double[i][numRows];
        ColumnMajorIntMatrix sort = SmileExtUtils.sort(this._attributes, matrix);
        LKNodeOutput[] lKNodeOutputArr = new LKNodeOutput[i];
        for (int i2 = 0; i2 < i; i2++) {
            lKNodeOutputArr[i2] = new LKNodeOutput(dArr3[i2], i);
        }
        BitSet bitSet = new BitSet(numRows);
        int[] iArr2 = new int[round];
        int[] permutation = MathUtils.permutation(numRows);
        PRNG createPRNG = RandomNumberGeneratorFactory.createPRNG(this._seed == -1 ? SmileExtUtils.generateSeed() : RandomNumberGeneratorFactory.createPRNG(this._seed).nextLong());
        PRNG createPRNG2 = RandomNumberGeneratorFactory.createPRNG(createPRNG.nextLong());
        int[] iArr3 = new int[numRows];
        Vector rowVector = matrix.rowVector();
        for (int i3 = 0; i3 < this._numTrees; i3++) {
            for (int i4 = 0; i4 < numRows; i4++) {
                double d = Double.NEGATIVE_INFINITY;
                for (int i5 = 0; i5 < i; i5++) {
                    double d2 = dArr[i5][i4];
                    if (d < d2) {
                        d = d2;
                    }
                }
                double d3 = 0.0d;
                for (int i6 = 0; i6 < i; i6++) {
                    double exp = Math.exp(dArr[i6][i4] - d);
                    dArr2[i6][i4] = exp;
                    d3 += exp;
                }
                for (int i7 = 0; i7 < i; i7++) {
                    double[] dArr4 = dArr2[i7];
                    int i8 = i4;
                    dArr4[i8] = dArr4[i8] / d3;
                }
            }
            RegressionTree[] regressionTreeArr = new RegressionTree[i];
            Arrays.fill(iArr3, -1);
            double d4 = Double.NEGATIVE_INFINITY;
            int i9 = 0;
            int i10 = 0;
            for (int i11 = 0; i11 < i; i11++) {
                reportProgress(this._progressReporter);
                double[] dArr5 = dArr3[i11];
                double[] dArr6 = dArr2[i11];
                double[] dArr7 = dArr[i11];
                for (int i12 = 0; i12 < numRows; i12++) {
                    if (iArr[i12] == i11) {
                        dArr5[i12] = 1.0d;
                    } else {
                        dArr5[i12] = 0.0d;
                    }
                    int i13 = i12;
                    dArr5[i13] = dArr5[i13] - dArr6[i12];
                }
                SmileExtUtils.shuffle(permutation, createPRNG);
                for (int i14 = 0; i14 < round; i14++) {
                    iArr2[i14] = permutation[i14];
                    bitSet.set(i14);
                }
                RegressionTree regressionTree = new RegressionTree(this._attributes, matrix, dArr3[i11], computeNumInputVars, this._maxDepth, this._maxLeafNodes, this._minSamplesSplit, this._minSamplesLeaf, sort, iArr2, lKNodeOutputArr[i11], createPRNG2);
                regressionTreeArr[i11] = regressionTree;
                for (int i15 = 0; i15 < numRows; i15++) {
                    matrix.getRow(i15, rowVector);
                    double predict = dArr7[i15] + (this._eta * regressionTree.predict(rowVector));
                    int i16 = i15;
                    dArr7[i16] = dArr7[i16] + predict;
                    if (predict > d4) {
                        d4 = predict;
                        iArr3[i15] = i11;
                    }
                }
            }
            int nextClearBit = bitSet.nextClearBit(0);
            while (true) {
                int i17 = nextClearBit;
                if (i17 >= numRows) {
                    break;
                }
                i9++;
                if (iArr3[i17] != iArr[i17]) {
                    i10++;
                }
                nextClearBit = bitSet.nextClearBit(i17 + 1);
            }
            bitSet.clear();
            float f = 0.0f;
            if (i9 > 0) {
                f = i10 / i9;
            }
            forward(i3 + 1, CMAESOptimizer.DEFAULT_STOPFITNESS, this._eta, f, regressionTreeArr);
        }
    }

    private void forward(int i, double d, double d2, float f, @Nonnull RegressionTree... regressionTreeArr) throws HiveException {
        Text[] model = getModel(regressionTreeArr);
        double[] dArr = new double[this._attributes.length];
        for (RegressionTree regressionTree : regressionTreeArr) {
            double[] importance = regressionTree.importance();
            for (int i2 = 0; i2 < importance.length; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + importance[i2];
            }
        }
        forward(new Object[]{new IntWritable(i), model, new DoubleWritable(d), new DoubleWritable(d2), WritableUtils.toWritableList(dArr), new FloatWritable(f)});
        reportProgress(this._progressReporter);
        incrCounter(this._iterationCounter, 1L);
        logger.info("Forwarded the output of " + i + "-th Boosting iteration out of " + this._numTrees);
    }

    @Nonnull
    private static Text[] getModel(@Nonnull RegressionTree[] regressionTreeArr) throws HiveException {
        int length = regressionTreeArr.length;
        Text[] textArr = new Text[length];
        for (int i = 0; i < length; i++) {
            textArr[i] = new Text(Base91.encode(regressionTreeArr[i].serialize(true)));
        }
        return textArr;
    }
}
