Spaces:
Running
on
Zero
Running
on
Zero
from typing import Dict, Any, Callable, List, Tuple, NamedTuple, Text, Optional | |
import numpy as np | |
from spacy.tokens.token import Token | |
from spacy.tokens.span import Span | |
from lattice import Product as L | |
from heuristics import Heuristics | |
Rel = Tuple[List[Token], "Entity"] | |
Sup = List[Token] | |
DEFAULT_HEURISTICS = Heuristics() | |
def find_superlatives(tokens, heuristics) -> List[Sup]: | |
"""Modify and return a list of superlative tokens.""" | |
for heuristic in heuristics.superlatives: | |
if any(tok.text in heuristic.keywords for tok in tokens): | |
tokens.sort(key=lambda tok: tok.i) | |
return [tokens] | |
return [] | |
def expand_chunks(doc, chunks): | |
expanded = {} | |
for key in chunks: | |
chunk = chunks[key] | |
start = chunk.start | |
end = chunk.end | |
for i in range(chunk.start-1, -1, -1): | |
if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)): | |
if not any(any(doc[i].is_ancestor(doc[j]) for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2): | |
start = i | |
for i in range(chunk.end, len(doc)): | |
if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)): | |
if not any(any(doc[i].is_ancestor(doc[j]) or i == j for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2): | |
end = i+1 | |
else: | |
break | |
expanded[key] = Span(doc=doc, start=start, end=end) | |
return expanded | |
class Entity(NamedTuple): | |
"""Represents an entity with locative constraints extracted from the parse.""" | |
head: Span | |
relations: List[Rel] | |
superlatives: List[Sup] | |
def extract(cls, head, chunks, heuristics: Optional[Heuristics] = None) -> "Entity": | |
"""Extract entities from a spacy parse. | |
Jointly recursive with `_get_rel_sups`.""" | |
if heuristics is None: | |
heuristics = DEFAULT_HEURISTICS | |
if head.i not in chunks: | |
# Handles predicative cases. | |
children = list(head.children) | |
if children and children[0].i in chunks: | |
head = children[0] | |
# TODO: Also extract predicative relations. | |
else: | |
return None | |
hchunk = chunks[head.i] | |
rels, sups = cls._get_rel_sups(head, head, [], chunks, heuristics) | |
return cls(hchunk, rels, sups) | |
def _get_rel_sups(cls, token, head, tokens, chunks, heuristics) -> Tuple[List[Rel], List[Sup]]: | |
hchunk = chunks[head.i] | |
is_keyword = any(token.text in h.keywords for h in heuristics.relations) | |
is_keyword |= token.text in heuristics.null_keywords | |
# Found another entity head. | |
if token.i in chunks and chunks[token.i] is not hchunk and not is_keyword: | |
tchunk = chunks[token.i] | |
tokens.sort(key=lambda tok: tok.i) | |
subhead = cls.extract(token, chunks, heuristics) | |
return [(tokens, subhead)], [] | |
# End of a chain of modifiers. | |
n_children = len(list(token.children)) | |
if n_children == 0: | |
return [], find_superlatives(tokens + [token], heuristics) | |
relations = [] | |
superlatives = [] | |
is_keyword |= any(token.text in h.keywords for h in heuristics.superlatives) | |
for child in token.children: | |
if token.i in chunks and child.i in chunks and chunks[token.i] is chunks[child.i]: | |
if not any(child.text in h.keywords for h in heuristics.superlatives): | |
if n_children == 1: | |
# Catches "the goat on the left" | |
sups = find_superlatives(tokens + [token], heuristics) | |
superlatives.extend(sups) | |
continue | |
new_tokens = tokens + [token] if token.i not in chunks or is_keyword else tokens | |
subrel, subsup = cls._get_rel_sups(child, head, new_tokens, chunks, heuristics) | |
relations.extend(subrel) | |
superlatives.extend(subsup) | |
return relations, superlatives | |
def expand(self, span: Span = None): | |
tokens = [token for token in self.head] | |
if span is None: | |
span = [None] | |
for target_token in span: | |
include = False | |
stack = [token for token in self.head] | |
while len(stack) > 0: | |
token = stack.pop() | |
if token == target_token: | |
token2 = target_token.head | |
while token2.head != token2: | |
tokens.append(token2) | |
token2 = token2.head | |
tokens.append(token2) | |
stack = [] | |
include = True | |
if target_token is None or include: | |
tokens.append(token) | |
for child in token.children: | |
stack.append(child) | |
tokens = list(set(tokens)) | |
tokens = sorted(tokens, key=lambda x: x.i) | |
return ' '.join([token.text for token in tokens]) | |
def __eq__(self, other: "Entity") -> bool: | |
if self.text != other.text: | |
return False | |
if self.relations != other.relations: | |
return False | |
if self.superlatives != other.superlatives: | |
return False | |
return True | |
def text(self) -> Text: | |
"""Get the text predicate associated with this entity.""" | |
return self.head.text | |