Gosse Minnema
Initial commit
05922fb
import json
import os
import copy
from collections import defaultdict
from argparse import ArgumentParser
from tqdm import tqdm
import random
from tqdm import tqdm
from scripts.predict_concrete import read_kairos
from sftp import SpanPredictor
parser = ArgumentParser()
parser.add_argument('aida', type=str)
parser.add_argument('model', type=str)
parser.add_argument('fn2kairos', type=str, default=None)
parser.add_argument('--device', type=int, default=3)
args = parser.parse_args()
corpus = json.load(open(args.aida))
mapping = read_kairos(args.fn2kairos)
predictor = SpanPredictor.from_path(args.model, cuda_device=args.device)
random.seed(42)
random.shuffle(corpus)
batch_size = 128
def batchify(a_list):
cur = list()
for item in a_list:
cur.append(item)
if len(cur) == batch_size:
yield cur
cur = list()
if len(cur) > 0:
yield cur
batches = list(batchify(corpus))
n_total = n_pos = n_span_match = 0
for idx, lines in tqdm(enumerate(batches)):
n_total += batch_size
prediction_lines = predictor.predict_batch_sentences(
[line['tokens'] for line in lines], max_tokens=1024, ontology_mapping=mapping
)
for preds, ann in zip(prediction_lines, lines):
ann = ann['annotation']
preds = preds['prediction']
for pred in preds:
if pred['start_idx'] == ann['start_idx'] and pred['end_idx'] == ann['end_idx']:
n_span_match += 1
if pred['label'] == ann['label']:
n_pos += 1
print(f'exact match precision: {n_pos * 100 / n_total:.3f}')
print(f'span only precision: {n_span_match * 100 / n_total:.3f}')