Make _decode compatible with PreTrainedTokenizerBase

#8
by Vinno97 - opened
Files changed (1) hide show
  1. tokenization_codegen25.py +4 -2
tokenization_codegen25.py CHANGED
@@ -4,7 +4,7 @@
4
  # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0
5
  """Tokenization classes for CodeGen2.5."""
6
 
7
- from typing import List, Optional
8
 
9
  from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
10
  from transformers.utils import logging
@@ -168,7 +168,9 @@ class CodeGen25Tokenizer(PreTrainedTokenizer):
168
  """Converts an index (integer) in a token (str) using the vocab."""
169
  return self.encoder.decode_single_token_bytes(index).decode("utf-8")
170
 
171
- def _decode(self, token_ids: List[int], skip_special_tokens: bool = False, **kwargs):
 
 
172
  if skip_special_tokens:
173
  token_ids = [t for t in token_ids if t not in self.all_special_ids]
174
  return self.encoder.decode(token_ids)
 
4
  # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/Apache-2.0
5
  """Tokenization classes for CodeGen2.5."""
6
 
7
+ from typing import List, Optional, Union
8
 
9
  from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
10
  from transformers.utils import logging
 
168
  """Converts an index (integer) in a token (str) using the vocab."""
169
  return self.encoder.decode_single_token_bytes(index).decode("utf-8")
170
 
171
+ def _decode(self, token_ids: Union[int, List[int]], skip_special_tokens: bool = False, **kwargs):
172
+ if isinstance(token_ids, int):
173
+ token_ids = [token_ids]
174
  if skip_special_tokens:
175
  token_ids = [t for t in token_ids if t not in self.all_special_ids]
176
  return self.encoder.decode(token_ids)