|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.init import trunc_normal_ |
|
from transformers.activations import ACT2FN |
|
|
|
|
|
class FFN(nn.Module): |
|
""" |
|
Feed-Forward Network module. |
|
|
|
Args: |
|
embed_dim (int): Input embedding dimension. |
|
ff_dim (int): Hidden dimension of the feed-forward network. |
|
output_dim (int): Output dimension. |
|
""" |
|
|
|
def __init__(self, embed_dim, ff_dim, output_dim): |
|
super().__init__() |
|
self.linear_in = nn.Linear(embed_dim, ff_dim, bias=False) |
|
self.linear_out = nn.Linear(ff_dim, output_dim, bias=False) |
|
self.act = ACT2FN["gelu_new"] |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.act(self.linear_in(hidden_states)) |
|
hidden_states = self.linear_out(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class CrossAttention(nn.Module): |
|
""" |
|
Cross-Attention module. |
|
|
|
Args: |
|
kv_dim (int): Dimension of key and value. |
|
embed_dim (int): Embedding dimension. |
|
num_heads (int): Number of attention heads. |
|
drop_out_rate (float): Dropout rate. Default is 0. |
|
""" |
|
|
|
def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) |
|
self.k_proj = nn.Linear(kv_dim, embed_dim, bias=False) |
|
self.v_proj = nn.Linear(kv_dim, embed_dim, bias=False) |
|
|
|
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) |
|
self.linear = nn.Linear(embed_dim, embed_dim) |
|
self.dropout = nn.Dropout(drop_out_rate) |
|
|
|
self.layer_norm = nn.LayerNorm(embed_dim) |
|
self.ln_kv = nn.LayerNorm(kv_dim) |
|
|
|
def forward(self, x, hidden_states, attn_mask=None, add_residual=False): |
|
""" |
|
Forward pass of the CrossAttention module. |
|
|
|
Args: |
|
x (torch.Tensor): Input tensor for key and value. |
|
hidden_states (torch.Tensor): Input tensor for query. |
|
attn_mask (torch.Tensor, optional): Attention mask. Default is None. |
|
add_residual (bool): Whether to add residual connection. Default is False. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor after cross-attention. |
|
""" |
|
normed_hidden_states = self.layer_norm(hidden_states) |
|
query = self.q_proj(normed_hidden_states).permute(1, 0, 2) |
|
|
|
x = self.ln_kv(x) |
|
key = self.k_proj(x).permute(1, 0, 2) |
|
value = self.v_proj(x).permute(1, 0, 2) |
|
|
|
attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask) |
|
|
|
attn_output = attn_output.permute(1, 0, 2) |
|
|
|
if add_residual: |
|
attn_output = hidden_states + self.dropout(self.linear(attn_output)) |
|
else: |
|
attn_output = self.dropout(self.linear(attn_output)) |
|
|
|
return attn_output |
|
|
|
|
|
class AriaProjector(nn.Module): |
|
""" |
|
A projection module with one cross attention layer and one FFN layer, which projects ViT's outputs into MoE's inputs. |
|
|
|
Args: |
|
patch_to_query_dict (dict): Maps patch numbers to their corresponding query numbers, |
|
e.g., {1225: 128, 4900: 256}. This allows for different query sizes based on image resolution. |
|
embed_dim (int): Embedding dimension. |
|
num_heads (int): Number of attention heads. |
|
kv_dim (int): Dimension of key and value. |
|
ff_dim (int): Hidden dimension of the feed-forward network. |
|
output_dim (int): Output dimension. |
|
norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm. |
|
|
|
Outputs: |
|
A tensor with the shape of (batch_size, query_number, output_dim) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
patch_to_query_dict, |
|
embed_dim, |
|
num_heads, |
|
kv_dim, |
|
ff_dim, |
|
output_dim, |
|
norm_layer=nn.LayerNorm, |
|
): |
|
super().__init__() |
|
self.patch_to_query_dict = patch_to_query_dict |
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
|
|
self.query = nn.Parameter( |
|
torch.zeros(max(patch_to_query_dict.values()), self.embed_dim) |
|
) |
|
|
|
trunc_normal_(self.query, std=0.02) |
|
|
|
self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads) |
|
|
|
self.ln_ffn = norm_layer(embed_dim) |
|
self.ffn = FFN(embed_dim, ff_dim, output_dim) |
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=0.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
def forward(self, x, attn_mask=None): |
|
""" |
|
Forward pass of the Projector module. |
|
|
|
Args: |
|
x (torch.Tensor): Input tensor of shape (batch_size, num_patches, kv_dim). |
|
attn_mask (torch.Tensor, optional): Attention mask. Default is None. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor of shape (batch_size, query_number, output_dim). |
|
""" |
|
bs = x.shape[0] |
|
queries = self.query.unsqueeze(0).repeat(bs, 1, 1) |
|
|
|
query_num = self.patch_to_query_dict.get(x.shape[1], None) |
|
assert ( |
|
query_num is not None |
|
), f"Query number for {x.shape[1]} patches is not provided" |
|
|
|
queries = queries[:, :query_num, :] |
|
|
|
if attn_mask is not None: |
|
attn_mask = attn_mask.repeat_interleave(self.num_heads, 0) |
|
attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1) |
|
|
|
attention_out = self.cross_attn(x, queries, attn_mask=attn_mask) |
|
|
|
out = self.ffn(self.ln_ffn(attention_out)) |
|
|
|
return out |
|
|