Update modeling_minicpmv.py

#17
Files changed (1) hide show
  1. modeling_minicpmv.py +43 -29
modeling_minicpmv.py CHANGED
@@ -1,17 +1,20 @@
1
  import math
2
- from typing import List, Optional
3
  import json
4
  import timm
5
  import torch
6
  import torchvision
 
7
  from PIL import Image
8
  from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
9
  from torchvision import transforms
10
  from transformers import LlamaTokenizer
11
-
12
  from .configuration_minicpm import MiniCPMVConfig
13
  from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMPreTrainedModel
14
  from .resampler import Resampler
 
 
 
15
 
16
 
17
  class MiniCPMVPreTrainedModel(MiniCPMPreTrainedModel):
@@ -72,17 +75,29 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
72
  def set_input_embeddings(self, value):
73
  self.llm.embed_tokens = value
74
 
 
 
 
 
 
 
 
 
75
  def get_vision_embedding(self, pixel_values):
76
  res = []
77
- dtype = self.vpm.pos_embed.data.dtype
78
- for pixel_value in pixel_values:
79
  H, W = pixel_value.shape[-2:]
80
- tgt_size = (
81
- math.ceil(H / self.vpm.patch_embed.patch_size[0]), math.ceil(W / self.vpm.patch_embed.patch_size[0]))
82
- vision_embedding = self.vpm.forward_features(pixel_value.unsqueeze(0).type(dtype))
83
- if hasattr(self.vpm, 'num_prefix_tokens') and self.vpm.num_prefix_tokens > 0:
84
- vision_embedding = vision_embedding[:, self.vpm.num_prefix_tokens:]
85
- res.append(self.resampler(vision_embedding, tgt_size))
 
 
 
 
86
  return torch.vstack(res)
87
 
88
  def get_vllm_embedding(self, data):
@@ -93,8 +108,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
93
  if len(pixel_values) > 0:
94
  vision_hidden_states.append(self.get_vision_embedding(pixel_values))
95
  elif self.training:
96
- dtype = self.vpm.pos_embed.data.dtype
97
- device = self.vpm.pos_embed.data.device
98
  dummy_image = torch.zeros(
99
  (1, 3, 224, 224), device=device, dtype=dtype
100
  )
@@ -319,24 +334,21 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
319
  content = msg["content"]
320
  assert role in ["user", "assistant"]
321
  if i == 0:
322
- if image is None:
323
- images = []
 
 
 
 
324
  else:
325
- assert role == "user", "The role of first msg should be user"
326
- if self.config.slice_mode:
327
- images, final_placeholder = self.get_slice_image_placeholder(
328
- image, tokenizer
329
- )
330
- content = final_placeholder + "\n" + content
331
- else:
332
- images = [image]
333
- content = (
334
- tokenizer.im_start
335
- + tokenizer.unk_token * self.config.query_num
336
- + tokenizer.im_end
337
- + "\n"
338
- + content
339
- )
340
  prompt += "<用户>" if role == "user" else "<AI>"
341
  prompt += content
342
  prompt += "<AI>"
@@ -377,6 +389,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
377
 
378
  return answer, context, generation_config
379
 
 
 
380
 
381
  class LlamaTokenizerWrapper(LlamaTokenizer):
382
  def __init__(self, **kwargs):
 
1
  import math
 
2
  import json
3
  import timm
4
  import torch
5
  import torchvision
6
+ import deepspeed
7
  from PIL import Image
8
  from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
9
  from torchvision import transforms
10
  from transformers import LlamaTokenizer
11
+ from transformers.integrations import is_deepspeed_zero3_enabled
12
  from .configuration_minicpm import MiniCPMVConfig
13
  from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMPreTrainedModel
14
  from .resampler import Resampler
15
+ from functools import partial
16
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
17
+ from peft.utils.other import ModulesToSaveWrapper
18
 
19
 
20
  class MiniCPMVPreTrainedModel(MiniCPMPreTrainedModel):
 
75
  def set_input_embeddings(self, value):
76
  self.llm.embed_tokens = value
77
 
78
+ def vpm_forward_features(self, pixel_value):
79
+ if isinstance(self.vpm, ModulesToSaveWrapper):
80
+ if self.vpm.disable_adapters or (self.vpm.active_adapter not in self.vpm.modules_to_save):
81
+ return self.vpm.original_module.forward_features(pixel_value)
82
+ return self.vpm.modules_to_save[self.vpm.active_adapter].forward_features(pixel_value)
83
+ else:
84
+ return self.vpm.forward_features(pixel_value)
85
+
86
  def get_vision_embedding(self, pixel_values):
87
  res = []
88
+ dtype = self.llm.lm_head.weight.dtype
89
+ def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
90
  H, W = pixel_value.shape[-2:]
91
+ target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
92
+ vision_embedding = self.vpm_forward_features(pixel_value.unsqueeze(0).type(dtype))
93
+
94
+ if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
95
+ vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
96
+ return resampler(vision_embedding, target_size)
97
+
98
+ for pixel_value in pixel_values:
99
+ result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
100
+ res.append(result)
101
  return torch.vstack(res)
102
 
103
  def get_vllm_embedding(self, data):
 
108
  if len(pixel_values) > 0:
109
  vision_hidden_states.append(self.get_vision_embedding(pixel_values))
110
  elif self.training:
111
+ dtype = self.llm.lm_head.weight.dtype
112
+ device = self.llm.lm_head.weight.device
113
  dummy_image = torch.zeros(
114
  (1, 3, 224, 224), device=device, dtype=dtype
115
  )
 
334
  content = msg["content"]
335
  assert role in ["user", "assistant"]
336
  if i == 0:
337
+ assert role == "user", "The role of first msg should be user"
338
+ if self.config.slice_mode:
339
+ images, final_placeholder = self.get_slice_image_placeholder(
340
+ image, tokenizer
341
+ )
342
+ content = final_placeholder + "\n" + content
343
  else:
344
+ images = [image]
345
+ content = (
346
+ tokenizer.im_start
347
+ + tokenizer.unk_token * self.config.query_num
348
+ + tokenizer.im_end
349
+ + "\n"
350
+ + content
351
+ )
 
 
 
 
 
 
 
352
  prompt += "<用户>" if role == "user" else "<AI>"
353
  prompt += content
354
  prompt += "<AI>"
 
389
 
390
  return answer, context, generation_config
391
 
392
+
393
+
394
 
395
  class LlamaTokenizerWrapper(LlamaTokenizer):
396
  def __init__(self, **kwargs):