package edu.stanford.nlp.loglinear.inference;

import com.pholser.junit.quickcheck.ForAll;
import com.pholser.junit.quickcheck.From;
import com.pholser.junit.quickcheck.generator.GenerationStatus;
import com.pholser.junit.quickcheck.generator.Generator;
import com.pholser.junit.quickcheck.generator.InRange;
import com.pholser.junit.quickcheck.random.SourceOfRandomness;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.ConcatVectorTable;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.contrib.theories.Theories;
import org.junit.contrib.theories.Theory;
import org.junit.runner.RunWith;

@RunWith(Theories.class)
/* loaded from: input_file:edu/stanford/nlp/loglinear/inference/TableFactorTest.class */
public class TableFactorTest {
    public static int[] variableSizes;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:edu/stanford/nlp/loglinear/inference/TableFactorTest$ConcatVectorGenerator.class */
    public static class ConcatVectorGenerator extends Generator<ConcatVector> {
        public ConcatVectorGenerator(Class<ConcatVector> cls) {
            super(cls);
        }

        /* renamed from: generate, reason: merged with bridge method [inline-methods] */
        public ConcatVector m321generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) {
            ConcatVector concatVector = new ConcatVector(1);
            double[] dArr = new double[20];
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = sourceOfRandomness.nextDouble();
            }
            concatVector.setDenseComponent(0, dArr);
            return concatVector;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/stanford/nlp/loglinear/inference/TableFactorTest$PartiallyObservedConstructorData.class */
    public static class PartiallyObservedConstructorData {
        public GraphicalModel.Factor factor;
        public int[] observations;

