File size: 10,442 Bytes
44a9d55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import os 
import json 
import argparse 
import re 
from copy import deepcopy
from tqdm import tqdm 

from utils import find_head, WhitespaceTokenizer, find_arg_span
import spacy 
print("convert_gen_to_output5.py")
def extract_args_from_template(ex, template, ontology_dict,):
    # extract argument text
    # 这个函数的返回值是一个字典 因此需要 template列表和ex中的predicted列表同时进行遍历放入字典中
    # 在这里定义两个列表 分别存放 定义存放模板的列表 TEMPLATE 和 相对应的生成 PREDICTED
    # 传过来的参数中的template就是包含所有模板的列表 因此不需要再定义TEMPLATE 还是需要定义一个存放分词后的template
    # 这里的template是相应事件类型下的模板包含多个
    # 原来处理的方式是一个数据和一个综合性模板 现在模板是分开的 为什么要把template传过来 这不是脱裤子放屁的操作?
    # 下面这段操作是因为上次模板的定义是相同因此只需要去列表中的第一个模板就行 这次需要用循环进行遍历
    t = []
    TEMPLATE = []
    for i in template:
        t = i.strip().split()
        TEMPLATE.append(t)
        t = []
    # 到此为止 得到存放该ex即该数据类型下的所有模板的分词后的列表存储 下面获取对应的predicted同理
    PREDICTED = []
    p = []
    for i in ex['predicted']:
        p = i.strip().split()
        PREDICTED.append(p)
        p = []
    # 这个字典变量定义了这个函数的返回值 应该是论元角色-论元短语的key-value映射
    predicted_args = {}
    evt_type = get_event_type(ex)[0]
    # 不出意外的话 TEMPLATE和PREDICTED的长度应该是相等的
    length = len(TEMPLATE)
    for i in range(length):
        template_words = TEMPLATE[i]
        predicted_words = PREDICTED[i]
        t_ptr = 0
        p_ptr = 0
        while t_ptr < len(template_words) and p_ptr < len(predicted_words):
            if re.match(r'<(arg\d+)>', template_words[t_ptr]):
                m = re.match(r'<(arg\d+)>', template_words[t_ptr])
                # 这一步的操作是从模板中到 <arg1> 这样的词符 即arg_num 然后通过arg_num找到对应论元角色arg_name
                arg_num = m.group(1)
                arg_name = ontology_dict[evt_type.replace('n/a', 'unspecified')][arg_num]

                if predicted_words[p_ptr] == '<arg>':
                    # missing argument
                    p_ptr +=1
                    t_ptr +=1
                else:
                    arg_start = p_ptr
                    if t_ptr + 1 == len(template_words):
                        while p_ptr < len(predicted_words):
                            p_ptr += 1
                    else:
                        while (p_ptr < len(predicted_words)) and (predicted_words[p_ptr] != template_words[t_ptr+1]):
                            p_ptr += 1
                    arg_text = predicted_words[arg_start:p_ptr]
                    predicted_args[arg_name] = arg_text
                    t_ptr += 1
                    # aligned
            else:
                t_ptr += 1
                p_ptr += 1

    return predicted_args

def get_event_type(ex):
        evt_type = []
        for evt in ex['evt_triggers']:
            for t in evt[2]:
                evt_type.append( t[0])
        return evt_type 

