package edu.stanford.nlp.loglinear.learning;

import com.pholser.junit.quickcheck.ForAll;
import com.pholser.junit.quickcheck.From;
import com.pholser.junit.quickcheck.generator.GenerationStatus;
import com.pholser.junit.quickcheck.generator.Generator;
import com.pholser.junit.quickcheck.random.SourceOfRandomness;
import edu.stanford.nlp.loglinear.inference.CliqueTree;
import edu.stanford.nlp.loglinear.inference.TableFactor;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.ConcatVectorTable;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.junit.Assert;
import org.junit.contrib.theories.Theories;
import org.junit.contrib.theories.Theory;
import org.junit.runner.RunWith;

@RunWith(Theories.class)
/* loaded from: input_file:edu/stanford/nlp/loglinear/learning/LogLikelihoodFunctionTest.class */
public class LogLikelihoodFunctionTest {
    public static final int CONCAT_VEC_COMPONENTS = 2;
    public static final int CONCAT_VEC_COMPONENT_LENGTH = 3;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/stanford/nlp/loglinear/learning/LogLikelihoodFunctionTest$GraphicalModelDatasetGenerator.class */
    public static class GraphicalModelDatasetGenerator extends Generator<GraphicalModel[]> {
        GraphicalModelGenerator modelGenerator;

        public GraphicalModelDatasetGenerator(Class<GraphicalModel[]> cls) {
            super(cls);
            this.modelGenerator = new GraphicalModelGenerator(GraphicalModel.class);
        }

        /* renamed from: generate, reason: merged with bridge method [inline-methods] */
        public GraphicalModel[] m328generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) {
            GraphicalModel[] graphicalModelArr = new GraphicalModel[sourceOfRandomness.nextInt(1, 10)];
            for (int i = 0; i < graphicalModelArr.length; i++) {
                graphicalModelArr[i] = this.modelGenerator.m330generate(sourceOfRandomness, generationStatus);
                for (GraphicalModel.Factor factor : graphicalModelArr[i].factors) {
                    for (int i2 = 0; i2 < factor.neigborIndices.length; i2++) {
                        graphicalModelArr[i].getVariableMetaDataByReference(factor.neigborIndices[i2]).put(LogLikelihoodDifferentiableFunction.VARIABLE_TRAINING_VALUE, "" + sourceOfRandomness.nextInt(factor.featuresTable.getDimensions()[i2]));
                    }
                }
            }
            return graphicalModelArr;
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/loglinear/learning/LogLikelihoodFunctionTest$GraphicalModelGenerator.class */
    public static class GraphicalModelGenerator extends Generator<GraphicalModel> {
        static final /* synthetic */ boolean $assertionsDisabled;

        public GraphicalModelGenerator(Class<GraphicalModel> cls) {
            super(cls);
        }

        private Map<String, String> generateMetaData(SourceOfRandomness sourceOfRandomness, Map<String, String> map) {
            int nextInt = sourceOfRandomness.nextInt(9);
            for (int i = 0; i < nextInt; i++) {
                map.put("key:" + sourceOfRandomness.nextInt(), "value:" + sourceOfRandomness.nextInt());
            }
            return map;
        }

        /* renamed from: generate, reason: merged with bridge method [inline-methods] */
        public GraphicalModel m330generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) {
            GraphicalModel graphicalModel = new GraphicalModel();
            int[] iArr = new int[8];
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = sourceOfRandomness.nextInt(1, 3);
            }
            generateCliques(iArr, new ArrayList(), new HashSet(), graphicalModel, sourceOfRandomness);
            generateMetaData(sourceOfRandomness, graphicalModel.getModelMetaDataByReference());
            for (int i2 = 0; i2 < 20; i2++) {
                generateMetaData(sourceOfRandomness, graphicalModel.getVariableMetaDataByReference(i2));
            }
            Iterator<GraphicalModel.Factor> it = graphicalModel.factors.iterator();
            while (it.hasNext()) {
                generateMetaData(sourceOfRandomness, it.next().getMetaDataByReference());
            }
            for (GraphicalModel.Factor factor : graphicalModel.factors) {
                for (int i3 = 0; i3 < factor.neigborIndices.length; i3++) {
                    if (sourceOfRandomness.nextDouble() > 0.8d) {
                        graphicalModel.getVariableMetaDataByReference(factor.neigborIndices[i3]).put(CliqueTree.VARIABLE_OBSERVED_VALUE, "" + sourceOfRandomness.nextInt(factor.featuresTable.getDimensions()[i3]));
                    }
                }
            }
            return graphicalModel;
        }

