RAG-ColPali / models /siglip.py
Huy
First commit
d8bb2be
raw
history blame
6.34 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional
@dataclass
class SigLIPConfig:
hidden_size: int = 1152
intermediate_size: int = 4304
num_attention_heads: int = 16
num_hidden_layers: int = 27
num_image_tokens: int = 256
patch_size: int = 14
projection_dim: int = 2048
n_channels: int = 3
img_size: int = 224
norm_eps: float = 1e-6
attention_dropout: float = 0.0
@classmethod
def from_dict(cls, data):
return cls(
hidden_size = data['hidden_size'],
intermediate_size = data['intermediate_size'],
num_attention_heads = data['num_attention_heads'],
num_hidden_layers = data['num_hidden_layers'],
num_image_tokens = data['num_image_tokens'],
patch_size = data['patch_size'],
projection_dim = data['projection_dim']
)
class SigLIPEmbedding(nn.Module):
def __init__(self, cfg: SigLIPConfig):
super().__init__()
self.patch_embedding = nn.Conv2d(cfg.n_channels, cfg.hidden_size, kernel_size=cfg.patch_size, stride=cfg.patch_size, padding='valid')
self.num_patches = (cfg.img_size // cfg.patch_size) ** 2
self.position_embedding = nn.Embedding(cfg.num_image_tokens, cfg.hidden_size)
self.register_buffer('position_ids',
torch.arange(cfg.num_image_tokens).expand(1, -1),
persistent=False)
def forward(self, x: torch.FloatTensor):
# x: (n, c, h, w) -> (n, c, num_patch_h, num_patch_w)
img_embeds = self.patch_embedding(x)
# (n, c, num_patch_h, num_patch_w) -> (n, c, num_patches) -> (n, num_patches, c)
img_embeds = img_embeds.reshape(*img_embeds.shape[:2], -1).transpose(1, 2)
return img_embeds + self.position_embedding(self.position_ids.to(torch.int64))
class SigLIPTransformerAttention(nn.Module):
def __init__(self, cfg: SigLIPConfig):
super().__init__()
self.cfg = cfg
self.num_attention_heads = cfg.num_attention_heads
self.head_dim = cfg.hidden_size // self.num_attention_heads
self.q_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size)
self.k_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size)
self.v_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size)
self.out_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size)
self.dropout_p = self.cfg.attention_dropout
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor):
batch_size, num_patches, _ = x.shape
xq = self.q_proj(x)
xk = self.k_proj(x)
xv = self.v_proj(x)
xq = xq.view(batch_size, num_patches, self.num_attention_heads, self.head_dim).transpose(1, 2)
xk = xk.view(batch_size, num_patches, self.num_attention_heads, self.head_dim).transpose(1, 2)
xv = xv.view(batch_size, num_patches, self.num_attention_heads, self.head_dim).transpose(1, 2)
# attn_weights = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
# attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(xq.dtype)
# attn_output = torch.matmul(attn_weights, xv)
# attn_output = attn_output.transpose(1, 2).contiguous()
# attn_output = attn_output.view(batch_size, num_patches, -1)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query=xq,
key=xk,
value=xv,
attn_mask=attention_mask,
dropout_p=self.dropout_p,
is_causal=False
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, num_patches, -1)
attn_output = self.out_proj(attn_output)
return attn_output, None
class SigLIPTransformerMLP(nn.Module):
def __init__(self, cfg: SigLIPConfig):
super().__init__()
self.cfg = cfg
self.fc1 = nn.Linear(cfg.hidden_size, cfg.intermediate_size)
self.fc2 = nn.Linear(cfg.intermediate_size, cfg.hidden_size)
def forward(self, x: torch.Tensor):
x = self.fc1(x)
x = F.gelu(x, approximate='tanh')
x = self.fc2(x)
return x
class SigLIPTransformerBlock(nn.Module):
def __init__(self, cfg: SigLIPConfig):
super().__init__()
self.layer_norm1 = nn.LayerNorm(cfg.hidden_size, eps=cfg.norm_eps)
self.layer_norm2 = nn.LayerNorm(cfg.hidden_size, eps=cfg.norm_eps)
self.self_attn = SigLIPTransformerAttention(cfg)
self.mlp = SigLIPTransformerMLP(cfg)
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
residual = x
x = self.layer_norm1(x)
x = residual + self.self_attn(x, attention_mask)[0]
residual = x
x = self.layer_norm2(x)
x = residual + self.mlp(x)
return x
class SigLIPTransformerEncoder(nn.Module):
def __init__(self, cfg: SigLIPConfig):
super().__init__()
self.cfg = cfg
self.layers = nn.ModuleList(
[SigLIPTransformerBlock(cfg) for _ in range(cfg.num_hidden_layers)]
)
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
for layer in self.layers:
x = layer(x, attention_mask)
return x
class SigLIPModel(nn.Module):
def __init__(self, cfg: SigLIPConfig):
super().__init__()
self.embeddings = SigLIPEmbedding(cfg)
self.encoder = SigLIPTransformerEncoder(cfg)
self.post_layernorm = nn.LayerNorm(cfg.hidden_size, eps=cfg.norm_eps)
def forward(self, x: torch.Tensor):
img_embed = self.embeddings(x)
output = self.encoder(img_embed)
output = self.post_layernorm(output)
return output
class SigLIPVisionTower(nn.Module):
def __init__(self, cfg: SigLIPConfig):
super().__init__()
self.cfg = cfg
self.vision_model = SigLIPModel(cfg)
def forward(self, x: torch.Tensor):
return self.vision_model(x)