KaleiNeely commited on
Commit
551c1fa
1 Parent(s): 42b5867

Update tokenization_rwkv_world.py

Browse files
Files changed (1) hide show
  1. tokenization_rwkv_world.py +221 -87
tokenization_rwkv_world.py CHANGED
@@ -12,38 +12,20 @@
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
- """Tokenization classes for OpenAI GPT."""
16
 
17
  import json
18
  import os
19
  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
20
- from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
21
- from transformers.utils import logging, to_py_obj
22
- from transformers.tokenization_utils_base import BatchEncoding
23
-
24
- import bisect
25
- import itertools
26
- import re
27
- import unicodedata
28
- from collections import OrderedDict
29
- from typing import Any, Dict, List, Optional, Tuple, Union, overload
30
 
 
31
  from transformers.tokenization_utils_base import (
32
- ENCODE_KWARGS_DOCSTRING,
33
- ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
34
- INIT_TOKENIZER_DOCSTRING,
35
- AddedToken,
36
  BatchEncoding,
37
  EncodedInput,
38
- EncodedInputPair,
39
- PreTokenizedInput,
40
- PreTokenizedInputPair,
41
- PreTrainedTokenizerBase,
42
  TextInput,
43
- TextInputPair,
44
  TruncationStrategy,
45
  )
46
- from transformers.utils import PaddingStrategy, TensorType, add_end_docstrings, logging
47
 
48
 
49
  if TYPE_CHECKING:
@@ -54,11 +36,18 @@ logger = logging.get_logger(__name__)
54
  VOCAB_FILES_NAMES = {
55
  "vocab_file": "rwkv_vocab_v20230424.txt",
56
  }
 
 
 
 
 
 
57
 
58
  class TRIE:
59
  __slots__ = tuple("ch,to,values,front".split(","))
60
- to:list
61
- values:set
 
62
  def __init__(self, front=None, ch=None):
63
  self.ch = ch
64
  self.to = [None for ch in range(256)]
@@ -68,67 +57,59 @@ class TRIE:
68
  def __repr__(self):
69
  fr = self
70
  ret = []
71
- while(fr!=None):
72
- if(fr.ch!=None):
73
  ret.append(fr.ch)
74
  fr = fr.front
75
- return "<TRIE %s %s>"%(ret[::-1], self.values)
76
-
77
- def add(self, key:bytes, idx:int=0, val=None):
78
- if(idx == len(key)):
79
- if(val is None):
80
  val = key
81
  self.values.add(val)
82
  return self
83
  ch = key[idx]
84
- if(self.to[ch] is None):
85
  self.to[ch] = TRIE(front=self, ch=ch)
86
- return self.to[ch].add(key, idx=idx+1, val=val)
87
-
88
- def find_longest(self, key:bytes, idx:int=0):
89
- u:TRIE = self
90
- ch:int = key[idx]
91
-
92
- while(u.to[ch] is not None):
93
  u = u.to[ch]
94
  idx += 1
95
- if(u.values):
96
  ret = idx, u, u.values
97
- if(idx==len(key)):
98
  break
99
  ch = key[idx]
100
  return ret
101
 
 
102
  class RWKVWorldTokenizer(PreTrainedTokenizer):
103
  vocab_files_names = VOCAB_FILES_NAMES
104
  model_input_names = ["input_ids", "attention_mask"]
105
 
106
- def __init__(
107
- self,
108
- vocab_file,
109
- errors="replace",
110
- **kwargs
111
- ):
112
  self.add_bos_token = False
113
  self.encoder = {}
114
- sorted = [] # must be already sorted
115
  with open(vocab_file, "r", encoding="utf-8") as f:
116
  lines = f.readlines()
117
  for l in lines:
118
- idx = int(l[:l.index(' ')])
119
- x = eval(l[l.index(' '):l.rindex(' ')])
120
  x = x.encode("utf-8") if isinstance(x, str) else x
121
  assert isinstance(x, bytes)
