|
import re
|
|
from collections import defaultdict
|
|
from operator import itemgetter
|
|
from pathlib import Path
|
|
from re import Pattern
|
|
from typing import Optional, Callable, Iterable, Dict, Tuple
|
|
|
|
from spacy.lang.hu import Hungarian
|
|
from spacy.language import Language
|
|
from spacy.lookups import Lookups, Table
|
|
from spacy.pipeline import Pipe
|
|
from spacy.pipeline.lemmatizer import lemmatizer_score
|
|
from spacy.tokens import Token
|
|
from spacy.tokens.doc import Doc
|
|
|
|
|
|
from spacy.training.example import Example
|
|
from spacy.util import ensure_path
|
|
|
|
|
|
class LookupLemmatizer(Pipe):
|
|
"""
|
|
LookupLemmatizer learn `(token, pos, morph. feat) -> lemma` mappings during training, and applies them at prediction
|
|
time.
|
|
"""
|
|
|
|
_number_pattern: Pattern = re.compile(r"\d")
|
|
|
|
|
|
@staticmethod
|
|
@Hungarian.factory(
|
|
"lookup_lemmatizer",
|
|
assigns=["token.lemma"],
|
|
requires=["token.pos"],
|
|
default_config={"scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"}, "source": ""},
|
|
)
|
|
def create(nlp: Language, name: str, scorer: Optional[Callable], source: str) -> "LookupLemmatizer":
|
|
return LookupLemmatizer(None, source, scorer)
|
|
|
|
def train(self, sentences: Iterable[Iterable[Tuple[str, str, str, str]]], min_occurrences: int = 1) -> None:
|
|
"""
|
|
|
|
Args:
|
|
sentences (Iterable[Iterable[Tuple[str, str, str, str]]]): Sentences to learn the mappings from
|
|
min_occurrences (int): mapping occurring less than this threshold are not learned
|
|
|
|
"""
|
|
|
|
|
|
|
|
lemma_lookup_table: Dict[Tuple[str, str], Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
|
|
|
for sentence in sentences:
|
|
for token, pos, feats, lemma in sentence:
|
|
token = self.__mask_numbers(token)
|
|
lemma = self.__mask_numbers(lemma)
|
|
feats_str = ("|" + feats) if feats else ""
|
|
key = (token, pos + feats_str)
|
|
lemma_lookup_table[key][lemma] += 1
|
|
lemma_lookup_table = dict(lemma_lookup_table)
|
|
|
|
self._lookups = Lookups()
|
|
table = Table(name="lemma_lookups")
|
|
|
|
lemma_freq: Dict[str, int]
|
|
for (form, pos), lemma_freq in dict(lemma_lookup_table).items():
|
|
most_freq_lemma, freq = sorted(lemma_freq.items(), key=itemgetter(1), reverse=True)[0]
|
|
if freq >= min_occurrences:
|
|
if form not in table:
|
|
|
|
table[form]: Dict[str, str] = dict()
|
|
table[form][pos] = most_freq_lemma
|
|
|
|
self._lookups.set_table(name=f"lemma_lookups", table=table)
|
|
|
|
def __init__(
|
|
self,
|
|
lookups: Optional[Lookups] = None,
|
|
source: Optional[str] = None,
|
|
scorer: Optional[Callable] = lemmatizer_score,
|
|
):
|
|
self._lookups: Optional[Lookups] = lookups
|
|
self.scorer = scorer
|
|
self.source = source
|
|
|
|
def __call__(self, doc: Doc) -> Doc:
|
|
assert self._lookups is not None, "Lookup table should be initialized first"
|
|
|
|
token: Token
|
|
for token in doc:
|
|
lemma_lookup_table = self._lookups.get_table(f"lemma_lookups")
|
|
masked_token = self.__mask_numbers(token.text)
|
|
|
|
if masked_token in lemma_lookup_table:
|
|
lemma_by_pos: Dict[str, str] = lemma_lookup_table[masked_token]
|
|
feats_str = ("|" + str(token.morph)) if str(token.morph) else ""
|
|
key = token.pos_ + feats_str
|
|
if key in lemma_by_pos:
|
|
if masked_token != token.text:
|
|
|
|
token.lemma_ = self.__replace_numbers(lemma_by_pos[key], token.text)
|
|
pass
|
|
else:
|
|
token.lemma_ = lemma_by_pos[key]
|
|
return doc
|
|
|
|
|
|
def to_disk(self, path, exclude=tuple()):
|
|
assert self._lookups is not None, "Lookup table should be initialized first"
|
|
|
|
path: Path = ensure_path(path)
|
|
path.mkdir(exist_ok=True)
|
|
self._lookups.to_disk(path)
|
|
|
|
|
|
def from_disk(self, path, exclude=tuple()) -> "LookupLemmatizer":
|
|
path: Path = ensure_path(path)
|
|
lookups = Lookups()
|
|
self._lookups = lookups.from_disk(path=path)
|
|
return self
|
|
|
|
def initialize(self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language = None) -> None:
|
|
lookups = Lookups()
|
|
self._lookups = lookups.from_disk(path=self.source)
|
|
|
|
@classmethod
|
|
def __mask_numbers(cls, token: str) -> str:
|
|
return cls._number_pattern.sub("0", token)
|
|
|
|
@classmethod
|
|
def __replace_numbers(cls, lemma: str, token: str) -> str:
|
|
return cls._number_pattern.sub(lambda match: token[match.start()], lemma)
|
|
|