package edu.stanford.nlp.trees.ud;

import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.semgraph.SemanticGraphEdge;
import edu.stanford.nlp.semgraph.semgrex.SemgrexMatcher;
import edu.stanford.nlp.semgraph.semgrex.SemgrexPattern;
import edu.stanford.nlp.semgraph.semgrex.ssurgeon.SsurgeonPattern;
import edu.stanford.nlp.trees.GrammaticalRelation;
import edu.stanford.nlp.trees.UniversalEnglishGrammaticalRelations;
import edu.stanford.nlp.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.regex.Pattern;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:BOOT-INF/lib/stanford-corenlp-4.5.6.jar:edu/stanford/nlp/trees/ud/UniversalGappingEnhancer.class */
public class UniversalGappingEnhancer {
    private static Embedding embeddings;
    private static double GAP_PENALTY = -10.0d;
    private static double POS_MISMATCH_PENALTY = -2.0d;
    private static double EDGE_WEIGHT = Double.NEGATIVE_INFINITY;
    private static final HashMap<String, String> coarserUPOSMap = new HashMap() { // from class: edu.stanford.nlp.trees.ud.UniversalGappingEnhancer.1
        {
            put("PROPN", "NOUN");
            put("PRON", "NOUN");
            put("NUM", "NOUN");
            put("DET", "NOUN");
        }
    };
    private static final SemgrexPattern ORPHAN_PATTERN = SemgrexPattern.compile("{}=orphangov < {}=conjgov >orphan {}");
    private static final Pattern ARGUMENT_PATTERN = Pattern.compile("^(i?obj|(n|c)subj.*|(x|c)comp|nmod(:tmod|:npadvmod)?|obl.*|advcl|acl|compound:prt)$");
    private static final Pattern CLAUSAL_ARGUMENT_PATTERN = Pattern.compile("^(csubj.*|(x|c)comp|advcl|acl)$");
    private static final Pattern MODIFIER_PATTERN = Pattern.compile("^(amod|advmod|nmod|obl|acl|mark|case|compound|flat)$");
    private static final Pattern CORE_ARGUMENTS_PATTERN = Pattern.compile("^((n|c)subj.*|(x|c)comp|i?obj|expl|compound:prt)$");
    private static final SemgrexPattern CONJ_PATTERN = SemgrexPattern.compile("{}=predicate > ({}=arg1 >conj {}=conjdep) > {}=arg2");

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:BOOT-INF/lib/stanford-corenlp-4.5.6.jar:edu/stanford/nlp/trees/ud/UniversalGappingEnhancer$ArgumentSequence.class */
    public static class ArgumentSequence {
        IndexedWord head;
        List<IndexedWord> sequence;

        private ArgumentSequence(IndexedWord indexedWord, List<IndexedWord> list) {
            this.head = indexedWord;
            this.sequence = list;
        }

        public String toString() {
            return this.sequence.toString();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public SimpleMatrix getAverageEmbeddings() {
            SimpleMatrix simpleMatrix = new SimpleMatrix(new double[UniversalGappingEnhancer.embeddings.getEmbeddingSize()][1]);
            Iterator<IndexedWord> it = this.sequence.iterator();
            while (it.hasNext()) {
                SimpleMatrix simpleMatrix2 = UniversalGappingEnhancer.embeddings.get(it.next().word().toLowerCase());
                if (simpleMatrix2 != null) {
                    simpleMatrix = simpleMatrix.plus(simpleMatrix2);
                }
            }
            return simpleMatrix.divide(this.sequence.size());
        }
    }

    private static final String coarsenUPOSTag(String str) {
        return coarserUPOSMap.containsKey(str) ? coarserUPOSMap.get(str) : str;
    }

    private static final Pair<Double, List<Integer>> align(List<ArgumentSequence> list, List<ArgumentSequence> list2) {
        int size = list.size();
        int size2 = list2.size();
        double[][] dArr = new double[size + 1][size2 + 1];
        int[][][] iArr = new int[size + 1][size2 + 1][2];
        for (int i = 0; i < size + 1; i++) {
            dArr[i][0] = i * GAP_PENALTY;
            iArr[i][0][0] = i - 1;
            iArr[i][0][1] = 0;
        }
        for (int i2 = 0; i2 < size2 + 1; i2++) {
            dArr[0][i2] = i2 * GAP_PENALTY;
            iArr[0][i2][0] = 0;
            iArr[0][i2][1] = i2 - 1;
        }
        for (int i3 = 1; i3 < size + 1; i3++) {
            for (int i4 = 1; i4 < size2 + 1; i4++) {
                double normF = (dArr[i3 - 1][i4 - 1] - (embeddings != null ? list.get(i3 - 1).getAverageEmbeddings().minus(list2.get(i4 - 1).getAverageEmbeddings()).normF() : 0.0d)) + (coarsenUPOSTag((String) list.get(i3 - 1).head.get(CoreAnnotations.CoarseTagAnnotation.class)).equals(coarsenUPOSTag((String) list2.get(i4 - 1).head.get(CoreAnnotations.CoarseTagAnnotation.class))) ? 0.0d : POS_MISMATCH_PENALTY);
                double d = dArr[i3 - 1][i4] + GAP_PENALTY;
                double d2 = dArr[i3][i4 - 1] + GAP_PENALTY;
                iArr[i3][i4][0] = (normF < d || normF < d2) ? (d <= normF || d < d2) ? i3 : i3 - 1 : i3 - 1;
                iArr[i3][i4][1] = (normF < d || normF < d2) ? (d <= normF || d < d2) ? i4 - 1 : i4 : i4 - 1;
                dArr[i3][i4] = Math.max(normF, Math.max(d, d2));
            }
        }
        int i5 = size;
        int i6 = size2;
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        while (true) {
            if (i5 <= 0 && i6 <= 0) {
                break;
            }
            int i7 = iArr[i5][i6][0];
            int i8 = iArr[i5][i6][1];
            if (i7 == i5 - 1 && i8 == i6 - 1) {
                linkedList.add(Integer.valueOf(i7));
                linkedList2.add(Integer.valueOf(i8));
            } else if (i7 == i5 - 1 && i8 == i6) {
                linkedList.add(Integer.valueOf(i7));
                linkedList2.add(-1);
            } else {
                linkedList.add(-1);
                linkedList2.add(Integer.valueOf(i8));
            }
            i5 = i7;
            i6 = i8;
        }
        Collections.reverse(linkedList);
        Collections.reverse(linkedList2);
        double d3 = dArr[size][size2];
        ArrayList arrayList = new ArrayList(size2);
        for (int i9 = 0; i9 < linkedList2.size(); i9++) {
            if (((Integer) linkedList2.get(i9)).intValue() > -1) {
                arrayList.add(linkedList.get(i9));
            }
        }
        return new Pair<>(Double.valueOf(d3), arrayList);
    }

    private static final Pair<IndexedWord, IndexedWord> getConjGovOrphanGovPair(SemanticGraph semanticGraph) {
        SemgrexMatcher matcher = ORPHAN_PATTERN.matcher(semanticGraph);
        IndexedWord indexedWord = null;
        IndexedWord indexedWord2 = null;
        while (matcher.find()) {
            IndexedWord node = matcher.getNode("conjgov");
            IndexedWord node2 = matcher.getNode("orphangov");
            if (indexedWord2 == null || indexedWord2.index() > node2.index()) {
                indexedWord = node;
                indexedWord2 = node2;
            }
        }
        if (indexedWord2 != null) {
            return new Pair<>(indexedWord, indexedWord2);
        }
        return null;
    }

    private static final boolean isArgument(SemanticGraph semanticGraph, SemanticGraphEdge semanticGraphEdge) {
        if (!ARGUMENT_PATTERN.matcher(semanticGraphEdge.getRelation().toString()).matches()) {
            return false;
        }
        Iterator<SemanticGraphEdge> it = semanticGraph.outgoingEdgeIterable(semanticGraphEdge.getDependent()).iterator();
        while (it.hasNext()) {
            if (it.next().getRelation().equals(UniversalEnglishGrammaticalRelations.ORPHAN)) {
                return false;
            }
        }
        return true;
    }

