package edu.stanford.nlp.parser.metrics;

import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.trees.Constituent;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.PrintWriter;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import java.util.regex.Pattern;

/* loaded from: input_file:BOOT-INF/lib/stanford-corenlp-4.0.0.jar:edu/stanford/nlp/parser/metrics/EvalbByCat.class */
public class EvalbByCat extends AbstractEval {
    private static Redwood.RedwoodChannels log = Redwood.channels(EvalbByCat.class);
    private final Evalb evalb;
    private Pattern pLabelFilter;
    private final Counter<Label> precisions;
    private final Counter<Label> recalls;
    private final Counter<Label> f1s;
    private final Counter<Label> precisions2;
    private final Counter<Label> recalls2;
    private final Counter<Label> pnums2;
    private final Counter<Label> rnums2;

    public EvalbByCat(String str, boolean z) {
        super(str, z);
        this.pLabelFilter = null;
        this.evalb = new Evalb(str, false);
        this.precisions = new ClassicCounter();
        this.recalls = new ClassicCounter();
        this.f1s = new ClassicCounter();
        this.precisions2 = new ClassicCounter();
        this.recalls2 = new ClassicCounter();
        this.pnums2 = new ClassicCounter();
        this.rnums2 = new ClassicCounter();
    }

    public EvalbByCat(String str, boolean z, String str2) {
        this(str, z);
        if (str2 != null) {
            this.pLabelFilter = Pattern.compile(str2.trim());
        }
    }

    @Override // edu.stanford.nlp.parser.metrics.AbstractEval
    protected Set<Constituent> makeObjects(Tree tree) {
        return this.evalb.makeObjects(tree);
    }

    private Map<Label, Set<Constituent>> makeObjectsByCat(Tree tree) {
        Map<Label, Set<Constituent>> newHashMap = Generics.newHashMap();
        for (Constituent constituent : makeObjects(tree)) {
            Label label = constituent.label();
            if (!newHashMap.keySet().contains(label)) {
                newHashMap.put(label, Generics.newHashSet());
            }
            newHashMap.get(label).add(constituent);
        }
        return newHashMap;
    }

    @Override // edu.stanford.nlp.parser.metrics.AbstractEval, edu.stanford.nlp.parser.metrics.Eval
    public void evaluate(Tree tree, Tree tree2, PrintWriter printWriter) {
        if (tree2 == null || tree == null) {
            System.err.printf("%s: Cannot compare against a null gold or guess tree!%n", getClass().getName());
            return;
        }
        Map<Label, Set<Constituent>> makeObjectsByCat = makeObjectsByCat(tree);
        Map<Label, Set<Constituent>> makeObjectsByCat2 = makeObjectsByCat(tree2);
        Set<Label> newHashSet = Generics.newHashSet(makeObjectsByCat.keySet());
        newHashSet.addAll(makeObjectsByCat2.keySet());
        if (printWriter != null && this.runningAverages) {
            printWriter.println("========================================");
            printWriter.println("Labeled Bracketed Evaluation by Category");
            printWriter.println("========================================");
        }
        this.num += 1.0d;
        for (Label label : newHashSet) {
            Set<Constituent> newHashSet2 = makeObjectsByCat.containsKey(label) ? makeObjectsByCat.get(label) : Generics.newHashSet();
            Set<Constituent> newHashSet3 = makeObjectsByCat2.containsKey(label) ? makeObjectsByCat2.get(label) : Generics.newHashSet();
            double precision = precision(newHashSet2, newHashSet3);
            double precision2 = precision(newHashSet3, newHashSet2);
            double d = (precision <= 0.0d || precision2 <= 0.0d) ? 0.0d : 2.0d / ((1.0d / precision) + (1.0d / precision2));
            this.precisions.incrementCount(label, precision);
            this.recalls.incrementCount(label, precision2);
            this.f1s.incrementCount(label, d);
            this.precisions2.incrementCount(label, newHashSet2.size() * precision);
            this.pnums2.incrementCount(label, newHashSet2.size());
            this.recalls2.incrementCount(label, newHashSet3.size() * precision2);
            this.rnums2.incrementCount(label, newHashSet3.size());
            if (printWriter != null && this.runningAverages) {
                printWriter.println(label + "\tP: " + (((int) (precision * 10000.0d)) / 100.0d) + " (sent ave " + (((int) ((this.precisions.getCount(label) * 10000.0d) / this.num)) / 100.0d) + ") (evalb " + (((int) ((this.precisions2.getCount(label) * 10000.0d) / this.pnums2.getCount(label))) / 100.0d) + ")");
                printWriter.println("\tR: " + (((int) (precision2 * 10000.0d)) / 100.0d) + " (sent ave " + (((int) ((this.recalls.getCount(label) * 10000.0d) / this.num)) / 100.0d) + ") (evalb " + (((int) ((this.recalls2.getCount(label) * 10000.0d) / this.rnums2.getCount(label))) / 100.0d) + ")");
                printWriter.println(this.str + " F1: " + (((int) (d * 10000.0d)) / 100.0d) + " (sent ave " + (((int) ((10000.0d * this.f1s.getCount(label)) / this.num)) / 100.0d) + ", evalb " + (((int) (10000.0d * (2.0d / ((this.rnums2.getCount(label) / this.recalls2.getCount(label)) + (this.pnums2.getCount(label) / this.precisions2.getCount(label)))))) / 100.0d) + ")");
            }
        }
        if (printWriter == null || !this.runningAverages) {
            return;
        }
        printWriter.println("========================================");
    }

