赛萌 commited on
Commit
b3bea9b
1 Parent(s): 33eda72
Files changed (2) hide show
  1. qwen.tiktoken +0 -0
  2. tokenization_qwen.py +412 -0
qwen.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
tokenization_qwen.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Tokenization classes for QWen."""
7
+
8
+ import base64
9
+ import logging
10
+ import os
11
+ import requests
12
+ import unicodedata
13
+ from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable
14
+
15
+ import tiktoken
16
+ import numpy as np
17
+ from PIL import Image
18
+ from PIL import ImageFont
19
+ from PIL import ImageDraw
20
+ from transformers import PreTrainedTokenizer, AddedToken
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
26
+
27
+ PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
28
+ ENDOFTEXT = "<|endoftext|>"
29
+ IMSTART = "<|im_start|>"
30
+ IMEND = "<|im_end|>"
31
+ # as the default behavior is changed to allow special tokens in
32
+ # regular texts, the surface forms of special tokens need to be
33
+ # as different as possible to minimize the impact
34
+ EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
35
+ SPECIAL_TOKENS = (
36
+ ENDOFTEXT,
37
+ IMSTART,
38
+ IMEND,
39
+ ) + EXTRAS
40
+ IMG_TOKEN_SPAN = 256
41
+
42
+
43
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
44
+ with open(tiktoken_bpe_file, "rb") as f:
45
+ contents = f.read()
46
+ return {
47
+ base64.b64decode(token): int(rank)
48
+ for token, rank in (line.split() for line in contents.splitlines() if line)
49
+ }
50
+
51
+ def _list_find(
52
+ input_list: List[Any],
53
+ candidates: Tuple[Any],
54
+ start: int = 0,
55
+ ):
56
+ for i in range(start, len(input_list)):
57
+ if input_list[i] in candidates:
58
+ return i
59
+ return -1
60
+
61
+ def _replace_closed_tag(
62
+ input_tokens: List[Any],
63
+ start_tags: Union[Any, Tuple[Any]],
64
+ end_tags: Union[Any, Tuple[Any]],
65
+ inclusive_replace_func: Callable,
66
+ exclusive_replace_func: Callable = lambda x: x,
67
+ ):
68
+ if isinstance(start_tags, (str, int)):
69
+ start_tags = (start_tags,)
70
+ if isinstance(end_tags, (str, int)):
71
+ end_tags = (end_tags,)
72
+ assert len(start_tags) == len(end_tags)
73
+
74
+ output_tokens = []
75
+ end = 0
76
+ while True:
77
+ start = _list_find(input_tokens, start_tags, end)
78
+ if start == -1:
79
+ break
80
+ output_tokens.extend(exclusive_replace_func(input_tokens[end : start]))
81
+ tag_idx = start_tags.index(input_tokens[start])
82
+ end = _list_find(input_tokens, (end_tags[tag_idx],), start)
83
+ if end == -1:
84
+ raise ValueError("Unclosed image token")
85
+ output_tokens.extend(inclusive_replace_func(input_tokens[start : end + 1]))
86
+ end += 1
87
+ output_tokens.extend(exclusive_replace_func(input_tokens[end : ]))
88
+ return output_tokens
89
+
90
+ class QWenTokenizer(PreTrainedTokenizer):
91
+ """QWen tokenizer."""
92
+
93
+ vocab_files_names = VOCAB_FILES_NAMES
94
+
95
+ def __init__(
96
+ self,
97
+ vocab_file,
98
+ errors="replace",
99
+ image_start_tag='<img>',
100
+ image_end_tag='</img>',
101
+ image_pad_tag='<imgpad>',
102
+ ref_start_tag='<ref>',
103
+ ref_end_tag='</ref>',
104
+ box_start_tag='<box>',
105
+ box_end_tag='</box>',
106
+ quad_start_tag='<quad>',
107
+ quad_end_tag='</quad>',
108
+ **kwargs,
109
+ ):
110
+ super().__init__(**kwargs)
111
+ self.image_start_tag = image_start_tag
112
+ self.image_end_tag = image_end_tag
113
+ self.image_pad_tag = image_pad_tag
114
+ self.ref_start_tag = ref_start_tag
115
+ self.ref_end_tag = ref_end_tag
116
+ self.box_start_tag = box_start_tag
117
+ self.box_end_tag = box_end_tag
118
+ self.quad_start_tag = quad_start_tag
119
+ self.quad_end_tag = quad_end_tag
120
+ self.IMAGE_ST = (
121
+ ref_start_tag, ref_end_tag,
122
+ box_start_tag, box_end_tag,
123
+ quad_start_tag, quad_end_tag,
124
+ image_start_tag, image_end_tag,
125
+ image_pad_tag
126
+ )
127
+
128
+ self.errors = errors # how to handle errors in decoding
129
+
130
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
131
+ self.special_tokens = {
132
+ token: index
133
+ for index, token in enumerate(
134
+ SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks)
135
+ )
136
+ }
137
+ self.img_start_id = self.special_tokens[self.image_start_tag]
138
+ self.img_end_id = self.special_tokens[self.image_end_tag]
139
+ self.img_pad_id = self.special_tokens[self.image_pad_tag]
140
+ self.ref_start_id = self.special_tokens[self.ref_start_tag]
141
+ self.ref_end_id = self.special_tokens[self.ref_end_tag]
142
+ self.box_start_id = self.special_tokens[self.box_start_tag]
143
+ self.box_end_id = self.special_tokens[self.box_end_tag]
144
+ self.quad_start_id = self.special_tokens[self.quad_start_tag]
145
+ self.quad_end_id = self.special_tokens[self.quad_end_tag]
146
+
147
+ enc = tiktoken.Encoding(
148
+ "Qwen",
149
+ pat_str=PAT_STR,
150
+ mergeable_ranks=self.mergeable_ranks,
151
+ special_tokens=self.special_tokens,
152
+ )
153
+ assert (
154
+ len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
155
+ ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
156
+
157
+ self.decoder = {
158
+ v: k for k, v in self.mergeable_ranks.items()
159
+ } # type: dict[int, bytes|str]
160
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
161
+
162
+ self.tokenizer = enc # type: tiktoken.Encoding
163
+
164
+ self.eod_id = self.tokenizer.eot_token
165
+ self.im_start_id = self.special_tokens[IMSTART]
166
+ self.im_end_id = self.special_tokens[IMEND]
167
+
168
+ def __len__(self) -> int:
169
+ return self.tokenizer.n_vocab
170
+
171
+ def get_vocab(self) -> Dict[bytes, int]:
172
+ return self.mergeable_ranks
173
+
174
+ def convert_tokens_to_ids(
175
+ self, tokens: Union[bytes, str, List[Union[bytes, str]]]
176
+ ) -> List[int]:
177
+ ids = []
178
+ if isinstance(tokens, (str, bytes)):
179
+ if tokens in self.special_tokens:
180
+ return self.special_tokens[tokens]
181
+ else:
182
+ return self.mergeable_ranks.get(tokens)
183
+ for token in tokens:
184
+ if token in self.special_tokens:
185
+ ids.append(self.special_tokens[token])
186
+ else:
187
+ ids.append(self.mergeable_ranks.get(token))
188
+ return ids
189
+
190
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
191
+ if not special_tokens and new_tokens:
192
+ raise ValueError('Adding regular tokens is not supported')
193
+ for token in new_tokens:
194
+ surface_form = token.content if isinstance(token, AddedToken) else token
195
+ if surface_form not in SPECIAL_TOKENS + self.IMAGE_ST:
196
+ raise ValueError('Adding unknown special tokens is not supported')
197
+ return 0
198
+
199
+ def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
200
+ """
201
+ Save only the vocabulary of the tokenizer (vocabulary).
202
+
203
+ Returns:
204
+ `Tuple(str)`: Paths to the files saved.
205
+ """
206
+ file_path = os.path.join(save_directory, "qwen.tiktoken")
207
+ with open(file_path, "w", encoding="utf8") as w:
208
+ for k, v in self.mergeable_ranks.items():
209
+ line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
210
+ w.write(line)
211
+ return (file_path,)
212
+
213
+ def tokenize(
214
+ self,
215
+ text: str,
216
+ allowed_special: Union[Set, str] = "all",
217
+ disallowed_special: Union[Collection, str] = (),
218
+ **kwargs,
219
+ ) -> List[Union[bytes, str]]:
220
+ """
221
+ Converts a string in a sequence of tokens.
222
+
223
+ Args:
224
+ text (`str`):
225
+ The sequence to be encoded.
226
+ allowed_special (`Literal["all"]` or `set`):
227
+ The surface forms of the tokens to be encoded as special tokens in regular texts.
228
+ Default to "all".
229
+ disallowed_special (`Literal["all"]` or `Collection`):
230
+ The surface forms of the tokens that should not be in regular texts and trigger errors.
231
+ Default to an empty tuple.
232
+
233
+ kwargs (additional keyword arguments, *optional*):
234
+ Will be passed to the underlying model specific encode method.
235
+
236
+ Returns:
237
+ `List[bytes|str]`: The list of tokens.
238
+ """
239
+ tokens = []
240
+ text = unicodedata.normalize("NFC", text)
241
+
242
+ # this implementation takes a detour: text -> token id -> token surface forms
243
+ for t in self.tokenizer.encode(
244
+ text, allowed_special=allowed_special, disallowed_special=disallowed_special
245
+ ):
246
+ tokens.append(self.decoder[t])
247
+
248
+ def _encode_imgurl(img_tokens):
249
+ assert img_tokens[0] == self.image_start_tag and img_tokens[-1] == self.image_end_tag
250
+ img_tokens = img_tokens[1:-1]
251
+ img_url = b''.join(img_tokens)
252
+ out_img_tokens = list(map(self.decoder.get, img_url))
253
+ if len(out_img_tokens) > IMG_TOKEN_SPAN:
254
+ raise ValueError("The content in {}..{} is too long".format(
255
+ self.image_start_tag, self.image_end_tag))
256
+ out_img_tokens.extend([self.image_pad_tag] * (IMG_TOKEN_SPAN - len(out_img_tokens)))
257
+ out_img_tokens = [self.image_start_tag] + out_img_tokens + [self.image_end_tag]
258
+ return out_img_tokens
259
+
260
+ return _replace_closed_tag(tokens, self.image_start_tag, self.image_end_tag, _encode_imgurl)
261
+
262
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
263
+ """
264
+ Converts a sequence of tokens in a single string.
265
+ """
266
+ text = ""
267
+ temp = b""
268
+ for t in tokens:
269
+ if isinstance(t, str):
270
+ if temp:
271
+ text += temp.decode("utf-8", errors=self.errors)
272
+ temp = b""
273
+ text += t
274
+ elif isinstance(t, bytes):
275
+ temp += t
276
+ else:
277
+ raise TypeError("token should only be of type types or str")
278
+ if temp:
279
+ text += temp.decode("utf-8", errors=self.errors)
280
+ return text
281
+
282
+ @property
283
+ def vocab_size(self):
284
+ return self.tokenizer.n_vocab
285
+
286
+ def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
287
+ """Converts an id to a token, special tokens included"""
288
+ if index in self.decoder:
289
+ return self.decoder[index]
290
+ raise ValueError("unknown ids")
291
+
292
+ def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
293
+ """Converts a token to an id using the vocab, special tokens included"""
294
+ if token in self.special_tokens:
295
+ return self.special_tokens[token]
296
+ if token in self.mergeable_ranks:
297
+ return self.mergeable_ranks[token]
298
+ raise ValueError("unknown token")
299
+
300
+ def _tokenize(self, text: str, **kwargs):
301
+ """
302
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
303
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
304
+
305
+ Do NOT take care of added tokens.
306
+ """
307
+ raise NotImplementedError
308
+
309
+ def _decode(
310
+ self,
311
+ token_ids: Union[int, List[int]],
312
+ skip_special_tokens: bool = False,
313
+ errors: str = None,
314
+ **kwargs,
315
+ ) -> str:
316
+ if isinstance(token_ids, int):
317
+ token_ids = [token_ids]
318
+
319
+ def _decode_imgurl(img_token_ids):
320
+ assert img_token_ids[0] == self.img_start_id and img_token_ids[-1] == self.img_end_id
321
+ img_token_ids = img_token_ids[1:-1]
322
+ img_token_ids = img_token_ids[ : img_token_ids.index(self.img_pad_id)]
323
+ img_url = bytes(img_token_ids).decode('utf-8')
324
+ return [self.img_start_id] + self.tokenizer.encode(img_url) + [self.img_end_id]
325
+
326
+ token_ids = _replace_closed_tag(token_ids, self.img_start_id, self.img_end_id, _decode_imgurl)
327
+
328
+ if skip_special_tokens:
329
+ token_ids = [i for i in token_ids if i < self.eod_id]
330
+ return self.tokenizer.decode(token_ids, errors=errors or self.errors)
331
+
332
+ def to_list_format(self, text: str):
333
+ text = unicodedata.normalize("NFC", text)
334
+ token_ids = self.tokenizer.encode(
335
+ text, allowed_special=set(self.IMAGE_ST + (ENDOFTEXT,)))
336
+
337
+ def _encode_vl_info(tokens):
338
+ if len(tokens) == 0:
339
+ return []
340
+ if tokens[0] == self.img_start_id and tokens[-1] == self.img_end_id:
341
+ key = 'image'
342
+ elif tokens[0] == self.ref_start_id and tokens[-1] == self.ref_end_id:
343
+ key = 'ref'
344
+ elif tokens[0] == self.box_start_id and tokens[-1] == self.box_end_id:
345
+ key = 'box'
346
+ elif tokens[0] == self.quad_start_id and tokens[-1] == self.quad_end_id:
347
+ key = 'quad'
348
+ else:
349
+ _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x
350
+ return [{'text': b''.join(map(_tobytes, map(self.decoder.get, tokens))).decode('utf-8')}]
351
+ val = b''.join(map(self.decoder.get, tokens[1:-1])).decode('utf-8')
352
+ return [{key: val}]
353
+
354
+ return _replace_closed_tag(
355
+ token_ids,
356
+ (self.img_start_id, self.ref_start_id, self.box_start_id, self.quad_start_id),
357
+ (self.img_end_id, self.ref_end_id, self.box_end_id, self.quad_end_id),
358
+ _encode_vl_info,
359
+ _encode_vl_info,
360
+ )
361
+
362
+ def _fetch_latest_picture(self, response, history):
363
+ if history is None:
364
+ history = []
365
+ _history = history + [(response, None)]
366
+ for q, r in _history[::-1]:
367
+ for ele in self.to_list_format(q)[::-1]:
368
+ if 'image' in ele:
369
+ return ele['image']
370
+ return None
371
+
372
+ def _fetch_all_box_with_ref(self, text):
373
+ list_format = self.to_list_format(text)
374
+ output = []
375
+ for i, ele in enumerate(list_format):
376
+ if 'box' in ele:
377
+ bbox = tuple(map(int, ele['box'].replace('(', '').replace(')', '').split(',')))
378
+ assert len(bbox) == 4
379
+ output.append({'box': bbox})
380
+ if i > 0 and 'ref' in list_format[i-1]:
381
+ output[-1]['ref'] = list_format[i-1]['ref'].strip()
382
+ return output
383
+
384
+ def draw_bbox_on_latest_picture(
385
+ self,
386
+ response,
387
+ history=None,
388
+ ):
389
+ image = self._fetch_latest_picture(response, history)
390
+ if image is None:
391
+ return None
392
+ if image.startswith("http://") or image.startswith("https://"):
393
+ image = Image.open(requests.get(image, stream=True).raw)
394
+ else:
395
+ image = Image.open(image)
396
+ h, w = image.height, image.width
397
+ image = image.convert("RGB")
398
+
399
+ boxes = self._fetch_all_box_with_ref(response)
400
+ if not boxes:
401
+ return None
402
+ fnt = ImageFont.truetype("SimSun.ttf", 20)
403
+ draw = ImageDraw.Draw(image)
404
+ for box in boxes:
405
+ x1, y1, x2, y2 = box['box']
406
+ x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h))
407
+ draw.rectangle((x1, y1, x2, y2), outline='red', width=2)
408
+ if 'ref' in box:
409
+ draw.text((x1, y1), box['ref'], fill='red', font=fnt)
410
+ return image
411
+
412
+