# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py import copy import numbers from functools import partial from typing import Any from typing import Callable from typing import List from typing import Optional from typing import Tuple from typing import Union import torch from AR.modules.activation_onnx import MultiheadAttention from AR.modules.scaling import BalancedDoubleSwish from torch import nn from torch import Tensor from torch.nn import functional as F _shape_t = Union[int, List[int], torch.Size] class LayerNorm(nn.Module): __constants__ = ["normalized_shape", "eps", "elementwise_affine"] normalized_shape: Tuple[int, ...] eps: float elementwise_affine: bool def __init__( self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): # mypy error: incompatible types in assignment normalized_shape = (normalized_shape,) # type: ignore[assignment] self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter( torch.empty(self.normalized_shape, **factory_kwargs) ) self.bias = nn.Parameter( torch.empty(self.normalized_shape, **factory_kwargs) ) else: self.register_parameter("weight", None) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self) -> None: if self.elementwise_affine: nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, input: Tensor, embedding: Any = None) -> Tensor: if isinstance(input, tuple): input, embedding = input return ( F.layer_norm( input, self.normalized_shape, self.weight, self.bias, self.eps, ), embedding, ) assert embedding is None return F.layer_norm( input, self.normalized_shape, self.weight, self.bias, self.eps ) def extra_repr(self) -> str: return ( "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) ) class IdentityNorm(nn.Module): def __init__( self, d_model: int, eps: float = 1e-5, device=None, dtype=None, ) -> None: super(IdentityNorm, self).__init__() def forward(self, input: Tensor, embedding: Any = None) -> Tensor: if isinstance(input, tuple): return input assert embedding is None return input class TransformerEncoder(nn.Module): r"""TransformerEncoder is a stack of N encoder layers. Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. Args: encoder_layer: an instance of the TransformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). norm: the layer normalization component (optional). enable_nested_tensor: if True, input will automatically convert to nested tensor (and convert back on output). This will improve the overall performance of TransformerEncoder when padding rate is high. Default: ``True`` (enabled). Examples:: >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) >>> out = transformer_encoder(src) """ __constants__ = ["norm"] def __init__(self, encoder_layer, num_layers, norm=None): super(TransformerEncoder, self).__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm def forward( self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, return_layer_states: bool = False, cache=None, ) -> Tensor: output = src for mod in self.layers: output = mod( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, cache=cache, ) if self.norm is not None: output = self.norm(output) return output class TransformerEncoderLayer(nn.Module): __constants__ = ["batch_first", "norm_first"] def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None, linear1_self_attention_cls: nn.Module = nn.Linear, linear2_self_attention_cls: nn.Module = nn.Linear, linear1_feedforward_cls: nn.Module = nn.Linear, linear2_feedforward_cls: nn.Module = nn.Linear, layer_norm_cls: nn.Module = LayerNorm, layer_norm_eps: float = 1e-5, adaptive_layer_norm=False, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super(TransformerEncoderLayer, self).__init__() self.self_attn = MultiheadAttention( d_model, # 512 16 nhead, dropout=dropout, batch_first=batch_first, linear1_cls=linear1_self_attention_cls, linear2_cls=linear2_self_attention_cls, **factory_kwargs, ) self.linear1 = linear1_feedforward_cls( d_model, dim_feedforward, **factory_kwargs ) self.dropout = nn.Dropout(dropout) self.linear2 = linear2_feedforward_cls( dim_feedforward, d_model, **factory_kwargs ) self.norm_first = norm_first self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) if isinstance(activation, str): activation = _get_activation_fn(activation) elif isinstance(activation, partial): activation = activation(d_model) elif activation == BalancedDoubleSwish: activation = BalancedDoubleSwish(d_model) self.activation = activation norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) if layer_norm_cls == IdentityNorm: norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) else: norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) if adaptive_layer_norm: self.norm1 = AdaptiveLayerNorm(d_model, norm1) self.norm2 = AdaptiveLayerNorm(d_model, norm2) else: self.norm1 = norm1 self.norm2 = norm2 def __setstate__(self, state): super(TransformerEncoderLayer, self).__setstate__(state) if not hasattr(self, "activation"): self.activation = F.relu def forward( self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, cache=None, ) -> Tensor: x = src stage_embedding = None x = self.norm1( x + self._sa_block(x, src_mask, src_key_padding_mask, cache=cache), stage_embedding, ) x = self.norm2(x + self._ff_block(x), stage_embedding) return x def _sa_block( self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], cache=None, ) -> Tensor: x = self.self_attn( x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, cache=cache, ) return self.dropout1(x) def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout2(x) class AdaptiveLayerNorm(nn.Module): r"""Adaptive Layer Normalization""" def __init__(self, d_model, norm) -> None: super(AdaptiveLayerNorm, self).__init__() self.project_layer = nn.Linear(d_model, 2 * d_model) self.norm = norm self.d_model = d_model self.eps = self.norm.eps def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: if isinstance(input, tuple): input, embedding = input weight, bias = torch.split( self.project_layer(embedding), split_size_or_sections=self.d_model, dim=-1, ) return (weight * self.norm(input) + bias, embedding) weight, bias = torch.split( self.project_layer(embedding), split_size_or_sections=self.d_model, dim=-1, ) return weight * self.norm(input) + bias def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)])