Bart-gen-arg / viz /visualize_output_KAIROS.py
adherent's picture
what is the <arg> in <trg>
6c25ddb
raw
history blame
No virus
6.93 kB
import os
import json
import argparse
from copy import deepcopy
import spacy
from spacy import displacy
import re
from collections import defaultdict
def find_head(arg_start, arg_end, doc):
cur_i = arg_start
while doc[cur_i].head.i >= arg_start and doc[cur_i].head.i <=arg_end:
if doc[cur_i].head.i == cur_i:
# self is the head
break
else:
cur_i = doc[cur_i].head.i
arg_head = cur_i
return (arg_head, arg_head)
def extract_args_from_template(predicted, template, ontology_dict, evt_type):
# extract argument text
template_words = template.strip().split()
predicted_words = predicted.strip().split()
predicted_args = defaultdict(list) # argname -> List of text
t_ptr= 0
p_ptr= 0
while t_ptr < len(template_words) and p_ptr < len(predicted_words):
if re.match(r'<(arg\d+)>', template_words[t_ptr]):
m = re.match(r'<(arg\d+)>', template_words[t_ptr])
arg_num = m.group(1)
arg_name = ontology_dict[evt_type][arg_num]
if predicted_words[p_ptr] == '<arg>':
# missing argument
p_ptr +=1
t_ptr +=1
else:
arg_start = p_ptr
while (p_ptr < len(predicted_words)) and (predicted_words[p_ptr] != template_words[t_ptr+1]):
p_ptr+=1
arg_text = predicted_words[arg_start:p_ptr]
predicted_args[arg_name].append(arg_text)
t_ptr+=1
# aligned
else:
t_ptr+=1
p_ptr+=1
return dict(predicted_args)
def find_arg_span(arg, context_words, trigger_start, trigger_end, head_only=False, doc=None):
match = None
arg_len = len(arg)
min_dis = len(context_words) # minimum distance to trigger
for i, w in enumerate(context_words):
if context_words[i:i+arg_len] == arg:
if i < trigger_start:
dis = abs(trigger_start-i-arg_len)
else:
dis = abs(i-trigger_end)
if dis< min_dis:
match = (i, i+arg_len-1)
min_dis = dis
if match and head_only:
assert(doc!=None)
match = find_head(match[0], match[1], doc)
return match
def load_ontology(dataset):
'''
Read ontology file for event to argument mapping.
'''
ontology_dict ={}
with open('event_role_{}.json'.format(dataset),'r') as f:
ontology_dict = json.load(f)
for evt_name, evt_dict in ontology_dict.items():
for i, argname in enumerate(evt_dict['roles']):
evt_dict['arg{}'.format(i+1)] = argname
# argname -> role is not a one-to-one mapping
if argname in evt_dict:
evt_dict[argname].append('arg{}'.format(i+1))
else:
evt_dict[argname] = ['arg{}'.format(i+1)]
return ontology_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--result-file',type=str, default='checkpoints/gen-KAIROS-pointer-pred/predictions.jsonl')
parser.add_argument('--test-file', type=str, default='data/kairos/test.jsonl')
parser.add_argument('--gold', action='store_true')
args = parser.parse_args()
ontology_dict = load_ontology('KAIROS')
render_dicts = []
reader= open(args.result_file, 'r')
with open(args.test_file,'r') as f:
for line in f:
doc = json.loads(line)
# use sent_id for ACE
context_words = doc['tokens']
render_dict = {
"text":' '.join(context_words),
"ents": [],
"title": '{}_gold'.format(doc['doc_id']) if args.gold else doc['doc_id'],
}
word2char = {} # word index to start, end char index (end is not inclusive)
ptr =0
for idx, w in enumerate(context_words):
word2char[idx] = (ptr, ptr+ len(w))
ptr = word2char[idx][1] +1
links = [] # (start_word, end_word, label)
for eidx, e in enumerate(doc['event_mentions']):
predicted = json.loads(reader.readline())
filled_template = predicted['predicted']
evt_type = e['event_type']
label = 'E{}-{}'.format(eidx, e['event_type'])
trigger_start= e['trigger']['start']
trigger_end = e['trigger']['end'] -1
trigger_tup = (trigger_start, trigger_end, label)
links.append(trigger_tup)
if args.gold:
# use gold arguments
for arg in e['arguments']:
label = 'E{}-{}'.format(eidx, arg['role'])
ent_id = arg['entity_id']
# get entity span
matched_ent = [entity for entity in doc['entity_mentions'] if entity['id'] == ent_id][0]
arg_start = matched_ent['start']
arg_end = matched_ent['end'] -1
links.append((arg_start, arg_end, label))
else: # use predicted arguments
template = ontology_dict[evt_type]['template']
# extract argument text
predicted_args = extract_args_from_template(filled_template,template, ontology_dict, evt_type)
# get trigger
# extract argument span
for argname in predicted_args:
for argtext in predicted_args[argname]:
arg_span = find_arg_span(argtext, context_words,
trigger_start, trigger_end, head_only=False, doc=None)
if arg_span:# if None means hullucination
label = 'E{}-{}'.format(eidx, argname)
links.append((arg_span[0], arg_span[1], label))
sorted_links = sorted(links, key=lambda x: x[0]) # sort by start idx
for tup in sorted_links:
arg_start, arg_end, arg_name = tup
label = arg_name
render_dict["ents"].append({
"start": word2char[arg_start][0],
"end": word2char[arg_end][1],
"label": label,
})
render_dicts.append(render_dict)
file_name = args.result_file.split('.')[0]
if args.gold:
file_name += '.gold'
html = displacy.render(render_dicts, style="ent", manual=True, page=True)
with open('{}.html'.format(file_name), 'w') as f:
f.write(html)