AndesVL-0_6B-Thinking / modeling_siglip2_navit_rope.py
davenliu's picture
Upload folder using huggingface_hub
1cc29bc verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
is_flash_attn_2_available,
)
try:
from .configuration_siglip2_navit_rope import Siglip2VisionConfig
except:
from configuration_siglip2_navit_rope import Siglip2VisionConfig
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
else:
flash_attn_varlen_func = None
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class PatchEmbed(nn.Module):
def __init__(
self,
patch_size,
num_channels,
embed_dim,
num_patches,
preserve_original_pe=False
):
super().__init__()
self.patch_size = patch_size
self.num_patches = num_patches
self.embed_dim = embed_dim
self.preserve_original_pe = preserve_original_pe
self.proj = nn.Linear(
num_channels * patch_size * patch_size, embed_dim
) # NOTE: bias默认为True
if preserve_original_pe:
assert num_patches**0.5 == int(num_patches**0.5), "num_patches must be a perfect square"
self.pos_embed = nn.Embedding(num_patches, embed_dim)
self.original_grid_size = int(num_patches**0.5)
else:
self.pos_embed = None
self.original_grid_size = 0
def get_patch_coordinates(self, grid_hw: torch.Tensor, device: torch.device):
"""
生成与2x2分块扫描顺序匹配的patch坐标。
"""
all_h_coords, all_w_coords, all_target_sizes = [], [], []
for h, w in grid_hw:
h, w = h.item(), w.item()
# 生成标准网格坐标
h_grid, w_grid = torch.meshgrid(
torch.arange(h, device=device, dtype=torch.float32),
torch.arange(w, device=device, dtype=torch.float32),
indexing='ij'
)
# 重排列为分块扫描顺序
h_coords = h_grid.reshape(
h//2, 2, w//2, 2
).permute(0, 2, 1, 3).flatten()
w_coords = w_grid.reshape(
h//2, 2, w//2, 2
).permute(0, 2, 1, 3).flatten()
all_h_coords.append(h_coords)
all_w_coords.append(w_coords)
target_size = torch.tensor([h, w], device=device, dtype=torch.float32)
all_target_sizes.append(target_size.expand(h * w, -1))
return torch.cat(all_h_coords), torch.cat(all_w_coords), torch.cat(all_target_sizes)
def abs_pos_embed(self, grid_hw: torch.Tensor, mode='bicubic') -> torch.Tensor:
pos_embed_weight = self.pos_embed.weight
pos_embed_2d = pos_embed_weight.transpose(0, 1).reshape(
self.embed_dim, self.original_grid_size, self.original_grid_size
).unsqueeze(0).to(torch.float32)
if grid_hw.numel() == 0:
return torch.empty(0, self.embed_dim, device=pos_embed_2d.device, dtype=pos_embed_weight.dtype)
h_coords, w_coords, target_sizes = self.get_patch_coordinates(grid_hw, pos_embed_2d.device)
if h_coords.shape[0] == 0:
return torch.empty(0, self.embed_dim, device=pos_embed_2d.device, dtype=pos_embed_weight.dtype)
target_h = target_sizes[:, 0]
target_w = target_sizes[:, 1]
# 这个归一化公式对于 align_corners=False 是正确的。
norm_w = (2.0 * (w_coords + 0.5) / target_w) - 1.0
norm_h = (2.0 * (h_coords + 0.5) / target_h) - 1.0
grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(0)
interpolated_embed = F.grid_sample(
pos_embed_2d, grid, mode=mode, align_corners=False,
padding_mode='border'
)
adapted_pos_embed = interpolated_embed.squeeze(0).squeeze(1).permute(1, 0)
return adapted_pos_embed.to(pos_embed_weight.dtype)
def forward(self, hidden_states: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states (torch.Tensor): input tensor of shape [seq_len, num_channels*patch_size*patch_size]
grid_hw (torch.Tensor): 形状为 [num_images, 2] 的张量,表示每个图像的patch网格高度和宽度
Returns:
torch.Tensor: output tensor of shape [seq_len, embed_dim]
"""
target_dtype = self.proj.weight.dtype
hidden_states = self.proj(hidden_states.to(dtype=target_dtype))
if self.preserve_original_pe:
pos_emb = self.abs_pos_embed(grid_hw)
hidden_states = hidden_states + pos_emb
return hidden_states
class Siglip2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Siglip2Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = q.reshape(seq_length, self.num_heads, -1)
k = k.reshape(seq_length, self.num_heads, -1)
v = v.reshape(seq_length, self.num_heads, -1)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = torch.full(
[1, seq_length, seq_length],
torch.finfo(q.dtype).min,
device=q.device,
dtype=q.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[
...,
cu_seqlens[i - 1] : cu_seqlens[i],
cu_seqlens[i - 1] : cu_seqlens[i],
] = 0
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.out_proj(attn_output)
return attn_output
class Siglip2FlashAttention2(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# 将 q, k, v 重塑为多头注意力的形状
q = q.reshape(seq_length, self.num_heads, -1)
k = k.reshape(seq_length, self.num_heads, -1)
v = v.reshape(seq_length, self.num_heads, -1)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
attn_output = flash_attn_varlen_func(
q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
).reshape(seq_length, -1)
attn_output = self.out_proj(attn_output)
return attn_output
class Siglip2SdpaAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = q.reshape(seq_length, self.num_heads, -1)
k = k.reshape(seq_length, self.num_heads, -1)
v = v.reshape(seq_length, self.num_heads, -1)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = torch.zeros(
[1, seq_length, seq_length], device=q.device, dtype=torch.bool
)
for i in range(1, len(cu_seqlens)):
attention_mask[
...,
cu_seqlens[i - 1] : cu_seqlens[i],
cu_seqlens[i - 1] : cu_seqlens[i],
] = True
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = F.scaled_dot_product_attention(
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, dropout_p=0.0
)
attn_output = attn_output.squeeze(0).transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.out_proj(attn_output)
return attn_output
VISION_ATTENTION_CLASSES = {
"eager": Siglip2Attention,
"flash_attention_2": Siglip2FlashAttention2,
"sdpa": Siglip2SdpaAttention,
}
class Siglip2EncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = VISION_ATTENTION_CLASSES[config._attn_implementation](
config=config
)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
# Ignore copy
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Siglip2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Siglip2EncoderLayer`].
Args:
config: Siglip2Config
"""
def __init__(self, config):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = True
# Ignore copy
def forward(
self,
hidden_states,
cu_seqlens,
rotary_pos_emb,
):
for encoder_layer in self.layers:
if self.gradient_checkpointing and self.training:
hidden_states = torch.utils.checkpoint.checkpoint(
encoder_layer,
hidden_states,
cu_seqlens,
rotary_pos_emb,
use_reentrant=False,
)
else:
hidden_states = encoder_layer(
hidden_states,
cu_seqlens,
rotary_pos_emb,
)
return hidden_states
class Siglip2VisionTransformer(nn.Module):
def __init__(self, config: Siglip2VisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = PatchEmbed(
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=embed_dim,
num_patches=config.num_patches,
preserve_original_pe=config.preserve_original_pe,
)
head_dim = config.hidden_size // config.num_attention_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2, config.rope_theta)
self.encoder = Siglip2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def rot_pos_emb(self, grid_hw):
pos_ids = []
for h, w in grid_hw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // 2,
2,
w // 2,
2,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // 2,
2,
w // 2,
2,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_hw.max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def forward(
self,
hidden_states: torch.Tensor,
grid_hw: torch.Tensor,
):
hidden_states = self.embeddings(hidden_states, grid_hw)
rotary_pos_emb = self.rot_pos_emb(grid_hw)
cu_seqlens = (grid_hw[:, 0] * grid_hw[:, 1]).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
hidden_states = self.encoder(
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
)
hidden_states = self.post_layernorm(hidden_states)
return hidden_states
class Siglip2VisionModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
config_class = Siglip2VisionConfig
main_input_name = "pixel_values"
def __init__(self, config):
super().__init__(config)
self.vision_model = Siglip2VisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def forward(
self, hidden_states: torch.Tensor, grid_hw: torch.Tensor
) -> torch.Tensor:
return self.vision_model(hidden_states, grid_hw)