        private PartiallyObservedConstructorData() {
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/loglinear/inference/TableFactorTest$PartiallyObservedConstructorDataGenerator.class */
    public static class PartiallyObservedConstructorDataGenerator extends Generator<PartiallyObservedConstructorData> {
        public PartiallyObservedConstructorDataGenerator(Class<PartiallyObservedConstructorData> cls) {
            super(cls);
        }

        /* renamed from: generate, reason: merged with bridge method [inline-methods] */
        public PartiallyObservedConstructorData m322generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) {
            int i;
            int nextInt = sourceOfRandomness.nextInt(1, 4);
            HashSet hashSet = new HashSet();
            int[] iArr = new int[nextInt];
            int[] iArr2 = new int[nextInt];
            int[] iArr3 = new int[nextInt];
            int i2 = 0;
            for (int i3 = 0; i3 < nextInt; i3++) {
                int nextInt2 = sourceOfRandomness.nextInt(8);
                while (true) {
                    i = nextInt2;
                    if (!hashSet.contains(Integer.valueOf(i))) {
                        break;
                    }
                    nextInt2 = sourceOfRandomness.nextInt(8);
                }
                hashSet.add(Integer.valueOf(i));
                iArr[i3] = i;
                iArr2[i3] = sourceOfRandomness.nextInt(1, 3);
                if (!sourceOfRandomness.nextBoolean() || i2 + 1 >= iArr2.length) {
                    iArr3[i3] = -1;
                } else {
                    iArr3[i3] = sourceOfRandomness.nextInt(iArr2[i3]);
                    i2++;
                }
            }
            ConcatVectorTable concatVectorTable = new ConcatVectorTable(iArr2);
            ConcatVectorGenerator concatVectorGenerator = new ConcatVectorGenerator(ConcatVector.class);
            Iterator<int[]> it = concatVectorTable.iterator();
            while (it.hasNext()) {
                int[] next = it.next();
                ConcatVector m321generate = concatVectorGenerator.m321generate(sourceOfRandomness, generationStatus);
                concatVectorTable.setAssignmentValue(next, () -> {
                    return m321generate;
                });
            }
            PartiallyObservedConstructorData partiallyObservedConstructorData = new PartiallyObservedConstructorData();
            partiallyObservedConstructorData.factor = new GraphicalModel.Factor(concatVectorTable, iArr);
            partiallyObservedConstructorData.observations = iArr3;
            return partiallyObservedConstructorData;
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/loglinear/inference/TableFactorTest$TableFactorGenerator.class */
    public static class TableFactorGenerator extends Generator<TableFactor> {
        public TableFactorGenerator(Class<TableFactor> cls) {
            super(cls);
        }

        /* renamed from: generate, reason: merged with bridge method [inline-methods] */
        public TableFactor m323generate(SourceOfRandomness sourceOfRandomness, GenerationStatus generationStatus) {
            int nextInt;
            int nextInt2 = sourceOfRandomness.nextInt(1, 3);
            int[] iArr = new int[nextInt2];
            int[] iArr2 = new int[nextInt2];
            HashSet hashSet = new HashSet();
            for (int i = 0; i < iArr.length; i++) {
                do {
                    nextInt = sourceOfRandomness.nextInt(0, 3);
                } while (hashSet.contains(Integer.valueOf(nextInt)));
                hashSet.add(Integer.valueOf(nextInt));
                iArr[i] = nextInt;
                iArr2[i] = TableFactorTest.variableSizes[nextInt];
            }
            double nextDouble = sourceOfRandomness.nextDouble();
            TableFactor tableFactor = new TableFactor(iArr, iArr2);
            Iterator<int[]> it = tableFactor.iterator();
            while (it.hasNext()) {
                tableFactor.setAssignmentValue(it.next(), nextDouble * sourceOfRandomness.nextDouble());
            }
            return tableFactor;
        }
    }

    @Theory
    public void testConstructWithObservations(@ForAll(sampleSize = 50) @From({PartiallyObservedConstructorDataGenerator.class}) PartiallyObservedConstructorData partiallyObservedConstructorData, @ForAll(sampleSize = 2) @From({ConcatVectorGenerator.class}) ConcatVector concatVector) throws Exception {
        int[] iArr = new int[9];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = -1;
        }
        for (int i2 = 0; i2 < partiallyObservedConstructorData.observations.length; i2++) {
            iArr[partiallyObservedConstructorData.factor.neigborIndices[i2]] = partiallyObservedConstructorData.observations[i2];
        }
        TableFactor tableFactor = new TableFactor(concatVector, partiallyObservedConstructorData.factor);
        for (int i3 = 0; i3 < iArr.length; i3++) {
            if (iArr[i3] != -1) {
                tableFactor = tableFactor.observe(i3, iArr[i3]);
            }
        }
        TableFactor tableFactor2 = new TableFactor(concatVector, partiallyObservedConstructorData.factor, partiallyObservedConstructorData.observations);
        Assert.assertArrayEquals(tableFactor.neighborIndices, tableFactor2.neighborIndices);
        Iterator<int[]> it = tableFactor.iterator();
        while (it.hasNext()) {
            int[] next = it.next();
            Assert.assertEquals(tableFactor.getAssignmentValue(next), tableFactor2.getAssignmentValue(next), 1.0E-9d);
        }
    }

    @Theory
    public void testObserve(@ForAll(sampleSize = 50) @From({TableFactorGenerator.class}) TableFactor tableFactor, @InRange(minInt = 0, maxInt = 3) @ForAll(sampleSize = 2) int i, @InRange(minInt = 0, maxInt = 1) @ForAll(sampleSize = 2) int i2) throws Exception {
        if (((Set) Arrays.stream(tableFactor.neighborIndices).boxed().collect(Collectors.toSet())).contains(Integer.valueOf(i)) && tableFactor.neighborIndices.length != 1) {
            TableFactor observe = tableFactor.observe(i, i2);
            int i3 = -1;
            for (int i4 = 0; i4 < tableFactor.neighborIndices.length; i4++) {
                if (tableFactor.neighborIndices[i4] == i) {
                    i3 = i4;
                }
            }
            Iterator<int[]> it = tableFactor.iterator();
            while (it.hasNext()) {
                int[] next = it.next();
                if (next[i3] == i2) {
                    Assert.assertEquals(tableFactor.getAssignmentValue(next), observe.getAssignmentValue(subsetAssignment(next, tableFactor, observe)), 1.0E-7d);
                }
            }
        }
    }

    @Theory
    public void testGetMaxedMarginals(@ForAll(sampleSize = 50) @From({TableFactorGenerator.class}) TableFactor tableFactor, @InRange(minInt = 0, maxInt = 3) @ForAll(sampleSize = 10) int i) throws Exception {
        if (((Set) Arrays.stream(tableFactor.neighborIndices).boxed().collect(Collectors.toSet())).contains(Integer.valueOf(i))) {
            int i2 = -1;
            int i3 = 0;
            while (true) {
                if (i3 >= tableFactor.neighborIndices.length) {
                    break;
                }
                if (tableFactor.neighborIndices[i3] == i) {
                    i2 = i3;
                    break;
                }
                i3++;
            }
            Assume.assumeTrue(i2 > -1);
            double[] dArr = new double[tableFactor.getDimensions()[i2]];
            for (int i4 = 0; i4 < dArr.length; i4++) {
                dArr[i4] = Double.NEGATIVE_INFINITY;
            }
            Iterator<int[]> it = tableFactor.iterator();
            while (it.hasNext()) {
                int[] next = it.next();
                dArr[next[i2]] = Math.max(dArr[next[i2]], tableFactor.getAssignmentValue(next));
            }
            normalize(dArr);
            Assert.assertArrayEquals(dArr, tableFactor.getMaxedMarginals()[i2], 1.0E-5d);
        }
    }

    @Theory
    public void testGetSummedMarginals(@ForAll(sampleSize = 50) @From({TableFactorGenerator.class}) TableFactor tableFactor, @InRange(minInt = 0, maxInt = 3) @ForAll(sampleSize = 10) int i) throws Exception {
        if (((Set) Arrays.stream(tableFactor.neighborIndices).boxed().collect(Collectors.toSet())).contains(Integer.valueOf(i))) {
            int i2 = -1;
            int i3 = 0;
            while (true) {
                if (i3 >= tableFactor.neighborIndices.length) {
                    break;
                }
                if (tableFactor.neighborIndices[i3] == i) {
                    i2 = i3;
                    break;
                }
                i3++;
            }
            Assume.assumeTrue(i2 > -1);
            double[] dArr = new double[tableFactor.getDimensions()[i2]];
            Iterator<int[]> it = tableFactor.iterator();
            while (it.hasNext()) {
                int[] next = it.next();
                dArr[next[i2]] = dArr[next[i2]] + tableFactor.getAssignmentValue(next);
            }
            normalize(dArr);
            Assert.assertArrayEquals(dArr, tableFactor.getSummedMarginals()[i2], 1.0E-5d);
        }
    }

    private void normalize(double[] dArr) {
        double d = 0.0d;
        for (double d2 : dArr) {
            d += d2;
        }
        if (d == 0.0d) {
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = 1.0d / dArr.length;
            }
            return;
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = dArr[i2] / d;
        }
    }

