package edu.stanford.nlp.loglinear.benchmarks;

import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.loglinear.inference.CliqueTree;
import edu.stanford.nlp.loglinear.learning.BacktrackingAdaGradOptimizer;
import edu.stanford.nlp.loglinear.learning.LogLikelihoodDifferentiableFunction;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.ConcatVectorNamespace;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import edu.stanford.nlp.semgraph.semgrex.ssurgeon.AddDep;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

/* loaded from: input_file:edu/stanford/nlp/loglinear/benchmarks/CoNLLBenchmark.class */
public class CoNLLBenchmark {
    private static Redwood.RedwoodChannels log;
    static final String DATA_PATH = "/u/nlp/data/ner/conll/";
    Map<String, double[]> embeddings = new HashMap();
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/stanford/nlp/loglinear/benchmarks/CoNLLBenchmark$CoNLLSentence.class */
    public static class CoNLLSentence {
        public List<String> token;
        public List<String> ner;
        public List<String> pos;
        public List<String> npchunk;

        public CoNLLSentence(List<String> list, List<String> list2, List<String> list3, List<String> list4) {
            this.token = new ArrayList();
            this.ner = new ArrayList();
            this.pos = new ArrayList();
            this.npchunk = new ArrayList();
            this.token = list;
            this.ner = list2;
            this.pos = list3;
            this.npchunk = list4;
        }
    }

    public static void main(String[] strArr) throws Exception {
        new CoNLLBenchmark().benchmarkOptimizer();
    }

    public void benchmarkOptimizer() throws Exception {
        List<CoNLLSentence> sentences = getSentences("/u/nlp/data/ner/conll/conll.iob.4class.train");
        List<CoNLLSentence> sentences2 = getSentences("/u/nlp/data/ner/conll/conll.iob.4class.testa");
        List<CoNLLSentence> sentences3 = getSentences("/u/nlp/data/ner/conll/conll.iob.4class.testb");
        List<CoNLLSentence> arrayList = new ArrayList<>();
        arrayList.addAll(sentences);
        arrayList.addAll(sentences2);
        arrayList.addAll(sentences3);
        HashSet hashSet = new HashSet();
        Iterator<CoNLLSentence> it = arrayList.iterator();
        while (it.hasNext()) {
            Iterator<String> it2 = it.next().ner.iterator();
            while (it2.hasNext()) {
                hashSet.add(it2.next());
            }
        }
        List<String> arrayList2 = new ArrayList<>();
        arrayList2.addAll(hashSet);
        this.embeddings = getEmbeddings("/u/nlp/data/ner/conll/google-300-trimmed.ser.gz", arrayList);
        log.info("Making the training set...");
        ConcatVectorNamespace concatVectorNamespace = new ConcatVectorNamespace();
        int size = sentences.size();
        GraphicalModel[] graphicalModelArr = new GraphicalModel[size];
        for (int i = 0; i < size; i++) {
            if (i % 10 == 0) {
                log.info(i + "/" + size);
            }
            graphicalModelArr[i] = generateSentenceModel(concatVectorNamespace, sentences.get(i), arrayList2);
        }
        log.info("Training system...");
        ConcatVector optimize = new BacktrackingAdaGradOptimizer().optimize(graphicalModelArr, new LogLikelihoodDifferentiableFunction(), concatVectorNamespace.newWeightsVector(), 0.01d, 1.0E-5d, false);
        log.info("Testing system...");
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        double d = 0.0d;
        double d2 = 0.0d;
        for (CoNLLSentence coNLLSentence : sentences2) {
            int[] calculateMAP = new CliqueTree(generateSentenceModel(concatVectorNamespace, coNLLSentence, arrayList2), optimize).calculateMAP();
            String[] strArr = new String[calculateMAP.length];
            for (int i2 = 0; i2 < calculateMAP.length; i2++) {
                strArr[i2] = arrayList2.get(calculateMAP[i2]);
                if (strArr[i2].equals(coNLLSentence.ner.get(i2))) {
                    d += 1.0d;
                    hashMap.put(strArr[i2], Double.valueOf(((Double) hashMap.getOrDefault(strArr[i2], Double.valueOf(0.0d))).doubleValue() + 1.0d));
                }
                d2 += 1.0d;
                hashMap2.put(coNLLSentence.ner.get(i2), Double.valueOf(((Double) hashMap2.getOrDefault(coNLLSentence.ner.get(i2), Double.valueOf(0.0d))).doubleValue() + 1.0d));
                hashMap3.put(strArr[i2], Double.valueOf(((Double) hashMap3.getOrDefault(strArr[i2], Double.valueOf(0.0d))).doubleValue() + 1.0d));
            }
        }
        log.info("\nSystem results:\n");
        log.info("Accuracy: " + (d / d2) + "\n");
        for (String str : arrayList2) {
            double doubleValue = ((Double) hashMap3.getOrDefault(str, Double.valueOf(0.0d))).doubleValue() == 0.0d ? 0.0d : ((Double) hashMap.getOrDefault(str, Double.valueOf(0.0d))).doubleValue() / ((Double) hashMap3.get(str)).doubleValue();
            double doubleValue2 = ((Double) hashMap2.getOrDefault(str, Double.valueOf(0.0d))).doubleValue() == 0.0d ? 0.0d : ((Double) hashMap.getOrDefault(str, Double.valueOf(0.0d))).doubleValue() / ((Double) hashMap2.get(str)).doubleValue();
            double d3 = doubleValue + doubleValue2 == 0.0d ? 0.0d : ((doubleValue * doubleValue2) * 2.0d) / (doubleValue + doubleValue2);
            log.info(str + " (" + ((Double) hashMap2.getOrDefault(str, Double.valueOf(0.0d))).intValue() + ")");
            log.info("\tP:" + doubleValue + " (" + ((Double) hashMap.getOrDefault(str, Double.valueOf(0.0d))).intValue() + "/" + ((Double) hashMap3.getOrDefault(str, Double.valueOf(0.0d))).intValue() + ")");
            log.info("\tR:" + doubleValue2 + " (" + ((Double) hashMap.getOrDefault(str, Double.valueOf(0.0d))).intValue() + "/" + ((Double) hashMap2.getOrDefault(str, Double.valueOf(0.0d))).intValue() + ")");
            log.info("\tF1:" + d3);
        }
    }

    private static String getWordShape(String str) {
        return (str.toUpperCase().equals(str) && str.toLowerCase().equals(str)) ? "no-case" : str.toUpperCase().equals(str) ? "upper-case" : str.toLowerCase().equals(str) ? "lower-case" : (str.length() > 1 && Character.isUpperCase(str.charAt(0)) && str.substring(1).toLowerCase().equals(str.substring(1))) ? "capitalized" : "mixed-case";
    }

    public GraphicalModel generateSentenceModel(ConcatVectorNamespace concatVectorNamespace, CoNLLSentence coNLLSentence, List<String> list) {
        GraphicalModel graphicalModel = new GraphicalModel();
        for (int i = 0; i < coNLLSentence.token.size(); i++) {
            Map<String, String> variableMetaDataByReference = graphicalModel.getVariableMetaDataByReference(i);
            variableMetaDataByReference.put(LogLikelihoodDifferentiableFunction.VARIABLE_TRAINING_VALUE, "" + list.indexOf(coNLLSentence.ner.get(i)));
            variableMetaDataByReference.put("TOKEN", "" + coNLLSentence.token.get(i));
            variableMetaDataByReference.put(AddDep.POS_KEY, "" + coNLLSentence.pos.get(i));
            variableMetaDataByReference.put("CHUNK", "" + coNLLSentence.npchunk.get(i));
            variableMetaDataByReference.put("TAG", "" + coNLLSentence.ner.get(i));
        }
        CoNLLFeaturizer.annotate(graphicalModel, list, concatVectorNamespace, this.embeddings);
        if (!$assertionsDisabled && graphicalModel.factors == null) {
            throw new AssertionError();
        }
        for (GraphicalModel.Factor factor : graphicalModel.factors) {
            if (!$assertionsDisabled && factor == null) {
                throw new AssertionError();
            }
        }
        return graphicalModel;
    }

    public List<CoNLLSentence> getSentences(String str) throws IOException {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        ArrayList arrayList5 = new ArrayList();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                return arrayList;
            }
            String[] split = readLine.split(LinearClassifier.TEXT_SERIALIZATION_DELIMITER);
            if (split.length == 4) {
                arrayList2.add(split[0]);
                arrayList4.add(split[1]);
                arrayList5.add(split[2]);
                String str2 = split[3];
                if (str2.contains("-")) {
                    arrayList3.add(str2.split("-")[1]);
                } else {
                    arrayList3.add(str2);
                }
                if (split[0].equals(".")) {
                    arrayList.add(new CoNLLSentence(arrayList2, arrayList3, arrayList4, arrayList5));
                    arrayList2 = new ArrayList();
                    arrayList3 = new ArrayList();
                    arrayList4 = new ArrayList();
                    arrayList5 = new ArrayList();
                }
            }
        }
    }

    public Map<String, double[]> getEmbeddings(String str, List<CoNLLSentence> list) throws IOException, ClassNotFoundException {
        Map<String, double[]> map;
        File file = new File(str);
        if (file.exists()) {
            map = (Map) new ObjectInputStream(new GZIPInputStream(new FileInputStream(str))).readObject();
        } else {
            map = new HashMap();
            Map<String, double[]> loadEmbeddingsFromFile = loadEmbeddingsFromFile("../google-300.txt");
            log.info("Got massive embedding set size " + loadEmbeddingsFromFile.size());
            Iterator<CoNLLSentence> it = list.iterator();
            while (it.hasNext()) {
                for (String str2 : it.next().token) {
                    if (loadEmbeddingsFromFile.containsKey(str2)) {
                        map.put(str2, loadEmbeddingsFromFile.get(str2));
                    }
                }
            }
            log.info("Got trimmed embedding set size " + map.size());
            file.createNewFile();
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(str)));
            objectOutputStream.writeObject(map);
            objectOutputStream.close();
            log.info("Wrote trimmed set to file");
        }
        return map;
    }

    public Map<String, double[]> loadEmbeddingsFromFile(String str) throws IOException {
        HashMap hashMap = new HashMap();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
        int i = 0;
        bufferedReader.readLine();
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                return hashMap;
            }
            String[] split = readLine.split(AddDep.ATOM_DELIMITER);
            if (split.length == 302) {
                String str2 = split[0];
                double[] dArr = new double[300];
                for (int i2 = 1; i2 < split.length - 1; i2++) {
                    dArr[i2 - 1] = Double.parseDouble(split[i2]);
                }
                hashMap.put(str2, dArr);
            }
            i++;
            if (i % 10000 == 0) {
                log.info("Read " + i + " lines");
            }
        }
    }

    static {
        $assertionsDisabled = !CoNLLBenchmark.class.desiredAssertionStatus();
        log = Redwood.channels(CoNLLBenchmark.class);
    }
}
