import json import os import copy from collections import defaultdict from argparse import ArgumentParser from tqdm import tqdm import random from tqdm import tqdm from scripts.predict_concrete import read_kairos from sftp import SpanPredictor parser = ArgumentParser() parser.add_argument('aida', type=str) parser.add_argument('model', type=str) parser.add_argument('fn2kairos', type=str, default=None) parser.add_argument('--device', type=int, default=3) args = parser.parse_args() corpus = json.load(open(args.aida)) mapping = read_kairos(args.fn2kairos) predictor = SpanPredictor.from_path(args.model, cuda_device=args.device) random.seed(42) random.shuffle(corpus) batch_size = 128 def batchify(a_list): cur = list() for item in a_list: cur.append(item) if len(cur) == batch_size: yield cur cur = list() if len(cur) > 0: yield cur batches = list(batchify(corpus)) n_total = n_pos = n_span_match = 0 for idx, lines in tqdm(enumerate(batches)): n_total += batch_size prediction_lines = predictor.predict_batch_sentences( [line['tokens'] for line in lines], max_tokens=1024, ontology_mapping=mapping ) for preds, ann in zip(prediction_lines, lines): ann = ann['annotation'] preds = preds['prediction'] for pred in preds: if pred['start_idx'] == ann['start_idx'] and pred['end_idx'] == ann['end_idx']: n_span_match += 1 if pred['label'] == ann['label']: n_pos += 1 print(f'exact match precision: {n_pos * 100 / n_total:.3f}') print(f'span only precision: {n_span_match * 100 / n_total:.3f}')