deberta-v2-base-japanese / tokenization_deberta_v2_jumanpp_fast.py
tealgreen0503's picture
feat: add custom fast tokenizer
dc384f8
raw
history blame
2.8 kB
import copy
from tokenizers import NormalizedString, PreTokenizedString, normalizers, pre_tokenizers
from transformers import DebertaV2TokenizerFast
class DebertaV2JumanppTokenizerFast(DebertaV2TokenizerFast):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.juman_normalizer = normalizers.Sequence(
[
# cf. https://github.com/ku-nlp/rhoknp/blob/v1.3.0/src/rhoknp/units/sentence.py#L36
normalizers.Replace("\r", ""),
normalizers.Replace("\n", ""),
# cf. https://github.com/ku-nlp/jumanpp/blob/v2.0.0-rc3/src/jumandic/shared/juman_format.cc#L44-L61
normalizers.Replace("\t", "\\t"),
normalizers.Replace(" ", " "),
normalizers.Replace('"', "”"),
normalizers.Replace("<", "<"),
normalizers.Replace(">", ">"),
]
)
self.juman_pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JumanppPreTokenizer())
self.default_normalizer = copy.deepcopy(self.backend_tokenizer.normalizer)
self.default_pre_tokenizer = copy.deepcopy(self.backend_tokenizer.pre_tokenizer)
self.backend_tokenizer.normalizer = normalizers.Sequence(
[self.juman_normalizer, self.backend_tokenizer.normalizer]
)
self.backend_tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[self.juman_pre_tokenizer, self.backend_tokenizer.pre_tokenizer]
)
def save_pretrained(self, *args, **kwargs):
self.backend_tokenizer.normalizer = self.default_normalizer
self.backend_tokenizer.pre_tokenizer = self.default_pre_tokenizer
super().save_pretrained(*args, **kwargs)
self.backend_tokenizer.normalizer = normalizers.Sequence(
[self.juman_normalizer, self.backend_tokenizer.normalizer]
)
self.backend_tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[self.juman_pre_tokenizer, self.backend_tokenizer.pre_tokenizer]
)
class JumanppPreTokenizer:
def __init__(self):
try:
import rhoknp
except ImportError:
raise ImportError(
"You need to install rhoknp to use JumanppPreTokenizer. "
"See https://github.com/ku-nlp/rhoknp for installation."
)
self.juman = rhoknp.Jumanpp()
def pre_tokenize(self, pretok: PreTokenizedString):
pretok.split(self.jumanpp_split)
def jumanpp_split(self, i: int, normalized_string: NormalizedString) -> list[NormalizedString]:
offsets = [morpheme.span for morpheme in self.juman.apply_to_sentence(str(normalized_string)).morphemes]
return [normalized_string[offset[0]:offset[1]] for offset in offsets]