# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from .multihead_attention import MultiheadAttention # noqa from .axial_attention import ColumnSelfAttention, RowSelfAttention def gelu(x): """Implementation of the gelu activation function. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) """ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) def symmetrize(x): "Make layer symmetric in final two dimensions, used for contact prediction." return x + x.transpose(-1, -2) def apc(x): "Perform average product correct, used for contact prediction." a1 = x.sum(-1, keepdims=True) a2 = x.sum(-2, keepdims=True) a12 = x.sum((-1, -2), keepdims=True) avg = a1 * a2 avg.div_(a12) # in-place to reduce memory normalized = x - avg return normalized class ESM1LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12, affine=True): """Construct a layernorm layer in the TF style (eps inside the sqrt).""" super().__init__() self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size) self.eps = eps self.affine = bool(affine) if self.affine: self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) else: self.weight, self.bias = None, None def forward(self, x): dims = tuple(-(i + 1) for i in range(len(self.hidden_size))) means = x.mean(dims, keepdim=True) x_zeromean = x - means variances = x_zeromean.pow(2).mean(dims, keepdim=True) x = x_zeromean / torch.sqrt(variances + self.eps) if self.affine: x = (self.weight * x) + self.bias return x try: from apex.normalization import FusedLayerNorm as _FusedLayerNorm class ESM1bLayerNorm(_FusedLayerNorm): @torch.jit.unused def forward(self, x): if not x.is_cuda: return super().forward(x) else: with torch.cuda.device(x.device): return super().forward(x) except ImportError: from torch.nn import LayerNorm as ESM1bLayerNorm class TransformerLayer(nn.Module): """Transformer layer block.""" def __init__( self, embed_dim, ffn_embed_dim, attention_heads, add_bias_kv=True, use_esm1b_layer_norm=False, use_rotary_embeddings: bool = False, ): super().__init__() self.embed_dim = embed_dim self.ffn_embed_dim = ffn_embed_dim self.attention_heads = attention_heads self.use_rotary_embeddings = use_rotary_embeddings self._init_submodules(add_bias_kv, use_esm1b_layer_norm) def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm): BertLayerNorm = ESM1bLayerNorm if use_esm1b_layer_norm else ESM1LayerNorm self.self_attn = MultiheadAttention( self.embed_dim, self.attention_heads, add_bias_kv=add_bias_kv, add_zero_attn=False, use_rotary_embeddings=self.use_rotary_embeddings, ) self.self_attn_layer_norm = BertLayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim) self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim) self.final_layer_norm = BertLayerNorm(self.embed_dim) def forward( self, x, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False ): residual = x x = self.self_attn_layer_norm(x) x, attn = self.self_attn( query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=True, need_head_weights=need_head_weights, attn_mask=self_attn_mask, ) x = residual + x residual = x x = self.final_layer_norm(x) x = gelu(self.fc1(x)) x = self.fc2(x) x = residual + x return x, attn class AxialTransformerLayer(nn.Module): """Implements an Axial MSA Transformer block.""" def __init__( self, embedding_dim: int = 768, ffn_embedding_dim: int = 3072, num_attention_heads: int = 8, dropout: float = 0.1, attention_dropout: float = 0.1, activation_dropout: float = 0.1, max_tokens_per_msa: int = 2**14, ) -> None: super().__init__() # Initialize parameters self.embedding_dim = embedding_dim self.dropout_prob = dropout row_self_attention = RowSelfAttention( embedding_dim, num_attention_heads, dropout=dropout, max_tokens_per_msa=max_tokens_per_msa, ) column_self_attention = ColumnSelfAttention( embedding_dim, num_attention_heads, dropout=dropout, max_tokens_per_msa=max_tokens_per_msa, ) feed_forward_layer = FeedForwardNetwork( embedding_dim, ffn_embedding_dim, activation_dropout=activation_dropout, max_tokens_per_msa=max_tokens_per_msa, ) self.row_self_attention = self.build_residual(row_self_attention) self.column_self_attention = self.build_residual(column_self_attention) self.feed_forward_layer = self.build_residual(feed_forward_layer) def build_residual(self, layer: nn.Module): return NormalizedResidualBlock( layer, self.embedding_dim, self.dropout_prob, ) def forward( self, x: torch.Tensor, self_attn_mask: Optional[torch.Tensor] = None, self_attn_padding_mask: Optional[torch.Tensor] = None, need_head_weights: bool = False, ): """ LayerNorm is applied either before or after the self-attention/ffn modules similar to the original Transformer implementation. """ x, row_attn = self.row_self_attention( x, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, ) x, column_attn = self.column_self_attention( x, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, ) x = self.feed_forward_layer(x) if need_head_weights: return x, column_attn, row_attn else: return x class LearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to the forward function. """ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): if padding_idx is not None: num_embeddings_ = num_embeddings + padding_idx + 1 else: num_embeddings_ = num_embeddings super().__init__(num_embeddings_, embedding_dim, padding_idx) self.max_positions = num_embeddings def forward(self, input: torch.Tensor): """Input is expected to be of size [bsz x seqlen].""" if input.size(1) > self.max_positions: raise ValueError( f"Sequence length {input.size(1)} above maximum " f" sequence length of {self.max_positions}" ) mask = input.ne(self.padding_idx).int() positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx return F.embedding( positions, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse, ) class SinusoidalPositionalEmbedding(nn.Module): def __init__(self, embed_dim, padding_idx, learned=False): super().__init__() self.embed_dim = embed_dim self.padding_idx = padding_idx self.register_buffer("_float_tensor", torch.FloatTensor(1)) self.weights = None def forward(self, x): bsz, seq_len = x.shape max_pos = self.padding_idx + 1 + seq_len if self.weights is None or max_pos > self.weights.size(0): self.weights = self.get_embedding(max_pos) self.weights = self.weights.type_as(self._float_tensor) positions = self.make_positions(x) return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() def make_positions(self, x): mask = x.ne(self.padding_idx) range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1 positions = range_buf.expand_as(x) return positions * mask.long() + self.padding_idx * (1 - mask.long()) def get_embedding(self, num_embeddings): half_dim = self.embed_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) if self.embed_dim % 2 == 1: # zero pad emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) if self.padding_idx is not None: emb[self.padding_idx, :] = 0 return emb class RobertaLMHead(nn.Module): """Head for masked language modeling.""" def __init__(self, embed_dim, output_dim, weight): super().__init__() self.dense = nn.Linear(embed_dim, embed_dim) self.layer_norm = ESM1bLayerNorm(embed_dim) self.weight = weight self.bias = nn.Parameter(torch.zeros(output_dim)) def forward(self, features): x = self.dense(features) x = gelu(x) x = self.layer_norm(x) # project back to size of vocabulary with bias x = F.linear(x, self.weight) + self.bias return x class ContactPredictionHead(nn.Module): """Performs symmetrization, apc, and computes a logistic regression on the output features""" def __init__( self, in_features: int, prepend_bos: bool, append_eos: bool, bias=True, eos_idx: Optional[int] = None, ): super().__init__() self.in_features = in_features self.prepend_bos = prepend_bos self.append_eos = append_eos if append_eos and eos_idx is None: raise ValueError("Using an alphabet with eos token, but no eos token was passed in.") self.eos_idx = eos_idx self.regression = nn.Linear(in_features, 1, bias) self.activation = nn.Sigmoid() def forward(self, tokens, attentions): # remove eos token attentions if self.append_eos: eos_mask = tokens.ne(self.eos_idx).to(attentions) eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) attentions = attentions * eos_mask[:, None, None, :, :] attentions = attentions[..., :-1, :-1] # remove cls token attentions if self.prepend_bos: attentions = attentions[..., 1:, 1:] batch_size, layers, heads, seqlen, _ = attentions.size() attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen) # features: B x C x T x T attentions = attentions.to( self.regression.weight.device ) # attentions always float32, may need to convert to float16 attentions = apc(symmetrize(attentions)) attentions = attentions.permute(0, 2, 3, 1) return self.activation(self.regression(attentions).squeeze(3)) class NormalizedResidualBlock(nn.Module): def __init__( self, layer: nn.Module, embedding_dim: int, dropout: float = 0.1, ): super().__init__() self.embedding_dim = embedding_dim self.layer = layer self.dropout_module = nn.Dropout( dropout, ) self.layer_norm = ESM1bLayerNorm(self.embedding_dim) def forward(self, x, *args, **kwargs): residual = x x = self.layer_norm(x) outputs = self.layer(x, *args, **kwargs) if isinstance(outputs, tuple): x, *out = outputs else: x = outputs out = None x = self.dropout_module(x) x = residual + x if out is not None: return (x,) + tuple(out) else: return x class FeedForwardNetwork(nn.Module): def __init__( self, embedding_dim: int, ffn_embedding_dim: int, activation_dropout: float = 0.1, max_tokens_per_msa: int = 2**14, ): super().__init__() self.embedding_dim = embedding_dim self.ffn_embedding_dim = ffn_embedding_dim self.max_tokens_per_msa = max_tokens_per_msa self.activation_fn = nn.GELU() self.activation_dropout_module = nn.Dropout( activation_dropout, ) self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim) self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim) def forward(self, x): x = self.activation_fn(self.fc1(x)) x = self.activation_dropout_module(x) x = self.fc2(x) return x