package org.apache.joshua.decoder.hypergraph;

import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import org.apache.joshua.corpus.Vocabulary;
import org.apache.joshua.decoder.BLEU;
import org.apache.joshua.decoder.JoshuaConfiguration;
import org.apache.joshua.decoder.StructuredTranslation;
import org.apache.joshua.decoder.StructuredTranslationFactory;
import org.apache.joshua.decoder.ff.FeatureFunction;
import org.apache.joshua.decoder.ff.FeatureVector;
import org.apache.joshua.decoder.ff.fragmentlm.Tree;
import org.apache.joshua.decoder.ff.lm.ArpaNgram;
import org.apache.joshua.decoder.ff.state_maintenance.DPState;
import org.apache.joshua.decoder.ff.tm.Rule;
import org.apache.joshua.decoder.io.DeNormalize;
import org.apache.joshua.decoder.segment_file.Sentence;
import org.apache.joshua.decoder.segment_file.Token;
import org.apache.joshua.util.Constants;
import org.apache.joshua.util.FormatUtils;

/* loaded from: input_file:org/apache/joshua/decoder/hypergraph/KBestExtractor.class */
public class KBestExtractor {
    private final JoshuaConfiguration joshuaConfiguration;
    private final String outputFormat;
    private final HashMap<HGNode, VirtualNode> virtualNodesTable = new HashMap<>();
    static final String rootSym = "ROOT";
    static final int rootID = Vocabulary.id("ROOT");
    private final boolean extractUniqueNbest;
    private final Side defaultSide;
    private final Sentence sentence;
    private final FeatureVector weights;
    private final List<FeatureFunction> featureFunctions;
    private BLEU.References references;

    /* loaded from: input_file:org/apache/joshua/decoder/hypergraph/KBestExtractor$DerivationExtractor.class */
    public class DerivationExtractor implements DerivationVisitor {
        final StringBuffer sb = new StringBuffer();

        public DerivationExtractor() {
        }

        @Override // org.apache.joshua.decoder.hypergraph.KBestExtractor.DerivationVisitor
        public void before(DerivationState derivationState, int i, int i2) {
            Rule rule = derivationState.edge.getRule();
            if (rule != null) {
                for (int i3 = 0; i3 < i * 2; i3++) {
                    this.sb.append(" ");
                }
                FeatureVectorExtractor featureVectorExtractor = new FeatureVectorExtractor(KBestExtractor.this.featureFunctions, KBestExtractor.this.sentence);
                featureVectorExtractor.before(derivationState, i, i2);
                FeatureVector features = featureVectorExtractor.getFeatures();
                this.sb.append(String.format("%d-%d", Integer.valueOf(derivationState.parentNode.i), Integer.valueOf(derivationState.parentNode.j)));
                this.sb.append(" ||| ").append(Vocabulary.word(rule.getLHS())).append(" -> ").append(Vocabulary.getWords(rule.getFrench())).append(" /// ").append(rule.getEnglishWords());
                this.sb.append(" |||");
                Iterator<DPState> it = derivationState.parentNode.getDPStates().iterator();
                while (it.hasNext()) {
                    this.sb.append(" ").append(it.next());
                }
                this.sb.append(" ||| ").append(features);
                this.sb.append(" ||| ").append(KBestExtractor.this.weights.innerProduct(features));
                if (rule.getAlignment() != null) {
                    this.sb.append(" ||| ").append(Arrays.toString(rule.getAlignment()));
                }
                this.sb.append("\n");
            }
        }

        public String toString() {
            return this.sb.toString();
        }

        @Override // org.apache.joshua.decoder.hypergraph.KBestExtractor.DerivationVisitor
        public void after(DerivationState derivationState, int i, int i2) {
        }
    }

    /* loaded from: input_file:org/apache/joshua/decoder/hypergraph/KBestExtractor$DerivationState.class */
    public class DerivationState {
        public final HyperEdge edge;
        public final HGNode parentNode;
        public final int edgePos;
        public final int[] ranks;
        private float cost;
        private float bleu;
        BLEU.Stats stats = null;

        public DerivationState(HGNode hGNode, HyperEdge hyperEdge, int[] iArr, float f, int i) {
            this.bleu = ArpaNgram.DEFAULT_BACKOFF;
            this.parentNode = hGNode;
            this.edge = hyperEdge;
            this.ranks = iArr;
            this.cost = f;
            this.edgePos = i;
            this.bleu = ArpaNgram.DEFAULT_BACKOFF;
        }

