Update decode method in tokenizer
Browse files- tokenization_chatglm.py +20 -7
tokenization_chatglm.py
CHANGED
@@ -31,6 +31,9 @@ class TextTokenizer:
|
|
31 |
def tokenize(self, text):
|
32 |
return self.sp.EncodeAsPieces(text)
|
33 |
|
|
|
|
|
|
|
34 |
def convert_tokens_to_ids(self, tokens):
|
35 |
return [self.sp.PieceToId(token) for token in tokens]
|
36 |
|
@@ -111,16 +114,25 @@ class SPTokenizer:
|
|
111 |
tokens = [x + self.num_image_tokens for x in tmp]
|
112 |
return tokens if add_dummy_prefix else tokens[2:]
|
113 |
|
114 |
-
def
|
115 |
-
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
|
116 |
-
ids = [_id for _id in ids if _id >= 0]
|
117 |
-
text = self._get_text_tokenizer().decode(ids)
|
118 |
text = text.replace("<n>", "\n")
|
119 |
text = text.replace(SPTokenizer.get_tab_token(), "\t")
|
120 |
for i in range(2, self.max_blank_length + 1):
|
121 |
text = text.replace(self.get_blank_token(i), " " * i)
|
122 |
return text
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
def tokenize(
|
125 |
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
126 |
) -> List[str]:
|
@@ -256,11 +268,12 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
256 |
|
257 |
return seq
|
258 |
|
|
|
|
|
|
|
259 |
def _decode(
|
260 |
self,
|
261 |
token_ids: Union[int, List[int]],
|
262 |
-
skip_special_tokens: bool = False,
|
263 |
-
clean_up_tokenization_spaces: bool = True,
|
264 |
**kwargs
|
265 |
) -> str:
|
266 |
if isinstance(token_ids, int):
|
@@ -269,7 +282,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
269 |
return ""
|
270 |
if self.pad_token_id in token_ids: # remove pad
|
271 |
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
272 |
-
return
|
273 |
|
274 |
def _convert_token_to_id(self, token):
|
275 |
""" Converts a token (str) in an id using the vocab. """
|
|
|
31 |
def tokenize(self, text):
|
32 |
return self.sp.EncodeAsPieces(text)
|
33 |
|
34 |
+
def convert_tokens_to_string(self, tokens):
|
35 |
+
return self.sp.DecodePieces(tokens)
|
36 |
+
|
37 |
def convert_tokens_to_ids(self, tokens):
|
38 |
return [self.sp.PieceToId(token) for token in tokens]
|
39 |
|
|
|
114 |
tokens = [x + self.num_image_tokens for x in tmp]
|
115 |
return tokens if add_dummy_prefix else tokens[2:]
|
116 |
|
117 |
+
def postprocess(self, text):
|
|
|
|
|
|
|
118 |
text = text.replace("<n>", "\n")
|
119 |
text = text.replace(SPTokenizer.get_tab_token(), "\t")
|
120 |
for i in range(2, self.max_blank_length + 1):
|
121 |
text = text.replace(self.get_blank_token(i), " " * i)
|
122 |
return text
|
123 |
|
124 |
+
def decode(self, text_ids: List[int]) -> str:
|
125 |
+
ids = [int(_id) - self.num_image_tokens for _id in text_ids]
|
126 |
+
ids = [_id for _id in ids if _id >= 0]
|
127 |
+
text = self._get_text_tokenizer().decode(ids)
|
128 |
+
text = self.postprocess(text)
|
129 |
+
return text
|
130 |
+
|
131 |
+
def decode_tokens(self, tokens: List[str]) -> str:
|
132 |
+
text = self._get_text_tokenizer().convert_tokens_to_string(tokens)
|
133 |
+
text = self.postprocess(text)
|
134 |
+
return text
|
135 |
+
|
136 |
def tokenize(
|
137 |
self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
|
138 |
) -> List[str]:
|
|
|
268 |
|
269 |
return seq
|
270 |
|
271 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
272 |
+
return self.sp_tokenizer.decode_tokens(tokens)
|
273 |
+
|
274 |
def _decode(
|
275 |
self,
|
276 |
token_ids: Union[int, List[int]],
|
|
|
|
|
277 |
**kwargs
|
278 |
) -> str:
|
279 |
if isinstance(token_ids, int):
|
|
|
282 |
return ""
|
283 |
if self.pad_token_id in token_ids: # remove pad
|
284 |
token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
|
285 |
+
return super()._decode(token_ids, **kwargs)
|
286 |
|
287 |
def _convert_token_to_id(self, token):
|
288 |
""" Converts a token (str) in an id using the vocab. """
|