package hivemall.smile.classification;

import hivemall.annotations.VisibleForTesting;
import hivemall.math.matrix.Matrix;
import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
import hivemall.math.vector.DenseVector;
import hivemall.math.vector.SparseVector;
import hivemall.math.vector.Vector;
import hivemall.math.vector.VectorProcedure;
import hivemall.smile.data.Attribute;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.lang.ObjectUtils;
import hivemall.utils.lang.StringUtils;
import hivemall.utils.lang.mutable.MutableInt;
import hivemall.utils.sampling.IntReservoirSampler;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.roaringbitmap.IntConsumer;
import org.roaringbitmap.RoaringBitmap;
import smile.classification.Classifier;
import smile.math.Math;

/* loaded from: input_file:hivemall/smile/classification/DecisionTree.class */
public final class DecisionTree implements Classifier<Vector> {

    @Nonnull
    private final Attribute[] _attributes;
    private final boolean _hasNumericType;

    @Nonnull
    private final Vector _importance;

    @Nonnull
    private final Node _root;
    private final int _maxDepth;

    @Nonnull
    private final SplitRule _rule;
    private final int _k;
    private final int _numVars;
    private final int _minSplit;
    private final int _minLeafSize;

    @Nonnull
    private final ColumnMajorIntMatrix _order;

    @Nonnull
    private final PRNG _rnd;

    /* loaded from: input_file:hivemall/smile/classification/DecisionTree$Node.class */
    public static final class Node implements Externalizable {
        int output;

        @Nullable
        double[] posteriori;
        int splitFeature;
        Attribute.AttributeType splitFeatureType;
        double splitValue;
        double splitScore;
        Node trueChild;
        Node falseChild;
        int trueChildOutput;
        int falseChildOutput;

        public Node() {
            this.output = -1;
            this.posteriori = null;
            this.splitFeature = -1;
            this.splitFeatureType = null;
            this.splitValue = Double.NaN;
            this.splitScore = CMAESOptimizer.DEFAULT_STOPFITNESS;
            this.trueChild = null;
            this.falseChild = null;
            this.trueChildOutput = -1;
            this.falseChildOutput = -1;
        }

        public Node(int i, @Nonnull double[] dArr) {
            this.output = -1;
            this.posteriori = null;
            this.splitFeature = -1;
            this.splitFeatureType = null;
            this.splitValue = Double.NaN;
            this.splitScore = CMAESOptimizer.DEFAULT_STOPFITNESS;
            this.trueChild = null;
            this.falseChild = null;
            this.trueChildOutput = -1;
            this.falseChildOutput = -1;
            this.output = i;
            this.posteriori = dArr;
        }

        private boolean isLeaf() {
            return this.posteriori != null;
        }

        @VisibleForTesting
        public int predict(@Nonnull double[] dArr) {
            return predict(new DenseVector(dArr));
        }

        public int predict(@Nonnull Vector vector) {
            if (this.trueChild == null && this.falseChild == null) {
                return this.output;
            }
            if (this.splitFeatureType == Attribute.AttributeType.NOMINAL) {
                return vector.get(this.splitFeature, Double.NaN) == this.splitValue ? this.trueChild.predict(vector) : this.falseChild.predict(vector);
            }
            if (this.splitFeatureType == Attribute.AttributeType.NUMERIC) {
                return vector.get(this.splitFeature, Double.NaN) <= this.splitValue ? this.trueChild.predict(vector) : this.falseChild.predict(vector);
            }
            throw new IllegalStateException("Unsupported attribute type: " + this.splitFeatureType);
        }

        public void predict(@Nonnull Vector vector, @Nonnull PredictionHandler predictionHandler) {
            if (this.trueChild == null && this.falseChild == null) {
                predictionHandler.handle(this.output, this.posteriori);
                return;
            }
            if (this.splitFeatureType == Attribute.AttributeType.NOMINAL) {
                if (vector.get(this.splitFeature, Double.NaN) == this.splitValue) {
                    this.trueChild.predict(vector, predictionHandler);
                    return;
                } else {
                    this.falseChild.predict(vector, predictionHandler);
                    return;
                }
            }
            if (this.splitFeatureType != Attribute.AttributeType.NUMERIC) {
                throw new IllegalStateException("Unsupported attribute type: " + this.splitFeatureType);
            }
            if (vector.get(this.splitFeature, Double.NaN) <= this.splitValue) {
                this.trueChild.predict(vector, predictionHandler);
            } else {
                this.falseChild.predict(vector, predictionHandler);
            }
        }

