YaTharThShaRma999 commited on
Commit
b1bbf8b
1 Parent(s): 1f583fd

Create utilf.py

Browse files
Files changed (1) hide show
  1. utilf.py +373 -0
utilf.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List
3
+
4
+ import torch
5
+ from PIL.Image import Image
6
+ from transformers import LlamaTokenizerFast
7
+ from transformers.processing_utils import ProcessorMixin
8
+
9
+ from deepseek_vl.models.image_processing_vlm import VLMImageProcessor
10
+ from deepseek_vl.utils.conversation import get_conv_template
11
+
12
+
13
+ class DictOutput(object):
14
+ def keys(self):
15
+ return self.__dict__.keys()
16
+
17
+ def __getitem__(self, item):
18
+ return self.__dict__[item]
19
+
20
+ def __setitem__(self, key, value):
21
+ self.__dict__[key] = value
22
+
23
+
24
+ @dataclass
25
+ class VLChatProcessorOutput(DictOutput):
26
+ sft_format: str
27
+ input_ids: torch.Tensor
28
+ pixel_values: torch.Tensor
29
+ num_image_tokens: torch.IntTensor
30
+
31
+ def __len__(self):
32
+ return len(self.input_ids)
33
+
34
+
35
+ @dataclass
36
+ class BatchedVLChatProcessorOutput(DictOutput):
37
+ sft_format: List[str]
38
+ input_ids: torch.Tensor
39
+ pixel_values: torch.Tensor
40
+ attention_mask: torch.Tensor
41
+ images_seq_mask: torch.BoolTensor
42
+ images_emb_mask: torch.BoolTensor
43
+
44
+ def to(self, device, dtype=torch.bfloat16):
45
+ self.input_ids = self.input_ids.to(device)
46
+ self.attention_mask = self.attention_mask.to(device)
47
+ self.images_seq_mask = self.images_seq_mask.to(device)
48
+ self.images_emb_mask = self.images_emb_mask.to(device)
49
+ self.pixel_values = self.pixel_values.to(device=device, dtype=dtype)
50
+ return self
51
+
52
+
53
+ class VLChatProcessor(ProcessorMixin):
54
+ image_processor_class = "AutoImageProcessor"
55
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
56
+
57
+ attributes = ["image_processor", "tokenizer"]
58
+
59
+ system_prompt = (
60
+ "You are a helpful language and vision assistant. "
61
+ "You are able to understand the visual content that the user provides, "
62
+ "and assist the user with a variety of tasks using natural language."
63
+ )
64
+
65
+ def __init__(
66
+ self,
67
+ image_processor: VLMImageProcessor,
68
+ tokenizer: LlamaTokenizerFast,
69
+ image_tag: str = "<image_placeholder>",
70
+ num_image_tokens: int = 576,
71
+ add_special_token: bool = False,
72
+ sft_format: str = "deepseek",
73
+ mask_prompt: bool = True,
74
+ ignore_id: int = -100,
75
+ system="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
76
+ **kwargs,
77
+ ):
78
+ self.system_prompt = system
79
+ self.image_processor = image_processor
80
+ self.tokenizer = tokenizer
81
+
82
+ image_id = self.tokenizer.vocab.get(image_tag)
83
+ if image_id is None:
84
+ special_tokens = [image_tag]
85
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
86
+ self.tokenizer.add_special_tokens(special_tokens_dict)
87
+ print(f"Add image tag = {image_tag} to the tokenizer")
88
+
89
+ self.image_tag = image_tag
90
+ self.num_image_tokens = num_image_tokens
91
+ self.add_special_token = add_special_token
92
+ self.sft_format = sft_format
93
+ self.mask_prompt = mask_prompt
94
+ self.ignore_id = ignore_id
95
+
96
+ super().__init__(
97
+ image_processor,
98
+ tokenizer,
99
+ image_tag,
100
+ num_image_tokens,
101
+ add_special_token,
102
+ sft_format,
103
+ mask_prompt,
104
+ ignore_id,
105
+ **kwargs,
106
+ )
107
+
108
+ def new_chat_template(self):
109
+ conv = get_conv_template(self.sft_format)
110
+ conv.set_system_message(self.system_prompt)
111
+ return conv
112
+
113
+ def apply_sft_template_for_multi_turn_prompts(
114
+ self,
115
+ conversations: List[Dict[str, str]],
116
+ sft_format: str = "deepseek",
117
+ system_prompt: str = "",
118
+ ):
119
+ """
120
+ Applies the SFT template to conversation.
121
+
122
+ An example of conversation:
123
+ conversation = [
124
+ {
125
+ "role": "User",
126
+ "content": "<image_placeholder> is Figure 1.\n<image_placeholder> is Figure 2.\nWhich image is brighter?",
127
+ "images": [
128
+ "./multi-images/attribute_comparison_1.png",
129
+ "./multi-images/attribute_comparison_2.png"
130
+ ]
131
+ },
132
+ {
133
+ "role": "Assistant",
134
+ "content": ""
135
+ }
136
+ ]
137
+
138
+ Args:
139
+ conversations (List[Dict]): A conversation with a List of Dict[str, str] text.
140
+ sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
141
+ system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
142
+
143
+ Returns:
144
+ sft_prompt (str): The formatted text.
145
+ """
146
+
147
+ conv = get_conv_template(sft_format)
148
+ conv.set_system_message(system_prompt)
149
+ for message in conversations:
150
+ conv.append_message(message["role"], message["content"].strip())
151
+ sft_prompt = conv.get_prompt().strip()
152
+
153
+ return sft_prompt
154
+
155
+ @property
156
+ def image_token(self):
157
+ return self.image_tag
158
+
159
+ @property
160
+ def image_id(self):
161
+ image_id = self.tokenizer.vocab.get(self.image_tag)
162
+ return image_id
163
+
164
+ @property
165
+ def pad_id(self):
166
+ pad_id = self.tokenizer.pad_token_id
167
+ if pad_id is None:
168
+ pad_id = self.tokenizer.eos_token_id
169
+
170
+ return pad_id
171
+
172
+ def add_image_token(
173
+ self,
174
+ image_indices: List[int],
175
+ input_ids: torch.LongTensor,
176
+ ):
177
+ """
178
+
179
+ Args:
180
+ image_indices (List[int]): [index_0, index_1, ..., index_j]
181
+ input_ids (torch.LongTensor): [N]
182
+
183
+ Returns:
184
+ input_ids (torch.LongTensor): [N + image tokens]
185
+ num_image_tokens (torch.IntTensor): [n_images]
186
+ """
187
+
188
+ input_slices = []
189
+
190
+ start = 0
191
+ for index in image_indices:
192
+ if self.add_special_token:
193
+ end = index + 1
194
+ else:
195
+ end = index
196
+
197
+ # original text tokens
198
+ input_slices.append(input_ids[start:end])
199
+
200
+ # add image tokens, and set the mask as False
201
+ input_slices.append(
202
+ self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
203
+ )
204
+ start = index + 1
205
+
206
+ # the left part
207
+ input_slices.append(input_ids[start:])
208
+
209
+ # concat all slices
210
+ input_ids = torch.cat(input_slices, dim=0)
211
+ num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
212
+
213
+ return input_ids, num_image_tokens
214
+
215
+ def process_one(
216
+ self,
217
+ prompt: str = None,
218
+ conversations: List[Dict[str, str]] = None,
219
+ images: List[Image] = None,
220
+ **kwargs,
221
+ ):
222
+ """
223
+
224
+ Args:
225
+ prompt (str): the formatted prompt;
226
+ conversations (List[Dict]): conversations with a list of messages;
227
+ images (List[ImageType]): the list of images;
228
+ **kwargs:
229
+
230
+ Returns:
231
+ outputs (BaseProcessorOutput): the output of the processor,
232
+ - input_ids (torch.LongTensor): [N + image tokens]
233
+ - target_ids (torch.LongTensor): [N + image tokens]
234
+ - images (torch.FloatTensor): [n_images, 3, H, W]
235
+ - image_id (int): the id of the image token
236
+ - num_image_tokens (List[int]): the number of image tokens
237
+ """
238
+
239
+ assert (
240
+ prompt is None or conversations is None
241
+ ), "prompt and conversations cannot be used at the same time."
242
+
243
+ if prompt is None:
244
+ # apply sft format
245
+ sft_format = self.apply_sft_template_for_multi_turn_prompts(
246
+ conversations=conversations,
247
+ sft_format=self.sft_format,
248
+ system_prompt=self.system_prompt,
249
+ )
250
+ else:
251
+ sft_format = prompt
252
+
253
+ # tokenize
254
+ input_ids = self.tokenizer.encode(sft_format)
255
+ input_ids = torch.LongTensor(input_ids)
256
+
257
+ # add image tokens to the input_ids
258
+ image_token_mask: torch.BoolTensor = input_ids == self.image_id
259
+ image_indices = image_token_mask.nonzero()
260
+ input_ids, num_image_tokens = self.add_image_token(
261
+ image_indices=image_indices,
262
+ input_ids=input_ids,
263
+ )
264
+
265
+ # load images
266
+ images_outputs = self.image_processor(images, return_tensors="pt")
267
+
268
+ prepare = VLChatProcessorOutput(
269
+ sft_format=sft_format,
270
+ input_ids=input_ids,
271
+ pixel_values=images_outputs.pixel_values,
272
+ num_image_tokens=num_image_tokens,
273
+ )
274
+
275
+ return prepare
276
+
277
+ def __call__(
278
+ self,
279
+ *,
280
+ prompt: str = None,
281
+ conversations: List[Dict[str, str]] = None,
282
+ images: List[Image] = None,
283
+ force_batchify: bool = True,
284
+ **kwargs,
285
+ ):
286
+ """
287
+
288
+ Args:
289
+ prompt (str): the formatted prompt;
290
+ conversations (List[Dict]): conversations with a list of messages;
291
+ images (List[ImageType]): the list of images;
292
+ force_batchify (bool): force batchify the inputs;
293
+ **kwargs:
294
+
295
+ Returns:
296
+ outputs (BaseProcessorOutput): the output of the processor,
297
+ - input_ids (torch.LongTensor): [N + image tokens]
298
+ - images (torch.FloatTensor): [n_images, 3, H, W]
299
+ - image_id (int): the id of the image token
300
+ - num_image_tokens (List[int]): the number of image tokens
301
+ """
302
+
303
+ prepare = self.process_one(
304
+ prompt=prompt, conversations=conversations, images=images
305
+ )
306
+
307
+ if force_batchify:
308
+ prepare = self.batchify([prepare])
309
+
310
+ return prepare
311
+
312
+ def batchify(
313
+ self, prepare_list: List[VLChatProcessorOutput]
314
+ ) -> BatchedVLChatProcessorOutput:
315
+ """
316
+ Preprocesses the inputs for multimodal inference.
317
+
318
+ Args:
319
+ prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
320
+
321
+ Returns:
322
+ BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
323
+ """
324
+
325
+ batch_size = len(prepare_list)
326
+ sft_format = []
327
+ n_images = []
328
+ seq_lens = []
329
+ for prepare in prepare_list:
330
+ n_images.append(len(prepare.num_image_tokens))
331
+ seq_lens.append(len(prepare))
332
+
333
+ input_token_max_len = max(seq_lens)
334
+ max_n_images = max(1, max(n_images))
335
+
336
+ batched_input_ids = torch.full(
337
+ (batch_size, input_token_max_len), self.pad_id
338
+ ).long() # FIXME
339
+ batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
340
+ batched_pixel_values = torch.zeros(
341
+ (batch_size, max_n_images, *self.image_processor.default_shape)
342
+ ).float()
343
+ batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
344
+ batched_images_emb_mask = torch.zeros(
345
+ (batch_size, max_n_images, self.num_image_tokens)
346
+ ).bool()
347
+
348
+ for i, prepare in enumerate(prepare_list):
349
+ input_ids = prepare.input_ids
350
+ seq_len = len(prepare)
351
+ n_image = len(prepare.num_image_tokens)
352
+ # left-padding
353
+ batched_attention_mask[i, -seq_len:] = 1
354
+ batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
355
+ batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
356
+
357
+ if n_image > 0:
358
+ batched_pixel_values[i, :n_image] = prepare.pixel_values
359
+ for j, n_image_tokens in enumerate(prepare.num_image_tokens):
360
+ batched_images_emb_mask[i, j, :n_image_tokens] = True
361
+
362
+ sft_format.append(prepare.sft_format)
363
+
364
+ batched_prepares = BatchedVLChatProcessorOutput(
365
+ input_ids=batched_input_ids,
366
+ attention_mask=batched_attention_mask,
367
+ pixel_values=batched_pixel_values,
368
+ images_seq_mask=batched_images_seq_mask,
369
+ images_emb_mask=batched_images_emb_mask,
370
+ sft_format=sft_format,
371
+ )
372
+
373
+ return batched_prepares