Spaces:
Sleeping
Sleeping
import json | |
import os | |
import copy | |
from collections import defaultdict | |
from argparse import ArgumentParser | |
from tqdm import tqdm | |
import random | |
from tqdm import tqdm | |
from scripts.predict_concrete import read_kairos | |
from sftp import SpanPredictor | |
parser = ArgumentParser() | |
parser.add_argument('aida', type=str) | |
parser.add_argument('model', type=str) | |
parser.add_argument('fn2kairos', type=str, default=None) | |
parser.add_argument('--device', type=int, default=3) | |
args = parser.parse_args() | |
corpus = json.load(open(args.aida)) | |
mapping = read_kairos(args.fn2kairos) | |
predictor = SpanPredictor.from_path(args.model, cuda_device=args.device) | |
random.seed(42) | |
random.shuffle(corpus) | |
batch_size = 128 | |
def batchify(a_list): | |
cur = list() | |
for item in a_list: | |
cur.append(item) | |
if len(cur) == batch_size: | |
yield cur | |
cur = list() | |
if len(cur) > 0: | |
yield cur | |
batches = list(batchify(corpus)) | |
n_total = n_pos = n_span_match = 0 | |
for idx, lines in tqdm(enumerate(batches)): | |
n_total += batch_size | |
prediction_lines = predictor.predict_batch_sentences( | |
[line['tokens'] for line in lines], max_tokens=1024, ontology_mapping=mapping | |
) | |
for preds, ann in zip(prediction_lines, lines): | |
ann = ann['annotation'] | |
preds = preds['prediction'] | |
for pred in preds: | |
if pred['start_idx'] == ann['start_idx'] and pred['end_idx'] == ann['end_idx']: | |
n_span_match += 1 | |
if pred['label'] == ann['label']: | |
n_pos += 1 | |
print(f'exact match precision: {n_pos * 100 / n_total:.3f}') | |
print(f'span only precision: {n_span_match * 100 / n_total:.3f}') | |