Add task_prefix_attention_mask Argument to _merge_input_ids_with_image_features for Better Padding Handling
Browse filesThis PR introduces a small change in the _merge_input_ids_with_image_features function by adding a task_prefix_attention_mask=None argument. This enhancement ensures that when doing batch processing with padding to the max length, the attention mask correctly ignores padding tokens.
Changes Made:
1. Added task_prefix_attention_mask=None argument to _merge_input_ids_with_image_features function.
2. Updated the function to incorporate the provided attention mask, allowing it to ignore padding tokens during batch processing.
Below is an example demonstrating the issue and the improvement:
```python
prompts =["prompt", "longer prompt", "much much longer prompt"]
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
images = [image] * len(prompts)
inputs = processor(text=prompts, images=images, return_tensors="pt", padding=True).to("cuda", torch.float16)
inputs_embeds = model.get_input_embeddings()(inputs.input_ids)
image_features = model._encode_image(inputs.pixel_values)
print(inputs.input_ids)
# Output:
# tensor([[ 0, 12501, 3320, 2, 1, 1],
# [ 0, 3479, 254, 14302, 2, 1],
# [ 0, 28431, 203, 1181, 14302, 2]], device='cuda:0')
# Before change
inputs_embeds, attention_mask = model._merge_input_ids_with_image_features(image_features, inputs_embeds)
print(attention_mask[:, -10:])
# Output:
# tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
# After change
inputs_embeds, attention_mask = model._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=inputs.attention_mask)
print(attention_mask[:, -10:])
# Output:
# tensor([[1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
```
- modeling_florence2.py +7 -5
@@ -2643,7 +2643,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2643 |
return x
|
2644 |
|
2645 |
def _merge_input_ids_with_image_features(
|
2646 |
-
self, image_features, inputs_embeds
|
2647 |
):
|
2648 |
batch_size, image_token_length = image_features.size()[:-1]
|
2649 |
device = image_features.device
|
@@ -2655,10 +2655,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2655 |
return image_features, image_attention_mask
|
2656 |
|
2657 |
task_prefix_embeds = inputs_embeds
|
2658 |
-
task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
|
2659 |
|
2660 |
-
if
|
2661 |
-
task_prefix_attention_mask =
|
|
|
|
|
|
|
2662 |
|
2663 |
# concat [image embeds, task prefix embeds]
|
2664 |
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
|
@@ -2734,7 +2736,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2734 |
if pixel_values is not None:
|
2735 |
# (batch_size, num_image_tokens, hidden_size)
|
2736 |
image_features = self._encode_image(pixel_values)
|
2737 |
-
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
2738 |
|
2739 |
if inputs_embeds is not None:
|
2740 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
|
|
2643 |
return x
|
2644 |
|
2645 |
def _merge_input_ids_with_image_features(
|
2646 |
+
self, image_features, inputs_embeds, task_prefix_attention_mask=None
|
2647 |
):
|
2648 |
batch_size, image_token_length = image_features.size()[:-1]
|
2649 |
device = image_features.device
|
|
|
2655 |
return image_features, image_attention_mask
|
2656 |
|
2657 |
task_prefix_embeds = inputs_embeds
|
|
|
2658 |
|
2659 |
+
if task_prefix_attention_mask is None:
|
2660 |
+
task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
|
2661 |
+
|
2662 |
+
if len(task_prefix_attention_mask.shape) == 3:
|
2663 |
+
task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
|
2664 |
|
2665 |
# concat [image embeds, task prefix embeds]
|
2666 |
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
|
|
|
2736 |
if pixel_values is not None:
|
2737 |
# (batch_size, num_image_tokens, hidden_size)
|
2738 |
image_features = self._encode_image(pixel_values)
|
2739 |
+
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
|
2740 |
|
2741 |
if inputs_embeds is not None:
|
2742 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|