package edu.stanford.nlp.loglinear.inference;

import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import edu.stanford.nlp.loglinear.model.NDArrayDoubles;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.function.BiFunction;

/* loaded from: input_file:edu/stanford/nlp/loglinear/inference/TableFactor.class */
public class TableFactor extends NDArrayDoubles {
    public int[] neighborIndices;
    public static final boolean USE_EXP_APPROX = false;
    static final /* synthetic */ boolean $assertionsDisabled;

    public TableFactor(ConcatVector concatVector, GraphicalModel.Factor factor) {
        super(factor.featuresTable.getDimensions());
        this.neighborIndices = factor.neigborIndices;
        Iterator<int[]> fastPassByReferenceIterator = factor.featuresTable.fastPassByReferenceIterator();
        int[] next = fastPassByReferenceIterator.next();
        while (true) {
            setAssignmentLogValue(next, factor.featuresTable.getAssignmentValue(next).get().dotProduct(concatVector));
            if (!fastPassByReferenceIterator.hasNext()) {
                return;
            } else {
                fastPassByReferenceIterator.next();
            }
        }
    }

    public static double exp(double d) {
        return Double.longBitsToDouble(((long) ((1512775.0d * d) + 1.072632447E9d)) << 32);
    }

    public TableFactor(ConcatVector concatVector, GraphicalModel.Factor factor, int[] iArr) {
        if (!$assertionsDisabled && iArr.length != factor.neigborIndices.length) {
            throw new AssertionError();
        }
        int i = 0;
        for (int i2 : iArr) {
            if (i2 == -1) {
                i++;
            }
        }
        this.neighborIndices = new int[i];
        this.dimensions = new int[i];
        int[] iArr2 = new int[i];
        int[] iArr3 = new int[factor.neigborIndices.length];
        int i3 = 0;
        for (int i4 = 0; i4 < factor.neigborIndices.length; i4++) {
            if (iArr[i4] == -1) {
                this.neighborIndices[i3] = factor.neigborIndices[i4];
                this.dimensions[i3] = factor.featuresTable.getDimensions()[i4];
                iArr2[i3] = i4;
                i3++;
            } else {
                iArr3[i4] = iArr[i4];
            }
        }
        if (!$assertionsDisabled && i3 != i) {
            throw new AssertionError();
        }
        this.values = new double[combinatorialNeighborStatesCount()];
        Iterator<int[]> it = iterator();
        while (it.hasNext()) {
            int[] next = it.next();
            for (int i5 = 0; i5 < next.length; i5++) {
                iArr3[iArr2[i5]] = next[i5];
            }
            setAssignmentLogValue(next, factor.featuresTable.getAssignmentValue(iArr3).get().dotProduct(concatVector));
        }
    }

