ssa-perin / utility /predict.py
larkkin's picture
Add code
991f07c
raw
history blame
1.78 kB
import os
import json
import torch
import sys
from subprocess import run
from data.batch import Batch
sys.path.append("../evaluation")
from evaluate_single_dataset import evaluate
def predict(model, data, input_path, raw_input_path, args, logger, output_directory, device, mode="validation", epoch=None):
model.eval()
framework, language = args.framework, args.language
sentences = {}
with open(input_path, encoding="utf8") as f:
for line in f.readlines():
line = json.loads(line)
line["nodes"], line["edges"], line["tops"] = [], [], []
line["framework"], line["language"] = framework, language
sentences[line["id"]] = line
for i, batch in enumerate(data):
with torch.no_grad():
predictions = model(Batch.to(batch, device), inference=True)
for prediction in predictions:
for key, value in prediction.items():
sentences[prediction["id"]][key] = value
if epoch is not None:
output_path = f"{output_directory}/prediction_{mode}_{epoch}_{framework}_{language}.json"
else:
output_path = f"{output_directory}/prediction.json"
with open(output_path, "w", encoding="utf8") as f:
for sentence in sentences.values():
json.dump(sentence, f, ensure_ascii=False)
f.write("\n")
f.flush()
run(["./convert.sh", output_path] + (["--node_centric "] if args.graph_mode != "labeled-edge" else []))
if raw_input_path:
results = evaluate(raw_input_path, f"{output_path}_converted")
print(mode, results, flush=True)
if logger is not None:
logger.log_evaluation(results, mode, epoch)
return results["sentiment_tuple/f1"]