qianyuchen commited on
Commit
88d11f4
1 Parent(s): 45387f9

Create modeling_minicpmv.py

Browse files

修改get_vision_embedding 使模型可以适应zero3的finetuning

Files changed (1) hide show
  1. modeling_minicpmv.py +130 -232
modeling_minicpmv.py CHANGED
@@ -1,22 +1,21 @@
1
  import math
2
  from typing import List, Optional
3
  import json
 
4
  import torch
5
  import torchvision
6
- from threading import Thread
7
- from copy import deepcopy
8
  from PIL import Image
 
9
  from torchvision import transforms
10
- from transformers import LlamaTokenizer, LlamaPreTrainedModel, LlamaForCausalLM, AutoModel, PreTrainedTokenizerFast, TextIteratorStreamer
11
- from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
12
-
13
  from .configuration_minicpm import MiniCPMVConfig
 
14
  from .resampler import Resampler
15
 
16
- IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
17
- IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
18
 
19
- class MiniCPMVPreTrainedModel(LlamaPreTrainedModel):
20
  config_class = MiniCPMVConfig
21
 
22
 
@@ -24,7 +23,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
24
  def __init__(self, config):
25
  super().__init__(config)
26
 
27
- self.llm = LlamaForCausalLM(config)
28
  self.vpm = self.init_vision_module()
29
  self.vision_dim = self.vpm.embed_dim
30
  self.embed_dim = self.llm.config.hidden_size
@@ -32,19 +31,26 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
32
  self.transform = self.init_transform()
33
 
34
  def init_vision_module(self):
35
- # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
36
- model = Idefics2VisionTransformer(self.config.vision_config)
37
- if self.config.drop_vision_last_layer:
38
- model.encoder.layers = model.encoder.layers[:-1]
 
 
 
39
 
40
- setattr(model, 'embed_dim', model.embeddings.embed_dim)
41
- setattr(model, 'patch_size', model.embeddings.patch_size)
 
 
 
 
42
 
43
  return model
44
 
45
  def init_resampler(self, embed_dim, vision_dim):
46
  return Resampler(
47
- num_queries=self.config.query_num,
48
  embed_dim=embed_dim,
49
  num_heads=embed_dim // 128,
50
  kv_dim=vision_dim,
@@ -67,94 +73,75 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
67
  def set_input_embeddings(self, value):
68
  self.llm.embed_tokens = value
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def get_vllm_embedding(self, data):
71
- if 'vision_hidden_states' not in data:
72
- dtype = self.vpm.embeddings.position_embedding.weight.dtype
73
- device = self.vpm.embeddings.position_embedding.weight.device
74
- tgt_sizes = data['tgt_sizes']
75
- pixel_values_list = data['pixel_values']
76
  vision_hidden_states = []
77
- all_pixel_values = []
78
- img_cnt = []
79
  for pixel_values in pixel_values_list:
80
- img_cnt.append(len(pixel_values))
81
- all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
82
-
83
- # exist image
84
- if all_pixel_values:
85
- tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
86
-
87
- if self.config.batch_vision_input:
88
- max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
89
-
90
- all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True,
91
- padding_value=0.0)
92
- B, L, _ = all_pixel_values.shape
93
- all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
94
-
95
- patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
96
- for i in range(B):
97
- patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
98
-
99
- vision_embedding = self.vpm(all_pixel_values.type(dtype), patch_attention_mask=patch_attn_mask).last_hidden_state
100
- vision_embedding = self.resampler(vision_embedding, tgt_sizes)
101
- else:
102
- # get vision_embedding foreach
103
- vision_embedding = []
104
- for single_tgt_size, single_pixel_values in zip(tgt_sizes, all_pixel_values):
105
- single_pixel_values = single_pixel_values.unsqueeze(0)
106
- B, L, _ = single_pixel_values.shape
107
- single_pixel_values = single_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
108
- single_vision_embedding = self.vpm(single_pixel_values.type(dtype)).last_hidden_state
109
- single_vision_embedding = self.resampler(single_vision_embedding, single_tgt_size.unsqueeze(0))
110
- vision_embedding.append(single_vision_embedding)
111
- vision_embedding = torch.vstack(vision_embedding)
112
-
113
- start = 0
114
- for pixel_values in pixel_values_list:
115
- img_cnt = len(pixel_values)
116
- if img_cnt > 0:
117
- vision_hidden_states.append(vision_embedding[start: start + img_cnt])
118
- start += img_cnt
119
- else:
120
- vision_hidden_states.append([])
121
- else: # no image
122
- if self.training:
123
  dummy_image = torch.zeros(
124
- (1, 3, 224, 224),
125
- device=device, dtype=dtype
126
  )