    public TableFactor observe(int i, int i2) {
        return marginalize(i, 0.0d, (num, iArr) -> {
            return num.intValue() == i2 ? (d, d2) -> {
                return d2;
            } : (d3, d4) -> {
                return d3;
            };
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    public double[][] getSummedMarginals() {
        ?? r0 = new double[this.neighborIndices.length];
        for (int i = 0; i < this.neighborIndices.length; i++) {
            r0[i] = new double[getDimensions()[i]];
        }
        double[] dArr = new double[this.neighborIndices.length];
        for (int i2 = 0; i2 < this.neighborIndices.length; i2++) {
            dArr[i2] = new double[getDimensions()[i2]];
            for (int i3 = 0; i3 < dArr[i2].length; i3++) {
                dArr[i2][i3] = -4503599627370496;
            }
        }
        Iterator<int[]> fastPassByReferenceIterator = fastPassByReferenceIterator();
        int[] next = fastPassByReferenceIterator.next();
        while (true) {
            double assignmentLogValue = getAssignmentLogValue(next);
            for (int i4 = 0; i4 < this.neighborIndices.length; i4++) {
                if (dArr[i4][next[i4]] < assignmentLogValue) {
                    dArr[i4][next[i4]] = assignmentLogValue;
                }
            }
            if (!fastPassByReferenceIterator.hasNext()) {
                break;
            }
            fastPassByReferenceIterator.next();
        }
        Iterator<int[]> fastPassByReferenceIterator2 = fastPassByReferenceIterator();
        int[] next2 = fastPassByReferenceIterator2.next();
        while (true) {
            double assignmentLogValue2 = getAssignmentLogValue(next2);
            for (int i5 = 0; i5 < this.neighborIndices.length; i5++) {
                double[] dArr2 = r0[i5];
                int i6 = next2[i5];
                dArr2[i6] = dArr2[i6] + Math.exp(assignmentLogValue2 - dArr[i5][next2[i5]]);
            }
            if (!fastPassByReferenceIterator2.hasNext()) {
                break;
            }
            fastPassByReferenceIterator2.next();
        }
        for (int i7 = 0; i7 < this.neighborIndices.length; i7++) {
            double d = 0.0d;
            for (int i8 = 0; i8 < r0[i7].length; i8++) {
                r0[i7][i8] = Math.exp(dArr[i7][i8]) * r0[i7][i8];
                d += r0[i7][i8];
            }
            if (Double.isInfinite(d)) {
                for (int i9 = 0; i9 < r0[i7].length; i9++) {
                    r0[i7][i9] = 1.0d / r0[i7].length;
                }
            } else {
                for (int i10 = 0; i10 < r0[i7].length; i10++) {
                    double[] dArr3 = r0[i7];
                    int i11 = i10;
                    dArr3[i11] = dArr3[i11] / d;
                }
            }
        }
        return r0;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    public double[][] getMaxedMarginals() {
        ?? r0 = new double[this.neighborIndices.length];
        for (int i = 0; i < this.neighborIndices.length; i++) {
            r0[i] = new double[getDimensions()[i]];
            for (int i2 = 0; i2 < r0[i].length; i2++) {
                r0[i][i2] = -4503599627370496;
            }
        }
        Iterator<int[]> fastPassByReferenceIterator = fastPassByReferenceIterator();
        int[] next = fastPassByReferenceIterator.next();
        while (true) {
            double assignmentLogValue = getAssignmentLogValue(next);
            for (int i3 = 0; i3 < this.neighborIndices.length; i3++) {
                if (r0[i3][next[i3]] < assignmentLogValue) {
                    r0[i3][next[i3]] = assignmentLogValue;
                }
            }
            if (!fastPassByReferenceIterator.hasNext()) {
                break;
            }
            fastPassByReferenceIterator.next();
        }
        for (int i4 = 0; i4 < this.neighborIndices.length; i4++) {
            normalizeLogArr(r0[i4]);
        }
        return r0;
    }

    public TableFactor maxOut(int i) {
        return marginalize(i, Double.NEGATIVE_INFINITY, (num, iArr) -> {
            return (v0, v1) -> {
                return Math.max(v0, v1);
            };
        });
    }

    public TableFactor sumOut(int i) {
        if (getDimensions().length != 2) {
            TableFactor maxOut = maxOut(i);
            TableFactor marginalize = marginalize(i, 0.0d, (num, iArr) -> {
                return (d, d2) -> {
                    return Double.valueOf(d.doubleValue() + Math.exp(d2.doubleValue() - maxOut.getAssignmentLogValue(iArr)));
                };
            });
            Iterator<int[]> it = marginalize.iterator();
            while (it.hasNext()) {
                int[] next = it.next();
                marginalize.setAssignmentLogValue(next, maxOut.getAssignmentLogValue(next) + Math.log(marginalize.getAssignmentLogValue(next)));
            }
            return marginalize;
        }
        if (this.neighborIndices[0] == i) {
            TableFactor tableFactor = new TableFactor(new int[]{this.neighborIndices[1]}, new int[]{getDimensions()[1]});
            for (int i2 = 0; i2 < tableFactor.values.length; i2++) {
                tableFactor.values[i2] = 0.0d;
            }
            double[] dArr = new double[getDimensions()[1]];
            for (int i3 = 0; i3 < getDimensions()[1]; i3++) {
                dArr[i3] = Double.NEGATIVE_INFINITY;
            }
            for (int i4 = 0; i4 < getDimensions()[0]; i4++) {
                int i5 = i4 * getDimensions()[1];
                for (int i6 = 0; i6 < getDimensions()[1]; i6++) {
                    int i7 = i5 + i6;
                    if (this.values[i7] > dArr[i6]) {
                        dArr[i6] = this.values[i7];
                    }
                }
            }
            for (int i8 = 0; i8 < getDimensions()[0]; i8++) {
                int i9 = i8 * getDimensions()[1];
                for (int i10 = 0; i10 < getDimensions()[1]; i10++) {
                    int i11 = i9 + i10;
                    if (Double.isFinite(dArr[i10])) {
                        double[] dArr2 = tableFactor.values;
                        int i12 = i10;
                        dArr2[i12] = dArr2[i12] + Math.exp(this.values[i11] - dArr[i10]);
                    }
                }
            }
            for (int i13 = 0; i13 < getDimensions()[1]; i13++) {
                if (Double.isFinite(dArr[i13])) {
                    tableFactor.values[i13] = dArr[i13] + Math.log(tableFactor.values[i13]);
                } else {
                    tableFactor.values[i13] = dArr[i13];
                }
            }
            return tableFactor;
        }
        if (!$assertionsDisabled && this.neighborIndices[1] != i) {
            throw new AssertionError();
        }
        TableFactor tableFactor2 = new TableFactor(new int[]{this.neighborIndices[0]}, new int[]{getDimensions()[0]});
        for (int i14 = 0; i14 < tableFactor2.values.length; i14++) {
            tableFactor2.values[i14] = 0.0d;
        }
        double[] dArr3 = new double[getDimensions()[0]];
        for (int i15 = 0; i15 < getDimensions()[0]; i15++) {
            dArr3[i15] = Double.NEGATIVE_INFINITY;
        }
        for (int i16 = 0; i16 < getDimensions()[0]; i16++) {
            int i17 = i16 * getDimensions()[1];
            for (int i18 = 0; i18 < getDimensions()[1]; i18++) {
                int i19 = i17 + i18;
                if (this.values[i19] > dArr3[i16]) {
                    dArr3[i16] = this.values[i19];
                }
            }
        }
        for (int i20 = 0; i20 < getDimensions()[0]; i20++) {
            int i21 = i20 * getDimensions()[1];
            for (int i22 = 0; i22 < getDimensions()[1]; i22++) {
                int i23 = i21 + i22;
                if (Double.isFinite(dArr3[i20])) {
                    double[] dArr4 = tableFactor2.values;
                    int i24 = i20;
                    dArr4[i24] = dArr4[i24] + Math.exp(this.values[i23] - dArr3[i20]);
                }
            }
        }
        for (int i25 = 0; i25 < getDimensions()[0]; i25++) {
            if (Double.isFinite(dArr3[i25])) {
                tableFactor2.values[i25] = dArr3[i25] + Math.log(tableFactor2.values[i25]);
            } else {
                tableFactor2.values[i25] = dArr3[i25];
            }
        }
        return tableFactor2;
    }

    public TableFactor multiply(TableFactor tableFactor) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i : this.neighborIndices) {
            arrayList.add(Integer.valueOf(i));
            arrayList3.add(Integer.valueOf(i));
        }
        for (int i2 : tableFactor.neighborIndices) {
            arrayList2.add(Integer.valueOf(i2));
            if (!arrayList3.contains(Integer.valueOf(i2))) {
                arrayList3.add(Integer.valueOf(i2));
            }
        }
        int[] iArr = new int[arrayList3.size()];
        int[] iArr2 = new int[iArr.length];
        for (int i3 = 0; i3 < arrayList3.size(); i3++) {
            int intValue = ((Integer) arrayList3.get(i3)).intValue();
            iArr[i3] = intValue;
            if (!$assertionsDisabled && ((getVariableSize(intValue) != 0 || tableFactor.getVariableSize(intValue) <= 0) && ((getVariableSize(intValue) <= 0 || tableFactor.getVariableSize(intValue) != 0) && getVariableSize(intValue) != tableFactor.getVariableSize(intValue)))) {
                throw new AssertionError();
            }
            iArr2[i3] = Math.max(getVariableSize(((Integer) arrayList3.get(i3)).intValue()), tableFactor.getVariableSize(((Integer) arrayList3.get(i3)).intValue()));
        }
        TableFactor tableFactor2 = new TableFactor(iArr, iArr2);
        if (arrayList2.size() == 1 && arrayList3.size() == arrayList.size() && arrayList.size() == 2) {
            int indexOf = arrayList3.indexOf(Integer.valueOf(((Integer) arrayList2.get(0)).intValue()));
            if (indexOf == 0) {
                for (int i4 = 0; i4 < iArr2[0]; i4++) {
                    double d = tableFactor.values[i4];
                    int i5 = i4 * iArr2[1];
                    for (int i6 = 0; i6 < iArr2[1]; i6++) {
                        int i7 = i5 + i6;
                        tableFactor2.values[i7] = this.values[i7] + d;
                    }
                }
            } else if (indexOf == 1) {
                for (int i8 = 0; i8 < iArr2[0]; i8++) {
                    int i9 = i8 * iArr2[1];
                    for (int i10 = 0; i10 < iArr2[1]; i10++) {
                        int i11 = i9 + i10;
                        tableFactor2.values[i11] = this.values[i11] + tableFactor.values[i10];
                    }
                }
            }
        } else {
            if (arrayList.size() == 1 && arrayList3.size() == arrayList2.size() && arrayList3.size() == 2) {
                return tableFactor.multiply(this);
            }
            int[] iArr3 = new int[tableFactor2.neighborIndices.length];
            int[] iArr4 = new int[tableFactor2.neighborIndices.length];
            for (int i12 = 0; i12 < tableFactor2.neighborIndices.length; i12++) {
                iArr3[i12] = arrayList.indexOf(Integer.valueOf(tableFactor2.neighborIndices[i12]));
                iArr4[i12] = arrayList2.indexOf(Integer.valueOf(tableFactor2.neighborIndices[i12]));
            }
            int[] iArr5 = new int[this.neighborIndices.length];
            int[] iArr6 = new int[tableFactor.neighborIndices.length];
            Iterator<int[]> fastPassByReferenceIterator = tableFactor2.fastPassByReferenceIterator();
            int[] next = fastPassByReferenceIterator.next();
            while (true) {
                for (int i13 = 0; i13 < next.length; i13++) {
                    if (iArr3[i13] != -1) {
                        iArr5[iArr3[i13]] = next[i13];
                    }
                    if (iArr4[i13] != -1) {
                        iArr6[iArr4[i13]] = next[i13];
                    }
                }
                tableFactor2.setAssignmentLogValue(next, getAssignmentLogValue(iArr5) + tableFactor.getAssignmentLogValue(iArr6));
                if (!fastPassByReferenceIterator.hasNext()) {
                    break;
                }
                fastPassByReferenceIterator.next();
            }
        }
        return tableFactor2;
    }

    public double valueSum() {
        double d = 0.0d;
        Iterator<int[]> it = iterator();
        while (it.hasNext()) {
            double assignmentLogValue = getAssignmentLogValue(it.next());
            if (assignmentLogValue > d) {
                d = assignmentLogValue;
            }
        }
        double d2 = 0.0d;
        Iterator<int[]> it2 = iterator();
        while (it2.hasNext()) {
            d2 += Math.exp(getAssignmentLogValue(it2.next()) - d);
        }
        return d2 * Math.exp(d);
    }

    @Override // edu.stanford.nlp.loglinear.model.NDArrayDoubles
    public double getAssignmentValue(int[] iArr) {
        return Math.exp(super.getAssignmentValue(iArr));
    }

    @Override // edu.stanford.nlp.loglinear.model.NDArrayDoubles
    public void setAssignmentValue(int[] iArr, double d) {
        super.setAssignmentValue(iArr, Math.log(d));
    }

    private double getAssignmentLogValue(int[] iArr) {
        return super.getAssignmentValue(iArr);
    }

    private void setAssignmentLogValue(int[] iArr, double d) {
        super.setAssignmentValue(iArr, d);
    }

    private TableFactor marginalize(int i, double d, BiFunction<Integer, int[], BiFunction<Double, Double, Double>> biFunction) {
        if (!$assertionsDisabled && getDimensions().length <= 1) {
            throw new AssertionError();
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 : this.neighborIndices) {
            if (i2 != i) {
                arrayList.add(Integer.valueOf(i2));
            }
        }
        int[] iArr = new int[arrayList.size()];
        int[] iArr2 = new int[iArr.length];
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            int intValue = ((Integer) arrayList.get(i3)).intValue();
            iArr[i3] = intValue;
            iArr2[i3] = getVariableSize(intValue);
        }
        TableFactor tableFactor = new TableFactor(iArr, iArr2);
        int[] iArr3 = new int[this.neighborIndices.length];
        for (int i4 = 0; i4 < this.neighborIndices.length; i4++) {
            iArr3[i4] = arrayList.indexOf(Integer.valueOf(this.neighborIndices[i4]));
        }
        Iterator<int[]> it = tableFactor.iterator();
        while (it.hasNext()) {
            tableFactor.setAssignmentLogValue(it.next(), d);
        }
        int[] iArr4 = new int[tableFactor.neighborIndices.length];
        int i5 = 0;
        Iterator<int[]> fastPassByReferenceIterator = fastPassByReferenceIterator();
        int[] next = fastPassByReferenceIterator.next();
        while (true) {
            for (int i6 = 0; i6 < next.length; i6++) {
                if (iArr3[i6] != -1) {
                    iArr4[iArr3[i6]] = next[i6];
                } else {
                    i5 = next[i6];
                }
            }
            tableFactor.setAssignmentLogValue(iArr4, biFunction.apply(Integer.valueOf(i5), iArr4).apply(Double.valueOf(tableFactor.getAssignmentLogValue(iArr4)), Double.valueOf(getAssignmentLogValue(next))).doubleValue());
            if (!fastPassByReferenceIterator.hasNext()) {
                return tableFactor;
            }
            fastPassByReferenceIterator.next();
        }
    }

    private int getVariableSize(int i) {
        for (int i2 = 0; i2 < this.neighborIndices.length; i2++) {
            if (this.neighborIndices[i2] == i) {
                return getDimensions()[i2];
            }
        }
        return 0;
    }

    private static void normalizeLogArr(double[] dArr) {
        double d = Double.NEGATIVE_INFINITY;
        for (double d2 : dArr) {
            if (d2 > d) {
                d = d2;
            }
        }
        double d3 = 0.0d;
        for (double d4 : dArr) {
            d3 += Math.exp(d4 - d);
        }
        double log = d + Math.log(d3);
        if (Double.isInfinite(log)) {
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = 1.0d / dArr.length;
            }
            return;
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = Math.exp(dArr[i2] - log);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public TableFactor(int[] iArr, int[] iArr2) {
        super(iArr2);
        this.neighborIndices = iArr;
        for (int i = 0; i < this.values.length; i++) {
            this.values[i] = Double.NEGATIVE_INFINITY;
        }
    }

    private boolean assertsEnabled() {
        boolean z = false;
        if (!$assertionsDisabled) {
            z = true;
            if (1 == 0) {
                throw new AssertionError();
            }
        }
        return z;
    }

    static {
        $assertionsDisabled = !TableFactor.class.desiredAssertionStatus();
    }
}
