Spaces:
Sleeping
Sleeping
NOT_PARSED_SENTINEL = object() | |
class NonConstituentException(Exception): | |
pass | |
class ConstituentData: | |
def __init__(self, starts, ends, labels, loc_to_constituent, label_vocab): | |
self.starts = starts | |
self.ends = ends | |
self.labels = labels | |
self.loc_to_constituent = loc_to_constituent | |
self.label_vocab = label_vocab | |
def get_constituent(span): | |
constituent_data = span.doc._._constituent_data | |
if constituent_data is NOT_PARSED_SENTINEL: | |
raise Exception( | |
"No constituency parse is available for this document." | |
" Consider adding a BeneparComponent to the pipeline." | |
) | |
search_start = constituent_data.loc_to_constituent[span.start] | |
if span.start + 1 < len(constituent_data.loc_to_constituent): | |
search_end = constituent_data.loc_to_constituent[span.start + 1] | |
else: | |
search_end = len(constituent_data.ends) | |
found_position = None | |
for position in range(search_start, search_end): | |
if constituent_data.ends[position] <= span.end: | |
if constituent_data.ends[position] == span.end: | |
found_position = position | |
break | |
if found_position is None: | |
raise NonConstituentException("Span is not a constituent: {}".format(span)) | |
return constituent_data, found_position | |
def get_labels(span): | |
constituent_data, position = get_constituent(span) | |
label_num = constituent_data.labels[position] | |
return constituent_data.label_vocab[label_num] | |
def parse_string(span): | |
constituent_data, position = get_constituent(span) | |
label_vocab = constituent_data.label_vocab | |
doc = span.doc | |
idx = position - 1 | |
def make_str(): | |
nonlocal idx | |
idx += 1 | |
i, j, label_idx = ( | |
constituent_data.starts[idx], | |
constituent_data.ends[idx], | |
constituent_data.labels[idx], | |
) | |
label = label_vocab[label_idx] | |
if (i + 1) >= j: | |
token = doc[i] | |
s = ( | |
"(" | |
+ u"{} {}".format(token.tag_, token.text) | |
.replace("(", "-LRB-") | |
.replace(")", "-RRB-") | |
.replace("{", "-LCB-") | |
.replace("}", "-RCB-") | |
.replace("[", "-LSB-") | |
.replace("]", "-RSB-") | |
+ ")" | |
) | |
else: | |
children = [] | |
while ( | |
(idx + 1) < len(constituent_data.starts) | |
and i <= constituent_data.starts[idx + 1] | |
and constituent_data.ends[idx + 1] <= j | |
): | |
children.append(make_str()) | |
s = u" ".join(children) | |
for sublabel in reversed(label): | |
s = u"({} {})".format(sublabel, s) | |
return s | |
return make_str() | |
def get_subconstituents(span): | |
constituent_data, position = get_constituent(span) | |
label_vocab = constituent_data.label_vocab | |
doc = span.doc | |
while position < len(constituent_data.starts): | |
start = constituent_data.starts[position] | |
end = constituent_data.ends[position] | |
if span.end <= start or span.end < end: | |
break | |
yield doc[start:end] | |
position += 1 | |
def get_child_spans(span): | |
constituent_data, position = get_constituent(span) | |
label_vocab = constituent_data.label_vocab | |
doc = span.doc | |
child_start_expected = span.start | |
position += 1 | |
while position < len(constituent_data.starts): | |
start = constituent_data.starts[position] | |
end = constituent_data.ends[position] | |
if span.end <= start or span.end < end: | |
break | |
if start == child_start_expected: | |
yield doc[start:end] | |
child_start_expected = end | |
position += 1 | |
def get_parent_span(span): | |
constituent_data, position = get_constituent(span) | |
label_vocab = constituent_data.label_vocab | |
doc = span.doc | |
sent = span.sent | |
position -= 1 | |
while position >= 0: | |
start = constituent_data.starts[position] | |
end = constituent_data.ends[position] | |
if start <= span.start and span.end <= end: | |
return doc[start:end] | |
if end < span.sent.start: | |
break | |
position -= 1 | |
return None | |
def install_spacy_extensions(): | |
from spacy.tokens import Doc, Span, Token | |
# None is not allowed as a default extension value! | |
Doc.set_extension("_constituent_data", default=NOT_PARSED_SENTINEL) | |
Span.set_extension("labels", getter=get_labels) | |
Span.set_extension("parse_string", getter=parse_string) | |
Span.set_extension("constituents", getter=get_subconstituents) | |
Span.set_extension("parent", getter=get_parent_span) | |
Span.set_extension("children", getter=get_child_spans) | |
Token.set_extension( | |
"labels", getter=lambda token: get_labels(token.doc[token.i : token.i + 1]) | |
) | |
Token.set_extension( | |
"parse_string", | |
getter=lambda token: parse_string(token.doc[token.i : token.i + 1]), | |
) | |
Token.set_extension( | |
"parent", getter=lambda token: get_parent_span(token.doc[token.i : token.i + 1]) | |
) | |
try: | |
install_spacy_extensions() | |
except ImportError: | |
pass | |