VictorSanh commited on
Commit
a4ce5f5
1 Parent(s): 83cef5d

save changes to modeling_siglip

Browse files
Files changed (1) hide show
  1. modeling_siglip.py +62 -12
modeling_siglip.py CHANGED
@@ -39,7 +39,7 @@ from transformers.utils import (
39
  logging,
40
  replace_return_docstrings,
41
  )
42
- from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
43
 
44
 
45
  logger = logging.get_logger(__name__)
@@ -283,16 +283,45 @@ class SiglipVisionEmbeddings(nn.Module):
283
  padding="valid",
284
  )
285
 
286
- self.num_patches = (self.image_size // self.patch_size) ** 2
 
287
  self.num_positions = self.num_patches
288
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
289
- self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
290
 
291
- def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
292
- patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
293
- embeddings = patch_embeds.flatten(2).transpose(1, 2)
294
 
295
- embeddings = embeddings + self.position_embedding(self.position_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  return embeddings
297
 
298
 
@@ -675,7 +704,7 @@ class SiglipPreTrainedModel(PreTrainedModel):
675
 
676
  def _init_weights(self, module):
677
  """Initialize the weights"""
678
-
679
  if isinstance(module, SiglipVisionEmbeddings):
680
  width = (
681
  self.config.vision_config.hidden_size
@@ -1055,6 +1084,7 @@ class SiglipVisionTransformer(nn.Module):
1055
  def forward(
1056
  self,
1057
  pixel_values,
 
1058
  output_attentions: Optional[bool] = None,
1059
  output_hidden_states: Optional[bool] = None,
1060
  return_dict: Optional[bool] = None,
@@ -1069,10 +1099,22 @@ class SiglipVisionTransformer(nn.Module):
1069
  )
1070
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1071
 
1072
- hidden_states = self.embeddings(pixel_values)
 
 
 
 
 
 
 
 
 
 
 
1073
 
1074
  encoder_outputs = self.encoder(
1075
  inputs_embeds=hidden_states,
 
1076
  output_attentions=output_attentions,
1077
  output_hidden_states=output_hidden_states,
1078
  return_dict=return_dict,
@@ -1081,7 +1123,10 @@ class SiglipVisionTransformer(nn.Module):
1081
  last_hidden_state = encoder_outputs[0]
1082
  last_hidden_state = self.post_layernorm(last_hidden_state)
1083
 
1084
- pooled_output = self.head(last_hidden_state)
 
 
 
1085
 
1086
  if not return_dict:
1087
  return (last_hidden_state, pooled_output) + encoder_outputs[1:]
@@ -1105,11 +1150,16 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
1105
  self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1106
  self.mlp = SiglipMLP(config)
1107
 
1108
- def forward(self, hidden_state):
1109
  batch_size = hidden_state.shape[0]
1110
  probe = self.probe.repeat(batch_size, 1, 1)
1111
 
1112
- hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
 
 
 
 
 
1113
 
1114
  residual = hidden_state
1115
  hidden_state = self.layernorm(hidden_state)
 
39
  logging,
40
  replace_return_docstrings,
41
  )
42
+ from configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
43
 
44
 
45
  logger = logging.get_logger(__name__)
 
283
  padding="valid",
284
  )
285
 
286
+ self.num_patches_per_side = self.image_size // self.patch_size
287
+ self.num_patches = self.num_patches_per_side ** 2
288
  self.num_positions = self.num_patches
289
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
 
290
 
291
+ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
292
+ batch_size = pixel_values.size(0)
 
293
 
294
+ patch_embeds = self.patch_embedding(pixel_values)
295
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
296
+
297
+ patches_to_select = patch_attention_mask.view(batch_size, -1)
298
+ max_num_patches = patches_to_select.sum(dim=-1).max()
299
+ embeddings = torch.zeros((batch_size, max_num_patches, patch_embeds.size(2)), device=patch_embeds.device, dtype=patch_embeds.dtype)
300
+ for b_idx, (p_embeds, p_to_select) in enumerate(zip(patch_embeds, patches_to_select)):
301
+ sub_p_embds = p_embeds[p_to_select]
302
+ embeddings[b_idx][:len(sub_p_embds)] = sub_p_embds
303
+
304
+ boundaries = torch.arange(1/self.num_patches_per_side, 1., 1/self.num_patches_per_side)
305
+ max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
306
+ max_nb_patches_h, max_nb_patches_w = max_im_h//self.patch_size, max_im_w//self.patch_size
307
+ position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
308
+
309
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
310
+ nb_patches_h = p_attn_mask[0].sum()
311
+ nb_patches_w = p_attn_mask[:, 0].sum()
312
+
313
+ fractional_coords_h = torch.arange(0, 1, 1/nb_patches_h)
314
+ fractional_coords_w = torch.arange(0, 1, 1/nb_patches_w)
315
+
316
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
317
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
318
+
319
+ pos_ids = (self.num_patches_per_side * bucket_coords_w[:, None] + bucket_coords_h[None, :]).flatten()
320
+ position_ids[batch_idx][:len(pos_ids)] = pos_ids
321
+
322
+ position_ids = position_ids.to(self.position_embedding.weight.device)
323
+
324
+ embeddings = embeddings + self.position_embedding(position_ids)
325
  return embeddings
326
 
327
 
 
704
 
705
  def _init_weights(self, module):
706
  """Initialize the weights"""
707
+
708
  if isinstance(module, SiglipVisionEmbeddings):
709
  width = (
710
  self.config.vision_config.hidden_size
 
1084
  def forward(
1085
  self,
1086
  pixel_values,
1087
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
1088
  output_attentions: Optional[bool] = None,
1089
  output_hidden_states: Optional[bool] = None,
1090
  return_dict: Optional[bool] = None,
 
1099
  )
1100
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1101
 
1102
+ if pixel_attention_mask is None:
1103
+ #TODO
1104
+ pass
1105
+
1106
+ batch_size = pixel_attention_mask.size(0) # assuming `pixel_attention_mask` is of size bs x h x w
1107
+ subgrids = pixel_attention_mask.unfold(dimension=1, size=self.config.patch_size, step=self.config.patch_size).unfold(dimension=2, size=self.config.patch_size, step=self.config.patch_size)
1108
+ patch_attention_mask = (subgrids.sum(dim=(-1, -2)) > 0).bool()
1109
+
1110
+ hidden_states = self.embeddings(
1111
+ pixel_values=pixel_values,
1112
+ patch_attention_mask=patch_attention_mask
1113
+ )
1114
 
1115
  encoder_outputs = self.encoder(
1116
  inputs_embeds=hidden_states,
1117
+ attention_mask=patch_attention_mask.view(batch_size, -1),
1118
  output_attentions=output_attentions,
1119
  output_hidden_states=output_hidden_states,
1120
  return_dict=return_dict,
 
1123
  last_hidden_state = encoder_outputs[0]
1124
  last_hidden_state = self.post_layernorm(last_hidden_state)
1125
 
1126
+ pooled_output = self.head(
1127
+ hidden_state=last_hidden_state,
1128
+ attention_mask=patch_attention_mask.view(batch_size, -1)
1129
+ )
1130
 
1131
  if not return_dict:
1132
  return (last_hidden_state, pooled_output) + encoder_outputs[1:]
 
1150
  self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1151
  self.mlp = SiglipMLP(config)
1152
 
1153
+ def forward(self, hidden_state, attention_mask):
1154
  batch_size = hidden_state.shape[0]
1155
  probe = self.probe.repeat(batch_size, 1, 1)
1156
 
1157
+ hidden_state = self.attention(
1158
+ query=probe,
1159
+ key=hidden_state,
1160
+ value=hidden_state,
1161
+ key_padding_mask=~attention_mask
1162
+ )[0]
1163
 
1164
  residual = hidden_state
1165
  hidden_state = self.layernorm(hidden_state)