package edu.stanford.nlp.naturalli;

import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.classify.LinearClassifierFactory;
import edu.stanford.nlp.classify.WeightedDataset;
import edu.stanford.nlp.ie.machinereading.structure.Span;
import edu.stanford.nlp.ie.util.RelationTriple;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.PropertiesUtils;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Trilean;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.stream.Stream;
import java.util.zip.GZIPOutputStream;

/* loaded from: input_file:BOOT-INF/lib/stanford-corenlp-4.0.0.jar:edu/stanford/nlp/naturalli/ClauseSplitter.class */
public interface ClauseSplitter extends BiFunction<SemanticGraph, Boolean, ClauseSplitterSearchProblem> {
    public static final Redwood.RedwoodChannels log = Redwood.channels(ClauseSplitter.class);

    /* loaded from: input_file:BOOT-INF/lib/stanford-corenlp-4.0.0.jar:edu/stanford/nlp/naturalli/ClauseSplitter$ClauseClassifierLabel.class */
    public enum ClauseClassifierLabel {
        CLAUSE_SPLIT(2),
        CLAUSE_INTERM(1),
        NOT_A_CLAUSE(0);

        public final byte index;

        ClauseClassifierLabel(int i) {
            this.index = (byte) i;
        }

        @Override // java.lang.Enum
        public String toString() {
            return name();
        }

