Spaces:
Sleeping
Sleeping
File size: 1,684 Bytes
05922fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import json
import os
import copy
from collections import defaultdict
from argparse import ArgumentParser
from tqdm import tqdm
import random
from tqdm import tqdm
from scripts.predict_concrete import read_kairos
from sftp import SpanPredictor
parser = ArgumentParser()
parser.add_argument('aida', type=str)
parser.add_argument('model', type=str)
parser.add_argument('fn2kairos', type=str, default=None)
parser.add_argument('--device', type=int, default=3)
args = parser.parse_args()
corpus = json.load(open(args.aida))
mapping = read_kairos(args.fn2kairos)
predictor = SpanPredictor.from_path(args.model, cuda_device=args.device)
random.seed(42)
random.shuffle(corpus)
batch_size = 128
def batchify(a_list):
cur = list()
for item in a_list:
cur.append(item)
if len(cur) == batch_size:
yield cur
cur = list()
if len(cur) > 0:
yield cur
batches = list(batchify(corpus))
n_total = n_pos = n_span_match = 0
for idx, lines in tqdm(enumerate(batches)):
n_total += batch_size
prediction_lines = predictor.predict_batch_sentences(
[line['tokens'] for line in lines], max_tokens=1024, ontology_mapping=mapping
)
for preds, ann in zip(prediction_lines, lines):
ann = ann['annotation']
preds = preds['prediction']
for pred in preds:
if pred['start_idx'] == ann['start_idx'] and pred['end_idx'] == ann['end_idx']:
n_span_match += 1
if pred['label'] == ann['label']:
n_pos += 1
print(f'exact match precision: {n_pos * 100 / n_total:.3f}')
print(f'span only precision: {n_span_match * 100 / n_total:.3f}')
|