package edu.stanford.nlp.classify.demo;

import edu.stanford.nlp.classify.Classifier;
import edu.stanford.nlp.classify.ColumnDataClassifier;
import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.util.ErasureUtils;
import edu.stanford.nlp.util.Pair;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Iterator;

/* loaded from: input_file:edu/stanford/nlp/classify/demo/ClassifierDemo.class */
class ClassifierDemo {
    private static String where = "";

    ClassifierDemo() {
    }

    public static void main(String[] strArr) throws Exception {
        if (strArr.length > 0) {
            where = strArr[0] + File.separator;
        }
        System.out.println("Training ColumnDataClassifier");
        ColumnDataClassifier columnDataClassifier = new ColumnDataClassifier(where + "examples/cheese2007.prop");
        columnDataClassifier.trainClassifier(where + "examples/cheeseDisease.train");
        System.out.println();
        System.out.println("Testing predictions of ColumnDataClassifier");
        Iterator<String> it = ObjectBank.getLineIterator(where + "examples/cheeseDisease.test", "utf-8").iterator();
        while (it.hasNext()) {
            String next = it.next();
            Datum<String, String> makeDatumFromLine = columnDataClassifier.makeDatumFromLine(next);
            System.out.printf("%s  ==>  %s (%.4f)%n", next, columnDataClassifier.classOf(makeDatumFromLine), Double.valueOf(columnDataClassifier.scoresOf(makeDatumFromLine).getCount(columnDataClassifier.classOf(makeDatumFromLine))));
        }
        System.out.println();
        System.out.println("Testing accuracy of ColumnDataClassifier");
        Pair<Double, Double> testClassifier = columnDataClassifier.testClassifier(where + "examples/cheeseDisease.test");
        System.out.printf("Accuracy: %.3f; macro-F1: %.3f%n", testClassifier.first(), testClassifier.second());
        demonstrateSerialization();
        demonstrateSerializationColumnDataClassifier();
    }

    private static void demonstrateSerialization() throws IOException, ClassNotFoundException {
        System.out.println();
        System.out.println("Demonstrating working with a serialized classifier");
        ColumnDataClassifier columnDataClassifier = new ColumnDataClassifier(where + "examples/cheese2007.prop");
        Classifier<String, String> makeClassifier = columnDataClassifier.makeClassifier(columnDataClassifier.readTrainingExamples(where + "examples/cheeseDisease.train"));
        System.out.println();
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream);
        objectOutputStream.writeObject(makeClassifier);
        objectOutputStream.close();
        ObjectInputStream objectInputStream = new ObjectInputStream(new ByteArrayInputStream(byteArrayOutputStream.toByteArray()));
        LinearClassifier linearClassifier = (LinearClassifier) ErasureUtils.uncheckedCast(objectInputStream.readObject());
        objectInputStream.close();
        ColumnDataClassifier columnDataClassifier2 = new ColumnDataClassifier(where + "examples/cheese2007.prop");
        System.out.println();
        System.out.println("Making predictions with both classifiers");
        Iterator<String> it = ObjectBank.getLineIterator(where + "examples/cheeseDisease.test", "utf-8").iterator();
        while (it.hasNext()) {
            String next = it.next();
            Datum<String, String> makeDatumFromLine = columnDataClassifier.makeDatumFromLine(next);
            Datum<String, String> makeDatumFromLine2 = columnDataClassifier2.makeDatumFromLine(next);
            System.out.printf("%s  =origi=>  %s (%.4f)%n", next, makeClassifier.classOf(makeDatumFromLine), Double.valueOf(makeClassifier.scoresOf(makeDatumFromLine).getCount(makeClassifier.classOf(makeDatumFromLine))));
            System.out.printf("%s  =deser=>  %s (%.4f)%n", next, linearClassifier.classOf(makeDatumFromLine2), Double.valueOf(linearClassifier.scoresOf(makeDatumFromLine).getCount(linearClassifier.classOf(makeDatumFromLine))));
        }
    }

    private static void demonstrateSerializationColumnDataClassifier() throws IOException, ClassNotFoundException {
        System.out.println();
        System.out.println("Demonstrating working with a serialized classifier using serializeTo");
        ColumnDataClassifier columnDataClassifier = new ColumnDataClassifier(where + "examples/cheese2007.prop");
        columnDataClassifier.trainClassifier(where + "examples/cheeseDisease.train");
        System.out.println();
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream);
        columnDataClassifier.serializeClassifier(objectOutputStream);
        objectOutputStream.close();
        ObjectInputStream objectInputStream = new ObjectInputStream(new ByteArrayInputStream(byteArrayOutputStream.toByteArray()));
        ColumnDataClassifier classifier = ColumnDataClassifier.getClassifier(objectInputStream);
        objectInputStream.close();
        System.out.println("Making predictions with both classifiers");
        Iterator<String> it = ObjectBank.getLineIterator(where + "examples/cheeseDisease.test", "utf-8").iterator();
        while (it.hasNext()) {
            String next = it.next();
            Datum<String, String> makeDatumFromLine = columnDataClassifier.makeDatumFromLine(next);
            Datum<String, String> makeDatumFromLine2 = classifier.makeDatumFromLine(next);
            System.out.printf("%s  =origi=>  %s (%.4f)%n", next, columnDataClassifier.classOf(makeDatumFromLine), Double.valueOf(columnDataClassifier.scoresOf(makeDatumFromLine).getCount(columnDataClassifier.classOf(makeDatumFromLine))));
            System.out.printf("%s  =deser=>  %s (%.4f)%n", next, classifier.classOf(makeDatumFromLine2), Double.valueOf(classifier.scoresOf(makeDatumFromLine).getCount(classifier.classOf(makeDatumFromLine))));
        }
    }
}
