package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.classify.LinearClassifierFactory;
import edu.stanford.nlp.classify.WeightedDataset;
import edu.stanford.nlp.io.NumberRangesFileFilter;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.Distribution;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.trees.MemoryTreebank;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreebankLanguagePack;
import edu.stanford.nlp.util.CollectionValuedMap;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.regex.Pattern;

/* loaded from: input_file:edu/stanford/nlp/parser/lexparser/ChineseMaxentLexicon.class */
public class ChineseMaxentLexicon implements Lexicon {
    private static final long serialVersionUID = 238834703409896852L;
    private static final boolean verbose = true;
    public static final boolean seenTagsOnly = false;
    private ChineseWordFeatureExtractor featExtractor;
    public static final boolean fixUnkFunctionWords = false;
    private LinearClassifier scorer;
    private Distribution<String> tagDist;
    private final Index<String> wordIndex;
    private final Index<String> tagIndex;
    private transient Counter<String> logProbs;
    private static final String featureDir = "gbfeatures";
    static final boolean tuneSigma = false;
    static final int trainCountThreshold = 5;
    final int featureLevel;
    static final int DEFAULT_FEATURE_LEVEL = 2;
    private final TreebankLangParserParams tlpParams;
    private final TreebankLanguagePack ctlp;
    private final Options op;
    transient IntCounter<TaggedWord> datumCounter;
    private static Redwood.RedwoodChannels log = Redwood.channels(ChineseMaxentLexicon.class);
    private static final Pattern wordPattern = Pattern.compile(".*-W");
    private static final Pattern charPattern = Pattern.compile(".*-.C");
    private static final Pattern bigramPattern = Pattern.compile(".*-.B");
    private static final Pattern conjPattern = Pattern.compile(".*&&.*");
    private final Pair<Pattern, Integer> wordThreshold = new Pair<>(wordPattern, 0);
    private final Pair<Pattern, Integer> charThreshold = new Pair<>(charPattern, 2);
    private final Pair<Pattern, Integer> bigramThreshold = new Pair<>(bigramPattern, 3);
    private final Pair<Pattern, Integer> conjThreshold = new Pair<>(conjPattern, 3);
    private final List<Pair<Pattern, Integer>> featureThresholds = new ArrayList();
    private final int universalThreshold = 0;
    private Map<String, String> functionWordTags = Generics.newHashMap();
    private double iteratorCutoffFactor = 4.0d;
    private transient int lastWord = -1;
    String initialWeightFile = null;
    boolean trainFloat = false;
    private double tol = 1.0E-4d;
    private double sigma = 0.4d;
    private boolean trainOnLowCount = false;
    private boolean trainByType = false;
    public CollectionValuedMap<String, String> tagsForWord = new CollectionValuedMap<>();

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public boolean isKnown(int i) {
        return isKnown(this.wordIndex.get(i));
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public boolean isKnown(String str) {
        return this.tagsForWord.containsKey(str);
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public Set<String> tagSet(Function<String, String> function) {
        HashSet hashSet = new HashSet();
        Iterator<String> it = this.tagIndex.objectsList().iterator();
        while (it.hasNext()) {
            hashSet.add(function.apply(it.next()));
        }
        return hashSet;
    }

    private void ensureProbs(int i) {
        ensureProbs(i, true);
    }

    private void ensureProbs(int i, boolean z) {
        if (i == this.lastWord) {
            return;
        }
        this.lastWord = i;
        if (this.functionWordTags.containsKey(this.wordIndex.get(i))) {
            this.logProbs = new ClassicCounter();
            String str = this.functionWordTags.get(this.wordIndex.get(i));
            for (String str2 : this.tagIndex.objectsList()) {
                if (this.ctlp.basicCategory(str2).equals(str)) {
                    this.logProbs.setCount(str2, 0.0d);
                } else {
                    this.logProbs.setCount(str2, Double.NEGATIVE_INFINITY);
                }
            }
            return;
        }
        this.logProbs = this.scorer.logProbabilityOf(new BasicDatum(this.featExtractor.makeFeatures(this.wordIndex.get(i))));
        if (z) {
            for (String str3 : this.logProbs.keySet()) {
                this.logProbs.incrementCount(str3, -Math.log(this.tagDist.probabilityOf(str3)));
            }
        }
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public Iterator<IntTaggedWord> ruleIteratorByWord(int i, int i2, String str) {
        ensureProbs(i);
        ArrayList arrayList = new ArrayList();
        double max = Counters.max(this.logProbs);
        for (int i3 = 0; i3 < this.tagIndex.size(); i3++) {
            IntTaggedWord intTaggedWord = new IntTaggedWord(i, i3);
            if (this.logProbs.getCount(this.tagIndex.get(i3)) > max - this.iteratorCutoffFactor) {
                arrayList.add(intTaggedWord);
            }
        }
        return arrayList.iterator();
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public Iterator<IntTaggedWord> ruleIteratorByWord(String str, int i, String str2) {
        return ruleIteratorByWord(this.wordIndex.indexOf(str), i, str2);
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public int numRules() {
        int i = 0;
        int size = this.wordIndex.size();
        for (int i2 = 0; i2 < size; i2++) {
            Iterator<IntTaggedWord> ruleIteratorByWord = ruleIteratorByWord(i2, 0, (String) null);
            while (ruleIteratorByWord.hasNext()) {
                ruleIteratorByWord.next();
                i++;
            }
        }
        return i;
    }

    private String getTag(String str) {
        ensureProbs(this.wordIndex.addToIndex(str), false);
        return (String) Counters.argmax(this.logProbs);
    }

    private void verbose(String str) {
        log.info(str);
    }

    public ChineseMaxentLexicon(Options options, Index<String> index, Index<String> index2, int i) {
        this.op = options;
        this.tlpParams = options.tlpParams;
        this.ctlp = options.tlpParams.treebankLanguagePack();
        this.wordIndex = index;
        this.tagIndex = index2;
        this.featureLevel = i;
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public void initializeTraining(double d) {
        verbose("Training ChineseMaxentLexicon.");
        verbose("trainOnLowCount = " + this.trainOnLowCount + ", trainByType = " + this.trainByType + ", featureLevel = " + this.featureLevel + ", tuneSigma = false");
        verbose("Making dataset...");
        if (this.featExtractor == null) {
            this.featExtractor = new ChineseWordFeatureExtractor(this.featureLevel);
        }
        this.datumCounter = new IntCounter<>();
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public final void train(Collection<Tree> collection) {
        train(collection, 1.0d);
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public void train(Collection<Tree> collection, double d) {
        Iterator<Tree> it = collection.iterator();
        while (it.hasNext()) {
            train(it.next(), d);
        }
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public void train(Tree tree, double d) {
        train((List<TaggedWord>) tree.taggedYield(), d);
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public void train(List<TaggedWord> list, double d) {
        this.featExtractor.train(list, d);
        for (TaggedWord taggedWord : list) {
            this.datumCounter.incrementCount((IntCounter<TaggedWord>) taggedWord, d);
            this.tagsForWord.add(taggedWord.word(), taggedWord.tag());
        }
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public void trainUnannotated(List<TaggedWord> list, double d) {
        throw new UnsupportedOperationException("This version of the parser does not support non-tree training data");
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public void incrementTreesRead(double d) {
        throw new UnsupportedOperationException();
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public void train(TaggedWord taggedWord, int i, double d) {
        throw new UnsupportedOperationException();
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public void finishTraining() {
        IntCounter intCounter = new IntCounter();
        WeightedDataset weightedDataset = new WeightedDataset(this.datumCounter.size());
        for (TaggedWord taggedWord : this.datumCounter.keySet()) {
            int intCount = this.datumCounter.getIntCount(taggedWord);
            if (!this.trainOnLowCount || intCount <= 5) {
                if (!this.functionWordTags.containsKey(taggedWord.word())) {
                    intCounter.incrementCount(taggedWord.tag());
                    if (this.trainByType) {
                        intCount = 1;
                    }
                    weightedDataset.add(new BasicDatum(this.featExtractor.makeFeatures(taggedWord.word()), taggedWord.tag()), intCount);
                }
            }
        }
        this.datumCounter = null;
        this.tagDist = Distribution.laplaceSmoothedDistribution(intCounter, intCounter.size(), 0.5d);
        applyThresholds(weightedDataset);
        verbose("Making classifier...");
        LinearClassifierFactory linearClassifierFactory = new LinearClassifierFactory(new QNMinimizer());
        linearClassifierFactory.setTol(this.tol);
        linearClassifierFactory.setSigma(this.sigma);
        this.scorer = linearClassifierFactory.trainClassifier((GeneralDataset) weightedDataset);
        verbose("Done training.");
    }

    private void applyThresholds(WeightedDataset weightedDataset) {
        if (this.wordThreshold.second.intValue() > 0) {
            this.featureThresholds.add(this.wordThreshold);
        }
        if (this.featExtractor.chars && this.charThreshold.second.intValue() > 0) {
            this.featureThresholds.add(this.charThreshold);
        }
        if (this.featExtractor.bigrams && this.bigramThreshold.second.intValue() > 0) {
            this.featureThresholds.add(this.bigramThreshold);
        }
        if ((this.featExtractor.conjunctions || this.featExtractor.mildConjunctions) && this.conjThreshold.second.intValue() > 0) {
            this.featureThresholds.add(this.conjThreshold);
        }
        int numFeatureTypes = weightedDataset.numFeatureTypes();
        if (this.featureThresholds.size() > 0) {
            weightedDataset.applyFeatureCountThreshold(this.featureThresholds);
        }
        int numFeatureTypes2 = numFeatureTypes - weightedDataset.numFeatureTypes();
        if (numFeatureTypes2 > 0) {
            verbose("Thresholding removed " + numFeatureTypes2 + " features.");
        }
    }

    public static void main(String[] strArr) {
        ChineseTreebankParserParams chineseTreebankParserParams = new ChineseTreebankParserParams();
        chineseTreebankParserParams.treebankLanguagePack();
        Options options = new Options(chineseTreebankParserParams);
        TreeAnnotator treeAnnotator = new TreeAnnotator(chineseTreebankParserParams.headFinder(), chineseTreebankParserParams, options);
        log.info("Reading Trees...");
        NumberRangesFileFilter numberRangesFileFilter = new NumberRangesFileFilter(strArr[1], true);
        MemoryTreebank memoryTreebank = chineseTreebankParserParams.memoryTreebank();
        memoryTreebank.loadPath(strArr[0], numberRangesFileFilter);
        log.info("Annotating trees...");
        ArrayList arrayList = new ArrayList();
        Iterator<Tree> it = memoryTreebank.iterator();
        while (it.hasNext()) {
            arrayList.add(treeAnnotator.transformTree(it.next()));
        }
        log.info("Training lexicon...");
        HashIndex hashIndex = new HashIndex();
        HashIndex hashIndex2 = new HashIndex();
        int i = 2;
        if (strArr.length > 3) {
            i = Integer.parseInt(strArr[3]);
        }
        ChineseMaxentLexicon chineseMaxentLexicon = new ChineseMaxentLexicon(options, hashIndex, hashIndex2, i);
        chineseMaxentLexicon.initializeTraining(arrayList.size());
        chineseMaxentLexicon.train(arrayList);
        chineseMaxentLexicon.finishTraining();
        log.info("Testing");
        NumberRangesFileFilter numberRangesFileFilter2 = new NumberRangesFileFilter(strArr[2], true);
        MemoryTreebank memoryTreebank2 = chineseTreebankParserParams.memoryTreebank();
        memoryTreebank2.loadPath(strArr[0], numberRangesFileFilter2);
        ArrayList arrayList2 = new ArrayList();
        Iterator<Tree> it2 = memoryTreebank2.iterator();
        while (it2.hasNext()) {
            Iterator<TaggedWord> it3 = it2.next().taggedYield().iterator();
            while (it3.hasNext()) {
                arrayList2.add(it3.next());
            }
        }
        int[] testOnTreebank = chineseMaxentLexicon.testOnTreebank(arrayList2);
        log.info("done.");
        System.out.println(testOnTreebank[1] + " correct out of " + testOnTreebank[0] + " -- ACC: " + (testOnTreebank[1] / testOnTreebank[0]));
    }

    private int[] testOnTreebank(Collection<TaggedWord> collection) {
        int[] iArr = {0, 0};
        for (TaggedWord taggedWord : collection) {
            String tag = taggedWord.tag();
            String basicCategory = this.ctlp.basicCategory(getTag(taggedWord.word()));
            iArr[0] = iArr[0] + 1;
            if (tag.equals(basicCategory)) {
                iArr[1] = iArr[1] + 1;
            }
        }
        return iArr;
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public float score(IntTaggedWord intTaggedWord, int i, String str, String str2) {
        ensureProbs(intTaggedWord.word());
        double max = Counters.max(this.logProbs);
        double count = this.logProbs.getCount(intTaggedWord.tagString(this.tagIndex));
        if (count > max - this.iteratorCutoffFactor) {
            return (float) count;
        }
        return Float.NEGATIVE_INFINITY;
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public void writeData(Writer writer) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public void readData(BufferedReader bufferedReader) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public UnknownWordModel getUnknownWordModel() {
        return null;
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public void setUnknownWordModel(UnknownWordModel unknownWordModel) {
    }

    @Override // edu.stanford.nlp.parser.lexparser.Lexicon
    public void train(Collection<Tree> collection, Collection<Tree> collection2) {
        train(collection);
    }
}
