package edu.stanford.nlp.classify;

import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.ArrayMap;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.Collection;
import java.util.Map;

/* loaded from: input_file:edu/stanford/nlp/classify/OneVsAllClassifier.class */
public class OneVsAllClassifier<L, F> implements Classifier<L, F> {
    private static final long serialVersionUID = -743792054415242776L;
    private static final String POS_LABEL = "+1";
    private static final String NEG_LABEL = "-1";
    private static final Index<String> binaryIndex = new HashIndex();
    private static final int posIndex;
    private Index<F> featureIndex;
    private Index<L> labelIndex;
    private Map<L, Classifier<String, F>> binaryClassifiers;
    private L defaultLabel;
    private static final Redwood.RedwoodChannels logger;

    public OneVsAllClassifier(Index<F> index, Index<L> index2) {
        this(index, index2, Generics.newHashMap(), null);
    }

    public OneVsAllClassifier(Index<F> index, Index<L> index2, Map<L, Classifier<String, F>> map) {
        this(index, index2, map, null);
    }

    public OneVsAllClassifier(Index<F> index, Index<L> index2, Map<L, Classifier<String, F>> map, L l) {
        this.featureIndex = index;
        this.labelIndex = index2;
        this.binaryClassifiers = map;
        this.defaultLabel = l;
    }

    public void addBinaryClassifier(L l, Classifier<String, F> classifier) {
        this.binaryClassifiers.put(l, classifier);
    }

    protected Classifier<String, F> getBinaryClassifier(L l) {
        return this.binaryClassifiers.get(l);
    }

    @Override // edu.stanford.nlp.classify.Classifier
    public L classOf(Datum<L, F> datum) {
        Counter<L> scoresOf = scoresOf(datum);
        return scoresOf != null ? (L) Counters.argmax(scoresOf) : this.defaultLabel;
    }

    @Override // edu.stanford.nlp.classify.Classifier
    public Counter<L> scoresOf(Datum<L, F> datum) {
        ClassicCounter classicCounter = new ClassicCounter();
        for (L l : this.labelIndex) {
            ArrayMap arrayMap = new ArrayMap();
            arrayMap.put(l, POS_LABEL);
            classicCounter.setCount(l, getBinaryClassifier(l).scoresOf(GeneralDataset.mapDatum(datum, arrayMap, NEG_LABEL)).getCount(POS_LABEL));
        }
        return classicCounter;
    }

    @Override // edu.stanford.nlp.classify.Classifier
    public Collection<L> labels() {
        return this.labelIndex.objectsList();
    }

    public static <L, F> OneVsAllClassifier<L, F> train(ClassifierFactory<String, F, Classifier<String, F>> classifierFactory, GeneralDataset<L, F> generalDataset) {
        return train(classifierFactory, generalDataset, generalDataset.labelIndex().objectsList());
    }

    public static <L, F> OneVsAllClassifier<L, F> train(ClassifierFactory<String, F, Classifier<String, F>> classifierFactory, GeneralDataset<L, F> generalDataset, Collection<L> collection) {
        Index<L> labelIndex = generalDataset.labelIndex();
        Index<F> featureIndex = generalDataset.featureIndex();
        Map newHashMap = Generics.newHashMap();
        for (L l : collection) {
            logger.info("Training " + l + " = " + labelIndex.indexOf(l) + ", posIndex = " + posIndex);
            ArrayMap arrayMap = new ArrayMap();
            arrayMap.put(l, POS_LABEL);
            newHashMap.put(l, classifierFactory.trainClassifier(generalDataset.mapDataset(generalDataset, binaryIndex, arrayMap, NEG_LABEL)));
        }
        return new OneVsAllClassifier<>(featureIndex, labelIndex, newHashMap);
    }

    static {
        binaryIndex.add(POS_LABEL);
        binaryIndex.add(NEG_LABEL);
        posIndex = binaryIndex.indexOf(POS_LABEL);
        logger = Redwood.channels(OneVsAllClassifier.class);
    }
}
