File size: 6,715 Bytes
6c25ddb
 
 
 
 
 
 
 
 
44a9d55
6c25ddb
 
44a9d55
6c25ddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44a9d55
 
 
6c25ddb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os 
import json 
import argparse 
import re 
from copy import deepcopy
from tqdm import tqdm 

from utils import find_head, WhitespaceTokenizer, find_arg_span
import spacy 
print("convert_gen_to_output2.py")
def extract_args_from_template(ex, template, ontology_dict,):
    # extract argument text 
    template_words = template[0].strip().split()
    predicted_words = ex['predicted'].strip().split()
    predicted_args = {}
    t_ptr= 0
    p_ptr= 0 
    evt_type = get_event_type(ex)[0]

    while t_ptr < len(template_words) and p_ptr < len(predicted_words):
        if re.match(r'<(arg\d+)>', template_words[t_ptr]):
            m = re.match(r'<(arg\d+)>', template_words[t_ptr])
            arg_num = m.group(1)
            arg_name = ontology_dict[evt_type.replace('n/a','unspecified')][arg_num]

            if predicted_words[p_ptr] == '<arg>':
                # missing argument
                p_ptr +=1 
                t_ptr +=1  
            else:
                arg_start = p_ptr 
                while (p_ptr < len(predicted_words)) and (predicted_words[p_ptr] != template_words[t_ptr+1]):
                    p_ptr+=1 
                arg_text = predicted_words[arg_start:p_ptr]
                predicted_args[arg_name] = arg_text 
                t_ptr+=1 
                # aligned 
        else:
            t_ptr+=1 
            p_ptr+=1 
    
    return predicted_args






def get_event_type(ex):
        evt_type = []
        for evt in ex['evt_triggers']:
            for t in evt[2]:
                evt_type.append( t[0])
        return evt_type 

def check_coref(ex, arg_span, gold_spans):
    for clus in ex['corefs']:
        if arg_span in clus:
            matched_gold_spans = [span for span in gold_spans if span in clus]
            if len(matched_gold_spans) > 0:
                return matched_gold_spans[0]
    return arg_span 
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--gen-file',type=str, default='checkpoints/gen-new-tokenization-pred/sample_predictions.jsonl')
    parser.add_argument('--test-file', type=str,default='data/RAMS_1.0/data/test_head.jsonlines')
    parser.add_argument('--output-file',type=str, default='test_output.jsonl')
    parser.add_argument('--ontology-file',type=str, default='aida_ontology_new.csv')
    parser.add_argument('--head-only',action='store_true',default=False)
    parser.add_argument('--coref', action='store_true', default=False)
    args = parser.parse_args() 

    nlp = spacy.load('en_core_web_sm')
    nlp.tokenizer = WhitespaceTokenizer(nlp.vocab)
    # read ontology 读取事件本体 模板文件中的内容
    ontology_dict ={} 
    with open('aida_ontology_new.csv','r') as f:
        for lidx, line in enumerate(f):
            # 跳过第一行表头字段
            if lidx == 0:# header 
                continue 
            fields = line.strip().split(',')
            # 说明该事件类型下不存在待抽取的论元
            if len(fields) < 2:
                break
            # 事件类型是第一个
            evt_type = fields[0]
            # 从第三个元素往后都是待抽取论语及其论元角色
            arguments = fields[2:]
            # 获取该事件类型下带带抽取的论元数量
            args_len = 0
            for i, arg in enumerate(arguments):
                if arg != '':
                    args_len = args_len + 1
            # 将事件本体字典中添加事件类型的key,该key下对应的value为模板
            # 利用args_len将template中的子模板数量进行循环增加, 将后续的子模板通过字符串拼接的方式进行增加
            # 最终的模板样式变为 what is the <arg1> in <trg> what is the <arg2> in <trg>
            # 先利用一个临时的字符串变量来存储模板 ----------> temp_template
            temp_template = []
            for i in range(len(arguments)):
                temp_template.append(" what is the <arg{}> in <trg>".format(i + 1))
            print(temp_template)
            # 在事件本体字典中建立key-value 以事件类型为关键字
            ontology_dict[evt_type] = {
                    'template': temp_template
                }

            for i, arg in enumerate(arguments):
                if arg !='':
                    ontology_dict[evt_type]['arg{}'.format(i+1)] = arg 
                    ontology_dict[evt_type][arg] = 'arg{}'.format(i+1)
    
    
    examples = {}
    print(args.gen_file)
    # data/RAMS_1.0/data/test_head_coref.jsonlines
    with open(args.test_file, 'r') as f:
        for line in f:
            ex = json.loads(line.strip())
            ex['ref_evt_links'] = deepcopy(ex['gold_evt_links']) 
            ex['gold_evt_links'] = []
            examples[ex['doc_key']] =ex 
        
    # checkpoints/gen-RAMS-pred/predictions.jsonl
    with open(args.gen_file,'r') as f:
        for line in f:
            pred = json.loads(line.strip()) 
            # print(pred)
            examples[pred['doc_key']]['predicted'] = pred['predicted']
            examples[pred['doc_key']]['gold'] = pred['gold']

    # checkpoints/gen-RAMS-pred/out_put.jsonl
    writer = open(args.output_file, 'w') 
    for ex in tqdm(examples.values()):
        if 'predicted' not in ex:# this is used for testing 
            continue 
        # get template 
        evt_type = get_event_type(ex)[0]
        context_words = [w for sent in ex['sentences'] for w in sent ]
        template = ontology_dict[evt_type.replace('n/a','unspecified')]['template']
        # extract argument text 

        predicted_args = extract_args_from_template(ex,template, ontology_dict)
        # get trigger 
        # extract argument span
        trigger_start = ex['evt_triggers'][0][0]
        trigger_end = ex['evt_triggers'][0][1]
        doc = None 
        if args.head_only:
            doc = nlp(' '.join(context_words))

        for argname in predicted_args:
            arg_span = find_arg_span(predicted_args[argname], context_words, 
                trigger_start, trigger_end, head_only=args.head_only, doc=doc) 
            if arg_span:# if None means hullucination
                
                if args.head_only and args.coref:
                    # consider coreferential mentions as matching 
                    assert('corefs' in ex)
                    gold_spans = [a[1] for a in ex['ref_evt_links'] if a[2]==argname]
                    arg_span = check_coref(ex, list(arg_span), gold_spans)

                ex['gold_evt_links'].append([[trigger_start, trigger_end], list(arg_span), argname])

        writer.write(json.dumps(ex)+'\n')
    
    writer.close()