Spaces:
Runtime error
Runtime error
Commit
·
cfcddc7
1
Parent(s):
b916070
Update vcoder_llava/model/vcoder_ds_llava_arch.py
Browse files
vcoder_llava/model/vcoder_ds_llava_arch.py
CHANGED
|
@@ -158,26 +158,17 @@ class VCoderDSLlavaMetaForCausalLM(ABC):
|
|
| 158 |
seg_features = self.encode_seg_images(seg_images)
|
| 159 |
|
| 160 |
if depth_images is not None:
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
else:
|
| 171 |
-
depth_features = self.encode_depth_images(depth_images)
|
| 172 |
-
except:
|
| 173 |
-
depth_images = None
|
| 174 |
-
mask = input_ids != DEPTH_TOKEN_INDEX # drop depth indices
|
| 175 |
-
input_ids = input_ids[mask]
|
| 176 |
-
for p in self.get_model().depth_mm_projector.parameters():
|
| 177 |
-
p.requires_grad = False
|
| 178 |
else:
|
| 179 |
-
|
| 180 |
-
p.requires_grad = False
|
| 181 |
|
| 182 |
self.get_model().vcoder_lm_emb.weight.data = self.get_model().get_input_embeddings().weight.data.clone()
|
| 183 |
|
|
@@ -187,13 +178,15 @@ class VCoderDSLlavaMetaForCausalLM(ABC):
|
|
| 187 |
cur_seg_idx = 0
|
| 188 |
cur_depth_idx = 0
|
| 189 |
for batch_idx, cur_input_ids in enumerate(input_ids):
|
|
|
|
| 190 |
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0 and (cur_input_ids == SEG_TOKEN_INDEX).sum() == 0:
|
| 191 |
# FIXME: this is a hacky fix, for deepspeed zero3 to work
|
| 192 |
cur_image_features = image_features[cur_image_idx]
|
| 193 |
half_len = cur_input_ids.shape[0] // 2
|
| 194 |
if seg_images is not None:
|
| 195 |
cur_seg_features = seg_features[cur_seg_idx]
|
| 196 |
-
|
|
|
|
| 197 |
cur_depth_features = depth_features[cur_depth_idx]
|
| 198 |
cur_input_embeds_1 = self.get_model().vcoder_lm_emb(cur_input_ids[:half_len])
|
| 199 |
cur_input_embeds_2 = self.get_model().vcoder_lm_emb(cur_input_ids[half_len:])
|
|
@@ -201,7 +194,7 @@ class VCoderDSLlavaMetaForCausalLM(ABC):
|
|
| 201 |
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
|
| 202 |
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
|
| 203 |
if seg_images is not None:
|
| 204 |
-
if
|
| 205 |
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_depth_features[0:0], cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
| 206 |
else:
|
| 207 |
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
|
@@ -243,19 +236,16 @@ class VCoderDSLlavaMetaForCausalLM(ABC):
|
|
| 243 |
while seg_token_indices.numel() > 0:
|
| 244 |
cur_seg_features = seg_features[cur_seg_idx]
|
| 245 |
seg_token_start = seg_token_indices[0]
|
| 246 |
-
if depth_images is None:
|
| 247 |
-
cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids[:seg_token_start]))
|
| 248 |
cur_new_input_embeds.append(cur_seg_features)
|
| 249 |
if labels is not None:
|
| 250 |
-
if depth_images is None:
|
| 251 |
-
cur_new_labels.append(cur_labels[:seg_token_start])
|
| 252 |
cur_new_labels.append(torch.full((cur_seg_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
| 253 |
cur_labels = cur_labels[seg_token_start+1:]
|
| 254 |
cur_seg_idx += 1
|
| 255 |
cur_input_ids = cur_input_ids[seg_token_start+1:]
|
| 256 |
seg_token_indices = torch.where(cur_input_ids == SEG_TOKEN_INDEX)[0]
|
| 257 |
|
| 258 |
-
|
|
|
|
| 259 |
depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
|
| 260 |
while depth_token_indices.numel() > 0:
|
| 261 |
cur_depth_features = depth_features[cur_depth_idx]
|
|
@@ -269,6 +259,8 @@ class VCoderDSLlavaMetaForCausalLM(ABC):
|
|
| 269 |
cur_depth_idx += 1
|
| 270 |
cur_input_ids = cur_input_ids[depth_token_start+1:]
|
| 271 |
depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
|
|
|
|
|
|
|
| 272 |
|
| 273 |
if cur_input_ids.numel() > 0:
|
| 274 |
if seg_images is None:
|
|
|
|
| 158 |
seg_features = self.encode_seg_images(seg_images)
|
| 159 |
|
| 160 |
if depth_images is not None:
|
| 161 |
+
is_depth_zero = [torch.mean(d) == 0 for d in depth_images]
|
| 162 |
+
if type(depth_images) is list or depth_images.ndim == 5:
|
| 163 |
+
concat_depth_images = torch.cat([image for image in depth_images], dim=0)
|
| 164 |
+
depth_features = self.encode_depth_images(concat_depth_images)
|
| 165 |
+
split_sizes = [image.shape[0] for image in depth_images]
|
| 166 |
+
depth_features = torch.split(depth_features, split_sizes, dim=0)
|
| 167 |
+
depth_features = [x.flatten(0, 1) for x in depth_features]
|
| 168 |
+
else:
|
| 169 |
+
depth_features = self.encode_depth_images(depth_images)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
else:
|
| 171 |
+
is_depth_zero = [True] * input_ids.shape[0]
|
|
|
|
| 172 |
|
| 173 |
self.get_model().vcoder_lm_emb.weight.data = self.get_model().get_input_embeddings().weight.data.clone()
|
| 174 |
|
|
|
|
| 178 |
cur_seg_idx = 0
|
| 179 |
cur_depth_idx = 0
|
| 180 |
for batch_idx, cur_input_ids in enumerate(input_ids):
|
| 181 |
+
print(cur_input_ids)
|
| 182 |
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0 and (cur_input_ids == SEG_TOKEN_INDEX).sum() == 0:
|
| 183 |
# FIXME: this is a hacky fix, for deepspeed zero3 to work
|
| 184 |
cur_image_features = image_features[cur_image_idx]
|
| 185 |
half_len = cur_input_ids.shape[0] // 2
|
| 186 |
if seg_images is not None:
|
| 187 |
cur_seg_features = seg_features[cur_seg_idx]
|
| 188 |
+
is_cur_depth_zero = is_depth_zero[cur_depth_idx]
|
| 189 |
+
if not is_cur_depth_zero:
|
| 190 |
cur_depth_features = depth_features[cur_depth_idx]
|
| 191 |
cur_input_embeds_1 = self.get_model().vcoder_lm_emb(cur_input_ids[:half_len])
|
| 192 |
cur_input_embeds_2 = self.get_model().vcoder_lm_emb(cur_input_ids[half_len:])
|
|
|
|
| 194 |
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
|
| 195 |
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
|
| 196 |
if seg_images is not None:
|
| 197 |
+
if not is_cur_depth_zero:
|
| 198 |
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_depth_features[0:0], cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
| 199 |
else:
|
| 200 |
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
|
|
|
| 236 |
while seg_token_indices.numel() > 0:
|
| 237 |
cur_seg_features = seg_features[cur_seg_idx]
|
| 238 |
seg_token_start = seg_token_indices[0]
|
|
|
|
|
|
|
| 239 |
cur_new_input_embeds.append(cur_seg_features)
|
| 240 |
if labels is not None:
|
|
|
|
|
|
|
| 241 |
cur_new_labels.append(torch.full((cur_seg_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
| 242 |
cur_labels = cur_labels[seg_token_start+1:]
|
| 243 |
cur_seg_idx += 1
|
| 244 |
cur_input_ids = cur_input_ids[seg_token_start+1:]
|
| 245 |
seg_token_indices = torch.where(cur_input_ids == SEG_TOKEN_INDEX)[0]
|
| 246 |
|
| 247 |
+
is_cur_depth_zero = is_depth_zero[cur_depth_idx]
|
| 248 |
+
if not is_cur_depth_zero:
|
| 249 |
depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
|
| 250 |
while depth_token_indices.numel() > 0:
|
| 251 |
cur_depth_features = depth_features[cur_depth_idx]
|
|
|
|
| 259 |
cur_depth_idx += 1
|
| 260 |
cur_input_ids = cur_input_ids[depth_token_start+1:]
|
| 261 |
depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
|
| 262 |
+
else:
|
| 263 |
+
cur_depth_idx += 1
|
| 264 |
|
| 265 |
if cur_input_ids.numel() > 0:
|
| 266 |
if seg_images is None:
|