VictorSanh
commited on
Commit
•
b240ec1
1
Parent(s):
cbd1c78
remove unnessary padding images
Browse files- modeling_img2html.py +10 -6
modeling_img2html.py
CHANGED
@@ -1390,15 +1390,16 @@ class VMistralModel(VMistralPreTrainedModel):
|
|
1390 |
vision_pipeline_output_seq_len = image_hidden_states.shape[1]
|
1391 |
vision_hidden_size = image_hidden_states.shape[2]
|
1392 |
new_inputs_embeds = inputs_embeds.clone()
|
1393 |
-
# Get
|
1394 |
-
|
1395 |
-
|
1396 |
-
)
|
1397 |
for batch_idx in range(batch_size):
|
1398 |
# Get the number of images for this particular example
|
1399 |
-
example_num_images =
|
1400 |
# Get the image_hidden_states corresponding to True images for the example, so get rid of the padding images.
|
1401 |
-
|
|
|
|
|
1402 |
if (
|
1403 |
new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]
|
1404 |
!= example_num_images * vision_pipeline_output_seq_len
|
@@ -1484,6 +1485,9 @@ class VMistralModel(VMistralPreTrainedModel):
|
|
1484 |
pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
|
1485 |
batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
|
1486 |
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
|
|
|
|
|
|
|
1487 |
# Get sequence from the vision encoder
|
1488 |
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
|
1489 |
|
|
|
1390 |
vision_pipeline_output_seq_len = image_hidden_states.shape[1]
|
1391 |
vision_hidden_size = image_hidden_states.shape[2]
|
1392 |
new_inputs_embeds = inputs_embeds.clone()
|
1393 |
+
# Get the number of images for each example
|
1394 |
+
num_images = (input_ids == self.image_token_id).sum(dim=-1) // self.image_seq_len
|
1395 |
+
cum_num_images = num_images.cumsum(dim=-1)
|
|
|
1396 |
for batch_idx in range(batch_size):
|
1397 |
# Get the number of images for this particular example
|
1398 |
+
example_num_images = num_images[batch_idx]
|
1399 |
# Get the image_hidden_states corresponding to True images for the example, so get rid of the padding images.
|
1400 |
+
start = 0 if batch_idx == 0 else cum_num_images[batch_idx - 1]
|
1401 |
+
end = cum_num_images[batch_idx]
|
1402 |
+
example_true_image_hidden_states = image_hidden_states[start:end]
|
1403 |
if (
|
1404 |
new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]
|
1405 |
!= example_num_images * vision_pipeline_output_seq_len
|
|
|
1485 |
pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
|
1486 |
batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
|
1487 |
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
|
1488 |
+
# Remove padding images - padding images are full 0.
|
1489 |
+
real_images_inds = pixel_values.sum(dim=(-1, -2, -3)) != 0.0
|
1490 |
+
pixel_values = pixel_values[real_images_inds]
|
1491 |
# Get sequence from the vision encoder
|
1492 |
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
|
1493 |
|