|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from ..modules import ( |
|
TransformerLayer, |
|
LearnedPositionalEmbedding, |
|
SinusoidalPositionalEmbedding, |
|
RobertaLMHead, |
|
ESM1bLayerNorm, |
|
ContactPredictionHead, |
|
) |
|
|
|
|
|
class ProteinBertModel(nn.Module): |
|
@classmethod |
|
def add_args(cls, parser): |
|
parser.add_argument( |
|
"--num_layers", default=36, type=int, metavar="N", help="number of layers" |
|
) |
|
parser.add_argument( |
|
"--embed_dim", default=1280, type=int, metavar="N", help="embedding dimension" |
|
) |
|
parser.add_argument( |
|
"--logit_bias", action="store_true", help="whether to apply bias to logits" |
|
) |
|
parser.add_argument( |
|
"--ffn_embed_dim", |
|
default=5120, |
|
type=int, |
|
metavar="N", |
|
help="embedding dimension for FFN", |
|
) |
|
parser.add_argument( |
|
"--attention_heads", |
|
default=20, |
|
type=int, |
|
metavar="N", |
|
help="number of attention heads", |
|
) |
|
|
|
def __init__(self, args, alphabet): |
|
super().__init__() |
|
self.args = args |
|
self.alphabet_size = len(alphabet) |
|
self.padding_idx = alphabet.padding_idx |
|
self.mask_idx = alphabet.mask_idx |
|
self.cls_idx = alphabet.cls_idx |
|
self.eos_idx = alphabet.eos_idx |
|
self.prepend_bos = alphabet.prepend_bos |
|
self.append_eos = alphabet.append_eos |
|
self.emb_layer_norm_before = getattr(self.args, "emb_layer_norm_before", False) |
|
if self.args.arch == "roberta_large": |
|
self.model_version = "ESM-1b" |
|
self._init_submodules_esm1b() |
|
else: |
|
self.model_version = "ESM-1" |
|
self._init_submodules_esm1() |
|
|
|
def _init_submodules_common(self): |
|
self.embed_tokens = nn.Embedding( |
|
self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx |
|
) |
|
self.layers = nn.ModuleList( |
|
[ |
|
TransformerLayer( |
|
self.args.embed_dim, |
|
self.args.ffn_embed_dim, |
|
self.args.attention_heads, |
|
add_bias_kv=(self.model_version != "ESM-1b"), |
|
use_esm1b_layer_norm=(self.model_version == "ESM-1b"), |
|
) |
|
for _ in range(self.args.layers) |
|
] |
|
) |
|
|
|
self.contact_head = ContactPredictionHead( |
|
self.args.layers * self.args.attention_heads, |
|
self.prepend_bos, |
|
self.append_eos, |
|
eos_idx=self.eos_idx, |
|
) |
|
|
|
def _init_submodules_esm1b(self): |
|
self._init_submodules_common() |
|
self.embed_scale = 1 |
|
self.embed_positions = LearnedPositionalEmbedding( |
|
self.args.max_positions, self.args.embed_dim, self.padding_idx |
|
) |
|
self.emb_layer_norm_before = ( |
|
ESM1bLayerNorm(self.args.embed_dim) if self.emb_layer_norm_before else None |
|
) |
|
self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim) |
|
self.lm_head = RobertaLMHead( |
|
embed_dim=self.args.embed_dim, |
|
output_dim=self.alphabet_size, |
|
weight=self.embed_tokens.weight, |
|
) |
|
|
|
def _init_submodules_esm1(self): |
|
self._init_submodules_common() |
|
self.embed_scale = math.sqrt(self.args.embed_dim) |
|
self.embed_positions = SinusoidalPositionalEmbedding(self.args.embed_dim, self.padding_idx) |
|
self.embed_out = nn.Parameter(torch.zeros((self.alphabet_size, self.args.embed_dim))) |
|
self.embed_out_bias = None |
|
if self.args.final_bias: |
|
self.embed_out_bias = nn.Parameter(torch.zeros(self.alphabet_size)) |
|
|
|
def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False): |
|
if return_contacts: |
|
need_head_weights = True |
|
|
|
assert tokens.ndim == 2 |
|
padding_mask = tokens.eq(self.padding_idx) |
|
|
|
x = self.embed_scale * self.embed_tokens(tokens) |
|
|
|
if getattr(self.args, "token_dropout", False): |
|
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0) |
|
|
|
mask_ratio_train = 0.15 * 0.8 |
|
src_lengths = (~padding_mask).sum(-1) |
|
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).float() / src_lengths |
|
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] |
|
|
|
x = x + self.embed_positions(tokens) |
|
|
|
if self.model_version == "ESM-1b": |
|
if self.emb_layer_norm_before: |
|
x = self.emb_layer_norm_before(x) |
|
if padding_mask is not None: |
|
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) |
|
|
|
repr_layers = set(repr_layers) |
|
hidden_representations = {} |
|
if 0 in repr_layers: |
|
hidden_representations[0] = x |
|
|
|
if need_head_weights: |
|
attn_weights = [] |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
if not padding_mask.any(): |
|
padding_mask = None |
|
|
|
for layer_idx, layer in enumerate(self.layers): |
|
x, attn = layer( |
|
x, self_attn_padding_mask=padding_mask, need_head_weights=need_head_weights |
|
) |
|
if (layer_idx + 1) in repr_layers: |
|
hidden_representations[layer_idx + 1] = x.transpose(0, 1) |
|
if need_head_weights: |
|
|
|
attn_weights.append(attn.transpose(1, 0)) |
|
|
|
if self.model_version == "ESM-1b": |
|
x = self.emb_layer_norm_after(x) |
|
x = x.transpose(0, 1) |
|
|
|
|
|
if (layer_idx + 1) in repr_layers: |
|
hidden_representations[layer_idx + 1] = x |
|
x = self.lm_head(x) |
|
else: |
|
x = F.linear(x, self.embed_out, bias=self.embed_out_bias) |
|
x = x.transpose(0, 1) |
|
|
|
result = {"logits": x, "representations": hidden_representations} |
|
if need_head_weights: |
|
|
|
attentions = torch.stack(attn_weights, 1) |
|
if self.model_version == "ESM-1": |
|
|
|
attentions = attentions[..., :-1] |
|
if padding_mask is not None: |
|
attention_mask = 1 - padding_mask.type_as(attentions) |
|
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2) |
|
attentions = attentions * attention_mask[:, None, None, :, :] |
|
result["attentions"] = attentions |
|
if return_contacts: |
|
contacts = self.contact_head(tokens, attentions) |
|
result["contacts"] = contacts |
|
|
|
return result |
|
|
|
def predict_contacts(self, tokens): |
|
return self(tokens, return_contacts=True)["contacts"] |
|
|
|
@property |
|
def num_layers(self): |
|
return self.args.layers |
|
|