Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# Contents of this file were adapted from the open source fairseq repository. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Dict, List, Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from esm.multihead_attention import MultiheadAttention | |
from torch import Tensor | |
class TransformerEncoderLayer(nn.Module): | |
"""Encoder layer block. | |
`layernorm -> dropout -> add residual` | |
Args: | |
args (argparse.Namespace): parsed command-line arguments | |
""" | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
self.embed_dim = args.encoder_embed_dim | |
self.self_attn = self.build_self_attention(self.embed_dim, args) | |
self.self_attn_layer_norm = torch.nn.LayerNorm(self.embed_dim) | |
self.dropout_module = nn.Dropout(args.dropout) | |
self.activation_fn = F.relu | |
self.fc1 = self.build_fc1( | |
self.embed_dim, | |
args.encoder_ffn_embed_dim, | |
) | |
self.fc2 = self.build_fc2( | |
args.encoder_ffn_embed_dim, | |
self.embed_dim, | |
) | |
self.final_layer_norm = nn.LayerNorm(self.embed_dim) | |
def build_fc1(self, input_dim, output_dim): | |
return nn.Linear(input_dim, output_dim) | |
def build_fc2(self, input_dim, output_dim): | |
return nn.Linear(input_dim, output_dim) | |
def build_self_attention(self, embed_dim, args): | |
return MultiheadAttention( | |
embed_dim, | |
args.encoder_attention_heads, | |
dropout=args.attention_dropout, | |
self_attention=True, | |
) | |
def residual_connection(self, x, residual): | |
return residual + x | |
def forward( | |
self, | |
x, | |
encoder_padding_mask: Optional[Tensor], | |
attn_mask: Optional[Tensor] = None, | |
): | |
""" | |
Args: | |
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` | |
encoder_padding_mask (ByteTensor): binary ByteTensor of shape | |
`(batch, seq_len)` where padding elements are indicated by ``1``. | |
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, | |
where `tgt_len` is the length of output and `src_len` is the | |
length of input, though here both are equal to `seq_len`. | |
`attn_mask[tgt_i, src_j] = 1` means that when calculating the | |
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is | |
useful for strided self-attention. | |
Returns: | |
encoded output of shape `(seq_len, batch, embed_dim)` | |
""" | |
# anything in original attn_mask = 1, becomes -1e8 | |
# anything in original attn_mask = 0, becomes 0 | |
# Note that we cannot use -inf here, because at some edge cases, | |
# the attention weight (before softmax) for some padded element in query | |
# will become -inf, which results in NaN in model parameters | |
if attn_mask is not None: | |
attn_mask = attn_mask.masked_fill( | |
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 | |
) | |
residual = x | |
x = self.self_attn_layer_norm(x) | |
x, _ = self.self_attn( | |
query=x, | |
key=x, | |
value=x, | |
key_padding_mask=encoder_padding_mask, | |
need_weights=False, | |
attn_mask=attn_mask, | |
) | |
x = self.dropout_module(x) | |
x = self.residual_connection(x, residual) | |
residual = x | |
x = self.final_layer_norm(x) | |
x = self.activation_fn(self.fc1(x)) | |
x = self.fc2(x) | |
x = self.dropout_module(x) | |
x = self.residual_connection(x, residual) | |
return x | |
class TransformerDecoderLayer(nn.Module): | |
"""Decoder layer block. | |
`layernorm -> dropout -> add residual` | |
Args: | |
args (argparse.Namespace): parsed command-line arguments | |
no_encoder_attn (bool, optional): whether to attend to encoder outputs | |
(default: False). | |
""" | |
def __init__( | |
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False | |
): | |
super().__init__() | |
self.embed_dim = args.decoder_embed_dim | |
self.dropout_module = nn.Dropout(args.dropout) | |
self.self_attn = self.build_self_attention( | |
self.embed_dim, | |
args, | |
add_bias_kv=add_bias_kv, | |
add_zero_attn=add_zero_attn, | |
) | |
self.nh = self.self_attn.num_heads | |
self.head_dim = self.self_attn.head_dim | |
self.activation_fn = F.relu | |
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) | |
if no_encoder_attn: | |
self.encoder_attn = None | |
self.encoder_attn_layer_norm = None | |
else: | |
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) | |
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) | |
self.ffn_layernorm = ( | |
LayerNorm(args.decoder_ffn_embed_dim) | |
if getattr(args, "scale_fc", False) | |
else None | |
) | |
self.w_resid = ( | |
nn.Parameter( | |
torch.ones( | |
self.embed_dim, | |
), | |
requires_grad=True, | |
) | |
if getattr(args, "scale_resids", False) | |
else None | |
) | |
self.fc1 = self.build_fc1( | |
self.embed_dim, | |
args.decoder_ffn_embed_dim, | |
) | |
self.fc2 = self.build_fc2( | |
args.decoder_ffn_embed_dim, | |
self.embed_dim, | |
) | |
self.final_layer_norm = nn.LayerNorm(self.embed_dim) | |
self.need_attn = True | |
def build_fc1(self, input_dim, output_dim): | |
return nn.Linear(input_dim, output_dim) | |
def build_fc2(self, input_dim, output_dim): | |
return nn.Linear(input_dim, output_dim) | |
def build_self_attention( | |
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False | |
): | |
return MultiheadAttention( | |
embed_dim, | |
args.decoder_attention_heads, | |
dropout=args.attention_dropout, | |
add_bias_kv=add_bias_kv, | |
add_zero_attn=add_zero_attn, | |
self_attention=True, | |
) | |
def build_encoder_attention(self, embed_dim, args): | |
return MultiheadAttention( | |
embed_dim, | |
args.decoder_attention_heads, | |
kdim=args.encoder_embed_dim, | |
vdim=args.encoder_embed_dim, | |
dropout=args.attention_dropout, | |
encoder_decoder_attention=True, | |
) | |
def residual_connection(self, x, residual): | |
return residual + x | |
def forward( | |
self, | |
x, | |
encoder_out: Optional[torch.Tensor] = None, | |
encoder_padding_mask: Optional[torch.Tensor] = None, | |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, | |
prev_self_attn_state: Optional[List[torch.Tensor]] = None, | |
prev_attn_state: Optional[List[torch.Tensor]] = None, | |
self_attn_mask: Optional[torch.Tensor] = None, | |
self_attn_padding_mask: Optional[torch.Tensor] = None, | |
need_attn: bool = False, | |
need_head_weights: bool = False, | |
): | |
""" | |
Args: | |
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` | |
encoder_padding_mask (ByteTensor, optional): binary | |
ByteTensor of shape `(batch, src_len)` where padding | |
elements are indicated by ``1``. | |
need_attn (bool, optional): return attention weights | |
need_head_weights (bool, optional): return attention weights | |
for each head (default: return average over heads). | |
Returns: | |
encoded output of shape `(seq_len, batch, embed_dim)` | |
""" | |
if need_head_weights: | |
need_attn = True | |
residual = x | |
x = self.self_attn_layer_norm(x) | |
if prev_self_attn_state is not None: | |
prev_key, prev_value = prev_self_attn_state[:2] | |
saved_state: Dict[str, Optional[Tensor]] = { | |
"prev_key": prev_key, | |
"prev_value": prev_value, | |
} | |
if len(prev_self_attn_state) >= 3: | |
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] | |
assert incremental_state is not None | |
self.self_attn._set_input_buffer(incremental_state, saved_state) | |
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) | |
y = x | |
x, attn = self.self_attn( | |
query=x, | |
key=y, | |
value=y, | |
key_padding_mask=self_attn_padding_mask, | |
incremental_state=incremental_state, | |
need_weights=False, | |
attn_mask=self_attn_mask, | |
) | |
x = self.dropout_module(x) | |
x = self.residual_connection(x, residual) | |
if self.encoder_attn is not None and encoder_out is not None: | |
residual = x | |
x = self.encoder_attn_layer_norm(x) | |
if prev_attn_state is not None: | |
prev_key, prev_value = prev_attn_state[:2] | |
saved_state: Dict[str, Optional[Tensor]] = { | |
"prev_key": prev_key, | |
"prev_value": prev_value, | |
} | |
if len(prev_attn_state) >= 3: | |
saved_state["prev_key_padding_mask"] = prev_attn_state[2] | |
assert incremental_state is not None | |
self.encoder_attn._set_input_buffer(incremental_state, saved_state) | |
x, attn = self.encoder_attn( | |
query=x, | |
key=encoder_out, | |
value=encoder_out, | |
key_padding_mask=encoder_padding_mask, | |
incremental_state=incremental_state, | |
static_kv=True, | |
need_weights=need_attn or (not self.training and self.need_attn), | |
need_head_weights=need_head_weights, | |
) | |
x = self.dropout_module(x) | |
x = self.residual_connection(x, residual) | |
residual = x | |
x = self.final_layer_norm(x) | |
x = self.activation_fn(self.fc1(x)) | |
if self.ffn_layernorm is not None: | |
x = self.ffn_layernorm(x) | |
x = self.fc2(x) | |
x = self.dropout_module(x) | |
if self.w_resid is not None: | |
residual = torch.mul(self.w_resid, residual) | |
x = self.residual_connection(x, residual) | |
return x, attn, None | |