Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| # Copyright (c) 2024 The HuggingFace Inc. team. | |
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20. | |
| # | |
| # Original file was released under Apache-2.0, with the full license text | |
| # available at https://github.com/huggingface/transformers/blob/main/LICENSE. | |
| # | |
| # This modified file is released under the same license. | |
| import torch | |
| from torch import nn | |
| from transformers.activations import ACT2FN | |
| from modeling.siglip.configuration_siglip import SiglipVisionConfig as _SiglipVisionConfig | |
| from modeling.siglip.modeling_siglip import SiglipAttention, SiglipPreTrainedModel | |
| from flash_attn import flash_attn_varlen_func | |
| class SiglipVisionConfig(_SiglipVisionConfig): | |
| r""" | |
| This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a | |
| Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a | |
| configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip | |
| [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. | |
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | |
| documentation from [`PretrainedConfig`] for more information. | |
| Args: | |
| hidden_size (`int`, *optional*, defaults to 768): | |
| Dimensionality of the encoder layers and the pooler layer. | |
| intermediate_size (`int`, *optional*, defaults to 3072): | |
| Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. | |
| num_hidden_layers (`int`, *optional*, defaults to 12): | |
| Number of hidden layers in the Transformer encoder. | |
| num_attention_heads (`int`, *optional*, defaults to 12): | |
| Number of attention heads for each attention layer in the Transformer encoder. | |
| num_channels (`int`, *optional*, defaults to 3): | |
| Number of channels in the input images. | |
| image_size (`int`, *optional*, defaults to 224): | |
| The size (resolution) of each image. | |
| patch_size (`int`, *optional*, defaults to 16): | |
| The size (resolution) of each patch. | |
| hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): | |
| The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, | |
| `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. | |
| layer_norm_eps (`float`, *optional*, defaults to 1e-06): | |
| The epsilon used by the layer normalization layers. | |
| attention_dropout (`float`, *optional*, defaults to 0.0): | |
| The dropout ratio for the attention probabilities. | |
| Example: | |
| ```python | |
| >>> from transformers import SiglipVisionConfig, SiglipVisionModel | |
| >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration | |
| >>> configuration = SiglipVisionConfig() | |
| >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration | |
| >>> model = SiglipVisionModel(configuration) | |
| >>> # Accessing the model configuration | |
| >>> configuration = model.config | |
| ```""" | |
| model_type = "siglip_vision_model" | |
| def __init__( | |
| self, | |
| hidden_size=768, | |
| intermediate_size=3072, | |
| num_hidden_layers=12, | |
| num_attention_heads=12, | |
| num_channels=3, | |
| image_size=224, | |
| patch_size=16, | |
| hidden_act="gelu_pytorch_tanh", | |
| layer_norm_eps=1e-6, | |
| attention_dropout=0.0, | |
| rope=True, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| hidden_size=hidden_size, | |
| intermediate_size=intermediate_size, | |
| num_hidden_layers=num_hidden_layers, | |
| num_attention_heads=num_attention_heads, | |
| num_channels=num_channels, | |
| image_size=image_size, | |
| patch_size=patch_size, | |
| hidden_act=hidden_act, | |
| layer_norm_eps=layer_norm_eps, | |
| attention_dropout=attention_dropout, | |
| **kwargs) | |
| self.rope = rope | |
| class RotaryEmbedding2D(torch.nn.Module): | |
| def __init__(self, dim, max_h, max_w, base=10000): | |
| super().__init__() | |
| freq = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim | |
| inv_freq = 1.0 / (base ** freq) | |
| grid_h = torch.arange(0, max_h) | |
| grid_h = grid_h.to(inv_freq.dtype) | |
| grid_h = grid_h[:, None].repeat(1, max_w) | |
| grid_w = torch.arange(0, max_w) | |
| grid_w = grid_w.to(inv_freq.dtype) | |
| grid_w = grid_w[None, :].repeat(max_h, 1) | |
| cos_h, sin_h = self._forward_one_side(grid_h, inv_freq) | |
| cos_w, sin_w = self._forward_one_side(grid_w, inv_freq) | |
| self.register_buffer("cos_h", cos_h) | |
| self.register_buffer("sin_h", sin_h) | |
| self.register_buffer("cos_w", cos_w) | |
| self.register_buffer("sin_w", sin_w) | |
| def _forward_one_side(self, grid, inv_freq): | |
| freqs = grid[..., None] * inv_freq[None, None, :] | |
| emb = torch.cat((freqs, freqs), dim=-1).flatten(0, 1) | |
| return emb.cos(), emb.sin() | |
| def rotate_half(x): | |
| x1 = x[..., : x.shape[-1] // 2] | |
| x2 = x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(q, k, cos, sin): | |
| # unsqueeze due to the head dimension | |
| cos = cos.unsqueeze(1) | |
| sin = sin.unsqueeze(1) | |
| q_embed = (q * cos) + (rotate_half(q) * sin) | |
| k_embed = (k * cos) + (rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| class SiglipVisionEmbeddings(nn.Module): | |
| def __init__(self, config: SiglipVisionConfig): | |
| super().__init__() | |
| self.config = config | |
| self.embed_dim = config.hidden_size | |
| self.image_size = config.image_size | |
| self.patch_size = config.patch_size | |
| self.patch_embedding = nn.Conv2d( | |
| in_channels=config.num_channels, | |
| out_channels=self.embed_dim, | |
| kernel_size=self.patch_size, | |
| stride=self.patch_size, | |
| padding="valid", | |
| ) | |
| self.num_patches_per_side = self.image_size // self.patch_size | |
| self.num_patches = self.num_patches_per_side**2 | |
| self.num_positions = self.num_patches | |
| if not config.rope: | |
| self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) | |
| def convert_conv2d_to_linear(self, config, meta=False): | |
| if meta: | |
| linear_patch_embedding = nn.Linear( | |
| config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True, device='meta' | |
| ) | |
| else: | |
| linear_patch_embedding = nn.Linear( | |
| config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True | |
| ) | |
| W = self.patch_embedding.weight.permute(0, 2, 3, 1).reshape( | |
| self.embed_dim, config.num_channels * self.patch_size ** 2 | |
| ) | |
| linear_patch_embedding.weight.data = W | |
| linear_patch_embedding.bias.data = self.patch_embedding.bias.data | |
| del self.patch_embedding | |
| self.patch_embedding = linear_patch_embedding | |
| def forward( | |
| self, | |
| packed_pixel_values: torch.FloatTensor, | |
| packed_flattened_position_ids: torch.LongTensor | |
| ) -> torch.Tensor: | |
| patch_embeds = self.patch_embedding(packed_pixel_values) | |
| if not self.config.rope: | |
| embeddings = patch_embeds + self.position_embedding(packed_flattened_position_ids) | |
| else: | |
| embeddings = patch_embeds | |
| return embeddings | |
| class SiglipFlashAttention2(SiglipAttention): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cu_seqlens: torch.IntTensor, | |
| max_seqlen: int, | |
| cos_h: torch.Tensor = None, | |
| sin_h: torch.Tensor = None, | |
| cos_w: torch.Tensor = None, | |
| sin_w: torch.Tensor = None, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| total_q_len, _ = hidden_states.size() | |
| query_states = self.q_proj(hidden_states) | |
| key_states = self.k_proj(hidden_states) | |
| value_states = self.v_proj(hidden_states) | |
| query_states = query_states.view(total_q_len, self.num_heads, self.head_dim) | |
| key_states = key_states.view(total_q_len, self.num_heads, self.head_dim) | |
| value_states = value_states.view(total_q_len, self.num_heads, self.head_dim) | |
| if self.config.rope: | |
| qh, qw = query_states[:, :, :self.head_dim // 2], query_states[:, :, self.head_dim // 2:] | |
| kh, kw = key_states[:, :, :self.head_dim // 2], key_states[:, :, self.head_dim // 2:] | |
| qh, kh = apply_rotary_pos_emb(qh, kh, cos_h, sin_h) | |
| qw, kw = apply_rotary_pos_emb(qw, kw, cos_w, sin_w) | |
| query_states = torch.cat([qh, qw], dim=-1) | |
| key_states = torch.cat([kh, kw], dim=-1) | |
| attn_output = flash_attn_varlen_func( | |
| query_states.to(torch.bfloat16), | |
| key_states.to(torch.bfloat16), | |
| value_states.to(torch.bfloat16), | |
| cu_seqlens_q=cu_seqlens, | |
| cu_seqlens_k=cu_seqlens, | |
| max_seqlen_q=max_seqlen, | |
| max_seqlen_k=max_seqlen, | |
| causal=False, | |
| ) | |
| attn_output = self.out_proj(attn_output.reshape(total_q_len, -1)) | |
| return attn_output | |
| class SiglipMLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.activation_fn = ACT2FN[config.hidden_act] | |
| self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) | |
| self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.fc1(hidden_states) | |
| hidden_states = self.activation_fn(hidden_states) | |
| hidden_states = self.fc2(hidden_states) | |
| return hidden_states | |
| class SiglipEncoderLayer(nn.Module): | |
| def __init__(self, config: SiglipVisionConfig): | |
| super().__init__() | |
| self.embed_dim = config.hidden_size | |
| self.self_attn = SiglipFlashAttention2(config) | |
| self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) | |
| self.mlp = SiglipMLP(config) | |
| self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cu_seqlens: torch.IntTensor, | |
| max_seqlen: int, | |
| cos_h: torch.Tensor = None, | |
| sin_h: torch.Tensor = None, | |
| cos_w: torch.Tensor = None, | |
| sin_w: torch.Tensor = None | |
| ) -> torch.Tensor: | |
| residual = hidden_states | |
| hidden_states = self.layer_norm1(hidden_states) | |
| hidden_states = self.self_attn( | |
| hidden_states=hidden_states, | |
| cu_seqlens=cu_seqlens, | |
| max_seqlen=max_seqlen, | |
| cos_h=cos_h, | |
| sin_h=sin_h, | |
| cos_w=cos_w, | |
| sin_w=sin_w | |
| ) | |
| hidden_states = residual + hidden_states | |
| residual = hidden_states | |
| hidden_states = self.layer_norm2(hidden_states) | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + hidden_states | |
| return hidden_states | |
| class SiglipEncoder(nn.Module): | |
| def __init__(self, config: SiglipVisionConfig): | |
| super().__init__() | |
| self.config = config | |
| self.layers = nn.ModuleList( | |
| [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)] | |
| ) | |
| def forward( | |
| self, | |
| inputs_embeds: torch.Tensor, | |
| cu_seqlens: torch.IntTensor, | |
| max_seqlen: int, | |
| cos_h: torch.Tensor = None, | |
| sin_h: torch.Tensor = None, | |
| cos_w: torch.Tensor = None, | |
| sin_w: torch.Tensor = None, | |
| ) -> torch.Tensor: | |
| hidden_states = inputs_embeds | |
| for encoder_layer in self.layers: | |
| hidden_states = encoder_layer(hidden_states, cu_seqlens, max_seqlen, | |
| cos_h=cos_h, sin_h=sin_h, cos_w=cos_w, sin_w=sin_w) | |
| return hidden_states | |
| class SiglipVisionTransformer(nn.Module): | |
| def __init__(self, config: SiglipVisionConfig): | |
| super().__init__() | |
| self.config = config | |
| embed_dim = config.hidden_size | |
| self.embeddings = SiglipVisionEmbeddings(config) | |
| if config.rope: | |
| max_size = config.image_size // config.patch_size | |
| dim_head = config.hidden_size // config.num_attention_heads | |
| self.rope = RotaryEmbedding2D(dim_head // 2, max_size, max_size) | |
| self.encoder = SiglipEncoder(config) | |
| self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | |
| def forward( | |
| self, | |
| packed_pixel_values: torch.Tensor, | |
| packed_flattened_position_ids: torch.LongTensor, | |
| cu_seqlens: torch.IntTensor, | |
| max_seqlen: int, | |
| ) -> torch.Tensor: | |
| hidden_states = self.embeddings( | |
| packed_pixel_values=packed_pixel_values, | |
| packed_flattened_position_ids=packed_flattened_position_ids | |
| ) | |
| extra_inputs = {} | |
| if self.config.rope: | |
| extra_inputs.update( | |
| cos_h = self.rope.cos_h[packed_flattened_position_ids], | |
| sin_h = self.rope.sin_h[packed_flattened_position_ids], | |
| cos_w = self.rope.cos_w[packed_flattened_position_ids], | |
| sin_w = self.rope.sin_w[packed_flattened_position_ids] | |
| ) | |
| last_hidden_state = self.encoder( | |
| inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, | |
| **extra_inputs | |
| ) | |
| last_hidden_state = self.post_layernorm(last_hidden_state) | |
| return last_hidden_state | |
| class SiglipVisionModel(SiglipPreTrainedModel): | |
| config_class = SiglipVisionConfig | |
| main_input_name = "packed_pixel_values" | |
| def __init__(self, config: SiglipVisionConfig): | |
| super().__init__(config) | |
| self.vision_model = SiglipVisionTransformer(config) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_input_embeddings(self) -> nn.Module: | |
| return self.vision_model.embeddings.patch_embedding | |
| def forward( | |
| self, | |
| packed_pixel_values: torch.Tensor, | |
| packed_flattened_position_ids: torch.LongTensor, | |
| cu_seqlens: torch.IntTensor, | |
| max_seqlen: int, | |
| ) -> torch.Tensor: | |
| return self.vision_model( | |
| packed_pixel_values=packed_pixel_values, | |
| packed_flattened_position_ids=packed_flattened_position_ids, | |
| cu_seqlens=cu_seqlens, | |
| max_seqlen=max_seqlen, | |
| ) | |