        public float computeBLEU() {
            if (this.stats == null) {
                this.stats = BLEU.compute(this.edge, (1.0f * (this.parentNode.j - this.parentNode.i)) / KBestExtractor.this.sentence.length(), KBestExtractor.this.references);
                if (this.edge.getTailNodes() != null) {
                    for (int i = 0; i < this.edge.getTailNodes().size(); i++) {
                        this.stats.add(getChildDerivationState(this.edge, i).stats);
                    }
                }
            }
            return BLEU.score(this.stats);
        }

        public void setCost(float f) {
            this.cost = f;
        }

        public float getModelCost() {
            return this.cost;
        }

        public float getCost() {
            return this.cost - (KBestExtractor.this.weights.getSparse("BLEU") * this.bleu);
        }

        public String toString() {
            StringBuilder sb = new StringBuilder(String.format("DS[[ %s (%d,%d)/%d ||| ", Vocabulary.word(this.parentNode.lhs), Integer.valueOf(this.parentNode.i), Integer.valueOf(this.parentNode.j), Integer.valueOf(this.edgePos)));
            sb.append("ranks=[ ");
            if (this.ranks != null) {
                for (int i : this.ranks) {
                    sb.append(i + " ");
                }
            }
            sb.append("] ||| ").append(String.format("%.5f ]]", Float.valueOf(this.cost)));
            return sb.toString();
        }

        public boolean equals(Object obj) {
            if (!(obj instanceof DerivationState)) {
                return false;
            }
            DerivationState derivationState = (DerivationState) obj;
            if (this.edgePos != derivationState.edgePos || this.ranks == null || derivationState.ranks == null || this.ranks.length != derivationState.ranks.length) {
                return false;
            }
            for (int i = 0; i < this.ranks.length; i++) {
                if (this.ranks[i] != derivationState.ranks[i]) {
                    return false;
                }
            }
            return true;
        }

        public int hashCode() {
            int i = this.edgePos;
            if (this.ranks != null) {
                for (int i2 = 0; i2 < this.ranks.length; i2++) {
                    i = (i * 53) + i2;
                }
            }
            return i;
        }

        private DerivationVisitor visit(DerivationVisitor derivationVisitor) {
            return visit(derivationVisitor, 0, 0);
        }

        private DerivationVisitor visit(DerivationVisitor derivationVisitor, int i, int i2) {
            derivationVisitor.before(this, i, i2);
            Rule rule = this.edge.getRule();
            List<HGNode> tailNodes = this.edge.getTailNodes();
            if (rule == null) {
                getChildDerivationState(this.edge, 0).visit(derivationVisitor, i + 1, 0);
            } else if (tailNodes != null) {
                for (int i3 = 0; i3 < tailNodes.size(); i3++) {
                    getChildDerivationState(this.edge, i3).visit(derivationVisitor, i + 1, i3);
                }
            }
            derivationVisitor.after(this, i, i2);
            return derivationVisitor;
        }

        public String getWordAlignment() {
            return visit(new WordAlignmentExtractor()).toString();
        }

        public List<List<Integer>> getWordAlignmentList() {
            WordAlignmentExtractor wordAlignmentExtractor = new WordAlignmentExtractor();
            visit(wordAlignmentExtractor);
            return wordAlignmentExtractor.getFinalWordAlignments();
        }

        public String getTree() {
            return visit(new TreeExtractor()).toString();
        }

        public String getHypothesis() {
            return getHypothesis(KBestExtractor.this.defaultSide);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public String getHypothesis(Side side) {
            return visit(new OutputStringExtractor(side.equals(Side.SOURCE))).toString();
        }

        public FeatureVector getFeatures() {
            FeatureVectorExtractor featureVectorExtractor = new FeatureVectorExtractor(KBestExtractor.this.featureFunctions, KBestExtractor.this.sentence);
            visit(featureVectorExtractor);
            return featureVectorExtractor.getFeatures();
        }

        public String getDerivation() {
            return visit(new DerivationExtractor()).toString();
        }

        public DerivationState getChildDerivationState(HyperEdge hyperEdge, int i) {
            return KBestExtractor.this.getVirtualNode(hyperEdge.getTailNodes().get(i)).nbests.get(this.ranks[i] - 1);
        }
    }