    @Theory
    public void testValueSum(@ForAll(sampleSize = 50) @From({TableFactorGenerator.class}) TableFactor tableFactor) throws Exception {
        double d = 0.0d;
        Iterator<int[]> it = tableFactor.iterator();
        while (it.hasNext()) {
            d += tableFactor.getAssignmentValue(it.next());
        }
        Assert.assertEquals(d, tableFactor.valueSum(), 1.0E-5d);
    }

    @Theory
    public void testMaxOut(@ForAll(sampleSize = 50) @From({TableFactorGenerator.class}) TableFactor tableFactor, @InRange(minInt = 0, maxInt = 3) @ForAll(sampleSize = 10) int i) throws Exception {
        if (((Set) Arrays.stream(tableFactor.neighborIndices).boxed().collect(Collectors.toSet())).contains(Integer.valueOf(i)) && tableFactor.neighborIndices.length > 1) {
            TableFactor maxOut = tableFactor.maxOut(i);
            Assert.assertEquals(tableFactor.neighborIndices.length - 1, maxOut.neighborIndices.length);
            Assert.assertTrue(!((Set) Arrays.stream(maxOut.neighborIndices).boxed().collect(Collectors.toSet())).contains(Integer.valueOf(i)));
            Iterator<int[]> it = tableFactor.iterator();
            while (it.hasNext()) {
                int[] next = it.next();
                Assert.assertTrue(tableFactor.getAssignmentValue(next) >= Double.NEGATIVE_INFINITY);
                Assert.assertTrue(tableFactor.getAssignmentValue(next) <= maxOut.getAssignmentValue(subsetAssignment(next, tableFactor, maxOut)));
            }
            Map<List<Integer>, List<int[]>> subsetToSupersetAssignments = subsetToSupersetAssignments(tableFactor, maxOut);
            for (List<Integer> list : subsetToSupersetAssignments.keySet()) {
                double d = Double.NEGATIVE_INFINITY;
                Iterator<int[]> it2 = subsetToSupersetAssignments.get(list).iterator();
                while (it2.hasNext()) {
                    d = Math.max(d, tableFactor.getAssignmentValue(it2.next()));
                }
                int[] iArr = new int[list.size()];
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    iArr[i2] = list.get(i2).intValue();
                }
                Assert.assertEquals(d, maxOut.getAssignmentValue(iArr), 1.0E-5d);
            }
        }
    }

