package edu.stanford.nlp.classify;

import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.stats.ClassicCounter;
import java.util.ArrayList;
import java.util.Iterator;
import junit.framework.Assert;
import junit.framework.TestCase;

/* loaded from: input_file:edu/stanford/nlp/classify/LinearClassifierITest.class */
public class LinearClassifierITest extends TestCase {
    private static <L, F> RVFDatum<L, F> newDatum(L l, F[] fArr, Double[] dArr) {
        ClassicCounter classicCounter = new ClassicCounter();
        for (int i = 0; i < fArr.length; i++) {
            classicCounter.setCount(fArr[i], dArr[i].doubleValue());
        }
        return new RVFDatum<>(classicCounter, l);
    }

    private static void testStrBinaryDatums(double d, double d2, double d3, double d4) throws Exception {
        RVFDataset rVFDataset = new RVFDataset();
        RVFDatum newDatum = newDatum("alpha", new String[]{"f1", "f2"}, new Double[]{Double.valueOf(d), Double.valueOf(d2)});
        RVFDatum newDatum2 = newDatum("beta", new String[]{"f1", "f2"}, new Double[]{Double.valueOf(d3), Double.valueOf(d4)});
        rVFDataset.add(newDatum);
        rVFDataset.add(newDatum2);
        LinearClassifier trainClassifier = new LinearClassifierFactory().trainClassifier((GeneralDataset) rVFDataset);
        Assert.assertEquals((String) newDatum.label(), (String) trainClassifier.classOf(newDatum));
        Assert.assertEquals((String) newDatum2.label(), (String) trainClassifier.classOf(newDatum2));
    }

    public void testStrBinaryDatums() throws Exception {
        testStrBinaryDatums(-1.0d, 0.0d, 1.0d, 0.0d);
        testStrBinaryDatums(1.0d, 0.0d, -1.0d, 0.0d);
        testStrBinaryDatums(0.0d, 1.0d, 0.0d, -1.0d);
        testStrBinaryDatums(0.0d, -1.0d, 0.0d, 1.0d);
        testStrBinaryDatums(1.0d, 1.0d, -1.0d, -1.0d);
        testStrBinaryDatums(0.0d, 1.0d, 1.0d, 0.0d);
        testStrBinaryDatums(1.0d, 0.0d, 0.0d, 1.0d);
    }

    public void testStrMultiClassDatums() throws Exception {
        RVFDataset rVFDataset = new RVFDataset();
        ArrayList<RVFDatum> arrayList = new ArrayList();
        arrayList.add(newDatum("alpha", new String[]{"f1", "f2"}, new Double[]{Double.valueOf(1.0d), Double.valueOf(0.0d)}));
        arrayList.add(newDatum("beta", new String[]{"f1", "f2"}, new Double[]{Double.valueOf(0.0d), Double.valueOf(1.0d)}));
        arrayList.add(newDatum("charlie", new String[]{"f1", "f2"}, new Double[]{Double.valueOf(5.0d), Double.valueOf(5.0d)}));
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            rVFDataset.add((RVFDatum) it.next());
        }
        LinearClassifier trainClassifier = new LinearClassifierFactory().trainClassifier((GeneralDataset) rVFDataset);
        RVFDatum newDatum = newDatum("alpha", new String[]{"f1", "f2", "f3"}, new Double[]{Double.valueOf(2.0d), Double.valueOf(0.0d), Double.valueOf(5.5d)});
        for (RVFDatum rVFDatum : arrayList) {
            Assert.assertEquals((String) rVFDatum.label(), (String) trainClassifier.classOf(rVFDatum));
        }
        Assert.assertEquals(newDatum.label(), trainClassifier.classOf(newDatum));
    }
}
