|
|
|
|
|
|
|
|
|
|
|
import math |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .multihead_attention import MultiheadAttention |
|
from .axial_attention import ColumnSelfAttention, RowSelfAttention |
|
|
|
|
|
def gelu(x): |
|
"""Implementation of the gelu activation function. |
|
|
|
For information: OpenAI GPT's gelu is slightly different |
|
(and gives slightly different results): |
|
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
|
""" |
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
|
|
|
|
|
def symmetrize(x): |
|
"Make layer symmetric in final two dimensions, used for contact prediction." |
|
return x + x.transpose(-1, -2) |
|
|
|
|
|
def apc(x): |
|
"Perform average product correct, used for contact prediction." |
|
a1 = x.sum(-1, keepdims=True) |
|
a2 = x.sum(-2, keepdims=True) |
|
a12 = x.sum((-1, -2), keepdims=True) |
|
|
|
avg = a1 * a2 |
|
avg.div_(a12) |
|
normalized = x - avg |
|
return normalized |
|
|
|
|
|
class ESM1LayerNorm(nn.Module): |
|
def __init__(self, hidden_size, eps=1e-12, affine=True): |
|
"""Construct a layernorm layer in the TF style (eps inside the sqrt).""" |
|
super().__init__() |
|
self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size) |
|
self.eps = eps |
|
self.affine = bool(affine) |
|
if self.affine: |
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
self.bias = nn.Parameter(torch.zeros(hidden_size)) |
|
else: |
|
self.weight, self.bias = None, None |
|
|
|
def forward(self, x): |
|
dims = tuple(-(i + 1) for i in range(len(self.hidden_size))) |
|
means = x.mean(dims, keepdim=True) |
|
x_zeromean = x - means |
|
variances = x_zeromean.pow(2).mean(dims, keepdim=True) |
|
x = x_zeromean / torch.sqrt(variances + self.eps) |
|
if self.affine: |
|
x = (self.weight * x) + self.bias |
|
return x |
|
|
|
|
|
try: |
|
from apex.normalization import FusedLayerNorm as _FusedLayerNorm |
|
|
|
class ESM1bLayerNorm(_FusedLayerNorm): |
|
@torch.jit.unused |
|
def forward(self, x): |
|
if not x.is_cuda: |
|
return super().forward(x) |
|
else: |
|
with torch.cuda.device(x.device): |
|
return super().forward(x) |
|
|
|
except ImportError: |
|
from torch.nn import LayerNorm as ESM1bLayerNorm |
|
|
|
|
|
class TransformerLayer(nn.Module): |
|
"""Transformer layer block.""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim, |
|
ffn_embed_dim, |
|
attention_heads, |
|
add_bias_kv=True, |
|
use_esm1b_layer_norm=False, |
|
use_rotary_embeddings: bool = False, |
|
): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.ffn_embed_dim = ffn_embed_dim |
|
self.attention_heads = attention_heads |
|
self.use_rotary_embeddings = use_rotary_embeddings |
|
self._init_submodules(add_bias_kv, use_esm1b_layer_norm) |
|
|
|
def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm): |
|
BertLayerNorm = ESM1bLayerNorm if use_esm1b_layer_norm else ESM1LayerNorm |
|
|
|
self.self_attn = MultiheadAttention( |
|
self.embed_dim, |
|
self.attention_heads, |
|
add_bias_kv=add_bias_kv, |
|
add_zero_attn=False, |
|
use_rotary_embeddings=self.use_rotary_embeddings, |
|
) |
|
self.self_attn_layer_norm = BertLayerNorm(self.embed_dim) |
|
|
|
self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim) |
|
self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim) |
|
|
|
self.final_layer_norm = BertLayerNorm(self.embed_dim) |
|
|
|
def forward( |
|
self, x, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False |
|
): |
|
residual = x |
|
x = self.self_attn_layer_norm(x) |
|
x, attn = self.self_attn( |
|
query=x, |
|
key=x, |
|
value=x, |
|
key_padding_mask=self_attn_padding_mask, |
|
need_weights=True, |
|
need_head_weights=need_head_weights, |
|
attn_mask=self_attn_mask, |
|
) |
|
x = residual + x |
|
|
|
residual = x |
|
x = self.final_layer_norm(x) |
|
x = gelu(self.fc1(x)) |
|
x = self.fc2(x) |
|
x = residual + x |
|
|
|
return x, attn |
|
|
|
|
|
class AxialTransformerLayer(nn.Module): |
|
"""Implements an Axial MSA Transformer block.""" |
|
|
|
def __init__( |
|
self, |
|
embedding_dim: int = 768, |
|
ffn_embedding_dim: int = 3072, |
|
num_attention_heads: int = 8, |
|
dropout: float = 0.1, |
|
attention_dropout: float = 0.1, |
|
activation_dropout: float = 0.1, |
|
max_tokens_per_msa: int = 2**14, |
|
) -> None: |
|
super().__init__() |
|
|
|
|
|
self.embedding_dim = embedding_dim |
|
self.dropout_prob = dropout |
|
|
|
row_self_attention = RowSelfAttention( |
|
embedding_dim, |
|
num_attention_heads, |
|
dropout=dropout, |
|
max_tokens_per_msa=max_tokens_per_msa, |
|
) |
|
|
|
column_self_attention = ColumnSelfAttention( |
|
embedding_dim, |
|
num_attention_heads, |
|
dropout=dropout, |
|
max_tokens_per_msa=max_tokens_per_msa, |
|
) |
|
|
|
feed_forward_layer = FeedForwardNetwork( |
|
embedding_dim, |
|
ffn_embedding_dim, |
|
activation_dropout=activation_dropout, |
|
max_tokens_per_msa=max_tokens_per_msa, |
|
) |
|
|
|
self.row_self_attention = self.build_residual(row_self_attention) |
|
self.column_self_attention = self.build_residual(column_self_attention) |
|
self.feed_forward_layer = self.build_residual(feed_forward_layer) |
|
|
|
def build_residual(self, layer: nn.Module): |
|
return NormalizedResidualBlock( |
|
layer, |
|
self.embedding_dim, |
|
self.dropout_prob, |
|
) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
self_attn_mask: Optional[torch.Tensor] = None, |
|
self_attn_padding_mask: Optional[torch.Tensor] = None, |
|
need_head_weights: bool = False, |
|
): |
|
""" |
|
LayerNorm is applied either before or after the self-attention/ffn |
|
modules similar to the original Transformer implementation. |
|
""" |
|
x, row_attn = self.row_self_attention( |
|
x, |
|
self_attn_mask=self_attn_mask, |
|
self_attn_padding_mask=self_attn_padding_mask, |
|
) |
|
x, column_attn = self.column_self_attention( |
|
x, |
|
self_attn_mask=self_attn_mask, |
|
self_attn_padding_mask=self_attn_padding_mask, |
|
) |
|
x = self.feed_forward_layer(x) |
|
if need_head_weights: |
|
return x, column_attn, row_attn |
|
else: |
|
return x |
|
|
|
|
|
class LearnedPositionalEmbedding(nn.Embedding): |
|
""" |
|
This module learns positional embeddings up to a fixed maximum size. |
|
Padding ids are ignored by either offsetting based on padding_idx |
|
or by setting padding_idx to None and ensuring that the appropriate |
|
position ids are passed to the forward function. |
|
""" |
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): |
|
if padding_idx is not None: |
|
num_embeddings_ = num_embeddings + padding_idx + 1 |
|
else: |
|
num_embeddings_ = num_embeddings |
|
super().__init__(num_embeddings_, embedding_dim, padding_idx) |
|
self.max_positions = num_embeddings |
|
|
|
def forward(self, input: torch.Tensor): |
|
"""Input is expected to be of size [bsz x seqlen].""" |
|
if input.size(1) > self.max_positions: |
|
raise ValueError( |
|
f"Sequence length {input.size(1)} above maximum " |
|
f" sequence length of {self.max_positions}" |
|
) |
|
mask = input.ne(self.padding_idx).int() |
|
positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx |
|
return F.embedding( |
|
positions, |
|
self.weight, |
|
self.padding_idx, |
|
self.max_norm, |
|
self.norm_type, |
|
self.scale_grad_by_freq, |
|
self.sparse, |
|
) |
|
|
|
|
|
class SinusoidalPositionalEmbedding(nn.Module): |
|
def __init__(self, embed_dim, padding_idx, learned=False): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.padding_idx = padding_idx |
|
self.register_buffer("_float_tensor", torch.FloatTensor(1)) |
|
self.weights = None |
|
|
|
def forward(self, x): |
|
bsz, seq_len = x.shape |
|
max_pos = self.padding_idx + 1 + seq_len |
|
if self.weights is None or max_pos > self.weights.size(0): |
|
self.weights = self.get_embedding(max_pos) |
|
self.weights = self.weights.type_as(self._float_tensor) |
|
|
|
positions = self.make_positions(x) |
|
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() |
|
|
|
def make_positions(self, x): |
|
mask = x.ne(self.padding_idx) |
|
range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1 |
|
positions = range_buf.expand_as(x) |
|
return positions * mask.long() + self.padding_idx * (1 - mask.long()) |
|
|
|
def get_embedding(self, num_embeddings): |
|
half_dim = self.embed_dim // 2 |
|
emb = math.log(10000) / (half_dim - 1) |
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) |
|
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) |
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) |
|
if self.embed_dim % 2 == 1: |
|
|
|
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) |
|
if self.padding_idx is not None: |
|
emb[self.padding_idx, :] = 0 |
|
return emb |
|
|
|
|
|
class RobertaLMHead(nn.Module): |
|
"""Head for masked language modeling.""" |
|
|
|
def __init__(self, embed_dim, output_dim, weight): |
|
super().__init__() |
|
self.dense = nn.Linear(embed_dim, embed_dim) |
|
self.layer_norm = ESM1bLayerNorm(embed_dim) |
|
self.weight = weight |
|
self.bias = nn.Parameter(torch.zeros(output_dim)) |
|
|
|
def forward(self, features): |
|
x = self.dense(features) |
|
x = gelu(x) |
|
x = self.layer_norm(x) |
|
|
|
x = F.linear(x, self.weight) + self.bias |
|
return x |
|
|
|
|
|
class ContactPredictionHead(nn.Module): |
|
"""Performs symmetrization, apc, and computes a logistic regression on the output features""" |
|
|
|
def __init__( |
|
self, |
|
in_features: int, |
|
prepend_bos: bool, |
|
append_eos: bool, |
|
bias=True, |
|
eos_idx: Optional[int] = None, |
|
): |
|
super().__init__() |
|
self.in_features = in_features |
|
self.prepend_bos = prepend_bos |
|
self.append_eos = append_eos |
|
if append_eos and eos_idx is None: |
|
raise ValueError("Using an alphabet with eos token, but no eos token was passed in.") |
|
self.eos_idx = eos_idx |
|
self.regression = nn.Linear(in_features, 1, bias) |
|
self.activation = nn.Sigmoid() |
|
|
|
def forward(self, tokens, attentions): |
|
|
|
if self.append_eos: |
|
eos_mask = tokens.ne(self.eos_idx).to(attentions) |
|
eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) |
|
attentions = attentions * eos_mask[:, None, None, :, :] |
|
attentions = attentions[..., :-1, :-1] |
|
|
|
if self.prepend_bos: |
|
attentions = attentions[..., 1:, 1:] |
|
batch_size, layers, heads, seqlen, _ = attentions.size() |
|
attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen) |
|
|
|
|
|
attentions = attentions.to( |
|
self.regression.weight.device |
|
) |
|
attentions = apc(symmetrize(attentions)) |
|
attentions = attentions.permute(0, 2, 3, 1) |
|
return self.activation(self.regression(attentions).squeeze(3)) |
|
|
|
|
|
class NormalizedResidualBlock(nn.Module): |
|
def __init__( |
|
self, |
|
layer: nn.Module, |
|
embedding_dim: int, |
|
dropout: float = 0.1, |
|
): |
|
super().__init__() |
|
self.embedding_dim = embedding_dim |
|
|
|
self.layer = layer |
|
self.dropout_module = nn.Dropout( |
|
dropout, |
|
) |
|
self.layer_norm = ESM1bLayerNorm(self.embedding_dim) |
|
|
|
def forward(self, x, *args, **kwargs): |
|
residual = x |
|
x = self.layer_norm(x) |
|
outputs = self.layer(x, *args, **kwargs) |
|
if isinstance(outputs, tuple): |
|
x, *out = outputs |
|
else: |
|
x = outputs |
|
out = None |
|
|
|
x = self.dropout_module(x) |
|
x = residual + x |
|
|
|
if out is not None: |
|
return (x,) + tuple(out) |
|
else: |
|
return x |
|
|
|
|
|
class FeedForwardNetwork(nn.Module): |
|
def __init__( |
|
self, |
|
embedding_dim: int, |
|
ffn_embedding_dim: int, |
|
activation_dropout: float = 0.1, |
|
max_tokens_per_msa: int = 2**14, |
|
): |
|
super().__init__() |
|
self.embedding_dim = embedding_dim |
|
self.ffn_embedding_dim = ffn_embedding_dim |
|
self.max_tokens_per_msa = max_tokens_per_msa |
|
self.activation_fn = nn.GELU() |
|
self.activation_dropout_module = nn.Dropout( |
|
activation_dropout, |
|
) |
|
self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim) |
|
self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim) |
|
|
|
def forward(self, x): |
|
x = self.activation_fn(self.fc1(x)) |
|
x = self.activation_dropout_module(x) |
|
x = self.fc2(x) |
|
return x |
|
|