praeclarumjj3 commited on
Commit
cfcddc7
1 Parent(s): b916070

Update vcoder_llava/model/vcoder_ds_llava_arch.py

Browse files
vcoder_llava/model/vcoder_ds_llava_arch.py CHANGED
@@ -158,26 +158,17 @@ class VCoderDSLlavaMetaForCausalLM(ABC):
158
  seg_features = self.encode_seg_images(seg_images)
159
 
160
  if depth_images is not None:
161
- try:
162
- for p in self.get_model().depth_mm_projector.parameters():
163
- p.requires_grad = True
164
- if type(depth_images) is list or depth_images.ndim == 5:
165
- concat_depth_images = torch.cat([image for image in depth_images], dim=0)
166
- depth_features = self.encode_depth_images(concat_depth_images)
167
- split_sizes = [image.shape[0] for image in depth_images]
168
- depth_features = torch.split(depth_features, split_sizes, dim=0)
169
- depth_features = [x.flatten(0, 1) for x in depth_features]
170
- else:
171
- depth_features = self.encode_depth_images(depth_images)
172
- except:
173
- depth_images = None
174
- mask = input_ids != DEPTH_TOKEN_INDEX # drop depth indices
175
- input_ids = input_ids[mask]
176
- for p in self.get_model().depth_mm_projector.parameters():
177
- p.requires_grad = False
178
  else:
179
- for p in self.get_model().depth_mm_projector.parameters():
180
- p.requires_grad = False
181
 
182
  self.get_model().vcoder_lm_emb.weight.data = self.get_model().get_input_embeddings().weight.data.clone()
183
 
@@ -187,13 +178,15 @@ class VCoderDSLlavaMetaForCausalLM(ABC):
187
  cur_seg_idx = 0
188
  cur_depth_idx = 0
189
  for batch_idx, cur_input_ids in enumerate(input_ids):
 
190
  if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0 and (cur_input_ids == SEG_TOKEN_INDEX).sum() == 0:
191
  # FIXME: this is a hacky fix, for deepspeed zero3 to work
192
  cur_image_features = image_features[cur_image_idx]
193
  half_len = cur_input_ids.shape[0] // 2
194
  if seg_images is not None:
195
  cur_seg_features = seg_features[cur_seg_idx]
196
- if depth_images is not None:
 
197
  cur_depth_features = depth_features[cur_depth_idx]
198
  cur_input_embeds_1 = self.get_model().vcoder_lm_emb(cur_input_ids[:half_len])
199
  cur_input_embeds_2 = self.get_model().vcoder_lm_emb(cur_input_ids[half_len:])
@@ -201,7 +194,7 @@ class VCoderDSLlavaMetaForCausalLM(ABC):
201
  cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
202
  cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
203
  if seg_images is not None:
204
- if depth_images is not None:
205
  cur_input_embeds = torch.cat([cur_input_embeds_1, cur_depth_features[0:0], cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
206
  else:
207
  cur_input_embeds = torch.cat([cur_input_embeds_1, cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
@@ -243,19 +236,16 @@ class VCoderDSLlavaMetaForCausalLM(ABC):
243
  while seg_token_indices.numel() > 0:
244
  cur_seg_features = seg_features[cur_seg_idx]
245
  seg_token_start = seg_token_indices[0]
246
- if depth_images is None:
247
- cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids[:seg_token_start]))
248
  cur_new_input_embeds.append(cur_seg_features)
249
  if labels is not None:
250
- if depth_images is None:
251
- cur_new_labels.append(cur_labels[:seg_token_start])
252
  cur_new_labels.append(torch.full((cur_seg_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
253
  cur_labels = cur_labels[seg_token_start+1:]
254
  cur_seg_idx += 1
255
  cur_input_ids = cur_input_ids[seg_token_start+1:]
256
  seg_token_indices = torch.where(cur_input_ids == SEG_TOKEN_INDEX)[0]
257
 
258
- if depth_images is not None:
 
259
  depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
260
  while depth_token_indices.numel() > 0:
261
  cur_depth_features = depth_features[cur_depth_idx]
@@ -269,6 +259,8 @@ class VCoderDSLlavaMetaForCausalLM(ABC):
269
  cur_depth_idx += 1
270
  cur_input_ids = cur_input_ids[depth_token_start+1:]
271
  depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
 
 
272
 
273
  if cur_input_ids.numel() > 0:
274
  if seg_images is None:
 
158
  seg_features = self.encode_seg_images(seg_images)
159
 
160
  if depth_images is not None:
161
+ is_depth_zero = [torch.mean(d) == 0 for d in depth_images]
162
+ if type(depth_images) is list or depth_images.ndim == 5:
163
+ concat_depth_images = torch.cat([image for image in depth_images], dim=0)
164
+ depth_features = self.encode_depth_images(concat_depth_images)
165
+ split_sizes = [image.shape[0] for image in depth_images]
166
+ depth_features = torch.split(depth_features, split_sizes, dim=0)
167
+ depth_features = [x.flatten(0, 1) for x in depth_features]
168
+ else:
169
+ depth_features = self.encode_depth_images(depth_images)
 
 
 
 
 
 
 
 
170
  else:
171
+ is_depth_zero = [True] * input_ids.shape[0]
 
172
 
173
  self.get_model().vcoder_lm_emb.weight.data = self.get_model().get_input_embeddings().weight.data.clone()
174
 
 
178
  cur_seg_idx = 0
179
  cur_depth_idx = 0
180
  for batch_idx, cur_input_ids in enumerate(input_ids):
181
+ print(cur_input_ids)
182
  if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0 and (cur_input_ids == SEG_TOKEN_INDEX).sum() == 0:
183
  # FIXME: this is a hacky fix, for deepspeed zero3 to work
184
  cur_image_features = image_features[cur_image_idx]
185
  half_len = cur_input_ids.shape[0] // 2
186
  if seg_images is not None:
187
  cur_seg_features = seg_features[cur_seg_idx]
188
+ is_cur_depth_zero = is_depth_zero[cur_depth_idx]
189
+ if not is_cur_depth_zero:
190
  cur_depth_features = depth_features[cur_depth_idx]
191
  cur_input_embeds_1 = self.get_model().vcoder_lm_emb(cur_input_ids[:half_len])
192
  cur_input_embeds_2 = self.get_model().vcoder_lm_emb(cur_input_ids[half_len:])
 
194
  cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
195
  cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
196
  if seg_images is not None:
197
+ if not is_cur_depth_zero:
198
  cur_input_embeds = torch.cat([cur_input_embeds_1, cur_depth_features[0:0], cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
199
  else:
200
  cur_input_embeds = torch.cat([cur_input_embeds_1, cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
 
236
  while seg_token_indices.numel() > 0:
237
  cur_seg_features = seg_features[cur_seg_idx]
238
  seg_token_start = seg_token_indices[0]
 
 
239
  cur_new_input_embeds.append(cur_seg_features)
240
  if labels is not None:
 
 
241
  cur_new_labels.append(torch.full((cur_seg_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
242
  cur_labels = cur_labels[seg_token_start+1:]
243
  cur_seg_idx += 1
244
  cur_input_ids = cur_input_ids[seg_token_start+1:]
245
  seg_token_indices = torch.where(cur_input_ids == SEG_TOKEN_INDEX)[0]
246
 
247
+ is_cur_depth_zero = is_depth_zero[cur_depth_idx]
248
+ if not is_cur_depth_zero:
249
  depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
250
  while depth_token_indices.numel() > 0:
251
  cur_depth_features = depth_features[cur_depth_idx]
 
259
  cur_depth_idx += 1
260
  cur_input_ids = cur_input_ids[depth_token_start+1:]
261
  depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
262
+ else:
263
+ cur_depth_idx += 1
264
 
265
  if cur_input_ids.numel() > 0:
266
  if seg_images is None: