Spaces:
Build error
Build error
from typing import * | |
import torch | |
import json | |
import argparse | |
import os | |
from tqdm import tqdm | |
from sftp.predictor import SpanPredictor | |
from sftp.models import SpanModel | |
from sftp.data_reader import BetterDatasetReader | |
def predict_doc(predictor, json_path: str): | |
src = json.load(open(json_path)) | |
for doc_name, entry in tqdm(list(src['entries'].items())): | |
pred = predictor.predict_json(entry) | |
triggers = list() | |
for trigger in pred['prediction']: | |
children = list() | |
for child in trigger['children']: | |
children.append([child['start_idx'], child['end_idx']]) | |
triggers.append({ | |
"span": [trigger['start_idx'], trigger['end_idx']], | |
"argument": children | |
}) | |
entry['trigger span'] = triggers | |
return src | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-a', type=str, help='archive path') | |
parser.add_argument('-s', type=str, help='source path') | |
parser.add_argument('-d', type=str, help='destination path') | |
parser.add_argument('-c', type=int, default=0, help='cuda device') | |
args = parser.parse_args() | |
predictor_ = SpanPredictor.from_path(os.path.join(args.a, 'model.tar.gz'), 'span', cuda_device=args.c) | |
model_name = os.path.basename(args.a) | |
tgt_path = os.path.join(args.d, model_name) | |
os.makedirs(tgt_path, exist_ok=True) | |
for root, _, files in os.walk(args.s): | |
for fn in files: | |
if not fn.endswith('json') and not fn.endswith('valid'): | |
continue | |
processed_json = predict_doc(predictor_, os.path.join(root, fn)) | |
with open(os.path.join(tgt_path, fn), 'w') as fp: | |
json.dump(processed_json, fp) | |