Add task_prefix_attention_mask argument to _merge_input_ids_with_image_features for better padding handling
#66
by
pawlowskipawel
- opened
- modeling_florence2.py +10 -6
modeling_florence2.py
CHANGED
@@ -2643,7 +2643,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2643 |
return x
|
2644 |
|
2645 |
def _merge_input_ids_with_image_features(
|
2646 |
-
self, image_features, inputs_embeds
|
2647 |
):
|
2648 |
batch_size, image_token_length = image_features.size()[:-1]
|
2649 |
device = image_features.device
|
@@ -2655,10 +2655,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2655 |
return image_features, image_attention_mask
|
2656 |
|
2657 |
task_prefix_embeds = inputs_embeds
|
2658 |
-
task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
|
2659 |
|
2660 |
-
if
|
2661 |
-
task_prefix_attention_mask =
|
|
|
|
|
|
|
2662 |
|
2663 |
# concat [image embeds, task prefix embeds]
|
2664 |
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
|
@@ -2734,7 +2736,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2734 |
if pixel_values is not None:
|
2735 |
# (batch_size, num_image_tokens, hidden_size)
|
2736 |
image_features = self._encode_image(pixel_values)
|
2737 |
-
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
2738 |
|
2739 |
if inputs_embeds is not None:
|
2740 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
@@ -2781,6 +2783,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2781 |
input_ids,
|
2782 |
inputs_embeds=None,
|
2783 |
pixel_values=None,
|
|
|
2784 |
**kwargs
|
2785 |
):
|
2786 |
|
@@ -2791,11 +2794,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2791 |
# 2. Merge text and images
|
2792 |
if pixel_values is not None:
|
2793 |
image_features = self._encode_image(pixel_values)
|
2794 |
-
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
2795 |
|
2796 |
return self.language_model.generate(
|
2797 |
input_ids=None,
|
2798 |
inputs_embeds=inputs_embeds,
|
|
|
2799 |
**kwargs
|
2800 |
)
|
2801 |
|
|
|
2643 |
return x
|
2644 |
|
2645 |
def _merge_input_ids_with_image_features(
|
2646 |
+
self, image_features, inputs_embeds, task_prefix_attention_mask=None
|
2647 |
):
|
2648 |
batch_size, image_token_length = image_features.size()[:-1]
|
2649 |
device = image_features.device
|
|
|
2655 |
return image_features, image_attention_mask
|
2656 |
|
2657 |
task_prefix_embeds = inputs_embeds
|
|
|
2658 |
|
2659 |
+
if task_prefix_attention_mask is None:
|
2660 |
+
task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
|
2661 |
+
|
2662 |
+
if len(task_prefix_attention_mask.shape) == 3:
|
2663 |
+
task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
|
2664 |
|
2665 |
# concat [image embeds, task prefix embeds]
|
2666 |
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
|
|
|
2736 |
if pixel_values is not None:
|
2737 |
# (batch_size, num_image_tokens, hidden_size)
|
2738 |
image_features = self._encode_image(pixel_values)
|
2739 |
+
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
|
2740 |
|
2741 |
if inputs_embeds is not None:
|
2742 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
|
|
2783 |
input_ids,
|
2784 |
inputs_embeds=None,
|
2785 |
pixel_values=None,
|
2786 |
+
attention_mask=None,
|
2787 |
**kwargs
|
2788 |
):
|
2789 |
|
|
|
2794 |
# 2. Merge text and images
|
2795 |
if pixel_values is not None:
|
2796 |
image_features = self._encode_image(pixel_values)
|
2797 |
+
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
|
2798 |
|
2799 |
return self.language_model.generate(
|
2800 |
input_ids=None,
|
2801 |
inputs_embeds=inputs_embeds,
|
2802 |
+
attention_mask=attention_mask,
|
2803 |
**kwargs
|
2804 |
)
|
2805 |
|