package edu.stanford.nlp.coref.fastneural;

import edu.stanford.nlp.coref.CorefAlgorithm;
import edu.stanford.nlp.coref.CorefProperties;
import edu.stanford.nlp.coref.CorefUtils;
import edu.stanford.nlp.coref.data.Dictionaries;
import edu.stanford.nlp.coref.data.Document;
import edu.stanford.nlp.coref.statistical.Compressor;
import edu.stanford.nlp.coref.statistical.DocumentExamples;
import edu.stanford.nlp.coref.statistical.Example;
import edu.stanford.nlp.coref.statistical.FeatureExtractor;
import edu.stanford.nlp.coref.statistical.StatisticalCorefProperties;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.RuntimeInterruptedException;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;

/* loaded from: input_file:BOOT-INF/lib/stanford-corenlp-4.5.6.jar:edu/stanford/nlp/coref/fastneural/FastNeuralCorefAlgorithm.class */
public class FastNeuralCorefAlgorithm implements CorefAlgorithm {
    private static Redwood.RedwoodChannels log = Redwood.channels(FastNeuralCorefAlgorithm.class);
    private final double greedyness;
    private final int maxMentionDistance;
    private final int maxMentionDistanceWithStringMatch;
    private final FeatureExtractor featureExtractor;
    private final FastNeuralCorefModel model;

    public FastNeuralCorefAlgorithm(Properties properties, Dictionaries dictionaries) {
        this.greedyness = FastNeuralCorefProperties.greedyness(properties);
        this.maxMentionDistance = CorefProperties.maxMentionDistance(properties);
        this.maxMentionDistanceWithStringMatch = CorefProperties.maxMentionDistanceWithStringMatch(properties);
        this.featureExtractor = new FeatureExtractor(properties, dictionaries, (Compressor<String>) null, StatisticalCorefProperties.wordCountsPath(properties));
        this.model = (FastNeuralCorefModel) IOUtils.readObjectAnnouncingTimingFromURLOrClasspathOrFileSystem(log, "Loading coref model...", FastNeuralCorefProperties.modelPath(properties));
    }

    @Override // edu.stanford.nlp.coref.CorefAlgorithm
    public void runCoref(Document document) {
        Map<Integer, List<Integer>> heuristicFilter = CorefUtils.heuristicFilter(CorefUtils.getSortedMentions(document), this.maxMentionDistance, this.maxMentionDistanceWithStringMatch);
        HashMap hashMap = new HashMap();
        for (Map.Entry<Integer, List<Integer>> entry : heuristicFilter.entrySet()) {
            Iterator<Integer> it = entry.getValue().iterator();
            while (it.hasNext()) {
                hashMap.put(new Pair(Integer.valueOf(it.next().intValue()), entry.getKey()), true);
            }
        }
        Compressor<String> compressor = new Compressor<>();
        DocumentExamples extract = this.featureExtractor.extract(0, document, hashMap, compressor);
        ClassicCounter classicCounter = new ClassicCounter();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        for (Example example2 : extract.examples) {
            if (Thread.interrupted()) {
                throw new RuntimeInterruptedException();
            }
            classicCounter.incrementCount(new Pair(Integer.valueOf(example2.mentionId1), Integer.valueOf(example2.mentionId2)), this.model.score(document.predictedMentionsByID.get(Integer.valueOf(example2.mentionId1)), document.predictedMentionsByID.get(Integer.valueOf(example2.mentionId2)), compressor.uncompress(extract.mentionFeatures.get(Integer.valueOf(example2.mentionId1))), compressor.uncompress(extract.mentionFeatures.get(Integer.valueOf(example2.mentionId2))), compressor.uncompress(example2.pairwiseFeatures), hashMap2, hashMap3));
        }
        Iterator<Integer> it2 = heuristicFilter.keySet().iterator();
        while (it2.hasNext()) {
            int intValue = it2.next().intValue();
            if (Thread.interrupted()) {
                throw new RuntimeInterruptedException();
            }
            classicCounter.incrementCount(new Pair(-1, Integer.valueOf(intValue)), this.model.score(null, document.predictedMentionsByID.get(Integer.valueOf(intValue)), null, compressor.uncompress(extract.mentionFeatures.get(Integer.valueOf(intValue))), null, hashMap2, hashMap3));
        }
        for (Map.Entry<Integer, List<Integer>> entry2 : heuristicFilter.entrySet()) {
            int i = -1;
            int intValue2 = entry2.getKey().intValue();
            double count = classicCounter.getCount(new Pair(-1, Integer.valueOf(intValue2))) - (50.0d * (this.greedyness - 0.5d));
            Iterator<Integer> it3 = entry2.getValue().iterator();
            while (it3.hasNext()) {
                int intValue3 = it3.next().intValue();
                double count2 = classicCounter.getCount(new Pair(Integer.valueOf(intValue3), Integer.valueOf(intValue2)));
                if (count2 > count) {
                    count = count2;
                    i = intValue3;
                }
            }
            if (i > 0) {
                CorefUtils.mergeCoreferenceClusters(new Pair(Integer.valueOf(i), Integer.valueOf(intValue2)), document);
            }
        }
    }
}
