VictorSanh
commited on
Commit
•
a4ce5f5
1
Parent(s):
83cef5d
save changes to modeling_siglip
Browse files- modeling_siglip.py +62 -12
modeling_siglip.py
CHANGED
@@ -39,7 +39,7 @@ from transformers.utils import (
|
|
39 |
logging,
|
40 |
replace_return_docstrings,
|
41 |
)
|
42 |
-
from
|
43 |
|
44 |
|
45 |
logger = logging.get_logger(__name__)
|
@@ -283,16 +283,45 @@ class SiglipVisionEmbeddings(nn.Module):
|
|
283 |
padding="valid",
|
284 |
)
|
285 |
|
286 |
-
self.
|
|
|
287 |
self.num_positions = self.num_patches
|
288 |
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
289 |
-
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
290 |
|
291 |
-
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
292 |
-
|
293 |
-
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
294 |
|
295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
return embeddings
|
297 |
|
298 |
|
@@ -675,7 +704,7 @@ class SiglipPreTrainedModel(PreTrainedModel):
|
|
675 |
|
676 |
def _init_weights(self, module):
|
677 |
"""Initialize the weights"""
|
678 |
-
|
679 |
if isinstance(module, SiglipVisionEmbeddings):
|
680 |
width = (
|
681 |
self.config.vision_config.hidden_size
|
@@ -1055,6 +1084,7 @@ class SiglipVisionTransformer(nn.Module):
|
|
1055 |
def forward(
|
1056 |
self,
|
1057 |
pixel_values,
|
|
|
1058 |
output_attentions: Optional[bool] = None,
|
1059 |
output_hidden_states: Optional[bool] = None,
|
1060 |
return_dict: Optional[bool] = None,
|
@@ -1069,10 +1099,22 @@ class SiglipVisionTransformer(nn.Module):
|
|
1069 |
)
|
1070 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1071 |
|
1072 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1073 |
|
1074 |
encoder_outputs = self.encoder(
|
1075 |
inputs_embeds=hidden_states,
|
|
|
1076 |
output_attentions=output_attentions,
|
1077 |
output_hidden_states=output_hidden_states,
|
1078 |
return_dict=return_dict,
|
@@ -1081,7 +1123,10 @@ class SiglipVisionTransformer(nn.Module):
|
|
1081 |
last_hidden_state = encoder_outputs[0]
|
1082 |
last_hidden_state = self.post_layernorm(last_hidden_state)
|
1083 |
|
1084 |
-
pooled_output = self.head(
|
|
|
|
|
|
|
1085 |
|
1086 |
if not return_dict:
|
1087 |
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
@@ -1105,11 +1150,16 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
|
1105 |
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
1106 |
self.mlp = SiglipMLP(config)
|
1107 |
|
1108 |
-
def forward(self, hidden_state):
|
1109 |
batch_size = hidden_state.shape[0]
|
1110 |
probe = self.probe.repeat(batch_size, 1, 1)
|
1111 |
|
1112 |
-
hidden_state = self.attention(
|
|
|
|
|
|
|
|
|
|
|
1113 |
|
1114 |
residual = hidden_state
|
1115 |
hidden_state = self.layernorm(hidden_state)
|
|
|
39 |
logging,
|
40 |
replace_return_docstrings,
|
41 |
)
|
42 |
+
from configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
43 |
|
44 |
|
45 |
logger = logging.get_logger(__name__)
|
|
|
283 |
padding="valid",
|
284 |
)
|
285 |
|
286 |
+
self.num_patches_per_side = self.image_size // self.patch_size
|
287 |
+
self.num_patches = self.num_patches_per_side ** 2
|
288 |
self.num_positions = self.num_patches
|
289 |
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
|
290 |
|
291 |
+
def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
|
292 |
+
batch_size = pixel_values.size(0)
|
|
|
293 |
|
294 |
+
patch_embeds = self.patch_embedding(pixel_values)
|
295 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
296 |
+
|
297 |
+
patches_to_select = patch_attention_mask.view(batch_size, -1)
|
298 |
+
max_num_patches = patches_to_select.sum(dim=-1).max()
|
299 |
+
embeddings = torch.zeros((batch_size, max_num_patches, patch_embeds.size(2)), device=patch_embeds.device, dtype=patch_embeds.dtype)
|
300 |
+
for b_idx, (p_embeds, p_to_select) in enumerate(zip(patch_embeds, patches_to_select)):
|
301 |
+
sub_p_embds = p_embeds[p_to_select]
|
302 |
+
embeddings[b_idx][:len(sub_p_embds)] = sub_p_embds
|
303 |
+
|
304 |
+
boundaries = torch.arange(1/self.num_patches_per_side, 1., 1/self.num_patches_per_side)
|
305 |
+
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
|
306 |
+
max_nb_patches_h, max_nb_patches_w = max_im_h//self.patch_size, max_im_w//self.patch_size
|
307 |
+
position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
|
308 |
+
|
309 |
+
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
310 |
+
nb_patches_h = p_attn_mask[0].sum()
|
311 |
+
nb_patches_w = p_attn_mask[:, 0].sum()
|
312 |
+
|
313 |
+
fractional_coords_h = torch.arange(0, 1, 1/nb_patches_h)
|
314 |
+
fractional_coords_w = torch.arange(0, 1, 1/nb_patches_w)
|
315 |
+
|
316 |
+
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
317 |
+
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
318 |
+
|
319 |
+
pos_ids = (self.num_patches_per_side * bucket_coords_w[:, None] + bucket_coords_h[None, :]).flatten()
|
320 |
+
position_ids[batch_idx][:len(pos_ids)] = pos_ids
|
321 |
+
|
322 |
+
position_ids = position_ids.to(self.position_embedding.weight.device)
|
323 |
+
|
324 |
+
embeddings = embeddings + self.position_embedding(position_ids)
|
325 |
return embeddings
|
326 |
|
327 |
|
|
|
704 |
|
705 |
def _init_weights(self, module):
|
706 |
"""Initialize the weights"""
|
707 |
+
|
708 |
if isinstance(module, SiglipVisionEmbeddings):
|
709 |
width = (
|
710 |
self.config.vision_config.hidden_size
|
|
|
1084 |
def forward(
|
1085 |
self,
|
1086 |
pixel_values,
|
1087 |
+
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
1088 |
output_attentions: Optional[bool] = None,
|
1089 |
output_hidden_states: Optional[bool] = None,
|
1090 |
return_dict: Optional[bool] = None,
|
|
|
1099 |
)
|
1100 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1101 |
|
1102 |
+
if pixel_attention_mask is None:
|
1103 |
+
#TODO
|
1104 |
+
pass
|
1105 |
+
|
1106 |
+
batch_size = pixel_attention_mask.size(0) # assuming `pixel_attention_mask` is of size bs x h x w
|
1107 |
+
subgrids = pixel_attention_mask.unfold(dimension=1, size=self.config.patch_size, step=self.config.patch_size).unfold(dimension=2, size=self.config.patch_size, step=self.config.patch_size)
|
1108 |
+
patch_attention_mask = (subgrids.sum(dim=(-1, -2)) > 0).bool()
|
1109 |
+
|
1110 |
+
hidden_states = self.embeddings(
|
1111 |
+
pixel_values=pixel_values,
|
1112 |
+
patch_attention_mask=patch_attention_mask
|
1113 |
+
)
|
1114 |
|
1115 |
encoder_outputs = self.encoder(
|
1116 |
inputs_embeds=hidden_states,
|
1117 |
+
attention_mask=patch_attention_mask.view(batch_size, -1),
|
1118 |
output_attentions=output_attentions,
|
1119 |
output_hidden_states=output_hidden_states,
|
1120 |
return_dict=return_dict,
|
|
|
1123 |
last_hidden_state = encoder_outputs[0]
|
1124 |
last_hidden_state = self.post_layernorm(last_hidden_state)
|
1125 |
|
1126 |
+
pooled_output = self.head(
|
1127 |
+
hidden_state=last_hidden_state,
|
1128 |
+
attention_mask=patch_attention_mask.view(batch_size, -1)
|
1129 |
+
)
|
1130 |
|
1131 |
if not return_dict:
|
1132 |
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
|
|
1150 |
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
1151 |
self.mlp = SiglipMLP(config)
|
1152 |
|
1153 |
+
def forward(self, hidden_state, attention_mask):
|
1154 |
batch_size = hidden_state.shape[0]
|
1155 |
probe = self.probe.repeat(batch_size, 1, 1)
|
1156 |
|
1157 |
+
hidden_state = self.attention(
|
1158 |
+
query=probe,
|
1159 |
+
key=hidden_state,
|
1160 |
+
value=hidden_state,
|
1161 |
+
key_padding_mask=~attention_mask
|
1162 |
+
)[0]
|
1163 |
|
1164 |
residual = hidden_state
|
1165 |
hidden_state = self.layernorm(hidden_state)
|