SkyWork commited on
Commit
a852f27
1 Parent(s): 6fcd10a

Update tokenization_sky.py

Browse files
Files changed (1) hide show
  1. tokenization_sky.py +3 -3
tokenization_sky.py CHANGED
@@ -325,7 +325,7 @@ class SkyTokenizer(PreTrainedTokenizer):
325
 
326
  def _tokenize(self, text, **kwargs):
327
  """Tokenize a string."""
328
- return self.trie.match(text, **kwargs)
329
 
330
  def _decode(self,
331
  token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
@@ -393,7 +393,7 @@ class SkyTokenizer(PreTrainedTokenizer):
393
  ) -> BatchEncoding:
394
  def get_input_ids(text):
395
  if isinstance(text, str):
396
- text_id = self.trie.match(text)
397
  return text_id
398
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
399
  return [self.trie.match(t, unk_id=self.unk_token_id) for t in text]
@@ -458,7 +458,7 @@ class SkyTokenizer(PreTrainedTokenizer):
458
  ) -> BatchEncoding:
459
  def get_input_ids(text):
460
  if isinstance(text, str):
461
- text_id = self.trie.match(text)
462
  return text_id
463
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
464
  return [self.trie.match(t, unk_id=self.unk_token_id) for t in text]
 
325
 
326
  def _tokenize(self, text, **kwargs):
327
  """Tokenize a string."""
328
+ return self.trie.match(text, unk_id=self.unk_token_id, **kwargs)
329
 
330
  def _decode(self,
331
  token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
 
393
  ) -> BatchEncoding:
394
  def get_input_ids(text):
395
  if isinstance(text, str):
396
+ text_id = self.trie.match(text, unk_id=self.unk_token_id)
397
  return text_id
398
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
399
  return [self.trie.match(t, unk_id=self.unk_token_id) for t in text]
 
458
  ) -> BatchEncoding:
459
  def get_input_ids(text):
460
  if isinstance(text, str):
461
+ text_id = self.trie.match(text, unk_id=self.unk_token_id)
462
  return text_id
463
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
464
  return [self.trie.match(t, unk_id=self.unk_token_id) for t in text]