    private static final boolean isClausalArgument(SemanticGraph semanticGraph, SemanticGraphEdge semanticGraphEdge) {
        if (!CLAUSAL_ARGUMENT_PATTERN.matcher(semanticGraphEdge.getRelation().toString()).matches()) {
            return false;
        }
        Iterator<SemanticGraphEdge> it = semanticGraph.outgoingEdgeIterable(semanticGraphEdge.getDependent()).iterator();
        while (it.hasNext()) {
            if (it.next().getRelation().equals(UniversalEnglishGrammaticalRelations.ORPHAN)) {
                return false;
            }
        }
        return true;
    }

    private static final void getArgumentSubsequences(SemanticGraph semanticGraph, IndexedWord indexedWord, List<ArgumentSequence> list) {
        for (SemanticGraphEdge semanticGraphEdge : semanticGraph.outgoingEdgeIterable(indexedWord)) {
            if (isArgument(semanticGraph, semanticGraphEdge)) {
                list.add(new ArgumentSequence(semanticGraphEdge.getDependent(), semanticGraph.yield(semanticGraphEdge.getDependent())));
                if (isClausalArgument(semanticGraph, semanticGraphEdge)) {
                    getArgumentSubsequences(semanticGraph, semanticGraphEdge.getDependent(), list);
                }
            }
        }
    }

    private static final List<List<ArgumentSequence>> getFullConjunctArgumentsHelper(SemanticGraph semanticGraph, IndexedWord indexedWord, IndexedWord indexedWord2) {
        LinkedList linkedList = new LinkedList();
        for (SemanticGraphEdge semanticGraphEdge : semanticGraph.outgoingEdgeIterable(indexedWord)) {
            if (isArgument(semanticGraph, semanticGraphEdge) && semanticGraphEdge.getDependent().pseudoPosition() < indexedWord2.pseudoPosition()) {
                LinkedList linkedList2 = new LinkedList();
                linkedList2.add(new ArgumentSequence(semanticGraphEdge.getDependent(), semanticGraph.yield(semanticGraphEdge.getDependent())));
                getArgumentSubsequences(semanticGraph, semanticGraphEdge.getDependent(), linkedList2);
                linkedList.add(linkedList2);
            }
        }
        linkedList.sort((list, list2) -> {
            return ((ArgumentSequence) list.get(0)).head.index() - ((ArgumentSequence) list2.get(0)).head.index();
        });
        return linkedList;
    }

    private static final void buildAllArgumentSequences(int i, List<ArgumentSequence> list, List<List<ArgumentSequence>> list2, List<List<ArgumentSequence>> list3) {
        int size = list2.size();
        for (ArgumentSequence argumentSequence : list2.get(i)) {
            ArrayList arrayList = new ArrayList(size);
            arrayList.addAll(list);
            arrayList.add(argumentSequence);
            if (size == i + 1) {
                list3.add(arrayList);
            } else {
                buildAllArgumentSequences(i + 1, arrayList, list2, list3);
            }
        }
    }

    private static final List<List<ArgumentSequence>> getFullConjunctArguments(SemanticGraph semanticGraph, IndexedWord indexedWord, IndexedWord indexedWord2) {
        List<List<ArgumentSequence>> fullConjunctArgumentsHelper = getFullConjunctArgumentsHelper(semanticGraph, indexedWord, indexedWord2);
        int i = fullConjunctArgumentsHelper.size() > 0 ? 1 : 0;
        Iterator<List<ArgumentSequence>> it = fullConjunctArgumentsHelper.iterator();
        while (it.hasNext()) {
            i *= it.next().size();
        }
        ArrayList arrayList = new ArrayList(i);
        if (i > 0) {
            buildAllArgumentSequences(0, new LinkedList(), fullConjunctArgumentsHelper, arrayList);
        }
        return arrayList;
    }

    private static final boolean isModifier(SemanticGraphEdge semanticGraphEdge) {
        return MODIFIER_PATTERN.matcher(semanticGraphEdge.getRelation().toString()).matches();
    }

    private static final ArgumentSequence getOrphanGovSequence(SemanticGraph semanticGraph, IndexedWord indexedWord) {
        LinkedList linkedList = new LinkedList();
        linkedList.add(indexedWord);
        for (SemanticGraphEdge semanticGraphEdge : semanticGraph.outgoingEdgeIterable(indexedWord)) {
            if (isModifier(semanticGraphEdge)) {
                linkedList.addAll(semanticGraph.yield(semanticGraphEdge.getDependent()));
            }
        }
        Collections.sort(linkedList);
        return new ArgumentSequence(indexedWord, linkedList);
    }

