package edu.stanford.nlp.coref.neural;

import edu.stanford.nlp.coref.data.Document;
import edu.stanford.nlp.coref.data.Mention;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.semgraph.SemanticGraphEdge;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:BOOT-INF/lib/stanford-corenlp-4.5.6.jar:edu/stanford/nlp/coref/neural/EmbeddingExtractor.class */
public class EmbeddingExtractor implements Serializable {
    private static final long serialVersionUID = -663338564691488202L;
    private final boolean conll;
    private final Embedding staticWordEmbeddings;
    private final Embedding tunedWordEmbeddings;
    private final String naEmbedding;

    public EmbeddingExtractor(boolean z, Embedding embedding, Embedding embedding2, String str) {
        this.conll = z;
        this.staticWordEmbeddings = embedding;
        this.tunedWordEmbeddings = embedding2;
        this.naEmbedding = str;
    }

    public boolean isConll() {
        return this.conll;
    }

    public Embedding getStaticWordEmbeddings() {
        return this.staticWordEmbeddings;
    }

    public Embedding getTunedWordEmbeddings() {
        return this.tunedWordEmbeddings;
    }

    public String getNAEmbedding() {
        return this.naEmbedding;
    }

    public SimpleMatrix getDocumentEmbedding(Document document) {
        if (!this.conll) {
            return new SimpleMatrix(this.staticWordEmbeddings.getEmbeddingSize(), 1);
        }
        ArrayList arrayList = new ArrayList();
        HashSet hashSet = new HashSet();
        for (Mention mention : document.predictedMentionsByID.values()) {
            if (!hashSet.contains(Integer.valueOf(mention.sentNum))) {
                hashSet.add(Integer.valueOf(mention.sentNum));
                arrayList.addAll(mention.sentenceWords);
            }
        }
        return getAverageEmbedding(arrayList);
    }

    public SimpleMatrix getMentionEmbeddingsForFast(Mention mention) {
        Iterator<SemanticGraphEdge> incomingEdgeIterator = mention.enhancedDependency.incomingEdgeIterator(mention.headIndexedWord);
        SemanticGraphEdge next = incomingEdgeIterator.hasNext() ? incomingEdgeIterator.next() : null;
        return NeuralUtils.concatenate(getWordEmbedding(mention.sentenceWords, mention.startIndex - 2), getWordEmbedding(mention.sentenceWords, mention.startIndex - 1), getWordEmbedding(mention.sentenceWords, mention.startIndex), getWordEmbedding(mention.sentenceWords, mention.headIndex), getWordEmbedding(mention.sentenceWords, mention.endIndex - 1), getWordEmbedding(mention.sentenceWords, mention.endIndex), getWordEmbedding(mention.sentenceWords, mention.endIndex + 1), getWordEmbedding(next == null ? "<missing>" : next.getSource().word()), getAverageEmbedding(mention.sentenceWords.subList(mention.startIndex, Math.min(mention.endIndex, mention.startIndex + 10))));
    }

    public SimpleMatrix getMentionEmbeddings(Mention mention, SimpleMatrix simpleMatrix) {
        Iterator<SemanticGraphEdge> incomingEdgeIterator = mention.enhancedDependency.incomingEdgeIterator(mention.headIndexedWord);
        SemanticGraphEdge next = incomingEdgeIterator.hasNext() ? incomingEdgeIterator.next() : null;
        SimpleMatrix[] simpleMatrixArr = new SimpleMatrix[13];
        simpleMatrixArr[0] = getAverageEmbedding(mention.sentenceWords, mention.startIndex, mention.endIndex);
        simpleMatrixArr[1] = getAverageEmbedding(mention.sentenceWords, mention.startIndex - 5, mention.startIndex);
        simpleMatrixArr[2] = getAverageEmbedding(mention.sentenceWords, mention.endIndex, mention.endIndex + 5);
        simpleMatrixArr[3] = getAverageEmbedding(mention.sentenceWords.subList(0, mention.sentenceWords.size() - 1));
        simpleMatrixArr[4] = simpleMatrix;
        simpleMatrixArr[5] = getWordEmbedding(mention.sentenceWords, mention.headIndex);
        simpleMatrixArr[6] = getWordEmbedding(mention.sentenceWords, mention.startIndex);
        simpleMatrixArr[7] = getWordEmbedding(mention.sentenceWords, mention.endIndex - 1);
        simpleMatrixArr[8] = getWordEmbedding(mention.sentenceWords, mention.startIndex - 1);
        simpleMatrixArr[9] = getWordEmbedding(mention.sentenceWords, mention.endIndex);
        simpleMatrixArr[10] = getWordEmbedding(mention.sentenceWords, mention.startIndex - 2);
        simpleMatrixArr[11] = getWordEmbedding(mention.sentenceWords, mention.endIndex + 1);
        simpleMatrixArr[12] = getWordEmbedding(next == null ? null : next.getSource().word());
        return NeuralUtils.concatenate(simpleMatrixArr);
    }

    private SimpleMatrix getAverageEmbedding(List<CoreLabel> list) {
        Embedding embedding = this.staticWordEmbeddings == null ? this.tunedWordEmbeddings : this.staticWordEmbeddings;
        SimpleMatrix simpleMatrix = new SimpleMatrix(embedding.getEmbeddingSize(), 1);
        Iterator<CoreLabel> it = list.iterator();
        while (it.hasNext()) {
            simpleMatrix = simpleMatrix.plus(embedding.get(normalizeWord(it.next().word())));
        }
        return simpleMatrix.divide(Math.max(1, list.size()));
    }

    private SimpleMatrix getAverageEmbedding(List<CoreLabel> list, int i, int i2) {
        return getAverageEmbedding(list.subList(Math.max(Math.min(i, list.size() - 1), 0), Math.max(Math.min(i2, list.size() - 1), 0)));
    }

    private SimpleMatrix getWordEmbedding(List<CoreLabel> list, int i) {
        return getWordEmbedding((i < 0 || i >= list.size()) ? this.naEmbedding : list.get(i).word());
    }

    public SimpleMatrix getWordEmbedding(String str) {
        String normalizeWord = normalizeWord(str);
        if (this.staticWordEmbeddings != null && !this.tunedWordEmbeddings.containsWord(normalizeWord)) {
            return this.staticWordEmbeddings.get(normalizeWord);
        }
        return this.tunedWordEmbeddings.get(normalizeWord);
    }

    private static String normalizeWord(String str) {
        if (str == null) {
            return "<missing>";
        }
        if (str.equals("/.")) {
            return ".";
        }
        if (str.equals("/?")) {
            return "?";
        }
        if (str.equals("-LRB-")) {
            return "(";
        }
        if (str.equals("-RRB-")) {
            return ")";
        }
        if (str.equals("-LCB-")) {
            return "{";
        }
        if (str.equals("-RCB-")) {
            return "}";
        }
        if (str.equals("-LSB-")) {
            return "[";
        }
        if (str.equals("-RSB-")) {
            return "]";
        }
        if (str.equals("''")) {
            str = "\"";
        } else if (str.startsWith("%") && str.length() > 1) {
            str = str.substring(1);
        }
        return str.replaceAll("\\d", "0").toLowerCase();
    }
}
