package edu.stanford.nlp.classify;

import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.CollectionValuedMap;
import edu.stanford.nlp.util.Generics;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:edu/stanford/nlp/classify/KNNClassifier.class */
public class KNNClassifier<K, V> implements Classifier<K, V> {
    private static final long serialVersionUID = 7115357548209007944L;
    private boolean weightedVotes;
    private CollectionValuedMap<K, Counter<V>> instances = new CollectionValuedMap<>();
    private Map<Counter<V>, K> classLookup = Generics.newHashMap();
    private boolean l2Normalize;
    int k;

    @Override // edu.stanford.nlp.classify.Classifier
    public Collection<K> labels() {
        return this.classLookup.values();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public KNNClassifier(int i, boolean z, boolean z2) {
        this.weightedVotes = false;
        this.l2Normalize = false;
        this.k = 0;
        this.k = i;
        this.weightedVotes = z;
        this.l2Normalize = z2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void addInstances(Collection<RVFDatum<K, V>> collection) {
        for (RVFDatum<K, V> rVFDatum : collection) {
            K label = rVFDatum.label();
            Counter<V> asFeaturesCounter = rVFDatum.asFeaturesCounter();
            this.instances.add(label, asFeaturesCounter);
            this.classLookup.put(asFeaturesCounter, label);
        }
    }

    @Override // edu.stanford.nlp.classify.Classifier
    public K classOf(Datum<K, V> datum) {
        if (datum instanceof RVFDatum) {
            return (K) Counters.toSortedList(scoresOf((Datum) datum)).get(0);
        }
        return null;
    }

    @Override // edu.stanford.nlp.classify.Classifier
    public ClassicCounter<K> scoresOf(Datum<K, V> datum) {
        if (!(datum instanceof RVFDatum)) {
            return null;
        }
        RVFDatum rVFDatum = (RVFDatum) datum;
        if (this.l2Normalize) {
            ClassicCounter classicCounter = new ClassicCounter(rVFDatum.asFeaturesCounter());
            Counters.normalize(classicCounter);
            rVFDatum = new RVFDatum(classicCounter);
        }
        ClassicCounter classicCounter2 = new ClassicCounter();
        for (Counter<V> counter : this.instances.allValues()) {
            classicCounter2.setCount(counter, Counters.cosine(rVFDatum.asFeaturesCounter(), counter));
        }
        List sortedList = Counters.toSortedList(classicCounter2);
        ClassicCounter<K> classicCounter3 = new ClassicCounter<>();
        for (int i = 0; i < this.k && i < sortedList.size(); i++) {
            K k = this.classLookup.get(sortedList.get(i));
            double d = 1.0d;
            if (this.weightedVotes) {
                d = classicCounter2.getCount(sortedList.get(i));
            }
            classicCounter3.incrementCount(k, d);
        }
        return classicCounter3;
    }

    public static void main(String[] strArr) {
        ArrayList arrayList = new ArrayList();
        ClassicCounter classicCounter = new ClassicCounter();
        classicCounter.setCount("humidity", 5.0d);
        classicCounter.setCount("temperature", 35.0d);
        arrayList.add(new RVFDatum(classicCounter, "rain"));
        ClassicCounter classicCounter2 = new ClassicCounter();
        classicCounter2.setCount("humidity", 4.0d);
        classicCounter2.setCount("temperature", 32.0d);
        arrayList.add(new RVFDatum(classicCounter2, "rain"));
        ClassicCounter classicCounter3 = new ClassicCounter();
        classicCounter3.setCount("humidity", 6.0d);
        classicCounter3.setCount("temperature", 30.0d);
        arrayList.add(new RVFDatum(classicCounter3, "rain"));
        ClassicCounter classicCounter4 = new ClassicCounter();
        classicCounter4.setCount("humidity", 2.0d);
        classicCounter4.setCount("temperature", 33.0d);
        arrayList.add(new RVFDatum(classicCounter4, "dry"));
        ClassicCounter classicCounter5 = new ClassicCounter();
        classicCounter5.setCount("humidity", 1.0d);
        classicCounter5.setCount("temperature", 34.0d);
        arrayList.add(new RVFDatum(classicCounter5, "dry"));
        KNNClassifier<K, V> train = new KNNClassifierFactory(3, false, true).train(arrayList);
        ClassicCounter classicCounter6 = new ClassicCounter();
        classicCounter6.setCount("humidity", 2.0d);
        classicCounter6.setCount("temperature", 33.0d);
        RVFDatum rVFDatum = new RVFDatum(classicCounter6);
        System.out.println(train.scoresOf((Datum) rVFDatum));
        System.out.println((String) train.classOf(rVFDatum));
    }
}
