# Copyright (c) Meta Platforms, Inc. and affiliates. # # Contents of this file were adapted from the open source fairseq repository. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import Dict, List, Optional import torch import torch.nn as nn import torch.nn.functional as F from esm.multihead_attention import MultiheadAttention from torch import Tensor class TransformerEncoderLayer(nn.Module): """Encoder layer block. `layernorm -> dropout -> add residual` Args: args (argparse.Namespace): parsed command-line arguments """ def __init__(self, args): super().__init__() self.args = args self.embed_dim = args.encoder_embed_dim self.self_attn = self.build_self_attention(self.embed_dim, args) self.self_attn_layer_norm = torch.nn.LayerNorm(self.embed_dim) self.dropout_module = nn.Dropout(args.dropout) self.activation_fn = F.relu self.fc1 = self.build_fc1( self.embed_dim, args.encoder_ffn_embed_dim, ) self.fc2 = self.build_fc2( args.encoder_ffn_embed_dim, self.embed_dim, ) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def build_fc1(self, input_dim, output_dim): return nn.Linear(input_dim, output_dim) def build_fc2(self, input_dim, output_dim): return nn.Linear(input_dim, output_dim) def build_self_attention(self, embed_dim, args): return MultiheadAttention( embed_dim, args.encoder_attention_heads, dropout=args.attention_dropout, self_attention=True, ) def residual_connection(self, x, residual): return residual + x def forward( self, x, encoder_padding_mask: Optional[Tensor], attn_mask: Optional[Tensor] = None, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor): binary ByteTensor of shape `(batch, seq_len)` where padding elements are indicated by ``1``. attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, where `tgt_len` is the length of output and `src_len` is the length of input, though here both are equal to `seq_len`. `attn_mask[tgt_i, src_j] = 1` means that when calculating the embedding for `tgt_i`, we exclude (mask out) `src_j`. This is useful for strided self-attention. Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ # anything in original attn_mask = 1, becomes -1e8 # anything in original attn_mask = 0, becomes 0 # Note that we cannot use -inf here, because at some edge cases, # the attention weight (before softmax) for some padded element in query # will become -inf, which results in NaN in model parameters if attn_mask is not None: attn_mask = attn_mask.masked_fill( attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 ) residual = x x = self.self_attn_layer_norm(x) x, _ = self.self_attn( query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=False, attn_mask=attn_mask, ) x = self.dropout_module(x) x = self.residual_connection(x, residual) residual = x x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) x = self.fc2(x) x = self.dropout_module(x) x = self.residual_connection(x, residual) return x class TransformerDecoderLayer(nn.Module): """Decoder layer block. `layernorm -> dropout -> add residual` Args: args (argparse.Namespace): parsed command-line arguments no_encoder_attn (bool, optional): whether to attend to encoder outputs (default: False). """ def __init__( self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False ): super().__init__() self.embed_dim = args.decoder_embed_dim self.dropout_module = nn.Dropout(args.dropout) self.self_attn = self.build_self_attention( self.embed_dim, args, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ) self.nh = self.self_attn.num_heads self.head_dim = self.self_attn.head_dim self.activation_fn = F.relu self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) if no_encoder_attn: self.encoder_attn = None self.encoder_attn_layer_norm = None else: self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.ffn_layernorm = ( LayerNorm(args.decoder_ffn_embed_dim) if getattr(args, "scale_fc", False) else None ) self.w_resid = ( nn.Parameter( torch.ones( self.embed_dim, ), requires_grad=True, ) if getattr(args, "scale_resids", False) else None ) self.fc1 = self.build_fc1( self.embed_dim, args.decoder_ffn_embed_dim, ) self.fc2 = self.build_fc2( args.decoder_ffn_embed_dim, self.embed_dim, ) self.final_layer_norm = nn.LayerNorm(self.embed_dim) self.need_attn = True def build_fc1(self, input_dim, output_dim): return nn.Linear(input_dim, output_dim) def build_fc2(self, input_dim, output_dim): return nn.Linear(input_dim, output_dim) def build_self_attention( self, embed_dim, args, add_bias_kv=False, add_zero_attn=False ): return MultiheadAttention( embed_dim, args.decoder_attention_heads, dropout=args.attention_dropout, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=True, ) def build_encoder_attention(self, embed_dim, args): return MultiheadAttention( embed_dim, args.decoder_attention_heads, kdim=args.encoder_embed_dim, vdim=args.encoder_embed_dim, dropout=args.attention_dropout, encoder_decoder_attention=True, ) def residual_connection(self, x, residual): return residual + x def forward( self, x, encoder_out: Optional[torch.Tensor] = None, encoder_padding_mask: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, prev_self_attn_state: Optional[List[torch.Tensor]] = None, prev_attn_state: Optional[List[torch.Tensor]] = None, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, need_attn: bool = False, need_head_weights: bool = False, ): """ Args: x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` encoder_padding_mask (ByteTensor, optional): binary ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``. need_attn (bool, optional): return attention weights need_head_weights (bool, optional): return attention weights for each head (default: return average over heads). Returns: encoded output of shape `(seq_len, batch, embed_dim)` """ if need_head_weights: need_attn = True residual = x x = self.self_attn_layer_norm(x) if prev_self_attn_state is not None: prev_key, prev_value = prev_self_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_self_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] assert incremental_state is not None self.self_attn._set_input_buffer(incremental_state, saved_state) _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) y = x x, attn = self.self_attn( query=x, key=y, value=y, key_padding_mask=self_attn_padding_mask, incremental_state=incremental_state, need_weights=False, attn_mask=self_attn_mask, ) x = self.dropout_module(x) x = self.residual_connection(x, residual) if self.encoder_attn is not None and encoder_out is not None: residual = x x = self.encoder_attn_layer_norm(x) if prev_attn_state is not None: prev_key, prev_value = prev_attn_state[:2] saved_state: Dict[str, Optional[Tensor]] = { "prev_key": prev_key, "prev_value": prev_value, } if len(prev_attn_state) >= 3: saved_state["prev_key_padding_mask"] = prev_attn_state[2] assert incremental_state is not None self.encoder_attn._set_input_buffer(incremental_state, saved_state) x, attn = self.encoder_attn( query=x, key=encoder_out, value=encoder_out, key_padding_mask=encoder_padding_mask, incremental_state=incremental_state, static_kv=True, need_weights=need_attn or (not self.training and self.need_attn), need_head_weights=need_head_weights, ) x = self.dropout_module(x) x = self.residual_connection(x, residual) residual = x x = self.final_layer_norm(x) x = self.activation_fn(self.fc1(x)) if self.ffn_layernorm is not None: x = self.ffn_layernorm(x) x = self.fc2(x) x = self.dropout_module(x) if self.w_resid is not None: residual = torch.mul(self.w_resid, residual) x = self.residual_connection(x, residual) return x, attn, None