update tokenizer code for compatibility with latest transformers
Browse files
tokenization_codegen25.py
CHANGED
@@ -133,15 +133,14 @@ class CodeGen25Tokenizer(PreTrainedTokenizer):
|
|
133 |
):
|
134 |
pad_token_added = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
135 |
eos_token_added = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
|
|
136 |
super().__init__(
|
137 |
pad_token=pad_token_added,
|
138 |
eos_token=eos_token_added,
|
139 |
add_eos_token=add_eos_token,
|
140 |
-
add_special_tokens=add_special_tokens,
|
141 |
**kwargs,
|
142 |
)
|
143 |
self.add_eos_token = add_eos_token
|
144 |
-
self.encoder = tiktoken_tokenizer(base="gpt2", pad_token=pad_token, add_special=add_special_tokens)
|
145 |
|
146 |
@property
|
147 |
def vocab_size(self):
|
@@ -166,7 +165,11 @@ class CodeGen25Tokenizer(PreTrainedTokenizer):
|
|
166 |
|
167 |
def _convert_id_to_token(self, index):
|
168 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
169 |
-
|
|
|
|
|
|
|
|
|
170 |
|
171 |
def _decode(self, token_ids: List[int], skip_special_tokens: bool = False, **kwargs):
|
172 |
if skip_special_tokens:
|
|
|
133 |
):
|
134 |
pad_token_added = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
135 |
eos_token_added = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
136 |
+
self.encoder = tiktoken_tokenizer(base="gpt2", pad_token=pad_token, add_special=add_special_tokens)
|
137 |
super().__init__(
|
138 |
pad_token=pad_token_added,
|
139 |
eos_token=eos_token_added,
|
140 |
add_eos_token=add_eos_token,
|
|
|
141 |
**kwargs,
|
142 |
)
|
143 |
self.add_eos_token = add_eos_token
|
|
|
144 |
|
145 |
@property
|
146 |
def vocab_size(self):
|
|
|
165 |
|
166 |
def _convert_id_to_token(self, index):
|
167 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
168 |
+
try:
|
169 |
+
token = self.encoder.decode_single_token_bytes(index).decode("utf-8")
|
170 |
+
except Exception:
|
171 |
+
token = ""
|
172 |
+
return token
|
173 |
|
174 |
def _decode(self, token_ids: List[int], skip_special_tokens: bool = False, **kwargs):
|
175 |
if skip_special_tokens:
|