Spaces:
Build error
Build error
praeclarumjj3
commited on
Commit
•
cfcddc7
1
Parent(s):
b916070
Update vcoder_llava/model/vcoder_ds_llava_arch.py
Browse files
vcoder_llava/model/vcoder_ds_llava_arch.py
CHANGED
@@ -158,26 +158,17 @@ class VCoderDSLlavaMetaForCausalLM(ABC):
|
|
158 |
seg_features = self.encode_seg_images(seg_images)
|
159 |
|
160 |
if depth_images is not None:
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
else:
|
171 |
-
depth_features = self.encode_depth_images(depth_images)
|
172 |
-
except:
|
173 |
-
depth_images = None
|
174 |
-
mask = input_ids != DEPTH_TOKEN_INDEX # drop depth indices
|
175 |
-
input_ids = input_ids[mask]
|
176 |
-
for p in self.get_model().depth_mm_projector.parameters():
|
177 |
-
p.requires_grad = False
|
178 |
else:
|
179 |
-
|
180 |
-
p.requires_grad = False
|
181 |
|
182 |
self.get_model().vcoder_lm_emb.weight.data = self.get_model().get_input_embeddings().weight.data.clone()
|
183 |
|
@@ -187,13 +178,15 @@ class VCoderDSLlavaMetaForCausalLM(ABC):
|
|
187 |
cur_seg_idx = 0
|
188 |
cur_depth_idx = 0
|
189 |
for batch_idx, cur_input_ids in enumerate(input_ids):
|
|
|
190 |
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0 and (cur_input_ids == SEG_TOKEN_INDEX).sum() == 0:
|
191 |
# FIXME: this is a hacky fix, for deepspeed zero3 to work
|
192 |
cur_image_features = image_features[cur_image_idx]
|
193 |
half_len = cur_input_ids.shape[0] // 2
|
194 |
if seg_images is not None:
|
195 |
cur_seg_features = seg_features[cur_seg_idx]
|
196 |
-
|
|
|
197 |
cur_depth_features = depth_features[cur_depth_idx]
|
198 |
cur_input_embeds_1 = self.get_model().vcoder_lm_emb(cur_input_ids[:half_len])
|
199 |
cur_input_embeds_2 = self.get_model().vcoder_lm_emb(cur_input_ids[half_len:])
|
@@ -201,7 +194,7 @@ class VCoderDSLlavaMetaForCausalLM(ABC):
|
|
201 |
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
|
202 |
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
|
203 |
if seg_images is not None:
|
204 |
-
if
|
205 |
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_depth_features[0:0], cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
206 |
else:
|
207 |
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
@@ -243,19 +236,16 @@ class VCoderDSLlavaMetaForCausalLM(ABC):
|
|
243 |
while seg_token_indices.numel() > 0:
|
244 |
cur_seg_features = seg_features[cur_seg_idx]
|
245 |
seg_token_start = seg_token_indices[0]
|
246 |
-
if depth_images is None:
|
247 |
-
cur_new_input_embeds.append(self.get_model().vcoder_lm_emb(cur_input_ids[:seg_token_start]))
|
248 |
cur_new_input_embeds.append(cur_seg_features)
|
249 |
if labels is not None:
|
250 |
-
if depth_images is None:
|
251 |
-
cur_new_labels.append(cur_labels[:seg_token_start])
|
252 |
cur_new_labels.append(torch.full((cur_seg_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
253 |
cur_labels = cur_labels[seg_token_start+1:]
|
254 |
cur_seg_idx += 1
|
255 |
cur_input_ids = cur_input_ids[seg_token_start+1:]
|
256 |
seg_token_indices = torch.where(cur_input_ids == SEG_TOKEN_INDEX)[0]
|
257 |
|
258 |
-
|
|
|
259 |
depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
|
260 |
while depth_token_indices.numel() > 0:
|
261 |
cur_depth_features = depth_features[cur_depth_idx]
|
@@ -269,6 +259,8 @@ class VCoderDSLlavaMetaForCausalLM(ABC):
|
|
269 |
cur_depth_idx += 1
|
270 |
cur_input_ids = cur_input_ids[depth_token_start+1:]
|
271 |
depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
|
|
|
|
|
272 |
|
273 |
if cur_input_ids.numel() > 0:
|
274 |
if seg_images is None:
|
|
|
158 |
seg_features = self.encode_seg_images(seg_images)
|
159 |
|
160 |
if depth_images is not None:
|
161 |
+
is_depth_zero = [torch.mean(d) == 0 for d in depth_images]
|
162 |
+
if type(depth_images) is list or depth_images.ndim == 5:
|
163 |
+
concat_depth_images = torch.cat([image for image in depth_images], dim=0)
|
164 |
+
depth_features = self.encode_depth_images(concat_depth_images)
|
165 |
+
split_sizes = [image.shape[0] for image in depth_images]
|
166 |
+
depth_features = torch.split(depth_features, split_sizes, dim=0)
|
167 |
+
depth_features = [x.flatten(0, 1) for x in depth_features]
|
168 |
+
else:
|
169 |
+
depth_features = self.encode_depth_images(depth_images)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
else:
|
171 |
+
is_depth_zero = [True] * input_ids.shape[0]
|
|
|
172 |
|
173 |
self.get_model().vcoder_lm_emb.weight.data = self.get_model().get_input_embeddings().weight.data.clone()
|
174 |
|
|
|
178 |
cur_seg_idx = 0
|
179 |
cur_depth_idx = 0
|
180 |
for batch_idx, cur_input_ids in enumerate(input_ids):
|
181 |
+
print(cur_input_ids)
|
182 |
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0 and (cur_input_ids == SEG_TOKEN_INDEX).sum() == 0:
|
183 |
# FIXME: this is a hacky fix, for deepspeed zero3 to work
|
184 |
cur_image_features = image_features[cur_image_idx]
|
185 |
half_len = cur_input_ids.shape[0] // 2
|
186 |
if seg_images is not None:
|
187 |
cur_seg_features = seg_features[cur_seg_idx]
|
188 |
+
is_cur_depth_zero = is_depth_zero[cur_depth_idx]
|
189 |
+
if not is_cur_depth_zero:
|
190 |
cur_depth_features = depth_features[cur_depth_idx]
|
191 |
cur_input_embeds_1 = self.get_model().vcoder_lm_emb(cur_input_ids[:half_len])
|
192 |
cur_input_embeds_2 = self.get_model().vcoder_lm_emb(cur_input_ids[half_len:])
|
|
|
194 |
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
|
195 |
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
|
196 |
if seg_images is not None:
|
197 |
+
if not is_cur_depth_zero:
|
198 |
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_depth_features[0:0], cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
199 |
else:
|
200 |
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_seg_features[0:0], cur_image_features[0:0], cur_input_embeds_2], dim=0)
|
|
|
236 |
while seg_token_indices.numel() > 0:
|
237 |
cur_seg_features = seg_features[cur_seg_idx]
|
238 |
seg_token_start = seg_token_indices[0]
|
|
|
|
|
239 |
cur_new_input_embeds.append(cur_seg_features)
|
240 |
if labels is not None:
|
|
|
|
|
241 |
cur_new_labels.append(torch.full((cur_seg_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
242 |
cur_labels = cur_labels[seg_token_start+1:]
|
243 |
cur_seg_idx += 1
|
244 |
cur_input_ids = cur_input_ids[seg_token_start+1:]
|
245 |
seg_token_indices = torch.where(cur_input_ids == SEG_TOKEN_INDEX)[0]
|
246 |
|
247 |
+
is_cur_depth_zero = is_depth_zero[cur_depth_idx]
|
248 |
+
if not is_cur_depth_zero:
|
249 |
depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
|
250 |
while depth_token_indices.numel() > 0:
|
251 |
cur_depth_features = depth_features[cur_depth_idx]
|
|
|
259 |
cur_depth_idx += 1
|
260 |
cur_input_ids = cur_input_ids[depth_token_start+1:]
|
261 |
depth_token_indices = torch.where(cur_input_ids == DEPTH_TOKEN_INDEX)[0]
|
262 |
+
else:
|
263 |
+
cur_depth_idx += 1
|
264 |
|
265 |
if cur_input_ids.numel() > 0:
|
266 |
if seg_images is None:
|