        public static ClauseClassifierLabel fromIndex(int i) {
            switch (i) {
                case 0:
                    return NOT_A_CLAUSE;
                case 1:
                    return CLAUSE_INTERM;
                case 2:
                    return CLAUSE_SPLIT;
                default:
                    throw new IllegalArgumentException("Not a valid index: " + i);
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    static ClauseSplitter train(Stream<Pair<CoreMap, Collection<Pair<Span, Span>>>> stream, Optional<File> optional, Optional<File> optional2, ClauseSplitterSearchProblem.Featurizer featurizer) {
        LinearClassifierFactory linearClassifierFactory = new LinearClassifierFactory();
        OpenIE openIE = new OpenIE(PropertiesUtils.asProperties("splitter.nomodel", "true", "optimizefor", "GENERAL"));
        WeightedDataset weightedDataset = new WeightedDataset();
        AtomicInteger atomicInteger = new AtomicInteger(0);
        Optional<U> map = optional2.map(file -> {
            try {
                return new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream((File) optional2.get()))));
            } catch (IOException e) {
                throw new RuntimeIOException(e);
            }
        });
        Redwood.Util.forceTrack("Training inference");
        stream.forEach(pair -> {
            CoreMap coreMap = (CoreMap) pair.first;
            Collection collection = (Collection) pair.second;
            List list = (List) coreMap.get(CoreAnnotations.TokensAnnotation.class);
            new ClauseSplitterSearchProblem((SemanticGraph) coreMap.get(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class), true).search(triple -> {
                List list2 = (List) triple.second;
                HashSet hashSet = new HashSet(openIE.relationsInFragments(openIE.entailmentsFromClause((SentenceFragment) ((Supplier) triple.third).get())));
                Trilean trilean = Trilean.FALSE;
                Iterator it = hashSet.iterator();
                loop0: while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    RelationTriple relationTriple = (RelationTriple) it.next();
                    Span fromValues = Span.fromValues(relationTriple.subject.get(0).index() - 1, relationTriple.subject.get(relationTriple.subject.size() - 1).index());
                    Span fromValues2 = Span.fromValues(relationTriple.object.get(0).index() - 1, relationTriple.object.get(relationTriple.object.size() - 1).index());
                    Iterator it2 = collection.iterator();
                    while (it2.hasNext()) {
                        Pair pair = (Pair) it2.next();
                        Span span = (Span) pair.first;
                        Span span2 = (Span) pair.second;
                        if ((!fromValues.equals(span) || !fromValues2.equals(span2)) && (!fromValues.equals(span2) || !fromValues2.equals(span))) {
                            if ((!Util.nerOverlap(list, span, fromValues) || !Util.nerOverlap(list, span2, fromValues2)) && (!Util.nerOverlap(list, span, fromValues2) || !Util.nerOverlap(list, span2, fromValues))) {
                                if (!trilean.isTrue()) {
                                    trilean = Trilean.UNKNOWN;
                                    break loop0;
                                }
                            } else {
                                if (!trilean.isTrue()) {
                                    trilean = Trilean.TRUE;
                                    break loop0;
                                }
                            }
                        }
                    }
                }
                trilean = Trilean.TRUE;
                if (list2.isEmpty()) {
                    return true;
                }
                ArrayList<Pair> arrayList = new ArrayList();
                if (trilean.isTrue()) {
                    for (int i = 0; i < list2.size(); i++) {
                        if (i == list2.size() - 1) {
                            arrayList.add(Pair.makePair(list2.get(i), ClauseClassifierLabel.CLAUSE_SPLIT));
                        } else {
                            arrayList.add(Pair.makePair(list2.get(i), ClauseClassifierLabel.CLAUSE_INTERM));
                        }
                    }
                } else if (trilean.isFalse()) {
                    arrayList.add(Pair.makePair(list2.get(list2.size() - 1), ClauseClassifierLabel.NOT_A_CLAUSE));
                } else if (trilean.isUnknown()) {
                    boolean z = false;
                    Iterator it3 = list2.iterator();
                    while (true) {
                        if (!it3.hasNext()) {
                            break;
                        }
                        if (featurizer.isSimpleSplit((Counter) it3.next())) {
                            z = true;
                            break;
                        }
                    }
                    if (z) {
                        for (int i2 = 0; i2 < list2.size(); i2++) {
                            if (i2 == list2.size() - 1) {
                                arrayList.add(Pair.makePair(list2.get(i2), ClauseClassifierLabel.CLAUSE_SPLIT));
                            } else {
                                arrayList.add(Pair.makePair(list2.get(i2), ClauseClassifierLabel.CLAUSE_INTERM));
                            }
                        }
                    }
                }
                for (Pair pair2 : arrayList) {
                    RVFDatum rVFDatum = new RVFDatum((Counter) pair2.first);
                    rVFDatum.setLabel(pair2.second);
                    if (map.isPresent()) {
                        ((PrintWriter) map.get()).println(pair2.second + "\t" + StringUtils.join(((Counter) pair2.first).entrySet().stream().map(entry -> {
                            return ((String) entry.getKey()) + "->" + entry.getValue();
                        }), ";"));
                    }
                    weightedDataset.add(rVFDatum);
                }
                return true;
            }, new LinearClassifier(new ClassicCounter()), Collections.emptyMap(), featurizer, 10000);
            if (atomicInteger.incrementAndGet() % 100 == 0) {
                Redwood.Util.log("processed " + atomicInteger + " training sentences: " + weightedDataset.size() + " datums");
            }
        });
        Redwood.Util.endTrack("Training inference");
        if (map.isPresent()) {
            ((PrintWriter) map.get()).close();
        }
        Redwood.Util.forceTrack("Training");
        LinearClassifier trainClassifier = linearClassifierFactory.trainClassifier((GeneralDataset) weightedDataset);
        Redwood.Util.endTrack("Training");
        if (optional.isPresent()) {
            try {
                IOUtils.writeObjectToFile(Pair.makePair(trainClassifier, featurizer), optional.get());
                Redwood.Util.log("SUCCESS: wrote model to " + optional.get().getPath());
            } catch (IOException e) {
                Redwood.Util.log("ERROR: failed to save model to path: " + optional.get().getPath());
                Redwood.Util.err(e);
            }
        }
        Redwood.Util.forceTrack("Training accuracy");
        weightedDataset.randomize(42L);
        Util.dumpAccuracy(trainClassifier, weightedDataset);
        Redwood.Util.endTrack("Training accuracy");
        Redwood.Util.forceTrack("5 fold cross-validation");
        for (int i = 0; i < 5; i++) {
            Redwood.Util.forceTrack("Fold " + (i + 1));
            Redwood.Util.forceTrack("Training");
            Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> splitOutFold = weightedDataset.splitOutFold(i, 5);
            LinearClassifier trainClassifier2 = linearClassifierFactory.trainClassifier((GeneralDataset) splitOutFold.first);
            Redwood.Util.endTrack("Training");
            Redwood.Util.forceTrack("Test");
            Util.dumpAccuracy(trainClassifier2, (GeneralDataset) splitOutFold.second);
            Redwood.Util.endTrack("Test");
            Redwood.Util.endTrack("Fold " + (i + 1));
        }
        Redwood.Util.endTrack("5 fold cross-validation");
        return (semanticGraph, bool) -> {
            return new ClauseSplitterSearchProblem(semanticGraph, bool.booleanValue(), Optional.of(trainClassifier), Optional.of(featurizer));
        };
    }

    static ClauseSplitter train(Stream<Pair<CoreMap, Collection<Pair<Span, Span>>>> stream, File file, File file2) {
        return train(stream, Optional.of(file), Optional.of(file2), ClauseSplitterSearchProblem.DEFAULT_FEATURIZER);
    }

    static ClauseSplitter load(String str) throws IOException {
        try {
            long currentTimeMillis = System.currentTimeMillis();
            Pair pair = (Pair) IOUtils.readObjectFromURLOrClasspathOrFileSystem(str);
            ClauseSplitter clauseSplitter = (semanticGraph, bool) -> {
                return new ClauseSplitterSearchProblem(semanticGraph, bool.booleanValue(), Optional.of(pair.first), Optional.of(pair.second));
            };
            log.info("Loading clause splitter from " + str + " ... done [" + Redwood.formatTimeDifference(System.currentTimeMillis() - currentTimeMillis) + "]");
            return clauseSplitter;
        } catch (ClassNotFoundException e) {
            throw new IllegalStateException("Invalid model at path: " + str, e);
        }
    }
}
