VictorSanh commited on
Commit
b240ec1
1 Parent(s): cbd1c78

remove unnessary padding images

Browse files
Files changed (1) hide show
  1. modeling_img2html.py +10 -6
modeling_img2html.py CHANGED
@@ -1390,15 +1390,16 @@ class VMistralModel(VMistralPreTrainedModel):
1390
  vision_pipeline_output_seq_len = image_hidden_states.shape[1]
1391
  vision_hidden_size = image_hidden_states.shape[2]
1392
  new_inputs_embeds = inputs_embeds.clone()
1393
- # Get a view of the image_hidden_states separating batch_size and num_images, to discard padding hidden_states
1394
- image_hidden_states = image_hidden_states.view(
1395
- batch_size, num_images, vision_pipeline_output_seq_len, vision_hidden_size
1396
- )
1397
  for batch_idx in range(batch_size):
1398
  # Get the number of images for this particular example
1399
- example_num_images = (input_ids[batch_idx] == self.image_token_id).sum() // self.image_seq_len
1400
  # Get the image_hidden_states corresponding to True images for the example, so get rid of the padding images.
1401
- example_true_image_hidden_states = image_hidden_states[batch_idx][:example_num_images]
 
 
1402
  if (
1403
  new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]
1404
  != example_num_images * vision_pipeline_output_seq_len
@@ -1484,6 +1485,9 @@ class VMistralModel(VMistralPreTrainedModel):
1484
  pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
1485
  batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
1486
  pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
 
 
 
1487
  # Get sequence from the vision encoder
1488
  image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
1489
 
 
1390
  vision_pipeline_output_seq_len = image_hidden_states.shape[1]
1391
  vision_hidden_size = image_hidden_states.shape[2]
1392
  new_inputs_embeds = inputs_embeds.clone()
1393
+ # Get the number of images for each example
1394
+ num_images = (input_ids == self.image_token_id).sum(dim=-1) // self.image_seq_len
1395
+ cum_num_images = num_images.cumsum(dim=-1)
 
1396
  for batch_idx in range(batch_size):
1397
  # Get the number of images for this particular example
1398
+ example_num_images = num_images[batch_idx]
1399
  # Get the image_hidden_states corresponding to True images for the example, so get rid of the padding images.
1400
+ start = 0 if batch_idx == 0 else cum_num_images[batch_idx - 1]
1401
+ end = cum_num_images[batch_idx]
1402
+ example_true_image_hidden_states = image_hidden_states[start:end]
1403
  if (
1404
  new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]
1405
  != example_num_images * vision_pipeline_output_seq_len
 
1485
  pixel_values = pixel_values.to(dtype=self.dtype, device=input_ids.device) # fp16 compatibility
1486
  batch_size, num_images = pixel_values.size(0), pixel_values.size(1)
1487
  pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
1488
+ # Remove padding images - padding images are full 0.
1489
+ real_images_inds = pixel_values.sum(dim=(-1, -2, -3)) != 0.0
1490
+ pixel_values = pixel_values[real_images_inds]
1491
  # Get sequence from the vision encoder
1492
  image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
1493