import copy from typing import List from tokenizers import NormalizedString, PreTokenizedString, normalizers, pre_tokenizers from transformers import DebertaV2TokenizerFast class DebertaV2JumanppTokenizerFast(DebertaV2TokenizerFast): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.juman_normalizer = normalizers.Sequence( [ # cf. https://github.com/ku-nlp/rhoknp/blob/v1.3.0/src/rhoknp/units/sentence.py#L36 normalizers.Replace("\r", ""), normalizers.Replace("\n", ""), # cf. https://github.com/ku-nlp/jumanpp/blob/v2.0.0-rc3/src/jumandic/shared/juman_format.cc#L44-L61 normalizers.Replace("\t", "\\t"), normalizers.Replace(" ", " "), normalizers.Replace('"', "”"), normalizers.Replace("<", "<"), normalizers.Replace(">", ">"), ] ) self.juman_pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JumanppPreTokenizer()) self.default_normalizer = copy.deepcopy(self.backend_tokenizer.normalizer) self.default_pre_tokenizer = copy.deepcopy(self.backend_tokenizer.pre_tokenizer) self.backend_tokenizer.normalizer = normalizers.Sequence( [self.juman_normalizer, self.backend_tokenizer.normalizer] ) self.backend_tokenizer.pre_tokenizer = pre_tokenizers.Sequence( [self.juman_pre_tokenizer, self.backend_tokenizer.pre_tokenizer] ) def save_pretrained(self, *args, **kwargs): self.backend_tokenizer.normalizer = self.default_normalizer self.backend_tokenizer.pre_tokenizer = self.default_pre_tokenizer super().save_pretrained(*args, **kwargs) self.backend_tokenizer.normalizer = normalizers.Sequence( [self.juman_normalizer, self.backend_tokenizer.normalizer] ) self.backend_tokenizer.pre_tokenizer = pre_tokenizers.Sequence( [self.juman_pre_tokenizer, self.backend_tokenizer.pre_tokenizer] ) class JumanppPreTokenizer: def __init__(self): try: import rhoknp except ImportError: raise ImportError( "You need to install rhoknp to use JumanppPreTokenizer. " "See https://github.com/ku-nlp/rhoknp for installation." ) self.rhoknp = rhoknp self.jumanpp = rhoknp.Jumanpp() def pre_tokenize(self, pretok: PreTokenizedString): pretok.split(self.jumanpp_split) def jumanpp_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]: try: offsets = [morpheme.span for morpheme in self.jumanpp.apply_to_sentence(str(normalized_string)).morphemes] except RuntimeError: doc = self.rhoknp.Document.from_raw_text(str(normalized_string)) offsets = [morpheme.span for morpheme in self.jumanpp.apply_to_document(doc).morphemes] return [normalized_string[offset[0]:offset[1]] for offset in offsets]