Leyo commited on
Commit
20f7712
1 Parent(s): 3b15dc9

use transformers siglip modeling implementation except for flash attention

Browse files
Files changed (1) hide show
  1. modeling_siglip.py +198 -168
modeling_siglip.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2023 Google AI and The HuggingFace Team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
@@ -15,14 +15,20 @@
15
  """ PyTorch Siglip model."""
16
 
17
 
 
 
18
  from dataclasses import dataclass
19
  from typing import Any, Optional, Tuple, Union
20
 
 
21
  import torch
22
  import torch.nn.functional as F
23
  import torch.utils.checkpoint
24
  from torch import nn
 
 
25
  from transformers.activations import ACT2FN
 
26
  from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
27
  from transformers.modeling_utils import PreTrainedModel
28
  from transformers.utils import (
@@ -33,7 +39,6 @@ from transformers.utils import (
33
  logging,
34
  replace_return_docstrings,
35
  )
36
-
37
  from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
38
 
39
 
@@ -64,32 +69,99 @@ def _get_unpad_data(attention_mask):
64
  )
65
 
66
 
67
- # Copied from transformers.models.bart.modeling_bart._expand_mask
68
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
69
- """
70
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  """
72
- bsz, src_len = mask.size()
73
- tgt_len = tgt_len if tgt_len is not None else src_len
 
 
74
 
75
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
 
 
 
 
 
 
 
76
 
77
- inverted_mask = 1.0 - expanded_mask
78
 
79
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
- # contrastive loss function, adapted from
83
- # https://sachinruk.github.io/blog/2021-03-07-siglip.html
84
- def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
85
- return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
86
 
87
 
88
- # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->siglip
89
- def siglip_loss(similarity: torch.Tensor) -> torch.Tensor:
90
- caption_loss = contrastive_loss(similarity)
91
- image_loss = contrastive_loss(similarity.t())
92
- return (caption_loss + image_loss) / 2.0
93
 
94
 
95
  @dataclass
@@ -168,8 +240,7 @@ class SiglipOutput(ModelOutput):
168
  text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
169
  The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
170
  image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
171
- The image embeddings obtained by applying the projection layer to the pooled output of
172
- [`SiglipVisionModel`].
173
  text_model_output(`BaseModelOutputWithPooling`):
174
  The output of the [`SiglipTextModel`].
175
  vision_model_output(`BaseModelOutputWithPooling`):
@@ -254,10 +325,10 @@ class SiglipTextEmbeddings(nn.Module):
254
  return embeddings
255
 
256
 
257
- # Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->Siglip
258
  class SiglipAttention(nn.Module):
259
  """Multi-headed attention from 'Attention Is All You Need' paper"""
260
 
 
261
  def __init__(self, config):
262
  super().__init__()
263
  self.config = config
@@ -277,86 +348,57 @@ class SiglipAttention(nn.Module):
277
  self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
278
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
279
 
280
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
281
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
282
-
283
  def forward(
284
  self,
285
  hidden_states: torch.Tensor,
286
  attention_mask: Optional[torch.Tensor] = None,
287
- causal_attention_mask: Optional[torch.Tensor] = None,
288
  output_attentions: Optional[bool] = False,
289
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
290
  """Input shape: Batch x Time x Channel"""
291
 
292
- bsz, tgt_len, embed_dim = hidden_states.size()
293
 
294
- # get query proj
295
- query_states = self.q_proj(hidden_states) * self.scale
296
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
297
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
298
 
299
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
300
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
301
- key_states = key_states.view(*proj_shape)
302
- value_states = value_states.view(*proj_shape)
303
 
304
- src_len = key_states.size(1)
305
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
306
 
307
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
308
  raise ValueError(
309
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
310
  f" {attn_weights.size()}"
311
  )
312
 
313
- # apply the causal_attention_mask first
314
- if causal_attention_mask is not None:
315
- if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
316
- raise ValueError(
317
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
318
- f" {causal_attention_mask.size()}"
319
- )
320
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
321
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
322
-
323
  if attention_mask is not None:
324
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
325
  raise ValueError(
326
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
327
  )
328
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
329
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
330
-
331
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
332
-
333
- if output_attentions:
334
- # this operation is a bit akward, but it's required to
335
- # make sure that attn_weights keeps its gradient.
336
- # In order to do so, attn_weights have to reshaped
337
- # twice and have to be reused in the following
338
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
339
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
340
- else:
341
- attn_weights_reshaped = None
342
 
343
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
 
 
 
344
 
345
- attn_output = torch.bmm(attn_probs, value_states)
346
-
347
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
348
  raise ValueError(
349
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
350
  f" {attn_output.size()}"
351
  )
352
 
353
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
354
- attn_output = attn_output.transpose(1, 2)
355
- attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
356
 
357
  attn_output = self.out_proj(attn_output)
358
 
359
- return attn_output, attn_weights_reshaped
360
 
361
 
362
  class SiglipFlashAttention2(SiglipAttention):
@@ -581,16 +623,15 @@ class SiglipEncoderLayer(nn.Module):
581
  self,
582
  hidden_states: torch.Tensor,
583
  attention_mask: torch.Tensor,
584
- causal_attention_mask: torch.Tensor,
585
  output_attentions: Optional[bool] = False,
586
  ) -> Tuple[torch.FloatTensor]:
587
  """
588
  Args:
589
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
590
- attention_mask (`torch.FloatTensor`): attention mask of size
591
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
592
- `(config.encoder_attention_heads,)`.
593
- output_attentions (`bool`, *optional*):
594
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
595
  returned tensors for more detail.
596
  """
@@ -600,7 +641,6 @@ class SiglipEncoderLayer(nn.Module):
600
  hidden_states, attn_weights = self.self_attn(
601
  hidden_states=hidden_states,
602
  attention_mask=attention_mask,
603
- causal_attention_mask=causal_attention_mask,
604
  output_attentions=output_attentions,
605
  )
606
  hidden_states = residual + hidden_states
@@ -630,39 +670,44 @@ class SiglipPreTrainedModel(PreTrainedModel):
630
 
631
  def _init_weights(self, module):
632
  """Initialize the weights"""
633
- factor = self.config.initializer_factor
634
- if isinstance(module, SiglipTextEmbeddings):
635
- module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
636
- module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
637
- elif isinstance(module, SiglipVisionEmbeddings):
638
- factor = self.config.initializer_factor
639
- nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
640
- nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
 
641
  elif isinstance(module, SiglipAttention):
642
- factor = self.config.initializer_factor
643
- in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
644
- out_proj_std = (module.embed_dim**-0.5) * factor
645
- nn.init.normal_(module.q_proj.weight, std=in_proj_std)
646
- nn.init.normal_(module.k_proj.weight, std=in_proj_std)
647
- nn.init.normal_(module.v_proj.weight, std=in_proj_std)
648
- nn.init.normal_(module.out_proj.weight, std=out_proj_std)
 
649
  elif isinstance(module, SiglipMLP):
