# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math import torch import torch.nn as nn class RowSelfAttention(nn.Module): """Compute self-attention over rows of a 2D input.""" def __init__( self, embed_dim, num_heads, dropout=0.0, max_tokens_per_msa: int = 2 ** 16, ): super().__init__() self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads self.scaling = self.head_dim ** -0.5 self.max_tokens_per_msa = max_tokens_per_msa self.attn_shape = "hnij" self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.dropout_module = nn.Dropout(dropout) def align_scaling(self, q): num_rows = q.size(0) return self.scaling / math.sqrt(num_rows) def _batched_forward( self, x, self_attn_mask=None, self_attn_padding_mask=None, ): num_rows, num_cols, batch_size, embed_dim = x.size() max_rows = max(1, self.max_tokens_per_msa // num_cols) attns = 0 scaling = self.align_scaling(x) for start in range(0, num_rows, max_rows): attn_weights = self.compute_attention_weights( x[start : start + max_rows], scaling, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows] if self_attn_padding_mask is not None else None, ) attns += attn_weights attn_probs = attns.softmax(-1) attn_probs = self.dropout_module(attn_probs) outputs = [] for start in range(0, num_rows, max_rows): output = self.compute_attention_update(x[start : start + max_rows], attn_probs) outputs.append(output) output = torch.cat(outputs, 0) return output, attn_probs def compute_attention_weights( self, x, scaling: float, self_attn_mask=None, self_attn_padding_mask=None, ): num_rows, num_cols, batch_size, embed_dim = x.size() q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) q *= scaling if self_attn_padding_mask is not None: # Zero out any padded aligned positions - this is important since # we take a sum across the alignment axis. q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q) attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k) if self_attn_mask is not None: raise NotImplementedError # Mask Size: [B x R x C], Weights Size: [H x B x C x C] if self_attn_padding_mask is not None: attn_weights = attn_weights.masked_fill( self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2), -10000, ) return attn_weights def compute_attention_update( self, x, attn_probs, ): num_rows, num_cols, batch_size, embed_dim = x.size() v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v) context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim) output = self.out_proj(context) return output def forward( self, x, self_attn_mask=None, self_attn_padding_mask=None, ): num_rows, num_cols, batch_size, embed_dim = x.size() if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled(): return self._batched_forward(x, self_attn_mask, self_attn_padding_mask) else: scaling = self.align_scaling(x) attn_weights = self.compute_attention_weights( x, scaling, self_attn_mask, self_attn_padding_mask ) attn_probs = attn_weights.softmax(-1) attn_probs = self.dropout_module(attn_probs) output = self.compute_attention_update(x, attn_probs) return output, attn_probs class ColumnSelfAttention(nn.Module): """Compute self-attention over columns of a 2D input.""" def __init__( self, embed_dim, num_heads, dropout=0.0, max_tokens_per_msa: int = 2 ** 16, ): super().__init__() self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads self.scaling = self.head_dim ** -0.5 self.max_tokens_per_msa = max_tokens_per_msa self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.dropout_module = nn.Dropout(dropout) def _batched_forward( self, x, self_attn_mask=None, self_attn_padding_mask=None, ): num_rows, num_cols, batch_size, embed_dim = x.size() max_cols = max(1, self.max_tokens_per_msa // num_rows) outputs = [] attns = [] for start in range(0, num_cols, max_cols): output, attn = self( x[:, start : start + max_cols], self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols] if self_attn_padding_mask is not None else None, ) outputs.append(output) attns.append(attn) output = torch.cat(outputs, 1) attns = torch.cat(attns, 1) return output, attns def compute_attention_update( self, x, self_attn_mask=None, self_attn_padding_mask=None, ): num_rows, num_cols, batch_size, embed_dim = x.size() if num_rows == 1: # if there is only 1 position, this is equivalent and doesn't break with padding attn_probs = torch.ones( self.num_heads, num_cols, batch_size, num_rows, num_rows, device=x.device, dtype=x.dtype, ) output = self.out_proj(self.v_proj(x)) else: q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) q *= self.scaling attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k) if self_attn_mask is not None: raise NotImplementedError if self_attn_padding_mask is not None: attn_weights = attn_weights.masked_fill( self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3), -10000, ) attn_probs = attn_weights.softmax(-1) attn_probs = self.dropout_module(attn_probs) context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v) context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim) output = self.out_proj(context) return output, attn_probs def forward( self, x, self_attn_mask=None, self_attn_padding_mask=None, ): num_rows, num_cols, batch_size, embed_dim = x.size() # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled(): if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled(): return self._batched_forward( x, self_attn_mask, self_attn_padding_mask, ) else: return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask)