package edu.stanford.nlp.stats;

import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.ProbabilisticClassifier;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.util.BinaryHeapPriorityQueue;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import java.text.NumberFormat;
import java.util.List;
import org.apache.tika.parser.microsoft.onenote.OneNotePtr;

/* loaded from: input_file:BOOT-INF/lib/stanford-corenlp-4.5.6.jar:edu/stanford/nlp/stats/MultiClassAccuracyStats.class */
public class MultiClassAccuracyStats<L> implements Scorer<L> {
    double[] scores;
    boolean[] isCorrect;
    double logLikelihood;
    double accuracy;
    static String saveFile = null;
    static int saveIndex = 1;
    public static final int USE_ACCURACY = 1;
    public static final int USE_LOGLIKELIHOOD = 2;
    private int scoreType;
    int correct;
    int total;

    public MultiClassAccuracyStats() {
        this.scoreType = 1;
        this.correct = 0;
        this.total = 0;
    }

    public MultiClassAccuracyStats(int i) {
        this.scoreType = 1;
        this.correct = 0;
        this.total = 0;
        this.scoreType = i;
    }

    public MultiClassAccuracyStats(String str) {
        this(str, 1);
    }

    public MultiClassAccuracyStats(String str, int i) {
        this.scoreType = 1;
        this.correct = 0;
        this.total = 0;
        saveFile = str;
        this.scoreType = i;
    }

    public <F> MultiClassAccuracyStats(ProbabilisticClassifier<L, F> probabilisticClassifier, GeneralDataset<L, F> generalDataset, String str) {
        this(probabilisticClassifier, generalDataset, str, 1);
    }

    public <F> MultiClassAccuracyStats(ProbabilisticClassifier<L, F> probabilisticClassifier, GeneralDataset<L, F> generalDataset, String str, int i) {
        this.scoreType = 1;
        this.correct = 0;
        this.total = 0;
        saveFile = str;
        this.scoreType = i;
        initMC(probabilisticClassifier, generalDataset);
    }

    @Override // edu.stanford.nlp.stats.Scorer
    public <F> double score(ProbabilisticClassifier<L, F> probabilisticClassifier, GeneralDataset<L, F> generalDataset) {
        initMC(probabilisticClassifier, generalDataset);
        return score();
    }

    public double score() {
        if (this.scoreType == 1) {
            return this.accuracy;
        }
        if (this.scoreType == 2) {
            return this.logLikelihood;
        }
        throw new RuntimeException("Unknown score type: " + this.scoreType);
    }

    public int numSamples() {
        return this.scores.length;
    }

    public double confidenceWeightedAccuracy() {
        double d = 0.0d;
        for (int i = 1; i <= numSamples(); i++) {
            d += numCorrect(i) / i;
        }
        return d / numSamples();
    }

    public <F> void initMC(ProbabilisticClassifier<L, F> probabilisticClassifier, GeneralDataset<L, F> generalDataset) {
        BinaryHeapPriorityQueue binaryHeapPriorityQueue = new BinaryHeapPriorityQueue();
        this.total = 0;
        this.correct = 0;
        this.logLikelihood = 0.0d;
        for (int i = 0; i < generalDataset.size(); i++) {
            RVFDatum<L, F> rVFDatum = generalDataset.getRVFDatum(i);
            Counter<L> logProbabilityOf = probabilisticClassifier.logProbabilityOf(rVFDatum);
            Object argmax = Counters.argmax(logProbabilityOf);
            L label = rVFDatum.label();
            double count = logProbabilityOf.getCount(argmax);
            double count2 = logProbabilityOf.getCount(label);
            int indexOf = generalDataset.labelIndex().indexOf(argmax);
            int indexOf2 = generalDataset.labelIndex().indexOf(label);
            this.total++;
            if (indexOf == indexOf2) {
                this.correct++;
            }
            this.logLikelihood += count2;
            binaryHeapPriorityQueue.add(new Pair(Integer.valueOf(i), new Pair(Double.valueOf(count), Boolean.valueOf(indexOf == indexOf2))), -count);
        }
        this.accuracy = this.correct / this.total;
        List<E> sortedList = binaryHeapPriorityQueue.toSortedList();
        this.scores = new double[sortedList.size()];
        this.isCorrect = new boolean[sortedList.size()];
        for (int i2 = 0; i2 < sortedList.size(); i2++) {
            Pair pair = (Pair) ((Pair) sortedList.get(i2)).second();
            this.scores[i2] = ((Double) pair.first()).doubleValue();
            this.isCorrect[i2] = ((Boolean) pair.second()).booleanValue();
        }
    }

    public int numCorrect(int i) {
        int i2 = 0;
        for (int length = this.scores.length - 1; length >= this.scores.length - i; length--) {
            if (this.isCorrect[length]) {
                i2++;
            }
        }
        return i2;
    }

    public int[] getAccCoverage() {
        int[] iArr = new int[numSamples()];
        for (int i = 1; i <= numSamples(); i++) {
            iArr[i - 1] = numCorrect(i);
        }
        return iArr;
    }

    @Override // edu.stanford.nlp.stats.Scorer
    public String getDescription(int i) {
        NumberFormat numberInstance = NumberFormat.getNumberInstance();
        numberInstance.setMaximumFractionDigits(i);
        StringBuilder sb = new StringBuilder();
        double confidenceWeightedAccuracy = confidenceWeightedAccuracy();
        sb.append("--- Accuracy Stats ---").append("\n");
        sb.append("accuracy: ").append(numberInstance.format(this.accuracy)).append(" (").append(this.correct).append("/").append(this.total).append(")\n");
        sb.append("confidence weighted accuracy :").append(numberInstance.format(confidenceWeightedAccuracy)).append("\n");
        sb.append("log-likelihood: ").append(this.logLikelihood).append("\n");
        if (saveFile != null) {
            String str = saveFile + "-" + saveIndex;
            sb.append("saving accuracy info to ").append(str).append(".accuracy\n");
            StringUtils.printToFile(str + ".accuracy", AccuracyStats.toStringArr(getAccCoverage()));
            saveIndex++;
        }
        return sb.toString();
    }

    public String toString() {
        return "MultiClassAccuracyStats(" + (this.scoreType == 1 ? "classification_accuracy" : this.scoreType == 2 ? "log_likelihood" : OneNotePtr.UNKNOWN) + ")" + this.scoreType + "12";
    }
}
