package edu.stanford.nlp.math;

import junit.framework.TestCase;

/* loaded from: input_file:edu/stanford/nlp/math/ArrayMathTest.class */
public class ArrayMathTest extends TestCase {
    private double[] d1 = new double[3];
    private double[] d2 = new double[3];
    private double[] d3 = new double[3];
    private double[] d4 = new double[3];
    private double[] d5 = new double[4];

    public void setUp() {
        this.d1[0] = 1.0d;
        this.d1[1] = 343.33d;
        this.d1[2] = -13.1d;
        this.d2[0] = 1.0d;
        this.d2[1] = 343.33d;
        this.d2[2] = -13.1d;
        this.d3[0] = Double.NaN;
        this.d3[1] = Double.POSITIVE_INFINITY;
        this.d3[2] = 2.0d;
        this.d4[0] = 0.1d;
        this.d4[1] = 0.2d;
        this.d4[2] = 0.3d;
        this.d5[0] = 0.1d;
        this.d5[1] = 0.2d;
        this.d5[2] = 0.3d;
        this.d5[3] = 0.8d;
    }

    public void testInnerProduct() {
        assertEquals("Wrong inner product", 0.14d, ArrayMath.innerProduct(this.d4, this.d4), 1.0E-6d);
        assertEquals("Wrong inner product", 0.78d, ArrayMath.innerProduct(this.d5, this.d5), 1.0E-6d);
    }

    public void testNumRows() {
        assertEquals(ArrayMath.numRows(this.d1), 3);
    }

    public void testExpLog() {
        assertTrue(ArrayMath.norm(ArrayMath.pairwiseSubtract(this.d1, ArrayMath.log(ArrayMath.exp(this.d1)))) < 1.0E-4d);
    }

    public void testExpLogInplace() {
        ArrayMath.expInPlace(this.d1);
        ArrayMath.logInPlace(this.d1);
        ArrayMath.pairwiseSubtractInPlace(this.d1, this.d2);
        assertTrue(ArrayMath.norm(this.d1) < 1.0E-4d);
    }

    public void testAddInPlace() {
        ArrayMath.addInPlace(this.d1, 3.0d);
        for (int i = 0; i < ArrayMath.numRows(this.d1); i++) {
            assertTrue(this.d1[i] == this.d2[i] + 3.0d);
        }
    }

    public void testMultiplyInPlace() {
        ArrayMath.multiplyInPlace(this.d1, 3.0d);
        for (int i = 0; i < ArrayMath.numRows(this.d1); i++) {
            assertTrue(this.d1[i] == this.d2[i] * 3.0d);
        }
    }

    public void testPowInPlace() {
        ArrayMath.powInPlace(this.d1, 3.0d);
        for (int i = 0; i < ArrayMath.numRows(this.d1); i++) {
            assertTrue(this.d1[i] == Math.pow(this.d2[i], 3.0d));
        }
    }

    public void testAdd() {
        double[] add = ArrayMath.add(this.d1, 3.0d);
        for (int i = 0; i < ArrayMath.numRows(add); i++) {
            assertTrue(add[i] == this.d1[i] + 3.0d);
        }
    }

    public void testMultiply() {
        double[] multiply = ArrayMath.multiply(this.d1, 3.0d);
        for (int i = 0; i < ArrayMath.numRows(multiply); i++) {
            assertTrue(multiply[i] == this.d1[i] * 3.0d);
        }
    }

    public void testPow() {
        double[] pow = ArrayMath.pow(this.d1, 3.0d);
        for (int i = 0; i < ArrayMath.numRows(pow); i++) {
            assertTrue(pow[i] == Math.pow(this.d1[i], 3.0d));
        }
    }

    public void testPairwiseAdd() {
        double[] pairwiseAdd = ArrayMath.pairwiseAdd(this.d1, this.d2);
        for (int i = 0; i < ArrayMath.numRows(this.d1); i++) {
            assertTrue(pairwiseAdd[i] == this.d1[i] + this.d2[i]);
        }
    }

    public void testPairwiseSubtract() {
        double[] pairwiseSubtract = ArrayMath.pairwiseSubtract(this.d1, this.d2);
        for (int i = 0; i < ArrayMath.numRows(this.d1); i++) {
            assertTrue(pairwiseSubtract[i] == this.d1[i] - this.d2[i]);
        }
    }

    public void testPairwiseMultiply() {
        double[] pairwiseMultiply = ArrayMath.pairwiseMultiply(this.d1, this.d2);
        for (int i = 0; i < ArrayMath.numRows(this.d1); i++) {
            assertTrue(pairwiseMultiply[i] == this.d1[i] * this.d2[i]);
        }
    }

    public void testHasNaN() {
        assertFalse(ArrayMath.hasNaN(this.d1));
        assertFalse(ArrayMath.hasNaN(this.d2));
        assertTrue(ArrayMath.hasNaN(this.d3));
    }

    public void testHasInfinite() {
        assertFalse(ArrayMath.hasInfinite(this.d1));
        assertFalse(ArrayMath.hasInfinite(this.d2));
        assertTrue(ArrayMath.hasInfinite(this.d3));
    }