        private void generateCliques(int[] iArr, List<Integer> list, Set<Integer> set, GraphicalModel graphicalModel, SourceOfRandomness sourceOfRandomness) {
            int nextInt;
            if (set.size() == iArr.length) {
                return;
            }
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(list);
            while (set.size() != iArr.length && (arrayList.size() == 0 || sourceOfRandomness.nextDouble(0.0d, 1.0d) < 0.7d)) {
                do {
                    nextInt = sourceOfRandomness.nextInt(iArr.length);
                } while (set.contains(Integer.valueOf(nextInt)));
                set.add(Integer.valueOf(nextInt));
                arrayList.add(Integer.valueOf(nextInt));
            }
            int[] iArr2 = new int[arrayList.size()];
            int[] iArr3 = new int[iArr2.length];
            for (int i = 0; i < iArr2.length; i++) {
                iArr2[i] = ((Integer) arrayList.get(i)).intValue();
                iArr3[i] = iArr[iArr2[i]];
            }
            ConcatVectorTable concatVectorTable = new ConcatVectorTable(iArr3);
            Iterator<int[]> it = concatVectorTable.iterator();
            while (it.hasNext()) {
                int[] next = it.next();
                ConcatVector concatVector = new ConcatVector(2);
                for (int i2 = 0; i2 < 2; i2++) {
                    if (sourceOfRandomness.nextBoolean()) {
                        concatVector.setSparseComponent(i2, sourceOfRandomness.nextInt(3), sourceOfRandomness.nextDouble());
                    } else {
                        double[] dArr = new double[sourceOfRandomness.nextInt(3)];
                        for (int i3 = 0; i3 < dArr.length; i3++) {
                            dArr[i3] = sourceOfRandomness.nextDouble();
                        }
                        concatVector.setDenseComponent(i2, dArr);
                    }
                }
                concatVectorTable.setAssignmentValue(next, () -> {
                    return concatVector;
                });
            }
            graphicalModel.addFactor(concatVectorTable, iArr2);
            ArrayList arrayList2 = new ArrayList();
            arrayList2.addAll(arrayList);
            arrayList2.removeAll(list);
            int nextInt2 = sourceOfRandomness.nextInt(0, arrayList2.size());
            if (nextInt2 == 0) {
                return;
            }
            ArrayList<List> arrayList3 = new ArrayList();
            for (int i4 = 0; i4 < nextInt2; i4++) {
                arrayList3.add(new ArrayList());
            }
            int i5 = 0;
            while (true) {
                int i6 = i5;
                if (arrayList2.size() != 0 && (((List) arrayList3.get(i6)).size() == 0 || sourceOfRandomness.nextBoolean())) {
                    int nextInt3 = sourceOfRandomness.nextInt(arrayList2.size());
                    ((List) arrayList3.get(i6)).add(arrayList2.get(nextInt3));
                    arrayList2.remove(arrayList2.get(nextInt3));
                    i5 = (i6 + 1) % nextInt2;
                }
            }
            for (List list2 : arrayList3) {
                Iterator it2 = list2.iterator();
                while (it2.hasNext()) {
                    int intValue = ((Integer) it2.next()).intValue();
                    for (List list3 : arrayList3) {
                        if (!$assertionsDisabled && list2 != list3 && list3.contains(Integer.valueOf(intValue))) {
                            throw new AssertionError();
                        }
                    }
                }
            }
            Iterator it3 = arrayList3.iterator();
            while (it3.hasNext()) {
                List<Integer> list4 = (List) it3.next();
                if (list4.size() > 0) {
                    generateCliques(iArr, list4, set, graphicalModel, sourceOfRandomness);
                }
            }
        }

        static {
            $assertionsDisabled = !LogLikelihoodFunctionTest.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/loglinear/learning/LogLikelihoodFunctionTest$WeightsGenerator.class */
    public static class WeightsGenerator extends Generator<ConcatVector> {
        public WeightsGenerator(Class<ConcatVector> cls) {
            super(cls);
        }

        /* renamed from: generate, reason: merged with bridge method [inline-methods] */
        public ConcatVector m331generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) {
            ConcatVector concatVector = new ConcatVector(2);
            for (int i = 0; i < 2; i++) {
                if (sourceOfRandomness.nextBoolean()) {
                    concatVector.setSparseComponent(i, sourceOfRandomness.nextInt(3), sourceOfRandomness.nextDouble());
                } else {
                    double[] dArr = new double[sourceOfRandomness.nextInt(3)];
                    for (int i2 = 0; i2 < dArr.length; i2++) {
                        dArr[i2] = sourceOfRandomness.nextDouble();
                    }
                    concatVector.setDenseComponent(i, dArr);
                }
            }
            return concatVector;
        }
    }

