| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| | from diffusers.configuration_utils import ConfigMixin, register_to_config |
| | from diffusers.models.modeling_utils import ModelMixin |
| |
|
| |
|
| | def rotate_half(x): |
| | x1 = x[..., : x.shape[-1] // 2] |
| | x2 = x[..., x.shape[-1] // 2 :] |
| | return torch.cat((-x2, x1), dim=-1) |
| |
|
| |
|
| | def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): |
| | cos = cos.unsqueeze(unsqueeze_dim) |
| | sin = sin.unsqueeze(unsqueeze_dim) |
| | return (x * cos) + (rotate_half(x) * sin) |
| |
|
| |
|
| | class RotaryEmbedding(nn.Module): |
| | def __init__(self, head_dim): |
| | super().__init__() |
| | self.rope_theta = 10000 |
| | inv_freq = 1.0 / ( |
| | self.rope_theta |
| | ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim) |
| | ) |
| | self.register_buffer("inv_freq", inv_freq, persistent=False) |
| |
|
| | @torch.no_grad() |
| | def forward(self, x, position_ids): |
| | inv_freq_expanded = ( |
| | self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
| | ) |
| | position_ids_expanded = position_ids[:, None, :].float() |
| |
|
| | device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| | with torch.autocast(device_type=device_type, enabled=False): |
| | freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| | emb = torch.cat((freqs, freqs), dim=-1) |
| | cos = emb.cos() |
| | sin = emb.sin() |
| | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
| |
|
| |
|
| | class Attention(nn.Module): |
| | def __init__(self, query_dim, context_dim, n_heads, head_dim): |
| | super().__init__() |
| | inner_dim = head_dim * n_heads |
| | self.n_heads = n_heads |
| | self.head_dim = head_dim |
| |
|
| | self.q_proj = nn.Linear(query_dim, inner_dim, bias=False) |
| | self.q_norm = nn.RMSNorm(head_dim, eps=1e-6) |
| | self.k_proj = nn.Linear(context_dim, inner_dim, bias=False) |
| | self.k_norm = nn.RMSNorm(head_dim, eps=1e-6) |
| | self.v_proj = nn.Linear(context_dim, inner_dim, bias=False) |
| | self.o_proj = nn.Linear(inner_dim, query_dim, bias=False) |
| |
|
| | def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None): |
| | context = x if context is None else context |
| | input_shape = x.shape[:-1] |
| | q_shape = (*input_shape, self.n_heads, self.head_dim) |
| | context_shape = context.shape[:-1] |
| | kv_shape = (*context_shape, self.n_heads, self.head_dim) |
| |
|
| | query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2) |
| | key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2) |
| | value_states = self.v_proj(context).view(kv_shape).transpose(1, 2) |
| |
|
| | if position_embeddings is not None: |
| | assert position_embeddings_context is not None |
| | cos, sin = position_embeddings |
| | query_states = apply_rotary_pos_emb(query_states, cos, sin) |
| | cos, sin = position_embeddings_context |
| | key_states = apply_rotary_pos_emb(key_states, cos, sin) |
| |
|
| | attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask) |
| | attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous() |
| | return self.o_proj(attn_output) |
| |
|
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=True): |
| | super().__init__() |
| | self.use_self_attn = use_self_attn |
| |
|
| | if self.use_self_attn: |
| | self.norm_self_attn = nn.RMSNorm(model_dim, eps=1e-6) |
| | self.self_attn = Attention( |
| | query_dim=model_dim, |
| | context_dim=model_dim, |
| | n_heads=num_heads, |
| | head_dim=model_dim // num_heads, |
| | ) |
| |
|
| | self.norm_cross_attn = nn.RMSNorm(model_dim, eps=1e-6) |
| | self.cross_attn = Attention( |
| | query_dim=model_dim, |
| | context_dim=source_dim, |
| | n_heads=num_heads, |
| | head_dim=model_dim // num_heads, |
| | ) |
| |
|
| | self.norm_mlp = nn.RMSNorm(model_dim, eps=1e-6) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(model_dim, int(model_dim * mlp_ratio)), |
| | nn.GELU(), |
| | nn.Linear(int(model_dim * mlp_ratio), model_dim), |
| | ) |
| |
|
| | def forward( |
| | self, |
| | x, |
| | context, |
| | target_attention_mask=None, |
| | source_attention_mask=None, |
| | position_embeddings=None, |
| | position_embeddings_context=None, |
| | ): |
| | if self.use_self_attn: |
| | normed = self.norm_self_attn(x) |
| | attn_out = self.self_attn( |
| | normed, |
| | mask=target_attention_mask, |
| | position_embeddings=position_embeddings, |
| | position_embeddings_context=position_embeddings, |
| | ) |
| | x = x + attn_out |
| |
|
| | normed = self.norm_cross_attn(x) |
| | attn_out = self.cross_attn( |
| | normed, |
| | mask=source_attention_mask, |
| | context=context, |
| | position_embeddings=position_embeddings, |
| | position_embeddings_context=position_embeddings_context, |
| | ) |
| | x = x + attn_out |
| | x = x + self.mlp(self.norm_mlp(x)) |
| | return x |
| |
|
| |
|
| | class AnimaLLMAdapter(ModelMixin, ConfigMixin): |
| | @register_to_config |
| | def __init__( |
| | self, |
| | source_dim: int = 1024, |
| | target_dim: int = 1024, |
| | model_dim: int = 1024, |
| | num_layers: int = 6, |
| | num_heads: int = 16, |
| | mlp_ratio: float = 4.0, |
| | vocab_size: int = 32128, |
| | use_self_attn: bool = True, |
| | ): |
| | super().__init__() |
| |
|
| | self.embed = nn.Embedding(vocab_size, target_dim) |
| | if model_dim != target_dim: |
| | self.in_proj = nn.Linear(target_dim, model_dim) |
| | else: |
| | self.in_proj = nn.Identity() |
| | self.rotary_emb = RotaryEmbedding(model_dim // num_heads) |
| | self.blocks = nn.ModuleList( |
| | [ |
| | TransformerBlock( |
| | source_dim, |
| | model_dim, |
| | num_heads=num_heads, |
| | mlp_ratio=mlp_ratio, |
| | use_self_attn=use_self_attn, |
| | ) |
| | for _ in range(num_layers) |
| | ] |
| | ) |
| | self.out_proj = nn.Linear(model_dim, target_dim) |
| | self.norm = nn.RMSNorm(target_dim, eps=1e-6) |
| |
|
| | def forward( |
| | self, |
| | source_hidden_states: torch.Tensor, |
| | target_input_ids: torch.Tensor, |
| | target_attention_mask: torch.Tensor = None, |
| | source_attention_mask: torch.Tensor = None, |
| | ) -> torch.Tensor: |
| | if target_attention_mask is not None: |
| | target_attention_mask = target_attention_mask.to(torch.bool) |
| | if target_attention_mask.ndim == 2: |
| | target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1) |
| |
|
| | if source_attention_mask is not None: |
| | source_attention_mask = source_attention_mask.to(torch.bool) |
| | if source_attention_mask.ndim == 2: |
| | source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1) |
| |
|
| | x = self.in_proj(self.embed(target_input_ids)) |
| | context = source_hidden_states |
| |
|
| | position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0) |
| | position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0) |
| | position_embeddings = self.rotary_emb(x, position_ids) |
| | position_embeddings_context = self.rotary_emb(x, position_ids_context) |
| |
|
| | for block in self.blocks: |
| | x = block( |
| | x, |
| | context, |
| | target_attention_mask=target_attention_mask, |
| | source_attention_mask=source_attention_mask, |
| | position_embeddings=position_embeddings, |
| | position_embeddings_context=position_embeddings_context, |
| | ) |
| |
|
| | return self.norm(self.out_proj(x)) |
| |
|