File size: 1,684 Bytes
05922fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import json
import os
import copy
from collections import defaultdict
from argparse import ArgumentParser
from tqdm import tqdm
import random
from tqdm import tqdm
from scripts.predict_concrete import read_kairos

from sftp import SpanPredictor


parser = ArgumentParser()
parser.add_argument('aida', type=str)
parser.add_argument('model', type=str)
parser.add_argument('fn2kairos', type=str, default=None)
parser.add_argument('--device', type=int, default=3)
args = parser.parse_args()

corpus = json.load(open(args.aida))
mapping = read_kairos(args.fn2kairos)
predictor = SpanPredictor.from_path(args.model, cuda_device=args.device)
random.seed(42)
random.shuffle(corpus)
batch_size = 128


def batchify(a_list):
    cur = list()
    for item in a_list:
        cur.append(item)
        if len(cur) == batch_size:
            yield cur
            cur = list()
    if len(cur) > 0:
        yield cur


batches = list(batchify(corpus))


n_total = n_pos = n_span_match = 0
for idx, lines in tqdm(enumerate(batches)):
    n_total += batch_size
    prediction_lines = predictor.predict_batch_sentences(
        [line['tokens'] for line in lines], max_tokens=1024, ontology_mapping=mapping
    )
    for preds, ann in zip(prediction_lines, lines):
        ann = ann['annotation']
        preds = preds['prediction']
        for pred in preds:
            if pred['start_idx'] == ann['start_idx'] and pred['end_idx'] == ann['end_idx']:
                n_span_match += 1
                if pred['label'] == ann['label']:
                    n_pos += 1

    print(f'exact match precision: {n_pos * 100 / n_total:.3f}')
    print(f'span only precision: {n_span_match * 100 / n_total:.3f}')