package edu.stanford.nlp.classify;

import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.util.Pair;
import java.util.Arrays;
import junit.framework.TestCase;

/* loaded from: input_file:edu/stanford/nlp/classify/GeneralDatasetTest.class */
public class GeneralDatasetTest extends TestCase {
    /* JADX WARN: Multi-variable type inference failed */
    public static void testCreateFolds() {
        Dataset dataset = new Dataset();
        dataset.add(new BasicDatum(Arrays.asList("fever", "cough", "congestion"), "cold"));
        dataset.add(new BasicDatum(Arrays.asList("fever", "cough", "nausea"), "flu"));
        dataset.add(new BasicDatum(Arrays.asList("cough", "congestion"), "cold"));
        dataset.add(new BasicDatum(Arrays.asList("cough", "congestion"), "cold"));
        dataset.add(new BasicDatum(Arrays.asList("fever", "nausea"), "flu"));
        dataset.add(new BasicDatum(Arrays.asList("cough", "sore throat"), "cold"));
        Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split = dataset.split(3, 5);
        assertEquals(4, ((GeneralDataset) split.first()).size());
        assertEquals(2, ((GeneralDataset) split.second()).size());
        assertEquals("cold", (String) ((GeneralDataset) split.first()).getDatum(((GeneralDataset) split.first()).size() - 1).label());
        assertEquals("flu", (String) ((GeneralDataset) split.second()).getDatum(((GeneralDataset) split.second()).size() - 1).label());
        Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split2 = dataset.split(0, 2);
        assertEquals(4, ((GeneralDataset) split2.first()).size());
        assertEquals(2, ((GeneralDataset) split2.second()).size());
        Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split3 = dataset.split(0.3333333333333333d);
        assertEquals(((GeneralDataset) split2.first()).size(), ((GeneralDataset) split3.first()).size());
        assertEquals(((GeneralDataset) split2.first()).labelIndex(), ((GeneralDataset) split3.first()).labelIndex());
        assertEquals(((GeneralDataset) split2.second()).size(), ((GeneralDataset) split3.second()).size());
        assertTrue(Arrays.equals(((GeneralDataset) split2.first()).labels, ((GeneralDataset) split2.first()).labels));
        assertTrue(Arrays.equals(((GeneralDataset) split2.second()).labels, ((GeneralDataset) split2.second()).labels));
        dataset.add(new BasicDatum(Arrays.asList("fever", "nausea"), "flu"));
        Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split4 = dataset.split(0.3333333333333333d);
        assertEquals(5, ((GeneralDataset) split4.first()).size());
        assertEquals(2, ((GeneralDataset) split4.second()).size());
        Pair<GeneralDataset<L, F>, GeneralDataset<L, F>> split5 = dataset.split(0.125d);
        assertEquals(7, ((GeneralDataset) split5.first()).size());
        assertEquals(0, ((GeneralDataset) split5.second()).size());
    }
}
