| import json | |
| import os | |
| import copy | |
| from collections import defaultdict | |
| from argparse import ArgumentParser | |
| from tqdm import tqdm | |
| import random | |
| from tqdm import tqdm | |
| from scripts.predict_concrete import read_kairos | |
| from sftp import SpanPredictor | |
| parser = ArgumentParser() | |
| parser.add_argument('aida', type=str) | |
| parser.add_argument('model', type=str) | |
| parser.add_argument('dst', type=str) | |
| parser.add_argument('--topk', type=int, default=10) | |
| parser.add_argument('--device', type=int, default=0) | |
| args = parser.parse_args() | |
| k = args.topk | |
| corpus = json.load(open(args.aida)) | |
| predictor = SpanPredictor.from_path(args.model, cuda_device=args.device) | |
| idx2fn = predictor._model.vocab.get_index_to_token_vocabulary('span_label') | |
| random.seed(42) | |
| random.shuffle(corpus) | |
| output_fp = open(args.dst, 'a') | |
| for line in tqdm(corpus): | |
| tokens, ann = line['tokens'], line['annotation'] | |
| start, end, kairos_label = ann['start_idx'], ann['end_idx'], ann['label'] | |
| prob_dist = predictor.force_decode(tokens, [(start, end)])[0] | |
| topk_indices = prob_dist.argsort(descending=True)[:k] | |
| prob = prob_dist[topk_indices].tolist() | |
| frames = [(idx2fn[int(idx)], p) for idx, p in zip(topk_indices, prob)] | |
| output_fp.write(json.dumps({ | |
| 'tokens': tokens, | |
| 'frames': frames, | |
| 'kairos': kairos_label | |
| }) + '\n') | |