package edu.stanford.nlp.stats;

import java.util.Iterator;
import java.util.Random;

/* loaded from: input_file:edu/stanford/nlp/stats/Dirichlet.class */
public class Dirichlet<E> implements ConjugatePrior<Multinomial<E>, E> {
    private static final long serialVersionUID = 1;
    private Counter<E> parameters;

    public Dirichlet(Counter<E> counter) {
        checkParameters(counter);
        this.parameters = new ClassicCounter(counter);
    }

    private void checkParameters(Counter<E> counter) {
        Iterator<E> it = counter.keySet().iterator();
        while (it.hasNext()) {
            if (counter.getCount(it.next()) < 0.0d) {
                throw new RuntimeException("Parameters must be non-negative!");
            }
        }
        if (counter.totalCount() <= 0.0d) {
            throw new RuntimeException("Parameters must have positive mass!");
        }
    }

    @Override // edu.stanford.nlp.stats.ProbabilityDistribution
    public Multinomial<E> drawSample(Random random) {
        return drawSample(random, this.parameters);
    }

    public static <F> Multinomial<F> drawSample(Random random, Counter<F> counter) {
        ClassicCounter classicCounter = new ClassicCounter();
        double d = 0.0d;
        for (F f : counter.keySet()) {
            double doubleValue = Gamma.drawSample(random, counter.getCount(f)).doubleValue();
            d += doubleValue;
            classicCounter.setCount(f, doubleValue);
        }
        for (E e : classicCounter.keySet()) {
            classicCounter.setCount(e, classicCounter.getCount(e) / d);
        }
        return new Multinomial<>(classicCounter);
    }

    public static double[] drawSample(Random random, double[] dArr) {
        double d = 0.0d;
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            double doubleValue = Gamma.drawSample(random, dArr[i]).doubleValue();
            d += doubleValue;
            dArr2[i] = doubleValue;
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            int i3 = i2;
            dArr2[i3] = dArr2[i3] / d;
        }
        return dArr2;
    }

    public static double sampleBeta(double d, double d2, Random random) {
        ClassicCounter classicCounter = new ClassicCounter();
        classicCounter.setCount(true, d);
        classicCounter.setCount(false, d2);
        return new Dirichlet(classicCounter).drawSample(random).probabilityOf(true);
    }

    @Override // edu.stanford.nlp.stats.ConjugatePrior
    public double getPredictiveProbability(E e) {
        return this.parameters.getCount(e) / this.parameters.totalCount();
    }

    @Override // edu.stanford.nlp.stats.ConjugatePrior
    public double getPredictiveLogProbability(E e) {
        return Math.log(getPredictiveProbability(e));
    }

    @Override // edu.stanford.nlp.stats.ConjugatePrior
    public Dirichlet<E> getPosteriorDistribution(Counter<E> counter) {
        ClassicCounter classicCounter = new ClassicCounter(this.parameters);
        Counters.addInPlace(classicCounter, counter);
        return new Dirichlet<>(classicCounter);
    }

    @Override // edu.stanford.nlp.stats.ConjugatePrior
    public double getPosteriorPredictiveProbability(Counter<E> counter, E e) {
        return (this.parameters.getCount(e) + counter.getCount(e)) / (this.parameters.totalCount() + counter.totalCount());
    }

    @Override // edu.stanford.nlp.stats.ConjugatePrior
    public double getPosteriorPredictiveLogProbability(Counter<E> counter, E e) {
        return Math.log(getPosteriorPredictiveProbability(counter, e));
    }

    @Override // edu.stanford.nlp.stats.ProbabilityDistribution
    public double probabilityOf(Multinomial<E> multinomial) {
        return 0.0d;
    }

    public static double unnormalizedLogProbabilityOf(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr2.length; i++) {
            if (dArr[i] > 0.0d) {
                d += (dArr2[i] - 1.0d) * Math.log(dArr[i]);
            }
        }
        return d;
    }

    @Override // edu.stanford.nlp.stats.ProbabilityDistribution
    public double logProbabilityOf(Multinomial<E> multinomial) {
        return 0.0d;
    }

    public String toString() {
        return Counters.toBiggestValuesFirstString(this.parameters, 50);
    }
}
