package edu.stanford.nlp.scenegraph;

import edu.stanford.nlp.classify.KNNClassifier;
import edu.stanford.nlp.classify.KNNClassifierFactory;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.scenegraph.image.SceneGraphImage;
import edu.stanford.nlp.scenegraph.image.SceneGraphImageAttribute;
import edu.stanford.nlp.scenegraph.image.SceneGraphImageRegion;
import edu.stanford.nlp.scenegraph.image.SceneGraphImageRelationship;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Triple;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/scenegraph/KNNSceneGraphParser.class */
public class KNNSceneGraphParser extends AbstractSceneGraphParser {
    KNNClassifier<String, String> classifier;

    public KNNSceneGraphParser(String str) {
        if (str != null) {
            try {
                this.classifier = (KNNClassifier) IOUtils.readObjectFromFile(str);
            } catch (IOException | ClassNotFoundException e) {
                e.printStackTrace();
            }
        }
    }

    @Override // edu.stanford.nlp.scenegraph.AbstractSceneGraphParser
    public SceneGraph parse(SemanticGraph semanticGraph) {
        return null;
    }

    public SceneGraphImageRegion parse(List<CoreLabel> list, Map<Integer, SceneGraphImage> map) throws IOException {
        ClassicCounter classicCounter = new ClassicCounter();
        Iterator<CoreLabel> it = list.iterator();
        while (it.hasNext()) {
            classicCounter.incrementCount(it.next().word());
        }
        String[] split = ((String) this.classifier.classOf(new RVFDatum(classicCounter))).split("_");
        int parseInt = Integer.parseInt(split[0]);
        int parseInt2 = Integer.parseInt(split[1]);
        SceneGraphImage sceneGraphImage = map.get(Integer.valueOf(parseInt));
        if (sceneGraphImage == null) {
            return null;
        }
        return sceneGraphImage.regions.get(parseInt2);
    }

    private Map<Integer, SceneGraphImage> loadImages(String str) throws IOException {
        Map<Integer, SceneGraphImage> newHashMap = Generics.newHashMap();
        BufferedReader readerFromString = IOUtils.readerFromString(str);
        String readLine = readerFromString.readLine();
        while (true) {
            String str2 = readLine;
            if (str2 == null) {
                return newHashMap;
            }
            SceneGraphImage readFromJSON = SceneGraphImage.readFromJSON(str2);
            if (readFromJSON != null) {
                newHashMap.put(Integer.valueOf(readFromJSON.id), readFromJSON);
            }
            readLine = readerFromString.readLine();
        }
    }

    private void train(String str, String str2) throws IOException {
        Map<Integer, SceneGraphImage> loadImages = loadImages(str);
        KNNClassifierFactory kNNClassifierFactory = new KNNClassifierFactory(1, false, false);
        LinkedList newLinkedList = Generics.newLinkedList();
        Iterator<Integer> it = loadImages.keySet().iterator();
        while (it.hasNext()) {
            SceneGraphImage sceneGraphImage = loadImages.get(it.next());
            if (sceneGraphImage != null) {
                int size = sceneGraphImage.regions.size();
                for (int i = 0; i < size; i++) {
                    SceneGraphImageRegion sceneGraphImageRegion = sceneGraphImage.regions.get(i);
                    ClassicCounter classicCounter = new ClassicCounter();
                    Iterator<CoreLabel> it2 = sceneGraphImageRegion.tokens.iterator();
                    while (it2.hasNext()) {
                        classicCounter.incrementCount(it2.next().word());
                    }
                    newLinkedList.add(new RVFDatum(classicCounter, String.format("%d_%d", Integer.valueOf(sceneGraphImage.id), Integer.valueOf(i))));
                }
            }
        }
        IOUtils.writeObjectToFile(kNNClassifierFactory.train(newLinkedList), str2);
    }

