package edu.stanford.nlp.scenegraph;

import edu.stanford.nlp.classify.Classifier;
import edu.stanford.nlp.classify.Dataset;
import edu.stanford.nlp.classify.LinearClassifierFactory;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.scenegraph.BoWExample;
import edu.stanford.nlp.scenegraph.image.SceneGraphImage;
import edu.stanford.nlp.scenegraph.image.SceneGraphImageRegion;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.PropertiesUtils;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Triple;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Scanner;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/scenegraph/BoWSceneGraphParser.class */
public class BoWSceneGraphParser extends AbstractSceneGraphParser {
    Classifier<String, String> classifier;
    EntityClassifier entityClassifer;
    boolean enforceSubtree;
    boolean includeAllObjects;
    private Embedding embeddings;
    private SceneGraphSentenceMatcher sentenceMatcher;
    public static final String NONE_RELATION = "----NONE----";
    public static final String IS_RELATION = "is";
    public static final String DEFAULT_MODEL_PATH = "edu/stanford/nlp/models/scenegraph/bow_model_final_subtree.gz";
    public static final String DEFAULT_ENTITY_MODEL_PATH = "edu/stanford/nlp/models/scenegraph/entityModel.gz";
    private static final double REG_STRENGTH = 1.0d;
    private static BoWExample.FEATURE_SET[] featureSets = {BoWExample.FEATURE_SET.LEMMA_BOW, BoWExample.FEATURE_SET.WORD_BOW, BoWExample.FEATURE_SET.TREE_FEAT};
    private static final Map<String, Integer> numArgs = new HashMap();

    public BoWSceneGraphParser(String str, String str2, Embedding embedding) {
        if (str != null) {
            try {
                this.classifier = (Classifier) IOUtils.readObjectFromURLOrClasspathOrFileSystem(str);
            } catch (IOException | ClassNotFoundException e) {
                e.printStackTrace();
            }
        }
        if (str2 != null) {
            this.entityClassifer = new EntityClassifier(str2);
        }
        this.embeddings = embedding;
        this.sentenceMatcher = new SceneGraphSentenceMatcher(embedding);
    }

    @Override // edu.stanford.nlp.scenegraph.AbstractSceneGraphParser
    public SceneGraph parse(SemanticGraph semanticGraph) {
        SemanticGraphEnhancer.enhance(semanticGraph);
        List<IndexedWord> extractEntities = EntityExtractor.extractEntities(semanticGraph);
        List<IndexedWord> extractAttributes = EntityExtractor.extractAttributes(semanticGraph);
        Iterator<IndexedWord> it = extractEntities.iterator();
        while (it.hasNext()) {
            this.entityClassifer.predictEntity(it.next(), this.embeddings);
        }
        Iterator<IndexedWord> it2 = extractAttributes.iterator();
        while (it2.hasNext()) {
            this.entityClassifer.predictEntity(it2.next(), this.embeddings);
        }
        LinkedList<BoWExample> newLinkedList = Generics.newLinkedList();
        for (IndexedWord indexedWord : extractEntities) {
            for (IndexedWord indexedWord2 : extractEntities) {
                if (indexedWord.index() != indexedWord2.index() && (!this.enforceSubtree || SceneGraphUtils.inSameSubTree(semanticGraph, indexedWord, indexedWord2))) {
                    newLinkedList.add(new BoWExample(indexedWord, indexedWord2, semanticGraph));
                }
            }
        }
        for (IndexedWord indexedWord3 : extractEntities) {
            for (IndexedWord indexedWord4 : extractAttributes) {
                if (!this.enforceSubtree || SceneGraphUtils.inSameSubTree(semanticGraph, indexedWord3, indexedWord4)) {
                    newLinkedList.add(new BoWExample(indexedWord3, indexedWord4, semanticGraph));
                }
            }
        }
        SceneGraph sceneGraph = new SceneGraph();
        for (BoWExample boWExample : newLinkedList) {
            String str = (String) this.classifier.classOf(new BasicDatum(boWExample.extractFeatures(featureSets)));
            if (!str.equals(NONE_RELATION)) {
                if (str.equals(IS_RELATION)) {
                    sceneGraph.getOrAddNode(boWExample.w1).addAttribute(boWExample.w2);
                } else {
                    sceneGraph.addEdge(sceneGraph.getOrAddNode(boWExample.w1), sceneGraph.getOrAddNode(boWExample.w2), str);
                }
            }
        }
        if (this.includeAllObjects || sceneGraph.nodeListSorted().isEmpty()) {
            Iterator<IndexedWord> it3 = extractEntities.iterator();
            while (it3.hasNext()) {
                sceneGraph.getOrAddNode(it3.next());
            }
        }
        return sceneGraph;
    }