        public void exportJavascript(@Nonnull StringBuilder sb, @Nullable String[] strArr, @Nullable String[] strArr2, int i) {
            if (this.trueChild == null && this.falseChild == null) {
                DecisionTree.indent(sb, i);
                sb.append("").append(SmileExtUtils.resolveName(this.output, strArr2)).append(";\n");
                return;
            }
            DecisionTree.indent(sb, i);
            if (this.splitFeatureType == Attribute.AttributeType.NOMINAL) {
                if (strArr == null) {
                    sb.append("if( x[").append(this.splitFeature).append("] == ").append(this.splitValue).append(" ) {\n");
                } else {
                    sb.append("if( ").append(SmileExtUtils.resolveFeatureName(this.splitFeature, strArr)).append(" == ").append(this.splitValue).append(" ) {\n");
                }
            } else {
                if (this.splitFeatureType != Attribute.AttributeType.NUMERIC) {
                    throw new IllegalStateException("Unsupported attribute type: " + this.splitFeatureType);
                }
                if (strArr == null) {
                    sb.append("if( x[").append(this.splitFeature).append("] <= ").append(this.splitValue).append(" ) {\n");
                } else {
                    sb.append("if( ").append(SmileExtUtils.resolveFeatureName(this.splitFeature, strArr)).append(" <= ").append(this.splitValue).append(" ) {\n");
                }
            }
            this.trueChild.exportJavascript(sb, strArr, strArr2, i + 1);
            DecisionTree.indent(sb, i);
            sb.append("} else  {\n");
            this.falseChild.exportJavascript(sb, strArr, strArr2, i + 1);
            DecisionTree.indent(sb, i);
            sb.append("}\n");
        }

