package edu.stanford.nlp.coref.misc;

import edu.stanford.nlp.classify.Dataset;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LogisticClassifier;
import edu.stanford.nlp.classify.LogisticClassifierFactory;
import edu.stanford.nlp.coref.data.CorefCluster;
import edu.stanford.nlp.coref.data.Dictionaries;
import edu.stanford.nlp.coref.data.Document;
import edu.stanford.nlp.coref.data.DocumentMaker;
import edu.stanford.nlp.coref.data.Mention;
import edu.stanford.nlp.dcoref.Constants;
import edu.stanford.nlp.ie.pascal.PascalTemplate;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.pipeline.DefaultPaths;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.tagger.maxent.TaggerConfig;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.PropertiesUtils;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Properties;

/* loaded from: input_file:edu/stanford/nlp/coref/misc/SingletonPredictor.class */
public class SingletonPredictor {
    private static final Redwood.RedwoodChannels log = Redwood.channels(SingletonPredictor.class);

    private SingletonPredictor() {
    }

    private static void setTokenIndices(Document document) {
        int i = 0;
        Iterator it = ((List) document.annotation.get(CoreAnnotations.SentencesAnnotation.class)).iterator();
        while (it.hasNext()) {
            Iterator it2 = ((List) ((CoreMap) it.next()).get(CoreAnnotations.TokensAnnotation.class)).iterator();
            while (it2.hasNext()) {
                int i2 = i;
                i++;
                ((CoreLabel) it2.next()).set(CoreAnnotations.TokenBeginAnnotation.class, Integer.valueOf(i2));
            }
        }
    }

    private static GeneralDataset<String, String> generateFeatureVectors(Properties properties) throws Exception {
        Dataset dataset = new Dataset();
        Dictionaries dictionaries = new Dictionaries(properties);
        DocumentMaker documentMaker = new DocumentMaker(properties, dictionaries);
        while (true) {
            Document nextDoc = documentMaker.nextDoc();
            if (nextDoc == null) {
                dataset.summaryStatistics();
                return dataset;
            }
            setTokenIndices(nextDoc);
            Iterator<CorefCluster> it = nextDoc.goldCorefClusters.values().iterator();
            while (it.hasNext()) {
                for (Mention mention : it.next().getCorefMentions()) {
                    if (!mention.headWord.tag().startsWith("V") && mention.enhancedDependency.getNodeByIndexSafe(mention.headWord.index()) != null) {
                        dataset.add(new BasicDatum(mention.getSingletonFeatures(dictionaries), TaggerConfig.NTHREADS));
                    }
                }
            }
            ArrayList arrayList = new ArrayList();
            Iterator<Mention> it2 = nextDoc.goldMentionsByID.values().iterator();
            while (it2.hasNext()) {
                arrayList.add(it2.next().headWord);
            }
            for (Mention mention2 : nextDoc.predictedMentionsByID.values()) {
                SemanticGraph semanticGraph = mention2.enhancedDependency;
                IndexedWord nodeByIndexSafe = semanticGraph.getNodeByIndexSafe(mention2.headWord.index());
                if (nodeByIndexSafe != null && semanticGraph.vertexSet().contains(nodeByIndexSafe) && !mention2.headWord.tag().startsWith("V") && !arrayList.contains(mention2.headWord)) {
                    dataset.add(new BasicDatum(mention2.getSingletonFeatures(dictionaries), PascalTemplate.BACKGROUND_SYMBOL));
                }
            }
        }
    }

    public static LogisticClassifier<String, String> train(GeneralDataset<String, String> generalDataset) {
        return new LogisticClassifierFactory().trainClassifier((GeneralDataset) generalDataset);
    }

    private static void saveToSerialized(LogisticClassifier<String, String> logisticClassifier, String str) {
        try {
            log.info("Writing singleton predictor in serialized format to file " + str + ' ');
            ObjectOutputStream writeStreamFromString = IOUtils.writeStreamFromString(str);
            writeStreamFromString.writeObject(logisticClassifier);
            writeStreamFromString.close();
            log.info("done.");
        } catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    }

    private static String getPathSingletonPredictor(Properties properties) {
        return PropertiesUtils.getString(properties, "coref.path.singletonPredictor", DefaultPaths.DEFAULT_DCOREF_SINGLETON_MODEL);
    }

    public static void main(String[] strArr) throws Exception {
        Properties argsToProperties = strArr.length > 0 ? StringUtils.argsToProperties(strArr) : new Properties();
        if (!argsToProperties.containsKey(Constants.CONLL2011_PROP)) {
            log.info("-dcoref.conll2011 [input_CoNLL_corpus]: was not specified");
        } else if (argsToProperties.containsKey("singleton.predictor.output")) {
            saveToSerialized(train(generateFeatureVectors(argsToProperties)), getPathSingletonPredictor(argsToProperties));
        } else {
            log.info("-singleton.predictor.output [output_model_file]: was not specified");
        }
    }
}
