File size: 1,852 Bytes
05922fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import sys
import json
from pprint import pprint
from collections import defaultdict

from sftp.metrics.exact_match import ExactMatch


def evaluate():
    em = ExactMatch(True)
    sm = ExactMatch(False)
    gold_file, pred_file = sys.argv[1:]
    test_sentences = {json.loads(line)['meta']['sentence ID']: json.loads(line) for line in open(gold_file).readlines()}
    pred_sentences = defaultdict(list)
    for line in open(pred_file).readlines():
        one_pred = json.loads(line)
        pred_sentences[one_pred['meta']['sentence ID']].append(one_pred)
    for sent_id, gold_sent in test_sentences.items():
        pred_sent = pred_sentences.get(sent_id, [])
        pred_frames, pred_fes = [], []
        for fr_idx, fr in enumerate(pred_sent):
            pred_frames.append({key: fr[key] for key in ["start_idx", "end_idx", "label"]})
            pred_frames[-1]['parent'] = 0
            for fe in fr['children']:
                pred_fes.append({key: fe[key] for key in ["start_idx", "end_idx", "label"]})
                pred_fes[-1]['parent'] = fr_idx+1
        pred_to_eval = pred_frames + pred_fes

        gold_frames, gold_fes = [], []
        for fr_idx, fr in enumerate(gold_sent['frame']):
            gold_frames.append({
                'start_idx': fr['target'][0], 'end_idx': fr['target'][-1], "label": fr['name'], 'parent': 0
            })
            for start_idx, end_idx, fe_name in fr['fe']:
                gold_fes.append({
                    "start_idx": start_idx, "end_idx": end_idx, "label": fe_name, "parent": fr_idx+1
                })
        gold_to_eval = gold_frames + gold_fes
        em(pred_to_eval, gold_to_eval)
        sm(pred_to_eval, gold_to_eval)

    print('EM')
    pprint(em.get_metric(True))
    print('SM')
    pprint(sm.get_metric(True))


if __name__ == '__main__':
    evaluate()