package edu.stanford.nlp.loglinear.learning;

import com.pholser.junit.quickcheck.ForAll;
import com.pholser.junit.quickcheck.From;
import com.pholser.junit.quickcheck.generator.InRange;
import edu.stanford.nlp.loglinear.learning.LogLikelihoodFunctionTest;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import java.util.Random;
import org.junit.Assert;
import org.junit.contrib.theories.DataPoint;
import org.junit.contrib.theories.Theories;
import org.junit.contrib.theories.Theory;
import org.junit.runner.RunWith;

@RunWith(Theories.class)
/* loaded from: input_file:edu/stanford/nlp/loglinear/learning/OptimizerTests.class */
public class OptimizerTests {

    @DataPoint
    public static AbstractBatchOptimizer backtrackingAdaGrad = new BacktrackingAdaGradOptimizer();

    @Theory
    public void testOptimizeLogLikelihood(AbstractBatchOptimizer abstractBatchOptimizer, @ForAll(sampleSize = 5) @From({LogLikelihoodFunctionTest.GraphicalModelDatasetGenerator.class}) GraphicalModel[] graphicalModelArr, @ForAll(sampleSize = 2) @From({LogLikelihoodFunctionTest.WeightsGenerator.class}) ConcatVector concatVector, @InRange(minDouble = 0.0d, maxDouble = 5.0d) @ForAll(sampleSize = 2) double d) throws Exception {
        LogLikelihoodDifferentiableFunction logLikelihoodDifferentiableFunction = new LogLikelihoodDifferentiableFunction();
        ConcatVector optimize = abstractBatchOptimizer.optimize(graphicalModelArr, logLikelihoodDifferentiableFunction, concatVector, d, 1.0E-9d, true);
        System.err.println("Finished optimizing");
        double valueSum = getValueSum(graphicalModelArr, optimize, logLikelihoodDifferentiableFunction, d);
        Random random = new Random(42L);
        for (int i = 0; i < 1000; i++) {
            int numberOfComponents = optimize.getNumberOfComponents();
            ConcatVector concatVector2 = new ConcatVector(numberOfComponents);
            for (int i2 = 0; i2 < numberOfComponents; i2++) {
                double[] dArr = new double[optimize.isComponentSparse(i2) ? optimize.getSparseIndex(i2) + 1 : optimize.getDenseComponent(i2).length];
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    dArr[i3] = (random.nextDouble() - 0.5d) * 0.001d;
                }
                concatVector2.setDenseComponent(i2, dArr);
            }
            ConcatVector deepClone = optimize.deepClone();
            deepClone.addVectorInPlace(concatVector2, 1.0d);
            double valueSum2 = getValueSum(graphicalModelArr, deepClone, logLikelihoodDifferentiableFunction, d);
            if (valueSum < valueSum2 - (0.001d * Math.max(1.0d, Math.abs(valueSum)))) {
                System.err.println("Thought optimal point was: " + valueSum);
                System.err.println("Discovered better point: " + valueSum2);
            }
            Assert.assertTrue(valueSum >= valueSum2 - (0.001d * Math.max(1.0d, Math.abs(valueSum))));
        }
    }

    private <T> double getValueSum(T[] tArr, ConcatVector concatVector, AbstractDifferentiableFunction<T> abstractDifferentiableFunction, double d) {
        double d2 = 0.0d;
        for (T t : tArr) {
            d2 += abstractDifferentiableFunction.getSummaryForInstance(t, concatVector, new ConcatVector(0));
        }
        return (d2 / tArr.length) - (concatVector.dotProduct(concatVector) * d);
    }
}