    public static List<SceneGraphImage> loadImages(String str) throws IOException {
        LinkedList newLinkedList = Generics.newLinkedList();
        BufferedReader readerFromString = IOUtils.readerFromString(str);
        String readLine = readerFromString.readLine();
        while (true) {
            String str2 = readLine;
            if (str2 == null) {
                return newLinkedList;
            }
            SceneGraphImage readFromJSON = SceneGraphImage.readFromJSON(str2);
            if (readFromJSON != null) {
                newLinkedList.add(readFromJSON);
            }
            readLine = readerFromString.readLine();
        }
    }

    public Dataset<String, String> getTrainingExamples(String str, boolean z) throws IOException {
        Dataset<String, String> dataset = new Dataset<>();
        Dataset dataset2 = new Dataset();
        Iterator<SceneGraphImage> it = loadImages(str).iterator();
        while (it.hasNext()) {
            for (SceneGraphImageRegion sceneGraphImageRegion : it.next().regions) {
                SemanticGraph enhancedSemanticGraph = sceneGraphImageRegion.getEnhancedSemanticGraph();
                SemanticGraphEnhancer.processQuanftificationModifiers(enhancedSemanticGraph);
                SemanticGraphEnhancer.collapseCompounds(enhancedSemanticGraph);
                SemanticGraphEnhancer.collapseParticles(enhancedSemanticGraph);
                SemanticGraphEnhancer.resolvePronouns(enhancedSemanticGraph);
                Set newHashSet = Generics.newHashSet();
                for (Triple<IndexedWord, IndexedWord, String> triple : this.sentenceMatcher.getRelationTriples(sceneGraphImageRegion)) {
                    IndexedWord nodeByIndexSafe = enhancedSemanticGraph.getNodeByIndexSafe(((IndexedWord) triple.first).index());
                    IndexedWord nodeByIndexSafe2 = enhancedSemanticGraph.getNodeByIndexSafe(((IndexedWord) triple.second).index());
                    if (nodeByIndexSafe != null && nodeByIndexSafe2 != null && (!this.enforceSubtree || SceneGraphUtils.inSameSubTree(enhancedSemanticGraph, nodeByIndexSafe, nodeByIndexSafe2))) {
                        this.entityClassifer.predictEntity(nodeByIndexSafe, this.embeddings);
                        this.entityClassifer.predictEntity(nodeByIndexSafe2, this.embeddings);
                        dataset.add(new BoWExample(nodeByIndexSafe, nodeByIndexSafe2, enhancedSemanticGraph).extractFeatures(featureSets), triple.third);
                    }
                    newHashSet.add(Integer.valueOf((((IndexedWord) triple.first).index() << 4) + ((IndexedWord) triple.second).index()));
                }
                List<IndexedWord> extractEntities = EntityExtractor.extractEntities(enhancedSemanticGraph);
                List<IndexedWord> extractAttributes = EntityExtractor.extractAttributes(enhancedSemanticGraph);
                Iterator<IndexedWord> it2 = extractEntities.iterator();
                while (it2.hasNext()) {
                    this.entityClassifer.predictEntity(it2.next(), this.embeddings);
                }
                Iterator<IndexedWord> it3 = extractAttributes.iterator();
                while (it3.hasNext()) {
                    this.entityClassifer.predictEntity(it3.next(), this.embeddings);
                }
                for (IndexedWord indexedWord : extractEntities) {
                    for (IndexedWord indexedWord2 : extractEntities) {
                        if (indexedWord.index() != indexedWord2.index() && !newHashSet.contains(Integer.valueOf((indexedWord.index() << 4) + indexedWord2.index())) && (!this.enforceSubtree || SceneGraphUtils.inSameSubTree(enhancedSemanticGraph, indexedWord, indexedWord2))) {
                            dataset2.add(new BoWExample(indexedWord, indexedWord2, enhancedSemanticGraph).extractFeatures(featureSets), NONE_RELATION);
                        }
                    }
                }
                for (IndexedWord indexedWord3 : extractEntities) {
                    for (IndexedWord indexedWord4 : extractAttributes) {
                        if (!newHashSet.contains(Integer.valueOf((indexedWord3.index() << 4) + indexedWord4.index())) && (!this.enforceSubtree || SceneGraphUtils.inSameSubTree(enhancedSemanticGraph, indexedWord3, indexedWord4))) {
                            dataset2.add(new BoWExample(indexedWord3, indexedWord4, enhancedSemanticGraph).extractFeatures(featureSets), NONE_RELATION);
                        }
                    }
                }
            }
        }
        if (z && dataset.size() < dataset2.size()) {
            dataset2 = dataset2.getRandomSubDataset((dataset.size() * REG_STRENGTH) / dataset2.size(), 42);
        }
        dataset.addAll(dataset2);
        return dataset;
    }

    public void train(String str, String str2) throws IOException {
        IOUtils.writeObjectToFile(new LinearClassifierFactory(new QNMinimizer(15), 1.0E-4d, false, REG_STRENGTH).trainClassifier(getTrainingExamples(str, true)), str2);
    }