    private static final List<ArgumentSequence> getGappedConjunctArguments(SemanticGraph semanticGraph, IndexedWord indexedWord) {
        LinkedList linkedList = new LinkedList();
        for (SemanticGraphEdge semanticGraphEdge : semanticGraph.outgoingEdgeIterable(indexedWord)) {
            if (semanticGraphEdge.getRelation().equals(UniversalEnglishGrammaticalRelations.ORPHAN)) {
                linkedList.add(new ArgumentSequence(semanticGraphEdge.getDependent(), semanticGraph.yield(semanticGraphEdge.getDependent())));
            }
        }
        linkedList.add(getOrphanGovSequence(semanticGraph, indexedWord));
        linkedList.sort((argumentSequence, argumentSequence2) -> {
            return argumentSequence.head.compareTo(argumentSequence2.head);
        });
        return linkedList;
    }

    private static final void doEnhancement(SemanticGraph semanticGraph, IndexedWord indexedWord, IndexedWord indexedWord2, List<ArgumentSequence> list, List<ArgumentSequence> list2, List<Integer> list3) {
        HashMap hashMap = new HashMap();
        IndexedWord makeSoftCopy = indexedWord.makeSoftCopy();
        makeSoftCopy.setPseudoPosition(makeSoftCopy.pseudoPosition() + (makeSoftCopy.copyCount() / 10.0d));
        SemanticGraphEdge edge = semanticGraph.getEdge(indexedWord, indexedWord2);
        semanticGraph.removeEdge(edge);
        semanticGraph.addEdge(indexedWord, makeSoftCopy, edge.getRelation(), EDGE_WEIGHT, false);
        hashMap.put(indexedWord, makeSoftCopy);
        for (int i = 0; i < list2.size(); i++) {
            IndexedWord indexedWord3 = list2.get(i).head;
            if (semanticGraph.hasParentWithReln(indexedWord3, UniversalEnglishGrammaticalRelations.ORPHAN)) {
                semanticGraph.removeEdge(semanticGraph.getEdge(indexedWord2, indexedWord3));
            }
            int intValue = list3.get(i).intValue();
            if (intValue < 0) {
                semanticGraph.addEdge(makeSoftCopy, indexedWord3, GrammaticalRelation.DEPENDENT, EDGE_WEIGHT, false);
            } else {
                List<SemanticGraphEdge> shortestDirectedPathEdges = semanticGraph.getShortestDirectedPathEdges(indexedWord, list.get(intValue).head);
                int i2 = 0;
                while (i2 < shortestDirectedPathEdges.size()) {
                    SemanticGraphEdge semanticGraphEdge = shortestDirectedPathEdges.get(i2);
                    boolean z = false;
                    IndexedWord indexedWord4 = (IndexedWord) hashMap.get(semanticGraphEdge.getGovernor());
                    if (indexedWord4 == null) {
                        IndexedWord makeSoftCopy2 = semanticGraphEdge.getGovernor().makeSoftCopy();
                        makeSoftCopy2.setPseudoPosition(makeSoftCopy2.pseudoPosition() + (makeSoftCopy2.copyCount() / 10.0d));
                        hashMap.put(semanticGraphEdge.getGovernor(), makeSoftCopy2);
                        z = true;
                        indexedWord4 = makeSoftCopy2;
                    }
                    IndexedWord indexedWord5 = i2 < shortestDirectedPathEdges.size() - 1 ? (IndexedWord) hashMap.get(semanticGraphEdge.getDependent()) : indexedWord3;
                    if (indexedWord5 == null) {
                        IndexedWord makeSoftCopy3 = semanticGraphEdge.getDependent().makeSoftCopy();
                        makeSoftCopy3.setPseudoPosition(makeSoftCopy3.pseudoPosition() + (makeSoftCopy3.copyCount() / 10.0d));
                        hashMap.put(semanticGraphEdge.getDependent(), makeSoftCopy3);
                        z = true;
                        indexedWord5 = makeSoftCopy3;
                    }
                    if (indexedWord5 == indexedWord3 || z) {
                        semanticGraph.addEdge(indexedWord4, indexedWord5, semanticGraphEdge.getRelation(), EDGE_WEIGHT, false);
                    }
                    i2++;
                }
            }
        }
        for (IndexedWord indexedWord6 : hashMap.keySet()) {
            for (SemanticGraphEdge semanticGraphEdge2 : semanticGraph.outgoingEdgeIterable(indexedWord6)) {
                if (CORE_ARGUMENTS_PATTERN.matcher(semanticGraphEdge2.getRelation().toString()).matches()) {
                    IndexedWord indexedWord7 = (IndexedWord) hashMap.get(indexedWord6);
                    if (!semanticGraph.hasChildWithReln(indexedWord7, semanticGraphEdge2.getRelation())) {
                        semanticGraph.addEdge(indexedWord7, semanticGraphEdge2.getDependent(), semanticGraphEdge2.getRelation(), EDGE_WEIGHT, false);
                    }
                }
            }
        }
        SemanticGraph makeSoftCopy4 = semanticGraph.makeSoftCopy();
        for (IndexedWord indexedWord8 : hashMap.values()) {
            SemgrexMatcher matcher = CONJ_PATTERN.matcher(makeSoftCopy4, indexedWord8);
            while (matcher.find()) {
                IndexedWord node = matcher.getNode(SsurgeonPattern.PREDICATE_TAG);
                IndexedWord node2 = matcher.getNode("arg1");
                IndexedWord node3 = matcher.getNode("conjdep");
                IndexedWord node4 = matcher.getNode("arg2");
                if (node == indexedWord8 && node2 == indexedWord2 && node4.pseudoPosition() > node2.pseudoPosition() && node3.pseudoPosition() > node4.pseudoPosition()) {
                    semanticGraph.removeEdge(semanticGraph.getEdge(node2, node3));
                    semanticGraph.addEdge(node, node3, UniversalEnglishGrammaticalRelations.CONJUNCT, EDGE_WEIGHT, false);
                }
            }
        }
        for (SemanticGraphEdge semanticGraphEdge3 : semanticGraph.outgoingEdgeList(indexedWord2)) {
            if (semanticGraphEdge3.getRelation().getShortName().equals("cc")) {
                semanticGraph.removeEdge(semanticGraphEdge3);
                semanticGraph.addEdge(makeSoftCopy, semanticGraphEdge3.getDependent(), semanticGraphEdge3.getRelation(), EDGE_WEIGHT, false);
            }
        }
    }

