ymzhang319's picture
init
7f2690b
raw
history blame
13.6 kB
import torch
import torch.nn as nn
import torch.utils.checkpoint
from typing import Any, Optional, Tuple, Union
class Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, hidden_size, num_attention_heads, attention_head_dim, attention_dropout=0.0):
super().__init__()
self.embed_dim = hidden_size
self.num_heads = num_attention_heads
self.head_dim = attention_head_dim
self.scale = self.head_dim**-0.5
self.dropout = attention_dropout
self.inner_dim = self.head_dim * self.num_heads
self.k_proj = nn.Linear(self.embed_dim, self.inner_dim)
self.v_proj = nn.Linear(self.embed_dim, self.inner_dim)
self.q_proj = nn.Linear(self.embed_dim, self.inner_dim)
self.out_proj = nn.Linear(self.inner_dim, self.embed_dim)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scale
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
# apply the causal_attention_mask first
if causal_attention_mask is not None:
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {causal_attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped
class MLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, mult=4):
super().__init__()
self.activation_fn = nn.SiLU()
self.fc1 = nn.Linear(hidden_size, intermediate_size * mult)
self.fc2 = nn.Linear(intermediate_size * mult, hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Transformer(nn.Module):
def __init__(self, depth=12):
super().__init__()
self.layers = nn.ModuleList([TransformerBlock() for _ in range(depth)])
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor=None,
causal_attention_mask: torch.Tensor=None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
for layer in self.layers:
hidden_states = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
return hidden_states
class TransformerBlock(nn.Module):
def __init__(self, hidden_size=512, num_attention_heads=12, attention_head_dim=64, attention_dropout=0.0, dropout=0.0, eps=1e-5):
super().__init__()
self.embed_dim = hidden_size
self.self_attn = Attention(hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps)
self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor=None,
causal_attention_mask: torch.Tensor=None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs[0]
class DiffusionTransformerBlock(nn.Module):
def __init__(self, hidden_size=512, num_attention_heads=12, attention_head_dim=64, attention_dropout=0.0, dropout=0.0, eps=1e-5):
super().__init__()
self.embed_dim = hidden_size
self.self_attn = Attention(hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=eps)
self.mlp = MLP(hidden_size=hidden_size, intermediate_size=hidden_size)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=eps)
self.output_token = nn.Parameter(torch.randn(1, hidden_size))
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor=None,
causal_attention_mask: torch.Tensor=None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
output_token = self.output_token.unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)
hidden_states = torch.cat([output_token, hidden_states], dim=1)
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs[0][:,0:1,...]
class V2AMapperMLP(nn.Module):
def __init__(self, input_dim=512, output_dim=512, expansion_rate=4):
super().__init__()
self.linear = nn.Linear(input_dim, input_dim * expansion_rate)
self.silu = nn.SiLU()
self.layer_norm = nn.LayerNorm(input_dim * expansion_rate)
self.linear2 = nn.Linear(input_dim * expansion_rate, output_dim)
def forward(self, x):
x = self.linear(x)
x = self.silu(x)
x = self.layer_norm(x)
x = self.linear2(x)
return x
class ImageProjModel(torch.nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
self.zero_initialize_last_layer()
def zero_initialize_last_layer(module):
last_layer = None
for module_name, layer in module.named_modules():
if isinstance(layer, torch.nn.Linear):
last_layer = layer
if last_layer is not None:
last_layer.weight.data.zero_()
last_layer.bias.data.zero_()
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
class VisionAudioAdapter(torch.nn.Module):
def __init__(
self,
embedding_size=768,
expand_dim=4,
token_num=4,
):
super().__init__()
self.mapper = V2AMapperMLP(
embedding_size,
embedding_size,
expansion_rate=expand_dim,
)
self.proj = ImageProjModel(
cross_attention_dim=embedding_size,
clip_embeddings_dim=embedding_size,
clip_extra_context_tokens=token_num,
)
def forward(self, image_embeds):
image_embeds = self.mapper(image_embeds)
image_embeds = self.proj(image_embeds)
return image_embeds