adherent's picture
what is the <arg> in <trg>
6c25ddb
raw
history blame
3.16 kB
import json
import spacy
from spacy.tokens import Doc
PRONOUN_FILE='pronoun_list.txt'
pronoun_set = set()
with open(PRONOUN_FILE, 'r') as f:
for line in f:
pronoun_set.add(line.strip())
def check_pronoun(text):
if text.lower() in pronoun_set:
return True
else:
return False
def clean_mention(text):
'''
Clean up a mention by removing 'a', 'an', 'the' prefixes.
'''
prefixes = ['the ', 'The ', 'an ', 'An ', 'a ', 'A ']
for prefix in prefixes:
if text.startswith(prefix):
return text[len(prefix):]
return text
def safe_div(num, denom):
if denom > 0:
return num / denom
else:
return 0
def compute_f1(predicted, gold, matched):
precision = safe_div(matched, predicted)
recall = safe_div(matched, gold)
f1 = safe_div(2 * precision * recall, precision + recall)
return precision, recall, f1
class WhitespaceTokenizer:
def __init__(self, vocab):
self.vocab = vocab
def __call__(self, text):
words = text.split(" ")
return Doc(self.vocab, words=words)
def find_head(arg_start, arg_end, doc):
cur_i = arg_start
while doc[cur_i].head.i >= arg_start and doc[cur_i].head.i <=arg_end:
if doc[cur_i].head.i == cur_i:
# self is the head
break
else:
cur_i = doc[cur_i].head.i
arg_head = cur_i
return (arg_head, arg_head)
def load_ontology(dataset, ontology_file=None):
'''
Read ontology file for event to argument mapping.
'''
ontology_dict ={}
if not ontology_file: # use the default file path
if not dataset:
raise ValueError
with open('event_role_{}.json'.format(dataset),'r') as f:
ontology_dict = json.load(f)
else:
with open(ontology_file,'r') as f:
ontology_dict = json.load(f)
for evt_name, evt_dict in ontology_dict.items():
for i, argname in enumerate(evt_dict['roles']):
evt_dict['arg{}'.format(i+1)] = argname
# argname -> role is not a one-to-one mapping
if argname in evt_dict:
evt_dict[argname].append('arg{}'.format(i+1))
else:
evt_dict[argname] = ['arg{}'.format(i+1)]
return ontology_dict
def find_arg_span(arg, context_words, trigger_start, trigger_end, head_only=False, doc=None):
match = None
arg_len = len(arg)
min_dis = len(context_words) # minimum distance to trigger
for i, w in enumerate(context_words):
if context_words[i:i+arg_len] == arg:
if i < trigger_start:
dis = abs(trigger_start-i-arg_len)
else:
dis = abs(i-trigger_end)
if dis< min_dis:
match = (i, i+arg_len-1)
min_dis = dis
if match and head_only:
assert(doc!=None)
match = find_head(match[0], match[1], doc)
return match
def get_entity_span(ex, entity_id):
for ent in ex['entity_mentions']:
if ent['id'] == entity_id:
return (ent['start'], ent['end'])