    public static final void addEnhancements(SemanticGraph semanticGraph, Embedding embedding) {
        embeddings = embedding;
        int i = 0;
        while (true) {
            Pair<IndexedWord, IndexedWord> conjGovOrphanGovPair = getConjGovOrphanGovPair(semanticGraph);
            if (conjGovOrphanGovPair == null) {
                break;
            }
            i++;
            if (i >= 10) {
                break;
            }
            IndexedWord first = conjGovOrphanGovPair.first();
            IndexedWord second = conjGovOrphanGovPair.second();
            List<List<ArgumentSequence>> fullConjunctArguments = getFullConjunctArguments(semanticGraph, first, second);
            List<ArgumentSequence> gappedConjunctArguments = getGappedConjunctArguments(semanticGraph, second);
            List<Integer> list = null;
            List<ArgumentSequence> list2 = null;
            Double valueOf = Double.valueOf(Double.NEGATIVE_INFINITY);
            for (List<ArgumentSequence> list3 : fullConjunctArguments) {
                Pair<Double, List<Integer>> align = align(list3, gappedConjunctArguments);
                double doubleValue = align.first.doubleValue();
                List<Integer> second2 = align.second();
                if (doubleValue > valueOf.doubleValue()) {
                    valueOf = Double.valueOf(doubleValue);
                    list = second2;
                    list2 = list3;
                }
            }
            if (list2 != null) {
                doEnhancement(semanticGraph, first, second, list2, gappedConjunctArguments, list);
            }
        }
        if (i == 10) {
            System.err.println("Problem with graph:");
            System.err.println(semanticGraph.toString(SemanticGraph.OutputFormat.READABLE));
        }
    }
}