650
- factor = self.config.initializer_factor
651
- in_proj_std = (
652
- (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
653
- )
654
- fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
655
- nn.init.normal_(module.fc1.weight, std=fc_std)
656
- nn.init.normal_(module.fc2.weight, std=in_proj_std)
657
- if isinstance(module, nn.LayerNorm):
 
 
 
 
 
 
 
 
 
658
  module.bias.data.zero_()
659
  module.weight.data.fill_(1.0)
660
- if isinstance(module, nn.Linear) and module.bias is not None:
661
- module.bias.data.zero_()
662
-
663
- def _set_gradient_checkpointing(self, module, value=False):
664
- if isinstance(module, SiglipEncoder):
665
- module.gradient_checkpointing = value
666
 
667
 
668
  SIGLIP_START_DOCSTRING = r"""
@@ -781,11 +826,11 @@ class SiglipEncoder(nn.Module):
781
  self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
782
  self.gradient_checkpointing = False
783
 
 
784
  def forward(
785
  self,
786
  inputs_embeds,
787
  attention_mask: Optional[torch.Tensor] = None,
788
- causal_attention_mask: Optional[torch.Tensor] = None,
789
  output_attentions: Optional[bool] = None,
790
  output_hidden_states: Optional[bool] = None,
791
  return_dict: Optional[bool] = None,
@@ -802,13 +847,6 @@ class SiglipEncoder(nn.Module):
802
  - 1 for tokens that are **not masked**,
803
  - 0 for tokens that are **masked**.
804
 
805
- [What are attention masks?](../glossary#attention-mask)
806
- causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
807
- Causal mask for the text model. Mask values selected in `[0, 1]`:
808
-
809
- - 1 for tokens that are **not masked**,
810
- - 0 for tokens that are **masked**.
811
-
812
  [What are attention masks?](../glossary#attention-mask)
813
  output_attentions (`bool`, *optional*):
814
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
@@ -829,28 +867,20 @@ class SiglipEncoder(nn.Module):
829
  all_attentions = () if output_attentions else None
830
 
831
  hidden_states = inputs_embeds
832
- for idx, encoder_layer in enumerate(self.layers):
833
  if output_hidden_states:
834
  encoder_states = encoder_states + (hidden_states,)
835
  if self.gradient_checkpointing and self.training:
836
-
837
- def create_custom_forward(module):
838
- def custom_forward(*inputs):
839
- return module(*inputs, output_attentions)
840
-
841
- return custom_forward
842
-
843
- layer_outputs = torch.utils.checkpoint.checkpoint(
844
- create_custom_forward(encoder_layer),
845
  hidden_states,
846
  attention_mask,
847
- causal_attention_mask,
848
  )
849
  else:
850
  layer_outputs = encoder_layer(
851
  hidden_states,
852
  attention_mask,
853
- causal_attention_mask,
854
  output_attentions=output_attentions,
855
  )
856
 
@@ -909,16 +939,15 @@ class SiglipTextTransformer(nn.Module):
909
 
910
  hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
911
 
912
- # note: SigLIP's text model does not use q causal mask, unlike the original CLIP model.
913
  # expand attention_mask
914
  if attention_mask is not None:
915
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
916
- attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
917
 
918
  encoder_outputs = self.encoder(
919
  inputs_embeds=hidden_states,
920
- attention_mask=None,
921
- causal_attention_mask=None,
922
  output_attentions=output_attentions,
923
  output_hidden_states=output_hidden_states,
924
  return_dict=return_dict,
@@ -985,7 +1014,8 @@ class SiglipTextModel(SiglipPreTrainedModel):
985
  >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
986
  >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
987
 
988
- >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
 
989
 
990
  >>> outputs = model(**inputs)
991
  >>> last_hidden_state = outputs.last_hidden_state
@@ -1130,7 +1160,7 @@ class SiglipVisionModel(SiglipPreTrainedModel):
1130
 
1131
  >>> outputs = model(**inputs)
1132
  >>> last_hidden_state = outputs.last_hidden_state
1133
- >>> pooled_output = outputs.pooler_output # pooled CLS states
1134
  ```"""
1135
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1136
 
@@ -1164,19 +1194,11 @@ class SiglipModel(SiglipPreTrainedModel):
1164
  text_config = config.text_config
1165
  vision_config = config.vision_config
1166
 
1167
- self.text_model = SiglipTextModel(text_config)
1168
- self.vision_model = SiglipVisionModel(vision_config)
1169
 
1170
- self.temperature = nn.Parameter(
1171
- torch.randn(
1172
- 1,
1173
- )
1174
- )
1175
- self.bias = nn.Parameter(
1176
- torch.randn(
1177
- 1,
1178
- )
1179
- )
1180
 
1181
  # Initialize weights and apply final processing
1182
  self.post_init()
@@ -1199,13 +1221,16 @@ class SiglipModel(SiglipPreTrainedModel):
1199
  Examples:
1200
 
1201
  ```python
1202
- >>> from transformers import AutoTokenizer, SiglipModel
 
1203
 
1204
- >>> model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
1205
  >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1206
 
1207
- >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1208
- >>> text_features = model.get_text_features(**inputs)
 
 
1209
  ```"""
1210
  # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1211
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1245,9 +1270,10 @@ class SiglipModel(SiglipPreTrainedModel):
1245
  ```python
1246
  >>> from PIL import Image
1247
  >>> import requests
1248
- >>> from transformers import AutoProcessor, SiglipModel
 
1249
 
1250
- >>> model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
1251
  >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1252
 
1253
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
@@ -1255,7 +1281,8 @@ class SiglipModel(SiglipPreTrainedModel):
1255
 
1256
  >>> inputs = processor(images=image, return_tensors="pt")
1257
 
1258
- >>> image_features = model.get_image_features(**inputs)
 
1259
  ```"""
1260
  # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1261
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1296,21 +1323,26 @@ class SiglipModel(SiglipPreTrainedModel):
1296
  ```python
1297
  >>> from PIL import Image
1298
  >>> import requests
1299
- >>> from transformers import AutoProcessor, SiglipModel
 
1300
 
1301
- >>> model = SiglipModel.from_pretrained("google/siglip-base-patch16-224")
1302
  >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1303
 
1304
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1305
  >>> image = Image.open(requests.get(url, stream=True).raw)
1306
 
1307
- >>> inputs = processor(
1308
- ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
1309
- ... )
1310
 
1311
- >>> outputs = model(**inputs)
1312
- >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
1313
- >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
 
 
 
 
1314
  ```"""
1315
  # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1316
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1343,11 +1375,9 @@ class SiglipModel(SiglipPreTrainedModel):
1343
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1344
 
1345
  # cosine similarity as logits
1346
- logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.temperature.exp() + self.bias
1347
  logits_per_image = logits_per_text.t()
1348
 
1349
- z = torch.matmul(image_embeds, text_embeds.t()) * self.temperature.exp()
1350
-
1351
  loss = None
1352
  if return_loss:
1353
  raise NotImplementedError("SigLIP loss to be implemented")
 
1
  # coding=utf-8
2
+ # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
15
  """ PyTorch Siglip model."""
16
 
17
 
18
+ import math
19
+ import warnings
20
  from dataclasses import dataclass
21
  from typing import Any, Optional, Tuple, Union
22
 
23
+ import numpy as np
24
  import torch
25
  import torch.nn.functional as F
26
  import torch.utils.checkpoint
27
  from torch import nn
28
+ from torch.nn.init import _calculate_fan_in_and_fan_out
29
+
30
  from transformers.activations import ACT2FN
31
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
32
  from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
33
  from transformers.modeling_utils import PreTrainedModel
34
  from transformers.utils import (
 
39
  logging,
40
  replace_return_docstrings,
41
  )
 
42
  from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
43
 
44
 
 
69
  )
70
 
71
 
72
+ def _trunc_normal_(tensor, mean, std, a, b):
73
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
74
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
75
+ def norm_cdf(x):
76
+ # Computes standard normal cumulative distribution function
77
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
78
+
79
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
80
+ warnings.warn(
81
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
82
+ "The distribution of values may be incorrect.",
83
+ stacklevel=2,
84
+ )
85
+
86
+ # Values are generated by using a truncated uniform distribution and
87
+ # then using the inverse CDF for the normal distribution.
88
+ # Get upper and lower cdf values
89
+ l = norm_cdf((a - mean) / std)
90
+ u = norm_cdf((b - mean) / std)
91
+
92
+ # Uniformly fill tensor with values from [l, u], then translate to
93
+ # [2l-1, 2u-1].
94
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
95
+
96
+ # Use inverse cdf transform for normal distribution to get truncated
97
+ # standard normal
98
+ tensor.erfinv_()
99
+
100
+ # Transform to proper mean, std
101
+ tensor.mul_(std * math.sqrt(2.0))
102
+ tensor.add_(mean)
103
+
104
+ # Clamp to ensure it's in the proper range
105
+ tensor.clamp_(min=a, max=b)
106
+
107
+
108
+ def trunc_normal_tf_(
109
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
110
+ ) -> torch.Tensor:
111
+ """Fills the input Tensor with values drawn from a truncated
112
+ normal distribution. The values are effectively drawn from the
113
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
114
+ with values outside :math:`[a, b]` redrawn until they are within
115
+ the bounds. The method used for generating the random values works
116
+ best when :math:`a \\leq \text{mean} \\leq b`.
117
+
118
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
119
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
120
+ and the result is subsquently scaled and shifted by the mean and std args.
121
+
122
+ Args:
123
+ tensor: an n-dimensional `torch.Tensor`
124
+ mean: the mean of the normal distribution
125
+ std: the standard deviation of the normal distribution
126
+ a: the minimum cutoff value
127
+ b: the maximum cutoff value
128
  """
129
+ with torch.no_grad():
130
+ _trunc_normal_(tensor, 0, 1.0, a, b)
131
+ tensor.mul_(std).add_(mean)
132
+
133
 
134
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
135
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
136
+ if mode == "fan_in":
137
+ denom = fan_in
138
+ elif mode == "fan_out":
139
+ denom = fan_out
140
+ elif mode == "fan_avg":
141
+ denom = (fan_in + fan_out) / 2
142
 
143
+ variance = scale / denom
144
 
145
+ if distribution == "truncated_normal":
146
+ # constant is stddev of standard normal truncated to (-2, 2)
147
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
148
+ elif distribution == "normal":
149
+ with torch.no_grad():
150
+ tensor.normal_(std=math.sqrt(variance))
151
+ elif distribution == "uniform":
152
+ bound = math.sqrt(3 * variance)
153
+ with torch.no_grad():
154
+ tensor.uniform_(-bound, bound)
155
+ else:
156
+ raise ValueError(f"invalid distribution {distribution}")
157
 
158
 
159
+ def lecun_normal_(tensor):
160
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
 
 
161
 
162
 
163
+ def default_flax_embed_init(tensor):
164
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
 
 
 
165
 
166
 
167
  @dataclass
 
240
  text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
241
  The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
242
  image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
243
+ The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
 
244
  text_model_output(`BaseModelOutputWithPooling`):
245
  The output of the [`SiglipTextModel`].
246
  vision_model_output(`BaseModelOutputWithPooling`):
 
325
  return embeddings
326
 
327
 
 
328
  class SiglipAttention(nn.Module):
329
  """Multi-headed attention from 'Attention Is All You Need' paper"""
330
 
331
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
332
  def __init__(self, config):
333
  super().__init__()
334
  self.config = config
 
348
  self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
349
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
350
 
 
 
 
351
  def forward(
352
  self,
353
  hidden_states: torch.Tensor,
354
  attention_mask: Optional[torch.Tensor] = None,
 
355
  output_attentions: Optional[bool] = False,
356
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
357
  """Input shape: Batch x Time x Channel"""
358
 
359
+ batch_size, q_len, _ = hidden_states.size()
360
 
361
+ query_states = self.q_proj(hidden_states)
362
+ key_states = self.k_proj(hidden_states)
363
+ value_states = self.v_proj(hidden_states)
 
364
 
365
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
366
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
367
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 
368
 
369
+ k_v_seq_len = key_states.shape[-2]
370
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
371
 
372
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
373
  raise ValueError(
374
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
375
  f" {attn_weights.size()}"
376
  )
377
 
 
 
 
 
 
 
 
 
 
 
378
  if attention_mask is not None:
379
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
380
  raise ValueError(
381
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
382
  )
383
+ attn_weights = attn_weights + attention_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
+ # upcast attention to fp32
386
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
387
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
388
+ attn_output = torch.matmul(attn_weights, value_states)
389
 
390
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
 
 
391
  raise ValueError(
392
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
393
  f" {attn_output.size()}"
394
  )
395
 
396
+ attn_output = attn_output.transpose(1, 2).contiguous()
397
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
 
398
 
399
  attn_output = self.out_proj(attn_output)
400
 
401
+ return attn_output, attn_weights
402
 
403
 
404
  class SiglipFlashAttention2(SiglipAttention):
 
623
  self,
624
  hidden_states: torch.Tensor,
625
  attention_mask: torch.Tensor,
 
626
  output_attentions: Optional[bool] = False,
627
  ) -> Tuple[torch.FloatTensor]:
628
  """
629
  Args:
630
+ hidden_states (`torch.FloatTensor`):
631
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
632
+ attention_mask (`torch.FloatTensor`):
633
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
634
+ output_attentions (`bool`, *optional*, defaults to `False`):
635
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
636
  returned tensors for more detail.
637
  """
 
641
  hidden_states, attn_weights = self.self_attn(
642
  hidden_states=hidden_states,
643
  attention_mask=attention_mask,
 
644
  output_attentions=output_attentions,
645
  )
646
  hidden_states = residual + hidden_states
 
670
 
671
  def _init_weights(self, module):
672
  """Initialize the weights"""
673
+ if isinstance(module, SiglipVisionEmbeddings):
674
+ width = (
675
+ self.config.vision_config.hidden_size
676
+ if isinstance(self.config, SiglipConfig)
677
+ else self.config.hidden_size
678
+ )
679
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
680
+ elif isinstance(module, nn.Embedding):
681
+ default_flax_embed_init(module.weight)
682
  elif isinstance(module, SiglipAttention):
683
+ nn.init.xavier_uniform_(module.q_proj.weight)
684
+ nn.init.xavier_uniform_(module.k_proj.weight)
685
+ nn.init.xavier_uniform_(module.v_proj.weight)
686
+ nn.init.xavier_uniform_(module.out_proj.weight)
687
+ nn.init.zeros_(module.q_proj.bias)
688
+ nn.init.zeros_(module.k_proj.bias)
689
+ nn.init.zeros_(module.v_proj.bias)
690
+ nn.init.zeros_(module.out_proj.bias)
691
  elif isinstance(module, SiglipMLP):
692
+ nn.init.xavier_uniform_(module.fc1.weight)
693
+ nn.init.xavier_uniform_(module.fc2.weight)
694
+ nn.init.normal_(module.fc1.bias, std=1e-6)
695
+ nn.init.normal_(module.fc2.bias, std=1e-6)
696
+ elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
697
+ nn.init.xavier_uniform_(module.probe.data)
698
+ nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
699
+ nn.init.zeros_(module.attention.in_proj_bias.data)
700
+ elif isinstance(module, SiglipModel):
701
+ logit_scale_init = torch.log(torch.tensor(1.0))
702
+ module.logit_scale.data.fill_(logit_scale_init)
703
+ module.logit_bias.data.zero_()
704
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
705
+ lecun_normal_(module.weight)
706
+ if module.bias is not None:
707
+ nn.init.zeros_(module.bias)
708
+ elif isinstance(module, nn.LayerNorm):
709
  module.bias.data.zero_()
710
  module.weight.data.fill_(1.0)
 
 
 
 
 
 
711
 
712
 
713
  SIGLIP_START_DOCSTRING = r"""
 
826
  self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
827
  self.gradient_checkpointing = False
828
 
829
+ # Ignore copy
830
  def forward(
831
  self,
832
  inputs_embeds,
833
  attention_mask: Optional[torch.Tensor] = None,
 
834
  output_attentions: Optional[bool] = None,
835
  output_hidden_states: Optional[bool] = None,
836
  return_dict: Optional[bool] = None,
 
847
  - 1 for tokens that are **not masked**,
848
  - 0 for tokens that are **masked**.
849
 
 
 
 
 
 
 
 
850
  [What are attention masks?](../glossary#attention-mask)
851
  output_attentions (`bool`, *optional*):
852
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under
 
867
  all_attentions = () if output_attentions else None
868
 
869
  hidden_states = inputs_embeds
870
+ for encoder_layer in self.layers:
871
  if output_hidden_states:
872
  encoder_states = encoder_states + (hidden_states,)
873
  if self.gradient_checkpointing and self.training:
874
+ layer_outputs = self._gradient_checkpointing_func(
875
+ encoder_layer.__call__,
 
 
 
 
 
 
 
876
  hidden_states,
877
  attention_mask,
878
+ output_attentions,
879
  )
880
  else:
881
  layer_outputs = encoder_layer(
882
  hidden_states,
883
  attention_mask,
 
884
  output_attentions=output_attentions,
885
  )
886
 
 
939
 
940
  hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
941
 
942
+ # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
943
  # expand attention_mask
944
  if attention_mask is not None:
945
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
946
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
947
 
948
  encoder_outputs = self.encoder(
949
  inputs_embeds=hidden_states,
950
+ attention_mask=attention_mask,
 
951
  output_attentions=output_attentions,
952
  output_hidden_states=output_hidden_states,
953
  return_dict=return_dict,
 
1014
  >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
1015
  >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1016
 
1017
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1018
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1019
 
1020
  >>> outputs = model(**inputs)
1021
  >>> last_hidden_state = outputs.last_hidden_state
 
1160
 
1161
  >>> outputs = model(**inputs)
1162
  >>> last_hidden_state = outputs.last_hidden_state
1163
+ >>> pooled_output = outputs.pooler_output # pooled features
1164
  ```"""
1165
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1166
 
 
1194
  text_config = config.text_config
1195
  vision_config = config.vision_config
1196
 
1197
+ self.text_model = SiglipTextTransformer(text_config)
1198
+ self.vision_model = SiglipVisionTransformer(vision_config)
1199
 
1200
+ self.logit_scale = nn.Parameter(torch.randn(1))
1201
+ self.logit_bias = nn.Parameter(torch.randn(1))
 
 
 
 
 
 
 
 
1202
 
1203
  # Initialize weights and apply final processing
1204
  self.post_init()
 
1221
  Examples:
1222
 
1223
  ```python
1224
+ >>> from transformers import AutoTokenizer, AutoModel
1225
+ >>> import torch
1226
 
1227
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1228
  >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1229
 
1230
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1231
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1232
+ >>> with torch.no_grad():
1233
+ ... text_features = model.get_text_features(**inputs)
1234
  ```"""
1235
  # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1236
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
1270
  ```python
1271
  >>> from PIL import Image
1272
  >>> import requests
1273
+ >>> from transformers import AutoProcessor, AutoModel
1274
+ >>> import torch
1275
 
1276
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1277
  >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1278
 
1279
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
 
1281
 
1282
  >>> inputs = processor(images=image, return_tensors="pt")
1283
 
1284
+ >>> with torch.no_grad():
1285
+ ... image_features = model.get_image_features(**inputs)
1286
  ```"""
1287
  # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1288
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
1323
  ```python
1324
  >>> from PIL import Image
1325
  >>> import requests
1326
+ >>> from transformers import AutoProcessor, AutoModel
1327
+ >>> import torch
1328
 
1329
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1330
  >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1331
 
1332
  >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1333
  >>> image = Image.open(requests.get(url, stream=True).raw)
1334
 
1335
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
1336
+ >>> # important: we pass `padding=max_length` since the model was trained with this
1337
+ >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
1338
 
1339
+ >>> with torch.no_grad():
1340
+ ... outputs = model(**inputs)
1341
+
1342
+ >>> logits_per_image = outputs.logits_per_image
1343
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
1344
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
1345
+ 31.9% that image 0 is 'a photo of 2 cats'
1346
  ```"""
1347
  # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1348
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
1375
  text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1376
 
1377
  # cosine similarity as logits
1378
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias
1379
  logits_per_image = logits_per_text.t()
1380
 
 
 
1381
  loss = None
1382
  if return_loss:
1383
  raise NotImplementedError("SigLIP loss to be implemented")