Bagel-7B-Demo / modeling /bagel /siglip_navit.py
KingNish's picture
Update
12a0dd9 verified
# Copyright (c) 2024 The HuggingFace Inc. team.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under Apache-2.0, with the full license text
# available at https://github.com/huggingface/transformers/blob/main/LICENSE.
#
# This modified file is released under the same license.
import torch
from torch import nn
from transformers.activations import ACT2FN
from modeling.siglip.configuration_siglip import SiglipVisionConfig as _SiglipVisionConfig
from modeling.siglip.modeling_siglip import SiglipAttention, SiglipPreTrainedModel
from flash_attn import flash_attn_varlen_func
class SiglipVisionConfig(_SiglipVisionConfig):
r"""
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
Example:
```python
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
>>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipVisionConfig()
>>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "siglip_vision_model"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=224,
patch_size=16,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
rope=True,
**kwargs,
):
super().__init__(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_channels=num_channels,
image_size=image_size,
patch_size=patch_size,
hidden_act=hidden_act,
layer_norm_eps=layer_norm_eps,
attention_dropout=attention_dropout,
**kwargs)
self.rope = rope
class RotaryEmbedding2D(torch.nn.Module):
def __init__(self, dim, max_h, max_w, base=10000):
super().__init__()
freq = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
inv_freq = 1.0 / (base ** freq)
grid_h = torch.arange(0, max_h)
grid_h = grid_h.to(inv_freq.dtype)
grid_h = grid_h[:, None].repeat(1, max_w)
grid_w = torch.arange(0, max_w)
grid_w = grid_w.to(inv_freq.dtype)
grid_w = grid_w[None, :].repeat(max_h, 1)
cos_h, sin_h = self._forward_one_side(grid_h, inv_freq)
cos_w, sin_w = self._forward_one_side(grid_w, inv_freq)
self.register_buffer("cos_h", cos_h)
self.register_buffer("sin_h", sin_h)
self.register_buffer("cos_w", cos_w)
self.register_buffer("sin_w", sin_w)
def _forward_one_side(self, grid, inv_freq):
freqs = grid[..., None] * inv_freq[None, None, :]
emb = torch.cat((freqs, freqs), dim=-1).flatten(0, 1)
return emb.cos(), emb.sin()
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
# unsqueeze due to the head dimension
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
if not config.rope:
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
def convert_conv2d_to_linear(self, config, meta=False):
if meta:
linear_patch_embedding = nn.Linear(
config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True, device='meta'
)
else:
linear_patch_embedding = nn.Linear(
config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True
)
W = self.patch_embedding.weight.permute(0, 2, 3, 1).reshape(
self.embed_dim, config.num_channels * self.patch_size ** 2
)
linear_patch_embedding.weight.data = W
linear_patch_embedding.bias.data = self.patch_embedding.bias.data
del self.patch_embedding
self.patch_embedding = linear_patch_embedding
def forward(
self,
packed_pixel_values: torch.FloatTensor,
packed_flattened_position_ids: torch.LongTensor
) -> torch.Tensor:
patch_embeds = self.patch_embedding(packed_pixel_values)
if not self.config.rope:
embeddings = patch_embeds + self.position_embedding(packed_flattened_position_ids)
else:
embeddings = patch_embeds
return embeddings
class SiglipFlashAttention2(SiglipAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
cos_h: torch.Tensor = None,
sin_h: torch.Tensor = None,
cos_w: torch.Tensor = None,
sin_w: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
total_q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(total_q_len, self.num_heads, self.head_dim)
key_states = key_states.view(total_q_len, self.num_heads, self.head_dim)
value_states = value_states.view(total_q_len, self.num_heads, self.head_dim)
if self.config.rope:
qh, qw = query_states[:, :, :self.head_dim // 2], query_states[:, :, self.head_dim // 2:]
kh, kw = key_states[:, :, :self.head_dim // 2], key_states[:, :, self.head_dim // 2:]
qh, kh = apply_rotary_pos_emb(qh, kh, cos_h, sin_h)
qw, kw = apply_rotary_pos_emb(qw, kw, cos_w, sin_w)
query_states = torch.cat([qh, qw], dim=-1)
key_states = torch.cat([kh, kw], dim=-1)
attn_output = flash_attn_varlen_func(
query_states.to(torch.bfloat16),
key_states.to(torch.bfloat16),
value_states.to(torch.bfloat16),
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=False,
)
attn_output = self.out_proj(attn_output.reshape(total_q_len, -1))
return attn_output
class SiglipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipFlashAttention2(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
cos_h: torch.Tensor = None,
sin_h: torch.Tensor = None,
cos_w: torch.Tensor = None,
sin_w: torch.Tensor = None
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
cos_h=cos_h,
sin_h=sin_h,
cos_w=cos_w,
sin_w=sin_w
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class SiglipEncoder(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
def forward(
self,
inputs_embeds: torch.Tensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
cos_h: torch.Tensor = None,
sin_h: torch.Tensor = None,
cos_w: torch.Tensor = None,
sin_w: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states, cu_seqlens, max_seqlen,
cos_h=cos_h, sin_h=sin_h, cos_w=cos_w, sin_w=sin_w)
return hidden_states
class SiglipVisionTransformer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
if config.rope:
max_size = config.image_size // config.patch_size
dim_head = config.hidden_size // config.num_attention_heads
self.rope = RotaryEmbedding2D(dim_head // 2, max_size, max_size)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def forward(
self,
packed_pixel_values: torch.Tensor,
packed_flattened_position_ids: torch.LongTensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
) -> torch.Tensor:
hidden_states = self.embeddings(
packed_pixel_values=packed_pixel_values,
packed_flattened_position_ids=packed_flattened_position_ids
)
extra_inputs = {}
if self.config.rope:
extra_inputs.update(
cos_h = self.rope.cos_h[packed_flattened_position_ids],
sin_h = self.rope.sin_h[packed_flattened_position_ids],
cos_w = self.rope.cos_w[packed_flattened_position_ids],
sin_w = self.rope.sin_w[packed_flattened_position_ids]
)
last_hidden_state = self.encoder(
inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen,
**extra_inputs
)
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
class SiglipVisionModel(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
main_input_name = "packed_pixel_values"
def __init__(self, config: SiglipVisionConfig):
super().__init__(config)
self.vision_model = SiglipVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def forward(
self,
packed_pixel_values: torch.Tensor,
packed_flattened_position_ids: torch.LongTensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
) -> torch.Tensor:
return self.vision_model(
packed_pixel_values=packed_pixel_values,
packed_flattened_position_ids=packed_flattened_position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)