|
|
|
|
|
|
|
|
|
|
|
import math |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class RowSelfAttention(nn.Module): |
|
"""Compute self-attention over rows of a 2D input.""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim, |
|
num_heads, |
|
dropout=0.0, |
|
max_tokens_per_msa: int = 2 ** 16, |
|
): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
self.head_dim = embed_dim // num_heads |
|
self.scaling = self.head_dim ** -0.5 |
|
self.max_tokens_per_msa = max_tokens_per_msa |
|
self.attn_shape = "hnij" |
|
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim) |
|
self.v_proj = nn.Linear(embed_dim, embed_dim) |
|
self.q_proj = nn.Linear(embed_dim, embed_dim) |
|
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim) |
|
self.dropout_module = nn.Dropout(dropout) |
|
|
|
def align_scaling(self, q): |
|
num_rows = q.size(0) |
|
return self.scaling / math.sqrt(num_rows) |
|
|
|
def _batched_forward( |
|
self, |
|
x, |
|
self_attn_mask=None, |
|
self_attn_padding_mask=None, |
|
): |
|
num_rows, num_cols, batch_size, embed_dim = x.size() |
|
max_rows = max(1, self.max_tokens_per_msa // num_cols) |
|
attns = 0 |
|
scaling = self.align_scaling(x) |
|
for start in range(0, num_rows, max_rows): |
|
attn_weights = self.compute_attention_weights( |
|
x[start : start + max_rows], |
|
scaling, |
|
self_attn_mask=self_attn_mask, |
|
self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows] |
|
if self_attn_padding_mask is not None |
|
else None, |
|
) |
|
attns += attn_weights |
|
attn_probs = attns.softmax(-1) |
|
attn_probs = self.dropout_module(attn_probs) |
|
|
|
outputs = [] |
|
for start in range(0, num_rows, max_rows): |
|
output = self.compute_attention_update(x[start : start + max_rows], attn_probs) |
|
outputs.append(output) |
|
|
|
output = torch.cat(outputs, 0) |
|
return output, attn_probs |
|
|
|
def compute_attention_weights( |
|
self, |
|
x, |
|
scaling: float, |
|
self_attn_mask=None, |
|
self_attn_padding_mask=None, |
|
): |
|
num_rows, num_cols, batch_size, embed_dim = x.size() |
|
q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) |
|
k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) |
|
q *= scaling |
|
if self_attn_padding_mask is not None: |
|
|
|
|
|
q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q) |
|
|
|
attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k) |
|
|
|
if self_attn_mask is not None: |
|
raise NotImplementedError |
|
|
|
|
|
if self_attn_padding_mask is not None: |
|
attn_weights = attn_weights.masked_fill( |
|
self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2), |
|
-10000, |
|
) |
|
|
|
return attn_weights |
|
|
|
def compute_attention_update( |
|
self, |
|
x, |
|
attn_probs, |
|
): |
|
num_rows, num_cols, batch_size, embed_dim = x.size() |
|
v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) |
|
context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v) |
|
context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim) |
|
output = self.out_proj(context) |
|
return output |
|
|
|
def forward( |
|
self, |
|
x, |
|
self_attn_mask=None, |
|
self_attn_padding_mask=None, |
|
): |
|
num_rows, num_cols, batch_size, embed_dim = x.size() |
|
if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled(): |
|
return self._batched_forward(x, self_attn_mask, self_attn_padding_mask) |
|
else: |
|
scaling = self.align_scaling(x) |
|
attn_weights = self.compute_attention_weights( |
|
x, scaling, self_attn_mask, self_attn_padding_mask |
|
) |
|
attn_probs = attn_weights.softmax(-1) |
|
attn_probs = self.dropout_module(attn_probs) |
|
output = self.compute_attention_update(x, attn_probs) |
|
return output, attn_probs |
|
|
|
|
|
class ColumnSelfAttention(nn.Module): |
|
"""Compute self-attention over columns of a 2D input.""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim, |
|
num_heads, |
|
dropout=0.0, |
|
max_tokens_per_msa: int = 2 ** 16, |
|
): |
|
super().__init__() |
|
|
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
self.head_dim = embed_dim // num_heads |
|
self.scaling = self.head_dim ** -0.5 |
|
self.max_tokens_per_msa = max_tokens_per_msa |
|
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim) |
|
self.v_proj = nn.Linear(embed_dim, embed_dim) |
|
self.q_proj = nn.Linear(embed_dim, embed_dim) |
|
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim) |
|
self.dropout_module = nn.Dropout(dropout) |
|
|
|
def _batched_forward( |
|
self, |
|
x, |
|
self_attn_mask=None, |
|
self_attn_padding_mask=None, |
|
): |
|
num_rows, num_cols, batch_size, embed_dim = x.size() |
|
max_cols = max(1, self.max_tokens_per_msa // num_rows) |
|
outputs = [] |
|
attns = [] |
|
for start in range(0, num_cols, max_cols): |
|
output, attn = self( |
|
x[:, start : start + max_cols], |
|
self_attn_mask=self_attn_mask, |
|
self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols] |
|
if self_attn_padding_mask is not None |
|
else None, |
|
) |
|
outputs.append(output) |
|
attns.append(attn) |
|
output = torch.cat(outputs, 1) |
|
attns = torch.cat(attns, 1) |
|
return output, attns |
|
|
|
def compute_attention_update( |
|
self, |
|
x, |
|
self_attn_mask=None, |
|
self_attn_padding_mask=None, |
|
): |
|
num_rows, num_cols, batch_size, embed_dim = x.size() |
|
if num_rows == 1: |
|
|
|
attn_probs = torch.ones( |
|
self.num_heads, |
|
num_cols, |
|
batch_size, |
|
num_rows, |
|
num_rows, |
|
device=x.device, |
|
dtype=x.dtype, |
|
) |
|
output = self.out_proj(self.v_proj(x)) |
|
else: |
|
q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) |
|
k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) |
|
v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) |
|
q *= self.scaling |
|
|
|
attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k) |
|
|
|
if self_attn_mask is not None: |
|
raise NotImplementedError |
|
if self_attn_padding_mask is not None: |
|
attn_weights = attn_weights.masked_fill( |
|
self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3), |
|
-10000, |
|
) |
|
|
|
attn_probs = attn_weights.softmax(-1) |
|
attn_probs = self.dropout_module(attn_probs) |
|
context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v) |
|
context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim) |
|
output = self.out_proj(context) |
|
return output, attn_probs |
|
|
|
def forward( |
|
self, |
|
x, |
|
self_attn_mask=None, |
|
self_attn_padding_mask=None, |
|
): |
|
num_rows, num_cols, batch_size, embed_dim = x.size() |
|
|
|
if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled(): |
|
return self._batched_forward( |
|
x, |
|
self_attn_mask, |
|
self_attn_padding_mask, |
|
) |
|
else: |
|
return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask) |
|
|