package org.apache.joshua.decoder.hypergraph;

import java.util.HashMap;
import java.util.Iterator;

/* loaded from: input_file:org/apache/joshua/decoder/hypergraph/DefaultInsideOutside.class */
public abstract class DefaultInsideOutside {
    double scaling_factor;
    int ADD_MODE = 0;
    final int LOG_SEMIRING = 1;
    int SEMIRING = 1;
    double ZERO_IN_SEMIRING = Double.NEGATIVE_INFINITY;
    double ONE_IN_SEMIRING = 0.0d;
    private final HashMap<HGNode, Double> tbl_inside_prob = new HashMap<>();
    private final HashMap<HGNode, Double> tbl_outside_prob = new HashMap<>();
    double normalizationConstant = this.ONE_IN_SEMIRING;
    private final HashMap<HGNode, Integer> tbl_num_parent_deductions = new HashMap<>();
    private HashMap<HGNode, Integer> tbl_for_sanity_check = null;

    protected abstract double getHyperedgeLogProb(HyperEdge hyperEdge, HGNode hGNode);

    protected double getHyperedgeLogProb(HyperEdge hyperEdge, HGNode hGNode, double d) {
        return getHyperedgeLogProb(hyperEdge, hGNode) * d;
    }

    public void runInsideOutside(HyperGraph hyperGraph, int i, int i2, double d) {
        setup_semiring(i2, i);
        this.scaling_factor = d;
        inside_estimation_hg(hyperGraph);
        outside_estimation_hg(hyperGraph);
        this.normalizationConstant = this.tbl_inside_prob.get(hyperGraph.goalNode).doubleValue();
        System.out.println("normalization constant is " + this.normalizationConstant);
        this.tbl_num_parent_deductions.clear();
        sanityCheckHG(hyperGraph);
    }

    public void clearState() {
        this.tbl_num_parent_deductions.clear();
        this.tbl_inside_prob.clear();
        this.tbl_outside_prob.clear();
    }

    public double getLogNormalizationConstant() {
        return this.normalizationConstant;
    }

    public double getEdgeUnormalizedPosteriorLogProb(HyperEdge hyperEdge, HGNode hGNode) {
        double doubleValue = this.tbl_outside_prob.get(hGNode).doubleValue();
        double d = this.ONE_IN_SEMIRING;
        if (hyperEdge.getTailNodes() != null) {
            Iterator<HGNode> it = hyperEdge.getTailNodes().iterator();
            while (it.hasNext()) {
                d = multi_in_semiring(d, this.tbl_inside_prob.get(it.next()).doubleValue());
            }
        }
        return multi_in_semiring(multi_in_semiring(d, doubleValue), getHyperedgeLogProb(hyperEdge, hGNode, this.scaling_factor));
    }

    public double getEdgePosteriorProb(HyperEdge hyperEdge, HGNode hGNode) {
        if (this.SEMIRING != 1) {
            throw new RuntimeException("not implemented");
        }
        double exp = Math.exp(getEdgeUnormalizedPosteriorLogProb(hyperEdge, hGNode) - getLogNormalizationConstant());
        if (exp < -0.01d || exp > 1.01d) {
            throw new RuntimeException("res is not within [0,1], must be wrong value: " + exp);
        }
        return exp;
    }

    public double getNodeUnnormalizedPosteriorLogProb(HGNode hGNode) {
        return multi_in_semiring(this.tbl_inside_prob.get(hGNode).doubleValue(), this.tbl_outside_prob.get(hGNode).doubleValue());
    }

    public double getNodePosteriorProb(HGNode hGNode) {
        if (this.SEMIRING != 1) {
            throw new RuntimeException("not implemented");
        }
        double exp = Math.exp(getNodeUnnormalizedPosteriorLogProb(hGNode) - getLogNormalizationConstant());
        if (exp < -0.01d || exp > 1.01d) {
            throw new RuntimeException("res is not within [0,1], must be wrong value: " + exp);
        }
        return exp;
    }

    public void sanityCheckHG(HyperGraph hyperGraph) {
        this.tbl_for_sanity_check = new HashMap<>();
        sanity_check_item(hyperGraph.goalNode);
        System.out.println("survied sanity check!!!!");
    }

    private void sanity_check_item(HGNode hGNode) {
        if (this.tbl_for_sanity_check.containsKey(hGNode)) {
            return;
        }
        this.tbl_for_sanity_check.put(hGNode, 1);
        double d = 0.0d;
        for (HyperEdge hyperEdge : hGNode.hyperedges) {
            d += getEdgePosteriorProb(hyperEdge, hGNode);
            sanity_check_deduction(hyperEdge);
        }
        double nodePosteriorProb = getNodePosteriorProb(hGNode);
        if (Math.abs(d - nodePosteriorProb) > 0.001d) {
            throw new RuntimeException("prob_sum=" + d + "; supposed_sum=" + nodePosteriorProb + "; sanity check fail!!!!");
        }
    }

    private void sanity_check_deduction(HyperEdge hyperEdge) {
        if (null != hyperEdge.getTailNodes()) {
            hyperEdge.getTailNodes().forEach(this::sanity_check_item);
        }
    }

    private void inside_estimation_hg(HyperGraph hyperGraph) {
        this.tbl_inside_prob.clear();
        this.tbl_num_parent_deductions.clear();
        inside_estimation_item(hyperGraph.goalNode);
    }

