package edu.stanford.nlp.tagger.maxent;

import edu.stanford.nlp.io.EncodingPrintWriter;
import edu.stanford.nlp.io.PrintFile;
import edu.stanford.nlp.ling.HasOffset;
import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.SentenceUtils;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.naturalli.demo.CORSFilter;
import edu.stanford.nlp.sequences.ExactBestSequenceFinder;
import edu.stanford.nlp.sequences.SequenceModel;
import edu.stanford.nlp.util.ArrayUtils;
import edu.stanford.nlp.util.ConfusionMatrix;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.RuntimeInterruptedException;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.io.UnsupportedEncodingException;
import java.io.Writer;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/tagger/maxent/TestSentence.class */
public class TestSentence implements SequenceModel {
    private static Redwood.RedwoodChannels log;
    protected final boolean VERBOSE;
    protected static final String naTag = "NA";
    private static final String[] naTagArr;
    protected static final boolean DBG = false;
    protected static final int kBestSize = 1;
    protected final String tagSeparator;
    protected final String encoding;
    protected List<String> sent;
    private List<String> originalTags;
    protected List<HasWord> origWords;
    protected int size;
    private String[] correctTags;
    protected String[] finalTags;
    int numRight;
    int numWrong;
    int numUnknown;
    int numWrongUnknown;
    private int endSizePairs;
    private volatile History history;
    private volatile double[][] localContextScores;
    protected final MaxentTagger maxentTagger;
    static final /* synthetic */ boolean $assertionsDisabled;
    protected final PairsHolder pairs = new PairsHolder();
    private volatile Map<String, double[]> localScores = Generics.newHashMap();

    public TestSentence(MaxentTagger maxentTagger) {
        if (!$assertionsDisabled && maxentTagger == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && maxentTagger.getLambdaSolve() == null) {
            throw new AssertionError();
        }
        this.maxentTagger = maxentTagger;
        if (maxentTagger.config != null) {
            this.tagSeparator = maxentTagger.config.getTagSeparator();
            this.encoding = maxentTagger.config.getEncoding();
            this.VERBOSE = maxentTagger.config.getVerbose();
        } else {
            this.tagSeparator = TaggerConfig.getDefaultTagSeparator();
            this.encoding = "utf-8";
            this.VERBOSE = false;
        }
        this.history = new History(this.pairs, maxentTagger.extractors);
    }

    public void setCorrectTags(List<? extends HasTag> list) {
        int size = list.size();
        this.correctTags = new String[size];
        for (int i = 0; i < size; i++) {
            this.correctTags[i] = list.get(i).tag();
        }
    }

    public ArrayList<TaggedWord> tagSentence(List<? extends HasWord> list, boolean z) {
        this.origWords = new ArrayList(list);
        int size = list.size();
        this.sent = new ArrayList(size + 1);
        for (HasWord hasWord : list) {
            if (this.maxentTagger.wordFunction != null) {
                this.sent.add(this.maxentTagger.wordFunction.apply(hasWord.word()));
            } else {
                this.sent.add(hasWord.word());
            }
        }
        this.sent.add(".$.");
        if (z) {
            this.originalTags = new ArrayList(size + 1);
            for (HasWord hasWord2 : list) {
                if (hasWord2 instanceof HasTag) {
                    this.originalTags.add(((HasTag) hasWord2).tag());
                } else {
                    this.originalTags.add(null);
                }
            }
            this.originalTags.add(".$$.");
        }
        this.size = size + 1;
        if (this.VERBOSE) {
            log.info("Sentence is " + SentenceUtils.listToString(this.sent, false, this.tagSeparator));
        }
        init();
        ArrayList<TaggedWord> testTagInference = testTagInference();
        if (this.maxentTagger.wordFunction != null) {
            for (int i = 0; i < size; i++) {
                testTagInference.get(i).setWord(list.get(i).word());
            }
        }
        return testTagInference;
    }

    protected void revert(int i) {
        this.endSizePairs = i;
    }

