package edu.stanford.nlp.loglinear.benchmarks;

import edu.stanford.nlp.loglinear.benchmarks.CoNLLBenchmark;
import edu.stanford.nlp.loglinear.inference.CliqueTree;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.ConcatVectorNamespace;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Stack;

/* loaded from: input_file:edu/stanford/nlp/loglinear/benchmarks/GamePlayerBenchmark.class */
public class GamePlayerBenchmark {
    private static Redwood.RedwoodChannels log;
    static final String DATA_PATH = "/u/nlp/data/ner/conll/";
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/stanford/nlp/loglinear/benchmarks/GamePlayerBenchmark$SampleState.class */
    public static class SampleState {
        public GraphicalModel.Factor addedFactor;
        public int variable;
        public int observation;
        public List<SampleState> children = new ArrayList();
        public double[][] cachedMarginal = (double[][]) null;
        static final /* synthetic */ boolean $assertionsDisabled;

        public SampleState(GraphicalModel.Factor factor, int i, int i2) {
            this.addedFactor = factor;
            this.variable = i;
            this.observation = i2;
        }

        public void push(GraphicalModel graphicalModel) {
            if (!$assertionsDisabled && graphicalModel.factors.contains(this.addedFactor)) {
                throw new AssertionError();
            }
            graphicalModel.factors.add(this.addedFactor);
            graphicalModel.getVariableMetaDataByReference(this.variable).put(CliqueTree.VARIABLE_OBSERVED_VALUE, "" + this.observation);
        }

        public void pop(GraphicalModel graphicalModel) {
            if (!$assertionsDisabled && !graphicalModel.factors.contains(this.addedFactor)) {
                throw new AssertionError();
            }
            graphicalModel.factors.remove(this.addedFactor);
            graphicalModel.getVariableMetaDataByReference(this.variable).remove(CliqueTree.VARIABLE_OBSERVED_VALUE);
        }

        static {
            $assertionsDisabled = !GamePlayerBenchmark.class.desiredAssertionStatus();
        }
    }