    private double inside_estimation_item(HGNode hGNode) {
        Integer num = this.tbl_num_parent_deductions.get(hGNode);
        if (null == num) {
            this.tbl_num_parent_deductions.put(hGNode, 1);
        } else {
            this.tbl_num_parent_deductions.put(hGNode, Integer.valueOf(num.intValue() + 1));
        }
        if (this.tbl_inside_prob.containsKey(hGNode)) {
            return this.tbl_inside_prob.get(hGNode).doubleValue();
        }
        double d = this.ZERO_IN_SEMIRING;
        Iterator<HyperEdge> it = hGNode.hyperedges.iterator();
        while (it.hasNext()) {
            d = add_in_semiring(d, inside_estimation_deduction(it.next(), hGNode));
        }
        this.tbl_inside_prob.put(hGNode, Double.valueOf(d));
        return d;
    }

    private double inside_estimation_deduction(HyperEdge hyperEdge, HGNode hGNode) {
        double d = this.ONE_IN_SEMIRING;
        if (hyperEdge.getTailNodes() != null) {
            Iterator<HGNode> it = hyperEdge.getTailNodes().iterator();
            while (it.hasNext()) {
                d = multi_in_semiring(d, inside_estimation_item(it.next()));
            }
        }
        return multi_in_semiring(d, getHyperedgeLogProb(hyperEdge, hGNode, this.scaling_factor));
    }

    private void outside_estimation_hg(HyperGraph hyperGraph) {
        this.tbl_outside_prob.clear();
        this.tbl_outside_prob.put(hyperGraph.goalNode, Double.valueOf(this.ONE_IN_SEMIRING));
        Iterator<HyperEdge> it = hyperGraph.goalNode.hyperedges.iterator();
        while (it.hasNext()) {
            outside_estimation_deduction(it.next(), hyperGraph.goalNode);
        }
    }

    private void outside_estimation_item(HGNode hGNode, HGNode hGNode2, HyperEdge hyperEdge, double d) {
        Integer num = this.tbl_num_parent_deductions.get(hGNode);
        if (null == num || 0 == num.intValue()) {
            throw new RuntimeException("un-expected call, must be wrong");
        }
        this.tbl_num_parent_deductions.put(hGNode, Integer.valueOf(num.intValue() - 1));
        double d2 = this.ZERO_IN_SEMIRING;
        if (this.tbl_outside_prob.containsKey(hGNode)) {
            d2 = this.tbl_outside_prob.get(hGNode).doubleValue();
        }
        double multi_in_semiring = multi_in_semiring(this.ONE_IN_SEMIRING, d);
        if (hyperEdge.getTailNodes() != null && hyperEdge.getTailNodes().size() > 1) {
            for (HGNode hGNode3 : hyperEdge.getTailNodes()) {
                if (hGNode3 != hGNode) {
                    multi_in_semiring = multi_in_semiring(multi_in_semiring, this.tbl_inside_prob.get(hGNode3).doubleValue());
                }
            }
        }
        this.tbl_outside_prob.put(hGNode, Double.valueOf(add_in_semiring(multi_in_semiring(multi_in_semiring, this.tbl_outside_prob.get(hGNode2).doubleValue()), d2)));
        if (num.intValue() - 1 <= 0) {
            Iterator<HyperEdge> it = hGNode.hyperedges.iterator();
            while (it.hasNext()) {
                outside_estimation_deduction(it.next(), hGNode);
            }
        }
    }

    private void outside_estimation_deduction(HyperEdge hyperEdge, HGNode hGNode) {
        if (hyperEdge.getTailNodes() != null) {
            double hyperedgeLogProb = getHyperedgeLogProb(hyperEdge, hGNode, this.scaling_factor);
            Iterator<HGNode> it = hyperEdge.getTailNodes().iterator();
            while (it.hasNext()) {
                outside_estimation_item(it.next(), hGNode, hyperEdge, hyperedgeLogProb);
            }
        }
    }

    private void setup_semiring(int i, int i2) {
        this.ADD_MODE = i2;
        this.SEMIRING = i;
        if (this.SEMIRING != 1) {
            throw new RuntimeException("un-supported semiring");
        }
        if (this.ADD_MODE == 0) {
            this.ZERO_IN_SEMIRING = Double.NEGATIVE_INFINITY;
            this.ONE_IN_SEMIRING = 0.0d;
        } else if (this.ADD_MODE == 1) {
            this.ZERO_IN_SEMIRING = Double.POSITIVE_INFINITY;
            this.ONE_IN_SEMIRING = 0.0d;
        } else {
            if (this.ADD_MODE != 2) {
                throw new RuntimeException("invalid add mode");
            }
            this.ZERO_IN_SEMIRING = Double.NEGATIVE_INFINITY;
            this.ONE_IN_SEMIRING = 0.0d;
        }
    }

    private double multi_in_semiring(double d, double d2) {
        if (this.SEMIRING == 1) {
            return multi_in_log_semiring(d, d2);
        }
        throw new RuntimeException("un-supported semiring");
    }

    private double add_in_semiring(double d, double d2) {
        if (this.SEMIRING == 1) {
            return add_in_log_semiring(d, d2);
        }
        throw new RuntimeException("un-supported semiring");
    }

    private double multi_in_log_semiring(double d, double d2) {
        return d + d2;
    }

    private double add_in_log_semiring(double d, double d2) {
        if (this.ADD_MODE == 0) {
            return d == Double.NEGATIVE_INFINITY ? d2 : d2 == Double.NEGATIVE_INFINITY ? d : d2 <= d ? d + Math.log(1.0d + Math.exp(d2 - d)) : d2 + Math.log(1.0d + Math.exp(d - d2));
        }
        if (this.ADD_MODE == 1) {
            return d <= d2 ? d : d2;
        }
        if (this.ADD_MODE == 2) {
            return d >= d2 ? d : d2;
        }
        throw new RuntimeException("invalid add mode");
    }
}