122
- assert len(x) == int(l[l.rindex(' '):])
123
  sorted += [x]
124
  self.encoder[idx] = x
125
 
126
- super().__init__(
127
- errors=errors,
128
- **kwargs,
129
- )
130
  self.decoder = {}
131
- for k,v in self.encoder.items():
132
  self.decoder[v] = int(k)
133
 
134
  self.trie = TRIE()
@@ -136,6 +117,23 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
136
  _ = self.trie.add(t, val=(t, i))
137
  self.errors = errors # how to handle errors in decoding
138
  self.cache = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  @property
141
  def vocab_size(self):
@@ -144,6 +142,22 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
144
  def get_vocab(self):
145
  return dict(self.encoder, **self.added_tokens_encoder)
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
148
  if self.add_bos_token:
149
  bos_token_ids = [self.bos_token_id]
@@ -158,8 +172,7 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
158
  return output + bos_token_ids + token_ids_1
159
 
160
  def get_special_tokens_mask(
161
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None,
162
- already_has_special_tokens: bool = False
163
  ) -> List[int]:
164
  """
165
  Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
@@ -190,19 +203,19 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
190
  return [1] + ([0] * len(token_ids_0))
191
  return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
192
 
193
- def encodeBytes(self, src:bytes):
194
- idx:int = 0
195
  tokens = []
196
- while (idx < len(src)):
197
- _idx:int = idx
198
  idx, _, values = self.trie.find_longest(src, idx)
199
- assert(idx != _idx)
200
- _, token = next(iter(values))
201
  tokens.append(token)
202
  return tokens
203
-
204
  def decodeBytes(self, tokens):
205
- return b''.join(map(lambda i: self.encoder[i], tokens))
206
 
207
  def _tokenize(self, text, **kwargs):
208
  """Tokenize a string."""
@@ -210,23 +223,30 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
210
 
211
  def _decode_tokens(self, tokens):
212
  try:
213
- return self.decodeBytes(tokens).decode('utf-8')
214
- except:
215
- return '\ufffd' # bad utf-8
216
 
217
- def _decode(self,
218
- token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
219
- skip_special_tokens: bool = False,
220
- **kwargs
221
- ) -> str:
 
 
 
 
 
222
 
223
  # Convert inputs to python lists
224
  token_ids = to_py_obj(token_ids)
 
225
  if isinstance(token_ids, int):
226
  if token_ids in self.all_special_ids and skip_special_tokens:
227
  return ""
228
  return self.encoder.get(token_ids, self.unk_token)
229
  elif isinstance(token_ids, list):
 
230
  out_str = ""
231
  out_last = 0
232
  out_tokens = []
@@ -235,7 +255,7 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
235
  break
236
  out_tokens += [token]
237
  tmp = self._decode_tokens(out_tokens[out_last:])
238
- if '\ufffd' not in tmp:
239
  out_str += tmp
240
  out_last = i + 1
241
  return out_str
@@ -268,6 +288,11 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
268
  def prepare_for_tokenization(self, text, **kwargs):
269
  return (text, kwargs)
270
 
 
 
 
 
 
271
  def _encode_plus(
272
  self,
273
  text: Union[TextInput, EncodedInput],
@@ -285,16 +310,29 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
285
  return_offsets_mapping: bool = False,
286
  return_length: bool = False,
287
  verbose: bool = True,
288
- **kwargs
289
  ) -> BatchEncoding:
290
- def get_input_ids(text):
 
 
 
291
  if isinstance(text, str):
292
- text_id = self._tokenize(text)
293
- return text_id
 
 
 
294
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
295
- return [self._tokenize(t) for t in text]
 
 
 
 
296
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
 
 
297
  return text
 
298
  else:
299
  raise ValueError(
300
  "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
@@ -350,16 +388,29 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
350
  return_offsets_mapping: bool = False,
351
  return_length: bool = False,
352
  verbose: bool = True,
353
- **kwargs
354
  ) -> BatchEncoding:
355
- def get_input_ids(text):
 
 
 
356
  if isinstance(text, str):
357
- text_id = self._tokenize(text)
358
- return text_id
 
 
 
359
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
360
- return [self._tokenize(t) for t in text]
 
 
 
 
361
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
 
 
362
  return text
 
363
  else:
364
  raise ValueError(
365
  "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
@@ -372,15 +423,29 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
372
  "transformers.PreTrainedTokenizerFast."
373
  )
374
 
375
- input_ids = []
 
376
  for ids_or_pair_ids in batch_text_or_text_pairs:
377
  if not isinstance(ids_or_pair_ids, (list, tuple)):
378
  ids, pair_ids = ids_or_pair_ids, None
379
  else:
380
  ids, pair_ids = ids_or_pair_ids
381
-
382
  first_ids = get_input_ids(ids)
383
  second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  input_ids.append((first_ids, second_ids))
385
 
386
  batch_outputs = self._batch_prepare_for_model(
@@ -402,10 +467,79 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
402
 
403
  return BatchEncoding(batch_outputs)
404
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
406
  input_ids = []
407
  for is_user, text in conversation.iter_texts():
408
  input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
409
  if len(input_ids) > self.model_max_length:
410
- input_ids = input_ids[-self.model_max_length:]
411
  return input_ids
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
+ """Tokenization classes for RWKV5."""
16
 
17
  import json
18
  import os
19
  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
 
 
 
 
 
 
 
 
 
 
20
 
21
+ from transformers.tokenization_utils import PreTrainedTokenizer
22
  from transformers.tokenization_utils_base import (
 
 
 
 
23
  BatchEncoding,
24
  EncodedInput,
 
 
 
 
25
  TextInput,
 
26
  TruncationStrategy,
27
  )
28
+ from transformers.utils import PaddingStrategy, TensorType, logging, to_py_obj
29
 
30
 
31
  if TYPE_CHECKING:
 
36
  VOCAB_FILES_NAMES = {
37
  "vocab_file": "rwkv_vocab_v20230424.txt",
38
  }
39
+ PRETRAINED_VOCAB_FILES_MAP = {
40
+ "vocab_file": {
41
+ "RWKV/rwkv-5-world-169m": "https://huggingface.co/RWKV/rwkv-5-world-169m/blob/main/rwkv_vocab_v20230424.txt",
42
+ },
43
+ }
44
+
45
 
46
  class TRIE:
47
  __slots__ = tuple("ch,to,values,front".split(","))
48
+ to: list
49
+ values: set
50
+
51
  def __init__(self, front=None, ch=None):
52
  self.ch = ch
53
  self.to = [None for ch in range(256)]
 
57
  def __repr__(self):
58
  fr = self
59
  ret = []
60
+ while fr is not None:
61
+ if fr.ch is not None:
62
  ret.append(fr.ch)
63
  fr = fr.front
64
+ return "<TRIE %s %s>" % (ret[::-1], self.values)
65
+
66
+ def add(self, key: bytes, idx: int = 0, val=None):
67
+ if idx == len(key):
68
+ if val is None:
69
  val = key
70
  self.values.add(val)
71
  return self
72
  ch = key[idx]
73
+ if self.to[ch] is None:
74
  self.to[ch] = TRIE(front=self, ch=ch)
75
+ return self.to[ch].add(key, idx=idx + 1, val=val)
76
+
77
+ def find_longest(self, key: bytes, idx: int = 0):
78
+ u: TRIE = self
79
+ ch: int = key[idx]
80
+
81
+ while u.to[ch] is not None:
82
  u = u.to[ch]
83
  idx += 1
84
+ if u.values:
85
  ret = idx, u, u.values
86
+ if idx == len(key):
87
  break
88
  ch = key[idx]
89
  return ret
90
 
91
+
92
  class RWKVWorldTokenizer(PreTrainedTokenizer):
93
  vocab_files_names = VOCAB_FILES_NAMES
94
  model_input_names = ["input_ids", "attention_mask"]
95
 
96
+ def __init__(self, vocab_file, errors="replace", pad_token="0", **kwargs):
 
 
 
 
 
97
  self.add_bos_token = False
98
  self.encoder = {}
99
+ sorted = [] # must be already sorted
100
  with open(vocab_file, "r", encoding="utf-8") as f:
101
  lines = f.readlines()
102
  for l in lines:
103
+ idx = int(l[: l.index(" ")])
104
+ x = eval(l[l.index(" ") : l.rindex(" ")])
105
  x = x.encode("utf-8") if isinstance(x, str) else x
106
  assert isinstance(x, bytes)
107
+ assert len(x) == int(l[l.rindex(" ") :])
108
  sorted += [x]
109
  self.encoder[idx] = x
110
 
 
 
 
 
111
  self.decoder = {}
112
+ for k, v in self.encoder.items():
113
  self.decoder[v] = int(k)
114
 
115
  self.trie = TRIE()
 
117
  _ = self.trie.add(t, val=(t, i))
118
  self.errors = errors # how to handle errors in decoding
119
  self.cache = {}
120
+ self.first_max_length = 0
121
+ super().__init__(
122
+ errors=errors,
123
+ **kwargs,
124
+ )
125
+
126
+ @property
127
+ def eos_token_id(self) -> Optional[int]:
128
+ return 0
129
+
130
+ @property
131
+ def eot_token_id(self) -> Optional[int]:
132
+ return 0
133
+
134
+ @property
135
+ def pad_token_id(self) -> Optional[int]:
136
+ return 0
137
 
138
  @property
139
  def vocab_size(self):
 
142
  def get_vocab(self):
143
  return dict(self.encoder, **self.added_tokens_encoder)
144
 
145
+ def add_tokens(self, new_tokens, special_tokens: bool = False):
146
+ for token in new_tokens:
147
+ token_id = self.convert_tokens_to_ids(token)
148
+ self.added_tokens_decoder[token_id] = token
149
+
150
+ def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
151
+ if isinstance(ids, int):
152
+ ids = [ids]
153
+ tokens = []
154
+ for id_ in ids:
155
+ if id_ in self.added_tokens_decoder:
156
+ tokens.append(self.added_tokens_decoder[id_])
157
+ else:
158
+ tokens.append(self._convert_id_to_token(id_))
159
+ return tokens
160
+
161
  def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
162
  if self.add_bos_token:
163
  bos_token_ids = [self.bos_token_id]
 
172
  return output + bos_token_ids + token_ids_1
173
 
174
  def get_special_tokens_mask(
175
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
 
176
  ) -> List[int]:
177
  """
