package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import java.util.Arrays;

/* loaded from: input_file:BOOT-INF/lib/stanford-corenlp-4.5.6.jar:edu/stanford/nlp/classify/BiasedLogConditionalObjectiveFunction.class */
public class BiasedLogConditionalObjectiveFunction extends AbstractCachingDiffFunction {
    protected LogPrior prior;
    protected int numFeatures;
    protected int numClasses;
    protected int[][] data;
    protected int[] labels;
    private double[][] confusionMatrix;

    public void setPrior(LogPrior logPrior) {
        this.prior = logPrior;
    }

    @Override // edu.stanford.nlp.optimization.Function
    public int domainDimension() {
        return this.numFeatures * this.numClasses;
    }

    int classOf(int i) {
        return i % this.numClasses;
    }

    int featureOf(int i) {
        return i / this.numClasses;
    }

    protected int indexOf(int i, int i2) {
        return (i * this.numClasses) + i2;
    }

    public double[][] to2D(double[] dArr) {
        double[][] dArr2 = new double[this.numFeatures][this.numClasses];
        for (int i = 0; i < this.numFeatures; i++) {
            for (int i2 = 0; i2 < this.numClasses; i2++) {
                dArr2[i][i2] = dArr[indexOf(i, i2)];
            }
        }
        return dArr2;
    }

    @Override // edu.stanford.nlp.optimization.AbstractCachingDiffFunction
    protected void calculate(double[] dArr) {
        if (this.derivative == null) {
            this.derivative = new double[dArr.length];
        } else {
            Arrays.fill(this.derivative, 0.0d);
        }
        this.value = 0.0d;
        double[] dArr2 = new double[this.numClasses];
        double[] dArr3 = new double[this.numClasses];
        double[] dArr4 = new double[this.numClasses];
        for (int i = 0; i < this.data.length; i++) {
            int[] iArr = this.data[i];
            int i2 = this.labels[i];
            Arrays.fill(dArr2, 0.0d);
            for (int i3 = 0; i3 < this.numClasses; i3++) {
                for (int i4 : iArr) {
                    int i5 = i3;
                    dArr2[i5] = dArr2[i5] + dArr[indexOf(i4, i3)];
                }
            }
            double logSum = ArrayMath.logSum(dArr2);
            double[] dArr5 = new double[this.numClasses];
            for (int i6 = 0; i6 < this.numClasses; i6++) {
                dArr5[i6] = Math.log(this.confusionMatrix[i2][i6]) + dArr2[i6];
            }
            double logSum2 = ArrayMath.logSum(dArr5);
            for (int i7 = 0; i7 < this.numClasses; i7++) {
                dArr3[i7] = Math.exp(dArr2[i7] - logSum);
                dArr4[i7] = Math.exp(dArr5[i7] - logSum2);
                for (int i8 : iArr) {
                    int indexOf = indexOf(i8, i7);
                    double[] dArr6 = this.derivative;
                    dArr6[indexOf] = dArr6[indexOf] + (dArr3[i7] - dArr4[i7]);
                }
            }
            double d = 0.0d;
            for (int i9 = 0; i9 < this.numClasses; i9++) {
                d += this.confusionMatrix[i2][i9] * Math.exp(dArr2[i9] - logSum);
            }
            this.value -= Math.log(d);
        }
        this.value += this.prior.compute(dArr, this.derivative);
    }

    public BiasedLogConditionalObjectiveFunction(GeneralDataset<?, ?> generalDataset, double[][] dArr) {
        this(generalDataset, dArr, new LogPrior(LogPrior.LogPriorType.QUADRATIC));
    }

    public BiasedLogConditionalObjectiveFunction(GeneralDataset<?, ?> generalDataset, double[][] dArr, LogPrior logPrior) {
        this(generalDataset.numFeatures(), generalDataset.numClasses(), generalDataset.getDataArray(), generalDataset.getLabelsArray(), dArr, logPrior);
    }

    public BiasedLogConditionalObjectiveFunction(int i, int i2, int[][] iArr, int[] iArr2, double[][] dArr) {
        this(i, i2, iArr, iArr2, dArr, new LogPrior(LogPrior.LogPriorType.QUADRATIC));
    }

    public BiasedLogConditionalObjectiveFunction(int i, int i2, int[][] iArr, int[] iArr2, double[][] dArr, LogPrior logPrior) {
        this.numFeatures = 0;
        this.numClasses = 0;
        this.data = (int[][]) null;
        this.labels = null;
        this.numFeatures = i;
        this.numClasses = i2;
        this.data = iArr;
        this.labels = iArr2;
        this.prior = logPrior;
        this.confusionMatrix = dArr;
    }
}
