package edu.stanford.nlp.naturalli;

import edu.stanford.nlp.ie.machinereading.structure.Span;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasIndex;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.process.TSVSentenceProcessor;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations;
import edu.stanford.nlp.semgraph.SemanticGraphEdge;
import edu.stanford.nlp.semgraph.semgrex.SemgrexMatcher;
import edu.stanford.nlp.semgraph.semgrex.SemgrexPattern;
import edu.stanford.nlp.trees.PennTreeReader;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.trees.UniversalEnglishGrammaticalStructureFactory;
import edu.stanford.nlp.util.ArgumentParser;
import edu.stanford.nlp.util.ArrayCoreMap;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Interval;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/* loaded from: input_file:edu/stanford/nlp/naturalli/CreateClauseDataset.class */
public class CreateClauseDataset implements TSVSentenceProcessor {
    private static Redwood.RedwoodChannels log;

    @ArgumentParser.Option(name = "in", gloss = "The input to read from")
    private static InputStream in;
    private static Pattern TRACE_TARGET_PATTERN;
    private static Pattern TRACE_SOURCE_PATTERN;
    private static UniversalEnglishGrammaticalStructureFactory parser;
    private static RelationTripleSegmenter segmenter;
    private static NaturalLogicAnnotator natlog;
    static final /* synthetic */ boolean $assertionsDisabled;

    private static Span toSpan(List<? extends HasIndex> list) {
        int i = Integer.MAX_VALUE;
        int i2 = -1;
        for (HasIndex hasIndex : list) {
            i = Math.min(hasIndex.index() - 1, i);
            i2 = Math.max(hasIndex.index(), i2);
        }
        if (!$assertionsDisabled && i < 0) {
            throw new AssertionError();
        }
        if ($assertionsDisabled || (i2 < Integer.MAX_VALUE && i2 > 0)) {
            return new Span(i, i2);
        }
        throw new AssertionError();
    }

    @Override // edu.stanford.nlp.process.TSVSentenceProcessor
    public void process(long j, Annotation annotation) {
        CoreMap coreMap = (CoreMap) ((List) annotation.get(CoreAnnotations.SentencesAnnotation.class)).get(0);
        SemanticGraph semanticGraph = (SemanticGraph) coreMap.get(SemanticGraphCoreAnnotations.BasicDependenciesAnnotation.class);
        log.info("| " + ((String) coreMap.get(CoreAnnotations.TextAnnotation.class)));
        BitSet bitSet = new BitSet();
        ArrayList arrayList = new ArrayList();
        for (IndexedWord indexedWord : semanticGraph.topologicalSort()) {
            if (indexedWord.tag().startsWith("N") || indexedWord.tag().equals("PRP")) {
                Optional<List<IndexedWord>> validChunk = segmenter.getValidChunk(semanticGraph, indexedWord, segmenter.VALID_SUBJECT_ARCS, Optional.empty(), true);
                if (validChunk.isPresent()) {
                    Iterator<IndexedWord> it = validChunk.get().iterator();
                    while (true) {
                        if (it.hasNext()) {
                            if (bitSet.get(it.next().index())) {
                                break;
                            }
                        } else {
                            Iterator<IndexedWord> it2 = validChunk.get().iterator();
                            while (it2.hasNext()) {
                                bitSet.set(it2.next().index());
                            }
                            arrayList.add(toSpan(validChunk.get()));
                        }
                    }
                }
            }
        }
    }

    private static SemanticGraph parse(Tree tree) {
        return new SemanticGraph(parser.newGrammaticalStructure(tree).typedDependenciesCollapsed());
    }