127
- tgt_sizes = torch.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).type(torch.int32)
128
- dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
129
  else:
130
- dummy_feature = []
131
- for _ in range(len(pixel_values_list)):
132
- vision_hidden_states.append(dummy_feature)
133
-
134
- else:
135
- vision_hidden_states = data['vision_hidden_states']
136
 
137
- if hasattr(self.llm.config, 'scale_emb'):
138
- vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
139
  else:
140
- vllm_embedding = self.llm.model.embed_tokens(data['input_ids'])
141
 
142
- vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(
143
- i, torch.Tensor) else i for i in vision_hidden_states]
 
 
 
 
 
144
 
145
- bs = len(data['input_ids'])
146
  for i in range(bs):
147
  cur_vs_hs = vision_hidden_states[i]
148
  if len(cur_vs_hs) > 0:
149
  cur_vllm_emb = vllm_embedding[i]
150
- cur_image_bound = data['image_bound'][i]
151
  if len(cur_image_bound) > 0:
152
  image_indices = torch.stack(
153
- [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
 
 
 
154
  ).to(vllm_embedding.device)
155
 
156
- cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
157
- cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
 
 
 
158
  elif self.training:
159
  cur_vllm_emb += cur_vs_hs[0].mean() * 0
160
 
@@ -174,8 +161,12 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
174
  )
175
 
176
  def _convert_to_tensors(
177
- self, tokenizer, input_ids, max_inp_length: Optional[int] = None
178
  ):
 
 
 
 
179
  if max_inp_length is not None:
180
  input_ids = input_ids[:max_inp_length]
181
  input_ids = torch.tensor(input_ids, dtype=torch.int32)
@@ -199,13 +190,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
199
  return model_input
200
 
201
  def _process_list(
202
- self, tokenizer, input_id_list, max_inp_length: Optional[int] = None
203
  ):
204
  pad_keys = ["input_ids"]
205
  input_tensors = []
206
- for input_ids in input_id_list:
207
  input_tensors.append(
208
- self._convert_to_tensors(tokenizer, input_ids, max_inp_length)
209
  )
210
  padded = {}
211
  for key in pad_keys:
@@ -214,36 +205,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
214
  return padded
215
 
216
  def _decode(self, inputs_embeds, tokenizer, **kwargs):
217
- terminators = [
218
- tokenizer.eos_token_id,
219
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
220
- ]
221
  output = self.llm.generate(
222
  inputs_embeds=inputs_embeds,
223
  pad_token_id=0,
224
- eos_token_id=terminators,
225
  **kwargs
226
  )
227
  return self._decode_text(output, tokenizer)
228
-
229
- def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
230
- terminators = [
231
- tokenizer.eos_token_id,
232
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
233
- ]
234
- streamer = TextIteratorStreamer(tokenizer=tokenizer)
235
- generation_kwargs = {
236
- 'inputs_embeds': inputs_embeds,
237
- 'pad_token_id': 0,
238
- 'eos_token_id': terminators,
239
- 'streamer': streamer
240
- }
241
- generation_kwargs.update(kwargs)
242
-
243
- thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
244
- thread.start()
245
-
246
- return streamer
247
 
248
  def _decode_text(self, result_ids, tokenizer):
249
  result_text = []
@@ -251,7 +219,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
251
  result = result[result != 0]
252
  if result[0] == tokenizer.bos_id:
