package edu.stanford.nlp.parser.dvparser;

import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.TwoDimensionalMap;
import edu.stanford.nlp.util.TwoDimensionalSet;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:edu/stanford/nlp/parser/dvparser/AverageDVModels.class */
public class AverageDVModels {
    private static Redwood.RedwoodChannels log = Redwood.channels(AverageDVModels.class);

    public static TwoDimensionalSet<String, String> getBinaryMatrixNames(List<TwoDimensionalMap<String, String, SimpleMatrix>> list) {
        TwoDimensionalSet<String, String> twoDimensionalSet = new TwoDimensionalSet<>();
        Iterator<TwoDimensionalMap<String, String, SimpleMatrix>> it = list.iterator();
        while (it.hasNext()) {
            Iterator<TwoDimensionalMap.Entry<String, String, SimpleMatrix>> it2 = it.next().iterator();
            while (it2.hasNext()) {
                TwoDimensionalMap.Entry<String, String, SimpleMatrix> next = it2.next();
                twoDimensionalSet.add(next.getFirstKey(), next.getSecondKey());
            }
        }
        return twoDimensionalSet;
    }

    public static Set<String> getUnaryMatrixNames(List<Map<String, SimpleMatrix>> list) {
        Set<String> newHashSet = Generics.newHashSet();
        Iterator<Map<String, SimpleMatrix>> it = list.iterator();
        while (it.hasNext()) {
            Iterator<Map.Entry<String, SimpleMatrix>> it2 = it.next().entrySet().iterator();
            while (it2.hasNext()) {
                newHashSet.add(it2.next().getKey());
            }
        }
        return newHashSet;
    }

    public static TwoDimensionalMap<String, String, SimpleMatrix> averageBinaryMatrices(List<TwoDimensionalMap<String, String, SimpleMatrix>> list) {
        TwoDimensionalMap<String, String, SimpleMatrix> treeMap = TwoDimensionalMap.treeMap();
        Iterator<Pair<String, String>> it = getBinaryMatrixNames(list).iterator();
        while (it.hasNext()) {
            Pair<String, String> next = it.next();
            int i = 0;
            SimpleMatrix simpleMatrix = null;
            for (TwoDimensionalMap<String, String, SimpleMatrix> twoDimensionalMap : list) {
                if (twoDimensionalMap.contains(next.first(), next.second())) {
                    SimpleMatrix simpleMatrix2 = twoDimensionalMap.get(next.first(), next.second());
                    i++;
                    simpleMatrix = simpleMatrix == null ? simpleMatrix2 : (SimpleMatrix) simpleMatrix.plus(simpleMatrix2);
                }
            }
            treeMap.put(next.first(), next.second(), simpleMatrix.divide(i));
        }
        return treeMap;
    }

    public static Map<String, SimpleMatrix> averageUnaryMatrices(List<Map<String, SimpleMatrix>> list) {
        TreeMap newTreeMap = Generics.newTreeMap();
        for (String str : getUnaryMatrixNames(list)) {
            int i = 0;
            SimpleMatrix simpleMatrix = null;
            for (Map<String, SimpleMatrix> map : list) {
                if (map.containsKey(str)) {
                    SimpleMatrix simpleMatrix2 = map.get(str);
                    i++;
                    simpleMatrix = simpleMatrix == null ? simpleMatrix2 : (SimpleMatrix) simpleMatrix.plus(simpleMatrix2);
                }
            }
            newTreeMap.put(str, simpleMatrix.divide(i));
        }
        return newTreeMap;
    }

    public static void main(String[] strArr) {
        String str = null;
        ArrayList newArrayList = Generics.newArrayList();
        int i = 0;
        while (i < strArr.length) {
            if (strArr[i].equalsIgnoreCase("-output")) {
                str = strArr[i + 1];
                i += 2;
            } else {
                if (!strArr[i].equalsIgnoreCase("-input")) {
                    throw new RuntimeException("Unknown argument " + strArr[i]);
                }
                while (true) {
                    i++;
                    if (i < strArr.length && !strArr[i].startsWith("-")) {
                        newArrayList.addAll(Arrays.asList(strArr[i].split(",")));
                    }
                }
            }
        }
        if (str == null) {
            log.info("Need to specify output model name with -output");
            System.exit(2);
        }
        if (newArrayList.size() == 0) {
            log.info("Need to specify input model names with -input");
            System.exit(2);
        }
        log.info("Averaging " + newArrayList);
        log.info("Outputting result to " + str);
        LexicalizedParser lexicalizedParser = null;
        ArrayList newArrayList2 = Generics.newArrayList();
        Iterator it = newArrayList.iterator();
        while (it.hasNext()) {
            LexicalizedParser loadModel = LexicalizedParser.loadModel((String) it.next(), new String[0]);
            if (lexicalizedParser == null) {
                lexicalizedParser = loadModel;
            }
            newArrayList2.add(DVParser.getModelFromLexicalizedParser(loadModel));
        }
        new DVParser(new DVModel(averageBinaryMatrices(CollectionUtils.transformAsList(newArrayList2, dVModel -> {
            return dVModel.binaryTransform;
        })), averageUnaryMatrices(CollectionUtils.transformAsList(newArrayList2, dVModel2 -> {
            return dVModel2.unaryTransform;
        })), averageBinaryMatrices(CollectionUtils.transformAsList(newArrayList2, dVModel3 -> {
            return dVModel3.binaryScore;
        })), averageUnaryMatrices(CollectionUtils.transformAsList(newArrayList2, dVModel4 -> {
            return dVModel4.unaryScore;
        })), averageUnaryMatrices(CollectionUtils.transformAsList(newArrayList2, dVModel5 -> {
            return dVModel5.wordVectors;
        })), lexicalizedParser.getOp()), lexicalizedParser).saveModel(str);
    }
}
