| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| from typing import Optional, List, Tuple |
|
|
| from openfold.model.primitives import ( |
| Linear, |
| LayerNorm, |
| Attention, |
| GlobalAttention, |
| _attention_chunked_trainable, |
| ) |
| from openfold.utils.checkpointing import get_checkpoint_fn |
| from openfold.utils.tensor_utils import ( |
| chunk_layer, |
| permute_final_dims, |
| flatten_final_dims, |
| ) |
|
|
|
|
| class MSAAttention(nn.Module): |
| def __init__( |
| self, |
| c_in, |
| c_hidden, |
| no_heads, |
| pair_bias=False, |
| c_z=None, |
| inf=1e9, |
| ): |
| """ |
| Args: |
| c_in: |
| Input channel dimension |
| c_hidden: |
| Per-head hidden channel dimension |
| no_heads: |
| Number of attention heads |
| pair_bias: |
| Whether to use pair embedding bias |
| c_z: |
| Pair embedding channel dimension. Ignored unless pair_bias |
| is true |
| inf: |
| A large number to be used in computing the attention mask |
| """ |
| super(MSAAttention, self).__init__() |
|
|
| self.c_in = c_in |
| self.c_hidden = c_hidden |
| self.no_heads = no_heads |
| self.pair_bias = pair_bias |
| self.c_z = c_z |
| self.inf = inf |
|
|
| self.layer_norm_m = LayerNorm(self.c_in) |
|
|
| self.layer_norm_z = None |
| self.linear_z = None |
| if self.pair_bias: |
| self.layer_norm_z = LayerNorm(self.c_z) |
| self.linear_z = Linear( |
| self.c_z, self.no_heads, bias=False, init="normal" |
| ) |
| |
| self.mha = Attention( |
| self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads |
| ) |
|
|
| @torch.jit.ignore |
| def _chunk(self, |
| m: torch.Tensor, |
| biases: List[torch.Tensor], |
| chunk_size: int, |
| ) -> torch.Tensor: |
| return chunk_layer( |
| self.mha, |
| {"q_x": m, "kv_x": m, "biases": biases}, |
| chunk_size=chunk_size, |
| no_batch_dims=len(m.shape[:-2]), |
| ) |
|
|
| def _prep_inputs(self, |
| m: torch.Tensor, |
| z: Optional[torch.Tensor], |
| mask: Optional[torch.Tensor] |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| |
| m = self.layer_norm_m(m) |
|
|
| n_seq, n_res = m.shape[-3:-1] |
| if mask is None: |
| |
| mask = m.new_ones( |
| m.shape[:-3] + (n_seq, n_res), |
| ) |
|
|
| |
| mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] |
|
|
| |
| |
| |
| |
| |
| |
|
|
| if (self.pair_bias and |
| z is not None and |
| self.layer_norm_z is not None and |
| self.linear_z is not None |
| ): |
| |
| z = self.layer_norm_z(z) |
| |
| |
| z = self.linear_z(z) |
| |
| |
| z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4) |
|
|
| return m, mask_bias, z |
|
|
| @torch.jit.ignore |
| def _chunked_msa_attn(self, |
| m: torch.Tensor, |
| z: Optional[torch.Tensor], |
| mask: Optional[torch.Tensor], |
| chunk_logits: int, |
| checkpoint: bool, |
| ) -> torch.Tensor: |
| MSA_DIM = -4 |
|
|
| def _get_qkv(m, z): |
| m, mask_bias, z = self._prep_inputs(m, z, mask) |
| q, k, v = self.mha._prep_qkv(m, m) |
| return m, q, k, v, mask_bias, z |
|
|
| checkpoint_fn = get_checkpoint_fn() |
|
|
| if(torch.is_grad_enabled() and checkpoint): |
| m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z) |
| else: |
| m, q, k, v, mask_bias, z = _get_qkv(m, z) |
| |
| o = _attention_chunked_trainable( |
| query=q, |
| key=k, |
| value=v, |
| biases=[mask_bias, z], |
| chunk_size=chunk_logits, |
| chunk_dim=MSA_DIM, |
| checkpoint=checkpoint, |
| ) |
|
|
| if(torch.is_grad_enabled() and checkpoint): |
| |
| m = checkpoint_fn(self.mha._wrap_up, o, m) |
| else: |
| m = self.mha._wrap_up(o, m) |
|
|
| return m |
|
|
| def forward(self, |
| m: torch.Tensor, |
| z: Optional[torch.Tensor] = None, |
| mask: Optional[torch.Tensor] = None, |
| chunk_size: Optional[int] = None, |
| _chunk_logits: Optional[int] = None, |
| _checkpoint_chunks: Optional[bool] = None, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| m: |
| [*, N_seq, N_res, C_m] MSA embedding |
| z: |
| [*, N_res, N_res, C_z] pair embedding. Required only if |
| pair_bias is True |
| mask: |
| [*, N_seq, N_res] MSA mask |
| chunk_size: |
| Size of chunks into which the inputs are split along their |
| batch dimensions. A low value decreases memory overhead at the |
| cost of slower execution. Chunking is not performed by default. |
| |
| """ |
| if(_chunk_logits is not None): |
| return self._chunked_msa_attn( |
| m=m, z=z, mask=mask, |
| chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks |
| ) |
|
|
| m, mask_bias, z = self._prep_inputs(m, z, mask) |
|
|
| biases = [mask_bias] |
| if(z is not None): |
| biases.append(z) |
|
|
| if chunk_size is not None: |
| m = self._chunk(m, biases, chunk_size) |
| else: |
| m = self.mha( |
| q_x=m, |
| kv_x=m, |
| biases=biases |
| ) |
|
|
| return m |
|
|
|
|
| class MSARowAttentionWithPairBias(MSAAttention): |
| """ |
| Implements Algorithm 7. |
| """ |
|
|
| def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9): |
| """ |
| Args: |
| c_m: |
| Input channel dimension |
| c_z: |
| Pair embedding channel dimension |
| c_hidden: |
| Per-head hidden channel dimension |
| no_heads: |
| Number of attention heads |
| inf: |
| Large number used to construct attention masks |
| """ |
| super(MSARowAttentionWithPairBias, self).__init__( |
| c_m, |
| c_hidden, |
| no_heads, |
| pair_bias=True, |
| c_z=c_z, |
| inf=inf, |
| ) |
|
|
|
|
| class MSAColumnAttention(nn.Module): |
| """ |
| Implements Algorithm 8. |
| |
| By rights, this should also be a subclass of MSAAttention. Alas, |
| most inheritance isn't supported by TorchScript. |
| """ |
|
|
| def __init__(self, c_m, c_hidden, no_heads, inf=1e9): |
| """ |
| Args: |
| c_m: |
| MSA channel dimension |
| c_hidden: |
| Per-head hidden channel dimension |
| no_heads: |
| Number of attention heads |
| inf: |
| Large number used to construct attention masks |
| """ |
| super(MSAColumnAttention, self).__init__() |
| |
| self.c_m = c_m |
| self.c_hidden = c_hidden |
| self.no_heads = no_heads |
| self.inf = inf |
|
|
| self._msa_att = MSAAttention( |
| c_in=c_m, |
| c_hidden=c_hidden, |
| no_heads=no_heads, |
| pair_bias=False, |
| c_z=None, |
| inf=inf, |
| ) |
|
|
| def forward(self, |
| m: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| chunk_size: Optional[int] = None |
| ) -> torch.Tensor: |
| """ |
| Args: |
| m: |
| [*, N_seq, N_res, C_m] MSA embedding |
| mask: |
| [*, N_seq, N_res] MSA mask |
| chunk_size: |
| Size of chunks into which the inputs are split along their |
| batch dimensions. A low value decreases memory overhead at the |
| cost of slower execution. Chunking is not performed by default. |
| """ |
| |
| m = m.transpose(-2, -3) |
| if mask is not None: |
| mask = mask.transpose(-1, -2) |
|
|
| m = self._msa_att(m, mask=mask, chunk_size=chunk_size) |
|
|
| |
| m = m.transpose(-2, -3) |
| if mask is not None: |
| mask = mask.transpose(-1, -2) |
|
|
| return m |
|
|
|
|
| class MSAColumnGlobalAttention(nn.Module): |
| def __init__( |
| self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10, |
| ): |
| super(MSAColumnGlobalAttention, self).__init__() |
|
|
| self.c_in = c_in |
| self.c_hidden = c_hidden |
| self.no_heads = no_heads |
| self.inf = inf |
| self.eps = eps |
|
|
| self.layer_norm_m = nn.LayerNorm(c_in) |
|
|
| self.global_attention = GlobalAttention( |
| c_in=c_in, |
| c_hidden=c_hidden, |
| no_heads=no_heads, |
| inf=inf, |
| eps=eps, |
| ) |
|
|
| @torch.jit.ignore |
| def _chunk(self, |
| m: torch.Tensor, |
| mask: torch.Tensor, |
| chunk_size: int, |
| ) -> torch.Tensor: |
| mha_input = { |
| "m": m, |
| "mask": mask, |
| } |
| return chunk_layer( |
| self.global_attention, |
| mha_input, |
| chunk_size=chunk_size, |
| no_batch_dims=len(m.shape[:-2]), |
| ) |
|
|
| def forward( |
| self, |
| m: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| chunk_size: Optional[int] = None, |
| ) -> torch.Tensor: |
| n_seq, n_res, c_in = m.shape[-3:] |
|
|
| if mask is None: |
| |
| mask = torch.ones( |
| m.shape[:-1], |
| dtype=m.dtype, |
| device=m.device, |
| ).detach() |
|
|
| |
| m = m.transpose(-2, -3) |
| mask = mask.transpose(-1, -2) |
|
|
| |
| m = self.layer_norm_m(m) |
|
|
| if chunk_size is not None: |
| m = self._chunk(m, mask, chunk_size) |
| else: |
| m = self.global_attention(m=m, mask=mask) |
|
|
| |
| m = m.transpose(-2, -3) |
|
|
| return m |
|
|