qianyuchen commited on
Commit
832499a
1 Parent(s): 1878519

Update modeling_minicpmv.py

Browse files

update model.py for zero3 and lora finetuning

Files changed (1) hide show
  1. modeling_minicpmv.py +49 -29
modeling_minicpmv.py CHANGED
@@ -1,17 +1,20 @@
1
  import math
2
- from typing import List, Optional
3
  import json
4
  import timm
5
  import torch
6
  import torchvision
 
7
  from PIL import Image
8
  from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
9
  from torchvision import transforms
10
  from transformers import LlamaTokenizer
11
-
12
  from .configuration_minicpm import MiniCPMVConfig
13
  from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMPreTrainedModel
14
  from .resampler import Resampler
 
 
 
15
 
16
 
17
  class MiniCPMVPreTrainedModel(MiniCPMPreTrainedModel):
@@ -72,17 +75,35 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
72
  def set_input_embeddings(self, value):
73
  self.llm.embed_tokens = value
74
 
 
 
 
 
 
 
 
 
75
  def get_vision_embedding(self, pixel_values):
76
  res = []
77
- dtype = self.vpm.pos_embed.data.dtype
78
- for pixel_value in pixel_values:
79
  H, W = pixel_value.shape[-2:]
80
- tgt_size = (
81
- math.ceil(H / self.vpm.patch_embed.patch_size[0]), math.ceil(W / self.vpm.patch_embed.patch_size[0]))
82
- vision_embedding = self.vpm.forward_features(pixel_value.unsqueeze(0).type(dtype))
83
- if hasattr(self.vpm, 'num_prefix_tokens') and self.vpm.num_prefix_tokens > 0:
84
- vision_embedding = vision_embedding[:, self.vpm.num_prefix_tokens:]
85
- res.append(self.resampler(vision_embedding, tgt_size))
 
 
 
 
 
 
 
 
 
 
86
  return torch.vstack(res)
87
 
88
  def get_vllm_embedding(self, data):
@@ -93,8 +114,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
93
  if len(pixel_values) > 0:
94
  vision_hidden_states.append(self.get_vision_embedding(pixel_values))
95
  elif self.training:
96
- dtype = self.vpm.pos_embed.data.dtype
97
- device = self.vpm.pos_embed.data.device
98
  dummy_image = torch.zeros(
99
  (1, 3, 224, 224), device=device, dtype=dtype
100
  )
@@ -319,24 +340,21 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
319
  content = msg["content"]
320
  assert role in ["user", "assistant"]
321
  if i == 0:
322
- if image is None:
323
- images = []
 
 
 
 
324
  else:
325
- assert role == "user", "The role of first msg should be user"
326
- if self.config.slice_mode:
327
- images, final_placeholder = self.get_slice_image_placeholder(
328
- image, tokenizer
329
- )
330
- content = final_placeholder + "\n" + content
331
- else:
332
- images = [image]
333
- content = (
334
- tokenizer.im_start
335
- + tokenizer.unk_token * self.config.query_num
336
- + tokenizer.im_end
337
- + "\n"
338
- + content
339
- )
340
  prompt += "<用户>" if role == "user" else "<AI>"
341
  prompt += content
342
  prompt += "<AI>"
@@ -377,6 +395,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
377
 
378
  return answer, context, generation_config
379
 
 
 
380
 
381
  class LlamaTokenizerWrapper(LlamaTokenizer):
382
  def __init__(self, **kwargs):
 
1
  import math
 
2
  import json
3
  import timm
4
  import torch
5
  import torchvision
6
+ import deepspeed
7
  from PIL import Image
8
  from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
9
  from torchvision import transforms
10
  from transformers import LlamaTokenizer
11
+ from transformers.integrations import is_deepspeed_zero3_enabled
12
  from .configuration_minicpm import MiniCPMVConfig
13
  from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMPreTrainedModel
14
  from .resampler import Resampler
15
+ from functools import partial
16
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
17
+ from peft.utils.other import ModulesToSaveWrapper
18
 
19
 
20
  class MiniCPMVPreTrainedModel(MiniCPMPreTrainedModel):
 
75
  def set_input_embeddings(self, value):
76
  self.llm.embed_tokens = value
77
 
78
+ def vpm_forward_features(self, pixel_value):
79
+ if isinstance(self.vpm, ModulesToSaveWrapper):
80
+ if self.vpm.disable_adapters or (self.vpm.active_adapter not in self.vpm.modules_to_save):
81
+ return self.vpm.original_module.forward_features(pixel_value)
82
+ return self.vpm.modules_to_save[self.vpm.active_adapter].forward_features(pixel_value)
83
+ else:
84
+ return self.vpm.forward_features(pixel_value)
85
+
86
  def get_vision_embedding(self, pixel_values):
87
  res = []
88
+ dtype = self.llm.lm_head.weight.dtype
89
+ def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
90
  H, W = pixel_value.shape[-2:]
91
+ target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
92
+ vision_embedding = self.vpm_forward_features(pixel_value.unsqueeze(0).type(dtype))
93
+ if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
94
+ vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
95
+ return resampler(vision_embedding, target_size)
96
+
97
+ if is_deepspeed_zero3_enabled():
98
+ with deepspeed.zero.GatheredParameters(self.vpm.pos_embed):
99
+ for pixel_value in pixel_values:
100
+ result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
101
+ res.append(result)
102
+ else:
103
+ for pixel_value in pixel_values:
104
+ print(pixel_value.shape)
105
+ result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
106
+ res.append(result)
107
  return torch.vstack(res)
108
 
109
  def get_vllm_embedding(self, data):
 
114
  if len(pixel_values) > 0:
115
  vision_hidden_states.append(self.get_vision_embedding(pixel_values))
116
  elif self.training:
117
+ dtype = self.llm.lm_head.weight.dtype
118
+ device = self.llm.lm_head.weight.device
119
  dummy_image = torch.zeros(
120
  (1, 3, 224, 224), device=device, dtype=dtype
121
  )
 
340
  content = msg["content"]
341
  assert role in ["user", "assistant"]
342
  if i == 0:
343
+ assert role == "user", "The role of first msg should be user"
344
+ if self.config.slice_mode:
345
+ images, final_placeholder = self.get_slice_image_placeholder(
346
+ image, tokenizer
347
+ )
348
+ content = final_placeholder + "\n" + content
349
  else:
350
+ images = [image]
351
+ content = (
352
+ tokenizer.im_start
353
+ + tokenizer.unk_token * self.config.query_num
354
+ + tokenizer.im_end
355
+ + "\n"
356
+ + content
357
+ )
 
 
 
 
 
 
 
358
  prompt += "<用户>" if role == "user" else "<AI>"
359
  prompt += content
360
  prompt += "<AI>"
 
395
 
396
  return answer, context, generation_config
397
 
398
+
399
+
400
 
401
  class LlamaTokenizerWrapper(LlamaTokenizer):
402
  def __init__(self, **kwargs):