    @Theory
    public void testGetSummaryForInstance(@ForAll(sampleSize = 50) @From({GraphicalModelDatasetGenerator.class}) GraphicalModel[] graphicalModelArr, @ForAll(sampleSize = 2) @From({WeightsGenerator.class}) ConcatVector concatVector) throws Exception {
        LogLikelihoodDifferentiableFunction logLikelihoodDifferentiableFunction = new LogLikelihoodDifferentiableFunction();
        for (GraphicalModel graphicalModel : graphicalModelArr) {
            double logLikelihood = logLikelihood(graphicalModel, concatVector);
            ConcatVector definitionOfDerivative = definitionOfDerivative(graphicalModel, concatVector);
            ConcatVector concatVector2 = new ConcatVector(0);
            Assert.assertEquals(logLikelihood, logLikelihoodDifferentiableFunction.getSummaryForInstance(graphicalModel, concatVector, concatVector2), Math.max(0.001d, logLikelihood * 0.01d));
            ConcatVector deepClone = definitionOfDerivative.deepClone();
            deepClone.addVectorInPlace(concatVector2, -1.0d);
            double sqrt = Math.sqrt(deepClone.dotProduct(deepClone));
            if (sqrt > 0.05d) {
                System.err.println("Definitional and calculated gradient differ!");
                System.err.println("Definition approx: " + definitionOfDerivative);
                System.err.println("Calculated: " + concatVector2);
            }
            Assert.assertEquals(0.0d, sqrt, 0.05d);
        }
    }

    private double logLikelihood(GraphicalModel graphicalModel, ConcatVector concatVector) {
        Set<TableFactor> set = (Set) graphicalModel.factors.stream().map(factor -> {
            return new TableFactor(concatVector, factor);
        }).collect(Collectors.toSet());
        if (!$assertionsDisabled && set.size() != graphicalModel.factors.size()) {
            throw new AssertionError();
        }
        TableFactor tableFactor = null;
        for (TableFactor tableFactor2 : set) {
            tableFactor = tableFactor == null ? tableFactor2 : tableFactor.multiply(tableFactor2);
        }
        if (!$assertionsDisabled && tableFactor == null) {
            throw new AssertionError();
        }
        TableFactor tableFactor3 = tableFactor;
        for (int i : tableFactor.neighborIndices) {
            if (graphicalModel.getVariableMetaDataByReference(i).containsKey(CliqueTree.VARIABLE_OBSERVED_VALUE)) {
                int parseInt = Integer.parseInt(graphicalModel.getVariableMetaDataByReference(i).get(CliqueTree.VARIABLE_OBSERVED_VALUE));
                if (tableFactor3.neighborIndices.length <= 1) {
                    return 0.0d;
                }
                tableFactor3 = tableFactor3.observe(i, parseInt);
            }
        }
        TableFactor tableFactor4 = tableFactor3;
        double valueSum = tableFactor4.valueSum();
        int[] iArr = new int[tableFactor4.neighborIndices.length];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (!$assertionsDisabled && graphicalModel.getVariableMetaDataByReference(tableFactor4.neighborIndices[i2]).containsKey(CliqueTree.VARIABLE_OBSERVED_VALUE)) {
                throw new AssertionError();
            }
            iArr[i2] = Integer.parseInt(graphicalModel.getVariableMetaDataByReference(tableFactor4.neighborIndices[i2]).get(LogLikelihoodDifferentiableFunction.VARIABLE_TRAINING_VALUE));
        }
        if (tableFactor4.getAssignmentValue(iArr) == 0.0d || valueSum == 0.0d) {
            return Double.NEGATIVE_INFINITY;
        }
        return Math.log(tableFactor4.getAssignmentValue(iArr)) - Math.log(valueSum);
    }

    private ConcatVector definitionOfDerivative(GraphicalModel graphicalModel, ConcatVector concatVector) {
        ConcatVector concatVector2 = new ConcatVector(2);
        for (int i = 0; i < 2; i++) {
            double[] dArr = new double[3];
            for (int i2 = 0; i2 < 3; i2++) {
                ConcatVector concatVector3 = new ConcatVector(2);
                concatVector3.setSparseComponent(i, i2, 1.0d);
                ConcatVector deepClone = concatVector.deepClone();
                deepClone.addVectorInPlace(concatVector3, 1.0E-7d);
                ConcatVector deepClone2 = concatVector.deepClone();
                deepClone2.addVectorInPlace(concatVector3, -1.0E-7d);
                dArr[i2] = (logLikelihood(graphicalModel, deepClone) - logLikelihood(graphicalModel, deepClone2)) / (2.0d * 1.0E-7d);
                if (Double.isNaN(dArr[i2])) {
                    dArr[i2] = 0.0d;
                }
            }
            concatVector2.setDenseComponent(i, dArr);
        }
        return concatVector2;
    }

    static {
        $assertionsDisabled = !LogLikelihoodFunctionTest.class.desiredAssertionStatus();
    }
}
