package edu.stanford.nlp.loglinear.learning;

import edu.stanford.nlp.loglinear.learning.AbstractBatchOptimizer;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.util.logging.Redwood;

/* loaded from: input_file:edu/stanford/nlp/loglinear/learning/BacktrackingAdaGradOptimizer.class */
public class BacktrackingAdaGradOptimizer extends AbstractBatchOptimizer {
    private static Redwood.RedwoodChannels log = Redwood.channels(BacktrackingAdaGradOptimizer.class);
    static final double alpha = 0.1d;

    /* loaded from: input_file:edu/stanford/nlp/loglinear/learning/BacktrackingAdaGradOptimizer$AdaGradOptimizationState.class */
    protected class AdaGradOptimizationState extends AbstractBatchOptimizer.OptimizationState {
        ConcatVector lastDerivative;
        ConcatVector adagradAccumulator;
        double lastLogLikelihood;

        protected AdaGradOptimizationState() {
            super();
            this.lastDerivative = new ConcatVector(0);
            this.adagradAccumulator = new ConcatVector(0);
            this.lastLogLikelihood = Double.NEGATIVE_INFINITY;
        }
    }

    @Override // edu.stanford.nlp.loglinear.learning.AbstractBatchOptimizer
    public boolean updateWeights(ConcatVector concatVector, ConcatVector concatVector2, double d, AbstractBatchOptimizer.OptimizationState optimizationState, boolean z) {
        AdaGradOptimizationState adaGradOptimizationState = (AdaGradOptimizationState) optimizationState;
        double d2 = d - adaGradOptimizationState.lastLogLikelihood;
        if (d2 == 0.0d) {
            if (z) {
                return true;
            }
            log.info("\tlogLikelihood improvement = 0: quitting");
            return true;
        }
        if (d2 < 0.0d) {
            adaGradOptimizationState.lastDerivative.mapInPlace(d3 -> {
                return Double.valueOf(d3.doubleValue() / 2.0d);
            });
            concatVector.addVectorInPlace(adaGradOptimizationState.lastDerivative, -1.0d);
            if (!z) {
                log.info("\tBACKTRACK...");
            }
            if (adaGradOptimizationState.lastDerivative.dotProduct(adaGradOptimizationState.lastDerivative) >= 1.0E-10d) {
                return false;
            }
            if (z) {
                return true;
            }
            log.info("\tBacktracking derivative norm " + adaGradOptimizationState.lastDerivative.dotProduct(adaGradOptimizationState.lastDerivative) + " < 1.0e-9: quitting");
            return true;
        }
        ConcatVector deepClone = concatVector2.deepClone();
        deepClone.mapInPlace(d4 -> {
            return Double.valueOf(d4.doubleValue() * d4.doubleValue());
        });
        adaGradOptimizationState.adagradAccumulator.addVectorInPlace(deepClone, 1.0d);
        ConcatVector deepClone2 = adaGradOptimizationState.adagradAccumulator.deepClone();
        deepClone2.mapInPlace(d5 -> {
            return d5.doubleValue() == 0.0d ? Double.valueOf(0.1d) : Double.valueOf(0.1d / Math.sqrt(d5.doubleValue()));
        });
        concatVector2.elementwiseProductInPlace(deepClone2);
        concatVector.addVectorInPlace(concatVector2, 1.0d);
        adaGradOptimizationState.lastDerivative = concatVector2;
        adaGradOptimizationState.lastLogLikelihood = d;
        if (z) {
            return false;
        }
        log.info("\tLL: " + d);
        return false;
    }

    @Override // edu.stanford.nlp.loglinear.learning.AbstractBatchOptimizer
    protected AbstractBatchOptimizer.OptimizationState getFreshOptimizationState(ConcatVector concatVector) {
        return new AdaGradOptimizationState();
    }
}
