File size: 1,777 Bytes
991f07c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import json
import torch
import sys

from subprocess import run
from data.batch import Batch

sys.path.append("../evaluation")
from evaluate_single_dataset import evaluate


def predict(model, data, input_path, raw_input_path, args, logger, output_directory, device, mode="validation", epoch=None):
    model.eval()

    framework, language = args.framework, args.language
    sentences = {}
    with open(input_path, encoding="utf8") as f:
        for line in f.readlines():
            line = json.loads(line)
            line["nodes"], line["edges"], line["tops"] = [], [], []
            line["framework"], line["language"] = framework, language
            sentences[line["id"]] = line

    for i, batch in enumerate(data):
        with torch.no_grad():
            predictions = model(Batch.to(batch, device), inference=True)
            for prediction in predictions:
                for key, value in prediction.items():
                    sentences[prediction["id"]][key] = value

    if epoch is not None:
        output_path = f"{output_directory}/prediction_{mode}_{epoch}_{framework}_{language}.json"
    else:
        output_path = f"{output_directory}/prediction.json"

    with open(output_path, "w", encoding="utf8") as f:
        for sentence in sentences.values():
            json.dump(sentence, f, ensure_ascii=False)
            f.write("\n")
            f.flush()

    run(["./convert.sh", output_path] + (["--node_centric "] if args.graph_mode != "labeled-edge" else []))

    if raw_input_path:
        results = evaluate(raw_input_path, f"{output_path}_converted")
        print(mode, results, flush=True)

        if logger is not None:
            logger.log_evaluation(results, mode, epoch)

        return results["sentiment_tuple/f1"]