    /* JADX WARN: Type inference failed for: r1v2, types: [double[], double[][]] */
    protected void init() {
        this.localContextScores = new double[this.size];
        for (int i = 0; i < this.size - 1; i++) {
            if (this.maxentTagger.dict.isUnknown(this.sent.get(i))) {
                this.numUnknown++;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public String getTaggedNice() {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < this.size - 1; i++) {
            sb.append(toNice(this.sent.get(i))).append(this.tagSeparator).append(toNice(this.finalTags[i]));
            sb.append(' ');
        }
        return sb.toString();
    }

    ArrayList<TaggedWord> getTaggedSentence() {
        boolean z = this.origWords != null && this.origWords.size() > 0 && (this.origWords.get(0) instanceof HasOffset);
        ArrayList<TaggedWord> arrayList = new ArrayList<>();
        for (int i = 0; i < this.size - 1; i++) {
            TaggedWord taggedWord = new TaggedWord(this.sent.get(i), this.finalTags[i]);
            if (z) {
                HasOffset hasOffset = (HasOffset) this.origWords.get(i);
                taggedWord.setBeginPosition(hasOffset.beginPosition());
                taggedWord.setEndPosition(hasOffset.endPosition());
            }
            arrayList.add(taggedWord);
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static String toNice(String str) {
        return str == null ? naTag : str;
    }

    protected void calculateProbs(double[][][] dArr) {
        ArrayUtils.fill(dArr, Double.NEGATIVE_INFINITY);
        for (int i = 0; i < 1; i++) {
            this.pairs.setSize(this.size);
            for (int i2 = 0; i2 < this.size; i2++) {
                this.pairs.setWord(i2, this.sent.get(i2));
                this.pairs.setTag(i2, this.finalTags[i2]);
            }
            int i3 = this.endSizePairs;
            int i4 = (this.endSizePairs + this.size) - 1;
            this.endSizePairs += this.size;
            for (int i5 = 0; i5 < this.size; i5++) {
                History history = new History(i3, i4, i5 + i3, this.pairs, this.maxentTagger.extractors);
                String[] stringTagsAt = stringTagsAt((history.current - history.start) + leftWindow());
                double[] histories = getHistories(stringTagsAt, history);
                ArrayMath.logNormalize(histories);
                for (int i6 = 0; i6 < stringTagsAt.length; i6++) {
                    dArr[i5][i][this.maxentTagger.hasApproximateScoring() ? this.maxentTagger.tags.getIndex(stringTagsAt[i6]) : i6] = histories[i6];
                }
            }
        }
        revert(0);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void writeTagsAndErrors(String[] strArr, PrintFile printFile, boolean z) {
        PrintWriter printWriter;
        StringWriter stringWriter = new StringWriter(200);
        for (int i = 0; i < this.correctTags.length; i++) {
            stringWriter.write(toNice(this.sent.get(i)));
            stringWriter.write(this.tagSeparator);
            stringWriter.write(strArr[i]);
            stringWriter.write(32);
            if (printFile != null) {
                printFile.print(toNice(this.sent.get(i)));
                printFile.print(this.tagSeparator);
                printFile.print(strArr[i]);
            }
            if (this.correctTags[i].equals(strArr[i])) {
                this.numRight++;
            } else {
                this.numWrong++;
                if (printFile != null) {
                    printFile.print('|' + this.correctTags[i]);
                }
                if (z) {
                    EncodingPrintWriter.err.println((this.maxentTagger.dict.isUnknown(this.sent.get(i)) ? "Unk" : "") + "Word: " + this.sent.get(i) + "; correct: " + this.correctTags[i] + "; guessed: " + strArr[i], this.encoding);
                }
                if (this.maxentTagger.dict.isUnknown(this.sent.get(i))) {
                    this.numWrongUnknown++;
                    if (printFile != null) {
                        printFile.print(CORSFilter.DEFAULT_ALLOWED_ORIGINS);
                    }
                }
            }
            if (printFile != null) {
                printFile.print(' ');
            }
        }
        if (printFile != null) {
            printFile.println();
        }
        if (z) {
            try {
                printWriter = new PrintWriter((Writer) new OutputStreamWriter(System.out, this.encoding), true);
            } catch (UnsupportedEncodingException e) {
                printWriter = new PrintWriter((Writer) new OutputStreamWriter(System.out), true);
            }
            printWriter.println(stringWriter);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateConfusionMatrix(String[] strArr, ConfusionMatrix<String> confusionMatrix) {
        for (int i = 0; i < this.correctTags.length; i++) {
            confusionMatrix.add(strArr[i], this.correctTags[i]);
        }
    }

    private ArrayList<TaggedWord> testTagInference() {
        runTagInference();
        return getTaggedSentence();
    }

    private void runTagInference() {
        initializeScorer();
        if (Thread.interrupted()) {
            throw new RuntimeInterruptedException();
        }
        int[] bestSequence = new ExactBestSequenceFinder().bestSequence(this);
        this.finalTags = new String[bestSequence.length];
        for (int i = 0; i < this.size; i++) {
            this.finalTags[i] = this.maxentTagger.tags.getTag(bestSequence[i + leftWindow()]);
        }
        if (Thread.interrupted()) {
            throw new RuntimeInterruptedException();
        }
        cleanUpScorer();
    }

    private void setHistory(int i, History history, int[] iArr) {
        int leftWindow = leftWindow();
        int rightWindow = rightWindow();
        for (int i2 = i - leftWindow; i2 <= i + rightWindow; i2++) {
            if (i2 >= leftWindow) {
                if (i2 >= this.size + leftWindow) {
                    return;
                } else {
                    history.setTag(i2 - leftWindow, this.maxentTagger.tags.getTag(iArr[i2]));
                }
            }
        }
    }

    protected void initializeScorer() {
        this.pairs.setSize(this.size);
        for (int i = 0; i < this.size; i++) {
            this.pairs.setWord(i, this.sent.get(i));
        }
        this.endSizePairs += this.size;
    }

    protected void cleanUpScorer() {
        revert(0);
    }

    private double[] getScores(History history) {
        return this.maxentTagger.hasApproximateScoring() ? getApproximateScores(history) : getExactScores(history);
    }

    private double[] getExactScores(History history) {
        String[] stringTagsAt = stringTagsAt((history.current - history.start) + leftWindow());
        double[] histories = getHistories(stringTagsAt, history);
        ArrayMath.logNormalize(histories);
        double[] dArr = new double[stringTagsAt.length];
        for (int i = 0; i < stringTagsAt.length; i++) {
            dArr[i] = histories[this.maxentTagger.tags.getIndex(stringTagsAt[i])];
        }
        return dArr;
    }

    private double[] getApproximateScores(History history) {
        String[] stringTagsAt = stringTagsAt((history.current - history.start) + leftWindow());
        double[] histories = getHistories(stringTagsAt, history);
        ArrayMath.addInPlace(histories, -SloppyMath.logAdd(ArrayMath.logSum(histories), this.maxentTagger.getInactiveTagDefaultScore(this.maxentTagger.ySize - stringTagsAt.length)));
        return histories;
    }

    protected double[] getHistories(String[] strArr, History history) {
        boolean isRare = this.maxentTagger.isRare(ExtractorFrames.cWord.extract(history));
        Extractors extractors = this.maxentTagger.extractors;
        Extractors extractors2 = this.maxentTagger.extractorsRare;
        String word = this.pairs.getWord(history.current);
        double[] dArr = this.localScores.get(word);
        if (dArr == null) {
            dArr = getHistories(strArr, history, extractors.local, isRare ? extractors2.local : null);
            this.localScores.put(word, dArr);
        } else if (dArr.length != strArr.length) {
            dArr = getHistories(strArr, history, extractors.local, isRare ? extractors2.local : null);
            if (strArr.length > 1) {
                this.localScores.put(word, dArr);
            }
        }
        double[] dArr2 = this.localContextScores[history.current];
        double[] dArr3 = dArr2;
        if (dArr2 == null) {
            dArr3 = getHistories(strArr, history, extractors.localContext, isRare ? extractors2.localContext : null);
            this.localContextScores[history.current] = dArr3;
            ArrayMath.pairwiseAddInPlace(dArr3, dArr);
        }
        double[] histories = getHistories(strArr, history, extractors.dynamic, isRare ? extractors2.dynamic : null);
        ArrayMath.pairwiseAddInPlace(histories, dArr3);
        return histories;
    }

    private double[] getHistories(String[] strArr, History history, List<Pair<Integer, Extractor>> list, List<Pair<Integer, Extractor>> list2) {
        return this.maxentTagger.hasApproximateScoring() ? getApproximateHistories(strArr, history, list, list2) : getExactHistories(history, list, list2);
    }

    private double[] getExactHistories(History history, List<Pair<Integer, Extractor>> list, List<Pair<Integer, Extractor>> list2) {
        double[] dArr = new double[this.maxentTagger.ySize];
        int size = this.maxentTagger.extractors.size();
        for (Pair<Integer, Extractor> pair : list) {
            int[] iArr = this.maxentTagger.fAssociations.get(pair.first().intValue()).get(pair.second().extract(history));
            if (iArr != null) {
                for (int i = 0; i < this.maxentTagger.ySize; i++) {
                    int i2 = iArr[i];
                    if (i2 > -1) {
                        int i3 = i;
                        dArr[i3] = dArr[i3] + this.maxentTagger.getLambdaSolve().lambda[i2];
                    }
                }
            }
        }
        if (list2 != null) {
            for (Pair<Integer, Extractor> pair2 : list2) {
                int[] iArr2 = this.maxentTagger.fAssociations.get(pair2.first().intValue() + size).get(pair2.second().extract(history));
                if (iArr2 != null) {
                    for (int i4 = 0; i4 < this.maxentTagger.ySize; i4++) {
                        int i5 = iArr2[i4];
                        if (i5 > -1) {
                            int i6 = i4;
                            dArr[i6] = dArr[i6] + this.maxentTagger.getLambdaSolve().lambda[i5];
                        }
                    }
                }
            }
        }
        return dArr;
    }

    private double[] getApproximateHistories(String[] strArr, History history, List<Pair<Integer, Extractor>> list, List<Pair<Integer, Extractor>> list2) {
        double[] dArr = new double[strArr.length];
        int size = this.maxentTagger.extractors.size();
        for (Pair<Integer, Extractor> pair : list) {
            int[] iArr = this.maxentTagger.fAssociations.get(pair.first().intValue()).get(pair.second().extract(history));
            if (iArr != null) {
                for (int i = 0; i < strArr.length; i++) {
                    int i2 = iArr[this.maxentTagger.tags.getIndex(strArr[i])];
                    if (i2 > -1) {
                        int i3 = i;
                        dArr[i3] = dArr[i3] + this.maxentTagger.getLambdaSolve().lambda[i2];
                    }
                }
            }
        }
        if (list2 != null) {
            for (Pair<Integer, Extractor> pair2 : list2) {
                int[] iArr2 = this.maxentTagger.fAssociations.get(size + pair2.first().intValue()).get(pair2.second().extract(history));
                if (iArr2 != null) {
                    for (int i4 = 0; i4 < strArr.length; i4++) {
                        int i5 = iArr2[this.maxentTagger.tags.getIndex(strArr[i4])];
                        if (i5 > -1) {
                            int i6 = i4;
                            dArr[i6] = dArr[i6] + this.maxentTagger.getLambdaSolve().lambda[i5];
                        }
                    }
                }
            }
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void printUnknown(int i, PrintFile printFile) {
        DecimalFormat decimalFormat = new DecimalFormat("0.0000");
        double[][][] dArr = new double[this.size][1][this.maxentTagger.numTags()];
        calculateProbs(dArr);
        for (int i2 = 0; i2 < this.size; i2++) {
            if (this.maxentTagger.dict.isUnknown(this.sent.get(i2))) {
                printFile.print(this.sent.get(i2));
                printFile.print(':');
                printFile.print(i);
                double[] dArr2 = new double[3];
                String[] strArr = new String[3];
                getTop3(dArr, i2, dArr2, strArr);
                for (int i3 = 0; i3 < 3; i3++) {
                    if (dArr2[i3] > Double.NEGATIVE_INFINITY) {
                        printFile.print('\t');
                        printFile.print(strArr[i3]);
                        printFile.print(' ');
                        printFile.print(decimalFormat.format(Math.exp(dArr2[i3])));
                    }
                }
                String nice = toNice(this.correctTags[i2]);
                int i4 = 0;
                while (i4 < 3 && !nice.equals(strArr[i4])) {
                    i4++;
                }
                printFile.print('\t');
                switch (i4) {
                    case 0:
                        printFile.print("Correct");
                        break;
                    case 1:
                        printFile.print("2nd");
                        break;
                    case 2:
                        printFile.print("3rd");
                        break;
                    default:
                        printFile.print("Not top 3");
                        break;
                }
                printFile.println();
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void printTop(PrintFile printFile) {
        DecimalFormat decimalFormat = new DecimalFormat("0.0000");
        double[][][] dArr = new double[this.size][1][this.maxentTagger.numTags()];
        calculateProbs(dArr);
        for (int i = 0; i < this.correctTags.length; i++) {
            printFile.print(this.sent.get(i));
            double[] dArr2 = new double[3];
            String[] strArr = new String[3];
            getTop3(dArr, i, dArr2, strArr);
            for (int i2 = 0; i2 < 3; i2++) {
                if (dArr2[i2] > Double.NEGATIVE_INFINITY) {
                    printFile.print('\t');
                    printFile.print(strArr[i2]);
                    printFile.print(' ');
                    printFile.print(decimalFormat.format(Math.exp(dArr2[i2])));
                }
            }
            String nice = toNice(this.correctTags[i]);
            int i3 = 0;
            while (i3 < 3 && !nice.equals(strArr[i3])) {
                i3++;
            }
            printFile.print('\t');
            switch (i3) {
                case 0:
                    printFile.print("Correct");
                    break;
                case 1:
                    printFile.print("2nd");
                    break;
                case 2:
                    printFile.print("3rd");
                    break;
                default:
                    printFile.print("Not top 3");
                    break;
            }
            printFile.println();
        }
    }

    private void getTop3(double[][][] dArr, int i, double[] dArr2, String[] strArr) {
        int[] iArr = new int[3];
        double[] dArr3 = dArr[i][0];
        Arrays.fill(dArr2, Double.NEGATIVE_INFINITY);
        for (int i2 = 0; i2 < dArr3.length; i2++) {
            if (dArr3[i2] > dArr2[0]) {
                dArr2[2] = dArr2[1];
                dArr2[1] = dArr2[0];
                dArr2[0] = dArr3[i2];
                iArr[2] = iArr[1];
                iArr[1] = iArr[0];
                iArr[0] = i2;
            } else if (dArr3[i2] > dArr2[1]) {
                dArr2[2] = dArr2[1];
                dArr2[1] = dArr3[i2];
                iArr[2] = iArr[1];
                iArr[1] = i2;
            } else if (dArr3[i2] > dArr2[2]) {
                dArr2[2] = dArr3[i2];
                iArr[2] = i2;
            }
        }
        for (int i3 = 0; i3 < 3; i3++) {
            strArr[i3] = toNice(this.maxentTagger.tags.getTag(iArr[i3]));
        }
    }

    @Override // edu.stanford.nlp.sequences.SequenceModel
    public int length() {
        return this.sent.size();
    }

    @Override // edu.stanford.nlp.sequences.SequenceModel
    public int leftWindow() {
        return this.maxentTagger.leftContext;
    }

    @Override // edu.stanford.nlp.sequences.SequenceModel
    public int rightWindow() {
        return this.maxentTagger.rightContext;
    }

    @Override // edu.stanford.nlp.sequences.SequenceModel
    public int[] getPossibleValues(int i) {
        String[] stringTagsAt = stringTagsAt(i);
        int[] iArr = new int[stringTagsAt.length];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr[i2] = this.maxentTagger.tags.getIndex(stringTagsAt[i2]);
        }
        return iArr;
    }

    @Override // edu.stanford.nlp.sequences.SequenceModel
    public double scoreOf(int[] iArr, int i) {
        double[] scoresOf = scoresOf(iArr, i);
        double d = Double.NEGATIVE_INFINITY;
        int[] possibleValues = getPossibleValues(i);
        for (int i2 = 0; i2 < scoresOf.length; i2++) {
            if (possibleValues[i2] == iArr[i]) {
                d = scoresOf[i2];
            }
        }
        return d;
    }

    @Override // edu.stanford.nlp.sequences.SequenceModel
    public double scoreOf(int[] iArr) {
        throw new UnsupportedOperationException();
    }

    @Override // edu.stanford.nlp.sequences.SequenceModel
    public double[] scoresOf(int[] iArr, int i) {
        this.history.init(this.endSizePairs - this.size, this.endSizePairs - 1, ((this.endSizePairs - this.size) + i) - leftWindow());
        setHistory(i, this.history, iArr);
        return getScores(this.history);
    }

    protected String[] stringTagsAt(int i) {
        String[] tags;
        if (i < leftWindow() || i >= this.size + leftWindow()) {
            return naTagArr;
        }
        if (this.originalTags != null && this.originalTags.get(i - leftWindow()) != null) {
            return new String[]{this.originalTags.get(i - leftWindow())};
        }
        String str = this.sent.get(i - leftWindow());
        if (this.maxentTagger.dict.isUnknown(str)) {
            Set<String> openTags = this.maxentTagger.tags.getOpenTags();
            tags = (String[]) openTags.toArray(new String[openTags.size()]);
        } else {
            tags = this.maxentTagger.dict.getTags(str);
        }
        return this.maxentTagger.tags.deterministicallyExpandTags(tags);
    }

    static {
        $assertionsDisabled = !TestSentence.class.desiredAssertionStatus();
        log = Redwood.channels(TestSentence.class);
        naTagArr = new String[]{naTag};
    }
}
