# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F from modules.general.utils import Conv1d, normalization, zero_module from .basic import UNetBlock class AttentionBlock(UNetBlock): r"""A spatial transformer encoder block that allows spatial positions to attend to each other. Reference from `latent diffusion repo `_. Args: channels: Number of channels in the input. num_head_channels: Number of channels per attention head. num_heads: Number of attention heads. Overrides ``num_head_channels`` if set. encoder_channels: Number of channels in the encoder output for cross-attention. If ``None``, then self-attention is performed. use_self_attention: Whether to use self-attention before cross-attention, only applicable if encoder_channels is set. dims: Number of spatial dimensions, i.e. 1 for temporal signals, 2 for images. h_dim: The dimension of the height, would be applied if ``dims`` is 2. encoder_hdim: The dimension of the height of the encoder output, would be applied if ``dims`` is 2. p_dropout: Dropout probability. """ def __init__( self, channels: int, num_head_channels: int = 32, num_heads: int = -1, encoder_channels: int = None, use_self_attention: bool = False, dims: int = 1, h_dim: int = 100, encoder_hdim: int = 384, p_dropout: float = 0.0, ): super().__init__() self.channels = channels self.p_dropout = p_dropout self.dims = dims if dims == 1: self.channels = channels elif dims == 2: # We consider the channel as product of channel and height, i.e. C x H # This is because we want to apply attention on the audio signal, which is 1D self.channels = channels * h_dim else: raise ValueError(f"invalid number of dimensions: {dims}") if num_head_channels == -1: assert ( self.channels % num_heads == 0 ), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}" self.num_heads = num_heads self.num_head_channels = self.channels // num_heads else: assert ( self.channels % num_head_channels == 0 ), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = self.channels // num_head_channels self.num_head_channels = num_head_channels if encoder_channels is not None: self.use_self_attention = use_self_attention if dims == 1: self.encoder_channels = encoder_channels elif dims == 2: self.encoder_channels = encoder_channels * encoder_hdim else: raise ValueError(f"invalid number of dimensions: {dims}") if use_self_attention: self.self_attention = BasicAttentionBlock( self.channels, self.num_head_channels, self.num_heads, p_dropout=self.p_dropout, ) self.cross_attention = BasicAttentionBlock( self.channels, self.num_head_channels, self.num_heads, self.encoder_channels, p_dropout=self.p_dropout, ) else: self.encoder_channels = None self.self_attention = BasicAttentionBlock( self.channels, self.num_head_channels, self.num_heads, p_dropout=self.p_dropout, ) def forward(self, x: torch.Tensor, encoder_output: torch.Tensor = None): r""" Args: x: input tensor with shape [B x ``channels`` x ...] encoder_output: feature tensor with shape [B x ``encoder_channels`` x ...], if ``None``, then self-attention is performed. Returns: output tensor with shape [B x ``channels`` x ...] """ shape = x.size() x = x.reshape(shape[0], self.channels, -1).contiguous() if self.encoder_channels is None: assert ( encoder_output is None ), "encoder_output must be None for self-attention." h = self.self_attention(x) else: assert ( encoder_output is not None ), "encoder_output must be given for cross-attention." encoder_output = encoder_output.reshape( shape[0], self.encoder_channels, -1 ).contiguous() if self.use_self_attention: x = self.self_attention(x) h = self.cross_attention(x, encoder_output) return h.reshape(*shape).contiguous() class BasicAttentionBlock(nn.Module): def __init__( self, channels: int, num_head_channels: int = 32, num_heads: int = -1, context_channels: int = None, p_dropout: float = 0.0, ): super().__init__() self.channels = channels self.p_dropout = p_dropout self.context_channels = context_channels if num_head_channels == -1: assert ( self.channels % num_heads == 0 ), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}" self.num_heads = num_heads self.num_head_channels = self.channels // num_heads else: assert ( self.channels % num_head_channels == 0 ), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = self.channels // num_head_channels self.num_head_channels = num_head_channels if context_channels is not None: self.to_q = nn.Sequential( normalization(self.channels), Conv1d(self.channels, self.channels, 1), ) self.to_kv = Conv1d(context_channels, 2 * self.channels, 1) else: self.to_qkv = nn.Sequential( normalization(self.channels), Conv1d(self.channels, 3 * self.channels, 1), ) self.linear = Conv1d(self.channels, self.channels) self.proj_out = nn.Sequential( normalization(self.channels), Conv1d(self.channels, self.channels, 1), nn.GELU(), nn.Dropout(p=self.p_dropout), zero_module(Conv1d(self.channels, self.channels, 1)), ) def forward(self, q: torch.Tensor, kv: torch.Tensor = None): r""" Args: q: input tensor with shape [B, ``channels``, L] kv: feature tensor with shape [B, ``context_channels``, T], if ``None``, then self-attention is performed. Returns: output tensor with shape [B, ``channels``, L] """ N, C, L = q.size() if self.context_channels is not None: assert kv is not None, "kv must be given for cross-attention." q = ( self.to_q(q) .reshape(self.num_heads, self.num_head_channels, -1) .transpose(-1, -2) .contiguous() ) kv = ( self.to_kv(kv) .reshape(2, self.num_heads, self.num_head_channels, -1) .transpose(-1, -2) .chunk(2) ) k, v = ( kv[0].squeeze(0).contiguous(), kv[1].squeeze(0).contiguous(), ) else: qkv = ( self.to_qkv(q) .reshape(3, self.num_heads, self.num_head_channels, -1) .transpose(-1, -2) .chunk(3) ) q, k, v = ( qkv[0].squeeze(0).contiguous(), qkv[1].squeeze(0).contiguous(), qkv[2].squeeze(0).contiguous(), ) h = F.scaled_dot_product_attention(q, k, v, dropout_p=self.p_dropout).transpose( -1, -2 ) h = h.reshape(N, -1, L).contiguous() h = self.linear(h) x = q + h h = self.proj_out(x) return x + h