package edu.stanford.nlp.classify;

import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Triple;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:BOOT-INF/lib/stanford-corenlp-4.5.6.jar:edu/stanford/nlp/classify/GeneralizedExpectationObjectiveFunction.class */
public class GeneralizedExpectationObjectiveFunction<L, F> extends AbstractCachingDiffFunction {
    private final GeneralDataset<L, F> labeledDataset;
    private final List<? extends Datum<L, F>> unlabeledDataList;
    private final List<F> geFeatures;
    private final LinearClassifier<L, F> classifier;
    private double[][] geFeature2EmpiricalDist;
    private List<List<Integer>> geFeature2DatumList;
    private final int numFeatures;
    private final int numClasses;

    @Override // edu.stanford.nlp.optimization.Function
    public int domainDimension() {
        return this.numFeatures * this.numClasses;
    }

    int classOf(int i) {
        return i % this.numClasses;
    }

    int featureOf(int i) {
        return i / this.numClasses;
    }

    protected int indexOf(int i, int i2) {
        return (i * this.numClasses) + i2;
    }

    public double[][] to2D(double[] dArr) {
        double[][] dArr2 = new double[this.numFeatures][this.numClasses];
        for (int i = 0; i < this.numFeatures; i++) {
            for (int i2 = 0; i2 < this.numClasses; i2++) {
                dArr2[i][i2] = dArr[indexOf(i, i2)];
            }
        }
        return dArr2;
    }

