| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from typing import Tuple |
| |
|
| | from openfold.model.primitives import Linear, LayerNorm |
| | from openfold.utils.tensor_utils import one_hot |
| |
|
| |
|
| | class InputEmbedder(nn.Module): |
| | """ |
| | Embeds a subset of the input features. |
| | |
| | Implements Algorithms 3 (InputEmbedder) and 4 (relpos). |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | tf_dim: int, |
| | msa_dim: int, |
| | c_z: int, |
| | c_m: int, |
| | relpos_k: int, |
| | **kwargs, |
| | ): |
| | """ |
| | Args: |
| | tf_dim: |
| | Final dimension of the target features |
| | msa_dim: |
| | Final dimension of the MSA features |
| | c_z: |
| | Pair embedding dimension |
| | c_m: |
| | MSA embedding dimension |
| | relpos_k: |
| | Window size used in relative positional encoding |
| | """ |
| | super(InputEmbedder, self).__init__() |
| |
|
| | self.tf_dim = tf_dim |
| | self.msa_dim = msa_dim |
| |
|
| | self.c_z = c_z |
| | self.c_m = c_m |
| |
|
| | self.linear_tf_z_i = Linear(tf_dim, c_z) |
| | self.linear_tf_z_j = Linear(tf_dim, c_z) |
| | self.linear_tf_m = Linear(tf_dim, c_m) |
| | self.linear_msa_m = Linear(msa_dim, c_m) |
| |
|
| | |
| | self.relpos_k = relpos_k |
| | self.no_bins = 2 * relpos_k + 1 |
| | self.linear_relpos = Linear(self.no_bins, c_z) |
| |
|
| | def relpos(self, ri: torch.Tensor): |
| | """ |
| | Computes relative positional encodings |
| | |
| | Implements Algorithm 4. |
| | |
| | Args: |
| | ri: |
| | "residue_index" features of shape [*, N] |
| | """ |
| | d = ri[..., None] - ri[..., None, :] |
| | boundaries = torch.arange( |
| | start=-self.relpos_k, end=self.relpos_k + 1, device=d.device |
| | ) |
| | oh = one_hot(d, boundaries).type(ri.dtype) |
| | return self.linear_relpos(oh) |
| |
|
| | def forward( |
| | self, |
| | tf: torch.Tensor, |
| | ri: torch.Tensor, |
| | msa: torch.Tensor, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Args: |
| | tf: |
| | "target_feat" features of shape [*, N_res, tf_dim] |
| | ri: |
| | "residue_index" features of shape [*, N_res] |
| | msa: |
| | "msa_feat" features of shape [*, N_clust, N_res, msa_dim] |
| | Returns: |
| | msa_emb: |
| | [*, N_clust, N_res, C_m] MSA embedding |
| | pair_emb: |
| | [*, N_res, N_res, C_z] pair embedding |
| | |
| | """ |
| | |
| | tf_emb_i = self.linear_tf_z_i(tf) |
| | tf_emb_j = self.linear_tf_z_j(tf) |
| |
|
| | |
| | pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] |
| | pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype)) |
| |
|
| | |
| | n_clust = msa.shape[-3] |
| | tf_m = ( |
| | self.linear_tf_m(tf) |
| | .unsqueeze(-3) |
| | .expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1))) |
| | ) |
| | msa_emb = self.linear_msa_m(msa) + tf_m |
| |
|
| | return msa_emb, pair_emb |
| |
|
| |
|
| | class RecyclingEmbedder(nn.Module): |
| | """ |
| | Embeds the output of an iteration of the model for recycling. |
| | |
| | Implements Algorithm 32. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | c_m: int, |
| | c_z: int, |
| | min_bin: float, |
| | max_bin: float, |
| | no_bins: int, |
| | inf: float = 1e8, |
| | **kwargs, |
| | ): |
| | """ |
| | Args: |
| | c_m: |
| | MSA channel dimension |
| | c_z: |
| | Pair embedding channel dimension |
| | min_bin: |
| | Smallest distogram bin (Angstroms) |
| | max_bin: |
| | Largest distogram bin (Angstroms) |
| | no_bins: |
| | Number of distogram bins |
| | """ |
| | super(RecyclingEmbedder, self).__init__() |
| |
|
| | self.c_m = c_m |
| | self.c_z = c_z |
| | self.min_bin = min_bin |
| | self.max_bin = max_bin |
| | self.no_bins = no_bins |
| | self.inf = inf |
| |
|
| | self.bins = None |
| |
|
| | self.linear = Linear(self.no_bins, self.c_z) |
| | self.layer_norm_m = LayerNorm(self.c_m) |
| | self.layer_norm_z = LayerNorm(self.c_z) |
| |
|
| | def forward( |
| | self, |
| | m: torch.Tensor, |
| | z: torch.Tensor, |
| | x: torch.Tensor, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Args: |
| | m: |
| | First row of the MSA embedding. [*, N_res, C_m] |
| | z: |
| | [*, N_res, N_res, C_z] pair embedding |
| | x: |
| | [*, N_res, 3] predicted C_beta coordinates |
| | Returns: |
| | m: |
| | [*, N_res, C_m] MSA embedding update |
| | z: |
| | [*, N_res, N_res, C_z] pair embedding update |
| | """ |
| | if self.bins is None: |
| | self.bins = torch.linspace( |
| | self.min_bin, |
| | self.max_bin, |
| | self.no_bins, |
| | dtype=x.dtype, |
| | device=x.device, |
| | requires_grad=False, |
| | ) |
| |
|
| | |
| | m_update = self.layer_norm_m(m) |
| |
|
| | |
| | |
| | |
| | squared_bins = self.bins ** 2 |
| | upper = torch.cat( |
| | [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1 |
| | ) |
| | d = torch.sum( |
| | (x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True |
| | ) |
| |
|
| | |
| | d = ((d > squared_bins) * (d < upper)).type(x.dtype) |
| |
|
| | |
| | d = self.linear(d) |
| | z_update = d + self.layer_norm_z(z) |
| |
|
| | return m_update, z_update |
| |
|
| |
|
| | class TemplateAngleEmbedder(nn.Module): |
| | """ |
| | Embeds the "template_angle_feat" feature. |
| | |
| | Implements Algorithm 2, line 7. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | c_in: int, |
| | c_out: int, |
| | **kwargs, |
| | ): |
| | """ |
| | Args: |
| | c_in: |
| | Final dimension of "template_angle_feat" |
| | c_out: |
| | Output channel dimension |
| | """ |
| | super(TemplateAngleEmbedder, self).__init__() |
| |
|
| | self.c_out = c_out |
| | self.c_in = c_in |
| |
|
| | self.linear_1 = Linear(self.c_in, self.c_out, init="relu") |
| | self.relu = nn.ReLU() |
| | self.linear_2 = Linear(self.c_out, self.c_out, init="relu") |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | x: [*, N_templ, N_res, c_in] "template_angle_feat" features |
| | Returns: |
| | x: [*, N_templ, N_res, C_out] embedding |
| | """ |
| | x = self.linear_1(x) |
| | x = self.relu(x) |
| | x = self.linear_2(x) |
| |
|
| | return x |
| |
|
| |
|
| | class TemplatePairEmbedder(nn.Module): |
| | """ |
| | Embeds "template_pair_feat" features. |
| | |
| | Implements Algorithm 2, line 9. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | c_in: int, |
| | c_out: int, |
| | **kwargs, |
| | ): |
| | """ |
| | Args: |
| | c_in: |
| | |
| | c_out: |
| | Output channel dimension |
| | """ |
| | super(TemplatePairEmbedder, self).__init__() |
| |
|
| | self.c_in = c_in |
| | self.c_out = c_out |
| |
|
| | |
| | self.linear = Linear(self.c_in, self.c_out, init="relu") |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Args: |
| | x: |
| | [*, C_in] input tensor |
| | Returns: |
| | [*, C_out] output tensor |
| | """ |
| | x = self.linear(x) |
| |
|
| | return x |
| |
|
| |
|
| | class ExtraMSAEmbedder(nn.Module): |
| | """ |
| | Embeds unclustered MSA sequences. |
| | |
| | Implements Algorithm 2, line 15 |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | c_in: int, |
| | c_out: int, |
| | **kwargs, |
| | ): |
| | """ |
| | Args: |
| | c_in: |
| | Input channel dimension |
| | c_out: |
| | Output channel dimension |
| | """ |
| | super(ExtraMSAEmbedder, self).__init__() |
| |
|
| | self.c_in = c_in |
| | self.c_out = c_out |
| |
|
| | self.linear = Linear(self.c_in, self.c_out) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | x: |
| | [*, N_extra_seq, N_res, C_in] "extra_msa_feat" features |
| | Returns: |
| | [*, N_extra_seq, N_res, C_out] embedding |
| | """ |
| | x = self.linear(x) |
| |
|
| | return x |
| |
|