    private Set<Label> getEvalLabelSet(Set<Label> set) {
        if (this.pLabelFilter == null) {
            return Generics.newHashSet(this.precisions.keySet());
        }
        Set<Label> newHashSet = Generics.newHashSet(this.precisions.keySet().size());
        for (Label label : set) {
            if (this.pLabelFilter.matcher(label.value()).matches()) {
                newHashSet.add(label);
            }
        }
        return newHashSet;
    }

    @Override // edu.stanford.nlp.parser.metrics.AbstractEval, edu.stanford.nlp.parser.metrics.Eval
    public void display(boolean z, PrintWriter printWriter) {
        if (this.precisions.keySet().size() != this.recalls.keySet().size()) {
            log.error("Different counts for precisions and recalls!");
            return;
        }
        Set<Label> evalLabelSet = getEvalLabelSet(this.precisions.keySet());
        Random random = new Random();
        TreeMap treeMap = new TreeMap();
        for (Label label : evalLabelSet) {
            double count = 2.0d / ((1.0d / (this.precisions2.getCount(label) / this.pnums2.getCount(label))) + (1.0d / (this.recalls2.getCount(label) / this.rnums2.getCount(label))));
            if (new Double(count).equals(Double.valueOf(Double.NaN))) {
                count = -1.0d;
            }
            if (treeMap.containsKey(Double.valueOf(count))) {
                treeMap.put(Double.valueOf(count + (random.nextDouble() / 1000.0d)), label);
            } else {
                treeMap.put(Double.valueOf(count), label);
            }
        }
        printWriter.println("============================================================");
        printWriter.println("Labeled Bracketed Evaluation by Category -- final statistics");
        printWriter.println("============================================================");
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (Label label2 : treeMap.values()) {
            double count2 = this.pnums2.getCount(label2);
            double count3 = this.rnums2.getCount(label2);
            double count4 = (this.precisions2.getCount(label2) / count2) * 100.0d;
            double count5 = (this.recalls2.getCount(label2) / count3) * 100.0d;
            double d5 = 2.0d / ((1.0d / count4) + (1.0d / count5));
            d += this.precisions2.getCount(label2);
            d2 += count2;
            d3 += this.recalls2.getCount(label2);
            d4 += count3;
            printWriter.printf("%s\tLP: %s\tguessed: %d\tLR: %s\tgold: %d\t F1: %s%n", label2.value(), count2 == 0.0d ? "N/A" : String.format("%.2f", Double.valueOf(count4)), Integer.valueOf((int) count2), count3 == 0.0d ? "N/A" : String.format("%.2f", Double.valueOf(count5)), Integer.valueOf((int) count3), (count2 == 0.0d || count3 == 0.0d) ? "N/A" : String.format("%.2f", Double.valueOf(d5)));
        }
        printWriter.println("============================================================");
        double d6 = d / d2;
        double d7 = d3 / d4;
        printWriter.printf("Total\tLP: %.2f\tguessed: %d\tLR: %.2f\tgold: %d\t F1: %.2f%n", Double.valueOf(d6 * 100.0d), Integer.valueOf((int) d2), Double.valueOf(d7 * 100.0d), Integer.valueOf((int) d4), Double.valueOf((((2.0d * d6) * d7) / (d6 + d7)) * 100.0d));
        printWriter.println("============================================================");
    }
}
