fillmorle-app / sftp /data_reader /concrete_srl.py
gossminn's picture
First version
6680682
raw
history blame
6.9 kB
from time import time
from typing import *
from collections import defaultdict
from concrete import (
Token, TokenList, TextSpan, MentionArgument, SituationMentionSet, SituationMention, TokenRefSequence,
Communication, EntityMention, EntityMentionSet, Entity, EntitySet, AnnotationMetadata, Sentence
)
from concrete.util import create_comm, AnalyticUUIDGeneratorFactory
from concrete.validate import validate_communication
from ..utils import Span
def _process_sentence(sent, comm_sent, aug, char_idx_offset: int):
token_list = list()
for tok_idx, (start_idx, end_idx) in enumerate(sent['tokenization']):
token_list.append(Token(
tokenIndex=tok_idx,
text=sent['sentence'][start_idx:end_idx + 1],
textSpan=TextSpan(
start=start_idx + char_idx_offset,
ending=end_idx + char_idx_offset + 1
),
))
comm_sent.tokenization.tokenList = TokenList(tokenList=token_list)
sm_list, em_dict, entity_list = list(), dict(), list()
annotation = sent['annotations'] if isinstance(sent['annotations'], Span) else Span.from_json(sent['annotations'])
for event in annotation:
char_start_idx = sent['tokenization'][event.start_idx][0]
char_end_idx = sent['tokenization'][event.end_idx][1]
sm = SituationMention(
uuid=next(aug),
text=sent['sentence'][char_start_idx: char_end_idx + 1],
situationType='EVENT',
situationKind=event.label,
argumentList=list(),
tokens=TokenRefSequence(
tokenIndexList=list(range(event.start_idx, event.end_idx + 1)),
tokenizationId=comm_sent.tokenization.uuid
),
)
for arg in event:
em = em_dict.get((arg.start_idx, arg.end_idx + 1))
if em is None:
char_start_idx = sent['tokenization'][arg.start_idx][0]
char_end_idx = sent['tokenization'][arg.end_idx][1]
em = EntityMention(next(aug), TokenRefSequence(
tokenIndexList=list(range(arg.start_idx, arg.end_idx + 1)),
tokenizationId=comm_sent.tokenization.uuid,
), text=sent['sentence'][char_start_idx: char_end_idx + 1])
entity_list.append(Entity(next(aug), id=em.text, mentionIdList=[em.uuid]))
em_dict[(arg.start_idx, arg.end_idx + 1)] = em
sm.argumentList.append(MentionArgument(
role=arg.label,
entityMentionId=em.uuid,
))
sm_list.append(sm)
return sm_list, list(em_dict.values()), entity_list
def concrete_doc(
sentences: List[Dict[str, Any]],
doc_name: str = 'document',
) -> Communication:
"""
Data format: A list of sentences. Each sentence should be a dict of the following format:
{
"sentence": String.
"tokenization": A list of Tuple[int, int] for start and end indices. Both inclusive.
"annotations": A list of event dict, or Span object.
}
If it is dict, its format should be:
Each event should be a dict of the following format:
{
"span": [start_idx, end_idx]: Integer. Both inclusive.
"label": String.
"children": A list of arguments.
}
Each argument should be a dict of the following format:
{
"span": [start_idx, end_idx]: Integer. Both inclusive.
"label": String.
}
Note the "indices" above all refer to the indices of tokens, instead of characters.
"""
comm = create_comm(
doc_name,
'\n'.join([sent['sentence'] for sent in sentences]),
)
aug = AnalyticUUIDGeneratorFactory(comm).create()
situation_mention_set = SituationMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
comm.situationMentionSetList = [situation_mention_set]
entity_mention_set = EntityMentionSet(next(aug), AnnotationMetadata('Span Finder', time()), list())
comm.entityMentionSetList = [entity_mention_set]
entity_set = EntitySet(
next(aug), AnnotationMetadata('O(0) Coref Paser.', time()), list(), None, entity_mention_set.uuid
)
comm.entitySetList = [entity_set]
assert len(sentences) == len(comm.sectionList[0].sentenceList)
char_idx_offset = 0
for sent, comm_sent in zip(sentences, comm.sectionList[0].sentenceList):
sm_list, em_list, entity_list = _process_sentence(sent, comm_sent, aug, char_idx_offset)
entity_set.entityList.extend(entity_list)
situation_mention_set.mentionList.extend(sm_list)
entity_mention_set.mentionList.extend(em_list)
char_idx_offset += len(sent['sentence']) + 1
validate_communication(comm)
return comm
def concrete_doc_tokenized(
sentences: List[List[str]],
spans: List[Span],
doc_name: str = "document",
):
"""
Similar to concrete_doc, but with tokenized words and spans.
"""
inputs = list()
for sent, vr in zip(sentences, spans):
cur_start = 0
tokenization = list()
for token in sent:
tokenization.append((cur_start, cur_start + len(token) - 1))
cur_start += len(token) + 1
inputs.append({
"sentence": " ".join(sent),
"tokenization": tokenization,
"annotations": vr
})
return concrete_doc(inputs, doc_name)
def collect_concrete_srl(comm: Communication) -> List[Tuple[List[str], Span]]:
# Mapping from <sentence uuid> to [<ConcreteSentence>, <Associated situation mentions>]
sentences = defaultdict(lambda: [None, list()])
for sec in comm.sectionList:
for sen in sec.sentenceList:
sentences[sen.uuid.uuidString][0] = sen
# Assume there's only ONE situation mention set
assert len(comm.situationMentionSetList) == 1
# Assign each situation mention to the corresponding sentence
for men in comm.situationMentionSetList[0].mentionList:
if men.tokens is None: continue # For ACE relations
sentences[men.tokens.tokenization.sentence.uuid.uuidString][1].append(men)
ret = list()
for sen, mention_list in sentences.values():
tokens = [t.text for t in sen.tokenization.tokenList.tokenList]
spans = list()
for mention in mention_list:
mention_tokens = sorted(mention.tokens.tokenIndexList)
event = Span(mention_tokens[0], mention_tokens[-1], mention.situationKind, True)
for men_arg in mention.argumentList:
arg_tokens = sorted(men_arg.entityMention.tokens.tokenIndexList)
event.add_child(Span(arg_tokens[0], arg_tokens[-1], men_arg.role, False))
spans.append(event)
vr = Span.virtual_root(spans)
ret.append((tokens, vr))
return ret