    public static void main(String[] strArr) throws IOException, ClassNotFoundException {
        CoNLLBenchmark coNLLBenchmark = new CoNLLBenchmark();
        List<CoNLLBenchmark.CoNLLSentence> sentences = coNLLBenchmark.getSentences("/u/nlp/data/ner/conll/conll.iob.4class.train");
        List<CoNLLBenchmark.CoNLLSentence> sentences2 = coNLLBenchmark.getSentences("/u/nlp/data/ner/conll/conll.iob.4class.testa");
        List<CoNLLBenchmark.CoNLLSentence> sentences3 = coNLLBenchmark.getSentences("/u/nlp/data/ner/conll/conll.iob.4class.testb");
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(sentences);
        arrayList.addAll(sentences2);
        arrayList.addAll(sentences3);
        HashSet hashSet = new HashSet();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            Iterator<String> it2 = ((CoNLLBenchmark.CoNLLSentence) it.next()).ner.iterator();
            while (it2.hasNext()) {
                hashSet.add(it2.next());
            }
        }
        ArrayList arrayList2 = new ArrayList();
        arrayList2.addAll(hashSet);
        coNLLBenchmark.embeddings = coNLLBenchmark.getEmbeddings("/u/nlp/data/ner/conll/google-300-trimmed.ser.gz", arrayList);
        log.info("Making the training set...");
        ConcatVectorNamespace concatVectorNamespace = new ConcatVectorNamespace();
        int size = sentences.size();
        GraphicalModel[] graphicalModelArr = new GraphicalModel[size];
        for (int i = 0; i < size; i++) {
            if (i % 10 == 0) {
                log.info(i + "/" + size);
            }
            graphicalModelArr[i] = coNLLBenchmark.generateSentenceModel(concatVectorNamespace, sentences.get(i), arrayList2);
        }
        Random random = new Random(10L);
        ConcatVector[] concatVectorArr = new ConcatVector[1000];
        for (int i2 = 0; i2 < concatVectorArr.length; i2++) {
            concatVectorArr[i2] = new ConcatVector(5);
            for (int i3 = 0; i3 < 5; i3++) {
                if (random.nextBoolean()) {
                    concatVectorArr[i2].setSparseComponent(i3, random.nextInt(30), random.nextDouble());
                } else {
                    double[] dArr = new double[30];
                    for (int i4 = 0; i4 < dArr.length; i4++) {
                        dArr[i4] = random.nextDouble();
                    }
                    concatVectorArr[i2].setDenseComponent(i3, dArr);
                }
            }
        }
        ConcatVector concatVector = new ConcatVector(5);
        for (int i5 = 0; i5 < 5; i5++) {
            double[] dArr2 = new double[30];
            for (int i6 = 0; i6 < dArr2.length; i6++) {
                dArr2[i6] = random.nextDouble();
            }
            concatVector.setDenseComponent(i5, dArr2);
        }
        log.info("Warming up the JIT...");
        for (int i7 = 0; i7 < 10; i7++) {
            log.info(Integer.valueOf(i7));
            gameplay(random, graphicalModelArr[i7], concatVector, concatVectorArr);
        }
        log.info("Timing actual run...");
        long currentTimeMillis = System.currentTimeMillis();
        for (int i8 = 0; i8 < 10; i8++) {
            log.info(Integer.valueOf(i8));
            gameplay(random, graphicalModelArr[i8], concatVector, concatVectorArr);
        }
        log.info("Duration: " + (System.currentTimeMillis() - currentTimeMillis));
    }

    private static void gameplay(Random random, GraphicalModel graphicalModel, ConcatVector concatVector, ConcatVector[] concatVectorArr) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (GraphicalModel.Factor factor : graphicalModel.factors) {
            for (int i = 0; i < factor.neigborIndices.length; i++) {
                int i2 = factor.neigborIndices[i];
                if (!arrayList.contains(Integer.valueOf(i2))) {
                    arrayList.add(Integer.valueOf(i2));
                    arrayList2.add(Integer.valueOf(factor.featuresTable.getDimensions()[i]));
                }
            }
        }
        int[] array = arrayList.stream().mapToInt(num -> {
            return num.intValue();
        }).toArray();
        int[] array2 = arrayList2.stream().mapToInt(num2 -> {
            return num2.intValue();
        }).toArray();
        ArrayList arrayList3 = new ArrayList();
        CliqueTree cliqueTree = new CliqueTree(graphicalModel, concatVector);
        int size = graphicalModel.factors.size();
        long currentTimeMillis = System.currentTimeMillis();
        long j = 0;
        for (int i3 = 0; i3 < 1000; i3++) {
            log.info("\tTaking sample " + i3);
            Stack stack = new Stack();
            SampleState selectOrCreateChildAtRandom = selectOrCreateChildAtRandom(random, graphicalModel, array, array2, arrayList3, concatVectorArr);
            long j2 = 0;
            for (int i4 = 0; i4 < 10; i4++) {
                selectOrCreateChildAtRandom.push(graphicalModel);
                if (!$assertionsDisabled && graphicalModel.factors.size() != size + i4 + 1) {
                    throw new AssertionError();
                }
                if (selectOrCreateChildAtRandom.cachedMarginal == null) {
                    long currentTimeMillis2 = System.currentTimeMillis();
                    selectOrCreateChildAtRandom.cachedMarginal = cliqueTree.calculateMarginalsJustSingletons();
                    j2 += System.currentTimeMillis() - currentTimeMillis2;
                }
                stack.push(selectOrCreateChildAtRandom);
                selectOrCreateChildAtRandom = selectOrCreateChildAtRandom(random, graphicalModel, array, array2, selectOrCreateChildAtRandom.children, concatVectorArr);
            }
            log.info("\t\t" + j2 + " ms");
            j += j2;
            while (!stack.empty()) {
                ((SampleState) stack.pop()).pop(graphicalModel);
            }
            if (!$assertionsDisabled && graphicalModel.factors.size() != size) {
                throw new AssertionError();
            }
        }
        log.info("Marginals time: " + j + " ms");
        log.info("Avg time per marginal: " + (j / 200) + " ms");
        log.info("Total time: " + (System.currentTimeMillis() - currentTimeMillis));
    }

    private static SampleState selectOrCreateChildAtRandom(Random random, GraphicalModel graphicalModel, int[] iArr, int[] iArr2, List<SampleState> list, ConcatVector[] concatVectorArr) {
        int nextInt = random.nextInt(iArr.length);
        int i = iArr[nextInt];
        int nextInt2 = random.nextInt(iArr2[nextInt]);
        for (SampleState sampleState : list) {
            if (sampleState.variable == i && sampleState.observation == nextInt2) {
                return sampleState;
            }
        }
        int i2 = 0;
        Iterator<GraphicalModel.Factor> it = graphicalModel.factors.iterator();
        while (it.hasNext()) {
            for (int i3 : it.next().neigborIndices) {
                if (i3 >= i2) {
                    i2 = i3 + 1;
                }
            }
        }
        GraphicalModel.Factor addFactor = graphicalModel.addFactor(new int[]{i, i2}, new int[]{iArr2[nextInt], iArr2[nextInt]}, iArr3 -> {
            return concatVectorArr[(iArr3[0] * iArr2[nextInt]) + iArr3[1]];
        });
        graphicalModel.factors.remove(addFactor);
        SampleState sampleState2 = new SampleState(addFactor, i, nextInt2);
        list.add(sampleState2);
        return sampleState2;
    }

    static {
        $assertionsDisabled = !GamePlayerBenchmark.class.desiredAssertionStatus();
        log = Redwood.channels(GamePlayerBenchmark.class);
    }
}
