zhjohnchan commited on
Commit
832389d
1 Parent(s): 7adba1f

Upload tokenization_chexagent.py

Browse files
Files changed (1) hide show
  1. tokenization_chexagent.py +648 -0
tokenization_chexagent.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import unicodedata
4
+ from shutil import copyfile
5
+ from typing import TYPE_CHECKING, Dict, List, Tuple, Union, Any, Callable, Optional
6
+
7
+ import matplotlib as mpl
8
+ import matplotlib.colors as mcolors
9
+ import matplotlib.colors as mplc
10
+ import matplotlib.figure as mplfigure
11
+ import numpy as np
12
+ import requests
13
+ import sentencepiece as spm
14
+ import torch
15
+ from PIL import Image
16
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
17
+ from transformers import PreTrainedTokenizer, AddedToken
18
+ from transformers.convert_slow_tokenizer import import_protobuf
19
+ from transformers.utils import logging
20
+
21
+ if TYPE_CHECKING:
22
+ from transformers.tokenization_utils_base import TextInput
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
27
+
28
+ PRETRAINED_VOCAB_FILES_MAP = {
29
+ "vocab_file": {
30
+ "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
31
+ },
32
+ "tokenizer_file": {
33
+ "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
34
+ },
35
+ }
36
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
37
+ "hf-internal-testing/llama-tokenizer": 2048,
38
+ }
39
+ SPIECE_UNDERLINE = "▁"
40
+
41
+ IMG_TOKEN_SPAN = 256
42
+
43
+ DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['from'] == 'human' %}\n{{ '<|user|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'system' %}\n{{ '<|system|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'gpt' %}\n{{ '<|assistant|>\n' + message['value'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
44
+
45
+
46
+ def _list_find(
47
+ input_list: List[Any],
48
+ candidates: Tuple[Any],
49
+ start: int = 0,
50
+ ):
51
+ for i in range(start, len(input_list)):
52
+ if input_list[i] in candidates:
53
+ return i
54
+ return -1
55
+
56
+
57
+ def _replace_closed_tag(
58
+ input_tokens: List[Any],
59
+ start_tags: Union[Any, Tuple[Any]],
60
+ end_tags: Union[Any, Tuple[Any]],
61
+ inclusive_replace_func: Callable,
62
+ exclusive_replace_func: Callable = lambda x: x,
63
+ ):
64
+ if isinstance(start_tags, (str, int)):
65
+ start_tags = (start_tags,)
66
+ if isinstance(end_tags, (str, int)):
67
+ end_tags = (end_tags,)
68
+ assert len(start_tags) == len(end_tags)
69
+
70
+ output_tokens = []
71
+ end = 0
72
+ while True:
73
+ start = _list_find(input_tokens, start_tags, end)
74
+ if start == -1:
75
+ break
76
+ output_tokens.extend(exclusive_replace_func(input_tokens[end: start]))
77
+ tag_idx = start_tags.index(input_tokens[start])
78
+ end = _list_find(input_tokens, (end_tags[tag_idx],), start)
79
+ if end == -1:
80
+ raise ValueError("Unclosed image token")
81
+ output_tokens.extend(inclusive_replace_func(input_tokens[start: end + 1]))
82
+ end += 1
83
+ output_tokens.extend(exclusive_replace_func(input_tokens[end:]))
84
+ return output_tokens
85
+
86
+
87
+ class CheXagentTokenizer(PreTrainedTokenizer):
88
+ vocab_files_names = VOCAB_FILES_NAMES
89
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
90
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
91
+ model_input_names = ["input_ids", "attention_mask"]
92
+
93
+ def __init__(
94
+ self,
95
+ vocab_file,
96
+ unk_token="<unk>",
97
+ bos_token="<s>",
98
+ eos_token="</s>",
99
+ pad_token=None,
100
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
101
+ add_bos_token=True,
102
+ add_eos_token=False,
103
+ clean_up_tokenization_spaces=False,
104
+ use_default_system_prompt=False,
105
+ spaces_between_special_tokens=False,
106
+ legacy=None,
107
+ errors="replace",
108
+ image_start_tag='<|img|>',
109
+ image_end_tag='<|/img|>',
110
+ image_pad_tag='<|imgpad|>',
111
+ ref_start_tag='<|ref|>',
112
+ ref_end_tag='<|/ref|>',
113
+ box_start_tag='<|box|>',
114
+ box_end_tag='<|/box|>',
115
+ quad_start_tag='<|quad|>',
116
+ quad_end_tag='<|/quad|>',
117
+ **kwargs,
118
+ ):
119
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
120
+ bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
121
+ eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
122
+ unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
123
+ pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
124
+
125
+ if legacy is None:
126
+ logger.warning_once(
127
+ f"You are using the default legacy behaviour of the {self.__class__}. This is"
128
+ " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
129
+ " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
130
+ " means, and thoroughly read the reason why this was added as explained in"
131
+ " https://github.com/huggingface/transformers/pull/24565"
132
+ )
133
+ legacy = True
134
+
135
+ self.legacy = legacy
136
+ self.vocab_file = vocab_file
137
+ self.add_bos_token = add_bos_token
138
+ self.add_eos_token = add_eos_token
139
+ self.use_default_system_prompt = use_default_system_prompt
140
+ self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
141
+ super().__init__(
142
+ bos_token=bos_token,
143
+ eos_token=eos_token,
144
+ unk_token=unk_token,
145
+ pad_token=pad_token,
146
+ add_bos_token=add_bos_token,
147
+ add_eos_token=add_eos_token,
148
+ sp_model_kwargs=self.sp_model_kwargs,
149
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
150
+ use_default_system_prompt=use_default_system_prompt,
151
+ spaces_between_special_tokens=spaces_between_special_tokens,
152
+ legacy=legacy,
153
+ **kwargs,
154
+ )
155
+ self.errors = errors # how to handle errors in decoding
156
+ self.image_start_tag = image_start_tag
157
+ self.image_end_tag = image_end_tag
158
+ self.image_pad_tag = image_pad_tag
159
+ self.ref_start_tag = ref_start_tag
160
+ self.ref_end_tag = ref_end_tag
161
+ self.box_start_tag = box_start_tag
162
+ self.box_end_tag = box_end_tag
163
+ self.quad_start_tag = quad_start_tag
164
+ self.quad_end_tag = quad_end_tag
165
+ self.IMAGE_ST = (
166
+ image_start_tag, image_end_tag, image_pad_tag,
167
+ ref_start_tag, ref_end_tag, box_start_tag, box_end_tag,
168
+ quad_start_tag, quad_end_tag,
169
+ )
170
+ for special_token in self.IMAGE_ST:
171
+ if special_token not in self.get_vocab():
172
+ self.add_special_tokens({"additional_special_tokens": [special_token]})
173
+ for coordinate in range(10):
174
+ if f"<{coordinate}>" not in self.get_vocab():
175
+ self.add_special_tokens({"additional_special_tokens": [f"<|coord_{coordinate}|>"]})
176
+ if len(self) % 64 != 0:
177
+ for extra in range(((len(self) // 64) + 1) * 64 - len(self)):
178
+ if f"<extra_{extra}>" not in self.get_vocab():
179
+ self.add_special_tokens({"additional_special_tokens": [f"<|extra_{extra}|>"]})
180
+ self.img_start_id = self.convert_tokens_to_ids(self.image_start_tag)
181
+ self.img_end_id = self.convert_tokens_to_ids(self.image_end_tag)
182
+ self.img_pad_id = self.convert_tokens_to_ids(self.image_pad_tag)
183
+ self.ref_start_id = self.convert_tokens_to_ids(self.ref_start_tag)
184
+ self.ref_end_id = self.convert_tokens_to_ids(self.ref_end_tag)
185
+ self.box_start_id = self.convert_tokens_to_ids(self.box_start_tag)
186
+ self.box_end_id = self.convert_tokens_to_ids(self.box_end_tag)
187
+ self.quad_start_id = self.convert_tokens_to_ids(self.quad_start_tag)
188
+ self.quad_end_id = self.convert_tokens_to_ids(self.quad_end_tag)
189
+ self.chat_template = DEFAULT_CHAT_TEMPLATE
190
+
191
+ @property
192
+ def unk_token_length(self):
193
+ return len(self.sp_model.encode(str(self.unk_token)))
194
+
195
+ def get_spm_processor(self, from_slow=False):
196
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
197
+ if self.legacy or from_slow: # no dependency on protobuf
198
+ tokenizer.Load(self.vocab_file)
199
+ return tokenizer
200
+
201
+ with open(self.vocab_file, "rb") as f:
202
+ sp_model = f.read()
203
+ model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
204
+ model = model_pb2.ModelProto.FromString(sp_model)
205
+ normalizer_spec = model_pb2.NormalizerSpec()
206
+ normalizer_spec.add_dummy_prefix = False
207
+ model.normalizer_spec.MergeFrom(normalizer_spec)
208
+ sp_model = model.SerializeToString()
209
+ tokenizer.LoadFromSerializedProto(sp_model)
210
+ return tokenizer
211
+
212
+ def __getstate__(self):
213
+ state = self.__dict__.copy()
214
+ state["sp_model"] = None
215
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
216
+ return state
217
+
218
+ def __setstate__(self, d):
219
+ self.__dict__ = d
220
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
221
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
222
+
223
+ @property
224
+ def vocab_size(self):
225
+ """Returns vocab size"""
226
+ return self.sp_model.get_piece_size()
227
+
228
+ def get_vocab(self):
229
+ """Returns vocab as a dict"""
230
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
231
+ vocab.update(self.added_tokens_encoder)
232
+ return vocab
233
+
234
+ def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
235
+ """
236
+ Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
237
+ first token is special.
238
+ """
239
+
240
+ def _encode_imgurl(img_tokens):
241
+ assert img_tokens[0] == self.image_start_tag and img_tokens[-1] == self.image_end_tag
242
+ img_tokens = img_tokens[1:-1]
243
+ img_url = ''.join(img_tokens)
244
+ out_img_tokens = list(img_url)
245
+ if len(out_img_tokens) > IMG_TOKEN_SPAN:
246
+ raise ValueError("The content in {}..{} is too long".format(self.image_start_tag, self.image_end_tag))
247
+ out_img_tokens.extend([self.image_pad_tag] * (IMG_TOKEN_SPAN - len(out_img_tokens)))
248
+ out_img_tokens = [self.image_start_tag] + out_img_tokens + [self.image_end_tag]
249
+ return out_img_tokens
250
+
251
+ if self.legacy or len(text) == 0:
252
+ tokens = super().tokenize(text, **kwargs)
253
+ tokens = _replace_closed_tag(tokens, self.image_start_tag, self.image_end_tag, _encode_imgurl)
254
+ return tokens
255
+
256
+ tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs)
257
+
258
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
259
+ tokens = tokens[1:]
260
+ return _replace_closed_tag(tokens, self.image_start_tag, self.image_end_tag, _encode_imgurl)
261
+
262
+ def _decode(
263
+ self,
264
+ token_ids: Union[int, List[int]],
265
+ skip_special_tokens: bool = False,
266
+ errors: str = None,
267
+ **kwargs,
268
+ ) -> str:
269
+ def _decode_imgurl(img_token_ids):
270
+ assert img_token_ids[0] == self.img_start_id and img_token_ids[-1] == self.img_end_id
271
+ img_token_ids = img_token_ids[1:-1]
272
+ img_token_ids = img_token_ids[: img_token_ids.index(self.img_pad_id)]
273
+ return [self.img_start_id] + img_token_ids + [self.img_end_id]
274
+
275
+ token_ids = _replace_closed_tag(token_ids, self.img_start_id, self.img_end_id, _decode_imgurl)
276
+ return super()._decode(token_ids, errors=errors or self.errors)
277
+
278
+ def to_list_format(self, text: str):
279
+ text = unicodedata.normalize("NFC", text)
280
+ token_ids = self.encode(text)[1:]
281
+
282
+ def _encode_vl_info(tokens):
283
+ if len(tokens) == 0:
284
+ return []
285
+ if tokens[0] == self.img_start_id and tokens[-1] == self.img_end_id:
286
+ key = 'image'
287
+ tokens = tokens[: tokens.index(self.img_pad_id)]
288
+ elif tokens[0] == self.ref_start_id and tokens[-1] == self.ref_end_id:
289
+ key = 'ref'
290
+ elif tokens[0] == self.box_start_id and tokens[-1] == self.box_end_id:
291
+ key = 'box'
292
+ elif tokens[0] == self.quad_start_id and tokens[-1] == self.quad_end_id:
293
+ key = 'quad'
294
+ else:
295
+ key = 'text'
296
+ return [{key: self.decode(tokens)}]
297
+ return [{key: self.decode(tokens[1:-1])}]
298
+
299
+ return _replace_closed_tag(
300
+ token_ids,
301
+ (self.img_start_id, self.ref_start_id, self.box_start_id, self.quad_start_id),
302
+ (self.img_end_id, self.ref_end_id, self.box_end_id, self.quad_end_id),
303
+ _encode_vl_info,
304
+ _encode_vl_info,
305
+ )
306
+
307
+ def from_list_format(self, list_format: List[Dict]):
308
+ text = ''
309
+ num_images = 0
310
+ for ele in list_format:
311
+ if 'image' in ele:
312
+ num_images += 1
313
+ text += f'Picture {num_images}:'
314
+ text += self.image_start_tag + ele['image'] + self.image_end_tag
315
+ text += '\n'
316
+ elif 'text' in ele:
317
+ text += ele['text']
318
+ elif 'box' in ele:
319
+ if 'ref' in ele:
320
+ text += self.ref_start_tag + ele['ref'] + self.ref_end_tag
321
+ for box in ele['box']:
322
+ text += self.box_start_tag + '(%d,%d),(%d,%d)' % (box[0], box[1], box[2], box[3]) + self.box_end_tag
323
+ else:
324
+ raise ValueError("Unsupport element: " + str(ele))
325
+ return text
326
+
327
+ def _fetch_latest_picture(self, response, history):
328
+ if history is None:
329
+ history = []
330
+ _history = history + [(response, None)]
331
+ for q, r in _history[::-1]:
332
+ for ele in self.to_list_format(q)[::-1]:
333
+ if 'image' in ele:
334
+ return ele['image']
335
+ return None
336
+
337
+ def _fetch_all_box_with_ref(self, text):
338
+ list_format = self.to_list_format(text)
339
+ output = []
340
+ for i, ele in enumerate(list_format):
341
+ if 'box' in ele:
342
+ bbox = tuple(map(int, ele['box'].replace('(', '').replace(')', '').split(',')))
343
+ assert len(bbox) == 4
344
+ output.append({'box': bbox})
345
+ if i > 0 and 'ref' in list_format[i - 1]:
346
+ output[-1]['ref'] = list_format[i - 1]['ref'].strip()
347
+ return output
348
+
349
+ def draw_bbox_on_latest_picture(
350
+ self,
351
+ response,
352
+ history=None,
353
+ ) -> Optional[Image.Image]:
354
+ image = self._fetch_latest_picture(response, history)
355
+ if image is None:
356
+ return None
357
+ if image.startswith("http://") or image.startswith("https://"):
358
+ image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
359
+ h, w = image.height, image.width
360
+ else:
361
+ image = np.asarray(Image.open(image).convert("RGB"))
362
+ h, w = image.shape[0], image.shape[1]
363
+ visualizer = Visualizer(image)
364
+
365
+ boxes = self._fetch_all_box_with_ref(response)
366
+ if not boxes:
367
+ return None
368
+ color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()]) # init color
369
+ for box in boxes:
370
+ if 'ref' in box: # random new color for new refexps
371
+ color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()])
372
+ x1, y1, x2, y2 = box['box']
373
+ x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h))
374
+ visualizer.draw_box((x1, y1, x2, y2), alpha=1, edge_color=color)
375
+ if 'ref' in box:
376
+ visualizer.draw_text(box['ref'], (x1, y1), color=color, horizontal_alignment="left")
377
+ return visualizer.output
378
+
379
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
380
+ def _tokenize(self, text, **kwargs):
381
+ """
382
+ Returns a tokenized string.
383
+
384
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
385
+ SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
386
+ `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
387
+ `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
388
+ `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
389
+ """
390
+ tokens = self.sp_model.encode(text, out_type=str)
391
+ if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
392
+ return tokens
393
+
394
+ # 1. Encode string + prefix ex: "<unk> Hey"
395
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
396
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
397
+ return tokens[self.unk_token_length:] if len(tokens) >= self.unk_token_length else tokens
398
+
399
+ def _convert_token_to_id(self, token):
400
+ """Converts a token (str) in an id using the vocab."""
401
+ return self.sp_model.piece_to_id(token)
402
+
403
+ def _convert_id_to_token(self, index):
404
+ """Converts an index (integer) in a token (str) using the vocab."""
405
+ token = self.sp_model.IdToPiece(index)
406
+ return token
407
+
408
+ def convert_tokens_to_string(self, tokens):
409
+ """Converts a sequence of tokens (string) in a single string."""
410
+ # since we manually add the prefix space, we have to remove it when decoding
411
+ if tokens[0].startswith(SPIECE_UNDERLINE):
412
+ tokens[0] = tokens[0][1:]
413
+
414
+ current_sub_tokens = []
415
+ out_string = ""
416
+ prev_is_special = False
417
+ for i, token in enumerate(tokens):
418
+ # make sure that special tokens are not decoded using sentencepiece model
419
+ if token in self.all_special_tokens:
420
+ if not prev_is_special and i != 0 and self.legacy:
421
+ out_string += " "
422
+ out_string += self.sp_model.decode(current_sub_tokens) + token
423
+ prev_is_special = True
424
+ current_sub_tokens = []
425
+ else:
426
+ current_sub_tokens.append(token)
427
+ prev_is_special = False
428
+ out_string += self.sp_model.decode(current_sub_tokens)
429
+ return out_string
430
+
431
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
432
+ """
433
+ Save the vocabulary and special tokens file to a directory.
434
+
435
+ Args:
436
+ save_directory (`str`):
437
+ The directory in which to save the vocabulary.
438
+
439
+ Returns:
440
+ `Tuple(str)`: Paths to the files saved.
441
+ """
442
+ if not os.path.isdir(save_directory):
443
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
444
+ return
445
+ out_vocab_file = os.path.join(
446
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
447
+ )
448
+
449
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
450
+ copyfile(self.vocab_file, out_vocab_file)
451
+ elif not os.path.isfile(self.vocab_file):
452
+ with open(out_vocab_file, "wb") as fi:
453
+ content_spiece_model = self.sp_model.serialized_model_proto()
454
+ fi.write(content_spiece_model)
455
+
456
+ return (out_vocab_file,)
457
+
458
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
459
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
460
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
461
+
462
+ output = bos_token_id + token_ids_0 + eos_token_id
463
+
464
+ if token_ids_1 is not None:
465
+ output = output + bos_token_id + token_ids_1 + eos_token_id
466
+
467
+ return output
468
+
469
+ def get_special_tokens_mask(
470
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None,
471
+ already_has_special_tokens: bool = False
472
+ ) -> List[int]:
473
+ """
474
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
475
+ special tokens using the tokenizer `prepare_for_model` method.
476
+
477
+ Args:
478
+ token_ids_0 (`List[int]`):
479
+ List of IDs.
480
+ token_ids_1 (`List[int]`, *optional*):
481
+ Optional second list of IDs for sequence pairs.
482
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
483
+ Whether or not the token list is already formatted with special tokens for the model.
484
+
485
+ Returns:
486
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
487
+ """
488
+ if already_has_special_tokens:
489
+ return super().get_special_tokens_mask(
490
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
491
+ )
492
+
493
+ bos_token_id = [1] if self.add_bos_token else []
494
+ eos_token_id = [1] if self.add_eos_token else []
495
+
496
+ if token_ids_1 is None:
497
+ return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
498
+ return (
499
+ bos_token_id
500
+ + ([0] * len(token_ids_0))
501
+ + eos_token_id
502
+ + bos_token_id
503
+ + ([0] * len(token_ids_1))
504
+ + eos_token_id
505
+ )
506
+
507
+ def create_token_type_ids_from_sequences(
508
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
509
+ ) -> List[int]:
510
+ """
511
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
512
+ sequence pair mask has the following format:
513
+
514
+ ```
515
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
516
+ | first sequence | second sequence |
517
+ ```
518
+
519
+ if token_ids_1 is None, only returns the first portion of the mask (0s).
520
+
521
+ Args:
522
+ token_ids_0 (`List[int]`):
523
+ List of ids.
524
+ token_ids_1 (`List[int]`, *optional*):
525
+ Optional second list of IDs for sequence pairs.
526
+
527
+ Returns:
528
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
529
+ """
530
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
531
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
532
+
533
+ output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
534
+
535
+ if token_ids_1 is not None:
536
+ output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
537
+
538
+ return output
539
+
540
+
541
+ class VisImage:
542
+ def __init__(self, img, scale=1.0):
543
+ self.img = img
544
+ self.scale = scale
545
+ self.width, self.height = img.shape[1], img.shape[0]
546
+ self._setup_figure(img)
547
+
548
+ def _setup_figure(self, img):
549
+ fig = mplfigure.Figure(frameon=False)
550
+ self.dpi = fig.get_dpi()
551
+ # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
552
+ # (https://github.com/matplotlib/matplotlib/issues/15363)
553
+ fig.set_size_inches(
554
+ (self.width * self.scale + 1e-2) / self.dpi,
555
+ (self.height * self.scale + 1e-2) / self.dpi,
556
+ )
557
+ self.canvas = FigureCanvasAgg(fig)
558
+ # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
559
+ ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
560
+ ax.axis("off")
561
+ self.fig = fig
562
+ self.ax = ax
563
+ self.reset_image(img)
564
+
565
+ def reset_image(self, img):
566
+ img = img.astype("uint8")
567
+ self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
568
+
569
+ def save(self, filepath):
570
+ self.fig.savefig(filepath)
571
+
572
+ def get_image(self):
573
+ canvas = self.canvas
574
+ s, (width, height) = canvas.print_to_buffer()
575
+
576
+ buffer = np.frombuffer(s, dtype="uint8")
577
+
578
+ img_rgba = buffer.reshape(height, width, 4)
579
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
580
+ return rgb.astype("uint8")
581
+
582
+
583
+ class Visualizer:
584
+ def __init__(self, img_rgb, metadata=None, scale=1.0):
585
+ self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
586
+ self.output = VisImage(self.img, scale=scale)
587
+ self.cpu_device = torch.device("cpu")
588
+
589
+ # too small texts are useless, therefore clamp to 14
590
+ self._default_font_size = max(
591
+ np.sqrt(self.output.height * self.output.width) // 30, 15 // scale
592
+ )
593
+
594
+ def draw_text(
595
+ self,
596
+ text,
597
+ position,
598
+ *,
599
+ font_size=None,
600
+ color="g",
601
+ horizontal_alignment="center",
602
+ rotation=0,
603
+ ):
604
+ if not font_size:
605
+ font_size = self._default_font_size
606
+
607
+ # since the text background is dark, we don't want the text to be dark
608
+ color = np.maximum(list(mplc.to_rgb(color)), 0.2)
609
+ color[np.argmax(color)] = max(0.8, np.max(color))
610
+
611
+ x, y = position
612
+ self.output.ax.text(
613
+ x,
614
+ y,
615
+ text,
616
+ size=font_size * self.output.scale,
617
+ bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
618
+ verticalalignment="top",
619
+ horizontalalignment=horizontal_alignment,
620
+ color=color,
621
+ zorder=10,
622
+ rotation=rotation,
623
+ )
624
+ return self.output
625
+
626
+ def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
627
+ x0, y0, x1, y1 = box_coord
628
+ width = x1 - x0
629
+ height = y1 - y0
630
+
631
+ linewidth = max(self._default_font_size / 4, 1)
632
+
633
+ self.output.ax.add_patch(
634
+ mpl.patches.Rectangle(
635
+ (x0, y0),
636
+ width,
637
+ height,
638
+ fill=False,
639
+ edgecolor=edge_color,
640
+ linewidth=linewidth * self.output.scale,
641
+ alpha=alpha,
642
+ linestyle=line_style,
643
+ )
644
+ )
645
+ return self.output
646
+
647
+ def get_output(self):
648
+ return self.output