File size: 10,842 Bytes
44a9d55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import json
import re
import spacy
from tqdm import tqdm

from src.genie.utils import WhitespaceTokenizer

#x = 0
def find_head(arg_start, arg_end, doc):
    # 设置一个临时变量 存储 论元短语的开始索引 cur_i = arg_start
    cur_i = arg_start
    # 进行遍历
    while doc[cur_i].head.i >= arg_start and doc[cur_i].head.i <= arg_end:
        if doc[cur_i].head.i == cur_i:
            # self is the head
            break
        else:
            cur_i = doc[cur_i].head.i

    arg_head = cur_i

    return (arg_head, arg_head)

def find_arg_span(arg, context_words, trigger_start, trigger_end, head_only=False, doc=None):
    # 要定义一个match 作为匹配项
    match = None
    # arg 是论元短语 是预测文件中predicted中生成的论元短语 arg_len目前的含义是获取生成论元短语的长度
    arg_len = len(arg)
    # context_words 是文本 min_dis是最短距离
    min_dis = len(context_words)  # minimum distance to trigger
    #print(arg)

    #x = 0
    # i 代表文本中的单词索引 w 代表文本中的i索引对应的单词
    for i, w in enumerate(context_words):
        # 如果文本单词列表中有一段单词 和 模型生成的单词是相等的
        if context_words[i:i + arg_len] == arg:
            # 如果 这个论元单词的开始索引在触发词单词索引之前
            # global x += 1
            # print('aa')
            if i < trigger_start:
                # 那么距离就是 触发词单词的开始索引减去论元短语的开始索引再减去论元短语的长度
                dis = abs(trigger_start - i - arg_len)
            else:
                # 反之
                dis = abs(i - trigger_end)
            if dis < min_dis:
                # match是一个元组
                match = (i, i + arg_len - 1)
                min_dis = dis

    #print(match)
    if match and head_only:
        assert (doc != None)
        match = find_head(match[0], match[1], doc)
    #print(x)
    return match

def get_event_type(ex):
    evt_type = []
    for evt in ex['evt_triggers']:
        for t in evt[2]:
            evt_type.append(t[0])
    return evt_type

def extract_args_from_template(ex, template, ontology_dict,):
    # extract argument text
    # 这个函数的返回值是一个字典 因此需要 template列表和ex中的predicted列表同时进行遍历放入字典中
    # 在这里定义两个列表 分别存放 定义存放模板的列表 TEMPLATE 和 相对应的生成 PREDICTED
    # 传过来的参数中的template就是包含所有模板的列表 因此不需要再定义TEMPLATE 还是需要定义一个存放分词后的template
    # 这里的template是相应事件类型下的模板包含多个
    # 原来处理的方式是一个数据和一个综合性模板 现在模板是分开的 为什么要把template传过来 这不是脱裤子放屁的操作?
    # 下面这段操作是因为上次模板的定义是相同因此只需要去列表中的第一个模板就行 这次需要用循环进行遍历
    # print(ex)
    t = []
    TEMPLATE = []
    for i in template:
        t = i.strip().split()
        TEMPLATE.append(t)
        t = []
    # 到此为止 得到存放该ex即该数据类型下的所有模板的分词后的列表存储 下面获取对应的predicted同理
    PREDICTED = []
    p = []
    # 形参中插入的ex应该包含了该条数据(即该事件类型下)所有应该生成的论元对应的模板
    # 在程序中出现了不一样的情况 貌似只有一条模板数据 这个问题解决了
    # print(ex['predicted'])
    for i in ex['predicted']:
        p = i.strip().split()
        PREDICTED.append(p)
        p = []
    # print(TEMPLATE)
    # print(PREDICTED)
    # 这个字典变量定义了这个函数的返回值 应该是论元角色-论元短语的key-value映射
    predicted_args = {}
    evt_type = get_event_type(ex)[0]
    # print(evt_type)
    # 不出意外的话 TEMPLATE和PREDICTED的长度应该是相等的
    length = len(TEMPLATE)
    for i in range(length):
        #if i < 4:
            #continue
        template_words = TEMPLATE[i]
        predicted_words = PREDICTED[i]
        t_ptr = 0
        p_ptr = 0
        print(template_words)
        print(predicted_words)
        while t_ptr < len(template_words) and p_ptr < len(predicted_words):
            if re.match(r'<(arg\d+)>', template_words[t_ptr]):
                # print('aa')
                m = re.match(r'<(arg\d+)>', template_words[t_ptr])
                # 这一步的操作是从模板中得到 <arg1> 这样的词符 即arg_num 然后通过arg_num找到对应论元角色arg_name
                arg_num = m.group(1)
                # print(arg_num)
                arg_name = ontology_dict[evt_type.replace('n/a', 'unspecified')][arg_num]

                if predicted_words[p_ptr] == '<arg>':
                    # missing argument
                    p_ptr +=1
                    t_ptr +=1
                else:
                    arg_start = p_ptr
                    if t_ptr + 1 == len(template_words):
                        while (p_ptr < len(predicted_words)):
                            p_ptr += 1
                    else:
                        while (p_ptr < len(predicted_words)) and (predicted_words[p_ptr] != template_words[t_ptr+1]):
                            p_ptr += 1
                    arg_text = predicted_words[arg_start:p_ptr]
                    predicted_args[arg_name] = arg_text
                    t_ptr += 1
                    # aligned
            else:
                t_ptr += 1
                p_ptr += 1

    # print(predicted_args)
    return predicted_args