178
  Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
 
203
  return [1] + ([0] * len(token_ids_0))
204
  return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
205
 
206
+ def encodeBytes(self, src: bytes):
207
+ idx: int = 0
208
  tokens = []
209
+ while idx < len(src):
210
+ _idx: int = idx
211
  idx, _, values = self.trie.find_longest(src, idx)
212
+ assert idx != _idx
213
+ _, token = next(iter(values))
214
  tokens.append(token)
215
  return tokens
216
+
217
  def decodeBytes(self, tokens):
218
+ return b"".join(map(lambda i: self.encoder[i], tokens)) # noqa
219
 
220
  def _tokenize(self, text, **kwargs):
221
  """Tokenize a string."""
 
223
 
224
  def _decode_tokens(self, tokens):
225
  try:
226
+ return self.decodeBytes(tokens).decode("utf-8")
227
+ except Exception:
228
+ return "\ufffd" # bad utf-8
229
 
230
+ def _decode(
231
+ self,
232
+ token_ids: Union[int, List[int]],
233
+ skip_special_tokens: bool = False,
234
+ **kwargs,
235
+ ) -> str:
236
+ def remove_zeros_from_first_segment(token_ids, first_max_length):
237
+ first_segment = token_ids[:first_max_length]
238
+ first_segment_cleaned = [token for token in first_segment if token != 0]
239
+ return first_segment_cleaned + token_ids[first_max_length:]
240
 