        public void exportGraphviz(@Nonnull StringBuilder sb, @Nullable String[] strArr, @Nullable String[] strArr2, @Nonnull String str, @Nullable double[] dArr, @Nonnull MutableInt mutableInt, int i) {
            int value = mutableInt.getValue();
            if (this.trueChild == null && this.falseChild == null) {
                sb.append(String.format(" %d [label=<%s = %s>, fillcolor=\"%s\", shape=ellipse];\n", Integer.valueOf(value), str, SmileExtUtils.resolveName(this.output, strArr2), (dArr == null || this.output >= dArr.length) ? "#00000000" : String.format("%.4f,1.000,1.000", Double.valueOf(dArr[this.output]))));
                if (value != i) {
                    sb.append(' ').append(i).append(" -> ").append(value);
                    if (i == 0) {
                        if (value == 1) {
                            sb.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
                        } else {
                            sb.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
                        }
                    }
                    sb.append(";\n");
                    return;
                }
                return;
            }
            if (this.splitFeatureType == Attribute.AttributeType.NOMINAL) {
                sb.append(String.format(" %d [label=<%s = %s>, fillcolor=\"#00000000\"];\n", Integer.valueOf(value), SmileExtUtils.resolveFeatureName(this.splitFeature, strArr), Double.toString(this.splitValue)));
            } else {
                if (this.splitFeatureType != Attribute.AttributeType.NUMERIC) {
                    throw new IllegalStateException("Unsupported attribute type: " + this.splitFeatureType);
                }
                sb.append(String.format(" %d [label=<%s &le; %s>, fillcolor=\"#00000000\"];\n", Integer.valueOf(value), SmileExtUtils.resolveFeatureName(this.splitFeature, strArr), Double.toString(this.splitValue)));
            }
            if (value != i) {
                sb.append(' ').append(i).append(" -> ").append(value);
                if (i == 0) {
                    if (value == 1) {
                        sb.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]");
                    } else {
                        sb.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]");
                    }
                }
                sb.append(";\n");
            }
            mutableInt.addValue(1);
            this.trueChild.exportGraphviz(sb, strArr, strArr2, str, dArr, mutableInt, value);
            mutableInt.addValue(1);
            this.falseChild.exportGraphviz(sb, strArr, strArr2, str, dArr, mutableInt, value);
        }

        @Deprecated
        public int opCodegen(@Nonnull List<String> list, int i) {
            int opCodegen;
            StringBuilder sb = new StringBuilder();
            if (this.trueChild == null && this.falseChild == null) {
                sb.append("push ").append(this.output);
                list.add(sb.toString());
                sb.setLength(0);
                sb.append("goto last");
                list.add(sb.toString());
                opCodegen = 0 + 2;
            } else if (this.splitFeatureType == Attribute.AttributeType.NOMINAL) {
                sb.append("push ").append("x[").append(this.splitFeature).append("]");
                list.add(sb.toString());
                sb.setLength(0);
                sb.append("push ").append(this.splitValue);
                list.add(sb.toString());
                sb.setLength(0);
                sb.append("ifeq ");
                list.add(sb.toString());
                int i2 = i + 3;
                int opCodegen2 = this.trueChild.opCodegen(list, i2);
                list.set(i2 - 1, "ifeq " + String.valueOf(i2 + opCodegen2));
                opCodegen = 0 + 3 + opCodegen2 + this.falseChild.opCodegen(list, i2 + opCodegen2);
            } else {
                if (this.splitFeatureType != Attribute.AttributeType.NUMERIC) {
                    throw new IllegalStateException("Unsupported attribute type: " + this.splitFeatureType);
                }
                sb.append("push ").append("x[").append(this.splitFeature).append("]");
                list.add(sb.toString());
                sb.setLength(0);
                sb.append("push ").append(this.splitValue);
                list.add(sb.toString());
                sb.setLength(0);
                sb.append("ifle ");
                list.add(sb.toString());
                int i3 = i + 3;
                int opCodegen3 = this.trueChild.opCodegen(list, i3);
                list.set(i3 - 1, "ifle " + String.valueOf(i3 + opCodegen3));
                opCodegen = 0 + 3 + opCodegen3 + this.falseChild.opCodegen(list, i3 + opCodegen3);
            }
            return opCodegen;
        }

        @Override // java.io.Externalizable
        public void writeExternal(ObjectOutput objectOutput) throws IOException {
            objectOutput.writeInt(this.splitFeature);
            if (this.splitFeatureType == null) {
                objectOutput.writeByte(-1);
            } else {
                objectOutput.writeByte(this.splitFeatureType.getTypeId());
            }
            objectOutput.writeDouble(this.splitValue);
            if (isLeaf()) {
                objectOutput.writeBoolean(true);
                objectOutput.writeInt(this.output);
                objectOutput.writeInt(this.posteriori.length);
                for (int i = 0; i < this.posteriori.length; i++) {
                    objectOutput.writeDouble(this.posteriori[i]);
                }
                return;
            }
            objectOutput.writeBoolean(false);
            if (this.trueChild == null) {
                objectOutput.writeBoolean(false);
            } else {
                objectOutput.writeBoolean(true);
                this.trueChild.writeExternal(objectOutput);
            }
            if (this.falseChild == null) {
                objectOutput.writeBoolean(false);
            } else {
                objectOutput.writeBoolean(true);
                this.falseChild.writeExternal(objectOutput);
            }
        }

        @Override // java.io.Externalizable
        public void readExternal(ObjectInput objectInput) throws IOException, ClassNotFoundException {
            this.splitFeature = objectInput.readInt();
            byte readByte = objectInput.readByte();
            if (readByte == -1) {
                this.splitFeatureType = null;
            } else {
                this.splitFeatureType = Attribute.AttributeType.resolve(readByte);
            }
            this.splitValue = objectInput.readDouble();
            if (!objectInput.readBoolean()) {
                if (objectInput.readBoolean()) {
                    this.trueChild = new Node();
                    this.trueChild.readExternal(objectInput);
                }
                if (objectInput.readBoolean()) {
                    this.falseChild = new Node();
                    this.falseChild.readExternal(objectInput);
                    return;
                }
                return;
            }
            this.output = objectInput.readInt();
            int readInt = objectInput.readInt();
            double[] dArr = new double[readInt];
            for (int i = 0; i < readInt; i++) {
                dArr[i] = objectInput.readDouble();
            }
            this.posteriori = dArr;
        }
    }

    /* loaded from: input_file:hivemall/smile/classification/DecisionTree$SplitRule.class */
    public enum SplitRule {
        GINI,
        ENTROPY,
        CLASSIFICATION_ERROR
    }

    /* loaded from: input_file:hivemall/smile/classification/DecisionTree$TrainNode.class */
    private final class TrainNode implements Comparable<TrainNode> {
        final Node node;
        final Matrix x;
        final int[] y;
        int[] bags;
        final int depth;

        public TrainNode(Node node, Matrix matrix, int[] iArr, int[] iArr2, int i) {
            this.node = node;
            this.x = matrix;
            this.y = iArr;
            this.bags = iArr2;
            this.depth = i;
        }

        @Override // java.lang.Comparable
        public int compareTo(TrainNode trainNode) {
            return (int) Math.signum(trainNode.node.splitScore - this.node.splitScore);
        }

        public boolean findBestSplit() {
            int length;
            if (this.depth >= DecisionTree.this._maxDepth || (length = this.bags.length) <= DecisionTree.this._minSplit) {
                return false;
            }
            int[] iArr = new int[DecisionTree.this._k];
            if (sampleCount(iArr)) {
                return false;
            }
            double impurity = DecisionTree.impurity(iArr, length, DecisionTree.this._rule);
            int[] bagsToSamples = DecisionTree.this._hasNumericType ? SmileExtUtils.bagsToSamples(this.bags, this.x.numRows()) : null;
            int[] iArr2 = new int[DecisionTree.this._k];
            for (int i : variableIndex(this.x, this.bags)) {
                Node findBestSplit = findBestSplit(length, iArr, iArr2, impurity, i, bagsToSamples);
                if (findBestSplit.splitScore > this.node.splitScore) {
                    this.node.splitFeature = findBestSplit.splitFeature;
                    this.node.splitFeatureType = findBestSplit.splitFeatureType;
                    this.node.splitValue = findBestSplit.splitValue;
                    this.node.splitScore = findBestSplit.splitScore;
                    this.node.trueChildOutput = findBestSplit.trueChildOutput;
                    this.node.falseChildOutput = findBestSplit.falseChildOutput;
                }
            }
            return this.node.splitFeature != -1;
        }

        @Nonnull
        private int[] variableIndex(@Nonnull Matrix matrix, @Nonnull int[] iArr) {
            final IntReservoirSampler intReservoirSampler = new IntReservoirSampler(DecisionTree.this._numVars, DecisionTree.this._rnd.nextLong());
            if (matrix.isSparse()) {
                final RoaringBitmap roaringBitmap = new RoaringBitmap();
                VectorProcedure vectorProcedure = new VectorProcedure() { // from class: hivemall.smile.classification.DecisionTree.TrainNode.1
                    @Override // hivemall.math.vector.VectorProcedure
                    public void apply(int i) {
                        roaringBitmap.add(i);
                    }
                };
                for (int i : iArr) {
                    matrix.eachColumnIndexInRow(i, vectorProcedure);
                }
                roaringBitmap.forEach(new IntConsumer() { // from class: hivemall.smile.classification.DecisionTree.TrainNode.2
                    @Override // org.roaringbitmap.IntConsumer
                    public void accept(int i2) {
                        intReservoirSampler.add(i2);
                    }
                });
            } else {
                int length = DecisionTree.this._attributes.length;
                for (int i2 = 0; i2 < length; i2++) {
                    intReservoirSampler.add(i2);
                }
            }
            return intReservoirSampler.getSample();
        }

        private boolean sampleCount(@Nonnull int[] iArr) {
            int i = -1;
            boolean z = true;
            for (int i2 = 0; i2 < this.bags.length; i2++) {
                int i3 = this.y[this.bags[i2]];
                iArr[i3] = iArr[i3] + 1;
                if (i == -1) {
                    i = i3;
                } else if (i3 != i) {
                    z = false;
                }
            }
            return z;
        }

        private Node findBestSplit(final int i, final int[] iArr, final int[] iArr2, final double d, final int i2, @Nullable final int[] iArr3) {
            final Node node = new Node();
            if (DecisionTree.this._attributes[i2].type == Attribute.AttributeType.NOMINAL) {
                int size = DecisionTree.this._attributes[i2].getSize();
                int[][] iArr4 = new int[size][DecisionTree.this._k];
                int length = this.bags.length;
                for (int i3 = 0; i3 < length; i3++) {
                    int i4 = this.bags[i3];
                    double d2 = this.x.get(i4, i2, Double.NaN);
                    if (!Double.isNaN(d2)) {
                        int[] iArr5 = iArr4[(int) d2];
                        int i5 = this.y[i4];
                        iArr5[i5] = iArr5[i5] + 1;
                    }
                }
                for (int i6 = 0; i6 < size; i6++) {
                    int sum = Math.sum(iArr4[i6]);
                    int i7 = i - sum;
                    if (sum >= DecisionTree.this._minSplit && i7 >= DecisionTree.this._minSplit) {
                        for (int i8 = 0; i8 < DecisionTree.this._k; i8++) {
                            iArr2[i8] = iArr[i8] - iArr4[i6][i8];
                        }
                        double impurity = (d - ((sum / i) * DecisionTree.impurity(iArr4[i6], sum, DecisionTree.this._rule))) - ((i7 / i) * DecisionTree.impurity(iArr2, i7, DecisionTree.this._rule));
                        if (impurity > node.splitScore) {
                            node.splitFeature = i2;
                            node.splitFeatureType = Attribute.AttributeType.NOMINAL;
                            node.splitValue = i6;
                            node.splitScore = impurity;
                            node.trueChildOutput = Math.whichMax(iArr4[i6]);
                            node.falseChildOutput = Math.whichMax(iArr2);
                        }
                    }
                }
            } else {
                if (DecisionTree.this._attributes[i2].type != Attribute.AttributeType.NUMERIC) {
                    throw new IllegalStateException("Unsupported attribute type: " + DecisionTree.this._attributes[i2].type);
                }
                final int[] iArr6 = new int[DecisionTree.this._k];
                DecisionTree.this._order.eachNonNullInColumn(i2, new VectorProcedure() { // from class: hivemall.smile.classification.DecisionTree.TrainNode.3
                    double prevx = Double.NaN;
                    int prevy = -1;

                    @Override // hivemall.math.vector.VectorProcedure
                    public void apply(int i9, int i10) {
                        int i11 = iArr3[i10];
                        if (i11 == 0) {
                            return;
                        }
                        double d3 = TrainNode.this.x.get(i10, i2, Double.NaN);
                        if (Double.isNaN(d3)) {
                            return;
                        }
                        int i12 = TrainNode.this.y[i10];
                        if (Double.isNaN(this.prevx) || d3 == this.prevx || i12 == this.prevy) {
                            this.prevx = d3;
                            this.prevy = i12;
                            int[] iArr7 = iArr6;
                            iArr7[i12] = iArr7[i12] + i11;
                            return;
                        }
                        int sum2 = Math.sum(iArr6);
                        int i13 = i - sum2;
                        if (sum2 < DecisionTree.this._minSplit || i13 < DecisionTree.this._minSplit) {
                            this.prevx = d3;
                            this.prevy = i12;
                            int[] iArr8 = iArr6;
                            iArr8[i12] = iArr8[i12] + i11;
                            return;
                        }
                        for (int i14 = 0; i14 < DecisionTree.this._k; i14++) {
                            iArr2[i14] = iArr[i14] - iArr6[i14];
                        }
                        double impurity2 = (d - ((sum2 / i) * DecisionTree.impurity(iArr6, sum2, DecisionTree.this._rule))) - ((i13 / i) * DecisionTree.impurity(iArr2, i13, DecisionTree.this._rule));
                        if (impurity2 > node.splitScore) {
                            node.splitFeature = i2;
                            node.splitFeatureType = Attribute.AttributeType.NUMERIC;
                            node.splitValue = (d3 + this.prevx) / 2.0d;
                            node.splitScore = impurity2;
                            node.trueChildOutput = Math.whichMax(iArr6);
                            node.falseChildOutput = Math.whichMax(iArr2);
                        }
                        this.prevx = d3;
                        this.prevy = i12;
                        int[] iArr9 = iArr6;
                        iArr9[i12] = iArr9[i12] + i11;
                    }
                });
            }
            return node;
        }

        public boolean split(@Nullable PriorityQueue<TrainNode> priorityQueue) {
            if (this.node.splitFeature < 0) {
                throw new IllegalStateException("Split a node with invalid feature.");
            }
            int length = (int) (this.bags.length * 0.4d);
            IntArrayList intArrayList = new IntArrayList(length);
            IntArrayList intArrayList2 = new IntArrayList(length);
            double[] dArr = new double[DecisionTree.this._k];
            double[] dArr2 = new double[DecisionTree.this._k];
            int splitSamples = splitSamples(intArrayList, intArrayList2, dArr, dArr2);
            int length2 = this.bags.length - splitSamples;
            this.bags = null;
            if (splitSamples < DecisionTree.this._minLeafSize || length2 < DecisionTree.this._minLeafSize) {
                this.node.splitFeature = -1;
                this.node.splitFeatureType = null;
                this.node.splitValue = Double.NaN;
                this.node.splitScore = CMAESOptimizer.DEFAULT_STOPFITNESS;
                return false;
            }
            for (int i = 0; i < DecisionTree.this._k; i++) {
                int i2 = i;
                dArr[i2] = dArr[i2] / splitSamples;
                int i3 = i;
                dArr2[i3] = dArr2[i3] / length2;
            }
            this.node.trueChild = new Node(this.node.trueChildOutput, dArr);
            TrainNode trainNode = new TrainNode(this.node.trueChild, this.x, this.y, intArrayList.toArray(), this.depth + 1);
            if (splitSamples >= DecisionTree.this._minSplit && trainNode.findBestSplit()) {
                if (priorityQueue != null) {
                    priorityQueue.add(trainNode);
                } else {
                    trainNode.split(null);
                }
            }
            this.node.falseChild = new Node(this.node.falseChildOutput, dArr2);
            TrainNode trainNode2 = new TrainNode(this.node.falseChild, this.x, this.y, intArrayList2.toArray(), this.depth + 1);
            if (length2 >= DecisionTree.this._minSplit && trainNode2.findBestSplit()) {
                if (priorityQueue != null) {
                    priorityQueue.add(trainNode2);
                } else {
                    trainNode2.split(null);
                }
            }
            DecisionTree.this._importance.incr(this.node.splitFeature, this.node.splitScore);
            this.node.posteriori = null;
            return true;
        }

        private int splitSamples(@Nonnull IntArrayList intArrayList, @Nonnull IntArrayList intArrayList2, @Nonnull double[] dArr, @Nonnull double[] dArr2) {
            int i = 0;
            if (this.node.splitFeatureType == Attribute.AttributeType.NOMINAL) {
                int i2 = this.node.splitFeature;
                double d = this.node.splitValue;
                int length = this.bags.length;
                for (int i3 = 0; i3 < length; i3++) {
                    int i4 = this.bags[i3];
                    if (this.x.get(i4, i2, Double.NaN) == d) {
                        intArrayList.add(i4);
                        int i5 = this.y[i4];
                        dArr[i5] = dArr[i5] + 1.0d;
                        i++;
                    } else {
                        intArrayList2.add(i4);
                        int i6 = this.y[i4];
                        dArr2[i6] = dArr2[i6] + 1.0d;
                    }
                }
            } else {
                if (this.node.splitFeatureType != Attribute.AttributeType.NUMERIC) {
                    throw new IllegalStateException("Unsupported attribute type: " + this.node.splitFeatureType);
                }
                int i7 = this.node.splitFeature;
                double d2 = this.node.splitValue;
                int length2 = this.bags.length;
                for (int i8 = 0; i8 < length2; i8++) {
                    int i9 = this.bags[i8];
                    if (this.x.get(i9, i7, Double.NaN) <= d2) {
                        intArrayList.add(i9);
                        int i10 = this.y[i9];
                        dArr[i10] = dArr[i10] + 1.0d;
                        i++;
                    } else {
                        intArrayList2.add(i9);
                        int i11 = this.y[i9];
                        dArr2[i11] = dArr2[i11] + 1.0d;
                    }
                }
            }
            return i;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void indent(StringBuilder sb, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            sb.append("  ");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double impurity(@Nonnull int[] iArr, int i, @Nonnull SplitRule splitRule) {
        double d = 0.0d;
        switch (splitRule) {
            case GINI:
                d = 1.0d;
                for (int i2 : iArr) {
                    if (i2 > 0) {
                        double d2 = i2 / i;
                        d -= d2 * d2;
                    }
                }
                break;
            case ENTROPY:
                for (int i3 : iArr) {
                    if (i3 > 0) {
                        double d3 = i3 / i;
                        d -= d3 * Math.log2(d3);
                    }
                }
                break;
            case CLASSIFICATION_ERROR:
                double d4 = 0.0d;
                for (int i4 : iArr) {
                    if (i4 > 0) {
                        d4 = Math.max(d4, i4 / i);
                    }
                }
                d = Math.abs(1.0d - d4);
                break;
        }
        return d;
    }

    public DecisionTree(@Nullable Attribute[] attributeArr, @Nonnull Matrix matrix, @Nonnull int[] iArr, int i) {
        this(attributeArr, matrix, iArr, matrix.numColumns(), Integer.MAX_VALUE, i, 2, 1, null, null, SplitRule.GINI, null);
    }

    public DecisionTree(@Nullable Attribute[] attributeArr, @Nullable Matrix matrix, @Nullable int[] iArr, int i, @Nullable PRNG prng) {
        this(attributeArr, matrix, iArr, matrix.numColumns(), Integer.MAX_VALUE, i, 2, 1, null, null, SplitRule.GINI, prng);
    }

    public DecisionTree(@Nullable Attribute[] attributeArr, @Nonnull Matrix matrix, @Nonnull int[] iArr, int i, int i2, int i3, int i4, int i5, @Nullable int[] iArr2, @Nullable ColumnMajorIntMatrix columnMajorIntMatrix, @Nonnull SplitRule splitRule, @Nullable PRNG prng) {
        TrainNode poll;
        checkArgument(matrix, iArr, i, i2, i3, i4, i5);
        this._k = Math.max(iArr) + 1;
        if (this._k < 2) {
            throw new IllegalArgumentException("Only one class or negative class labels.");
        }
        this._attributes = SmileExtUtils.attributeTypes(attributeArr, matrix);
        if (attributeArr.length != matrix.numColumns()) {
            throw new IllegalArgumentException("-attrs option is invalid: " + Arrays.toString(attributeArr));
        }
        this._hasNumericType = SmileExtUtils.containsNumericType(this._attributes);
        this._numVars = i;
        this._maxDepth = i2;
        this._minSplit = i4;
        this._minLeafSize = i5;
        this._rule = splitRule;
        this._order = columnMajorIntMatrix == null ? SmileExtUtils.sort(this._attributes, matrix) : columnMajorIntMatrix;
        this._importance = matrix.isSparse() ? new SparseVector() : new DenseVector(this._attributes.length);
        this._rnd = prng == null ? RandomNumberGeneratorFactory.createPRNG() : prng;
        int length = iArr.length;
        int[] iArr3 = new int[this._k];
        if (iArr2 == null) {
            iArr2 = new int[length];
            for (int i6 = 0; i6 < length; i6++) {
                iArr2[i6] = i6;
                int i7 = iArr[i6];
                iArr3[i7] = iArr3[i7] + 1;
            }
        } else {
            for (int i8 : iArr2) {
                int i9 = iArr[i8];
                iArr3[i9] = iArr3[i9] + 1;
            }
        }
        double[] dArr = new double[this._k];
        for (int i10 = 0; i10 < this._k; i10++) {
            dArr[i10] = iArr3[i10] / length;
        }
        this._root = new Node(Math.whichMax(iArr3), dArr);
        TrainNode trainNode = new TrainNode(this._root, matrix, iArr, iArr2, 1);
        if (i3 == Integer.MAX_VALUE) {
            if (trainNode.findBestSplit()) {
                trainNode.split(null);
                return;
            }
            return;
        }
        PriorityQueue<TrainNode> priorityQueue = new PriorityQueue<>();
        if (trainNode.findBestSplit()) {
            priorityQueue.add(trainNode);
        }
        for (int i11 = 1; i11 < i3 && (poll = priorityQueue.poll()) != null; i11++) {
            poll.split(priorityQueue);
        }
    }

    @VisibleForTesting
    Node getRootNode() {
        return this._root;
    }

    private static void checkArgument(@Nonnull Matrix matrix, @Nonnull int[] iArr, int i, int i2, int i3, int i4, int i5) {
        if (matrix.numRows() != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(matrix.numRows()), Integer.valueOf(iArr.length)));
        }
        if (i <= 0 || i > matrix.numColumns()) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + i);
        }
        if (i2 < 2) {
            throw new IllegalArgumentException("maxDepth should be greater than 1: " + i2);
        }
        if (i3 < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + i3);
        }
        if (i4 < 2) {
            throw new IllegalArgumentException("Invalid minimum number of samples required to split an internal node: " + i4);
        }
        if (i5 < 1) {
            throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + i5);
        }
    }

    @Nonnull
    public Vector importance() {
        return this._importance;
    }

    @VisibleForTesting
    public int predict(@Nonnull double[] dArr) {
        return predict((Vector) new DenseVector(dArr));
    }

    @Override // smile.classification.Classifier
    public int predict(@Nonnull Vector vector) {
        return this._root.predict(vector);
    }

    @Override // smile.classification.Classifier
    public int predict(Vector vector, double[] dArr) {
        throw new UnsupportedOperationException("Not supported.");
    }

    @Nonnull
    public String predictJsCodegen(@Nonnull String[] strArr, @Nonnull String[] strArr2) {
        StringBuilder sb = new StringBuilder(1024);
        this._root.exportJavascript(sb, strArr, strArr2, 0);
        return sb.toString();
    }

    @Nonnull
    @Deprecated
    public String predictOpCodegen(@Nonnull String str) {
        ArrayList arrayList = new ArrayList();
        this._root.opCodegen(arrayList, 0);
        arrayList.add("call end");
        return StringUtils.concat(arrayList, str);
    }

    @Nonnull
    public byte[] serialize(boolean z) throws HiveException {
        try {
            return z ? ObjectUtils.toCompressedBytes((Externalizable) this._root) : ObjectUtils.toBytes((Externalizable) this._root);
        } catch (IOException e) {
            throw new HiveException("IOException cause while serializing DecisionTree object", e);
        } catch (Exception e2) {
            throw new HiveException("Exception cause while serializing DecisionTree object", e2);
        }
    }

    @Nonnull
    public static Node deserialize(@Nonnull byte[] bArr, int i, boolean z) throws HiveException {
        Node node = new Node();
        try {
            if (z) {
                ObjectUtils.readCompressedObject(bArr, 0, i, node);
            } else {
                ObjectUtils.readObject(bArr, i, node);
            }
            return node;
        } catch (IOException e) {
            throw new HiveException("IOException cause while deserializing DecisionTree object", e);
        } catch (Exception e2) {
            throw new HiveException("Exception cause while deserializing DecisionTree object", e2);
        }
    }

    public String toString() {
        return this._root == null ? "" : predictJsCodegen(null, null);
    }
}
