HugoLaurencon commited on
Commit
81425ef
1 Parent(s): 0ece5f5
Files changed (1) hide show
  1. modeling_siglip.py +84 -68
modeling_siglip.py CHANGED
@@ -95,10 +95,12 @@ def _trunc_normal_(tensor, mean, std, a, b):
95
 
96
  # Use inverse cdf transform for normal distribution to get truncated
97
  # standard normal
98
- if tensor.dtype == torch.bfloat16:
 
 
99
  tensor = tensor.to(torch.float32)
100
  tensor.erfinv_()
101
- tensor = tensor.to(torch.bfloat16)
102
  else:
103
  tensor.erfinv_()
104
 
@@ -107,7 +109,13 @@ def _trunc_normal_(tensor, mean, std, a, b):
107
  tensor.add_(mean)
108
 
109
  # Clamp to ensure it's in the proper range
110
- tensor.clamp_(min=a, max=b)
 
 
 
 
 
 
111
 
112
 
113
  def trunc_normal_tf_(
@@ -119,11 +127,9 @@ def trunc_normal_tf_(
119
  with values outside :math:`[a, b]` redrawn until they are within
120
  the bounds. The method used for generating the random values works
121
  best when :math:`a \\leq \text{mean} \\leq b`.
122
-
123
  NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
124
  bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
125
  and the result is subsquently scaled and shifted by the mean and std args.
126
-
127
  Args:
128
  tensor: an n-dimensional `torch.Tensor`
129
  mean: the mean of the normal distribution
@@ -174,7 +180,6 @@ def default_flax_embed_init(tensor):
174
  class SiglipVisionModelOutput(ModelOutput):
175
  """
176
  Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
177
-
178
  Args:
179
  image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
180
  The image embeddings obtained by applying the projection layer to the pooler_output.
@@ -183,12 +188,10 @@ class SiglipVisionModelOutput(ModelOutput):
183
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
184
  Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
185
  one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
186
-
187
  Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
188
  attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
189
  Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
190
  sequence_length)`.
191
-
192
  Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
193
  heads.
194
  """
@@ -204,7 +207,6 @@ class SiglipVisionModelOutput(ModelOutput):
204
  class SiglipTextModelOutput(ModelOutput):
205
  """
206
  Base class for text model's outputs that also contains a pooling of the last hidden states.
207
-
208
  Args:
209
  text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
210
  The text embeddings obtained by applying the projection layer to the pooler_output.
@@ -213,12 +215,10 @@ class SiglipTextModelOutput(ModelOutput):
213
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
214
  Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
215
  one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
216
-
217
  Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
218
  attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
219
  Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
220
  sequence_length)`.
221
-
222
  Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
223
  heads.
224
  """
@@ -283,16 +283,44 @@ class SiglipVisionEmbeddings(nn.Module):
283
  padding="valid",
284
  )
285
 
286
- self.num_patches = (self.image_size // self.patch_size) ** 2
 
287
  self.num_positions = self.num_patches
288
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
289
- self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
290
 
291
- def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
292
- patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
 
 
293
  embeddings = patch_embeds.flatten(2).transpose(1, 2)
294
 
295
- embeddings = embeddings + self.position_embedding(self.position_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  return embeddings
297
 
298
 
@@ -504,7 +532,6 @@ class SiglipFlashAttention2(SiglipAttention):
504
  """
505
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
506
  first unpad the input, then computes the attention scores and pad the final attention scores.
507
-
508
  Args:
509
  query_states (`torch.Tensor`):
510
  Input query states to be passed to Flash Attention API
@@ -675,7 +702,7 @@ class SiglipPreTrainedModel(PreTrainedModel):
675
 
676
  def _init_weights(self, module):
677
  """Initialize the weights"""
678
-
679
  if isinstance(module, SiglipVisionEmbeddings):
680
  width = (
681
  self.config.vision_config.hidden_size
@@ -704,7 +731,7 @@ class SiglipPreTrainedModel(PreTrainedModel):
704
  nn.init.normal_(module.attention.in_proj_weight.data)
705
  nn.init.zeros_(module.attention.in_proj_bias.data)
706
  elif isinstance(module, SiglipModel):
707
- logit_scale_init = torch.log(torch.tensor(1.0))
708
  module.logit_scale.data.fill_(logit_scale_init)
709
  module.logit_bias.data.zero_()
710
  elif isinstance(module, (nn.Linear, nn.Conv2d)):
@@ -720,11 +747,9 @@ SIGLIP_START_DOCSTRING = r"""
720
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
721
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
722
  etc.)
723
-
724
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
725
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
726
  and behavior.
727
-
728
  Parameters:
729
  config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
730
  Initializing with a config file does not load the weights associated with the model, only the
@@ -736,22 +761,17 @@ SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
736
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
737
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
738
  it.
739
-
740
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
741
  [`PreTrainedTokenizer.__call__`] for details.
742
-
743
  [What are input IDs?](../glossary#input-ids)
744
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
745
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
746
-
747
  - 1 for tokens that are **not masked**,
748
  - 0 for tokens that are **masked**.
749
-
750
  [What are attention masks?](../glossary#attention-mask)
751
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
752
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
753
  config.max_position_embeddings - 1]`.
754
-
755
  [What are position IDs?](../glossary#position-ids)
756
  output_attentions (`bool`, *optional*):
757
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
@@ -783,22 +803,17 @@ SIGLIP_INPUTS_DOCSTRING = r"""
783
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
784
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
785
  it.
786
-
787
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
788
  [`PreTrainedTokenizer.__call__`] for details.
789
-
790
  [What are input IDs?](../glossary#input-ids)
791
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
792
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
793
-
794
  - 1 for tokens that are **not masked**,
795
  - 0 for tokens that are **masked**.
796
-
797
  [What are attention masks?](../glossary#attention-mask)
798
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
799
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
800
  config.max_position_embeddings - 1]`.
801
-
802
  [What are position IDs?](../glossary#position-ids)
803
  pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
804
  Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
@@ -821,7 +836,6 @@ class SiglipEncoder(nn.Module):
821
  """
822
  Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
823
  [`SiglipEncoderLayer`].
824
-
825
  Args:
826
  config: SiglipConfig
827
  """
@@ -849,10 +863,8 @@ class SiglipEncoder(nn.Module):
849
  than the model's internal embedding lookup matrix.
850
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
851
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
852
-
853
  - 1 for tokens that are **not masked**,
854
  - 0 for tokens that are **masked**.
855
-
856
  [What are attention masks?](../glossary#attention-mask)
857
  output_attentions (`bool`, *optional*):
858
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
@@ -929,7 +941,6 @@ class SiglipTextTransformer(nn.Module):
929
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
930
  r"""
931
  Returns:
932
-
933
  """
934
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
935
  output_hidden_states = (
@@ -1011,18 +1022,13 @@ class SiglipTextModel(SiglipPreTrainedModel):
1011
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
1012
  r"""
1013
  Returns:
1014
-
1015
  Examples:
1016
-
1017
  ```python
1018
  >>> from transformers import AutoTokenizer, SiglipTextModel
1019
-
1020
  >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
1021
  >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1022
-
1023
  >>> # important: make sure to set padding="max_length" as that's how the model was trained
1024
  >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1025
-
1026
  >>> outputs = model(**inputs)
1027
  >>> last_hidden_state = outputs.last_hidden_state
1028
  >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
@@ -1055,13 +1061,13 @@ class SiglipVisionTransformer(nn.Module):
1055
  def forward(
1056
  self,
1057
  pixel_values,
 
1058
  output_attentions: Optional[bool] = None,
1059
  output_hidden_states: Optional[bool] = None,
1060
  return_dict: Optional[bool] = None,
1061
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
1062
  r"""
1063
  Returns:
1064
-
1065
  """
1066
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1067
  output_hidden_states = (
@@ -1069,10 +1075,36 @@ class SiglipVisionTransformer(nn.Module):
1069
  )
1070
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1071
 
1072
- hidden_states = self.embeddings(pixel_values)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1073
 
1074
  encoder_outputs = self.encoder(
1075
  inputs_embeds=hidden_states,
 
1076
  output_attentions=output_attentions,
1077
  output_hidden_states=output_hidden_states,
1078
  return_dict=return_dict,
@@ -1081,7 +1113,10 @@ class SiglipVisionTransformer(nn.Module):
1081
  last_hidden_state = encoder_outputs[0]
1082
  last_hidden_state = self.post_layernorm(last_hidden_state)
1083
 
1084
- pooled_output = self.head(last_hidden_state)
 
 
 
1085
 
1086
  if not return_dict:
1087
  return (last_hidden_state, pooled_output) + encoder_outputs[1:]
@@ -1105,11 +1140,13 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
1105
  self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1106
  self.mlp = SiglipMLP(config)
1107
 
1108
- def forward(self, hidden_state):
1109
  batch_size = hidden_state.shape[0]
1110
  probe = self.probe.repeat(batch_size, 1, 1)
1111
 
1112
- hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
 
 
1113
 
1114
  residual = hidden_state
1115
  hidden_state = self.layernorm(hidden_state)
@@ -1142,28 +1179,23 @@ class SiglipVisionModel(SiglipPreTrainedModel):
1142
  def forward(
1143
  self,
1144
  pixel_values,
 
1145
  output_attentions: Optional[bool] = None,
1146
  output_hidden_states: Optional[bool] = None,
1147
  return_dict: Optional[bool] = None,
1148
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
1149
  r"""
1150
  Returns:
1151
-
1152
  Examples:
1153
-
1154
  ```python
1155
  >>> from PIL import Image
1156
  >>> import requests
1157
  >>> from transformers import AutoProcessor, SiglipVisionModel
1158
-
1159
  >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
1160
  >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1161
-
1162
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1163
  >>> image = Image.open(requests.get(url, stream=True).raw)
1164
-
1165
  >>> inputs = processor(images=image, return_tensors="pt")
1166
-
1167
  >>> outputs = model(**inputs)
1168
  >>> last_hidden_state = outputs.last_hidden_state
1169
  >>> pooled_output = outputs.pooler_output # pooled features
@@ -1172,6 +1204,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
1172
 
1173
  return self.vision_model(
1174
  pixel_values=pixel_values,
 
1175
  output_attentions=output_attentions,
1176
  output_hidden_states=output_hidden_states,
1177
  return_dict=return_dict,
@@ -1223,16 +1256,12 @@ class SiglipModel(SiglipPreTrainedModel):
1223
  Returns:
1224
  text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1225
  applying the projection layer to the pooled output of [`SiglipTextModel`].
1226
-
1227
  Examples:
1228
-
1229
  ```python
1230
  >>> from transformers import AutoTokenizer, AutoModel
1231
  >>> import torch
1232
-
1233
  >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1234
  >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1235
-
1236
  >>> # important: make sure to set padding="max_length" as that's how the model was trained
1237
  >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1238
  >>> with torch.no_grad():
@@ -1270,23 +1299,17 @@ class SiglipModel(SiglipPreTrainedModel):
1270
  Returns:
1271
  image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1272
  applying the projection layer to the pooled output of [`SiglipVisionModel`].
1273
-
1274
  Examples:
1275
-
1276
  ```python
1277
  >>> from PIL import Image
1278
  >>> import requests
1279
  >>> from transformers import AutoProcessor, AutoModel
1280
  >>> import torch
1281
-
1282
  >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1283
  >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1284
-
1285
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1286
  >>> image = Image.open(requests.get(url, stream=True).raw)
1287
-
1288
  >>> inputs = processor(images=image, return_tensors="pt")
1289
-
1290
  >>> with torch.no_grad():
1291
  ... image_features = model.get_image_features(**inputs)
1292
  ```"""
@@ -1323,28 +1346,21 @@ class SiglipModel(SiglipPreTrainedModel):
1323
  ) -> Union[Tuple, SiglipOutput]:
1324
  r"""
1325
  Returns:
1326
-
1327
  Examples:
1328
-
1329
  ```python
1330
  >>> from PIL import Image
1331
  >>> import requests
1332
  >>> from transformers import AutoProcessor, AutoModel
1333
  >>> import torch
1334
-
1335
  >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1336
  >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1337
-
1338
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1339
  >>> image = Image.open(requests.get(url, stream=True).raw)
1340
-
1341
  >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
1342
  >>> # important: we pass `padding=max_length` since the model was trained with this
1343
  >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
1344
-
1345
  >>> with torch.no_grad():
1346
  ... outputs = model(**inputs)
1347
-
1348
  >>> logits_per_image = outputs.logits_per_image
1349
  >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
1350
  >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
 
95
 
96
  # Use inverse cdf transform for normal distribution to get truncated
97
  # standard normal
98
+ if tensor.dtype in [torch.float16, torch.bfloat16]:
99
+ # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
100
+ og_dtype = tensor.dtype
101
  tensor = tensor.to(torch.float32)
102
  tensor.erfinv_()
103
+ tensor = tensor.to(og_dtype)
104
  else:
105
  tensor.erfinv_()
106
 
 
109
  tensor.add_(mean)
110
 
111
  # Clamp to ensure it's in the proper range
112
+ if tensor.dtype == torch.float16:
113
+ # The `clamp_` op is not (yet?) defined in float16+cpu
114
+ tensor = tensor.to(torch.float32)
115
+ tensor.clamp_(min=a, max=b)
116
+ tensor = tensor.to(torch.float16)
117
+ else:
118
+ tensor.clamp_(min=a, max=b)
119
 
120
 
121
  def trunc_normal_tf_(
 
127
  with values outside :math:`[a, b]` redrawn until they are within
128
  the bounds. The method used for generating the random values works
129
  best when :math:`a \\leq \text{mean} \\leq b`.
 
130
  NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
131
  bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
132
  and the result is subsquently scaled and shifted by the mean and std args.
 
133
  Args:
134
  tensor: an n-dimensional `torch.Tensor`
135
  mean: the mean of the normal distribution
 
180
  class SiglipVisionModelOutput(ModelOutput):
181
  """
182
  Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
 
183
  Args:
184
  image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
185
  The image embeddings obtained by applying the projection layer to the pooler_output.
 
188
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
189
  Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
190
  one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
 
191
  Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
192
  attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
193
  Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
194
  sequence_length)`.
 
195
  Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
196
  heads.
197
  """
 
207
  class SiglipTextModelOutput(ModelOutput):
208
  """
209
  Base class for text model's outputs that also contains a pooling of the last hidden states.
 
210
  Args:
211
  text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
212
  The text embeddings obtained by applying the projection layer to the pooler_output.
 
215
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
216
  Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
217
  one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
 
218
  Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
219
  attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
220
  Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
221
  sequence_length)`.
 
222
  Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
223
  heads.
224
  """
 
283
  padding="valid",
284
  )
285
 
286
+ self.num_patches_per_side = self.image_size // self.patch_size
287
+ self.num_patches = self.num_patches_per_side**2
288
  self.num_positions = self.num_patches
289
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
 
290
 
291
+ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
292
+ batch_size = pixel_values.size(0)
293
+
294
+ patch_embeds = self.patch_embedding(pixel_values)
295
  embeddings = patch_embeds.flatten(2).transpose(1, 2)
296
 
297
+ max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
298
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
299
+ boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
300
+ position_ids = torch.full(
301
+ size=(
302
+ batch_size,
303
+ max_nb_patches_h * max_nb_patches_w,
304
+ ),
305
+ fill_value=0,
306
+ )
307
+
308
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
309
+ nb_patches_h = p_attn_mask[:, 0].sum()
310
+ nb_patches_w = p_attn_mask[0].sum()
311
+
312
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
313
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
314
+
315
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
316
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
317
+
318
+ pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
319
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
320
+
321
+ position_ids = position_ids.to(self.position_embedding.weight.device)
322
+
323
+ embeddings = embeddings + self.position_embedding(position_ids)
324
  return embeddings
325
 
326
 
 
532
  """
533
  Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
534
  first unpad the input, then computes the attention scores and pad the final attention scores.
 
535
  Args:
536
  query_states (`torch.Tensor`):
537
  Input query states to be passed to Flash Attention API
 
702
 
703
  def _init_weights(self, module):
704
  """Initialize the weights"""
705
+
706
  if isinstance(module, SiglipVisionEmbeddings):
707
  width = (
708
  self.config.vision_config.hidden_size
 
731
  nn.init.normal_(module.attention.in_proj_weight.data)
732
  nn.init.zeros_(module.attention.in_proj_bias.data)
733
  elif isinstance(module, SiglipModel):
734
+ logit_scale_init = torch.tensor(0.0)
735
  module.logit_scale.data.fill_(logit_scale_init)
736
  module.logit_bias.data.zero_()
737
  elif isinstance(module, (nn.Linear, nn.Conv2d)):
 
747
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
748
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
749
  etc.)
 
750
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
751
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
752
  and behavior.
 
753
  Parameters:
754
  config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
755
  Initializing with a config file does not load the weights associated with the model, only the
 
761
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
762
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
763
  it.
 
764
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
765
  [`PreTrainedTokenizer.__call__`] for details.
 
766
  [What are input IDs?](../glossary#input-ids)
767
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
768
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
769
  - 1 for tokens that are **not masked**,
770
  - 0 for tokens that are **masked**.
 
771
  [What are attention masks?](../glossary#attention-mask)
772
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
773
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
774
  config.max_position_embeddings - 1]`.
 
775
  [What are position IDs?](../glossary#position-ids)
776
  output_attentions (`bool`, *optional*):
777
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
 
803
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
804
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
805
  it.
 
806
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
807
  [`PreTrainedTokenizer.__call__`] for details.
 
808
  [What are input IDs?](../glossary#input-ids)
809
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
810
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
811
  - 1 for tokens that are **not masked**,
812
  - 0 for tokens that are **masked**.
 
813
  [What are attention masks?](../glossary#attention-mask)
814
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
815
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
816
  config.max_position_embeddings - 1]`.
 
817
  [What are position IDs?](../glossary#position-ids)
818
  pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
819
  Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
 
836
  """
837
  Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
838
  [`SiglipEncoderLayer`].
 
839
  Args:
840
  config: SiglipConfig
841
  """
 
863
  than the model's internal embedding lookup matrix.
864
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
865
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
866
  - 1 for tokens that are **not masked**,
867
  - 0 for tokens that are **masked**.
 
868
  [What are attention masks?](../glossary#attention-mask)
869
  output_attentions (`bool`, *optional*):
870
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
 
941
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
942
  r"""
943
  Returns:
 
944
  """
945
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
946
  output_hidden_states = (
 
1022
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
1023
  r"""
1024
  Returns:
 
1025
  Examples:
 
1026
  ```python
1027
  >>> from transformers import AutoTokenizer, SiglipTextModel
 
1028
  >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
1029
  >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
 
1030
  >>> # important: make sure to set padding="max_length" as that's how the model was trained
1031
  >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
 
1032
  >>> outputs = model(**inputs)
1033
  >>> last_hidden_state = outputs.last_hidden_state
1034
  >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
 
1061
  def forward(
1062
  self,
1063
  pixel_values,
1064
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
1065
  output_attentions: Optional[bool] = None,
1066
  output_hidden_states: Optional[bool] = None,
1067
  return_dict: Optional[bool] = None,
1068
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
1069
  r"""
1070
  Returns:
 
1071
  """
1072
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1073
  output_hidden_states = (
 
1075
  )
1076
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1077
 
1078
+ batch_size = pixel_values.size(0)
1079
+ if patch_attention_mask is None:
1080
+ patch_attention_mask = torch.ones(
1081
+ size=(
1082
+ batch_size,
1083
+ pixel_values.size(2) // self.config.patch_size,
1084
+ pixel_values.size(3) // self.config.patch_size,
1085
+ ),
1086
+ dtype=torch.bool,
1087
+ device=pixel_values.device,
1088
+ )
1089
+
1090
+ hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
1091
+
1092
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
1093
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
1094
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
1095
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
1096
+ if not torch.any(~patch_attention_mask):
1097
+ attention_mask=None
1098
+ else:
1099
+ attention_mask = (
1100
+ _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
1101
+ if not self.config._flash_attn_2_enabled
1102
+ else patch_attention_mask
1103
+ )
1104
 
1105
  encoder_outputs = self.encoder(
1106
  inputs_embeds=hidden_states,
1107
+ attention_mask=attention_mask,
1108
  output_attentions=output_attentions,
1109
  output_hidden_states=output_hidden_states,
1110
  return_dict=return_dict,
 
1113
  last_hidden_state = encoder_outputs[0]
1114
  last_hidden_state = self.post_layernorm(last_hidden_state)
1115
 
1116
+ pooled_output = self.head(
1117
+ hidden_state=last_hidden_state,
1118
+ attention_mask=patch_attention_mask,
1119
+ )
1120
 
1121
  if not return_dict:
1122
  return (last_hidden_state, pooled_output) + encoder_outputs[1:]
 
1140
  self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1141
  self.mlp = SiglipMLP(config)
1142
 
1143
+ def forward(self, hidden_state, attention_mask):
1144
  batch_size = hidden_state.shape[0]
1145
  probe = self.probe.repeat(batch_size, 1, 1)
1146
 
1147
+ hidden_state = self.attention(
1148
+ query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask
1149
+ )[0]
1150
 
1151
  residual = hidden_state
1152
  hidden_state = self.layernorm(hidden_state)
 
1179
  def forward(
1180
  self,
1181
  pixel_values,
1182
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
1183
  output_attentions: Optional[bool] = None,
1184
  output_hidden_states: Optional[bool] = None,
1185
  return_dict: Optional[bool] = None,
1186
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
1187
  r"""
1188
  Returns:
 
1189
  Examples:
 
1190
  ```python
1191
  >>> from PIL import Image
1192
  >>> import requests
1193
  >>> from transformers import AutoProcessor, SiglipVisionModel
 
1194
  >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
1195
  >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
 
1196
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1197
  >>> image = Image.open(requests.get(url, stream=True).raw)
 
1198
  >>> inputs = processor(images=image, return_tensors="pt")
 
1199
  >>> outputs = model(**inputs)
1200
  >>> last_hidden_state = outputs.last_hidden_state
1201
  >>> pooled_output = outputs.pooler_output # pooled features
 
1204
 
1205
  return self.vision_model(
1206
  pixel_values=pixel_values,
1207
+ patch_attention_mask=patch_attention_mask,
1208
  output_attentions=output_attentions,
1209
  output_hidden_states=output_hidden_states,
1210
  return_dict=return_dict,
 
1256
  Returns:
1257
  text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1258
  applying the projection layer to the pooled output of [`SiglipTextModel`].
 
1259
  Examples:
 
1260
  ```python
1261
  >>> from transformers import AutoTokenizer, AutoModel
1262
  >>> import torch
 
1263
  >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1264
  >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
 
1265
  >>> # important: make sure to set padding="max_length" as that's how the model was trained
1266
  >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1267
  >>> with torch.no_grad():
 
1299
  Returns:
1300
  image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1301
  applying the projection layer to the pooled output of [`SiglipVisionModel`].
 
1302
  Examples:
 
1303
  ```python
1304
  >>> from PIL import Image
1305
  >>> import requests
1306
  >>> from transformers import AutoProcessor, AutoModel
1307
  >>> import torch
 
1308
  >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1309
  >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
 
1310
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1311
  >>> image = Image.open(requests.get(url, stream=True).raw)
 
1312
  >>> inputs = processor(images=image, return_tensors="pt")
 
1313
  >>> with torch.no_grad():
1314
  ... image_features = model.get_image_features(**inputs)
1315
  ```"""
 
1346
  ) -> Union[Tuple, SiglipOutput]:
1347
  r"""
1348
  Returns:
 
1349
  Examples:
 
1350
  ```python
1351
  >>> from PIL import Image
1352
  >>> import requests
1353
  >>> from transformers import AutoProcessor, AutoModel
1354
  >>> import torch
 
1355
  >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1356
  >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
 
1357
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1358
  >>> image = Image.open(requests.get(url, stream=True).raw)
 
1359
  >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
1360
  >>> # important: we pass `padding=max_length` since the model was trained with this
1361
  >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
 
1362
  >>> with torch.no_grad():
1363
  ... outputs = model(**inputs)
 
1364
  >>> logits_per_image = outputs.logits_per_image
1365
  >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
1366
  >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")