ProteinGPT-Llama3 / esm /inverse_folding /transformer_layer.py
EdwardoSunny's picture
finished
85ab89d
raw
history blame
10.7 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Contents of this file were adapted from the open source fairseq repository.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from esm.multihead_attention import MultiheadAttention
from torch import Tensor
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
`layernorm -> dropout -> add residual`
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, args):
super().__init__()
self.args = args
self.embed_dim = args.encoder_embed_dim
self.self_attn = self.build_self_attention(self.embed_dim, args)
self.self_attn_layer_norm = torch.nn.LayerNorm(self.embed_dim)
self.dropout_module = nn.Dropout(args.dropout)
self.activation_fn = F.relu
self.fc1 = self.build_fc1(
self.embed_dim,
args.encoder_ffn_embed_dim,
)
self.fc2 = self.build_fc2(
args.encoder_ffn_embed_dim,
self.embed_dim,
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def build_fc1(self, input_dim, output_dim):
return nn.Linear(input_dim, output_dim)
def build_fc2(self, input_dim, output_dim):
return nn.Linear(input_dim, output_dim)
def build_self_attention(self, embed_dim, args):
return MultiheadAttention(
embed_dim,
args.encoder_attention_heads,
dropout=args.attention_dropout,
self_attention=True,
)
def residual_connection(self, x, residual):
return residual + x
def forward(
self,
x,
encoder_padding_mask: Optional[Tensor],
attn_mask: Optional[Tensor] = None,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
)
residual = x
x = self.self_attn_layer_norm(x)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=encoder_padding_mask,
need_weights=False,
attn_mask=attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
residual = x
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
return x
class TransformerDecoderLayer(nn.Module):
"""Decoder layer block.
`layernorm -> dropout -> add residual`
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.dropout_module = nn.Dropout(args.dropout)
self.self_attn = self.build_self_attention(
self.embed_dim,
args,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
self.nh = self.self_attn.num_heads
self.head_dim = self.self_attn.head_dim
self.activation_fn = F.relu
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.ffn_layernorm = (
LayerNorm(args.decoder_ffn_embed_dim)
if getattr(args, "scale_fc", False)
else None
)
self.w_resid = (
nn.Parameter(
torch.ones(
self.embed_dim,
),
requires_grad=True,
)
if getattr(args, "scale_resids", False)
else None
)
self.fc1 = self.build_fc1(
self.embed_dim,
args.decoder_ffn_embed_dim,
)
self.fc2 = self.build_fc2(
args.decoder_ffn_embed_dim,
self.embed_dim,
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
self.need_attn = True
def build_fc1(self, input_dim, output_dim):
return nn.Linear(input_dim, output_dim)
def build_fc2(self, input_dim, output_dim):
return nn.Linear(input_dim, output_dim)
def build_self_attention(
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
):
return MultiheadAttention(
embed_dim,
args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=True,
)
def build_encoder_attention(self, embed_dim, args):
return MultiheadAttention(
embed_dim,
args.decoder_attention_heads,
kdim=args.encoder_embed_dim,
vdim=args.encoder_embed_dim,
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
def residual_connection(self, x, residual):
return residual + x
def forward(
self,
x,
encoder_out: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
prev_attn_state: Optional[List[torch.Tensor]] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if need_head_weights:
need_attn = True
residual = x
x = self.self_attn_layer_norm(x)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state, saved_state)
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
y = x
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if self.encoder_attn is not None and encoder_out is not None:
residual = x
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
residual = x
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
if self.ffn_layernorm is not None:
x = self.ffn_layernorm(x)
x = self.fc2(x)
x = self.dropout_module(x)
if self.w_resid is not None:
residual = torch.mul(self.w_resid, residual)
x = self.residual_connection(x, residual)
return x, attn, None