pawlowskipawel commited on
Commit
43425cd
1 Parent(s): ee1f1f1

Add task_prefix_attention_mask Argument to _merge_input_ids_with_image_features for Better Padding Handling

Browse files

This PR introduces a small change in the _merge_input_ids_with_image_features function by adding a task_prefix_attention_mask=None argument. This enhancement ensures that when doing batch processing with padding to the max length, the attention mask correctly ignores padding tokens.

Changes Made:
1. Added task_prefix_attention_mask=None argument to _merge_input_ids_with_image_features function.
2. Updated the function to incorporate the provided attention mask, allowing it to ignore padding tokens during batch processing.

Below is an example demonstrating the issue and the improvement:
```python
prompts =["prompt", "longer prompt", "much much longer prompt"]

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"

image = Image.open(requests.get(url, stream=True).raw)
images = [image] * len(prompts)

inputs = processor(text=prompts, images=images, return_tensors="pt", padding=True).to("cuda", torch.float16)

inputs_embeds = model.get_input_embeddings()(inputs.input_ids)
image_features = model._encode_image(inputs.pixel_values)

print(inputs.input_ids)
# Output:
# tensor([[ 0, 12501, 3320, 2, 1, 1],
# [ 0, 3479, 254, 14302, 2, 1],
# [ 0, 28431, 203, 1181, 14302, 2]], device='cuda:0')

# Before change
inputs_embeds, attention_mask = model._merge_input_ids_with_image_features(image_features, inputs_embeds)
print(attention_mask[:, -10:])
# Output:
# tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')

# After change
inputs_embeds, attention_mask = model._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=inputs.attention_mask)
print(attention_mask[:, -10:])
# Output:
# tensor([[1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
```

Files changed (1) hide show
  1. modeling_florence2.py +7 -5
modeling_florence2.py CHANGED
@@ -2643,7 +2643,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2643
  return x
2644
 
2645
  def _merge_input_ids_with_image_features(
2646
- self, image_features, inputs_embeds
2647
  ):
2648
  batch_size, image_token_length = image_features.size()[:-1]
2649
  device = image_features.device
@@ -2655,10 +2655,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2655
  return image_features, image_attention_mask
2656
 
2657
  task_prefix_embeds = inputs_embeds
2658
- task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
2659
 
2660
- if len(task_prefix_attention_mask.shape) == 3:
2661
- task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
 
 
 
2662
 
2663
  # concat [image embeds, task prefix embeds]
2664
  inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
@@ -2734,7 +2736,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2734
  if pixel_values is not None:
2735
  # (batch_size, num_image_tokens, hidden_size)
2736
  image_features = self._encode_image(pixel_values)
2737
- inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
2738
 
2739
  if inputs_embeds is not None:
2740
  attention_mask = attention_mask.to(inputs_embeds.dtype)
 
2643
  return x
2644
 
2645
  def _merge_input_ids_with_image_features(
2646
+ self, image_features, inputs_embeds, task_prefix_attention_mask=None
2647
  ):
2648
  batch_size, image_token_length = image_features.size()[:-1]
2649
  device = image_features.device
 
2655
  return image_features, image_attention_mask
2656
 
2657
  task_prefix_embeds = inputs_embeds
 
2658
 
2659
+ if task_prefix_attention_mask is None:
2660
+ task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
2661
+
2662
+ if len(task_prefix_attention_mask.shape) == 3:
2663
+ task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
2664
 
2665
  # concat [image embeds, task prefix embeds]
2666
  inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
 
2736
  if pixel_values is not None:
2737
  # (batch_size, num_image_tokens, hidden_size)
2738
  image_features = self._encode_image(pixel_values)
2739
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
2740
 
2741
  if inputs_embeds is not None:
2742
  attention_mask = attention_mask.to(inputs_embeds.dtype)