Add task_prefix_attention_mask argument to _merge_input_ids_with_image_features for better padding handling

#66
Files changed (1) hide show
  1. modeling_florence2.py +10 -6
modeling_florence2.py CHANGED
@@ -2643,7 +2643,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2643
  return x
2644
 
2645
  def _merge_input_ids_with_image_features(
2646
- self, image_features, inputs_embeds
2647
  ):
2648
  batch_size, image_token_length = image_features.size()[:-1]
2649
  device = image_features.device
@@ -2655,10 +2655,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2655
  return image_features, image_attention_mask
2656
 
2657
  task_prefix_embeds = inputs_embeds
2658
- task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
2659
 
2660
- if len(task_prefix_attention_mask.shape) == 3:
2661
- task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
 
 
 
2662
 
2663
  # concat [image embeds, task prefix embeds]
2664
  inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
@@ -2734,7 +2736,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2734
  if pixel_values is not None:
2735
  # (batch_size, num_image_tokens, hidden_size)
2736
  image_features = self._encode_image(pixel_values)
2737
- inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
2738
 
2739
  if inputs_embeds is not None:
2740
  attention_mask = attention_mask.to(inputs_embeds.dtype)
@@ -2781,6 +2783,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2781
  input_ids,
2782
  inputs_embeds=None,
2783
  pixel_values=None,
 
2784
  **kwargs
2785
  ):
2786
 
@@ -2791,11 +2794,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2791
  # 2. Merge text and images
2792
  if pixel_values is not None:
2793
  image_features = self._encode_image(pixel_values)
2794
- inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
2795
 
2796
  return self.language_model.generate(
2797
  input_ids=None,
2798
  inputs_embeds=inputs_embeds,
 
2799
  **kwargs
2800
  )
2801
 
 
2643
  return x
2644
 
2645
  def _merge_input_ids_with_image_features(
2646
+ self, image_features, inputs_embeds, task_prefix_attention_mask=None
2647
  ):
2648
  batch_size, image_token_length = image_features.size()[:-1]
2649
  device = image_features.device
 
2655
  return image_features, image_attention_mask
2656
 
2657
  task_prefix_embeds = inputs_embeds
 
2658
 
2659
+ if task_prefix_attention_mask is None:
2660
+ task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
2661
+
2662
+ if len(task_prefix_attention_mask.shape) == 3:
2663
+ task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
2664
 
2665
  # concat [image embeds, task prefix embeds]
2666
  inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
 
2736
  if pixel_values is not None:
2737
  # (batch_size, num_image_tokens, hidden_size)
2738
  image_features = self._encode_image(pixel_values)
2739
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
2740
 
2741
  if inputs_embeds is not None:
2742
  attention_mask = attention_mask.to(inputs_embeds.dtype)
 
2783
  input_ids,
2784
  inputs_embeds=None,
2785
  pixel_values=None,
2786
+ attention_mask=None,
2787
  **kwargs
2788
  ):
2789
 
 
2794
  # 2. Merge text and images
2795
  if pixel_values is not None:
2796
  image_features = self._encode_image(pixel_values)
2797
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
2798
 
2799
  return self.language_model.generate(
2800
  input_ids=None,
2801
  inputs_embeds=inputs_embeds,
2802
+ attention_mask=attention_mask,
2803
  **kwargs
2804
  )
2805