241
  # Convert inputs to python lists
242
  token_ids = to_py_obj(token_ids)
243
+ token_ids = remove_zeros_from_first_segment(token_ids, self.first_max_length)
244
  if isinstance(token_ids, int):
245
  if token_ids in self.all_special_ids and skip_special_tokens:
246
  return ""
247
  return self.encoder.get(token_ids, self.unk_token)
248
  elif isinstance(token_ids, list):
249
+ self.first_max_length
250
  out_str = ""
251
  out_last = 0
252
  out_tokens = []
 
255
  break
256
  out_tokens += [token]
257
  tmp = self._decode_tokens(out_tokens[out_last:])
258
+ if "\ufffd" not in tmp:
259
  out_str += tmp
260
  out_last = i + 1
261
  return out_str
 
288
  def prepare_for_tokenization(self, text, **kwargs):
289
  return (text, kwargs)
290
 
291
+ def _get_padding_truncation_strategies(
292
+ self, padding=False, truncation=None, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
293
+ ):
294
+ return PaddingStrategy.LONGEST, TruncationStrategy.DO_NOT_TRUNCATE, -1, kwargs
295
+
296
  def _encode_plus(
297
  self,
298
  text: Union[TextInput, EncodedInput],
 
310
  return_offsets_mapping: bool = False,
311
  return_length: bool = False,
312
  verbose: bool = True,
313
+ **kwargs,
314
  ) -> BatchEncoding:
315
+ def get_input_ids(text, max_length=None, pad_token_id=0):
316
+ def pad_sequence(seq, max_len, pad_tok):
317
+ return [pad_tok] * (max_len - len(seq)) + seq
318
+
319
  if isinstance(text, str):
320
+ tokens = self._tokenize(text)
321
+ if max_length is not None:
322
+ tokens = pad_sequence(tokens, max_length, pad_token_id)
323
+ return tokens
324
+
325
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
326
+ tokenized_texts = [self._tokenize(t) for t in text]
327
+ if max_length is None:
328
+ max_length = max(len(t) for t in tokenized_texts)
329
+ return [pad_sequence(t, max_length, pad_token_id) for t in tokenized_texts]
330
+
331
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
332
+ if max_length is not None and len(text) < max_length:
333
+ return pad_sequence(text, max_length, pad_token_id)
334
  return text
335
+
336
  else:
337
  raise ValueError(
338
  "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
 
388
  return_offsets_mapping: bool = False,
389
  return_length: bool = False,
390
  verbose: bool = True,
391
+ **kwargs,
392
  ) -> BatchEncoding:
393
+ def get_input_ids(text, max_length=None, pad_token_id=0):
394
+ def pad_sequence(seq, max_len, pad_tok):
395
+ return [pad_tok] * (max_len - len(seq)) + seq
396
+
397
  if isinstance(text, str):
398
+ tokens = self._tokenize(text)
399
+ if max_length is not None:
400
+ tokens = pad_sequence(tokens, max_length, pad_token_id)
401
+ return tokens
402
+
403
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
404
+ tokenized_texts = [self._tokenize(t) for t in text]
405
+ if max_length is None:
406
+ max_length = max(len(t) for t in tokenized_texts)
407
+ return [pad_sequence(t, max_length, pad_token_id) for t in tokenized_texts]
408
+
409
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
410
+ if max_length is not None and len(text) < max_length:
411
+ return pad_sequence(text, max_length, pad_token_id)
412
  return text
413
+
414
  else:
415
  raise ValueError(
416
  "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
 
423
  "transformers.PreTrainedTokenizerFast."
424
  )
425
 
426
+ first_max_length = 0
427
+ second_max_length = 0
428
  for ids_or_pair_ids in batch_text_or_text_pairs:
429
  if not isinstance(ids_or_pair_ids, (list, tuple)):
430
  ids, pair_ids = ids_or_pair_ids, None
431
  else:
432
  ids, pair_ids = ids_or_pair_ids
 
433
  first_ids = get_input_ids(ids)
434
  second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
435
+ first_max_length = max(first_max_length, len(first_ids))
436
+ if second_ids is not None:
437
+ second_max_length = max(second_max_length, len(second_ids))
438
+
439
+ self.first_max_length = first_max_length
440
+ input_ids = []
441
+ for ids_or_pair_ids in batch_text_or_text_pairs:
442
+ if not isinstance(ids_or_pair_ids, (list, tuple)):
443
+ ids, pair_ids = ids_or_pair_ids, None
444
+ else:
445
+ ids, pair_ids = ids_or_pair_ids
446
+
447
+ first_ids = get_input_ids(ids, max_length=first_max_length)
448
+ second_ids = get_input_ids(pair_ids, max_length=second_max_length) if pair_ids is not None else None
449
  input_ids.append((first_ids, second_ids))
450
 
451
  batch_outputs = self._batch_prepare_for_model(
 
467
 
468
  return BatchEncoding(batch_outputs)
469
 
470
+ def decode(
471
+ self,
472
+ token_ids: Union[int, List[int]],
473
+ skip_special_tokens: bool = False,
474
+ clean_up_tokenization_spaces: bool = None,
475
+ **kwargs,
476
+ ) -> str:
477
+ """
478
+ Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
479
+ tokens and clean up tokenization spaces.
480
+
481
+ Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
482
+
483
+ Args:
484
+ token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
485
+ List of tokenized input ids. Can be obtained using the `__call__` method.
486
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
487
+ Whether or not to remove special tokens in the decoding.
488
+ clean_up_tokenization_spaces (`bool`, *optional*):
489
+ Whether or not to clean up the tokenization spaces. If `None`, will default to
490
+ `self.clean_up_tokenization_spaces`.
491
+ kwargs (additional keyword arguments, *optional*):
492
+ Will be passed to the underlying model specific decode method.
493
+
494
+ Returns:
495
+ `str`: The decoded sentence.
496
+ """
497
+ # Convert inputs to python lists
498
+ return self._decode(
499
+ token_ids=token_ids,
500
+ skip_special_tokens=skip_special_tokens,
501
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
502
+ **kwargs,
503
+ )
504
+
505
+ def batch_decode(
506
+ self,
507
+ sequences: Union[List[int], List[List[int]]],
508
+ skip_special_tokens: bool = False,
509
+ clean_up_tokenization_spaces: bool = None,
510
+ **kwargs,
511
+ ) -> List[str]:
512
+ """
513
+ Convert a list of lists of token ids into a list of strings by calling decode.
514
+
515
+ Args:
516
+ sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
517
+ List of tokenized input ids. Can be obtained using the `__call__` method.
518
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
519
+ Whether or not to remove special tokens in the decoding.
520
+ clean_up_tokenization_spaces (`bool`, *optional*):
521
+ Whether or not to clean up the tokenization spaces. If `None`, will default to
522
+ `self.clean_up_tokenization_spaces`.
523
+ kwargs (additional keyword arguments, *optional*):
524
+ Will be passed to the underlying model specific decode method.
525
+
526
+ Returns:
527
+ `List[str]`: The list of decoded sentences.
528
+ """
529
+ return [
530
+ self.decode(
531
+ seq,
532
+ skip_special_tokens=skip_special_tokens,
533
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
534
+ **kwargs,
535
+ )
536
+ for seq in sequences
537
+ ]
538
+
539
  def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
540
  input_ids = []
541
  for is_user, text in conversation.iter_texts():
542
  input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
543
  if len(input_ids) > self.model_max_length:
544
+ input_ids = input_ids[-self.model_max_length :]
545
  return input_ids