    @Theory
    public void testSumOut(@ForAll(sampleSize = 50) @From({TableFactorGenerator.class}) TableFactor tableFactor, @InRange(minInt = 0, maxInt = 3) @ForAll(sampleSize = 10) int i) throws Exception {
        if (((Set) Arrays.stream(tableFactor.neighborIndices).boxed().collect(Collectors.toSet())).contains(Integer.valueOf(i)) && tableFactor.neighborIndices.length > 1) {
            TableFactor sumOut = tableFactor.sumOut(i);
            Assert.assertEquals(tableFactor.neighborIndices.length - 1, sumOut.neighborIndices.length);
            Assert.assertTrue(!((Set) Arrays.stream(sumOut.neighborIndices).boxed().collect(Collectors.toSet())).contains(Integer.valueOf(i)));
            Map<List<Integer>, List<int[]>> subsetToSupersetAssignments = subsetToSupersetAssignments(tableFactor, sumOut);
            for (List<Integer> list : subsetToSupersetAssignments.keySet()) {
                double d = 0.0d;
                Iterator<int[]> it = subsetToSupersetAssignments.get(list).iterator();
                while (it.hasNext()) {
                    d += tableFactor.getAssignmentValue(it.next());
                }
                int[] iArr = new int[list.size()];
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    iArr[i2] = list.get(i2).intValue();
                }
                Assert.assertEquals(d, sumOut.getAssignmentValue(iArr), 1.0E-5d);
            }
        }
    }

    @Theory
    public void testMultiply(@ForAll(sampleSize = 10) @From({TableFactorGenerator.class}) TableFactor tableFactor, @ForAll(sampleSize = 10) @From({TableFactorGenerator.class}) TableFactor tableFactor2) throws Exception {
        TableFactor multiply = tableFactor.multiply(tableFactor2);
        Iterator<int[]> it = multiply.iterator();
        while (it.hasNext()) {
            int[] next = it.next();
            Assert.assertEquals(tableFactor.getAssignmentValue(subsetAssignment(next, multiply, tableFactor)) * tableFactor2.getAssignmentValue(subsetAssignment(next, multiply, tableFactor2)), multiply.getAssignmentValue(next), 1.0E-5d);
        }
        for (int i = 0; i < multiply.neighborIndices.length; i++) {
            for (int i2 = 0; i2 < multiply.neighborIndices.length; i2++) {
                if (i != i2) {
                    Assert.assertNotEquals(multiply.neighborIndices[i], multiply.neighborIndices[i2]);
                }
            }
        }
    }

    private int[] subsetAssignment(int[] iArr, TableFactor tableFactor, TableFactor tableFactor2) {
        int[] iArr2 = new int[tableFactor2.neighborIndices.length];
        for (int i = 0; i < tableFactor2.neighborIndices.length; i++) {
            int i2 = tableFactor2.neighborIndices[i];
            iArr2[i] = -1;
            int i3 = 0;
            while (true) {
                if (i3 >= tableFactor.neighborIndices.length) {
                    break;
                }
                if (tableFactor.neighborIndices[i3] == i2) {
                    iArr2[i] = iArr[i3];
                    break;
                }
                i3++;
            }
            if (!$assertionsDisabled && iArr2[i] == -1) {
                throw new AssertionError();
            }
        }
        return iArr2;
    }

    private Map<List<Integer>, List<int[]>> subsetToSupersetAssignments(TableFactor tableFactor, TableFactor tableFactor2) {
        HashMap hashMap = new HashMap();
        Iterator<int[]> it = tableFactor2.iterator();
        while (it.hasNext()) {
            int[] next = it.next();
            List list = (List) Arrays.stream(next).boxed().collect(Collectors.toList());
            ArrayList arrayList = new ArrayList();
            Iterator<int[]> it2 = tableFactor.iterator();
            while (it2.hasNext()) {
                int[] next2 = it2.next();
                if (Arrays.equals(next, subsetAssignment(next2, tableFactor, tableFactor2))) {
                    arrayList.add(next2);
                }
            }
            hashMap.put(list, arrayList);
        }
        return hashMap;
    }

    static {
        $assertionsDisabled = !TableFactorTest.class.desiredAssertionStatus();
        variableSizes = new int[]{2, 4, 2, 3};
    }
}
