Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn as nn | |
from ..modules import ( | |
AxialTransformerLayer, | |
LearnedPositionalEmbedding, | |
RobertaLMHead, | |
ESM1bLayerNorm, | |
ContactPredictionHead, | |
) | |
from ..axial_attention import RowSelfAttention, ColumnSelfAttention | |
class MSATransformer(nn.Module): | |
def add_args(cls, parser): | |
# fmt: off | |
parser.add_argument( | |
"--num_layers", | |
default=12, | |
type=int, | |
metavar="N", | |
help="number of layers" | |
) | |
parser.add_argument( | |
"--embed_dim", | |
default=768, | |
type=int, | |
metavar="N", | |
help="embedding dimension" | |
) | |
parser.add_argument( | |
"--logit_bias", | |
action="store_true", | |
help="whether to apply bias to logits" | |
) | |
parser.add_argument( | |
"--ffn_embed_dim", | |
default=3072, | |
type=int, | |
metavar="N", | |
help="embedding dimension for FFN", | |
) | |
parser.add_argument( | |
"--attention_heads", | |
default=12, | |
type=int, | |
metavar="N", | |
help="number of attention heads", | |
) | |
parser.add_argument( | |
"--dropout", | |
default=0.1, | |
type=float, | |
help="Dropout to apply." | |
) | |
parser.add_argument( | |
"--attention_dropout", | |
default=0.1, | |
type=float, | |
help="Dropout to apply." | |
) | |
parser.add_argument( | |
"--activation_dropout", | |
default=0.1, | |
type=float, | |
help="Dropout to apply." | |
) | |
parser.add_argument( | |
"--max_tokens_per_msa", | |
default=2 ** 14, | |
type=int, | |
help=( | |
"Used during inference to batch attention computations in a single " | |
"forward pass. This allows increased input sizes with less memory." | |
), | |
) | |
# fmt: on | |
def __init__(self, args, alphabet): | |
super().__init__() | |
self.args = args | |
self.alphabet_size = len(alphabet) | |
self.padding_idx = alphabet.padding_idx | |
self.mask_idx = alphabet.mask_idx | |
self.cls_idx = alphabet.cls_idx | |
self.eos_idx = alphabet.eos_idx | |
self.prepend_bos = alphabet.prepend_bos | |
self.append_eos = alphabet.append_eos | |
self.embed_tokens = nn.Embedding( | |
self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx | |
) | |
if getattr(self.args, "embed_positions_msa", False): | |
emb_dim = getattr(self.args, "embed_positions_msa_dim", self.args.embed_dim) | |
self.msa_position_embedding = nn.Parameter( | |
0.01 * torch.randn(1, 1024, 1, emb_dim), | |
requires_grad=True, | |
) | |
else: | |
self.register_parameter("msa_position_embedding", None) | |
self.dropout_module = nn.Dropout(self.args.dropout) | |
self.layers = nn.ModuleList( | |
[ | |
AxialTransformerLayer( | |
self.args.embed_dim, | |
self.args.ffn_embed_dim, | |
self.args.attention_heads, | |
self.args.dropout, | |
self.args.attention_dropout, | |
self.args.activation_dropout, | |
getattr(self.args, "max_tokens_per_msa", self.args.max_tokens), | |
) | |
for _ in range(self.args.layers) | |
] | |
) | |
self.contact_head = ContactPredictionHead( | |
self.args.layers * self.args.attention_heads, | |
self.prepend_bos, | |
self.append_eos, | |
eos_idx=self.eos_idx, | |
) | |
self.embed_positions = LearnedPositionalEmbedding( | |
self.args.max_positions, | |
self.args.embed_dim, | |
self.padding_idx, | |
) | |
self.emb_layer_norm_before = ESM1bLayerNorm(self.args.embed_dim) | |
self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim) | |
self.lm_head = RobertaLMHead( | |
embed_dim=self.args.embed_dim, | |
output_dim=self.alphabet_size, | |
weight=self.embed_tokens.weight, | |
) | |
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False): | |
if return_contacts: | |
need_head_weights = True | |
assert tokens.ndim == 3 | |
batch_size, num_alignments, seqlen = tokens.size() | |
padding_mask = tokens.eq(self.padding_idx) # B, R, C | |
if not padding_mask.any(): | |
padding_mask = None | |
x = self.embed_tokens(tokens) | |
x += self.embed_positions(tokens.view(batch_size * num_alignments, seqlen)).view(x.size()) | |
if self.msa_position_embedding is not None: | |
if x.size(1) > 1024: | |
raise RuntimeError( | |
"Using model with MSA position embedding trained on maximum MSA " | |
f"depth of 1024, but received {x.size(1)} alignments." | |
) | |
x += self.msa_position_embedding[:, :num_alignments] | |
x = self.emb_layer_norm_before(x) | |
x = self.dropout_module(x) | |
if padding_mask is not None: | |
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) | |
repr_layers = set(repr_layers) | |
hidden_representations = {} | |
if 0 in repr_layers: | |
hidden_representations[0] = x | |
if need_head_weights: | |
row_attn_weights = [] | |
col_attn_weights = [] | |
# B x R x C x D -> R x C x B x D | |
x = x.permute(1, 2, 0, 3) | |
for layer_idx, layer in enumerate(self.layers): | |
x = layer( | |
x, | |
self_attn_padding_mask=padding_mask, | |
need_head_weights=need_head_weights, | |
) | |
if need_head_weights: | |
x, col_attn, row_attn = x | |
# H x C x B x R x R -> B x H x C x R x R | |
col_attn_weights.append(col_attn.permute(2, 0, 1, 3, 4)) | |
# H x B x C x C -> B x H x C x C | |
row_attn_weights.append(row_attn.permute(1, 0, 2, 3)) | |
if (layer_idx + 1) in repr_layers: | |
hidden_representations[layer_idx + 1] = x.permute(2, 0, 1, 3) | |
x = self.emb_layer_norm_after(x) | |
x = x.permute(2, 0, 1, 3) # R x C x B x D -> B x R x C x D | |
# last hidden representation should have layer norm applied | |
if (layer_idx + 1) in repr_layers: | |
hidden_representations[layer_idx + 1] = x | |
x = self.lm_head(x) | |
result = {"logits": x, "representations": hidden_representations} | |
if need_head_weights: | |
# col_attentions: B x L x H x C x R x R | |
col_attentions = torch.stack(col_attn_weights, 1) | |
# row_attentions: B x L x H x C x C | |
row_attentions = torch.stack(row_attn_weights, 1) | |
result["col_attentions"] = col_attentions | |
result["row_attentions"] = row_attentions | |
if return_contacts: | |
contacts = self.contact_head(tokens, row_attentions) | |
result["contacts"] = contacts | |
return result | |
def predict_contacts(self, tokens): | |
return self(tokens, return_contacts=True)["contacts"] | |
def num_layers(self): | |
return self.args.layers | |
def max_tokens_per_msa_(self, value: int) -> None: | |
"""The MSA Transformer automatically batches attention computations when | |
gradients are disabled to allow you to pass in larger MSAs at test time than | |
you can fit in GPU memory. By default this occurs when more than 2^14 tokens | |
are passed in the input MSA. You can set this value to infinity to disable | |
this behavior. | |
""" | |
for module in self.modules(): | |
if isinstance(module, (RowSelfAttention, ColumnSelfAttention)): | |
module.max_tokens_per_msa = value | |