package edu.stanford.nlp.stats;

import edu.stanford.nlp.util.Generics;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/stats/Distributions.class */
public class Distributions {
    private Distributions() {
    }

    protected static <K> Set<K> getSetOfAllKeys(Distribution<K> distribution, Distribution<K> distribution2) {
        if (distribution.getNumberOfKeys() != distribution2.getNumberOfKeys()) {
            throw new RuntimeException("Tried to compare two Distribution<K> objects but d1.numberOfKeys != d2.numberOfKeys");
        }
        Set<K> newHashSet = Generics.newHashSet(distribution.getCounter().keySet());
        newHashSet.addAll(distribution2.getCounter().keySet());
        if (newHashSet.size() > distribution.getNumberOfKeys()) {
            throw new RuntimeException("Tried to compare two Distribution<K> objects but d1.counter intersect d2.counter > numberOfKeys");
        }
        return newHashSet;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <K> double overlap(Distribution<K> distribution, Distribution<K> distribution2) {
        double d = 0.0d;
        double d2 = 1.0d;
        double d3 = 1.0d;
        for (Object obj : getSetOfAllKeys(distribution, distribution2)) {
            double probabilityOf = distribution.probabilityOf(obj);
            double probabilityOf2 = distribution2.probabilityOf(obj);
            d2 -= probabilityOf;
            d3 -= probabilityOf2;
            d += Math.min(probabilityOf, probabilityOf2);
        }
        return d + Math.min(d2, d3);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <K> Distribution<K> weightedAverage(Distribution<K> distribution, double d, Distribution<K> distribution2) {
        double d2 = 1.0d - d;
        Set setOfAllKeys = getSetOfAllKeys(distribution, distribution2);
        int numberOfKeys = distribution.getNumberOfKeys();
        ClassicCounter classicCounter = new ClassicCounter();
        for (Object obj : setOfAllKeys) {
            classicCounter.setCount(obj, (distribution.probabilityOf(obj) * d) + (distribution2.probabilityOf(obj) * d2));
        }
        return Distribution.getDistributionFromPartiallySpecifiedCounter(classicCounter, numberOfKeys);
    }

    public static <K> Distribution<K> average(Distribution<K> distribution, Distribution<K> distribution2) {
        return weightedAverage(distribution, 0.5d, distribution2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <K> double klDivergence(Distribution<K> distribution, Distribution<K> distribution2) {
        Set setOfAllKeys = getSetOfAllKeys(distribution, distribution2);
        int numberOfKeys = distribution.getNumberOfKeys();
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double log = Math.log(2.0d);
        for (Object obj : setOfAllKeys) {
            double probabilityOf = distribution.probabilityOf(obj);
            double probabilityOf2 = distribution2.probabilityOf(obj);
            numberOfKeys--;
            d2 += probabilityOf;
            d3 += probabilityOf2;
            if (probabilityOf >= 1.0E-10d) {
                double log2 = Math.log(probabilityOf / probabilityOf2);
                if (log2 == Double.POSITIVE_INFINITY) {
                    System.out.println("Didtributions.kldivergence returning +inf: p1=" + probabilityOf + ", p2=" + probabilityOf2);
                    System.out.flush();
                    return Double.POSITIVE_INFINITY;
                }
                d += probabilityOf * (log2 / log);
            }
        }
        if (numberOfKeys != 0) {
            double d4 = (1.0d - d2) / numberOfKeys;
            if (d4 > 1.0E-10d) {
                double d5 = (1.0d - d3) / numberOfKeys;
                double log3 = Math.log(d4 / d5);
                if (log3 == Double.POSITIVE_INFINITY) {
                    System.out.println("Distributions.klDivergence (remaining mass) returning +inf: p1=" + d4 + ", p2=" + d5);
                    System.out.flush();
                    return Double.POSITIVE_INFINITY;
                }
                d += numberOfKeys * d4 * (log3 / log);
            }
        }
        return d;
    }

    public static <K> double jensenShannonDivergence(Distribution<K> distribution, Distribution<K> distribution2) {
        Distribution average = average(distribution, distribution2);
        return (klDivergence(distribution, average) + klDivergence(distribution2, average)) / 2.0d;
    }

    public static <K> double skewDivergence(Distribution<K> distribution, Distribution<K> distribution2, double d) {
        return klDivergence(distribution, weightedAverage(distribution2, d, distribution));
    }

    public static <K> double informationRadius(Distribution<K> distribution, Distribution<K> distribution2) {
        Distribution average = average(distribution, distribution2);
        return klDivergence(distribution, average) + klDivergence(distribution2, average);
    }
}