    public void testCountNaN() {
        assertTrue(ArrayMath.countNaN(this.d1) == 0);
        assertTrue(ArrayMath.countNaN(this.d2) == 0);
        assertTrue(ArrayMath.countNaN(this.d3) == 1);
    }

    public void testFliterNaN() {
        double[] filterNaN = ArrayMath.filterNaN(this.d3);
        assertTrue(ArrayMath.numRows(filterNaN) == 2);
        assertTrue(ArrayMath.countNaN(filterNaN) == 0);
    }

    public void testCountInfinite() {
        assertTrue(ArrayMath.countInfinite(this.d1) == 0);
        assertTrue(ArrayMath.countInfinite(this.d2) == 0);
        assertTrue(ArrayMath.countInfinite(this.d3) == 1);
    }

    public void testFliterInfinite() {
        double[] filterInfinite = ArrayMath.filterInfinite(this.d3);
        assertTrue(ArrayMath.numRows(filterInfinite) == 2);
        assertTrue(ArrayMath.countInfinite(filterInfinite) == 0);
    }

    public void testFliterNaNAndInfinite() {
        double[] filterNaNAndInfinite = ArrayMath.filterNaNAndInfinite(this.d3);
        assertTrue(ArrayMath.numRows(filterNaNAndInfinite) == 1);
        assertTrue(ArrayMath.countInfinite(filterNaNAndInfinite) == 0);
        assertTrue(ArrayMath.countNaN(filterNaNAndInfinite) == 0);
    }

    public void testSum() {
        double sum = ArrayMath.sum(this.d1);
        double d = 0.0d;
        for (double d2 : this.d1) {
            d += d2;
        }
        assertTrue(sum == d);
    }

    public void testNorm_inf() {
        assertTrue(ArrayMath.norm_inf(this.d1) == ArrayMath.max(this.d1));
        assertTrue(ArrayMath.norm_inf(this.d2) == ArrayMath.max(this.d2));
        assertTrue(ArrayMath.norm_inf(this.d3) == ArrayMath.max(this.d3));
    }

    public void testArgmax() {
        assertTrue(ArrayMath.max(this.d1) == this.d1[ArrayMath.argmax(this.d1)]);
        assertTrue(ArrayMath.max(this.d2) == this.d2[ArrayMath.argmax(this.d2)]);
        assertTrue(ArrayMath.max(this.d3) == this.d3[ArrayMath.argmax(this.d3)]);
    }

    public void testArgmin() {
        assertTrue(ArrayMath.min(this.d1) == this.d1[ArrayMath.argmin(this.d1)]);
        assertTrue(ArrayMath.min(this.d2) == this.d2[ArrayMath.argmin(this.d2)]);
        assertTrue(ArrayMath.min(this.d3) == this.d3[ArrayMath.argmin(this.d3)]);
    }

    public void testLogSum() {
        double logSum = ArrayMath.logSum(this.d4);
        double d = 0.0d;
        for (double d2 : this.d4) {
            d += Math.exp(d2);
        }
        assertTrue(Math.log(d) == logSum);
    }

    public void testNormalize() {
        ArrayMath.normalize(this.d1);
        ArrayMath.normalize(this.d2);
        ArrayMath.normalize(this.d4);
        assertTrue(ArrayMath.sum(this.d1) - 1.0d < 1.0E-4d);
        assertTrue(ArrayMath.sum(this.d2) - 1.0d < 1.0E-4d);
        assertTrue(ArrayMath.sum(this.d4) - 1.0d < 1.0E-4d);
    }

    public void testKLDivergence() {
        assertTrue(ArrayMath.klDivergence(this.d1, this.d2) == 0.0d);
    }

    public void testSumAndMean() {
        assertTrue(ArrayMath.sum(this.d1) == ArrayMath.mean(this.d1) * ((double) this.d1.length));
        assertTrue(ArrayMath.sum(this.d2) == ArrayMath.mean(this.d2) * ((double) this.d2.length));
        assertTrue(ArrayMath.sum(this.d4) == ArrayMath.mean(this.d4) * ((double) this.d4.length));
    }

    public static void helpTestSafeSumAndMean(double[] dArr) {
        double[] filterNaNAndInfinite = ArrayMath.filterNaNAndInfinite(dArr);
        assertTrue(ArrayMath.safeMean(dArr) * ((double) ArrayMath.numRows(filterNaNAndInfinite)) == ArrayMath.sum(filterNaNAndInfinite));
    }

    public void testSafeSumAndMean() {
        helpTestSafeSumAndMean(this.d1);
        helpTestSafeSumAndMean(this.d2);
        helpTestSafeSumAndMean(this.d3);
        helpTestSafeSumAndMean(this.d4);
    }

    public void testJensenShannon() {
        assertEquals(0.46514844544032313d, ArrayMath.jensenShannonDivergence(new double[]{0.1d, 0.1d, 0.7d, 0.1d, 0.0d, 0.0d}, new double[]{0.0d, 0.1d, 0.1d, 0.7d, 0.1d, 0.0d}), 1.0E-5d);
        assertEquals(1.0d, ArrayMath.jensenShannonDivergence(new double[]{1.0d, 0.0d, 0.0d}, new double[]{0.0d, 0.5d, 0.5d}), 1.0E-5d);
    }
}
