Cuiunbo commited on
Commit
9909305
1 Parent(s): 7e9623b

Upload modeling_minicpmv.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_minicpmv.py +221 -0
modeling_minicpmv.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ #
4
+ # Copyright @2023 AI, ZHIHU Inc. (zhihu.com)
5
+ #
6
+ # @author: wangchongyi <wangchongyi@zhihu.com>
7
+ # @date: 2023/9/1
8
+ #
9
+
10
+ # coding=utf-8
11
+ # Copyright 2024 RhapsodyAI. All rights reserved.
12
+ #
13
+ # Licensed under the Apache License, Version 2.0 (the "License");
14
+ # you may not use this file except in compliance with the License.
15
+ # You may obtain a copy of the License at
16
+ #
17
+ # http://www.apache.org/licenses/LICENSE-2.0
18
+ #
19
+ # Unless required by applicable law or agreed to in writing, software
20
+ # distributed under the License is distributed on an "AS IS" BASIS,
21
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22
+ # See the License for the specific language governing permissions and
23
+ # limitations under the License.
24
+
25
+
26
+ import torch
27
+ from torch import nn
28
+ import math
29
+ from dataclasses import dataclass
30
+ from typing import Optional, Tuple
31
+
32
+ from transformers.utils import ModelOutput
33
+ from transformers.modeling_utils import PreTrainedModel
34
+
35
+ from .configuration_siglip import SiglipVisionConfig
36
+ from .configuration_minicpm import MiniCPMConfig
37
+ from .configuration_minicpmv import MiniCPMVConfig
38
+
39
+ from .resampler import Resampler
40
+ from .modeling_minicpm import MiniCPMForCausalLM
41
+ from .modeling_siglip import SiglipVisionModel
42
+
43
+ from transformers import LlamaTokenizer # for text processing
44
+
45
+
46
+ @dataclass
47
+ class CausalVLMOutput(ModelOutput):
48
+ loss: Optional[torch.FloatTensor] = None
49
+ logits: torch.FloatTensor = None
50
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
51
+ vision_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
52
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
53
+
54
+
55
+ class MiniCPMVForCausalLM(PreTrainedModel):
56
+ model_type = "minicpm"
57
+ _supports_flash_attn_2 = True
58
+
59
+ def __init__(self, config: MiniCPMVConfig, adaptive=False):
60
+ super().__init__(config)
61
+
62
+ llm_config = config.llm_config
63
+ vpm_config = config.vpm_config
64
+
65
+ self.query_num = config.query_num
66
+ self.patch_size = vpm_config.patch_size
67
+ self.adaptive = adaptive
68
+ self.slice_mode = config.slice_mode
69
+ self.max_slice_nums = config.max_slice_nums
70
+ self.mm_use_im_start_end = config.mm_use_im_start_end
71
+
72
+ drop_vision_last_layer = config.drop_vision_last_layer
73
+
74
+ # should assert vpm_config is SiglipVisionConfig
75
+ vpm = SiglipVisionModel(vpm_config).vision_model
76
+
77
+ if drop_vision_last_layer: # drop last vision layer
78
+ vpm.encoder.layers = nn.ModuleList(vpm.encoder.layers[:-1])
79
+
80
+ self.vpm = vpm
81
+
82
+ # should assert llm_config is minicpmconfig
83
+ self.llm = MiniCPMForCausalLM(llm_config)
84
+
85
+ embed_dim = llm_config.hidden_size
86
+
87
+ self.resampler = Resampler(
88
+ num_queries=config.query_num,
89
+ embed_dim=embed_dim,
90
+ num_heads=embed_dim // 128,
91
+ kv_dim=vpm_config.hidden_size,
92
+ adaptive=adaptive
93
+ )
94
+
95
+ return
96
+
97
+ def vpm_forward(self, data):
98
+ if 'vision_hidden_states' not in data:
99
+ dtype = self.vpm.embeddings.position_embedding.weight.dtype
100
+ device = self.vpm.embeddings.position_embedding.weight.device
101
+
102
+ pixel_values_list = data['pixel_values']
103
+ tgt_sizes = data['tgt_sizes']
104
+
105
+ vision_hidden_states = []
106
+
107
+ all_pixel_values = []
108
+ img_cnt = []
109
+
110
+ for pixel_values in pixel_values_list:
111
+ img_cnt.append(len(pixel_values))
112
+ all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values]) # 42 * L
113
+
114
+ # exist image
115
+ if all_pixel_values:
116
+ tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
117
+ max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
118
+
119
+ all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True, padding_value=0.0)
120
+ all_pixel_values = all_pixel_values.to(device) # here we finally could put `all_pixel_values` to device.
121
+
122
+ B, L, _ = all_pixel_values.shape
123
+ all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) # B, 3, 14, L
124
+
125
+ patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
126
+ for i in range(B):
127
+ patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
128
+
129
+ vision_embedding = self.vpm(all_pixel_values.type(dtype), patch_attention_mask=patch_attn_mask).last_hidden_state
130
+ vision_embedding = self.resampler(vision_embedding, tgt_sizes)
131
+
132
+ start = 0
133
+ for pixel_values in pixel_values_list:
134
+ img_cnt = len(pixel_values)
135
+ if img_cnt > 0:
136
+ vision_hidden_states.append(vision_embedding[start: start + img_cnt])
137
+ start += img_cnt
138
+ else:
139
+ vision_hidden_states.append([])
140
+ else: # no image
141
+ if self.training:
142
+ dummy_image = torch.zeros(
143
+ (1, 3, 224, 224),
144
+ device=device, dtype=dtype
145
+ )
146
+ # 这是一个 dummy feature
147
+ tgt_sizes = torch.Tensor([[(224 // self.patch_size), math.ceil(224 / self.patch_size)]]).type(torch.int32)
148
+ dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
149
+ else:
150
+ dummy_feature = []
151
+ for _ in range(len(pixel_values_list)):
152
+ vision_hidden_states.append(dummy_feature)
153
+
154
+ else:
155
+ vision_hidden_states = data['vision_hidden_states']
156
+
157
+ if hasattr(self.llm.config, 'scale_emb'):
158
+ vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
159
+ else:
160
+ vllm_embedding = self.llm.model.embed_tokens(data['input_ids'])
161
+
162
+ vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(
163
+ i, torch.Tensor) else i for i in vision_hidden_states]
164
+
165
+ bs = len(data['input_ids'])
166
+ for i in range(bs):
167
+ cur_vs_hs = vision_hidden_states[i]
168
+
169
+ if len(cur_vs_hs) > 0:
170
+
171
+ cur_vllm_emb = vllm_embedding[i]
172
+
173
+ cur_image_bound = data['image_bound'][i]
174
+
175
+ if len(cur_image_bound) > 0:
176
+
177
+ image_indices = torch.stack(
178
+ [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
179
+ ).to(vllm_embedding.device)
180
+
181
+ cur_vllm_emb.scatter_(
182
+ 0,
183
+ image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
184
+ cur_vs_hs.view(-1, cur_vs_hs.shape[-1])
185
+ )
186
+
187
+ return vllm_embedding, vision_hidden_states
188
+
189
+ def forward(self, data, **kwargs):
190
+ vllm_embedding, vision_hidden_states = self.vpm_forward(data)
191
+
192
+ output = self.llm(
193
+ inputs_embeds=vllm_embedding,
194
+ attention_mask=data["attention_mask"],
195
+ return_dict=True
196
+ )
197
+
198
+ return CausalVLMOutput(
199
+ logits=output.logits,
200
+ hidden_states=output.hidden_states,
201
+ vision_hidden_states=vision_hidden_states
202
+ )
203
+
204
+ def generate(self, data, **kwargs):
205
+ vllm_embedding, vision_hidden_states = self.vpm_forward(data)
206
+
207
+ # position_ids = torch.arange(data["input_ids"].size(1), dtype=torch.long).to(data["input_ids"].device)
208
+ # position_ids = position_ids.unsqueeze(0).expand_as(data["input_ids"])
209
+
210
+ # 使用attention_mask将填充位置的position_ids设置为0
211
+ # position_ids = position_ids * data["attention_mask"]
212
+ output = self.llm.generate(
213
+ inputs_embeds=vllm_embedding,
214
+ # position_ids=position_ids,
215
+ attention_mask=data["attention_mask"],
216
+ **kwargs
217
+ )
218
+
219
+ return output
220
+
221
+