package edu.stanford.nlp.sentiment;

import com.ibm.icu.text.PluralRules;
import edu.stanford.nlp.neural.rnn.RNNCoreAnnotations;
import edu.stanford.nlp.neural.rnn.TopNGramRecord;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.ConfusionMatrix;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.logging.Redwood;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import org.apache.xmpbox.schema.XMPBasicSchema;

/* loaded from: input_file:BOOT-INF/lib/stanford-corenlp-4.5.6.jar:edu/stanford/nlp/sentiment/AbstractEvaluate.class */
public abstract class AbstractEvaluate {
    String[] equivalenceClassNames;
    int labelsCorrect;
    int labelsIncorrect;
    int[][] labelConfusion;
    int rootLabelsCorrect;
    int rootLabelsIncorrect;
    int[][] rootLabelConfusion;
    IntCounter<Integer> lengthLabelsCorrect;
    IntCounter<Integer> lengthLabelsIncorrect;
    TopNGramRecord ngrams;
    static final int NUM_NGRAMS = 5;
    int[][] equivalenceClasses;
    private RNNOptions op;
    private static Redwood.RedwoodChannels log = Redwood.channels(AbstractEvaluate.class);
    protected static final NumberFormat NF = new DecimalFormat("0.000000");

    public AbstractEvaluate(RNNOptions rNNOptions) {
        this.op = null;
        this.op = rNNOptions;
        reset();
    }

    protected static void printConfusionMatrix(String str, int[][] iArr) {
        log.info(str + " confusion matrix");
        ConfusionMatrix confusionMatrix = new ConfusionMatrix();
        confusionMatrix.setUseRealLabels(true);
        for (int i = 0; i < iArr.length; i++) {
            for (int i2 = 0; i2 < iArr[i].length; i2++) {
                confusionMatrix.add(Integer.valueOf(i2), Integer.valueOf(i), iArr[i][i2]);
            }
        }
        log.info("\n" + confusionMatrix);
    }

    protected static double[] approxAccuracy(int[][] iArr, int[][] iArr2) {
        int[] iArr3 = new int[iArr2.length];
        int[] iArr4 = new int[iArr2.length];
        double[] dArr = new double[iArr2.length];
        for (int i = 0; i < iArr2.length; i++) {
            for (int i2 = 0; i2 < iArr2[i].length; i2++) {
                for (int i3 = 0; i3 < iArr2[i].length; i3++) {
                    int i4 = i;
                    iArr3[i4] = iArr3[i4] + iArr[iArr2[i][i2]][iArr2[i][i3]];
                }
                for (int i5 = 0; i5 < iArr[iArr2[i][i2]].length; i5++) {
                    int i6 = i;
                    iArr4[i6] = iArr4[i6] + iArr[iArr2[i][i2]][i5];
                }
            }
            dArr[i] = iArr3[i] / iArr4[i];
        }
        return dArr;
    }

    protected static double approxCombinedAccuracy(int[][] iArr, int[][] iArr2) {
        int i = 0;
        int i2 = 0;
        for (int[] iArr3 : iArr2) {
            for (int i3 = 0; i3 < iArr3.length; i3++) {
                for (int i4 : iArr3) {
                    i += iArr[iArr3[i3]][i4];
                }
                for (int i5 = 0; i5 < iArr[iArr3[i3]].length; i5++) {
                    i2 += iArr[iArr3[i3]][i5];
                }
            }
        }
        return i / i2;
    }

    public void reset() {
        this.labelsCorrect = 0;
        this.labelsIncorrect = 0;
        this.labelConfusion = new int[this.op.numClasses][this.op.numClasses];
        this.rootLabelsCorrect = 0;
        this.rootLabelsIncorrect = 0;
        this.rootLabelConfusion = new int[this.op.numClasses][this.op.numClasses];
        this.lengthLabelsCorrect = new IntCounter<>();
        this.lengthLabelsIncorrect = new IntCounter<>();
        this.equivalenceClasses = this.op.equivalenceClasses;
        this.equivalenceClassNames = this.op.equivalenceClassNames;
        if (this.op.testOptions.ngramRecordSize > 0) {
            this.ngrams = new TopNGramRecord(this.op.numClasses, this.op.testOptions.ngramRecordSize, this.op.testOptions.ngramRecordMaximumLength);
        } else {
            this.ngrams = null;
        }
    }

    public void eval(List<Tree> list) {
        populatePredictedLabels(list);
        Iterator<Tree> it = list.iterator();
        while (it.hasNext()) {
            eval(it.next());
        }
    }

    public void eval(Tree tree) {
        countTree(tree);
        countRoot(tree);
        countLengthAccuracy(tree);
        if (this.ngrams != null) {
            this.ngrams.countTree(tree);
        }
    }

    protected int countLengthAccuracy(Tree tree) {
        int i;
        if (tree.isLeaf()) {
            return 0;
        }
        Integer valueOf = Integer.valueOf(RNNCoreAnnotations.getGoldClass(tree));
        Integer valueOf2 = Integer.valueOf(RNNCoreAnnotations.getPredictedClass(tree));
        if (tree.isPreTerminal()) {
            i = 1;
        } else {
            i = 0;
            for (Tree tree2 : tree.children()) {
                i += countLengthAccuracy(tree2);
            }
        }
        if (valueOf.intValue() >= 0) {
            if (valueOf.equals(valueOf2)) {
                this.lengthLabelsCorrect.incrementCount(Integer.valueOf(i));
            } else {
                this.lengthLabelsIncorrect.incrementCount(Integer.valueOf(i));
            }
        }
        return i;
    }

    protected void countTree(Tree tree) {
        if (tree.isLeaf()) {
            return;
        }
        for (Tree tree2 : tree.children()) {
            countTree(tree2);
        }
        Integer valueOf = Integer.valueOf(RNNCoreAnnotations.getGoldClass(tree));
        Integer valueOf2 = Integer.valueOf(RNNCoreAnnotations.getPredictedClass(tree));
        if (valueOf.intValue() >= 0) {
            if (valueOf.equals(valueOf2)) {
                this.labelsCorrect++;
            } else {
                this.labelsIncorrect++;
            }
            int[] iArr = this.labelConfusion[valueOf.intValue()];
            int intValue = valueOf2.intValue();
            iArr[intValue] = iArr[intValue] + 1;
        }
    }

    protected void countRoot(Tree tree) {
        Integer valueOf = Integer.valueOf(RNNCoreAnnotations.getGoldClass(tree));
        Integer valueOf2 = Integer.valueOf(RNNCoreAnnotations.getPredictedClass(tree));
        if (valueOf.intValue() >= 0) {
            if (valueOf.equals(valueOf2)) {
                this.rootLabelsCorrect++;
            } else {
                this.rootLabelsIncorrect++;
            }
            int[] iArr = this.rootLabelConfusion[valueOf.intValue()];
            int intValue = valueOf2.intValue();
            iArr[intValue] = iArr[intValue] + 1;
        }
    }

    public double exactNodeAccuracy() {
        return this.labelsCorrect / (this.labelsCorrect + this.labelsIncorrect);
    }

    public double exactRootAccuracy() {
        return this.rootLabelsCorrect / (this.rootLabelsCorrect + this.rootLabelsIncorrect);
    }

    public Counter<Integer> lengthAccuracies() {
        Set<Integer> newHashSet = Generics.newHashSet();
        newHashSet.addAll(this.lengthLabelsCorrect.keySet());
        newHashSet.addAll(this.lengthLabelsIncorrect.keySet());
        ClassicCounter classicCounter = new ClassicCounter();
        for (Integer num : newHashSet) {
            classicCounter.setCount(num, this.lengthLabelsCorrect.getCount(num) / (this.lengthLabelsCorrect.getCount(num) + this.lengthLabelsIncorrect.getCount(num)));
        }
        return classicCounter;
    }

    public void printLengthAccuracies() {
        Counter<Integer> lengthAccuracies = lengthAccuracies();
        TreeSet<Integer> newTreeSet = Generics.newTreeSet();
        newTreeSet.addAll(lengthAccuracies.keySet());
        log.info("Label accuracy at various lengths:");
        for (Integer num : newTreeSet) {
            log.info(StringUtils.padLeft(Integer.toString(num.intValue()), 4) + PluralRules.KEYWORD_RULE_SEPARATOR + NF.format(lengthAccuracies.getCount(num)));
        }
    }

    public void printSummary() {
        log.info("EVALUATION SUMMARY");
        log.info("Tested " + (this.labelsCorrect + this.labelsIncorrect) + " labels");
        log.info("  " + this.labelsCorrect + " correct");
        log.info("  " + this.labelsIncorrect + " incorrect");
        log.info("  " + NF.format(exactNodeAccuracy()) + " accuracy");
        log.info("Tested " + (this.rootLabelsCorrect + this.rootLabelsIncorrect) + " roots");
        log.info("  " + this.rootLabelsCorrect + " correct");
        log.info("  " + this.rootLabelsIncorrect + " incorrect");
        log.info("  " + NF.format(exactRootAccuracy()) + " accuracy");
        printConfusionMatrix(XMPBasicSchema.LABEL, this.labelConfusion);
        printConfusionMatrix("Root label", this.rootLabelConfusion);
        if (this.equivalenceClasses != null && this.equivalenceClassNames != null) {
            double[] approxAccuracy = approxAccuracy(this.labelConfusion, this.equivalenceClasses);
            for (int i = 0; i < this.equivalenceClassNames.length; i++) {
                log.info("Approximate " + this.equivalenceClassNames[i] + " label accuracy: " + NF.format(approxAccuracy[i]));
            }
            log.info("Combined approximate label accuracy: " + NF.format(approxCombinedAccuracy(this.labelConfusion, this.equivalenceClasses)));
            double[] approxAccuracy2 = approxAccuracy(this.rootLabelConfusion, this.equivalenceClasses);
            for (int i2 = 0; i2 < this.equivalenceClassNames.length; i2++) {
                log.info("Approximate " + this.equivalenceClassNames[i2] + " root label accuracy: " + NF.format(approxAccuracy2[i2]));
            }
            log.info("Combined approximate root label accuracy: " + NF.format(approxCombinedAccuracy(this.rootLabelConfusion, this.equivalenceClasses)));
        }
        if (this.op.testOptions.ngramRecordSize > 0) {
            log.info(this.ngrams);
        }
        if (this.op.testOptions.printLengthAccuracies) {
            printLengthAccuracies();
        }
    }

    public abstract void populatePredictedLabels(List<Tree> list);
}