253
  result = result[1:]
254
- if result[-1] == tokenizer.eos_id or result[-1] == tokenizer.eot_id:
255
  result = result[:-1]
256
  result_text.append(tokenizer.decode(result).strip())
257
  return result_text
@@ -259,9 +227,9 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
259
  def slice_image(self, image):
260
  return slice_image(
261
  image,
262
- self.config.slice_config.max_slice_nums,
263
- self.config.slice_config.scale_resolution,
264
- self.config.slice_config.patch_size,
265
  )
266
 
267
  def get_slice_image_placeholder(self, image, tokenizer):
@@ -275,9 +243,9 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
275
 
276
  source_image, patches, best_grid = slice_image(
277
  image,
278
- self.config.slice_config.max_slice_nums,
279
- self.config.slice_config.scale_resolution,
280
- self.config.slice_config.patch_size,
281
  )
282
 
283
  slice_images.append(source_image)
@@ -294,56 +262,36 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
294
 
295
  return slice_images, final_placeholder
296
 
297
- def reshape_by_patch(self, image_tensor):
298
- """
299
- :param image_tensor: shape [3, H, W]
300
- :param patch_size:
301
- :return: [3, patch_size, HW/patch_size]
302
- """
303
- patch_size = self.config.patch_size
304
- patches = torch.nn.functional.unfold(
305
- image_tensor,
306
- (patch_size, patch_size),
307
- stride=(patch_size, patch_size)
308
- )
309
-
310
- patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)
311
- patches = patches.permute(0, 1, 3, 2).reshape(image_tensor.size(0), patch_size, -1)
312
- return patches
313
-
314
  def generate(
315
  self,
316
- input_id_list=None,
317
  img_list=None,
318
- tgt_sizes=None,
319
  tokenizer=None,
320
  max_inp_length: Optional[int] = None,
321
  vision_hidden_states=None,
322
  return_vision_hidden_states=False,
323
- stream=False,
324
  **kwargs
325
  ):
326
 
327
- assert input_id_list is not None
328
- bs = len(input_id_list)
329
  if img_list == None:
330
  img_list = [[] for i in range(bs)]
331
  assert bs == len(img_list)
332
 
333
- model_inputs = self._process_list(tokenizer, input_id_list, max_inp_length)
334
 
335
  if vision_hidden_states is None:
336
  pixel_values = []
337
  for i in range(bs):
338
  img_inps = []
339
  for img in img_list[i]:
340
- img_inps.append(img.to(self.device))
341
  if img_inps:
342
  pixel_values.append(img_inps)
343
  else:
344
  pixel_values.append([])
345
  model_inputs["pixel_values"] = pixel_values
346
- model_inputs['tgt_sizes'] = tgt_sizes
347
  else:
348
  model_inputs["vision_hidden_states"] = vision_hidden_states
349
 
@@ -353,10 +301,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
353
  vision_hidden_states,
354
  ) = self.get_vllm_embedding(model_inputs)
355
 
356
- if stream:
357
- result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
358
- else:
359
- result = self._decode(model_inputs["inputs_embeds"], tokenizer, **kwargs)
360
 
361
  if return_vision_hidden_states:
362
  return result, vision_hidden_states
@@ -367,70 +312,42 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
367
  self,
368
  image,
369
  msgs,
 
370
  tokenizer,
371
  vision_hidden_states=None,
372
  max_new_tokens=1024,
373
  sampling=True,
374
  max_inp_length=2048,
375
- system_prompt='',
376
- stream=False,
377
  **kwargs
378
  ):
379
  if isinstance(msgs, str):
380
  msgs = json.loads(msgs)
381
-
382
- copy_msgs = deepcopy(msgs)
383
- assert len(copy_msgs) > 0, 'msgs is empty'
384
- assert sampling or not stream, 'if use stream mode, make sure sampling=True'
385
-
386
- if image is not None and isinstance(copy_msgs[0]['content'], str):
387
- copy_msgs[0]['content'] = [image, copy_msgs[0]['content']]
388
-
389
- images = []
390
- tgt_sizes = []
391
- for i, msg in enumerate(copy_msgs):
392
  role = msg["role"]
