package edu.stanford.nlp.parser.dvparser;

import edu.stanford.nlp.io.FileSystem;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.SentenceUtils;
import edu.stanford.nlp.parser.common.ArgUtils;
import edu.stanford.nlp.parser.common.ParserQuery;
import edu.stanford.nlp.parser.dvparser.DVModelReranker;
import edu.stanford.nlp.parser.lexparser.LatticeXMLReader;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.parser.lexparser.RerankerQuery;
import edu.stanford.nlp.parser.lexparser.RerankingParserQuery;
import edu.stanford.nlp.process.DocumentPreprocessor;
import edu.stanford.nlp.trees.DeepTree;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileFilter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:edu/stanford/nlp/parser/dvparser/ParseAndPrintMatrices.class */
public class ParseAndPrintMatrices {
    private static Redwood.RedwoodChannels log = Redwood.channels(ParseAndPrintMatrices.class);

    public static void outputMatrix(BufferedWriter bufferedWriter, SimpleMatrix simpleMatrix) throws IOException {
        for (int i = 0; i < simpleMatrix.getNumElements(); i++) {
            bufferedWriter.write("  " + simpleMatrix.get(i));
        }
        bufferedWriter.newLine();
    }

    public static void outputTreeMatrices(BufferedWriter bufferedWriter, Tree tree, IdentityHashMap<Tree, SimpleMatrix> identityHashMap) throws IOException {
        if (tree.isPreTerminal() || tree.isLeaf()) {
            return;
        }
        for (int length = tree.children().length - 1; length >= 0; length--) {
            outputTreeMatrices(bufferedWriter, tree.children()[length], identityHashMap);
        }
        outputMatrix(bufferedWriter, identityHashMap.get(tree));
    }

    public static Tree findRootTree(IdentityHashMap<Tree, SimpleMatrix> identityHashMap) {
        for (Tree tree : identityHashMap.keySet()) {
            if (tree.label().value().equals("ROOT")) {
                return tree;
            }
        }
        throw new RuntimeException("Could not find root");
    }

    public static void main(String[] strArr) throws IOException {
        String str = null;
        String str2 = null;
        String str3 = null;
        ArrayList newArrayList = Generics.newArrayList();
        int i = 0;
        while (i < strArr.length) {
            if (strArr[i].equalsIgnoreCase("-model")) {
                str = strArr[i + 1];
                i += 2;
            } else if (strArr[i].equalsIgnoreCase("-output")) {
                str2 = strArr[i + 1];
                i += 2;
            } else if (strArr[i].equalsIgnoreCase("-input")) {
                str3 = strArr[i + 1];
                i += 2;
            } else if (strArr[i].equalsIgnoreCase("-testTreebank")) {
                Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(strArr, i, "-testTreebank");
                i = i + ArgUtils.numSubArgs(strArr, i) + 1;
                treebankDescription.first();
                treebankDescription.second();
            } else {
                int i2 = i;
                i++;
                newArrayList.add(strArr[i2]);
            }
        }
        LexicalizedParser loadModel = LexicalizedParser.loadModel(str, (String[]) newArrayList.toArray(new String[newArrayList.size()]));
        DVModel modelFromLexicalizedParser = DVParser.getModelFromLexicalizedParser(loadModel);
        File file = new File(str2);
        FileSystem.checkNotExistsOrFail(file);
        FileSystem.mkdirOrFail(file);
        int i3 = 0;
        if (str3 != null) {
            Iterator<List<HasWord>> it = new DocumentPreprocessor(new BufferedReader(new FileReader(str3))).iterator();
            while (it.hasNext()) {
                List<HasWord> next = it.next();
                i3++;
                ParserQuery parserQuery = loadModel.parserQuery();
                if (!(parserQuery instanceof RerankingParserQuery)) {
                    throw new IllegalArgumentException("Expected a RerankingParserQuery");
                }
                RerankingParserQuery rerankingParserQuery = (RerankingParserQuery) parserQuery;
                if (!rerankingParserQuery.parse(next)) {
                    throw new RuntimeException("Unparsable sentence: " + next);
                }
                RerankerQuery rerankerQuery = rerankingParserQuery.rerankerQuery();
                if (!(rerankerQuery instanceof DVModelReranker.Query)) {
                    throw new IllegalArgumentException("Expected a DVModelReranker");
                }
                DeepTree deepTree = ((DVModelReranker.Query) rerankerQuery).getDeepTrees().get(0);
                IdentityHashMap<Tree, SimpleMatrix> vectors = deepTree.getVectors();
                for (Map.Entry<Tree, SimpleMatrix> entry : vectors.entrySet()) {
                    log.info(entry.getKey() + "   " + entry.getValue());
                }
                FileWriter fileWriter = new FileWriter(str2 + File.separator + LatticeXMLReader.SENTENCE + i3 + ".txt");
                BufferedWriter bufferedWriter = new BufferedWriter(fileWriter);
                bufferedWriter.write(SentenceUtils.listToString(next));
                bufferedWriter.newLine();
                bufferedWriter.write(deepTree.getTree().toString());
                bufferedWriter.newLine();
                Iterator<HasWord> it2 = next.iterator();
                while (it2.hasNext()) {
                    outputMatrix(bufferedWriter, modelFromLexicalizedParser.getWordVector(it2.next().word()));
                }
                outputTreeMatrices(bufferedWriter, findRootTree(vectors), vectors);
                bufferedWriter.flush();
                fileWriter.close();
            }
        }
    }
}
