khalidalt commited on
Commit
d0269df
1 Parent(s): f2abfa4

Upload 4 files

Browse files
rwkv_vocab_v20230424.json ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
tokenization_rwkv_world.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for OpenAI GPT."""
16
+
17
+ import json
18
+ import os
19
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
20
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
21
+ from transformers.utils import logging, to_py_obj
22
+ from transformers.tokenization_utils_base import BatchEncoding
23
+
24
+ import bisect
25
+ import itertools
26
+ import re
27
+ import unicodedata
28
+ from collections import OrderedDict
29
+ from typing import Any, Dict, List, Optional, Tuple, Union, overload
30
+
31
+ from transformers.tokenization_utils_base import (
32
+ ENCODE_KWARGS_DOCSTRING,
33
+ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
34
+ INIT_TOKENIZER_DOCSTRING,
35
+ AddedToken,
36
+ BatchEncoding,
37
+ EncodedInput,
38
+ EncodedInputPair,
39
+ PreTokenizedInput,
40
+ PreTokenizedInputPair,
41
+ PreTrainedTokenizerBase,
42
+ TextInput,
43
+ TextInputPair,
44
+ TruncationStrategy,
45
+ )
46
+ from transformers.utils import PaddingStrategy, TensorType, add_end_docstrings, logging
47
+
48
+
49
+ if TYPE_CHECKING:
50
+ from transformers.pipelines.conversational import Conversation
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ VOCAB_FILES_NAMES = {
55
+ "vocab_file": "rwkv_vocab_v20230424.json",
56
+ }
57
+
58
+
59
+ class DATrie:
60
+ class Node:
61
+ def __init__(self, is_leaf=False, leaf_data=None, tail=""):
62
+ self._is_leaf = is_leaf
63
+ self._leaf_data = leaf_data
64
+ self._tail = tail
65
+ self._next_map = {}
66
+
67
+ def is_leaf(self):
68
+ return self._is_leaf
69
+
70
+ def set_leaf(self):
71
+ self._is_leaf = True
72
+
73
+ def has_next(self, w):
74
+ if w in self._next_map:
75
+ return True
76
+ return False
77
+
78
+ def add_node(self, w, node):
79
+ self._next_map[w] = node
80
+
81
+ def get_node(self, w):
82
+ if w in self._next_map:
83
+ return self._next_map[w]
84
+ return None
85
+
86
+ def get_tail(self):
87
+ return self._tail
88
+
89
+ def get_data(self):
90
+ return self._leaf_data
91
+
92
+ def set_data(self, data):
93
+ self._leaf_data = data
94
+
95
+ def __init__(self, special_ids):
96
+ self.root = self.Node()
97
+ self.data = {}
98
+ self.r_data = {}
99
+ self.special_ids = special_ids
100
+
101
+ def insert(self, word, data):
102
+ self.data[word] = data
103
+ self.r_data[data] = word
104
+ idx = 0
105
+ node = self.root
106
+ while idx < len(word):
107
+ w = word[idx]
108
+ is_leaf = (idx == (len(word) - 1))
109
+ leaf_data = (data if is_leaf else None)
110
+ # 不存在则插入
111
+ if not node.has_next(w):
112
+ node.add_node(w, self.Node(is_leaf=is_leaf, leaf_data=leaf_data))
113
+ # last word
114
+ node = node.get_node(w)
115
+ idx += 1
116
+ if not node.is_leaf():
117
+ node.set_leaf()
118
+ node.set_data(data)
119
+
120
+ def findStrict(self, word):
121
+ idx = 0
122
+ node = self.root
123
+ while node is not None and idx < len(word):
124
+ w = word[idx]
125
+ if not node.has_next(w):
126
+ return None
127
+ # last word
128
+ node = node.get_node(w)
129
+ idx += 1
130
+ if node.is_leaf():
131
+ return node.get_data()
132
+ return None
133
+
134
+ def prefix(self, word):
135
+ idx = 0
136
+ node = self.root
137
+ result = []
138
+ while node is not None and idx < len(word):
139
+ w = word[idx]
140
+ if not node.has_next(w):
141
+ return result
142
+ # last word
143
+ node = node.get_node(w)
144
+ if node.is_leaf():
145
+ result.append([word[:idx + 1], node.get_data()])
146
+ idx += 1
147
+ return result
148
+
149
+ def max_prefix(self, content, start_idx):
150
+ idx = start_idx
151
+ node = self.root
152
+ l = len(content)
153
+ result = [["", ], ]
154
+ while node is not None and idx < l:
155
+ w = content[idx]
156
+ if not node.has_next(w):
157
+ return result[-1]
158
+ # last word
159
+ node = node.get_node(w)
160
+ if node.is_leaf():
161
+ result.append([content[start_idx:idx + 1], node.get_data()])
162
+ idx += 1
163
+ return result[-1]
164
+
165
+ def max_score(self, content, start_idx):
166
+ idx = start_idx
167
+ node = self.root
168
+ l = len(content)
169
+ result = [["", (3, 0)], ]
170
+ while node is not None and idx < l:
171
+ w = content[idx]
172
+ if not node.has_next(w):
173
+ break
174
+ # last word
175
+ node = node.get_node(w)
176
+ if node.is_leaf():
177
+ result.append([content[start_idx:idx + 1], node.get_data()])
178
+ idx += 1
179
+ if len(result) > 1:
180
+ result = sorted(result, key=lambda x: x[1][1])
181
+ return result[-1]
182
+
183
+ def match(self, content, add_unk=True, unk_id=-1, **kwargs):
184
+ # length
185
+ l = len(content)
186
+ i = 0
187
+ result_list = []
188
+ while i < l:
189
+ match_word = self.max_prefix(content=content, start_idx=i)
190
+ # print(match_word)
191
+ w = match_word[0]
192
+ if len(w) > 0:
193
+ result_list.append(match_word[1])
194
+ i += len(w)
195
+ else:
196
+ if add_unk:
197
+ result_list.append(unk_id)
198
+ i += 1
199
+ return result_list
200
+
201
+ def id2str(self, ids, escape_special_ids=True, end_ids=[], **kwargs):
202
+ res_str = ""
203
+ for rid in ids:
204
+ if rid in self.r_data:
205
+ if rid in end_ids:
206
+ break
207
+ if escape_special_ids and rid in self.special_ids:
208
+ continue
209
+ rstr = self.r_data[rid]
210
+ res_str += rstr
211
+ elif rid == 0:
212
+ break
213
+ else:
214
+ print("ERROR unknown id %d" % rid)
215
+ res_str += "UNK"
216
+ return res_str
217
+
218
+ def id2str_v2(self, ids, escape_special_ids=True, end_ids=[], **kwargs):
219
+ res_str = ""
220
+ for rid in ids:
221
+ if rid in self.r_data:
222
+ if rid in end_ids:
223
+ break
224
+ rstr = self.r_data[rid]
225
+ if escape_special_ids and rid in self.special_ids:
226
+ continue
227
+ res_str += rstr
228
+ elif rid == 0:
229
+ break
230
+ else:
231
+ print("ERROR unknown id %d" % rid)
232
+ res_str += "UNK"
233
+ return res_str
234
+
235
+
236
+ class RWKVWorldTokenizer(PreTrainedTokenizer):
237
+ vocab_files_names = VOCAB_FILES_NAMES
238
+ model_input_names = ["input_ids", "attention_mask"]
239
+
240
+ def __init__(
241
+ self,
242
+ vocab_file,
243
+ errors="replace",
244
+ **kwargs
245
+ ):
246
+ self.add_bos_token = False
247
+
248
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
249
+ self.encoder = json.load(vocab_handle)
250
+ super().__init__(
251
+ errors=errors,
252
+ **kwargs,
253
+ )
254
+ self.decoder = {v: k for k, v in self.encoder.items()}
255
+ self.trie = DATrie(self.all_special_ids)
256
+ for k, v in self.encoder.items():
257
+ self.trie.insert(k, v)
258
+ self.errors = errors # how to handle errors in decoding
259
+ self.cache = {}
260
+
261
+ @property
262
+ def vocab_size(self):
263
+ return len(self.encoder)
264
+
265
+ def get_vocab(self):
266
+ return dict(self.encoder, **self.added_tokens_encoder)
267
+
268
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
269
+ if self.add_bos_token:
270
+ bos_token_ids = [self.bos_token_id]
271
+ else:
272
+ bos_token_ids = []
273
+
274
+ output = bos_token_ids + token_ids_0
275
+
276
+ if token_ids_1 is None:
277
+ return output
278
+
279
+ return output + bos_token_ids + token_ids_1
280
+
281
+ def get_special_tokens_mask(
282
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None,
283
+ already_has_special_tokens: bool = False
284
+ ) -> List[int]:
285
+ """
286
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
287
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
288
+
289
+ Args:
290
+ token_ids_0 (`List[int]`):
291
+ List of IDs.
292
+ token_ids_1 (`List[int]`, *optional*):
293
+ Optional second list of IDs for sequence pairs.
294
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
295
+ Whether or not the token list is already formatted with special tokens for the model.
296
+
297
+ Returns:
298
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
299
+ """
300
+ if already_has_special_tokens:
301
+ return super().get_special_tokens_mask(
302
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
303
+ )
304
+
305
+ if not self.add_bos_token:
306
+ return super().get_special_tokens_mask(
307
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False
308
+ )
309
+
310
+ if token_ids_1 is None:
311
+ return [1] + ([0] * len(token_ids_0))
312
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
313
+
314
+ def _tokenize(self, text, **kwargs):
315
+ """Tokenize a string."""
316
+ return self.trie.match(text, unk_id=self.unk_token_id, **kwargs)
317
+
318
+ def _decode(self,
319
+ token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
320
+ skip_special_tokens: bool = False,
321
+ **kwargs
322
+ ) -> str:
323
+
324
+ # Convert inputs to python lists
325
+ token_ids = to_py_obj(token_ids)
326
+ if isinstance(token_ids, int):
327
+ if token_ids in self.all_special_ids and skip_special_tokens:
328
+ return ""
329
+ return self.decoder.get(token_ids, self.unk_token)
330
+ elif isinstance(token_ids, list):
331
+ return self.trie.id2str(
332
+ token_ids,
333
+ escape_special_ids=skip_special_tokens,
334
+ **kwargs
335
+ )
336
+ else:
337
+ return token_ids
338
+
339
+ def _convert_token_to_id(self, token):
340
+ """Converts a token (str) in an id using the vocab."""
341
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
342
+
343
+ def _convert_id_to_token(self, index):
344
+ """Converts an index (integer) in a token (str) using the vocab."""
345
+ return self.decoder.get(index)
346
+
347
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
348
+ if not os.path.exists(save_directory):
349
+ os.mkdir(save_directory)
350
+ if not os.path.isdir(save_directory):
351
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
352
+ return
353
+ vocab_file = os.path.join(
354
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
355
+ )
356
+
357
+ with open(vocab_file, "w", encoding="utf-8") as f:
358
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
359
+
360
+ return (vocab_file,)
361
+
362
+ def prepare_for_tokenization(self, text, **kwargs):
363
+ return (text, kwargs)
364
+
365
+ def _encode_plus(
366
+ self,
367
+ text: Union[TextInput, EncodedInput],
368
+ add_special_tokens: bool = True,
369
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
370
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
371
+ max_length: Optional[int] = None,
372
+ stride: int = 0,
373
+ pad_to_multiple_of: Optional[int] = None,
374
+ return_tensors: Optional[Union[str, TensorType]] = None,
375
+ return_token_type_ids: Optional[bool] = None,
376
+ return_attention_mask: Optional[bool] = None,
377
+ return_overflowing_tokens: bool = False,
378
+ return_special_tokens_mask: bool = False,
379
+ return_offsets_mapping: bool = False,
380
+ return_length: bool = False,
381
+ verbose: bool = True,
382
+ **kwargs
383
+ ) -> BatchEncoding:
384
+ def get_input_ids(text):
385
+ if isinstance(text, str):
386
+ text_id = self.trie.match(text, unk_id=self.unk_token_id)
387
+ return text_id
388
+ elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
389
+ return [self.trie.match(t, unk_id=self.unk_token_id) for t in text]
390
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
391
+ return text
392
+ else:
393
+ raise ValueError(
394
+ "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
395
+ )
396
+
397
+ if return_offsets_mapping:
398
+ raise NotImplementedError(
399
+ "return_offset_mapping is not available when using Python tokenizers. "
400
+ "To use this feature, change your tokenizer to one deriving from "
401
+ "transformers.PreTrainedTokenizerFast. "
402
+ "More information on available tokenizers at "
403
+ "https://github.com/huggingface/transformers/pull/2674"
404
+ )
405
+
406
+ first_ids = get_input_ids(text)
407
+
408
+ return self.prepare_for_model(
409
+ first_ids,
410
+ pair_ids=None,
411
+ add_special_tokens=add_special_tokens,
412
+ padding=padding_strategy.value,
413
+ truncation=truncation_strategy.value,
414
+ max_length=max_length,
415
+ stride=stride,
416
+ pad_to_multiple_of=pad_to_multiple_of,
417
+ return_tensors=return_tensors,
418
+ prepend_batch_axis=True,
419
+ return_attention_mask=return_attention_mask,
420
+ return_token_type_ids=return_token_type_ids,
421
+ return_overflowing_tokens=return_overflowing_tokens,
422
+ return_special_tokens_mask=return_special_tokens_mask,
423
+ return_length=return_length,
424
+ verbose=verbose,
425
+ )
426
+
427
+ def _batch_encode_plus(
428
+ self,
429
+ batch_text_or_text_pairs: Union[
430
+ List[TextInput],
431
+ List[EncodedInput],
432
+ ],
433
+ add_special_tokens: bool = True,
434
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
435
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
436
+ max_length: Optional[int] = None,
437
+ stride: int = 0,
438
+ pad_to_multiple_of: Optional[int] = None,
439
+ return_tensors: Optional[Union[str, TensorType]] = None,
440
+ return_token_type_ids: Optional[bool] = None,
441
+ return_attention_mask: Optional[bool] = None,
442
+ return_overflowing_tokens: bool = False,
443
+ return_special_tokens_mask: bool = False,
444
+ return_offsets_mapping: bool = False,
445
+ return_length: bool = False,
446
+ verbose: bool = True,
447
+ **kwargs
448
+ ) -> BatchEncoding:
449
+ def get_input_ids(text):
450
+ if isinstance(text, str):
451
+ text_id = self.trie.match(text, unk_id=self.unk_token_id)
452
+ return text_id
453
+ elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
454
+ return [self.trie.match(t, unk_id=self.unk_token_id) for t in text]
455
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
456
+ return text
457
+ else:
458
+ raise ValueError(
459
+ "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
460
+ )
461
+
462
+ if return_offsets_mapping:
463
+ raise NotImplementedError(
464
+ "return_offset_mapping is not available when using Python tokenizers. "
465
+ "To use this feature, change your tokenizer to one deriving from "
466
+ "transformers.PreTrainedTokenizerFast."
467
+ )
468
+
469
+ input_ids = []
470
+ for ids_or_pair_ids in batch_text_or_text_pairs:
471
+ if not isinstance(ids_or_pair_ids, (list, tuple)):
472
+ ids, pair_ids = ids_or_pair_ids, None
473
+ else:
474
+ ids, pair_ids = ids_or_pair_ids
475
+
476
+ first_ids = get_input_ids(ids)
477
+ second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
478
+ input_ids.append((first_ids, second_ids))
479
+
480
+ batch_outputs = self._batch_prepare_for_model(
481
+ input_ids,
482
+ add_special_tokens=add_special_tokens,
483
+ padding_strategy=padding_strategy,
484
+ truncation_strategy=truncation_strategy,
485
+ max_length=max_length,
486
+ stride=stride,
487
+ pad_to_multiple_of=pad_to_multiple_of,
488
+ return_attention_mask=return_attention_mask,
489
+ return_token_type_ids=return_token_type_ids,
490
+ return_overflowing_tokens=return_overflowing_tokens,
491
+ return_special_tokens_mask=return_special_tokens_mask,
492
+ return_length=return_length,
493
+ return_tensors=return_tensors,
494
+ verbose=verbose,
495
+ )
496
+
497
+ return BatchEncoding(batch_outputs)
498
+
499
+ def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
500
+ input_ids = []
501
+ for is_user, text in conversation.iter_texts():
502
+ input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
503
+ if len(input_ids) > self.model_max_length:
504
+ input_ids = input_ids[-self.model_max_length:]
505
+ return input_ids
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name_or_path": "rwkv-world",
3
+ "add_prefix_space": false,
4
+ "tokenizer_class": "RWKVWorldTokenizer",
5
+ "use_fast": false,
6
+ "auto_map": {
7
+ "AutoTokenizer": [
8
+ "tokenization_rwkv_world.RWKVWorldTokenizer",
9
+ null
10
+ ]
11
+ }
12
+ }