393
  content = msg["content"]
394
  assert role in ["user", "assistant"]
395
  if i == 0:
396
  assert role == "user", "The role of first msg should be user"
397
- if isinstance(content, str):
398
- content = [content]
399
-
400
- cur_msgs = []
401
- for c in content:
402
- if isinstance(c, Image.Image):
403
- image = c
404
- if self.config.slice_mode:
405
- slice_images, image_placeholder = self.get_slice_image_placeholder(
406
- image, tokenizer
407
- )
408
- cur_msgs.append(image_placeholder)
409
- for slice_image in slice_images:
410
- slice_image = self.transform(slice_image)
411
- H, W = slice_image.shape[1:]
412
- images.append(self.reshape_by_patch(slice_image))
413
- tgt_sizes.append(torch.Tensor([H // self.config.patch_size, W // self.config.patch_size]).type(torch.int32))
414
- else:
415
- images.append(self.transform(image))
416
- cur_msgs.append(
417
- tokenizer.im_start
418
- + tokenizer.unk_token * self.config.query_num
419
- + tokenizer.im_end
420
- )
421
- elif isinstance(c, str):
422
- cur_msgs.append(c)
423
-
424
-
425
- msg['content'] = '\n'.join(cur_msgs)
426
- if tgt_sizes:
427
- tgt_sizes = torch.vstack(tgt_sizes)
428
-
429
- if system_prompt:
430
- sys_msg = {'role': 'system', 'content': system_prompt}
431
- copy_msgs = [sys_msg] + copy_msgs
432
-
433
- input_ids = tokenizer.apply_chat_template(copy_msgs, tokenize=True, add_generation_prompt=False)
434
 
435
  if sampling:
436
  generation_config = {
@@ -452,34 +369,25 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
452
 
453
  with torch.inference_mode():
454
  res, vision_hidden_states = self.generate(
455
- input_id_list=[input_ids],
456
  max_inp_length=max_inp_length,
457
  img_list=[images],
458
- tgt_sizes=[tgt_sizes],
459
  tokenizer=tokenizer,
460
  max_new_tokens=max_new_tokens,
461
  vision_hidden_states=vision_hidden_states,
462
  return_vision_hidden_states=True,
463
- stream=stream,
464
  **generation_config
465
  )
 
 
 
466
 
467
- if stream:
468
- def stream_gen():
469
- for text in res:
470
- text = text.replace(tokenizer.eot_token, '').replace(tokenizer.eos_token, '')
471
- yield text
472
- return stream_gen()
473
-
474
- else:
475
- answer = res[0]
476
- return answer
477
 
478
 
479
- class PreTrainedTokenizerFastWrapper(PreTrainedTokenizerFast):
480
  def __init__(self, **kwargs):
481
  super().__init__(**kwargs)
482
- self.eot_token = "<|eot_id|>"
483
  self.im_start = "<image>"
484
  self.im_end = "</image>"
485
  self.ref_start = "<ref>"
@@ -488,40 +396,30 @@ class PreTrainedTokenizerFastWrapper(PreTrainedTokenizerFast):
488
  self.box_end = "</box>"
489
  self.quad_start = "<quad>"
490
  self.quad_end = "</quad>"
 
 
491
  self.slice_start = "<slice>"
492
  self.slice_end = "</slice>"
493
 
494
  @property
495
  def eos_id(self):
496
- return self.eos_token_id
497
 
498
  @property
499
  def bos_id(self):
500
- return self.bos_token_id
501
 
502
  @property
503
  def unk_id(self):
504
- return self.unk_token_id
505
-
506
- @property
507
- def eot_id(self):
508
- return self.convert_tokens_to_ids(self.eot_token)
509
 
510
  @property
511
  def im_start_id(self):
512
- return self.convert_tokens_to_ids(self.im_start)
513
 
514
  @property
515
  def im_end_id(self):
516
- return self.convert_tokens_to_ids(self.im_end)
517
-
518
- @staticmethod
519
- def escape(text: str) -> str:
520
- return text
521
-
522
- @staticmethod
523
- def unescape(text: str) -> str:
524
- return text
525
 
526
 
527
  def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"):
 
1
  import math
2
  from typing import List, Optional
3
  import json
4
+ import timm
5
  import torch
6
  import torchvision
7
+ import deepspeed
 
8
  from PIL import Image
9
+ from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
10
  from torchvision import transforms
11
+ from transformers import LlamaTokenizer
12
+ from transformers.integrations import is_deepspeed_zero3_enabled
 
13
  from .configuration_minicpm import MiniCPMVConfig
14
+ from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMPreTrainedModel
15
  from .resampler import Resampler
16
 
 
 
17
 
18
+ class MiniCPMVPreTrainedModel(MiniCPMPreTrainedModel):
19
  config_class = MiniCPMVConfig
20
 
21
 
 
23
  def __init__(self, config):
24
  super().__init__(config)
25
 
26
+ self.llm = MiniCPMForCausalLM(config)
27
  self.vpm = self.init_vision_module()
28
  self.vision_dim = self.vpm.embed_dim
29
  self.embed_dim = self.llm.config.hidden_size
 
31
  self.transform = self.init_transform()
32
 
33
  def init_vision_module(self):
34
+ model = timm.create_model(
35
+ self.config.vision_encoder,
36
+ pretrained=False,
37
+ num_classes=0,
38
+ dynamic_img_size=True,
39
+ dynamic_img_pad=True
40
+ )
41
 
42
+ if isinstance(model, timm.models.VisionTransformer):
43
+ if model.attn_pool is not None:
44
+ model.attn_pool = torch.nn.Identity()
45
+
46
+ if self.config.drop_vision_last_layer:
47
+ model.blocks = model.blocks[:-1]
48
 
49
  return model
50
 
51
  def init_resampler(self, embed_dim, vision_dim):
52
  return Resampler(
53
+ grid_size=int(math.sqrt(self.config.query_num)),
54
  embed_dim=embed_dim,
55
  num_heads=embed_dim // 128,
56
  kv_dim=vision_dim,
 
73
  def set_input_embeddings(self, value):
74
  self.llm.embed_tokens = value
75
 
76
+ def get_vision_embedding(self, pixel_values):
77
+ res = []
78
+ dtype = self.llm.lm_head.weight.dtype
79
+ def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
80
+ H, W = pixel_value.shape[-2:]
81
+ target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
82
+ vision_embedding = self.vpm.forward_features(pixel_value.unsqueeze(0).type(dtype))
83
+ if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
84
+ vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
85
+ return resampler(vision_embedding, target_size)
86
+
87
+ if is_deepspeed_zero3_enabled():
88
+ with deepspeed.zero.GatheredParameters(self.vpm.pos_embed, modifier_rank=0):
89
+ for pixel_value in pixel_values:
90
+ result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
91
+ res.append(result)
92
+ else:
93
+ for pixel_value in pixel_values:
94
+ result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
95
+ res.append(result)
96
+ return torch.vstack(res)
97
+
98
  def get_vllm_embedding(self, data):
99
+ if "vision_hidden_states" not in data:
100
+ pixel_values_list = data["pixel_values"]
 
 
 
101
  vision_hidden_states = []
 
 
102
  for pixel_values in pixel_values_list:
103
+ if len(pixel_values) > 0:
104
+ vision_hidden_states.append(self.get_vision_embedding(pixel_values))
105
+ elif self.training:
106
+ dtype = self.llm.lm_head.weight.dtype
107
+ device = self.llm.lm_head.weight.device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  dummy_image = torch.zeros(
109
+ (1, 3, 224, 224), device=device, dtype=dtype
 
110
  )
111
+ vision_hidden_states.append(self.get_vision_embedding(dummy_image))
 
112
  else:
113
+ vision_hidden_states.append([])
 
 
 
 
 
114
 
 
 
115
  else:
116
+ vision_hidden_states = data["vision_hidden_states"]
117
 
118
+ vllm_embedding = (
119
+ self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb
120
+ )
121
+ vision_hidden_states = [
122
+ i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i
123
+ for i in vision_hidden_states
124
+ ]
125
 
126
+ bs = len(data["input_ids"])
127
  for i in range(bs):
128
  cur_vs_hs = vision_hidden_states[i]
129
  if len(cur_vs_hs) > 0:
130
  cur_vllm_emb = vllm_embedding[i]
131
+ cur_image_bound = data["image_bound"][i]
132
  if len(cur_image_bound) > 0:
133
  image_indices = torch.stack(
134
+ [
135
+ torch.arange(r[0], r[1], dtype=torch.long)
136
+ for r in cur_image_bound
137
+ ]
138
  ).to(vllm_embedding.device)
139
 
140
+ cur_vllm_emb.scatter_(
141
+ 0,
142
+ image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
143
+ cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
144
+ )
145
  elif self.training:
146
  cur_vllm_emb += cur_vs_hs[0].mean() * 0
147
 
 
161
  )
162
 
163
  def _convert_to_tensors(
164
+ self, tokenizer, input_str, max_inp_length: Optional[int] = None
165
  ):
166
+ if tokenizer.add_bos_token:
167
+ input_ids = tokenizer.encode(input_str)
168
+ else:
169
+ input_ids = [tokenizer.bos_id] + tokenizer.encode(input_str)
170
  if max_inp_length is not None:
171
  input_ids = input_ids[:max_inp_length]
172
  input_ids = torch.tensor(input_ids, dtype=torch.int32)
 
190
  return model_input
191
 
192
  def _process_list(
193
+ self, tokenizer, data_list: List[str], max_inp_length: Optional[int] = None
194
  ):
195
  pad_keys = ["input_ids"]
196
  input_tensors = []
197
+ for data in data_list:
198
  input_tensors.append(
199
+ self._convert_to_tensors(tokenizer, data, max_inp_length)
200
  )
201
  padded = {}
202
  for key in pad_keys:
 
205
  return padded
206
 
207
  def _decode(self, inputs_embeds, tokenizer, **kwargs):
 
 
 
 
208
  output = self.llm.generate(
209
  inputs_embeds=inputs_embeds,
210
  pad_token_id=0,
211
+ eos_token_id=tokenizer.eos_token_id,
212
  **kwargs
213
  )
214
  return self._decode_text(output, tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  def _decode_text(self, result_ids, tokenizer):
217
  result_text = []
 
219
  result = result[result != 0]
220
  if result[0] == tokenizer.bos_id:
221
  result = result[1:]
222
+ if result[-1] == tokenizer.eos_id:
223
  result = result[:-1]
224
  result_text.append(tokenizer.decode(result).strip())
225
  return result_text
 
227
  def slice_image(self, image):
228
  return slice_image(
229
  image,
230
+ self.config.max_slice_nums,
231
+ self.config.scale_resolution,
232
+ self.config.patch_size,
233
  )
234
 
235
  def get_slice_image_placeholder(self, image, tokenizer):
 
243
 
244
  source_image, patches, best_grid = slice_image(
245
  image,
246
+ self.config.max_slice_nums,
247
+ self.config.scale_resolution,
248
+ self.config.patch_size,
249
  )
250
 
251
  slice_images.append(source_image)
 
262
 
263
  return slice_images, final_placeholder
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  def generate(
266
  self,
267
+ data_list=None,
268
  img_list=None,
 
269
  tokenizer=None,
270
  max_inp_length: Optional[int] = None,
271
  vision_hidden_states=None,
272
  return_vision_hidden_states=False,
 
273
  **kwargs
274
  ):
275
 
276
+ assert data_list is not None
277
+ bs = len(data_list)
278
  if img_list == None:
279
  img_list = [[] for i in range(bs)]
280
  assert bs == len(img_list)
281
 
282
+ model_inputs = self._process_list(tokenizer, data_list, max_inp_length)
283
 
284
  if vision_hidden_states is None:
285
  pixel_values = []
286
  for i in range(bs):
287
  img_inps = []
288
  for img in img_list[i]:
289
+ img_inps.append(self.transform(img).to(self.device))
290
  if img_inps:
291
  pixel_values.append(img_inps)
292
  else:
293
  pixel_values.append([])
294
  model_inputs["pixel_values"] = pixel_values
 
295
  else:
296
  model_inputs["vision_hidden_states"] = vision_hidden_states
297
 
 
301
  vision_hidden_states,
302
  ) = self.get_vllm_embedding(model_inputs)
303
 
304
+ result = self._decode(model_inputs["inputs_embeds"], tokenizer, **kwargs)
 
 
 
305
 
306
  if return_vision_hidden_states:
307
  return result, vision_hidden_states
 
312
  self,
313
  image,
314
  msgs,
315
+ context,
316
  tokenizer,
317
  vision_hidden_states=None,
318
  max_new_tokens=1024,
319
  sampling=True,
320
  max_inp_length=2048,
 
 
321
  **kwargs
322
  ):
323
  if isinstance(msgs, str):
324
  msgs = json.loads(msgs)
325
+ # msgs to prompt
326
+ prompt = ""
327
+ for i, msg in enumerate(msgs):
 
 
 
 
 
 
 
 
328
  role = msg["role"]
329
  content = msg["content"]
330
  assert role in ["user", "assistant"]
331
  if i == 0:
332
  assert role == "user", "The role of first msg should be user"
333
+ if self.config.slice_mode:
334
+ images, final_placeholder = self.get_slice_image_placeholder(
335
+ image, tokenizer
336
+ )
337
+ content = final_placeholder + "\n" + content
338
+ else:
339
+ images = [image]
340
+ content = (
341
+ tokenizer.im_start
342
+ + tokenizer.unk_token * self.config.query_num
343
+ + tokenizer.im_end
344
+ + "\n"
345
+ + content
346
+ )
347
+ prompt += "<用户>" if role == "user" else "<AI>"
348
+ prompt += content
349
+ prompt += "<AI>"
350
+ final_input = prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
  if sampling:
353
  generation_config = {
 
369
 
370
  with torch.inference_mode():
371
  res, vision_hidden_states = self.generate(
372
+ data_list=[final_input],
373
  max_inp_length=max_inp_length,
374
  img_list=[images],
 
375
  tokenizer=tokenizer,
376
  max_new_tokens=max_new_tokens,
377
  vision_hidden_states=vision_hidden_states,
378
  return_vision_hidden_states=True,
 
379
  **generation_config
380
  )
381
+ answer = res[0]
382
+ context = msgs.copy()
383
+ context.append({"role": "assistant", "content": answer})
384
 
385
+ return answer, context, generation_config
 
 
 
 
 
 
 
 
 
386
 
387
 
388
+ class LlamaTokenizerWrapper(LlamaTokenizer):
389
  def __init__(self, **kwargs):
390
  super().__init__(**kwargs)
 
391
  self.im_start = "<image>"
392
  self.im_end = "</image>"
393
  self.ref_start = "<ref>"
 
396
  self.box_end = "</box>"
397
  self.quad_start = "<quad>"
398
  self.quad_end = "</quad>"
399
+ self.point_start = "<point>"
400
+ self.point_end = "</point>"
401
  self.slice_start = "<slice>"
402
  self.slice_end = "</slice>"
403
 
404
  @property
405
  def eos_id(self):
406
+ return self.sp_model.eos_id()
407
 
408
  @property
409
  def bos_id(self):
410
+ return self.sp_model.bos_id()
411
 
412
  @property
413
  def unk_id(self):
414
+ return self.sp_model.unk_id()
 
 
 
 
415
 
416
  @property
417
  def im_start_id(self):
418
+ return self._convert_token_to_id(self.im_start)
419
 
420
  @property
421
  def im_end_id(self):
422
+ return self._convert_token_to_id(self.im_end)
 
 
 
 
 
 
 
 
423
 
424
 
425
  def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"):