    public static void main(String[] strArr) throws IOException {
        Properties argsToProperties = StringUtils.argsToProperties(strArr, numArgs);
        boolean bool = PropertiesUtils.getBool(argsToProperties, "train", false);
        boolean bool2 = PropertiesUtils.getBool(argsToProperties, "enforceSubtree", true);
        boolean bool3 = PropertiesUtils.getBool(argsToProperties, "includeAllObjects", false);
        boolean bool4 = PropertiesUtils.getBool(argsToProperties, "verbose", false);
        String string = PropertiesUtils.getString(argsToProperties, "model", DEFAULT_MODEL_PATH);
        String string2 = PropertiesUtils.getString(argsToProperties, "entityModel", DEFAULT_ENTITY_MODEL_PATH);
        String string3 = PropertiesUtils.getString(argsToProperties, "input", (String) null);
        String string4 = PropertiesUtils.getString(argsToProperties, "embeddings", (String) null);
        if (string == null || string2 == null || string4 == null) {
            System.err.printf("Usage java %s -model <model.gz> -entityModel <entityModel.gz> -embeddings <wordVectors.gz> [-input <input.json> -train -verbose -enforceSubtree -evalFilePrefix <run0>]%n", BoWSceneGraphParser.class.getCanonicalName());
            return;
        }
        boolean z = string3 == null;
        Embedding embedding = new Embedding(string4);
        if (bool) {
            BoWSceneGraphParser boWSceneGraphParser = new BoWSceneGraphParser(null, string2, embedding);
            boWSceneGraphParser.enforceSubtree = bool2;
            boWSceneGraphParser.train(string3, string);
            return;
        }
        BoWSceneGraphParser boWSceneGraphParser2 = new BoWSceneGraphParser(string, string2, embedding);
        boWSceneGraphParser2.enforceSubtree = bool2;
        boWSceneGraphParser2.includeAllObjects = bool3;
        if (z) {
            System.err.println("Processing from stdin. Enter one sentence per line.");
            System.err.print("> ");
            Scanner scanner = new Scanner(System.in);
            while (true) {
                String nextLine = scanner.nextLine();
                if (nextLine == null) {
                    scanner.close();
                    return;
                } else {
                    System.err.println(boWSceneGraphParser2.parse(nextLine).toReadableString());
                    System.err.println("------------------------");
                    System.err.print("> ");
                }
            }
        } else {
            BufferedReader readerFromString = IOUtils.readerFromString(string3);
            SceneGraphEvaluation sceneGraphEvaluation = new SceneGraphEvaluation();
            String string5 = PropertiesUtils.getString(argsToProperties, "evalFilePrefix", (String) null);
            PrintWriter printWriter = null;
            PrintWriter printWriter2 = null;
            if (string5 != null) {
                String str = string5 + ".smatch";
                printWriter = IOUtils.getPrintWriter(str);
                printWriter2 = IOUtils.getPrintWriter(string5 + "_gold.smatch");
            }
            double d = 0.0d;
            double d2 = 0.0d;
            String readLine = readerFromString.readLine();
            while (true) {
                String str2 = readLine;
                if (str2 == null) {
                    System.err.println("#########################################################");
                    System.err.printf("Macro-averaged F1: %f%n", Double.valueOf(d2 / d));
                    System.err.println("#########################################################");
                    return;
                }
                SceneGraphImage readFromJSON = SceneGraphImage.readFromJSON(str2);
                if (readFromJSON != null) {
                    for (SceneGraphImageRegion sceneGraphImageRegion : readFromJSON.regions) {
                        d += REG_STRENGTH;
                        SceneGraph parse = boWSceneGraphParser2.parse(sceneGraphImageRegion.getEnhancedSemanticGraph());
                        System.out.println(parse.toJSON(readFromJSON.id, readFromJSON.url, sceneGraphImageRegion.phrase));
                        Triple<Double, Double, Double> evaluate = sceneGraphEvaluation.evaluate(parse, sceneGraphImageRegion);
                        if (string5 != null) {
                            sceneGraphEvaluation.toSmatchString(parse, sceneGraphImageRegion, printWriter, printWriter2);
                        }
                        if (bool4) {
                            System.err.println(sceneGraphImageRegion.phrase);
                            System.err.println(parse.toReadableString());
                            System.err.println(sceneGraphImageRegion.toReadableString());
                            System.err.printf("Prec: %f, Recall: %f, F1: %f%n", evaluate.first, evaluate.second, evaluate.third);
                            System.err.println("------------------------");
                        }
                        d2 += ((Double) evaluate.third).doubleValue();
                    }
                }
                readLine = readerFromString.readLine();
            }
        }
    }

    static {
        numArgs.put("model", 1);
        numArgs.put("entityModel", 1);
        numArgs.put("evalFilePrefix", 1);
        numArgs.put("input", 1);
    }
}
