Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| class RowSelfAttention(nn.Module): | |
| """Compute self-attention over rows of a 2D input.""" | |
| def __init__( | |
| self, | |
| embed_dim, | |
| num_heads, | |
| dropout=0.0, | |
| max_tokens_per_msa: int = 2 ** 16, | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.dropout = dropout | |
| self.head_dim = embed_dim // num_heads | |
| self.scaling = self.head_dim ** -0.5 | |
| self.max_tokens_per_msa = max_tokens_per_msa | |
| self.attn_shape = "hnij" | |
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim) | |
| self.dropout_module = nn.Dropout(dropout) | |
| def align_scaling(self, q): | |
| num_rows = q.size(0) | |
| return self.scaling / math.sqrt(num_rows) | |
| def _batched_forward( | |
| self, | |
| x, | |
| self_attn_mask=None, | |
| self_attn_padding_mask=None, | |
| ): | |
| num_rows, num_cols, batch_size, embed_dim = x.size() | |
| max_rows = max(1, self.max_tokens_per_msa // num_cols) | |
| attns = 0 | |
| scaling = self.align_scaling(x) | |
| for start in range(0, num_rows, max_rows): | |
| attn_weights = self.compute_attention_weights( | |
| x[start : start + max_rows], | |
| scaling, | |
| self_attn_mask=self_attn_mask, | |
| self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows] | |
| if self_attn_padding_mask is not None | |
| else None, | |
| ) | |
| attns += attn_weights | |
| attn_probs = attns.softmax(-1) | |
| attn_probs = self.dropout_module(attn_probs) | |
| outputs = [] | |
| for start in range(0, num_rows, max_rows): | |
| output = self.compute_attention_update(x[start : start + max_rows], attn_probs) | |
| outputs.append(output) | |
| output = torch.cat(outputs, 0) | |
| return output, attn_probs | |
| def compute_attention_weights( | |
| self, | |
| x, | |
| scaling: float, | |
| self_attn_mask=None, | |
| self_attn_padding_mask=None, | |
| ): | |
| num_rows, num_cols, batch_size, embed_dim = x.size() | |
| q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) | |
| k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) | |
| q *= scaling | |
| if self_attn_padding_mask is not None: | |
| # Zero out any padded aligned positions - this is important since | |
| # we take a sum across the alignment axis. | |
| q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q) | |
| attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k) | |
| if self_attn_mask is not None: | |
| raise NotImplementedError | |
| # Mask Size: [B x R x C], Weights Size: [H x B x C x C] | |
| if self_attn_padding_mask is not None: | |
| attn_weights = attn_weights.masked_fill( | |
| self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2), | |
| -10000, | |
| ) | |
| return attn_weights | |
| def compute_attention_update( | |
| self, | |
| x, | |
| attn_probs, | |
| ): | |
| num_rows, num_cols, batch_size, embed_dim = x.size() | |
| v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) | |
| context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v) | |
| context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim) | |
| output = self.out_proj(context) | |
| return output | |
| def forward( | |
| self, | |
| x, | |
| self_attn_mask=None, | |
| self_attn_padding_mask=None, | |
| ): | |
| num_rows, num_cols, batch_size, embed_dim = x.size() | |
| if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled(): | |
| return self._batched_forward(x, self_attn_mask, self_attn_padding_mask) | |
| else: | |
| scaling = self.align_scaling(x) | |
| attn_weights = self.compute_attention_weights( | |
| x, scaling, self_attn_mask, self_attn_padding_mask | |
| ) | |
| attn_probs = attn_weights.softmax(-1) | |
| attn_probs = self.dropout_module(attn_probs) | |
| output = self.compute_attention_update(x, attn_probs) | |
| return output, attn_probs | |
| class ColumnSelfAttention(nn.Module): | |
| """Compute self-attention over columns of a 2D input.""" | |
| def __init__( | |
| self, | |
| embed_dim, | |
| num_heads, | |
| dropout=0.0, | |
| max_tokens_per_msa: int = 2 ** 16, | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.dropout = dropout | |
| self.head_dim = embed_dim // num_heads | |
| self.scaling = self.head_dim ** -0.5 | |
| self.max_tokens_per_msa = max_tokens_per_msa | |
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |
| self.out_proj = nn.Linear(embed_dim, embed_dim) | |
| self.dropout_module = nn.Dropout(dropout) | |
| def _batched_forward( | |
| self, | |
| x, | |
| self_attn_mask=None, | |
| self_attn_padding_mask=None, | |
| ): | |
| num_rows, num_cols, batch_size, embed_dim = x.size() | |
| max_cols = max(1, self.max_tokens_per_msa // num_rows) | |
| outputs = [] | |
| attns = [] | |
| for start in range(0, num_cols, max_cols): | |
| output, attn = self( | |
| x[:, start : start + max_cols], | |
| self_attn_mask=self_attn_mask, | |
| self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols] | |
| if self_attn_padding_mask is not None | |
| else None, | |
| ) | |
| outputs.append(output) | |
| attns.append(attn) | |
| output = torch.cat(outputs, 1) | |
| attns = torch.cat(attns, 1) | |
| return output, attns | |
| def compute_attention_update( | |
| self, | |
| x, | |
| self_attn_mask=None, | |
| self_attn_padding_mask=None, | |
| ): | |
| num_rows, num_cols, batch_size, embed_dim = x.size() | |
| if num_rows == 1: | |
| # if there is only 1 position, this is equivalent and doesn't break with padding | |
| attn_probs = torch.ones( | |
| self.num_heads, | |
| num_cols, | |
| batch_size, | |
| num_rows, | |
| num_rows, | |
| device=x.device, | |
| dtype=x.dtype, | |
| ) | |
| output = self.out_proj(self.v_proj(x)) | |
| else: | |
| q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) | |
| k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) | |
| v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) | |
| q *= self.scaling | |
| attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k) | |
| if self_attn_mask is not None: | |
| raise NotImplementedError | |
| if self_attn_padding_mask is not None: | |
| attn_weights = attn_weights.masked_fill( | |
| self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3), | |
| -10000, | |
| ) | |
| attn_probs = attn_weights.softmax(-1) | |
| attn_probs = self.dropout_module(attn_probs) | |
| context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v) | |
| context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim) | |
| output = self.out_proj(context) | |
| return output, attn_probs | |
| def forward( | |
| self, | |
| x, | |
| self_attn_mask=None, | |
| self_attn_padding_mask=None, | |
| ): | |
| num_rows, num_cols, batch_size, embed_dim = x.size() | |
| # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled(): | |
| if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled(): | |
| return self._batched_forward( | |
| x, | |
| self_attn_mask, | |
| self_attn_padding_mask, | |
| ) | |
| else: | |
| return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask) | |