Spaces:
Build error
Build error
File size: 6,926 Bytes
6c25ddb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import os
import json
import argparse
from copy import deepcopy
import spacy
from spacy import displacy
import re
from collections import defaultdict
def find_head(arg_start, arg_end, doc):
cur_i = arg_start
while doc[cur_i].head.i >= arg_start and doc[cur_i].head.i <=arg_end:
if doc[cur_i].head.i == cur_i:
# self is the head
break
else:
cur_i = doc[cur_i].head.i
arg_head = cur_i
return (arg_head, arg_head)
def extract_args_from_template(predicted, template, ontology_dict, evt_type):
# extract argument text
template_words = template.strip().split()
predicted_words = predicted.strip().split()
predicted_args = defaultdict(list) # argname -> List of text
t_ptr= 0
p_ptr= 0
while t_ptr < len(template_words) and p_ptr < len(predicted_words):
if re.match(r'<(arg\d+)>', template_words[t_ptr]):
m = re.match(r'<(arg\d+)>', template_words[t_ptr])
arg_num = m.group(1)
arg_name = ontology_dict[evt_type][arg_num]
if predicted_words[p_ptr] == '<arg>':
# missing argument
p_ptr +=1
t_ptr +=1
else:
arg_start = p_ptr
while (p_ptr < len(predicted_words)) and (predicted_words[p_ptr] != template_words[t_ptr+1]):
p_ptr+=1
arg_text = predicted_words[arg_start:p_ptr]
predicted_args[arg_name].append(arg_text)
t_ptr+=1
# aligned
else:
t_ptr+=1
p_ptr+=1
return dict(predicted_args)
def find_arg_span(arg, context_words, trigger_start, trigger_end, head_only=False, doc=None):
match = None
arg_len = len(arg)
min_dis = len(context_words) # minimum distance to trigger
for i, w in enumerate(context_words):
if context_words[i:i+arg_len] == arg:
if i < trigger_start:
dis = abs(trigger_start-i-arg_len)
else:
dis = abs(i-trigger_end)
if dis< min_dis:
match = (i, i+arg_len-1)
min_dis = dis
if match and head_only:
assert(doc!=None)
match = find_head(match[0], match[1], doc)
return match
def load_ontology(dataset):
'''
Read ontology file for event to argument mapping.
'''
ontology_dict ={}
with open('event_role_{}.json'.format(dataset),'r') as f:
ontology_dict = json.load(f)
for evt_name, evt_dict in ontology_dict.items():
for i, argname in enumerate(evt_dict['roles']):
evt_dict['arg{}'.format(i+1)] = argname
# argname -> role is not a one-to-one mapping
if argname in evt_dict:
evt_dict[argname].append('arg{}'.format(i+1))
else:
evt_dict[argname] = ['arg{}'.format(i+1)]
return ontology_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--result-file',type=str, default='checkpoints/gen-KAIROS-pointer-pred/predictions.jsonl')
parser.add_argument('--test-file', type=str, default='data/kairos/test.jsonl')
parser.add_argument('--gold', action='store_true')
args = parser.parse_args()
ontology_dict = load_ontology('KAIROS')
render_dicts = []
reader= open(args.result_file, 'r')
with open(args.test_file,'r') as f:
for line in f:
doc = json.loads(line)
# use sent_id for ACE
context_words = doc['tokens']
render_dict = {
"text":' '.join(context_words),
"ents": [],
"title": '{}_gold'.format(doc['doc_id']) if args.gold else doc['doc_id'],
}
word2char = {} # word index to start, end char index (end is not inclusive)
ptr =0
for idx, w in enumerate(context_words):
word2char[idx] = (ptr, ptr+ len(w))
ptr = word2char[idx][1] +1
links = [] # (start_word, end_word, label)
for eidx, e in enumerate(doc['event_mentions']):
predicted = json.loads(reader.readline())
filled_template = predicted['predicted']
evt_type = e['event_type']
label = 'E{}-{}'.format(eidx, e['event_type'])
trigger_start= e['trigger']['start']
trigger_end = e['trigger']['end'] -1
trigger_tup = (trigger_start, trigger_end, label)
links.append(trigger_tup)
if args.gold:
# use gold arguments
for arg in e['arguments']:
label = 'E{}-{}'.format(eidx, arg['role'])
ent_id = arg['entity_id']
# get entity span
matched_ent = [entity for entity in doc['entity_mentions'] if entity['id'] == ent_id][0]
arg_start = matched_ent['start']
arg_end = matched_ent['end'] -1
links.append((arg_start, arg_end, label))
else: # use predicted arguments
template = ontology_dict[evt_type]['template']
# extract argument text
predicted_args = extract_args_from_template(filled_template,template, ontology_dict, evt_type)
# get trigger
# extract argument span
for argname in predicted_args:
for argtext in predicted_args[argname]:
arg_span = find_arg_span(argtext, context_words,
trigger_start, trigger_end, head_only=False, doc=None)
if arg_span:# if None means hullucination
label = 'E{}-{}'.format(eidx, argname)
links.append((arg_span[0], arg_span[1], label))
sorted_links = sorted(links, key=lambda x: x[0]) # sort by start idx
for tup in sorted_links:
arg_start, arg_end, arg_name = tup
label = arg_name
render_dict["ents"].append({
"start": word2char[arg_start][0],
"end": word2char[arg_end][1],
"label": label,
})
render_dicts.append(render_dict)
file_name = args.result_file.split('.')[0]
if args.gold:
file_name += '.gold'
html = displacy.render(render_dicts, style="ent", manual=True, page=True)
with open('{}.html'.format(file_name), 'w') as f:
f.write(html)
|