Commit
•
52c7ed6
1
Parent(s):
1878519
Update modeling_minicpmv.py (#17)
Browse files- Update modeling_minicpmv.py (36640bdb52d2a56817da2a08c72252377539769b)
Co-authored-by: qianyu chen <qianyuchen@users.noreply.huggingface.co>
- modeling_minicpmv.py +43 -29
modeling_minicpmv.py
CHANGED
@@ -1,17 +1,20 @@
|
|
1 |
import math
|
2 |
-
from typing import List, Optional
|
3 |
import json
|
4 |
import timm
|
5 |
import torch
|
6 |
import torchvision
|
|
|
7 |
from PIL import Image
|
8 |
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
9 |
from torchvision import transforms
|
10 |
from transformers import LlamaTokenizer
|
11 |
-
|
12 |
from .configuration_minicpm import MiniCPMVConfig
|
13 |
from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMPreTrainedModel
|
14 |
from .resampler import Resampler
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
class MiniCPMVPreTrainedModel(MiniCPMPreTrainedModel):
|
@@ -72,17 +75,29 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
72 |
def set_input_embeddings(self, value):
|
73 |
self.llm.embed_tokens = value
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
def get_vision_embedding(self, pixel_values):
|
76 |
res = []
|
77 |
-
dtype = self.
|
78 |
-
|
79 |
H, W = pixel_value.shape[-2:]
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
if hasattr(
|
84 |
-
vision_embedding = vision_embedding[:,
|
85 |
-
|
|
|
|
|
|
|
|
|
86 |
return torch.vstack(res)
|
87 |
|
88 |
def get_vllm_embedding(self, data):
|
@@ -93,8 +108,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
93 |
if len(pixel_values) > 0:
|
94 |
vision_hidden_states.append(self.get_vision_embedding(pixel_values))
|
95 |
elif self.training:
|
96 |
-
dtype = self.
|
97 |
-
device = self.
|
98 |
dummy_image = torch.zeros(
|
99 |
(1, 3, 224, 224), device=device, dtype=dtype
|
100 |
)
|
@@ -319,24 +334,21 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
319 |
content = msg["content"]
|
320 |
assert role in ["user", "assistant"]
|
321 |
if i == 0:
|
322 |
-
|
323 |
-
|
|
|
|
|
|
|
|
|
324 |
else:
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
content = (
|
334 |
-
tokenizer.im_start
|
335 |
-
+ tokenizer.unk_token * self.config.query_num
|
336 |
-
+ tokenizer.im_end
|
337 |
-
+ "\n"
|
338 |
-
+ content
|
339 |
-
)
|
340 |
prompt += "<用户>" if role == "user" else "<AI>"
|
341 |
prompt += content
|
342 |
prompt += "<AI>"
|
@@ -377,6 +389,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
377 |
|
378 |
return answer, context, generation_config
|
379 |
|
|
|
|
|
380 |
|
381 |
class LlamaTokenizerWrapper(LlamaTokenizer):
|
382 |
def __init__(self, **kwargs):
|
|
|
1 |
import math
|
|
|
2 |
import json
|
3 |
import timm
|
4 |
import torch
|
5 |
import torchvision
|
6 |
+
import deepspeed
|
7 |
from PIL import Image
|
8 |
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
9 |
from torchvision import transforms
|
10 |
from transformers import LlamaTokenizer
|
11 |
+
from transformers.integrations import is_deepspeed_zero3_enabled
|
12 |
from .configuration_minicpm import MiniCPMVConfig
|
13 |
from .modeling_minicpm import MiniCPMForCausalLM, MiniCPMPreTrainedModel
|
14 |
from .resampler import Resampler
|
15 |
+
from functools import partial
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
17 |
+
from peft.utils.other import ModulesToSaveWrapper
|
18 |
|
19 |
|
20 |
class MiniCPMVPreTrainedModel(MiniCPMPreTrainedModel):
|
|
|
75 |
def set_input_embeddings(self, value):
|
76 |
self.llm.embed_tokens = value
|
77 |
|
78 |
+
def vpm_forward_features(self, pixel_value):
|
79 |
+
if isinstance(self.vpm, ModulesToSaveWrapper):
|
80 |
+
if self.vpm.disable_adapters or (self.vpm.active_adapter not in self.vpm.modules_to_save):
|
81 |
+
return self.vpm.original_module.forward_features(pixel_value)
|
82 |
+
return self.vpm.modules_to_save[self.vpm.active_adapter].forward_features(pixel_value)
|
83 |
+
else:
|
84 |
+
return self.vpm.forward_features(pixel_value)
|
85 |
+
|
86 |
def get_vision_embedding(self, pixel_values):
|
87 |
res = []
|
88 |
+
dtype = self.llm.lm_head.weight.dtype
|
89 |
+
def process_each_pixel(pixel_value, dtype, config, vpm, resampler):
|
90 |
H, W = pixel_value.shape[-2:]
|
91 |
+
target_size = (math.ceil(H / config.patch_size), math.ceil(W / config.patch_size))
|
92 |
+
vision_embedding = self.vpm_forward_features(pixel_value.unsqueeze(0).type(dtype))
|
93 |
+
|
94 |
+
if hasattr(vpm, 'num_prefix_tokens') and vpm.num_prefix_tokens > 0:
|
95 |
+
vision_embedding = vision_embedding[:, vpm.num_prefix_tokens:]
|
96 |
+
return resampler(vision_embedding, target_size)
|
97 |
+
|
98 |
+
for pixel_value in pixel_values:
|
99 |
+
result = process_each_pixel(pixel_value, dtype, self.config, self.vpm, self.resampler)
|
100 |
+
res.append(result)
|
101 |
return torch.vstack(res)
|
102 |
|
103 |
def get_vllm_embedding(self, data):
|
|
|
108 |
if len(pixel_values) > 0:
|
109 |
vision_hidden_states.append(self.get_vision_embedding(pixel_values))
|
110 |
elif self.training:
|
111 |
+
dtype = self.llm.lm_head.weight.dtype
|
112 |
+
device = self.llm.lm_head.weight.device
|
113 |
dummy_image = torch.zeros(
|
114 |
(1, 3, 224, 224), device=device, dtype=dtype
|
115 |
)
|
|
|
334 |
content = msg["content"]
|
335 |
assert role in ["user", "assistant"]
|
336 |
if i == 0:
|
337 |
+
assert role == "user", "The role of first msg should be user"
|
338 |
+
if self.config.slice_mode:
|
339 |
+
images, final_placeholder = self.get_slice_image_placeholder(
|
340 |
+
image, tokenizer
|
341 |
+
)
|
342 |
+
content = final_placeholder + "\n" + content
|
343 |
else:
|
344 |
+
images = [image]
|
345 |
+
content = (
|
346 |
+
tokenizer.im_start
|
347 |
+
+ tokenizer.unk_token * self.config.query_num
|
348 |
+
+ tokenizer.im_end
|
349 |
+
+ "\n"
|
350 |
+
+ content
|
351 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
prompt += "<用户>" if role == "user" else "<AI>"
|
353 |
prompt += content
|
354 |
prompt += "<AI>"
|
|
|
389 |
|
390 |
return answer, context, generation_config
|
391 |
|
392 |
+
|
393 |
+
|
394 |
|
395 |
class LlamaTokenizerWrapper(LlamaTokenizer):
|
396 |
def __init__(self, **kwargs):
|