package edu.stanford.nlp.trees;

import edu.stanford.nlp.util.ArgumentParser;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

/* loaded from: input_file:edu/stanford/nlp/trees/SplitTrainingSet.class */
public class SplitTrainingSet {
    private static Redwood.RedwoodChannels logger = Redwood.channels(SplitTrainingSet.class);

    @ArgumentParser.Option(name = "input", gloss = "The file to use as input.", required = true)
    private static String INPUT = null;

    @ArgumentParser.Option(name = "output", gloss = "Where to send the splits.", required = true)
    private static String OUTPUT = null;

    @ArgumentParser.Option(name = "split_names", gloss = "Divisions to use for the output")
    private static String[] SPLIT_NAMES = {"train", "dev", "test"};

    @ArgumentParser.Option(name = "split_weights", gloss = "Portions to use for the divisions")
    private static Double[] SPLIT_WEIGHTS = {Double.valueOf(0.7d), Double.valueOf(0.15d), Double.valueOf(0.15d)};

    @ArgumentParser.Option(name = "seed", gloss = "Random seed to use")
    private static long SEED = 0;

    public static int weightedIndex(List<Double> list, Random random) {
        double nextDouble = random.nextDouble();
        int i = 0;
        Iterator<Double> it = list.iterator();
        while (it.hasNext()) {
            nextDouble -= it.next().doubleValue();
            if (nextDouble < 0.0d) {
                return i;
            }
            i++;
        }
        return list.size() - 1;
    }

    public static void main(String[] strArr) throws IOException {
        ArgumentParser.fillOptions((Class<?>[]) new Class[]{ArgumentParser.class, SplitTrainingSet.class}, StringUtils.argsToProperties(strArr));
        if (SPLIT_NAMES.length != SPLIT_WEIGHTS.length) {
            throw new IllegalArgumentException("Name and weight arrays must be of the same length");
        }
        double d = 0.0d;
        for (Double d2 : SPLIT_WEIGHTS) {
            d += d2.doubleValue();
            if (d2.doubleValue() < 0.0d) {
                throw new IllegalArgumentException("Split weights cannot be negative");
            }
        }
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Split weights must total to a positive weight");
        }
        ArrayList<Double> arrayList = new ArrayList();
        for (Double d3 : SPLIT_WEIGHTS) {
            arrayList.add(Double.valueOf(d3.doubleValue() / d));
        }
        logger.info("Splitting into " + arrayList.size() + " lists with weights " + arrayList);
        if (SEED == 0) {
            SEED = System.nanoTime();
            logger.info("Random seed not set by options, using " + SEED);
        }
        Random random = new Random(SEED);
        ArrayList arrayList2 = new ArrayList();
        for (Double d4 : arrayList) {
            arrayList2.add(new ArrayList());
        }
        MemoryTreebank memoryTreebank = new MemoryTreebank(reader -> {
            return new PennTreeReader(reader);
        });
        memoryTreebank.loadPath(INPUT);
        logger.info("Splitting " + memoryTreebank.size() + " trees");
        Iterator<Tree> it = memoryTreebank.iterator();
        while (it.hasNext()) {
            ((List) arrayList2.get(weightedIndex(arrayList, random))).add(it.next());
        }
        for (int i = 0; i < arrayList2.size(); i++) {
            String str = OUTPUT + "." + SPLIT_NAMES[i];
            List list = (List) arrayList2.get(i);
            logger.info("Writing " + list.size() + " trees to " + str);
            FileWriter fileWriter = new FileWriter(str);
            BufferedWriter bufferedWriter = new BufferedWriter(fileWriter);
            Iterator it2 = list.iterator();
            while (it2.hasNext()) {
                bufferedWriter.write(((Tree) it2.next()).toString());
                bufferedWriter.newLine();
            }
            bufferedWriter.close();
            fileWriter.close();
        }
    }
}
