package edu.stanford.nlp.loglinear.learning;

import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import edu.stanford.nlp.util.RuntimeInterruptedException;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.lang.management.ManagementFactory;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

/* loaded from: input_file:edu/stanford/nlp/loglinear/learning/AbstractBatchOptimizer.class */
public abstract class AbstractBatchOptimizer {
    private static Redwood.RedwoodChannels log = Redwood.channels(AbstractBatchOptimizer.class);
    List<Constraint> constraints = new ArrayList();

    /* loaded from: input_file:edu/stanford/nlp/loglinear/learning/AbstractBatchOptimizer$Constraint.class */
    private static class Constraint {
        int component;
        boolean isSparse = true;
        int index;
        double value;
        double[] arr;

        public Constraint(int i, int i2, double d) {
            this.component = i;
            this.index = i2;
            this.value = d;
        }

        public Constraint(int i, double[] dArr) {
            this.component = i;
            this.arr = dArr;
        }

        public void applyToWeights(ConcatVector concatVector) {
            if (this.isSparse) {
                concatVector.setSparseComponent(this.component, this.index, this.value);
            } else {
                concatVector.setDenseComponent(this.component, this.arr);
            }
        }

        public void applyToDerivative(ConcatVector concatVector) {
            if (this.isSparse) {
                concatVector.setSparseComponent(this.component, this.index, 0.0d);
            } else {
                concatVector.setDenseComponent(this.component, new double[]{0.0d});
            }
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/loglinear/learning/AbstractBatchOptimizer$GradientWorker.class */
    private static class GradientWorker<T> implements Runnable {
        ConcatVector localDerivative;
        TrainingWorker mainWorker;
        int threadIdx;
        int numThreads;
        List<T> queue;
        AbstractDifferentiableFunction<T> fn;
        ConcatVector weights;
        double localLogLikelihood = 0.0d;
        long jvmThreadId = 0;
        long finishedAtTime = 0;
        long cpuTimeRequired = 0;

        public GradientWorker(TrainingWorker<T> trainingWorker, int i, int i2, List<T> list, AbstractDifferentiableFunction<T> abstractDifferentiableFunction, ConcatVector concatVector) {
            this.mainWorker = trainingWorker;
            this.threadIdx = i;
            this.numThreads = i2;
            this.queue = list;
            this.fn = abstractDifferentiableFunction;
            this.weights = concatVector;
            this.localDerivative = concatVector.newEmptyClone();
        }

        @Override // java.lang.Runnable
        public void run() {
            long threadCpuTime = ManagementFactory.getThreadMXBean().getThreadCpuTime(this.jvmThreadId);
            Iterator<T> it = this.queue.iterator();
            while (it.hasNext()) {
                this.localLogLikelihood += this.fn.getSummaryForInstance(it.next(), this.weights, this.localDerivative);
                if (this.mainWorker.isFinished) {
                    return;
                }
            }
            this.finishedAtTime = System.currentTimeMillis();
            this.cpuTimeRequired = ManagementFactory.getThreadMXBean().getThreadCpuTime(this.jvmThreadId) - threadCpuTime;
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/loglinear/learning/AbstractBatchOptimizer$OptimizationState.class */
    protected abstract class OptimizationState {
        /* JADX INFO: Access modifiers changed from: protected */
        public OptimizationState() {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/stanford/nlp/loglinear/learning/AbstractBatchOptimizer$TrainingWorker.class */
    public class TrainingWorker<T> implements Runnable {
        ConcatVector weights;
        OptimizationState optimizationState;
        boolean isFinished = false;
        boolean useThreads;
        T[] dataset;
        AbstractDifferentiableFunction<T> fn;
        double l2regularization;
        double convergenceDerivativeNorm;
        boolean quiet;
        final Object naturalTerminationBarrier;
        static final /* synthetic */ boolean $assertionsDisabled;

        public TrainingWorker(T[] tArr, AbstractDifferentiableFunction<T> abstractDifferentiableFunction, ConcatVector concatVector, double d, double d2, boolean z) {
            this.useThreads = Runtime.getRuntime().availableProcessors() > 1;
            this.naturalTerminationBarrier = new Object();
            this.optimizationState = AbstractBatchOptimizer.this.getFreshOptimizationState(concatVector);
            this.weights = concatVector.deepClone();
            this.dataset = tArr;
            this.fn = abstractDifferentiableFunction;
            this.l2regularization = d;
            this.convergenceDerivativeNorm = d2;
            this.quiet = z;
        }

        /* JADX WARN: Multi-variable type inference failed */
        private int estimateRelativeRuntime(T t) {
            if (!(t instanceof GraphicalModel)) {
                return 1;
            }
            int i = 0;
            Iterator<GraphicalModel.Factor> it = ((GraphicalModel) t).factors.iterator();
            while (it.hasNext()) {
                i += it.next().featuresTable.combinatorialNeighborStatesCount();
            }
            return i;
        }

        @Override // java.lang.Runnable
        public void run() {
            int max = Math.max(1, Runtime.getRuntime().availableProcessors());
            List[] listArr = new List[max];
            Random random = new Random();
            if (this.useThreads) {
                for (int i = 0; i < max; i++) {
                    listArr[i] = new ArrayList();
                }
                int[] iArr = new int[max];
                for (T t : this.dataset) {
                    int estimateRelativeRuntime = estimateRelativeRuntime(t);
                    int i2 = 0;
                    for (int i3 = 0; i3 < max; i3++) {
                        if (iArr[i3] < iArr[i2]) {
                            i2 = i3;
                        }
                    }
                    int i4 = i2;
                    iArr[i4] = iArr[i4] + estimateRelativeRuntime;
                    listArr[i2].add(t);
                }
            }
            while (true) {
                if (this.isFinished) {
                    break;
                }
                long currentTimeMillis = System.currentTimeMillis();
                long j = 0;
                ConcatVector newEmptyClone = this.weights.newEmptyClone();
                double d = 0.0d;
                if (this.useThreads) {
                    GradientWorker[] gradientWorkerArr = new GradientWorker[max];
                    Thread[] threadArr = new Thread[max];
                    for (int i5 = 0; i5 < gradientWorkerArr.length; i5++) {
                        gradientWorkerArr[i5] = new GradientWorker(this, i5, max, listArr[i5], this.fn, this.weights);
                        threadArr[i5] = new Thread(gradientWorkerArr[i5]);
                        gradientWorkerArr[i5].jvmThreadId = threadArr[i5].getId();
                        threadArr[i5].start();
                    }
                    long j2 = Long.MAX_VALUE;
                    long j3 = Long.MIN_VALUE;
                    long j4 = Long.MAX_VALUE;
                    long j5 = Long.MIN_VALUE;
                    int i6 = 0;
                    int i7 = 0;
                    for (int i8 = 0; i8 < gradientWorkerArr.length; i8++) {
                        try {
                            threadArr[i8].join();
                            d += gradientWorkerArr[i8].localLogLikelihood;
                            newEmptyClone.addVectorInPlace(gradientWorkerArr[i8].localDerivative, 1.0d);
                            if (gradientWorkerArr[i8].finishedAtTime < j2) {
                                j2 = gradientWorkerArr[i8].finishedAtTime;
                            }
                            if (gradientWorkerArr[i8].finishedAtTime > j3) {
                                j3 = gradientWorkerArr[i8].finishedAtTime;
                            }
                            if (gradientWorkerArr[i8].cpuTimeRequired < j4) {
                                i7 = i8;
                                j4 = gradientWorkerArr[i8].cpuTimeRequired;
                            }
                            if (gradientWorkerArr[i8].cpuTimeRequired > j5) {
                                i6 = i8;
                                j5 = gradientWorkerArr[i8].cpuTimeRequired;
                            }
                        } catch (InterruptedException e) {
                            throw new RuntimeInterruptedException(e);
                        }
                    }
                    j = j3 - j2;
                    int floor = (int) Math.floor(listArr[i6].size() * ((j5 - j4) / j5) * 0.5d);
                    for (int i9 = 0; i9 < floor; i9++) {
                        int nextInt = random.nextInt(listArr[i6].size());
                        Object obj = listArr[i6].get(nextInt);
                        listArr[i6].remove(nextInt);
                        listArr[i7].add(obj);
                    }
                    if (this.isFinished) {
                        return;
                    }
                } else {
                    for (T t2 : this.dataset) {
                        if (!$assertionsDisabled && t2 == null) {
                            throw new AssertionError();
                        }
                        d += this.fn.getSummaryForInstance(t2, this.weights, newEmptyClone);
                        if (this.isFinished) {
                            return;
                        }
                    }
                }
                newEmptyClone.mapInPlace(d2 -> {
                    return Double.valueOf(d2.doubleValue() / this.dataset.length);
                });
                long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
                double length = (d / this.dataset.length) - (this.l2regularization * this.weights.dotProduct(this.weights));
                newEmptyClone.addVectorInPlace(this.weights, (-2.0d) * this.l2regularization);
                Iterator<Constraint> it = AbstractBatchOptimizer.this.constraints.iterator();
                while (it.hasNext()) {
                    it.next().applyToDerivative(newEmptyClone);
                }
                double dotProduct = newEmptyClone.dotProduct(newEmptyClone);
                if (dotProduct >= this.convergenceDerivativeNorm) {
                    if (!this.quiet) {
                        AbstractBatchOptimizer.log.info("[" + currentTimeMillis2 + " ms, threads waiting " + j + " ms]");
                    }
                    boolean updateWeights = AbstractBatchOptimizer.this.updateWeights(this.weights, newEmptyClone, length, this.optimizationState, this.quiet);
                    Iterator<Constraint> it2 = AbstractBatchOptimizer.this.constraints.iterator();
                    while (it2.hasNext()) {
                        it2.next().applyToWeights(this.weights);
                    }
                    if (updateWeights) {
                        break;
                    }
                } else if (!this.quiet) {
                    AbstractBatchOptimizer.log.info("Derivative norm " + dotProduct + " < " + this.convergenceDerivativeNorm + ": quitting");
                }
            }
            synchronized (this.naturalTerminationBarrier) {
                this.naturalTerminationBarrier.notifyAll();
            }
            this.isFinished = true;
        }

        static {
            $assertionsDisabled = !AbstractBatchOptimizer.class.desiredAssertionStatus();
        }
    }

    public <T> ConcatVector optimize(T[] tArr, AbstractDifferentiableFunction<T> abstractDifferentiableFunction) {
        return optimize(tArr, abstractDifferentiableFunction, new ConcatVector(0), 0.0d, 1.0E-5d, false);
    }

    public <T> ConcatVector optimize(T[] tArr, AbstractDifferentiableFunction<T> abstractDifferentiableFunction, ConcatVector concatVector, double d, double d2, boolean z) {
        if (z) {
            log.info("[Beginning quiet training]");
        } else {
            log.info("\n**************\nBeginning training\n");
        }
        TrainingWorker trainingWorker = new TrainingWorker(tArr, abstractDifferentiableFunction, concatVector, d, d2, z);
        new Thread(trainingWorker).start();
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(System.in));
        if (z) {
            while (!trainingWorker.isFinished) {
                synchronized (trainingWorker.naturalTerminationBarrier) {
                    try {
                        trainingWorker.naturalTerminationBarrier.wait();
                    } catch (InterruptedException e) {
                        throw new RuntimeInterruptedException(e);
                    }
                }
            }
            log.info("[Quiet training complete]");
            return trainingWorker.weights;
        }
        log.info("NOTE: you can press any key (and maybe ENTER afterwards to jog stdin) to terminate learning early.");
        log.info("The convergence criteria are quite aggressive if left uninterrupted, and will run for a while");
        log.info("if left to their own devices.\n");
        while (!trainingWorker.isFinished) {
            try {
            } catch (IOException e2) {
                e2.printStackTrace();
            }
            if (bufferedReader.ready()) {
                log.info("received quit command: quitting");
                log.info("training completed by interruption");
                trainingWorker.isFinished = true;
                return trainingWorker.weights;
            }
            continue;
        }
        log.info("training completed without interruption");
        return trainingWorker.weights;
    }

    public void addSparseConstraint(int i, int i2, double d) {
        this.constraints.add(new Constraint(i, i2, d));
    }

    public void addDenseConstraint(int i, double[] dArr) {
        this.constraints.add(new Constraint(i, dArr));
    }

    public abstract boolean updateWeights(ConcatVector concatVector, ConcatVector concatVector2, double d, OptimizationState optimizationState, boolean z);

    protected abstract OptimizationState getFreshOptimizationState(ConcatVector concatVector);
}
