| import json | |
| import os | |
| import copy | |
| from collections import defaultdict | |
| from argparse import ArgumentParser | |
| from tqdm import tqdm | |
| import random | |
| from tqdm import tqdm | |
| from scripts.predict_concrete import read_kairos | |
| from sftp import SpanPredictor | |
| parser = ArgumentParser() | |
| parser.add_argument('aida', type=str) | |
| parser.add_argument('model', type=str) | |
| parser.add_argument('fn2kairos', type=str, default=None) | |
| parser.add_argument('--device', type=int, default=3) | |
| args = parser.parse_args() | |
| corpus = json.load(open(args.aida)) | |
| mapping = read_kairos(args.fn2kairos) | |
| predictor = SpanPredictor.from_path(args.model, cuda_device=args.device) | |
| random.seed(42) | |
| random.shuffle(corpus) | |
| batch_size = 128 | |
| def batchify(a_list): | |
| cur = list() | |
| for item in a_list: | |
| cur.append(item) | |
| if len(cur) == batch_size: | |
| yield cur | |
| cur = list() | |
| if len(cur) > 0: | |
| yield cur | |
| batches = list(batchify(corpus)) | |
| n_total = n_pos = n_span_match = 0 | |
| for idx, lines in tqdm(enumerate(batches)): | |
| n_total += batch_size | |
| prediction_lines = predictor.predict_batch_sentences( | |
| [line['tokens'] for line in lines], max_tokens=1024, ontology_mapping=mapping | |
| ) | |
| for preds, ann in zip(prediction_lines, lines): | |
| ann = ann['annotation'] | |
| preds = preds['prediction'] | |
| for pred in preds: | |
| if pred['start_idx'] == ann['start_idx'] and pred['end_idx'] == ann['end_idx']: | |
| n_span_match += 1 | |
| if pred['label'] == ann['label']: | |
| n_pos += 1 | |
| print(f'exact match precision: {n_pos * 100 / n_total:.3f}') | |
| print(f'span only precision: {n_span_match * 100 / n_total:.3f}') | |