    private static Collection<Pair<Span, Span>> subjectObjectPairs(SemanticGraph semanticGraph, List<CoreLabel> list, Map<Integer, Span> map, Map<Integer, Integer> map2) {
        ArrayList arrayList = new ArrayList();
        Iterator<SemgrexPattern> it = segmenter.VP_PATTERNS.iterator();
        while (it.hasNext()) {
            SemgrexMatcher matcher = it.next().matcher(semanticGraph);
            while (matcher.find()) {
                IndexedWord node = matcher.getNode("verb");
                IndexedWord node2 = matcher.getNode("object");
                if (node != null && node2 != null) {
                    boolean z = false;
                    Iterator<SemanticGraphEdge> it2 = semanticGraph.outgoingEdgeIterable(node).iterator();
                    while (it2.hasNext()) {
                        if (it2.next().getRelation().toString().contains("subj")) {
                            z = true;
                        }
                    }
                    Iterator<SemanticGraphEdge> it3 = semanticGraph.outgoingEdgeIterable(node2).iterator();
                    while (it3.hasNext()) {
                        if (it3.next().getRelation().toString().contains("subj")) {
                            z = true;
                        }
                    }
                    if (!z) {
                        Optional<List<IndexedWord>> validChunk = segmenter.getValidChunk(semanticGraph, node, segmenter.VALID_ADVERB_ARCS, Optional.empty(), true);
                        Optional<List<IndexedWord>> validChunk2 = segmenter.getValidChunk(semanticGraph, node2, segmenter.VALID_OBJECT_ARCS, Optional.empty(), true);
                        if (validChunk.isPresent() && validChunk2.isPresent()) {
                            Collections.sort(validChunk.get(), (indexedWord, indexedWord2) -> {
                                return indexedWord.index() - indexedWord2.index();
                            });
                            Collections.sort(validChunk2.get(), (indexedWord3, indexedWord4) -> {
                                return indexedWord3.index() - indexedWord4.index();
                            });
                            int i = -1;
                            Span span = toSpan(validChunk.get());
                            Span fromValues = Span.fromValues(span.start() - 1, span.end() + 1);
                            for (Map.Entry<Integer, Integer> entry : map2.entrySet()) {
                                if (fromValues.contains(entry.getValue().intValue())) {
                                    i = entry.getKey().intValue();
                                }
                            }
                            if (i >= 0) {
                                Span span2 = map.get(Integer.valueOf(i));
                                Span span3 = toSpan(validChunk2.get());
                                if (span2 != null) {
                                    arrayList.add(Pair.makePair(span2, span3));
                                }
                            }
                        }
                    }
                }
            }
        }
        Iterator<SemgrexPattern> it4 = segmenter.VERB_PATTERNS.iterator();
        while (it4.hasNext()) {
            SemgrexMatcher matcher2 = it4.next().matcher(semanticGraph);
            while (matcher2.find()) {
                IndexedWord node3 = matcher2.getNode("subject");
                IndexedWord node4 = matcher2.getNode("object");
                if (node3 != null && node4 != null) {
                    Optional<List<IndexedWord>> validChunk3 = segmenter.getValidChunk(semanticGraph, node3, segmenter.VALID_SUBJECT_ARCS, Optional.empty(), true);
                    Optional<List<IndexedWord>> validChunk4 = segmenter.getValidChunk(semanticGraph, node4, segmenter.VALID_OBJECT_ARCS, Optional.empty(), true);
                    if (validChunk3.isPresent() && validChunk4.isPresent()) {
                        arrayList.add(Pair.makePair(toSpan(validChunk3.get()), toSpan(validChunk4.get())));
                    }
                }
            }
        }
        return arrayList;
    }

    private static Map<Integer, Span> findTraceTargets(Tree tree) {
        HashMap hashMap = new HashMap(4);
        Matcher matcher = TRACE_TARGET_PATTERN.matcher(tree.label().value() == null ? "NULL" : tree.label().value());
        if (matcher.matches()) {
            hashMap.put(Integer.valueOf(Integer.parseInt(matcher.group(2))), Span.fromPair(tree.getSpan()).toExclusive());
        }
        for (Tree tree2 : tree.children()) {
            hashMap.putAll(findTraceTargets(tree2));
        }
        return hashMap;
    }

    private static Map<Integer, Integer> findTraceSources(Tree tree) {
        HashMap hashMap = new HashMap(4);
        Matcher matcher = TRACE_SOURCE_PATTERN.matcher(tree.label().value() == null ? "NULL" : tree.label().value());
        if (matcher.matches()) {
            hashMap.put(Integer.valueOf(Integer.parseInt(matcher.group(1))), Integer.valueOf(((CoreLabel) tree.label()).index() - 1));
        }
        for (Tree tree2 : tree.children()) {
            hashMap.putAll(findTraceSources(tree2));
        }
        return hashMap;
    }