    /* loaded from: input_file:org/apache/joshua/decoder/hypergraph/KBestExtractor$DerivationStateComparator.class */
    public static class DerivationStateComparator implements Comparator<DerivationState> {
        @Override // java.util.Comparator
        public int compare(DerivationState derivationState, DerivationState derivationState2) {
            if (derivationState.getCost() > derivationState2.getCost()) {
                return -1;
            }
            return derivationState.getCost() == derivationState2.getCost() ? 0 : 1;
        }
    }

    /* loaded from: input_file:org/apache/joshua/decoder/hypergraph/KBestExtractor$DerivationVisitor.class */
    public interface DerivationVisitor {
        void before(DerivationState derivationState, int i, int i2);

        void after(DerivationState derivationState, int i, int i2);
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/joshua/decoder/hypergraph/KBestExtractor$Side.class */
    public enum Side {
        SOURCE,
        TARGET
    }

    /* loaded from: input_file:org/apache/joshua/decoder/hypergraph/KBestExtractor$TreeExtractor.class */
    public class TreeExtractor implements DerivationVisitor {
        private Tree tree = null;

        public TreeExtractor() {
        }

        @Override // org.apache.joshua.decoder.hypergraph.KBestExtractor.DerivationVisitor
        public void before(DerivationState derivationState, int i, int i2) {
            Rule rule = derivationState.edge.getRule();
            if (rule == null) {
                return;
            }
            String word = Vocabulary.word(rule.getLHS());
            String substring = word.substring(1, word.length() - 1);
            Tree fragmentFromYield = Tree.getFragmentFromYield(rule.getEnglishWords());
            if (fragmentFromYield == null) {
                fragmentFromYield = Tree.fromString(String.format("(%s{%d-%d} %s)", substring, Integer.valueOf(derivationState.parentNode.i), Integer.valueOf(derivationState.parentNode.j), quoteTerminals(rule.getEnglishWords())));
            }
            merge(fragmentFromYield);
        }

        private String quoteTerminals(String str) {
            StringBuilder sb = new StringBuilder();
            for (String str2 : str.split(Constants.spaceSeparator)) {
                if (str2.startsWith("[") && str2.endsWith("]")) {
                    sb.append(String.format("%s ", str2));
                } else {
                    sb.append(String.format("\"%s\" ", str2));
                }
            }
            return sb.substring(0, sb.length() - 1);
        }

        @Override // org.apache.joshua.decoder.hypergraph.KBestExtractor.DerivationVisitor
        public void after(DerivationState derivationState, int i, int i2) {
        }

        public String toString() {
            return this.tree.unquotedString();
        }

