package hivemall.ensemble.bagging;

import javax.annotation.Nullable;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;

@Description(name = "voted_avg", value = "_FUNC_(double value) - Returns an averaged value by bagging for classification")
/* loaded from: input_file:hivemall/ensemble/bagging/VotedAvgUDAF.class */
public final class VotedAvgUDAF extends UDAF {

    /* loaded from: input_file:hivemall/ensemble/bagging/VotedAvgUDAF$Evaluator.class */
    public static class Evaluator implements UDAFEvaluator {
        private PartialResult partial;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* loaded from: input_file:hivemall/ensemble/bagging/VotedAvgUDAF$Evaluator$PartialResult.class */
        public static class PartialResult {
            double positiveSum;
            int positiveCnt;
            double negativeSum;
            int negativeCnt;

            void init() {
                this.positiveSum = CMAESOptimizer.DEFAULT_STOPFITNESS;
                this.positiveCnt = 0;
                this.negativeSum = CMAESOptimizer.DEFAULT_STOPFITNESS;
                this.negativeCnt = 0;
            }
        }

        public void init() {
            this.partial = null;
        }

        public boolean iterate(@Nullable DoubleWritable doubleWritable) {
            if (doubleWritable == null) {
                return true;
            }
            if (this.partial == null) {
                this.partial = new PartialResult();
                this.partial.init();
            }
            double d = doubleWritable.get();
            if (d > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                this.partial.positiveSum += d;
                this.partial.positiveCnt++;
                return true;
            }
            if (d >= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                return true;
            }
            this.partial.negativeSum += d;
            this.partial.negativeCnt++;
            return true;
        }

        public PartialResult terminatePartial() {
            return this.partial;
        }

        public boolean merge(PartialResult partialResult) {
            if (partialResult == null) {
                return true;
            }
            if (this.partial == null) {
                this.partial = new PartialResult();
                this.partial.init();
            }
            this.partial.positiveSum += partialResult.positiveSum;
            this.partial.positiveCnt += partialResult.positiveCnt;
            this.partial.negativeSum += partialResult.negativeSum;
            this.partial.negativeCnt += partialResult.negativeCnt;
            return true;
        }

        public DoubleWritable terminate() {
            if (this.partial == null) {
                return null;
            }
            if (this.partial.positiveCnt > this.partial.negativeCnt) {
                return new DoubleWritable(this.partial.positiveSum / this.partial.positiveCnt);
            }
            if (this.partial.negativeCnt != 0) {
                return new DoubleWritable(this.partial.negativeSum / this.partial.negativeCnt);
            }
            if ($assertionsDisabled || this.partial.negativeSum == CMAESOptimizer.DEFAULT_STOPFITNESS) {
                return new DoubleWritable(CMAESOptimizer.DEFAULT_STOPFITNESS);
            }
            throw new AssertionError(this.partial.negativeSum);
        }

        static {
            $assertionsDisabled = !VotedAvgUDAF.class.desiredAssertionStatus();
        }
    }
}
