CLIPAway / AlphaCLIP /eval /rec_zs_test /entity_extraction.py
hpc-yekin
initial commit
92e0882
raw
history blame
5.58 kB
from typing import Dict, Any, Callable, List, Tuple, NamedTuple, Text, Optional
import numpy as np
from spacy.tokens.token import Token
from spacy.tokens.span import Span
from lattice import Product as L
from heuristics import Heuristics
Rel = Tuple[List[Token], "Entity"]
Sup = List[Token]
DEFAULT_HEURISTICS = Heuristics()
def find_superlatives(tokens, heuristics) -> List[Sup]:
"""Modify and return a list of superlative tokens."""
for heuristic in heuristics.superlatives:
if any(tok.text in heuristic.keywords for tok in tokens):
tokens.sort(key=lambda tok: tok.i)
return [tokens]
return []
def expand_chunks(doc, chunks):
expanded = {}
for key in chunks:
chunk = chunks[key]
start = chunk.start
end = chunk.end
for i in range(chunk.start-1, -1, -1):
if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)):
if not any(any(doc[i].is_ancestor(doc[j]) for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2):
start = i
for i in range(chunk.end, len(doc)):
if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)):
if not any(any(doc[i].is_ancestor(doc[j]) or i == j for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2):
end = i+1
else:
break
expanded[key] = Span(doc=doc, start=start, end=end)
return expanded
class Entity(NamedTuple):
"""Represents an entity with locative constraints extracted from the parse."""
head: Span
relations: List[Rel]
superlatives: List[Sup]
@classmethod
def extract(cls, head, chunks, heuristics: Optional[Heuristics] = None) -> "Entity":
"""Extract entities from a spacy parse.
Jointly recursive with `_get_rel_sups`."""
if heuristics is None:
heuristics = DEFAULT_HEURISTICS
if head.i not in chunks:
# Handles predicative cases.
children = list(head.children)
if children and children[0].i in chunks:
head = children[0]
# TODO: Also extract predicative relations.
else:
return None
hchunk = chunks[head.i]
rels, sups = cls._get_rel_sups(head, head, [], chunks, heuristics)
return cls(hchunk, rels, sups)
@classmethod
def _get_rel_sups(cls, token, head, tokens, chunks, heuristics) -> Tuple[List[Rel], List[Sup]]:
hchunk = chunks[head.i]
is_keyword = any(token.text in h.keywords for h in heuristics.relations)
is_keyword |= token.text in heuristics.null_keywords
# Found another entity head.
if token.i in chunks and chunks[token.i] is not hchunk and not is_keyword:
tchunk = chunks[token.i]
tokens.sort(key=lambda tok: tok.i)
subhead = cls.extract(token, chunks, heuristics)
return [(tokens, subhead)], []
# End of a chain of modifiers.
n_children = len(list(token.children))
if n_children == 0:
return [], find_superlatives(tokens + [token], heuristics)
relations = []
superlatives = []
is_keyword |= any(token.text in h.keywords for h in heuristics.superlatives)
for child in token.children:
if token.i in chunks and child.i in chunks and chunks[token.i] is chunks[child.i]:
if not any(child.text in h.keywords for h in heuristics.superlatives):
if n_children == 1:
# Catches "the goat on the left"
sups = find_superlatives(tokens + [token], heuristics)
superlatives.extend(sups)
continue
new_tokens = tokens + [token] if token.i not in chunks or is_keyword else tokens
subrel, subsup = cls._get_rel_sups(child, head, new_tokens, chunks, heuristics)
relations.extend(subrel)
superlatives.extend(subsup)
return relations, superlatives
def expand(self, span: Span = None):
tokens = [token for token in self.head]
if span is None:
span = [None]
for target_token in span:
include = False
stack = [token for token in self.head]
while len(stack) > 0:
token = stack.pop()
if token == target_token:
token2 = target_token.head
while token2.head != token2:
tokens.append(token2)
token2 = token2.head
tokens.append(token2)
stack = []
include = True
if target_token is None or include:
tokens.append(token)
for child in token.children:
stack.append(child)
tokens = list(set(tokens))
tokens = sorted(tokens, key=lambda x: x.i)
return ' '.join([token.text for token in tokens])
def __eq__(self, other: "Entity") -> bool:
if self.text != other.text:
return False
if self.relations != other.relations:
return False
if self.superlatives != other.superlatives:
return False
return True
@property
def text(self) -> Text:
"""Get the text predicate associated with this entity."""
return self.head.text