if001 commited on
Commit
6c63bd9
1 Parent(s): fb7d963
Files changed (1) hide show
  1. sentencepiece_ja.py +56 -0
sentencepiece_ja.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union, List, Optional, Tuple
3
+
4
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
5
+
6
+ class SentencePieceJA(PreTrainedTokenizer):
7
+ def __init__(self, model_path, **kwargs):
8
+ super().__init__(**kwargs)
9
+ from tokenizers import Tokenizer
10
+ self._tokenizer = Tokenizer.from_file(model_path)
11
+ self.__pad_id = self._tokenize("<PAD>")[0]
12
+ self.__bos_id = self._tokenize("<BOS>")[0]
13
+ self.__eos_id = self._tokenize("<EOS>")[0]
14
+ self.__unk_id = self._tokenize("<UNK>")[0]
15
+ self.__mask_id = self._tokenize("<MASK>")[0]
16
+
17
+ def get_vocab(self) -> int:
18
+ return self._tokenizer.get_vocab()
19
+
20
+ def vocab_size(self) -> int:
21
+ return self._tokenizer.get_vocab_size()
22
+
23
+ def _tokenize(self, text, **kwargs):
24
+ return self._tokenizer.encode(text).ids
25
+
26
+ def _convert_token_to_id(self, token):
27
+ return token
28
+
29
+ def _convert_id_to_token(self, index: int) -> str:
30
+ return self._tokenizer.decode(index)
31
+ # return self._tokenizer.id_to_token(index)
32
+
33
+ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
34
+ return tokens
35
+
36
+ def convert_ids_to_tokens(
37
+ self, ids: Union[int, List[int]], skip_special_tokens: bool = False
38
+ ) -> Union[str, List[str]]:
39
+ decoded = self._tokenizer.decode(ids)
40
+ return decoded
41
+
42
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
43
+ index = 0
44
+ if os.path.isdir(save_directory):
45
+ vocab_file = os.path.join(
46
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + 'vocab.txt'
47
+ )
48
+ else:
49
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
50
+ with open(vocab_file, "w", encoding="utf-8") as writer:
51
+ for token, token_index in sorted(self.get_vocab().items(), key=lambda kv: kv[1]):
52
+ if index != token_index:
53
+ index = token_index
54
+ writer.write(token + "\n")
55
+ index += 1
56
+ return (vocab_file,)