KoichiYasuoka's picture
initial release
9235968
import os
from transformers import PreTrainedTokenizerFast
from transformers.models.bert_japanese.tokenization_bert_japanese import MecabTokenizer
try:
from transformers.utils import cached_file
except:
from transformers.file_utils import cached_path,hf_bucket_url
cached_file=lambda x,y:os.path.join(x,y) if os.path.isdir(x) else cached_path(hf_bucket_url(x,y))
class MecabPreTokenizer(MecabTokenizer):
def mecab_split(self,i,normalized_string):
t=str(normalized_string)
z=[]
e=0
for c in self.tokenize(t):
s=t.find(c,e)
e=e if s<0 else s+len(c)
z.append((0,0) if s<0 else (s,e))
return [normalized_string[s:e] for s,e in z if e>0]
def pre_tokenize(self,pretok):
pretok.split(self.mecab_split)
class JumanPreTrainedTokenizerFast(PreTrainedTokenizerFast):
def __init__(self,**kwargs):
from tokenizers.pre_tokenizers import PreTokenizer,Whitespace,Sequence
super().__init__(**kwargs)
d,r="/var/lib/mecab/dic/juman-utf8","/etc/mecabrc"
if not (os.path.isdir(d) and os.path.isfile(r)):
import zipfile
import tempfile
self.dicdir=tempfile.TemporaryDirectory()
d=self.dicdir.name
with zipfile.ZipFile(cached_file(self.name_or_path,"mecab-jumandic-utf8.zip")) as z:
z.extractall(d)
r=os.path.join(d,"mecabrc")
with open(r,"w",encoding="utf-8") as w:
print("dicdir =",d,file=w)
self.custom_pre_tokenizer=Sequence([PreTokenizer.custom(MecabPreTokenizer(mecab_dic=None,mecab_option="-d "+d+" -r "+r)),Whitespace()])
self._tokenizer.pre_tokenizer=self.custom_pre_tokenizer
def save_pretrained(self,save_directory,**kwargs):
import shutil
from tokenizers.pre_tokenizers import Whitespace
self._auto_map={"AutoTokenizer":[None,"juman.JumanPreTrainedTokenizerFast"]}
self._tokenizer.pre_tokenizer=Whitespace()
super().save_pretrained(save_directory,**kwargs)
self._tokenizer.pre_tokenizer=self.custom_pre_tokenizer
shutil.copy(os.path.abspath(__file__),os.path.join(save_directory,"juman.py"))
shutil.copy(cached_file(self.name_or_path,"mecab-jumandic-utf8.zip"),os.path.join(save_directory,"mecab-jumandic-utf8.zip"))