package edu.stanford.nlp.classify;

import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.TwoDimensionalCounter;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ScoredComparator;
import edu.stanford.nlp.util.ScoredObject;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.File;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.regex.Pattern;
import org.eclipse.jgit.transport.WalkEncryption;

/* loaded from: input_file:BOOT-INF/lib/stanford-corenlp-4.0.0.jar:edu/stanford/nlp/classify/Dataset.class */
public class Dataset<L, F> extends GeneralDataset<L, F> {
    private static final long serialVersionUID = -3883164942879961091L;
    static final Redwood.RedwoodChannels logger = Redwood.channels(Dataset.class);
    private static final double LN_TO_LOG2 = 1.0d / Math.log(2.0d);
    private static int line1 = 0;

    public Dataset() {
        this(10);
    }

    public Dataset(int i) {
        initialize(i);
    }

    public Dataset(int i, Index<F> index, Index<L> index2) {
        initialize(i);
        this.featureIndex = index;
        this.labelIndex = index2;
    }

    public Dataset(Index<F> index, Index<L> index2) {
        this(10, index, index2);
    }

    public Dataset(Index<L> index, int[] iArr, Index<F> index2, int[][] iArr2) {
        this(index, iArr, index2, iArr2, iArr2.length);
    }

    public Dataset(Index<L> index, int[] iArr, Index<F> index2, int[][] iArr2, int i) {
        this.labelIndex = index;
        this.labels = iArr;
        this.featureIndex = index2;
        this.data = iArr2;
        this.size = i;
    }

    @Override // edu.stanford.nlp.classify.GeneralDataset
    public Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split(double d) {
        return split(0, (int) (d * size()));
    }

    /* JADX WARN: Type inference failed for: r0v10, types: [java.lang.Object, int[], int[][]] */
    /* JADX WARN: Type inference failed for: r0v6, types: [java.lang.Object, int[], int[][]] */
    @Override // edu.stanford.nlp.classify.GeneralDataset
    public Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split(int i, int i2) {
        int i3 = i2 - i;
        int size = size() - i3;
        ?? r0 = new int[i3];
        int[] iArr = new int[i3];
        ?? r02 = new int[size];
        int[] iArr2 = new int[size];
        synchronized (System.class) {
            System.arraycopy(this.data, i, r0, 0, i3);
            System.arraycopy(this.labels, i, iArr, 0, i3);
            System.arraycopy(this.data, 0, r02, 0, i);
            System.arraycopy(this.data, i2, r02, i, size() - i2);
            System.arraycopy(this.labels, 0, iArr2, 0, i);
            System.arraycopy(this.labels, i2, iArr2, i, size() - i2);
        }
        if (!(this instanceof WeightedDataset)) {
            return new Pair<>(new Dataset(this.labelIndex, iArr2, this.featureIndex, r02, size), new Dataset(this.labelIndex, iArr, this.featureIndex, r0, i3));
        }
        float[] fArr = new float[size];
        float[] fArr2 = new float[i3];
        WeightedDataset weightedDataset = (WeightedDataset) this;
        synchronized (System.class) {
            System.arraycopy(weightedDataset.weights, i, fArr2, 0, i3);
            System.arraycopy(weightedDataset.weights, 0, fArr, 0, i);
            System.arraycopy(weightedDataset.weights, i2, fArr, i, size() - i2);
        }
        return new Pair<>(new WeightedDataset(this.labelIndex, iArr2, this.featureIndex, r02, size, fArr), new WeightedDataset(this.labelIndex, iArr, this.featureIndex, r0, i3, fArr2));
    }

    /* JADX WARN: Type inference failed for: r0v10, types: [int[], int[][]] */
    public Dataset<L, F> getRandomSubDataset(double d, int i) {
        int size = (int) (d * size());
        Set newHashSet = Generics.newHashSet();
        Random random = new Random(i);
        int size2 = size();
        while (newHashSet.size() < size) {
            newHashSet.add(Integer.valueOf(random.nextInt(size2)));
        }
        ?? r0 = new int[size];
        int[] iArr = new int[size];
        int i2 = 0;
        Iterator it = newHashSet.iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            r0[i2] = this.data[intValue];
            iArr[i2] = this.labels[intValue];
            i2++;
        }
        return new Dataset<>(this.labelIndex, iArr, this.featureIndex, r0);
    }

    @Override // edu.stanford.nlp.classify.GeneralDataset
    public double[][] getValuesArray() {
        return (double[][]) null;
    }

    public static Dataset<String, String> readSVMLightFormat(String str) {
        return readSVMLightFormat(str, new HashIndex(), new HashIndex());
    }

    public static Dataset<String, String> readSVMLightFormat(String str, List<String> list) {
        return readSVMLightFormat(str, new HashIndex(), new HashIndex(), list);
    }

    public static Dataset<String, String> readSVMLightFormat(String str, Index<String> index, Index<String> index2) {
        return readSVMLightFormat(str, index, index2, null);
    }

    public static Dataset<String, String> readSVMLightFormat(String str, Index<String> index, Index<String> index2, List<String> list) {
        try {
            Dataset<String, String> dataset = new Dataset<>(10, index, index2);
            Iterator<String> it = ObjectBank.getLineIterator(new File(str)).iterator();
            while (it.hasNext()) {
                String next = it.next();
                if (list != null) {
                    list.add(next);
                }
                dataset.add(svmLightLineToDatum(next));
            }
            return dataset;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static Datum<String, String> svmLightLineToDatum(String str) {
        line1++;
        String[] split = str.replaceAll("#.*", "").split(WalkEncryption.Vals.REGEX_WS);
        ArrayList arrayList = new ArrayList();
        for (int i = 1; i < split.length; i++) {
            String[] split2 = split[i].split(":");
            if (split2.length != 2) {
                logger.info("Dataset error: line " + line1);
            }
            int parseDouble = (int) Double.parseDouble(split2[1]);
            for (int i2 = 0; i2 < parseDouble; i2++) {
                arrayList.add(split2[0]);
            }
        }
        arrayList.add(String.valueOf(Integer.MAX_VALUE));
        return new BasicDatum(arrayList, split[0]);
    }

    public Counter<F> getFeatureCounter() {
        ClassicCounter classicCounter = new ClassicCounter();
        for (int i = 0; i < size(); i++) {
            Iterator it = Generics.newHashSet(((BasicDatum) getDatum(i)).asFeatures()).iterator();
            while (it.hasNext()) {
                classicCounter.incrementCount(it.next(), 1.0d);
            }
        }
        return classicCounter;
    }

    public RVFDatum<L, F> getL1NormalizedTFIDFDatum(Datum<L, F> datum, Counter<F> counter) {
        ClassicCounter classicCounter = new ClassicCounter();
        for (F f : datum.asFeatures()) {
            if (counter.containsKey(f)) {
                classicCounter.incrementCount(f, 1.0d);
            }
        }
        double d = 0.0d;
        for (Object obj : classicCounter.keySet()) {
            double log = Math.log((size() + 1) / (counter.getCount(obj) + 0.5d));
            double count = classicCounter.getCount(obj);
            classicCounter.setCount(obj, count * log);
            d += count * log;
        }
        for (Object obj2 : classicCounter.keySet()) {
            classicCounter.setCount(obj2, classicCounter.getCount(obj2) / d);
        }
        return new RVFDatum<>(classicCounter, datum.label());
    }

    public RVFDataset<L, F> getL1NormalizedTFIDFDataset() {
        RVFDataset<L, F> rVFDataset = new RVFDataset<>(size(), this.featureIndex, this.labelIndex);
        Counter<F> featureCounter = getFeatureCounter();
        for (int i = 0; i < size(); i++) {
            rVFDataset.add(getL1NormalizedTFIDFDatum(getDatum(i), featureCounter));
        }
        return rVFDataset;
    }

    @Override // edu.stanford.nlp.classify.GeneralDataset
    public void add(Datum<L, F> datum) {
        add(datum.asFeatures(), (Collection<F>) datum.label());
    }

    public void add(Collection<F> collection, L l) {
        add(collection, l, true);
    }

    public void add(Collection<F> collection, L l, boolean z) {
        ensureSize();
        addLabel(l);
        addFeatures(collection, z);
        this.size++;
    }

    public void add(int[] iArr, int i) {
        ensureSize();
        addLabelIndex(i);
        addFeatureIndices(iArr);
        this.size++;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Type inference failed for: r0v10, types: [java.lang.Object, int[], int[][]] */
    public void ensureSize() {
        if (this.labels.length == this.size) {
            int[] iArr = new int[this.size * 2];
            ?? r0 = new int[this.size * 2];
            synchronized (System.class) {
                System.arraycopy(this.labels, 0, iArr, 0, this.size);
                System.arraycopy(this.data, 0, r0, 0, this.size);
            }
            this.labels = iArr;
            this.data = r0;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void addLabel(L l) {
        this.labelIndex.add(l);
        this.labels[this.size] = this.labelIndex.indexOf(l);
    }

    protected void addLabelIndex(int i) {
        this.labels[this.size] = i;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void addFeatures(Collection<F> collection) {
        addFeatures(collection, true);
    }

    protected void addFeatures(Collection<F> collection, boolean z) {
        int[] iArr = new int[collection.size()];
        int i = 0;
        for (F f : collection) {
            if (z) {
                this.featureIndex.add(f);
            }
            if (this.featureIndex.indexOf(f) >= 0) {
                iArr[i] = this.featureIndex.indexOf(f);
                i++;
            }
        }
        this.data[this.size] = new int[i];
        synchronized (System.class) {
            System.arraycopy(iArr, 0, this.data[this.size], 0, i);
        }
    }

    protected void addFeatureIndices(int[] iArr) {
        this.data[this.size] = iArr;
    }

    /* JADX WARN: Type inference failed for: r1v5, types: [int[], int[][]] */
    @Override // edu.stanford.nlp.classify.GeneralDataset
    protected final void initialize(int i) {
        this.labelIndex = new HashIndex();
        this.featureIndex = new HashIndex();
        this.labels = new int[i];
        this.data = new int[i];
        this.size = 0;
    }

    @Override // edu.stanford.nlp.classify.GeneralDataset
    public Datum<L, F> getDatum(int i) {
        return new BasicDatum(this.featureIndex.objects(this.data[i]), this.labelIndex.get(this.labels[i]));
    }

    @Override // edu.stanford.nlp.classify.GeneralDataset
    public RVFDatum<L, F> getRVFDatum(int i) {
        ClassicCounter classicCounter = new ClassicCounter();
        Iterator<F> it = this.featureIndex.objects(this.data[i]).iterator();
        while (it.hasNext()) {
            classicCounter.incrementCount(it.next());
        }
        return new RVFDatum<>(classicCounter, this.labelIndex.get(this.labels[i]));
    }

    @Override // edu.stanford.nlp.classify.GeneralDataset
    public void summaryStatistics() {
        logger.info(toSummaryStatistics());
    }

    public String toSummaryStatistics() {
        StringBuilder sb = new StringBuilder();
        sb.append("numDatums: ").append(this.size).append('\n');
        sb.append("numDatumsPerLabel: ").append(numDatumsPerLabel()).append('\n');
        sb.append("numLabels: ").append(this.labelIndex.size()).append(" [");
        Iterator<L> it = this.labelIndex.iterator();
        while (it.hasNext()) {
            sb.append(it.next());
            if (it.hasNext()) {
                sb.append(", ");
            }
        }
        sb.append("]\n");
        sb.append("numFeatures (Phi(X) types): ").append(this.featureIndex.size()).append(" [");
        int min = Math.min(5, this.featureIndex.size());
        for (int i = 0; i < min; i++) {
            if (i > 0) {
                sb.append(", ");
            }
            sb.append(this.featureIndex.get(i));
        }
        if (min < this.featureIndex.size()) {
            sb.append(", ...");
        }
        sb.append(']');
        return sb.toString();
    }

    public void applyFeatureCountThreshold(List<Pair<Pattern, Integer>> list) {
        float[] featureCounts = getFeatureCounts();
        HashIndex hashIndex = new HashIndex();
        for (F f : this.featureIndex) {
            Iterator<Pair<Pattern, Integer>> it = list.iterator();
            while (true) {
                if (!it.hasNext()) {
                    hashIndex.add(f);
                    break;
                }
                if (it.next().first().matcher(f.toString()).matches()) {
                    if (featureCounts[this.featureIndex.indexOf(f)] >= r0.second.intValue()) {
                        hashIndex.add(f);
                    }
                }
            }
        }
        int[] iArr = new int[this.featureIndex.size()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = hashIndex.indexOf(this.featureIndex.get(i));
        }
        this.featureIndex = null;
        for (int i2 = 0; i2 < this.size; i2++) {
            ArrayList arrayList = new ArrayList(this.data[i2].length);
            for (int i3 = 0; i3 < this.data[i2].length; i3++) {
                if (iArr[this.data[i2][i3]] >= 0) {
                    arrayList.add(Integer.valueOf(iArr[this.data[i2][i3]]));
                }
            }
            this.data[i2] = new int[arrayList.size()];
            for (int i4 = 0; i4 < this.data[i2].length; i4++) {
                this.data[i2][i4] = ((Integer) arrayList.get(i4)).intValue();
            }
        }
        this.featureIndex = hashIndex;
    }

    public void printFullFeatureMatrix(PrintWriter printWriter) {
        for (int i = 0; i < this.featureIndex.size(); i++) {
            printWriter.print("\t" + this.featureIndex.get(i));
        }
        printWriter.println();
        for (int i2 = 0; i2 < this.labels.length; i2++) {
            printWriter.print(this.labelIndex.get(i2));
            Set newHashSet = Generics.newHashSet();
            for (int i3 = 0; i3 < this.data[i2].length; i3++) {
                newHashSet.add(Integer.valueOf(this.data[i2][i3]));
            }
            for (int i4 = 0; i4 < this.featureIndex.size(); i4++) {
                if (newHashSet.contains(Integer.valueOf(i4))) {
                    printWriter.print("\t1");
                } else {
                    printWriter.print("\t0");
                }
            }
        }
    }

    @Override // edu.stanford.nlp.classify.GeneralDataset
    public void printSparseFeatureMatrix() {
        printSparseFeatureMatrix(new PrintWriter((OutputStream) System.out, true));
    }

    @Override // edu.stanford.nlp.classify.GeneralDataset
    public void printSparseFeatureMatrix(PrintWriter printWriter) {
        for (int i = 0; i < this.size; i++) {
            printWriter.print(this.labelIndex.get(this.labels[i]));
            for (int i2 : this.data[i]) {
                printWriter.print("\t" + this.featureIndex.get(i2));
            }
            printWriter.println();
        }
    }

    public void changeLabelIndex(Index<L> index) {
        this.labels = trimToSize(this.labels);
        for (int i = 0; i < this.labels.length; i++) {
            this.labels[i] = index.indexOf(this.labelIndex.get(this.labels[i]));
        }
        this.labelIndex = index;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v5, types: [int[], int[][]] */
    public void changeFeatureIndex(Index<F> index) {
        this.data = trimToSize(this.data);
        this.labels = trimToSize(this.labels);
        ?? r0 = new int[this.data.length];
        for (int i = 0; i < this.data.length; i++) {
            int[] iArr = new int[this.data[i].length];
            int i2 = 0;
            for (int i3 = 0; i3 < this.data[i].length; i3++) {
                int indexOf = index.indexOf(this.featureIndex.get(this.data[i][i3]));
                if (indexOf >= 0) {
                    int i4 = i2;
                    i2++;
                    iArr[i4] = indexOf;
                }
            }
            r0[i] = new int[i2];
            synchronized (System.class) {
                System.arraycopy(iArr, 0, r0[i], 0, i2);
            }
        }
        this.data = r0;
        this.featureIndex = index;
    }

    public void selectFeaturesBinaryInformationGain(int i) {
        selectFeatures(i, getInformationGains());
    }

    public void selectFeatures(int i, double[] dArr) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < dArr.length; i2++) {
            arrayList.add(new ScoredObject(this.featureIndex.get(i2), dArr[i2]));
        }
        Collections.sort(arrayList, ScoredComparator.DESCENDING_COMPARATOR);
        HashIndex hashIndex = new HashIndex();
        for (int i3 = 0; i3 < arrayList.size() && i3 < i; i3++) {
            hashIndex.add(((ScoredObject) arrayList.get(i3)).object());
        }
        for (int i4 = 0; i4 < this.size; i4++) {
            int[] iArr = new int[this.data[i4].length];
            int i5 = 0;
            for (int i6 = 0; i6 < this.data[i4].length; i6++) {
                int indexOf = hashIndex.indexOf(this.featureIndex.get(this.data[i4][i6]));
                if (indexOf != -1) {
                    int i7 = i5;
                    i5++;
                    iArr[i7] = indexOf;
                }
            }
            int[] iArr2 = new int[i5];
            synchronized (System.class) {
                System.arraycopy(iArr, 0, iArr2, 0, i5);
            }
            this.data[i4] = iArr2;
        }
        this.featureIndex = hashIndex;
    }

    public double[] getInformationGains() {
        this.labels = trimToSize(this.labels);
        ClassicCounter classicCounter = new ClassicCounter();
        ClassicCounter classicCounter2 = new ClassicCounter();
        TwoDimensionalCounter twoDimensionalCounter = new TwoDimensionalCounter();
        for (int i = 0; i < this.labels.length; i++) {
            classicCounter2.incrementCount(this.labelIndex.get(this.labels[i]));
            boolean[] zArr = new boolean[this.featureIndex.size()];
            for (int i2 = 0; i2 < this.data[i].length; i2++) {
                zArr[this.data[i][i2]] = true;
            }
            for (int i3 = 0; i3 < zArr.length; i3++) {
                if (zArr[i3]) {
                    classicCounter.incrementCount(this.featureIndex.get(i3));
                    twoDimensionalCounter.incrementCount(this.featureIndex.get(i3), this.labelIndex.get(this.labels[i]), 1.0d);
                }
            }
        }
        double d = 0.0d;
        for (int i4 = 0; i4 < this.labelIndex.size(); i4++) {
            double count = classicCounter2.getCount(this.labelIndex.get(i4)) / size();
            d -= (count * Math.log(count)) * LN_TO_LOG2;
        }
        double[] dArr = new double[this.featureIndex.size()];
        Arrays.fill(dArr, d);
        for (int i5 = 0; i5 < this.featureIndex.size(); i5++) {
            F f = this.featureIndex.get(i5);
            double count2 = classicCounter.getCount(f);
            double size = size() - count2;
            double size2 = count2 / size();
            double d2 = 1.0d - size2;
            if (count2 == 0.0d) {
                dArr[i5] = 0.0d;
            } else if (size == 0.0d) {
                dArr[i5] = 0.0d;
            } else {
                double d3 = 0.0d;
                double d4 = 0.0d;
                for (int i6 = 0; i6 < this.labelIndex.size(); i6++) {
                    double count3 = twoDimensionalCounter.getCount(f, this.labelIndex.get(i6));
                    double size3 = size() - count3;
                    double d5 = count3 / count2;
                    double d6 = size3 / size;
                    if (count3 != 0.0d) {
                        d3 += d5 * Math.log(d5) * LN_TO_LOG2;
                    }
                    if (size3 != 0.0d) {
                        d4 += d6 * Math.log(d6) * LN_TO_LOG2;
                    }
                }
                int i7 = i5;
                dArr[i7] = dArr[i7] + (size2 * d3) + (d2 * d4);
            }
        }
        return dArr;
    }

    public void updateLabels(int[] iArr) {
        if (iArr.length != size()) {
            throw new IllegalArgumentException("size of labels array does not match dataset size");
        }
        this.labels = iArr;
    }

    public String toString() {
        return "Dataset of size " + this.size;
    }

    public String toSummaryString() {
        PrintWriter printWriter = new PrintWriter(new StringWriter());
        printWriter.println("Number of data points: " + size());
        printWriter.println("Number of active feature tokens: " + numFeatureTokens());
        printWriter.println("Number of active feature types:" + numFeatureTypes());
        return printWriter.toString();
    }

    public static void printSVMLightFormat(PrintWriter printWriter, ClassicCounter<Integer> classicCounter, int i) {
        Integer[] numArr = (Integer[]) classicCounter.keySet().toArray(new Integer[classicCounter.keySet().size()]);
        Arrays.sort(numArr);
        StringBuilder sb = new StringBuilder();
        sb.append(i);
        sb.append(' ');
        for (Integer num : numArr) {
            int intValue = num.intValue();
            sb.append(intValue + 1).append(':').append(classicCounter.getCount(Integer.valueOf(intValue))).append(' ');
        }
        printWriter.println(sb.toString());
    }
}