    @Override // edu.stanford.nlp.optimization.AbstractCachingDiffFunction
    protected void calculate(double[] dArr) {
        this.classifier.setWeights(to2D(dArr));
        if (this.derivative == null) {
            this.derivative = new double[dArr.length];
        } else {
            Arrays.fill(this.derivative, 0.0d);
        }
        ClassicCounter classicCounter = new ClassicCounter();
        this.value = 0.0d;
        for (int i = 0; i < this.geFeatures.size(); i++) {
            double[] dArr2 = new double[this.numClasses];
            Arrays.fill(dArr2, 0.0d);
            List<Integer> list = this.geFeature2DatumList.get(i);
            Iterator<Integer> it = list.iterator();
            while (it.hasNext()) {
                Datum<L, F> datum = this.unlabeledDataList.get(it.next().intValue());
                double[] modelProbs = getModelProbs(datum);
                for (int i2 = 0; i2 < this.numClasses; i2++) {
                    int i3 = i2;
                    dArr2[i3] = dArr2[i3] + modelProbs[i2];
                }
                updateDerivative(datum, modelProbs, classicCounter);
            }
            if (list.size() > 0) {
                for (int i4 = 0; i4 < this.numClasses; i4++) {
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] / list.size();
                }
                smoothDistribution(dArr2);
                for (int i6 = 0; i6 < this.numClasses; i6++) {
                    this.value += (-this.geFeature2EmpiricalDist[i][i6]) * Math.log(dArr2[i6]);
                }
                for (int i7 = 0; i7 < this.labeledDataset.featureIndex().size(); i7++) {
                    for (int i8 = 0; i8 < this.numClasses; i8++) {
                        int indexOf = indexOf(i7, i8);
                        for (int i9 = 0; i9 < this.numClasses; i9++) {
                            double[] dArr3 = this.derivative;
                            dArr3[indexOf] = dArr3[indexOf] + ((classicCounter.getCount(new Triple(Integer.valueOf(i7), Integer.valueOf(i8), Integer.valueOf(i9))) * this.geFeature2EmpiricalDist[i][i9]) / dArr2[i9]);
                        }
                        double[] dArr4 = this.derivative;
                        dArr4[indexOf] = dArr4[indexOf] / list.size();
                    }
                }
            }
        }
    }

    private void updateDerivative(Datum<L, F> datum, double[] dArr, Counter<Triple<Integer, Integer, Integer>> counter) {
        for (F f : datum.asFeatures()) {
            int indexOf = this.labeledDataset.featureIndex.indexOf(f);
            if (indexOf >= 0) {
                for (int i = 0; i < this.numClasses; i++) {
                    for (int i2 = 0; i2 < this.numClasses; i2++) {
                        if (i2 == i) {
                            counter.incrementCount(new Triple<>(Integer.valueOf(indexOf), Integer.valueOf(i), Integer.valueOf(i2)), (-dArr[i]) * (1.0d - dArr[i]) * valueOfFeature(f, datum));
                        } else {
                            counter.incrementCount(new Triple<>(Integer.valueOf(indexOf), Integer.valueOf(i), Integer.valueOf(i2)), dArr[i] * dArr[i2] * valueOfFeature(f, datum));
                        }
                    }
                }
            }
        }
    }

    private double valueOfFeature(F f, Datum<L, F> datum) {
        if (datum instanceof RVFDatum) {
            return ((RVFDatum) datum).asFeaturesCounter().getCount(f);
        }
        return 1.0d;
    }

    private void computeEmpiricalStatistics(List<F> list) {
        this.geFeature2EmpiricalDist = new double[list.size()][this.labeledDataset.labelIndex.size()];
        this.geFeature2DatumList = new ArrayList(list.size());
        Map newHashMap = Generics.newHashMap();
        Set newHashSet = Generics.newHashSet();
        for (int i = 0; i < list.size(); i++) {
            F f = list.get(i);
            this.geFeature2DatumList.add(new ArrayList());
            Arrays.fill(this.geFeature2EmpiricalDist[i], 0.0d);
            newHashMap.put(f, Integer.valueOf(i));
        }
        for (int i2 = 0; i2 < this.labeledDataset.size(); i2++) {
            Datum<L, F> datum = this.labeledDataset.getDatum(i2);
            int indexOf = this.labeledDataset.labelIndex.indexOf(datum.label());
            for (F f2 : datum.asFeatures()) {
                if (newHashMap.containsKey(f2)) {
                    double[] dArr = this.geFeature2EmpiricalDist[((Integer) newHashMap.get(f2)).intValue()];
                    dArr[indexOf] = dArr[indexOf] + 1.0d;
                }
            }
        }
        for (int i3 = 0; i3 < list.size(); i3++) {
            ArrayMath.normalize(this.geFeature2EmpiricalDist[i3]);
            smoothDistribution(this.geFeature2EmpiricalDist[i3]);
        }
        for (int i4 = 0; i4 < this.unlabeledDataList.size(); i4++) {
            for (F f3 : this.unlabeledDataList.get(i4).asFeatures()) {
                if (newHashMap.containsKey(f3)) {
                    this.geFeature2DatumList.get(((Integer) newHashMap.get(f3)).intValue()).add(Integer.valueOf(i4));
                    newHashSet.add(Integer.valueOf(i4));
                }
            }
        }
        System.out.println("Number of active unlabeled examples:" + newHashSet.size());
    }

    private static void smoothDistribution(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] + 1.0E-6d;
        }
        ArrayMath.normalize(dArr);
    }

    private double[] getModelProbs(Datum<L, F> datum) {
        double[] dArr = new double[this.labeledDataset.numClasses()];
        Counter<L> probabilityOf = this.classifier.probabilityOf(datum);
        for (L l : probabilityOf.keySet()) {
            dArr[this.labeledDataset.labelIndex.indexOf(l)] = probabilityOf.getCount(l);
        }
        return dArr;
    }

    public GeneralizedExpectationObjectiveFunction(GeneralDataset<L, F> generalDataset, List<? extends Datum<L, F>> list, List<F> list2) {
        System.out.println("Number of labeled examples:" + generalDataset.size + "\nNumber of unlabeled examples:" + list.size());
        System.out.println("Number of GE features:" + list2.size());
        this.numFeatures = generalDataset.numFeatures();
        this.numClasses = generalDataset.numClasses();
        this.labeledDataset = generalDataset;
        this.unlabeledDataList = list;
        this.geFeatures = list2;
        this.classifier = new LinearClassifier<>((double[][]) null, generalDataset.featureIndex, generalDataset.labelIndex);
        computeEmpiricalStatistics(list2);
    }
}