    private static int countDatums(List<Pair<CoreMap, Collection<Pair<Span, Span>>>> list) {
        int i = 0;
        Iterator<Pair<CoreMap, Collection<Pair<Span, Span>>>> it = list.iterator();
        while (it.hasNext()) {
            i += it.next().second.size();
        }
        return i;
    }

    private static List<Pair<CoreMap, Collection<Pair<Span, Span>>>> processDirectory(String str, File file) throws IOException {
        Redwood.Util.forceTrack("Processing " + str);
        Iterable<File> iterFilesRecursive = IOUtils.iterFilesRecursive(file, Treebank.DEFAULT_TREE_FILE_SUFFIX);
        int i = 0;
        ArrayList arrayList = new ArrayList(Interval.REL_FLAGS_ES_AFTER);
        Iterator<File> it = iterFilesRecursive.iterator();
        while (it.hasNext()) {
            PennTreeReader pennTreeReader = new PennTreeReader(IOUtils.readerFromFile(it.next()));
            while (true) {
                Tree readTree = pennTreeReader.readTree();
                if (readTree != null) {
                    try {
                        readTree.indexSpans();
                        readTree.setSpans();
                        final List list = (List) readTree.getLeaves().stream().map(tree -> {
                            return (CoreLabel) tree.label();
                        }).collect(Collectors.toList());
                        final SemanticGraph parse = parse(readTree);
                        Map<Integer, Span> findTraceTargets = findTraceTargets(readTree);
                        Map<Integer, Integer> findTraceSources = findTraceSources(readTree);
                        ArrayCoreMap arrayCoreMap = new ArrayCoreMap(4) { // from class: edu.stanford.nlp.naturalli.CreateClauseDataset.1
                            {
                                set(CoreAnnotations.TokensAnnotation.class, list);
                                set(SemanticGraphCoreAnnotations.BasicDependenciesAnnotation.class, parse);
                                set(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class, parse);
                                set(SemanticGraphCoreAnnotations.EnhancedPlusPlusDependenciesAnnotation.class, parse);
                            }
                        };
                        natlog.doOneSentence(null, arrayCoreMap);
                        arrayList.add(Pair.makePair(arrayCoreMap, subjectObjectPairs(parse, list, findTraceTargets, findTraceSources)));
                        i++;
                        if (i % 100 == 0) {
                            Redwood.Util.log("[" + new DecimalFormat("00000").format(i) + "] " + countDatums(arrayList) + " known extractions");
                        }
                    } catch (Throwable th) {
                        th.printStackTrace();
                    }
                }
            }
        }
        Redwood.Util.log("" + i + " trees processed yielding " + countDatums(arrayList) + " known extractions");
        Redwood.Util.endTrack("Processing " + str);
        return arrayList;
    }

    public static void main(String[] strArr) throws IOException {
        Redwood.Util.forceTrack("Processing treebanks");
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(processDirectory("WSJ", new File("/home/gabor/lib/data/penn_treebank/wsj")));
        arrayList.addAll(processDirectory("Brown", new File("/home/gabor/lib/data/penn_treebank/brown")));
        Redwood.Util.endTrack("Processing treebanks");
        Redwood.Util.forceTrack("Training");
        Redwood.Util.log("dataset size: " + arrayList.size());
        ClauseSplitter.train(arrayList.stream(), new File("/home/gabor/tmp/clauseSearcher.ser.gz"), new File("/home/gabor/tmp/clauseSearcherData.tab.gz"));
        Redwood.Util.endTrack("Training");
    }

    static {
        $assertionsDisabled = !CreateClauseDataset.class.desiredAssertionStatus();
        log = Redwood.channels(CreateClauseDataset.class);
        in = System.in;
        TRACE_TARGET_PATTERN = Pattern.compile("(NP-.*)-([0-9]+)");
        TRACE_SOURCE_PATTERN = Pattern.compile(".*\\*-([0-9]+)");
        parser = new UniversalEnglishGrammaticalStructureFactory();
        segmenter = new RelationTripleSegmenter();
        natlog = new NaturalLogicAnnotator();
    }
}
