qianyuchen
commited on
Commit
•
d56777d
1
Parent(s):
45387f9
Update modeling_minicpmv.py
Browse files修改get_vision_embedding 使模型可以适应zero3的finetuning
- modeling_minicpmv.py +96 -7
modeling_minicpmv.py
CHANGED
@@ -42,13 +42,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
42 |
|
43 |
return model
|
44 |
|
45 |
-
def init_resampler(self, embed_dim, vision_dim):
|
46 |
return Resampler(
|
47 |
num_queries=self.config.query_num,
|
48 |
embed_dim=embed_dim,
|
49 |
num_heads=embed_dim // 128,
|
50 |
kv_dim=vision_dim,
|
51 |
-
adaptive=True
|
52 |
)
|
53 |
|
54 |
def init_transform(self):
|
@@ -60,13 +60,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
60 |
),
|
61 |
]
|
62 |
)
|
63 |
-
|
64 |
def get_input_embeddings(self):
|
65 |
return self.llm.get_input_embeddings()
|
66 |
|
67 |
def set_input_embeddings(self, value):
|
68 |
self.llm.embed_tokens = value
|
69 |
-
|
70 |
def get_vllm_embedding(self, data):
|
71 |
if 'vision_hidden_states' not in data:
|
72 |
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
@@ -152,16 +152,105 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
152 |
image_indices = torch.stack(
|
153 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
154 |
).to(vllm_embedding.device)
|
155 |
-
|
156 |
cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
157 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
|
158 |
elif self.training:
|
159 |
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
160 |
|
161 |
return vllm_embedding, vision_hidden_states
|
162 |
-
|
163 |
def forward(self, data, **kwargs):
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
position_ids = data["position_ids"]
|
166 |
if position_ids.dtype != torch.int64:
|
167 |
position_ids = position_ids.long()
|
|
|
42 |
|
43 |
return model
|
44 |
|
45 |
+
def init_resampler(self, embed_dim, vision_dim,):
|
46 |
return Resampler(
|
47 |
num_queries=self.config.query_num,
|
48 |
embed_dim=embed_dim,
|
49 |
num_heads=embed_dim // 128,
|
50 |
kv_dim=vision_dim,
|
51 |
+
adaptive=True,
|
52 |
)
|
53 |
|
54 |
def init_transform(self):
|
|
|
60 |
),
|
61 |
]
|
62 |
)
|
63 |
+
|
64 |
def get_input_embeddings(self):
|
65 |
return self.llm.get_input_embeddings()
|
66 |
|
67 |
def set_input_embeddings(self, value):
|
68 |
self.llm.embed_tokens = value
|
69 |
+
|
70 |
def get_vllm_embedding(self, data):
|
71 |
if 'vision_hidden_states' not in data:
|
72 |
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
|
|
152 |
image_indices = torch.stack(
|
153 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
154 |
).to(vllm_embedding.device)
|
|
|
155 |
cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
156 |
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
|
157 |
elif self.training:
|
158 |
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
159 |
|
160 |
return vllm_embedding, vision_hidden_states
|
161 |
+
|
162 |
def forward(self, data, **kwargs):
|
163 |
+
|
164 |
+
if 'vision_hidden_states' not in data:
|
165 |
+
dtype = self.llm.lm_head.weight.dtype
|
166 |
+
device = self.llm.lm_head.weight.device
|
167 |
+
tgt_sizes = data['tgt_sizes']
|
168 |
+
pixel_values_list = data['pixel_values']
|
169 |
+
vision_hidden_states = []
|
170 |
+
all_pixel_values = []
|
171 |
+
img_cnt = []
|
172 |
+
for pixel_values in pixel_values_list:
|
173 |
+
img_cnt.append(len(pixel_values))
|
174 |
+
all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
|
175 |
+
|
176 |
+
# exist image
|
177 |
+
if all_pixel_values:
|
178 |
+
tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
|
179 |
+
|
180 |
+
if self.config.batch_vision_input:
|
181 |
+
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
|
182 |
+
|
183 |
+
all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True,
|
184 |
+
padding_value=0.0)
|
185 |
+
B, L, _ = all_pixel_values.shape
|
186 |
+
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
187 |
+
|
188 |
+
patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
|
189 |
+
for i in range(B):
|
190 |
+
patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
|
191 |
+
|
192 |
+
vision_embedding = self.vpm(all_pixel_values.type(dtype), patch_attention_mask=patch_attn_mask).last_hidden_state
|
193 |
+
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
|
194 |
+
else:
|
195 |
+
# get vision_embedding foreach
|
196 |
+
vision_embedding = []
|
197 |
+
for single_tgt_size, single_pixel_values in zip(tgt_sizes, all_pixel_values):
|
198 |
+
single_pixel_values = single_pixel_values.unsqueeze(0)
|
199 |
+
B, L, _ = single_pixel_values.shape
|
200 |
+
single_pixel_values = single_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
201 |
+
single_vision_embedding = self.vpm(single_pixel_values.type(dtype)).last_hidden_state
|
202 |
+
single_vision_embedding = self.resampler(single_vision_embedding, single_tgt_size.unsqueeze(0))
|
203 |
+
vision_embedding.append(single_vision_embedding)
|
204 |
+
vision_embedding = torch.vstack(vision_embedding)
|
205 |
+
|
206 |
+
start = 0
|
207 |
+
for pixel_values in pixel_values_list:
|
208 |
+
img_cnt = len(pixel_values)
|
209 |
+
if img_cnt > 0:
|
210 |
+
vision_hidden_states.append(vision_embedding[start: start + img_cnt])
|
211 |
+
start += img_cnt
|
212 |
+
else:
|
213 |
+
vision_hidden_states.append([])
|
214 |
+
else: # no image
|
215 |
+
if self.training:
|
216 |
+
dummy_image = torch.zeros(
|
217 |
+
(1, 3, 224, 224),
|
218 |
+
device=device, dtype=dtype
|
219 |
+
)
|
220 |
+
tgt_sizes = torch.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).type(torch.int32)
|
221 |
+
dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
|
222 |
+
else:
|
223 |
+
dummy_feature = []
|
224 |
+
for _ in range(len(pixel_values_list)):
|
225 |
+
vision_hidden_states.append(dummy_feature)
|
226 |
+
|
227 |
+
else:
|
228 |
+
vision_hidden_states = data['vision_hidden_states']
|
229 |
+
|
230 |
+
if hasattr(self.llm.config, 'scale_emb'):
|
231 |
+
vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
|
232 |
+
else:
|
233 |
+
vllm_embedding = self.llm.model.embed_tokens(data['input_ids'])
|
234 |
+
|
235 |
+
vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(
|
236 |
+
i, torch.Tensor) else i for i in vision_hidden_states]
|
237 |
+
|
238 |
+
bs = len(data['input_ids'])
|
239 |
+
for i in range(bs):
|
240 |
+
cur_vs_hs = vision_hidden_states[i]
|
241 |
+
if len(cur_vs_hs) > 0:
|
242 |
+
cur_vllm_emb = vllm_embedding[i]
|
243 |
+
cur_image_bound = data['image_bound'][i]
|
244 |
+
if len(cur_image_bound) > 0:
|
245 |
+
image_indices = torch.stack(
|
246 |
+
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
247 |
+
).to(vllm_embedding.device)
|
248 |
+
cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
249 |
+
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
|
250 |
+
elif self.training:
|
251 |
+
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
252 |
+
|
253 |
+
# vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
|
254 |
position_ids = data["position_ids"]
|
255 |
if position_ids.dtype != torch.int64:
|
256 |
position_ids = position_ids.long()
|