|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, List, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from esm.multihead_attention import MultiheadAttention |
|
from torch import Tensor |
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
"""Encoder layer block. |
|
`layernorm -> dropout -> add residual` |
|
|
|
Args: |
|
args (argparse.Namespace): parsed command-line arguments |
|
""" |
|
|
|
def __init__(self, args): |
|
super().__init__() |
|
self.args = args |
|
self.embed_dim = args.encoder_embed_dim |
|
self.self_attn = self.build_self_attention(self.embed_dim, args) |
|
self.self_attn_layer_norm = torch.nn.LayerNorm(self.embed_dim) |
|
self.dropout_module = nn.Dropout(args.dropout) |
|
self.activation_fn = F.relu |
|
self.fc1 = self.build_fc1( |
|
self.embed_dim, |
|
args.encoder_ffn_embed_dim, |
|
) |
|
self.fc2 = self.build_fc2( |
|
args.encoder_ffn_embed_dim, |
|
self.embed_dim, |
|
) |
|
|
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
|
def build_fc1(self, input_dim, output_dim): |
|
return nn.Linear(input_dim, output_dim) |
|
|
|
def build_fc2(self, input_dim, output_dim): |
|
return nn.Linear(input_dim, output_dim) |
|
|
|
def build_self_attention(self, embed_dim, args): |
|
return MultiheadAttention( |
|
embed_dim, |
|
args.encoder_attention_heads, |
|
dropout=args.attention_dropout, |
|
self_attention=True, |
|
) |
|
|
|
def residual_connection(self, x, residual): |
|
return residual + x |
|
|
|
def forward( |
|
self, |
|
x, |
|
encoder_padding_mask: Optional[Tensor], |
|
attn_mask: Optional[Tensor] = None, |
|
): |
|
""" |
|
Args: |
|
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` |
|
encoder_padding_mask (ByteTensor): binary ByteTensor of shape |
|
`(batch, seq_len)` where padding elements are indicated by ``1``. |
|
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, |
|
where `tgt_len` is the length of output and `src_len` is the |
|
length of input, though here both are equal to `seq_len`. |
|
`attn_mask[tgt_i, src_j] = 1` means that when calculating the |
|
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is |
|
useful for strided self-attention. |
|
|
|
Returns: |
|
encoded output of shape `(seq_len, batch, embed_dim)` |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if attn_mask is not None: |
|
attn_mask = attn_mask.masked_fill( |
|
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 |
|
) |
|
|
|
residual = x |
|
x = self.self_attn_layer_norm(x) |
|
x, _ = self.self_attn( |
|
query=x, |
|
key=x, |
|
value=x, |
|
key_padding_mask=encoder_padding_mask, |
|
need_weights=False, |
|
attn_mask=attn_mask, |
|
) |
|
x = self.dropout_module(x) |
|
x = self.residual_connection(x, residual) |
|
|
|
residual = x |
|
x = self.final_layer_norm(x) |
|
x = self.activation_fn(self.fc1(x)) |
|
x = self.fc2(x) |
|
x = self.dropout_module(x) |
|
x = self.residual_connection(x, residual) |
|
return x |
|
|
|
|
|
class TransformerDecoderLayer(nn.Module): |
|
"""Decoder layer block. |
|
`layernorm -> dropout -> add residual` |
|
|
|
Args: |
|
args (argparse.Namespace): parsed command-line arguments |
|
no_encoder_attn (bool, optional): whether to attend to encoder outputs |
|
(default: False). |
|
""" |
|
|
|
def __init__( |
|
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False |
|
): |
|
super().__init__() |
|
self.embed_dim = args.decoder_embed_dim |
|
self.dropout_module = nn.Dropout(args.dropout) |
|
|
|
self.self_attn = self.build_self_attention( |
|
self.embed_dim, |
|
args, |
|
add_bias_kv=add_bias_kv, |
|
add_zero_attn=add_zero_attn, |
|
) |
|
self.nh = self.self_attn.num_heads |
|
self.head_dim = self.self_attn.head_dim |
|
|
|
self.activation_fn = F.relu |
|
|
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
|
if no_encoder_attn: |
|
self.encoder_attn = None |
|
self.encoder_attn_layer_norm = None |
|
else: |
|
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) |
|
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
|
self.ffn_layernorm = ( |
|
LayerNorm(args.decoder_ffn_embed_dim) |
|
if getattr(args, "scale_fc", False) |
|
else None |
|
) |
|
self.w_resid = ( |
|
nn.Parameter( |
|
torch.ones( |
|
self.embed_dim, |
|
), |
|
requires_grad=True, |
|
) |
|
if getattr(args, "scale_resids", False) |
|
else None |
|
) |
|
|
|
self.fc1 = self.build_fc1( |
|
self.embed_dim, |
|
args.decoder_ffn_embed_dim, |
|
) |
|
self.fc2 = self.build_fc2( |
|
args.decoder_ffn_embed_dim, |
|
self.embed_dim, |
|
) |
|
|
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim) |
|
self.need_attn = True |
|
|
|
def build_fc1(self, input_dim, output_dim): |
|
return nn.Linear(input_dim, output_dim) |
|
|
|
def build_fc2(self, input_dim, output_dim): |
|
return nn.Linear(input_dim, output_dim) |
|
|
|
def build_self_attention( |
|
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False |
|
): |
|
return MultiheadAttention( |
|
embed_dim, |
|
args.decoder_attention_heads, |
|
dropout=args.attention_dropout, |
|
add_bias_kv=add_bias_kv, |
|
add_zero_attn=add_zero_attn, |
|
self_attention=True, |
|
) |
|
|
|
def build_encoder_attention(self, embed_dim, args): |
|
return MultiheadAttention( |
|
embed_dim, |
|
args.decoder_attention_heads, |
|
kdim=args.encoder_embed_dim, |
|
vdim=args.encoder_embed_dim, |
|
dropout=args.attention_dropout, |
|
encoder_decoder_attention=True, |
|
) |
|
|
|
def residual_connection(self, x, residual): |
|
return residual + x |
|
|
|
def forward( |
|
self, |
|
x, |
|
encoder_out: Optional[torch.Tensor] = None, |
|
encoder_padding_mask: Optional[torch.Tensor] = None, |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
|
prev_self_attn_state: Optional[List[torch.Tensor]] = None, |
|
prev_attn_state: Optional[List[torch.Tensor]] = None, |
|
self_attn_mask: Optional[torch.Tensor] = None, |
|
self_attn_padding_mask: Optional[torch.Tensor] = None, |
|
need_attn: bool = False, |
|
need_head_weights: bool = False, |
|
): |
|
""" |
|
Args: |
|
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` |
|
encoder_padding_mask (ByteTensor, optional): binary |
|
ByteTensor of shape `(batch, src_len)` where padding |
|
elements are indicated by ``1``. |
|
need_attn (bool, optional): return attention weights |
|
need_head_weights (bool, optional): return attention weights |
|
for each head (default: return average over heads). |
|
|
|
Returns: |
|
encoded output of shape `(seq_len, batch, embed_dim)` |
|
""" |
|
if need_head_weights: |
|
need_attn = True |
|
|
|
residual = x |
|
x = self.self_attn_layer_norm(x) |
|
if prev_self_attn_state is not None: |
|
prev_key, prev_value = prev_self_attn_state[:2] |
|
saved_state: Dict[str, Optional[Tensor]] = { |
|
"prev_key": prev_key, |
|
"prev_value": prev_value, |
|
} |
|
if len(prev_self_attn_state) >= 3: |
|
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] |
|
assert incremental_state is not None |
|
self.self_attn._set_input_buffer(incremental_state, saved_state) |
|
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) |
|
y = x |
|
|
|
x, attn = self.self_attn( |
|
query=x, |
|
key=y, |
|
value=y, |
|
key_padding_mask=self_attn_padding_mask, |
|
incremental_state=incremental_state, |
|
need_weights=False, |
|
attn_mask=self_attn_mask, |
|
) |
|
x = self.dropout_module(x) |
|
x = self.residual_connection(x, residual) |
|
|
|
if self.encoder_attn is not None and encoder_out is not None: |
|
residual = x |
|
x = self.encoder_attn_layer_norm(x) |
|
if prev_attn_state is not None: |
|
prev_key, prev_value = prev_attn_state[:2] |
|
saved_state: Dict[str, Optional[Tensor]] = { |
|
"prev_key": prev_key, |
|
"prev_value": prev_value, |
|
} |
|
if len(prev_attn_state) >= 3: |
|
saved_state["prev_key_padding_mask"] = prev_attn_state[2] |
|
assert incremental_state is not None |
|
self.encoder_attn._set_input_buffer(incremental_state, saved_state) |
|
|
|
x, attn = self.encoder_attn( |
|
query=x, |
|
key=encoder_out, |
|
value=encoder_out, |
|
key_padding_mask=encoder_padding_mask, |
|
incremental_state=incremental_state, |
|
static_kv=True, |
|
need_weights=need_attn or (not self.training and self.need_attn), |
|
need_head_weights=need_head_weights, |
|
) |
|
x = self.dropout_module(x) |
|
x = self.residual_connection(x, residual) |
|
|
|
residual = x |
|
x = self.final_layer_norm(x) |
|
|
|
x = self.activation_fn(self.fc1(x)) |
|
if self.ffn_layernorm is not None: |
|
x = self.ffn_layernorm(x) |
|
x = self.fc2(x) |
|
x = self.dropout_module(x) |
|
if self.w_resid is not None: |
|
residual = torch.mul(self.w_resid, residual) |
|
x = self.residual_connection(x, residual) |
|
return x, attn, None |
|
|