import torch import torch.nn as nn from torch import Tensor import math from models.positional_embeddings import PositionalEmbedding, FourierEmbedding from einops import rearrange torch.fx.wrap("rearrange") from typing import Tuple, Optional from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1 allow_ops_in_compiled_graph() class FusedMLP(nn.Sequential): def __init__( self, dim_model: int, dropout: float, activation: nn.Module, hidden_layer_multiplier: int = 4, bias: bool = True, ): super().__init__( nn.Linear(dim_model, dim_model * hidden_layer_multiplier, bias=bias), activation(), nn.Dropout(dropout), nn.Linear(dim_model * hidden_layer_multiplier, dim_model, bias=bias), ) def _cast_if_autocast_enabled(tensor): if torch.is_autocast_enabled(): if tensor.device.type == "cuda": dtype = torch.get_autocast_gpu_dtype() elif tensor.device.type == "cpu": dtype = torch.get_autocast_cpu_dtype() else: raise NotImplementedError() return tensor.to(dtype=dtype) return tensor class LayerNorm16Bits(torch.nn.LayerNorm): """ 16-bit friendly version of torch.nn.LayerNorm """ def __init__( self, normalized_shape, eps=1e-06, elementwise_affine=True, device=None, dtype=None, ): super().__init__( normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype, ) def forward(self, x): module_device = x.device downcast_x = _cast_if_autocast_enabled(x) downcast_weight = ( _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight ) downcast_bias = ( _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias ) with torch.autocast(enabled=False, device_type=module_device.type): return nn.functional.layer_norm( downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps, ) class StochatichDepth(nn.Module): def __init__(self, p: float): super().__init__() self.survival_prob = 1.0 - p def forward(self, x: Tensor) -> Tensor: if self.training and self.survival_prob < 1: mask = ( torch.empty(x.shape[0], 1, 1, device=x.device).uniform_() + self.survival_prob ) mask = mask.floor() if self.survival_prob > 0: mask = mask / self.survival_prob return x * mask else: return x class CrossAttentionOp(nn.Module): def __init__( self, attention_dim, num_heads, dim_q, dim_kv, use_biases=True, is_sa=False ): super().__init__() self.dim_q = dim_q self.dim_kv = dim_kv self.attention_dim = attention_dim self.num_heads = num_heads self.use_biases = use_biases self.is_sa = is_sa if self.is_sa: self.qkv = nn.Linear(dim_q, attention_dim * 3, bias=use_biases) else: self.q = nn.Linear(dim_q, attention_dim, bias=use_biases) self.kv = nn.Linear(dim_kv, attention_dim * 2, bias=use_biases) self.out = nn.Linear(attention_dim, dim_q, bias=use_biases) def forward(self, x_to, x_from=None, attention_mask=None, materialize_sdpa=False): if x_from is None: x_from = x_to if self.is_sa: q, k, v = self.qkv(x_to).chunk(3, dim=-1) else: q = self.q(x_to) k, v = self.kv(x_from).chunk(2, dim=-1) q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads) k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads) v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads) if attention_mask is not None: attention_mask = attention_mask.unsqueeze(1) if materialize_sdpa: x = self.materialize_sdpa(q, k, v, attention_mask) else: x = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask ) x = rearrange(x, "b h n d -> b n (h d)") x = self.out(x) return x def materialize_sdpa(self, q, k, v, attn_mask=None): scale = 1.0 / math.sqrt(q.shape[-1]) attn_matrix = torch.einsum("b h i d, b h j d -> b h i j", q, k) * scale if attn_mask is not None: attn_matrix = attn_matrix * attn_mask attn_matrix = torch.nn.functional.softmax(attn_matrix, dim=-1) return torch.einsum("b h i j, b h j d -> b h i d", attn_matrix, v) class CrossAttentionBlock(nn.Module): def __init__( self, dim_q: int, dim_kv: int, num_heads: int, attention_dim: int = 0, mlp_multiplier: int = 4, dropout: float = 0.0, stochastic_depth: float = 0.0, use_biases: bool = True, retrieve_attention_scores: bool = False, use_16_bits_layer_norm: bool = False, ): super().__init__() if use_16_bits_layer_norm and not retrieve_attention_scores: LayerNorm = LayerNorm16Bits else: LayerNorm = nn.LayerNorm self.retrieve_attention_scores = retrieve_attention_scores self.initial_to_ln = LayerNorm(dim_q, eps=1e-6) attention_dim = min(dim_q, dim_kv) if attention_dim == 0 else attention_dim self.ca = CrossAttentionOp( attention_dim, num_heads, dim_q, dim_kv, is_sa=False, use_biases=use_biases ) self.ca_stochastic_depth = StochatichDepth(stochastic_depth) self.middle_ln = LayerNorm(dim_q, eps=1e-6) self.ffn = FusedMLP( dim_model=dim_q, dropout=dropout, activation=nn.GELU, hidden_layer_multiplier=mlp_multiplier, bias=use_biases, ) self.ffn_stochastic_depth = StochatichDepth(stochastic_depth) self.register_parameter( "attention_mask_dummy", nn.Parameter(torch.ones(1, 1, dtype=torch.bool), requires_grad=False), ) def forward( self, to_tokens: Tensor, from_tokens: Tensor, to_token_mask: Optional[Tensor] = None, from_token_mask: Optional[Tensor] = None, ) -> Tensor: if to_token_mask is None and from_token_mask is None: attention_mask = None else: if to_token_mask is None: to_token_mask = self.attention_mask_dummy.expand( to_tokens.shape[0], to_tokens.shape[1], ) if from_token_mask is None: from_token_mask = self.attention_mask_dummy.expand( from_tokens.shape[0], from_tokens.shape[1], ) attention_mask = from_token_mask.unsqueeze(1) * to_token_mask.unsqueeze(2) if self.retrieve_attention_scores: attention_output = self.ca( self.initial_to_ln(to_tokens), from_tokens, attention_mask=attention_mask, materialize_sdpa=True, ) else: attention_output = self.ca( self.initial_to_ln(to_tokens), from_tokens, attention_mask=attention_mask, ) to_tokens = to_tokens + self.ca_stochastic_depth(attention_output) to_tokens = to_tokens + self.ffn_stochastic_depth( self.ffn(self.middle_ln(to_tokens)) ) return to_tokens class SelfAttentionBlock(nn.Module): def __init__( self, dim_qkv: int, num_heads: int, attention_dim: int = 0, mlp_multiplier: int = 4, dropout: float = 0.0, stochastic_depth: float = 0.0, use_biases: bool = True, use_layer_scale: bool = False, layer_scale_value: float = 0.1, retrieve_attention_scores: bool = False, use_16_bits_layer_norm: bool = False, ): super().__init__() if use_16_bits_layer_norm and not retrieve_attention_scores: LayerNorm = LayerNorm16Bits else: LayerNorm = nn.LayerNorm self.retrieve_attention_scores = retrieve_attention_scores self.initial_ln = LayerNorm(dim_qkv, eps=1e-6) attention_dim = dim_qkv if attention_dim == 0 else attention_dim self.sa = CrossAttentionOp( attention_dim, num_heads, dim_qkv, dim_qkv, is_sa=True, use_biases=use_biases, ) self.sa_stochastic_depth = StochatichDepth(stochastic_depth) self.middle_ln = LayerNorm(dim_qkv, eps=1e-6) self.ffn = FusedMLP( dim_model=dim_qkv, dropout=dropout, activation=nn.GELU, hidden_layer_multiplier=mlp_multiplier, bias=use_biases, ) self.ffn_stochastic_depth = StochatichDepth(stochastic_depth) self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter( torch.ones(dim_qkv) * layer_scale_value, requires_grad=True ) self.layer_scale_2 = nn.Parameter( torch.ones(dim_qkv) * layer_scale_value, requires_grad=True ) self.register_parameter( "attention_mask_dummy", nn.Parameter(torch.ones(1, 1, dtype=torch.bool), requires_grad=False), ) def forward( self, tokens: torch.Tensor, token_mask: Optional[torch.Tensor] = None, ): if token_mask is None: attention_mask = None else: attention_mask = token_mask.unsqueeze(1) * self.attention_mask_dummy.expand( tokens.shape[0], tokens.shape[1], ).unsqueeze(2) if self.retrieve_attention_scores: attention_output = self.sa( self.initial_ln(tokens), attention_mask=attention_mask, materialize_sdpa=True, ) else: attention_output = self.sa( self.initial_ln(tokens), attention_mask=attention_mask, ) if self.use_layer_scale: tokens = tokens + self.sa_stochastic_depth( self.layer_scale_1 * attention_output ) tokens = tokens + self.ffn_stochastic_depth( self.layer_scale_2 * self.ffn(self.middle_ln(tokens)) ) else: tokens = tokens + self.sa_stochastic_depth(attention_output) tokens = tokens + self.ffn_stochastic_depth( self.ffn(self.middle_ln(tokens)) ) return tokens