Gosse Minnema
Re-enable LOME
2890e34
from typing import *
import torch
import json
import argparse
import os
from tqdm import tqdm
from sftp.predictor import SpanPredictor
from sftp.models import SpanModel
from sftp.data_reader import BetterDatasetReader
def predict_doc(predictor, json_path: str):
src = json.load(open(json_path))
for doc_name, entry in tqdm(list(src['entries'].items())):
pred = predictor.predict_json(entry)
triggers = list()
for trigger in pred['prediction']:
children = list()
for child in trigger['children']:
children.append([child['start_idx'], child['end_idx']])
triggers.append({
"span": [trigger['start_idx'], trigger['end_idx']],
"argument": children
})
entry['trigger span'] = triggers
return src
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-a', type=str, help='archive path')
parser.add_argument('-s', type=str, help='source path')
parser.add_argument('-d', type=str, help='destination path')
parser.add_argument('-c', type=int, default=0, help='cuda device')
args = parser.parse_args()
predictor_ = SpanPredictor.from_path(os.path.join(args.a, 'model.tar.gz'), 'span', cuda_device=args.c)
model_name = os.path.basename(args.a)
tgt_path = os.path.join(args.d, model_name)
os.makedirs(tgt_path, exist_ok=True)
for root, _, files in os.walk(args.s):
for fn in files:
if not fn.endswith('json') and not fn.endswith('valid'):
continue
processed_json = predict_doc(predictor_, os.path.join(root, fn))
with open(os.path.join(tgt_path, fn), 'w') as fp:
json.dump(processed_json, fp)