package edu.stanford.nlp.parser.dvparser;

import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.parser.common.ArgUtils;
import edu.stanford.nlp.parser.common.ParserQuery;
import edu.stanford.nlp.parser.dvparser.DVModelReranker;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.parser.lexparser.RerankingParserQuery;
import edu.stanford.nlp.trees.DeepTree;
import edu.stanford.nlp.trees.MemoryTreebank;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ScoredComparator;
import edu.stanford.nlp.util.ScoredObject;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedWriter;
import java.io.FileFilter;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:edu/stanford/nlp/parser/dvparser/FindNearestNeighbors.class */
public class FindNearestNeighbors {
    private static Redwood.RedwoodChannels log = Redwood.channels(FindNearestNeighbors.class);
    static final int numNeighbors = 5;
    static final int maxLength = 8;

    /* loaded from: input_file:edu/stanford/nlp/parser/dvparser/FindNearestNeighbors$ParseRecord.class */
    public static class ParseRecord {
        final List<Word> sentence;
        final Tree goldTree;
        final Tree parse;
        final SimpleMatrix rootVector;
        final IdentityHashMap<Tree, SimpleMatrix> nodeVectors;

        public ParseRecord(List<Word> list, Tree tree, Tree tree2, SimpleMatrix simpleMatrix, IdentityHashMap<Tree, SimpleMatrix> identityHashMap) {
            this.sentence = list;
            this.goldTree = tree;
            this.parse = tree2;
            this.rootVector = simpleMatrix;
            this.nodeVectors = identityHashMap;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static void main(String[] strArr) throws Exception {
        String str = null;
        String str2 = null;
        String str3 = null;
        FileFilter fileFilter = null;
        ArrayList arrayList = new ArrayList();
        int i = 0;
        while (i < strArr.length) {
            if (strArr[i].equalsIgnoreCase("-model")) {
                str = strArr[i + 1];
                i += 2;
            } else if (strArr[i].equalsIgnoreCase("-testTreebank")) {
                Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(strArr, i, "-testTreebank");
                i = i + ArgUtils.numSubArgs(strArr, i) + 1;
                str3 = treebankDescription.first();
                fileFilter = treebankDescription.second();
            } else if (strArr[i].equalsIgnoreCase("-output")) {
                str2 = strArr[i + 1];
                i += 2;
            } else {
                int i2 = i;
                i++;
                arrayList.add(strArr[i2]);
            }
        }
        if (str == null) {
            throw new IllegalArgumentException("Need to specify -model");
        }
        if (str3 == null) {
            throw new IllegalArgumentException("Need to specify -testTreebank");
        }
        if (str2 == null) {
            throw new IllegalArgumentException("Need to specify -output");
        }
        LexicalizedParser loadModel = LexicalizedParser.loadModel(str, (String[]) arrayList.toArray(new String[arrayList.size()]));
        MemoryTreebank memoryTreebank = null;
        if (str3 != null) {
            log.info("Reading in trees from " + str3);
            if (fileFilter != null) {
                log.info("Filtering on " + fileFilter);
            }
            memoryTreebank = loadModel.getOp().tlpParams.memoryTreebank();
            memoryTreebank.loadPath(str3, fileFilter);
            log.info("Read in " + memoryTreebank.size() + " trees for testing");
        }
        FileWriter fileWriter = new FileWriter(str2);
        BufferedWriter bufferedWriter = new BufferedWriter(fileWriter);
        log.info("Parsing " + memoryTreebank.size() + " trees");
        int i3 = 0;
        ArrayList newArrayList = Generics.newArrayList();
        Iterator<Tree> it = memoryTreebank.iterator();
        while (it.hasNext()) {
            Tree next = it.next();
            ArrayList<Word> yieldWords = next.yieldWords();
            ParserQuery parserQuery = loadModel.parserQuery();
            if (!parserQuery.parse(yieldWords)) {
                throw new AssertionError("Could not parse: " + yieldWords);
            }
            if (!(parserQuery instanceof RerankingParserQuery)) {
                throw new IllegalArgumentException("Expected a LexicalizedParser with a Reranker attached");
            }
            RerankingParserQuery rerankingParserQuery = (RerankingParserQuery) parserQuery;
            if (!(rerankingParserQuery.rerankerQuery() instanceof DVModelReranker.Query)) {
                throw new IllegalArgumentException("Expected a LexicalizedParser with a DVModel attached");
            }
            DeepTree deepTree = ((DVModelReranker.Query) rerankingParserQuery.rerankerQuery()).getDeepTrees().get(0);
            SimpleMatrix simpleMatrix = null;
            Iterator<Map.Entry<Tree, SimpleMatrix>> it2 = deepTree.getVectors().entrySet().iterator();
            while (true) {
                if (!it2.hasNext()) {
                    break;
                }
                Map.Entry<Tree, SimpleMatrix> next2 = it2.next();
                if (next2.getKey().label().value().equals("ROOT")) {
                    simpleMatrix = next2.getValue();
                    break;
                }
            }
            if (simpleMatrix == null) {
                throw new AssertionError("Could not find root nodevector");
            }
            fileWriter.write(yieldWords + "\n");
            fileWriter.write(deepTree.getTree() + "\n");
            for (int i4 = 0; i4 < simpleMatrix.getNumElements(); i4++) {
                fileWriter.write("  " + simpleMatrix.get(i4));
            }
            fileWriter.write("\n\n\n");
            i3++;
            if (i3 % 10 == 0) {
                log.info("  " + i3);
            }
            newArrayList.add(new ParseRecord(yieldWords, next, deepTree.getTree(), simpleMatrix, deepTree.getVectors()));
        }
        log.info("  done parsing");
        ArrayList newArrayList2 = Generics.newArrayList();
        Iterator it3 = newArrayList.iterator();
        while (it3.hasNext()) {
            for (Map.Entry<Tree, SimpleMatrix> entry : ((ParseRecord) it3.next()).nodeVectors.entrySet()) {
                if (entry.getKey().getLeaves().size() <= 8) {
                    newArrayList2.add(Pair.makePair(entry.getKey(), entry.getValue()));
                }
            }
        }
        log.info("There are " + newArrayList2.size() + " subtrees in the set of trees");
        PriorityQueue priorityQueue = new PriorityQueue(101, ScoredComparator.DESCENDING_COMPARATOR);
        for (int i5 = 0; i5 < newArrayList2.size(); i5++) {
            log.info(((Tree) ((Pair) newArrayList2.get(i5)).first()).yieldWords());
            log.info(((Pair) newArrayList2.get(i5)).first());
            for (int i6 = 0; i6 < newArrayList2.size(); i6++) {
                if (i5 != i6) {
                    priorityQueue.add(new ScoredObject(Pair.makePair(((Pair) newArrayList2.get(i5)).first(), ((Pair) newArrayList2.get(i6)).first()), ((SimpleMatrix) ((Pair) newArrayList2.get(i5)).second()).minus((SimpleBase) ((Pair) newArrayList2.get(i6)).second()).normF()));
                    if (priorityQueue.size() > 100) {
                        priorityQueue.poll();
                    }
                }
            }
            ArrayList<ScoredObject> newArrayList3 = Generics.newArrayList();
            while (priorityQueue.size() > 0) {
                newArrayList3.add(priorityQueue.poll());
            }
            Collections.reverse(newArrayList3);
            for (ScoredObject scoredObject : newArrayList3) {
                log.info(" MATCHED " + ((Tree) ((Pair) scoredObject.object()).second).yieldWords() + " ... " + ((Pair) scoredObject.object()).second() + " with a score of " + scoredObject.score());
            }
            log.info(new Object[0]);
            log.info(new Object[0]);
            priorityQueue.clear();
        }
        bufferedWriter.flush();
        fileWriter.flush();
        fileWriter.close();
    }
}