        private void merge(Tree tree) {
            if (this.tree == null) {
                this.tree = tree;
                return;
            }
            Tree tree2 = this.tree.getNonterminalYield().get(0);
            tree2.setLabel(Vocabulary.word(tree.getLabel()));
            tree2.setChildren(tree.getChildren());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/joshua/decoder/hypergraph/KBestExtractor$VirtualNode.class */
    public class VirtualNode {
        HGNode node;
        public final List<DerivationState> nbests = new ArrayList();
        private PriorityQueue<DerivationState> candHeap = null;
        private HashSet<DerivationState> derivationTable = null;
        private HashSet<String> uniqueStringsTable = null;

        public VirtualNode(HGNode hGNode) {
            this.node = null;
            this.node = hGNode;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public DerivationState lazyKBestExtractOnNode(KBestExtractor kBestExtractor, int i) {
            if (this.nbests.size() >= i) {
                return this.nbests.get(i - 1);
            }
            DerivationState derivationState = null;
            if (null == this.candHeap) {
                getCandidates(kBestExtractor);
            }
            int i2 = 0;
            while (this.nbests.size() < i && this.candHeap.size() > 0) {
                derivationState = this.candHeap.poll();
                if (KBestExtractor.this.extractUniqueNbest) {
                    String hypothesis = derivationState.getHypothesis();
                    if (!this.uniqueStringsTable.contains(hypothesis)) {
                        this.nbests.add(derivationState);
                        this.uniqueStringsTable.add(hypothesis);
                    }
                } else {
                    this.nbests.add(derivationState);
                }
                lazyNext(kBestExtractor, derivationState);
                i2++;
                if (!KBestExtractor.this.extractUniqueNbest && i2 > 1) {
                    throw new RuntimeException("In lazyKBestExtractOnNode, add more than one time, k is " + i);
                }
            }
            if (this.nbests.size() < i) {
                derivationState = null;
            }
            return derivationState;
        }

        private void lazyNext(KBestExtractor kBestExtractor, DerivationState derivationState) {
            if (null == derivationState.edge.getTailNodes()) {
                return;
            }
            for (int i = 0; i < derivationState.edge.getTailNodes().size(); i++) {
                VirtualNode virtualNode = kBestExtractor.getVirtualNode(derivationState.edge.getTailNodes().get(i));
                int[] iArr = new int[derivationState.ranks.length];
                System.arraycopy(derivationState.ranks, 0, iArr, 0, iArr.length);
                iArr[i] = derivationState.ranks[i] + 1;
                DerivationState derivationState2 = new DerivationState(derivationState.parentNode, derivationState.edge, iArr, ArpaNgram.DEFAULT_BACKOFF, derivationState.edgePos);
                if (!this.derivationTable.contains(derivationState2)) {
                    virtualNode.lazyKBestExtractOnNode(kBestExtractor, iArr[i]);
                    if (iArr[i] <= virtualNode.nbests.size()) {
                        derivationState2.setCost((derivationState.getModelCost() - virtualNode.nbests.get(derivationState.ranks[i] - 1).getModelCost()) + virtualNode.nbests.get(iArr[i] - 1).getModelCost());
                        if (KBestExtractor.this.joshuaConfiguration.rescoreForest) {
                            derivationState2.bleu = derivationState2.computeBLEU();
                        }
                        this.candHeap.add(derivationState2);
                        this.derivationTable.add(derivationState2);
                    }
                }
            }
        }

        private void getCandidates(KBestExtractor kBestExtractor) {
            this.candHeap = new PriorityQueue<>(11, new DerivationStateComparator());
            this.derivationTable = new HashSet<>();
            if (KBestExtractor.this.extractUniqueNbest) {
                this.uniqueStringsTable = new HashSet<>();
            }
            int i = 0;
            Iterator<HyperEdge> it = this.node.hyperedges.iterator();
            while (it.hasNext()) {
                DerivationState bestDerivation = getBestDerivation(kBestExtractor, this.node, it.next(), i);
                if (this.derivationTable.contains(bestDerivation)) {
                    throw new RuntimeException("get duplicate derivation in get_candidates, this should not happen\nsignature is " + bestDerivation + "\nl_hyperedge size is " + this.node.hyperedges.size());
                }
                this.candHeap.add(bestDerivation);
                this.derivationTable.add(bestDerivation);
                i++;
            }
        }

        private DerivationState getBestDerivation(KBestExtractor kBestExtractor, HGNode hGNode, HyperEdge hyperEdge, int i) {
            int[] iArr;
            if (hyperEdge.getTailNodes() == null) {
                iArr = null;
            } else {
                iArr = new int[hyperEdge.getTailNodes().size()];
                for (int i2 = 0; i2 < hyperEdge.getTailNodes().size(); i2++) {
                    iArr[i2] = 1;
                    kBestExtractor.getVirtualNode(hyperEdge.getTailNodes().get(i2)).lazyKBestExtractOnNode(kBestExtractor, iArr[i2]);
                }
            }
            DerivationState derivationState = new DerivationState(hGNode, hyperEdge, iArr, hyperEdge.getBestDerivationScore(), i);
            if (KBestExtractor.this.joshuaConfiguration.rescoreForest) {
                derivationState.bleu = derivationState.computeBLEU();
            }
            return derivationState;
        }
    }

    public KBestExtractor(Sentence sentence, List<FeatureFunction> list, FeatureVector featureVector, boolean z, JoshuaConfiguration joshuaConfiguration) {
        this.references = null;
        this.featureFunctions = list;
        this.joshuaConfiguration = joshuaConfiguration;
        this.outputFormat = this.joshuaConfiguration.outputFormat;
        this.extractUniqueNbest = joshuaConfiguration.use_unique_nbest;
        this.weights = featureVector;
        this.defaultSide = z ? Side.SOURCE : Side.TARGET;
        this.sentence = sentence;
        if (joshuaConfiguration.rescoreForest) {
            this.references = new BLEU.References(sentence.references());
        }
    }

    public DerivationState getKthDerivation(HGNode hGNode, int i) {
        return getVirtualNode(hGNode).lazyKBestExtractOnNode(this, i);
    }

    public StructuredTranslation getKthStructuredTranslation(HGNode hGNode, int i) {
        StructuredTranslation structuredTranslation = null;
        DerivationState kthDerivation = getKthDerivation(hGNode, i);
        if (kthDerivation != null) {
            structuredTranslation = StructuredTranslationFactory.fromKBestDerivation(this.sentence, kthDerivation);
        }
        return structuredTranslation;
    }

    public List<StructuredTranslation> KbestExtractOnHG(HyperGraph hyperGraph, int i) {
        StructuredTranslation kthStructuredTranslation;
        resetState();
        if (hyperGraph == null || hyperGraph.goalNode == null) {
            return Collections.emptyList();
        }
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 1; i2 <= i && (kthStructuredTranslation = getKthStructuredTranslation(hyperGraph.goalNode, i2)) != null; i2++) {
            arrayList.add(kthStructuredTranslation);
        }
        return arrayList;
    }

    public String getKthHyp(HGNode hGNode, int i) {
        String str = null;
        DerivationState kthDerivation = getKthDerivation(hGNode, i);
        if (kthDerivation != null) {
            String maybeProjectCase = maybeProjectCase(FormatUtils.unescapeSpecialSymbols(FormatUtils.removeSentenceMarkers(kthDerivation.getHypothesis())), kthDerivation);
            FeatureVector featureVector = new FeatureVector();
            if (this.outputFormat.contains("%f") || this.outputFormat.contains("%d")) {
                featureVector = kthDerivation.getFeatures();
            }
            str = this.outputFormat.replace("%k", Integer.toString(i)).replace("%s", maybeProjectCase).replace("%S", DeNormalize.processSingleLine(maybeProjectCase)).replace("%i", Integer.toString(this.sentence.id())).replace("%f", this.joshuaConfiguration.moses ? featureVector.mosesString() : featureVector.toString()).replace("%c", String.format("%.3f", Float.valueOf(kthDerivation.cost)));
            if (this.outputFormat.contains("%t")) {
                str = str.replace("%t", kthDerivation.getTree());
            }
            if (this.outputFormat.contains("%e")) {
                str = str.replace("%e", FormatUtils.removeSentenceMarkers(kthDerivation.getHypothesis(Side.SOURCE)));
            }
            if (this.outputFormat.contains("%d")) {
                str = str.replace("%d", kthDerivation.getDerivation());
            }
            if (this.outputFormat.contains("%a")) {
                str = str.replace("%a", kthDerivation.getWordAlignment());
            }
        }
        return str;
    }

    private String maybeProjectCase(String str, DerivationState derivationState) {
        String str2 = str;
        if (this.joshuaConfiguration.project_case) {
            String[] split = str.split(Constants.spaceSeparator);
            List<List<Integer>> wordAlignmentList = derivationState.getWordAlignmentList();
            for (int i = 0; i < wordAlignmentList.size(); i++) {
                Iterator<Integer> it = wordAlignmentList.get(i).iterator();
                while (it.hasNext()) {
                    int intValue = it.next().intValue();
                    Token token = this.sentence.getTokens().get(intValue + 1);
                    String str3 = "";
                    if (token != null && token.getAnnotation("lettercase") != null) {
                        str3 = token.getAnnotation("lettercase");
                    }
                    if (intValue != 0 && str3.equals("upper")) {
                        split[i] = FormatUtils.capitalize(split[i]);
                    } else if (str3.equals("all-upper")) {
                        split[i] = split[i].toUpperCase();
                    }
                }
            }
            str2 = String.join(" ", split);
        }
        return str2;
    }

    public void lazyKBestExtractOnHG(HyperGraph hyperGraph, int i) throws IOException {
        lazyKBestExtractOnHG(hyperGraph, i, new BufferedWriter(new OutputStreamWriter(System.out)));
    }

    public void lazyKBestExtractOnHG(HyperGraph hyperGraph, int i, BufferedWriter bufferedWriter) throws IOException {
        String kthHyp;
        resetState();
        if (null == hyperGraph.goalNode) {
            return;
        }
        for (int i2 = 1; i2 <= i && null != (kthHyp = getKthHyp(hyperGraph.goalNode, i2)); i2++) {
            bufferedWriter.write(kthHyp);
            bufferedWriter.write("\n");
            bufferedWriter.flush();
        }
    }

    public void resetState() {
        this.virtualNodesTable.clear();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public VirtualNode getVirtualNode(HGNode hGNode) {
        VirtualNode virtualNode = this.virtualNodesTable.get(hGNode);
        if (null == virtualNode) {
            virtualNode = new VirtualNode(hGNode);
            this.virtualNodesTable.put(hGNode, virtualNode);
        }
        return virtualNode;
    }
}
