enhg-parsing / benepar /integrations /spacy_extensions.py
nielklug's picture
add parsing
8778cfe
NOT_PARSED_SENTINEL = object()
class NonConstituentException(Exception):
pass
class ConstituentData:
def __init__(self, starts, ends, labels, loc_to_constituent, label_vocab):
self.starts = starts
self.ends = ends
self.labels = labels
self.loc_to_constituent = loc_to_constituent
self.label_vocab = label_vocab
def get_constituent(span):
constituent_data = span.doc._._constituent_data
if constituent_data is NOT_PARSED_SENTINEL:
raise Exception(
"No constituency parse is available for this document."
" Consider adding a BeneparComponent to the pipeline."
)
search_start = constituent_data.loc_to_constituent[span.start]
if span.start + 1 < len(constituent_data.loc_to_constituent):
search_end = constituent_data.loc_to_constituent[span.start + 1]
else:
search_end = len(constituent_data.ends)
found_position = None
for position in range(search_start, search_end):
if constituent_data.ends[position] <= span.end:
if constituent_data.ends[position] == span.end:
found_position = position
break
if found_position is None:
raise NonConstituentException("Span is not a constituent: {}".format(span))
return constituent_data, found_position
def get_labels(span):
constituent_data, position = get_constituent(span)
label_num = constituent_data.labels[position]
return constituent_data.label_vocab[label_num]
def parse_string(span):
constituent_data, position = get_constituent(span)
label_vocab = constituent_data.label_vocab
doc = span.doc
idx = position - 1
def make_str():
nonlocal idx
idx += 1
i, j, label_idx = (
constituent_data.starts[idx],
constituent_data.ends[idx],
constituent_data.labels[idx],
)
label = label_vocab[label_idx]
if (i + 1) >= j:
token = doc[i]
s = (
"("
+ u"{} {}".format(token.tag_, token.text)
.replace("(", "-LRB-")
.replace(")", "-RRB-")
.replace("{", "-LCB-")
.replace("}", "-RCB-")
.replace("[", "-LSB-")
.replace("]", "-RSB-")
+ ")"
)
else:
children = []
while (
(idx + 1) < len(constituent_data.starts)
and i <= constituent_data.starts[idx + 1]
and constituent_data.ends[idx + 1] <= j
):
children.append(make_str())
s = u" ".join(children)
for sublabel in reversed(label):
s = u"({} {})".format(sublabel, s)
return s
return make_str()
def get_subconstituents(span):
constituent_data, position = get_constituent(span)
label_vocab = constituent_data.label_vocab
doc = span.doc
while position < len(constituent_data.starts):
start = constituent_data.starts[position]
end = constituent_data.ends[position]
if span.end <= start or span.end < end:
break
yield doc[start:end]
position += 1
def get_child_spans(span):
constituent_data, position = get_constituent(span)
label_vocab = constituent_data.label_vocab
doc = span.doc
child_start_expected = span.start
position += 1
while position < len(constituent_data.starts):
start = constituent_data.starts[position]
end = constituent_data.ends[position]
if span.end <= start or span.end < end:
break
if start == child_start_expected:
yield doc[start:end]
child_start_expected = end
position += 1
def get_parent_span(span):
constituent_data, position = get_constituent(span)
label_vocab = constituent_data.label_vocab
doc = span.doc
sent = span.sent
position -= 1
while position >= 0:
start = constituent_data.starts[position]
end = constituent_data.ends[position]
if start <= span.start and span.end <= end:
return doc[start:end]
if end < span.sent.start:
break
position -= 1
return None
def install_spacy_extensions():
from spacy.tokens import Doc, Span, Token
# None is not allowed as a default extension value!
Doc.set_extension("_constituent_data", default=NOT_PARSED_SENTINEL)
Span.set_extension("labels", getter=get_labels)
Span.set_extension("parse_string", getter=parse_string)
Span.set_extension("constituents", getter=get_subconstituents)
Span.set_extension("parent", getter=get_parent_span)
Span.set_extension("children", getter=get_child_spans)
Token.set_extension(
"labels", getter=lambda token: get_labels(token.doc[token.i : token.i + 1])
)
Token.set_extension(
"parse_string",
getter=lambda token: parse_string(token.doc[token.i : token.i + 1]),
)
Token.set_extension(
"parent", getter=lambda token: get_parent_span(token.doc[token.i : token.i + 1])
)
try:
install_spacy_extensions()
except ImportError:
pass