import os import json import argparse import re from copy import deepcopy from tqdm import tqdm from utils import find_head, WhitespaceTokenizer, find_arg_span import spacy print("convert_gen_to_output2.py") def extract_args_from_template(ex, template, ontology_dict,): # extract argument text template_words = template[0].strip().split() predicted_words = ex['predicted'].strip().split() predicted_args = {} t_ptr= 0 p_ptr= 0 evt_type = get_event_type(ex)[0] while t_ptr < len(template_words) and p_ptr < len(predicted_words): if re.match(r'<(arg\d+)>', template_words[t_ptr]): m = re.match(r'<(arg\d+)>', template_words[t_ptr]) arg_num = m.group(1) arg_name = ontology_dict[evt_type.replace('n/a','unspecified')][arg_num] if predicted_words[p_ptr] == '': # missing argument p_ptr +=1 t_ptr +=1 else: arg_start = p_ptr while (p_ptr < len(predicted_words)) and (predicted_words[p_ptr] != template_words[t_ptr+1]): p_ptr+=1 arg_text = predicted_words[arg_start:p_ptr] predicted_args[arg_name] = arg_text t_ptr+=1 # aligned else: t_ptr+=1 p_ptr+=1 return predicted_args def get_event_type(ex): evt_type = [] for evt in ex['evt_triggers']: for t in evt[2]: evt_type.append( t[0]) return evt_type def check_coref(ex, arg_span, gold_spans): for clus in ex['corefs']: if arg_span in clus: matched_gold_spans = [span for span in gold_spans if span in clus] if len(matched_gold_spans) > 0: return matched_gold_spans[0] return arg_span if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--gen-file',type=str, default='checkpoints/gen-new-tokenization-pred/sample_predictions.jsonl') parser.add_argument('--test-file', type=str,default='data/RAMS_1.0/data/test_head.jsonlines') parser.add_argument('--output-file',type=str, default='test_output.jsonl') parser.add_argument('--ontology-file',type=str, default='aida_ontology_new.csv') parser.add_argument('--head-only',action='store_true',default=False) parser.add_argument('--coref', action='store_true', default=False) args = parser.parse_args() nlp = spacy.load('en_core_web_sm') nlp.tokenizer = WhitespaceTokenizer(nlp.vocab) # read ontology 读取事件本体 模板文件中的内容 ontology_dict ={} with open('aida_ontology_new.csv','r') as f: for lidx, line in enumerate(f): # 跳过第一行表头字段 if lidx == 0:# header continue fields = line.strip().split(',') # 说明该事件类型下不存在待抽取的论元 if len(fields) < 2: break # 事件类型是第一个 evt_type = fields[0] # 从第三个元素往后都是待抽取论语及其论元角色 arguments = fields[2:] # 获取该事件类型下带带抽取的论元数量 args_len = 0 for i, arg in enumerate(arguments): if arg != '': args_len = args_len + 1 # 将事件本体字典中添加事件类型的key,该key下对应的value为模板 # 利用args_len将template中的子模板数量进行循环增加, 将后续的子模板通过字符串拼接的方式进行增加 # 最终的模板样式变为 what is the in what is the in # 先利用一个临时的字符串变量来存储模板 ----------> temp_template temp_template = [] for i in range(len(arguments)): temp_template.append(" what is the in ".format(i + 1)) print(temp_template) # 在事件本体字典中建立key-value 以事件类型为关键字 ontology_dict[evt_type] = { 'template': temp_template } for i, arg in enumerate(arguments): if arg !='': ontology_dict[evt_type]['arg{}'.format(i+1)] = arg ontology_dict[evt_type][arg] = 'arg{}'.format(i+1) examples = {} print(args.gen_file) # data/RAMS_1.0/data/test_head_coref.jsonlines with open(args.test_file, 'r') as f: for line in f: ex = json.loads(line.strip()) ex['ref_evt_links'] = deepcopy(ex['gold_evt_links']) ex['gold_evt_links'] = [] examples[ex['doc_key']] =ex # checkpoints/gen-RAMS-pred/predictions.jsonl with open(args.gen_file,'r') as f: for line in f: pred = json.loads(line.strip()) # print(pred) examples[pred['doc_key']]['predicted'] = pred['predicted'] examples[pred['doc_key']]['gold'] = pred['gold'] # checkpoints/gen-RAMS-pred/out_put.jsonl writer = open(args.output_file, 'w') for ex in tqdm(examples.values()): if 'predicted' not in ex:# this is used for testing continue # get template evt_type = get_event_type(ex)[0] context_words = [w for sent in ex['sentences'] for w in sent ] template = ontology_dict[evt_type.replace('n/a','unspecified')]['template'] # extract argument text predicted_args = extract_args_from_template(ex,template, ontology_dict) # get trigger # extract argument span trigger_start = ex['evt_triggers'][0][0] trigger_end = ex['evt_triggers'][0][1] doc = None if args.head_only: doc = nlp(' '.join(context_words)) for argname in predicted_args: arg_span = find_arg_span(predicted_args[argname], context_words, trigger_start, trigger_end, head_only=args.head_only, doc=doc) if arg_span:# if None means hullucination if args.head_only and args.coref: # consider coreferential mentions as matching assert('corefs' in ex) gold_spans = [a[1] for a in ex['ref_evt_links'] if a[2]==argname] arg_span = check_coref(ex, list(arg_span), gold_spans) ex['gold_evt_links'].append([[trigger_start, trigger_end], list(arg_span), argname]) writer.write(json.dumps(ex)+'\n') writer.close()