    public static void main(String[] strArr) throws IOException {
        if (strArr.length >= 3 && strArr[2].equals("-train")) {
            new KNNSceneGraphParser(null).train(strArr[0], strArr[1]);
            return;
        }
        KNNSceneGraphParser kNNSceneGraphParser = new KNNSceneGraphParser(strArr[1]);
        Map<Integer, SceneGraphImage> loadImages = kNNSceneGraphParser.loadImages(strArr[2]);
        BufferedReader readerFromString = IOUtils.readerFromString(strArr[0]);
        PrintWriter printWriter = IOUtils.getPrintWriter(strArr[3]);
        PrintWriter printWriter2 = IOUtils.getPrintWriter(strArr[4]);
        SceneGraphEvaluation sceneGraphEvaluation = new SceneGraphEvaluation();
        double d = 0.0d;
        double d2 = 0.0d;
        String readLine = readerFromString.readLine();
        while (true) {
            String str = readLine;
            if (str == null) {
                System.err.println("#########################################################");
                System.err.printf("Macro-averaged F1: %f%n", Double.valueOf(d2 / d));
                System.err.println("#########################################################");
                return;
            }
            SceneGraphImage readFromJSON = SceneGraphImage.readFromJSON(str);
            for (SceneGraphImageRegion sceneGraphImageRegion : readFromJSON.regions) {
                d += 1.0d;
                SceneGraphImageRegion parse = kNNSceneGraphParser.parse(sceneGraphImageRegion.tokens, loadImages);
                Triple<Double, Double, Double> evaluate = sceneGraphEvaluation.evaluate(parse, sceneGraphImageRegion);
                sceneGraphEvaluation.toSmatchString(parse, sceneGraphImageRegion, printWriter, printWriter2);
                SceneGraphImage sceneGraphImage = new SceneGraphImage();
                sceneGraphImage.id = readFromJSON.id;
                sceneGraphImage.url = readFromJSON.url;
                sceneGraphImage.height = readFromJSON.height;
                sceneGraphImage.width = readFromJSON.width;
                Set newHashSet = Generics.newHashSet();
                Iterator<SceneGraphImageAttribute> it = sceneGraphImageRegion.attributes.iterator();
                while (it.hasNext()) {
                    newHashSet.add(Integer.valueOf(readFromJSON.objects.indexOf(it.next().subject)));
                }
                for (SceneGraphImageRelationship sceneGraphImageRelationship : sceneGraphImageRegion.relationships) {
                    newHashSet.add(Integer.valueOf(readFromJSON.objects.indexOf(sceneGraphImageRelationship.subject)));
                    newHashSet.add(Integer.valueOf(readFromJSON.objects.indexOf(sceneGraphImageRelationship.object)));
                }
                sceneGraphImage.objects = Generics.newArrayList();
                Iterator it2 = newHashSet.iterator();
                while (it2.hasNext()) {
                    sceneGraphImage.objects.add(readFromJSON.objects.get(((Integer) it2.next()).intValue()));
                }
                SceneGraphImageRegion sceneGraphImageRegion2 = new SceneGraphImageRegion();
                sceneGraphImageRegion2.phrase = sceneGraphImageRegion.phrase;
                sceneGraphImageRegion2.x = sceneGraphImageRegion.x;
                sceneGraphImageRegion2.y = sceneGraphImageRegion.y;
                sceneGraphImageRegion2.h = sceneGraphImageRegion.h;
                sceneGraphImageRegion2.w = sceneGraphImageRegion.w;
                sceneGraphImageRegion2.attributes = Generics.newHashSet();
                sceneGraphImageRegion2.relationships = Generics.newHashSet();
                sceneGraphImage.regions = Generics.newArrayList();
                sceneGraphImage.regions.add(sceneGraphImageRegion2);
                sceneGraphImage.attributes = Generics.newLinkedList();
                Iterator<SceneGraphImageAttribute> it3 = sceneGraphImageRegion.attributes.iterator();
                while (it3.hasNext()) {
                    SceneGraphImageAttribute m11clone = it3.next().m11clone();
                    m11clone.region = sceneGraphImageRegion2;
                    m11clone.image = sceneGraphImage;
                    sceneGraphImage.addAttribute(m11clone);
                }
                sceneGraphImage.relationships = Generics.newLinkedList();
                Iterator<SceneGraphImageRelationship> it4 = sceneGraphImageRegion.relationships.iterator();
                while (it4.hasNext()) {
                    SceneGraphImageRelationship m12clone = it4.next().m12clone();
                    m12clone.image = sceneGraphImage;
                    m12clone.region = sceneGraphImageRegion2;
                    sceneGraphImage.addRelationship(m12clone);
                }
                System.out.println(sceneGraphImage.toJSON());
                System.err.printf("Prec: %f, Recall: %f, F1: %f%n", evaluate.first, evaluate.second, evaluate.third);
                d2 += ((Double) evaluate.third).doubleValue();
            }
            readLine = readerFromString.readLine();
        }
    }
}
