tianxie-sf commited on
Commit
5fe0f1b
1 Parent(s): 517e720

update _convert_id_to_token (#11)

Browse files

- update _convert_id_to_token (f6ff02a0461b229744a70cf8816019e75adf6591)

Files changed (1) hide show
  1. tokenization_xgen.py +13 -7
tokenization_xgen.py CHANGED
@@ -149,20 +149,22 @@ class XgenTokenizer(PreTrainedTokenizer):
149
  def _convert_token_to_id(self, token):
150
  """Converts a token (str) in an id using the vocab."""
151
  if isinstance(token, str):
152
- ids = self._tokenize(token)
153
- return ids[0]
154
- return token
155
 
156
  def _convert_id_to_token(self, index):
157
  """Converts an index (integer) in a token (str) using the vocab."""
158
- return self.encoder.decode_single_token_bytes(index)
159
 
160
  def _decode(self, token_ids: List[int], skip_special_tokens: bool = False, **kwargs):
 
 
161
  return self.encoder.decode(token_ids)
162
 
163
  def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
164
  """Build model inputs from a sequence by appending eos_token_id."""
165
- eos_token_id = [50256] if self.add_eos_token else []
166
 
167
  output = token_ids_0 + eos_token_id
168
 
@@ -218,11 +220,15 @@ class XgenTokenizer(PreTrainedTokenizer):
218
  Returns:
219
  `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
220
  """
221
- eos_token_id = [50256] if self.add_eos_token else []
222
 
223
  output = [0] * len(token_ids_0 + eos_token_id)
224
 
225
  if token_ids_1 is not None:
226
  output += [1] * len(token_ids_1 + eos_token_id)
227
 
228
- return output
 
 
 
 
 
149
  def _convert_token_to_id(self, token):
150
  """Converts a token (str) in an id using the vocab."""
151
  if isinstance(token, str):
152
+ return self.encoder.encode_single_token(token)
153
+ else:
154
+ return token
155
 
156
  def _convert_id_to_token(self, index):
157
  """Converts an index (integer) in a token (str) using the vocab."""
158
+ return self.encoder.decode_single_token_bytes(index).decode("utf-8")
159
 
160
  def _decode(self, token_ids: List[int], skip_special_tokens: bool = False, **kwargs):
161
+ if skip_special_tokens:
162
+ token_ids = [t for t in token_ids if t not in self.all_special_ids]
163
  return self.encoder.decode(token_ids)
164
 
165
  def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
166
  """Build model inputs from a sequence by appending eos_token_id."""
167
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
168
 
169
  output = token_ids_0 + eos_token_id
170
 
 
220
  Returns:
221
  `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
222
  """
223
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
224
 
225
  output = [0] * len(token_ids_0 + eos_token_id)
226
 
227
  if token_ids_1 is not None:
228
  output += [1] * len(token_ids_1 + eos_token_id)
229
 
230
+ return output
231
+
232
+ # has no vocab file
233
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None):
234
+ return ()