def check_coref(ex, arg_span, gold_spans):
    for clus in ex['corefs']:
        if arg_span in clus:
            matched_gold_spans = [span for span in gold_spans if span in clus]
            if len(matched_gold_spans) > 0:
                return matched_gold_spans[0]
    return arg_span 

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--gen-file',type=str, default='checkpoints/gen-new-tokenization-pred/sample_predictions.jsonl')
    parser.add_argument('--test-file', type=str,default='data/RAMS_1.0/data/test.jsonlines')
    parser.add_argument('--output-file',type=str, default='test_output.jsonl')
    parser.add_argument('--ontology-file',type=str, default='aida_ontology_new.csv')
    parser.add_argument('--head-only',action='store_true',default=False)
    parser.add_argument('--coref', action='store_true', default=False)
    args = parser.parse_args()

    # 加载词典
    nlp = spacy.load('en_core_web_sm')
    nlp.tokenizer = WhitespaceTokenizer(nlp.vocab)
    # read ontology 读取事件本体 模板文件中的内容
    ontology_dict = {}
    with open('aida_ontology_fj-5.csv', 'r') as f:
        for lidx, line in enumerate(f):
            if lidx == 0:  # header
                continue
            fields = line.strip().split(',')
            if len(fields) < 2:
                break
            evt_type = fields[0]
            if evt_type in ontology_dict.keys():
                arguments = fields[2:]
                ontology_dict[evt_type]['template'].append(fields[1])
                for i, arg in enumerate(arguments):
                    if arg != '':
                        ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg
                        ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1)
            else:
                ontology_dict[evt_type] = {}
                arguments = fields[2:]
                ontology_dict[evt_type]['template'] = []
                ontology_dict[evt_type]['template'].append(fields[1])
                for i, arg in enumerate(arguments):
                    if arg != '':
                        ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg
                        ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1)
    examples = {}
    print(args.gen_file)
    # data/RAMS_1.0/data/test_head_coref.jsonlines
    key = []
    with open(args.test_file, 'r') as f:
        for line in f:
            ex = json.loads(line.strip())
            #if ex['gold_evt_links'] == []:
                #key.append(ex['doc_key'])
                #continue
            ex['ref_evt_links'] = deepcopy(ex['gold_evt_links']) 
            ex['gold_evt_links'] = []
            examples[ex['doc_key']] = ex
        
    # checkpoints/gen-RAMS-pred/predictions.jsonl
    flag = {}
    with open(args.gen_file,'r') as f:
        for line in f:
            pred = json.loads(line.strip()) 
            # print(pred)
            # 因为最后生成 应该是 多个相同的事件类型在并列 这个操作好像把已经填入的predicte覆盖掉了
            # 在这里的循环中 应该继续向下扫描 采取和ontology中相同的处理方式 用列表的方式存储将pred中的内容存放到examples中的数据中
            # pred 是对预测文件中的预测结果句用空格进行分隔单词后的结果
            # pred中的内容主要包括 doc_key predicted gold
            # 如果扫描到的预测json数据事件类型在examples中存在 那么就将predicted存入列表
            # if pred['doc_key'] not in key:
            if pred['doc_key'] in flag.keys():
                #print(examples[pred['doc_key']]['predicted'])
                examples[pred['doc_key']]['predicted'].append(pred['predicted'])
                examples[pred['doc_key']]['gold'].append(pred['gold'])
            # 如果没有 说明这是新的事件类型
            else:
                flag[pred['doc_key']] = True
                examples[pred['doc_key']]['predicted'] = []
                examples[pred['doc_key']]['gold'] = []
                # 然后将此条数据存入
                examples[pred['doc_key']]['predicted'].append(pred['predicted'])
                examples[pred['doc_key']]['gold'].append(pred['gold'])

    # checkpoints/gen-RAMS-pred/out_put.jsonl
    writer = open(args.output_file, 'w')
    for ex in tqdm(examples.values()):
        if 'predicted' not in ex:# this is used for testing 
            continue 
        # get template  获取事件类型
        evt_type = get_event_type(ex)[0]
        context_words = [w for sent in ex['sentences'] for w in sent]
        # 这里的template是ontology_dict中 template 包含一个事件类型下的所有事件模板
        template = ontology_dict[evt_type.replace('n/a', 'unspecified')]['template']
        # extract argument text 
        # 这里应该是提取预测文件中预测到的论元短语 ex是一条json数据 template是这条json数据对应下的模板 on是论元角色和<arg1>的映射
        # 这里ex中的predicted和gold已经包括了该事件类型下的所有论元 用列表的形式进行存储 且顺序是一一对应的
        # 这里返回的predicted_args是一个字典:
        predicted_args = extract_args_from_template(ex, template, ontology_dict)
        # get trigger
        # extract argument span 找出触发词在文段中的索引
        str_p = ''
        str_g = ''
        for i in range(len(ex['predicted'])):
            str_p += ex['predicted'][i]
            str_g += ex['gold'][i]

        ex['predicted'] = str_p
        ex['gold'] = str_g
        trigger_start = ex['evt_triggers'][0][0]
        trigger_end = ex['evt_triggers'][0][1]
        # 上面返回的predicted_args是一个字典 暂时认为是论元角色和具体论元短语的映射
        # 还没有发现doc的作用
        doc = None
        # 通过test_rams.sh文件的设置 可以发现args.head_only的值为true
        # print('aa', args.head_only, args.coref)
        if args.head_only:
            # 从原始文本中取出标记
            doc = nlp(' '.join(context_words))
        # 其中arg_name是论元角色类型
        for argname in predicted_args:
            # 通过find_arg_span函数找出
            arg_span = find_arg_span(predicted_args[argname], context_words, 
                trigger_start, trigger_end, head_only=args.head_only, doc=doc) 
            #print(arg_span)
            if arg_span:# if None means hullucination

                if args.head_only and args.coref:
                    # consider coreferential mentions as matching 
                    assert('corefs' in ex)
                    print('aaa')
                    gold_spans = [a[1] for a in ex['ref_evt_links'] if a[2]==argname]
                    arg_span = check_coref(ex, list(arg_span), gold_spans)

                ex['gold_evt_links'].append([[trigger_start, trigger_end], list(arg_span), argname])

        writer.write(json.dumps(ex)+'\n')
    
    writer.close()