package edu.stanford.nlp.optimization;

import edu.stanford.nlp.optimization.SparseOnlineFunction;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.logging.Redwood;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Iterator;
import java.util.Random;

/* loaded from: input_file:edu/stanford/nlp/optimization/SparseAdaGradMinimizer.class */
public class SparseAdaGradMinimizer<K, F extends SparseOnlineFunction<K>> implements SparseMinimizer<K, F> {
    public boolean quiet;
    protected int numPasses;
    protected int batchSize;
    protected double eta;
    protected double lambdaL1;
    protected double lambdaL2;
    protected Counter<K> sumGradSquare;
    protected Counter<K> x;
    protected Random randGenerator;
    public final double EPS = 1.0E-15d;
    public final double soften = 1.0E-4d;
    private static Redwood.RedwoodChannels log = Redwood.channels(SparseAdaGradMinimizer.class);
    private static final NumberFormat nf = new DecimalFormat("0.000E0");

    public SparseAdaGradMinimizer(int i) {
        this(i, 0.1d);
    }

    public SparseAdaGradMinimizer(int i, double d) {
        this(i, d, 1, 0.0d, 0.0d);
    }

    public SparseAdaGradMinimizer(int i, double d, int i2, double d2, double d3) {
        this.quiet = false;
        this.randGenerator = new Random(1L);
        this.EPS = 1.0E-15d;
        this.soften = 1.0E-4d;
        this.numPasses = i;
        this.eta = d;
        this.batchSize = i2;
        this.lambdaL1 = d2;
        this.lambdaL2 = d3;
        this.sumGradSquare = new ClassicCounter();
    }

    @Override // edu.stanford.nlp.optimization.SparseMinimizer
    public Counter<K> minimize(F f, Counter<K> counter) {
        return minimize(f, counter, -1);
    }

    @Override // edu.stanford.nlp.optimization.SparseMinimizer
    public Counter<K> minimize(F f, Counter<K> counter, int i) {
        sayln("       Batch size of: " + this.batchSize);
        sayln("       Data dimension of: " + f.dataSize());
        int dataSize = ((f.dataSize() - 1) / this.batchSize) + 1;
        sayln("       Batches per pass through data:  " + dataSize);
        sayln("       Number of passes is = " + this.numPasses);
        sayln("       Max iterations is = " + i);
        ClassicCounter classicCounter = new ClassicCounter();
        int i2 = 0;
        new Timing().start();
        for (int i3 = 0; i3 < this.numPasses; i3++) {
            double d = 0.0d;
            for (int i4 = 0; i4 < dataSize; i4++) {
                int[] sample = getSample(f, this.batchSize);
                Counter<K> derivativeAt = f.derivativeAt(counter, sample);
                d += f.valueAt(counter, sample);
                Iterator<K> it = derivativeAt.keySet().iterator();
                while (true) {
                    if (it.hasNext()) {
                        K next = it.next();
                        double count = derivativeAt.getCount(next);
                        double sqrt = this.eta / (Math.sqrt(this.sumGradSquare.getCount(next)) + 1.0E-4d);
                        double sqrt2 = this.eta / (Math.sqrt(this.sumGradSquare.incrementCount(next, count * count)) + 1.0E-4d);
                        double count2 = counter.getCount(next) - (sqrt2 * derivativeAt.getCount(next));
                        double count3 = (i2 - classicCounter.getCount(next)) - 1.0d;
                        classicCounter.setCount(next, i2);
                        double signum = Math.signum(count2) * Math.max(0.0d, Math.abs(count2) - ((sqrt2 + (sqrt * count3)) * this.lambdaL1)) * Math.pow(1.0d - this.lambdaL2, sqrt2 + (sqrt * count3));
                        if (signum < 1.0E-15d) {
                            counter.remove(next);
                        } else {
                            counter.setCount(next, signum);
                        }
                        i2++;
                        if (i2 > i) {
                            sayln("Stochastic Optimization complete.  Stopped after max iterations");
                            break;
                        }
                        sayln(System.out.format("Iter %d \t batch: %d \t time=%.2f \t obj=%.4f", Integer.valueOf(i3), Integer.valueOf(i2), Double.valueOf(r0.report() / 1000.0d), Double.valueOf(d)).toString());
                    }
                }
            }
        }
        return counter;
    }

    private int[] getSample(F f, int i) {
        int[] iArr = new int[i];
        for (int i2 = 0; i2 < i; i2++) {
            iArr[i2] = this.randGenerator.nextInt(f.dataSize());
        }
        return iArr;
    }

    protected String getName() {
        return "SparseAdaGrad_batchsize" + this.batchSize + "_eta" + nf.format(this.eta) + "_lambdaL1" + nf.format(this.lambdaL1) + "_lambdaL2" + nf.format(this.lambdaL2);
    }

    protected void sayln(String str) {
        if (this.quiet) {
            return;
        }
        log.info(str);
    }
}