def pro():
    nlp = spacy.load('en_core_web_sm')
    nlp.tokenizer = WhitespaceTokenizer(nlp.vocab)
    ontology_dict = {}
    with open('./aida_ontology_fj-5.csv', 'r') as f:
        for lidx, line in enumerate(f):
            if lidx == 0:  # header
                continue
            fields = line.strip().split(',')
            if len(fields) < 2:
                break
            evt_type = fields[0]
            if evt_type in ontology_dict.keys():
                arguments = fields[2:]
                ontology_dict[evt_type]['template'].append(fields[1])
                for i, arg in enumerate(arguments):
                    if arg != '':
                        ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg
                        ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1)
            else:
                ontology_dict[evt_type] = {}
                arguments = fields[2:]
                ontology_dict[evt_type]['template'] = []
                ontology_dict[evt_type]['template'].append(fields[1])
                for i, arg in enumerate(arguments):
                    if arg != '':
                        ontology_dict[evt_type]['arg{}'.format(i + 1)] = arg
                        ontology_dict[evt_type][arg] = 'arg{}'.format(i + 1)

    examples = {}
    x = 0
    with open('./data/RAMS_1.0/data/test_head_coref.jsonlines', 'r') as f:
        for line in f:
            x += 1
            ex = json.loads(line.strip())
            ex['ref_evt_links'] = ex['gold_evt_links']
            ex['gold_evt_links'] = []
            examples[ex['doc_key']] = ex

    flag = {}
    y = 0
    with open('./checkpoints/gen-RAMS-pred/predictions.jsonl', 'r') as f:
        for line in f:
            y += 1
            pred = json.loads(line.strip())
            # print(pred['predicted'])
            if pred['doc_key'] in flag.keys():
                examples[pred['doc_key']]['predicted'].append(pred['predicted'])
                examples[pred['doc_key']]['gold'].append(pred['gold'])
                # 如果没有 说明这是新的事件类型
            else:
                flag[pred['doc_key']] = True
                examples[pred['doc_key']]['predicted'] = []
                examples[pred['doc_key']]['gold'] = []
                # 然后将此条数据存入
                examples[pred['doc_key']]['predicted'].append(pred['predicted'])
                examples[pred['doc_key']]['gold'].append(pred['gold'])
    # print(len(examples), x, y) 871 871 3614
    
    for ex in tqdm(examples.values()):
        if 'predicted' not in ex:# this is used for testing
            continue
        # print(ex)
        # break
        # print(ex)
        # get template  获取事件类型
        # print('nw_RC00c8620ef5810429342a1c339e6c76c1b0b9add3f6010f04482fd832')
        evt_type = get_event_type(ex)[0]
        context_words = [w for sent in ex['sentences'] for w in sent]
        # 这里的template是ontology_dict中 template 包含一个事件类型下的所有事件模板
        template = ontology_dict[evt_type.replace('n/a', 'unspecified')]['template']
        # extract argument text
        # 这里应该是提取预测文件中预测到的论元短语 ex是一条json数据 template是这条json数据对应下的模板 on是论元角色和<arg1>的映射
        # 这里ex中的predicted和gold已经包括了该事件类型下的所有论元 用列表的形式进行存储 且顺序是一一对应的
        # 这里返回的predicted_args是一个字典:
    # ex = {'predicted': [' A man attacked target using something at place in order to take something', ' Attacker attacked EgyptAir plane using something at place in order to take something', ' Attacker attacked target using a suicide belt at place in order to take something', ' Attacker attacked target using something at Flight 181 place in order to take something', ' Attacker attacked target using something at place in order to take EgyptAir Flight 181']}
    # template = ontology_dict['conflict.attack.stealrobhijack']['template']
        # print(ex)
        predicted_args = extract_args_from_template(ex, template, ontology_dict)
        # print(predicted_args)
        # break
        trigger_start = ex['evt_triggers'][0][0]
        trigger_end = ex['evt_triggers'][0][1]
        # 上面返回的predicted_args是一个字典 暂时认为是论元角色和具体论元短语的映射
        # 还没有发现doc的作用
        doc = None
        # 通过test_rams.sh文件的设置 可以发现args.head_only的值为true
        head_only = True
        if head_only:
        #     # 从原始文本中取出标记
            doc = nlp(' '.join(context_words))
        for argname in predicted_args:
            # 通过find_arg_span函数找出
            arg_span = find_arg_span(predicted_args[argname], context_words,
                                     trigger_start, trigger_end, head_only=True, doc=doc)
            # print()
            #print(arg_span)
pro()
#print(x)

# dict = {'A': 1, 'B': 2, 'C': 3}
#
# for x in dict:
#     print(x)
# if '1' in dict.keys():
#     print('aaaaaaaa')