|
|
import torch
|
|
|
from torch import Tensor, nn
|
|
|
from torch.nn.functional import one_hot
|
|
|
|
|
|
from . import vb_const as const
|
|
|
from .vb_layers_outer_product_mean import OuterProductMean
|
|
|
from .vb_layers_pair_averaging import PairWeightedAveraging
|
|
|
from .vb_layers_pairformer import (
|
|
|
PairformerNoSeqLayer,
|
|
|
PairformerNoSeqModule,
|
|
|
get_dropout_mask,
|
|
|
)
|
|
|
from .vb_layers_transition import Transition
|
|
|
from .vb_modules_encodersv2 import (
|
|
|
AtomAttentionEncoder,
|
|
|
AtomEncoder,
|
|
|
FourierEmbedding,
|
|
|
)
|
|
|
|
|
|
|
|
|
class ContactConditioning(nn.Module):
|
|
|
def __init__(self, token_z: int, cutoff_min: float, cutoff_max: float):
|
|
|
super().__init__()
|
|
|
|
|
|
self.fourier_embedding = FourierEmbedding(token_z)
|
|
|
self.encoder = nn.Linear(
|
|
|
token_z + len(const.contact_conditioning_info) - 1, token_z
|
|
|
)
|
|
|
self.encoding_unspecified = nn.Parameter(torch.zeros(token_z))
|
|
|
self.encoding_unselected = nn.Parameter(torch.zeros(token_z))
|
|
|
self.cutoff_min = cutoff_min
|
|
|
self.cutoff_max = cutoff_max
|
|
|
|
|
|
def forward(self, feats):
|
|
|
assert const.contact_conditioning_info["UNSPECIFIED"] == 0
|
|
|
assert const.contact_conditioning_info["UNSELECTED"] == 1
|
|
|
contact_conditioning = feats["contact_conditioning"][:, :, :, 2:]
|
|
|
contact_threshold = feats["contact_threshold"]
|
|
|
contact_threshold_normalized = (contact_threshold - self.cutoff_min) / (
|
|
|
self.cutoff_max - self.cutoff_min
|
|
|
)
|
|
|
contact_threshold_fourier = self.fourier_embedding(
|
|
|
contact_threshold_normalized.flatten()
|
|
|
).reshape(contact_threshold_normalized.shape + (-1,))
|
|
|
|
|
|
contact_conditioning = torch.cat(
|
|
|
[
|
|
|
contact_conditioning,
|
|
|
contact_threshold_normalized.unsqueeze(-1),
|
|
|
contact_threshold_fourier,
|
|
|
],
|
|
|
dim=-1,
|
|
|
)
|
|
|
contact_conditioning = self.encoder(contact_conditioning)
|
|
|
|
|
|
contact_conditioning = (
|
|
|
contact_conditioning
|
|
|
* (
|
|
|
1
|
|
|
- feats["contact_conditioning"][:, :, :, 0:2].sum(dim=-1, keepdim=True)
|
|
|
)
|
|
|
+ self.encoding_unspecified * feats["contact_conditioning"][:, :, :, 0:1]
|
|
|
+ self.encoding_unselected * feats["contact_conditioning"][:, :, :, 1:2]
|
|
|
)
|
|
|
return contact_conditioning
|
|
|
|
|
|
|
|
|
class InputEmbedder(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
atom_s: int,
|
|
|
atom_z: int,
|
|
|
token_s: int,
|
|
|
token_z: int,
|
|
|
atoms_per_window_queries: int,
|
|
|
atoms_per_window_keys: int,
|
|
|
atom_feature_dim: int,
|
|
|
atom_encoder_depth: int,
|
|
|
atom_encoder_heads: int,
|
|
|
activation_checkpointing: bool = False,
|
|
|
add_method_conditioning: bool = False,
|
|
|
add_modified_flag: bool = False,
|
|
|
add_cyclic_flag: bool = False,
|
|
|
add_mol_type_feat: bool = False,
|
|
|
use_no_atom_char: bool = False,
|
|
|
use_atom_backbone_feat: bool = False,
|
|
|
use_residue_feats_atoms: bool = False,
|
|
|
) -> None:
|
|
|
"""Initialize the input embedder.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
atom_s : int
|
|
|
The atom embedding size.
|
|
|
atom_z : int
|
|
|
The atom pairwise embedding size.
|
|
|
token_s : int
|
|
|
The token embedding size.
|
|
|
|
|
|
"""
|
|
|
super().__init__()
|
|
|
self.token_s = token_s
|
|
|
self.add_method_conditioning = add_method_conditioning
|
|
|
self.add_modified_flag = add_modified_flag
|
|
|
self.add_cyclic_flag = add_cyclic_flag
|
|
|
self.add_mol_type_feat = add_mol_type_feat
|
|
|
|
|
|
self.atom_encoder = AtomEncoder(
|
|
|
atom_s=atom_s,
|
|
|
atom_z=atom_z,
|
|
|
token_s=token_s,
|
|
|
token_z=token_z,
|
|
|
atoms_per_window_queries=atoms_per_window_queries,
|
|
|
atoms_per_window_keys=atoms_per_window_keys,
|
|
|
atom_feature_dim=atom_feature_dim,
|
|
|
structure_prediction=False,
|
|
|
use_no_atom_char=use_no_atom_char,
|
|
|
use_atom_backbone_feat=use_atom_backbone_feat,
|
|
|
use_residue_feats_atoms=use_residue_feats_atoms,
|
|
|
)
|
|
|
|
|
|
self.atom_enc_proj_z = nn.Sequential(
|
|
|
nn.LayerNorm(atom_z),
|
|
|
nn.Linear(atom_z, atom_encoder_depth * atom_encoder_heads, bias=False),
|
|
|
)
|
|
|
|
|
|
self.atom_attention_encoder = AtomAttentionEncoder(
|
|
|
atom_s=atom_s,
|
|
|
token_s=token_s,
|
|
|
atoms_per_window_queries=atoms_per_window_queries,
|
|
|
atoms_per_window_keys=atoms_per_window_keys,
|
|
|
atom_encoder_depth=atom_encoder_depth,
|
|
|
atom_encoder_heads=atom_encoder_heads,
|
|
|
structure_prediction=False,
|
|
|
activation_checkpointing=activation_checkpointing,
|
|
|
)
|
|
|
|
|
|
self.res_type_encoding = nn.Linear(const.num_tokens, token_s, bias=False)
|
|
|
self.msa_profile_encoding = nn.Linear(const.num_tokens + 1, token_s, bias=False)
|
|
|
|
|
|
if add_method_conditioning:
|
|
|
self.method_conditioning_init = nn.Embedding(
|
|
|
const.num_method_types, token_s
|
|
|
)
|
|
|
self.method_conditioning_init.weight.data.fill_(0)
|
|
|
if add_modified_flag:
|
|
|
self.modified_conditioning_init = nn.Embedding(2, token_s)
|
|
|
self.modified_conditioning_init.weight.data.fill_(0)
|
|
|
if add_cyclic_flag:
|
|
|
self.cyclic_conditioning_init = nn.Linear(1, token_s, bias=False)
|
|
|
self.cyclic_conditioning_init.weight.data.fill_(0)
|
|
|
if add_mol_type_feat:
|
|
|
self.mol_type_conditioning_init = nn.Embedding(
|
|
|
len(const.chain_type_ids), token_s
|
|
|
)
|
|
|
self.mol_type_conditioning_init.weight.data.fill_(0)
|
|
|
|
|
|
def forward(self, feats: dict[str, Tensor], affinity: bool = False) -> Tensor:
|
|
|
"""Perform the forward pass.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
feats : dict[str, Tensor]
|
|
|
Input features
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
Tensor
|
|
|
The embedded tokens.
|
|
|
|
|
|
"""
|
|
|
|
|
|
res_type = feats["res_type"].float()
|
|
|
if affinity:
|
|
|
profile = feats["profile_affinity"]
|
|
|
deletion_mean = feats["deletion_mean_affinity"].unsqueeze(-1)
|
|
|
else:
|
|
|
profile = feats["profile"]
|
|
|
deletion_mean = feats["deletion_mean"].unsqueeze(-1)
|
|
|
|
|
|
|
|
|
q, c, p, to_keys = self.atom_encoder(feats)
|
|
|
atom_enc_bias = self.atom_enc_proj_z(p)
|
|
|
a, _, _, _ = self.atom_attention_encoder(
|
|
|
feats=feats,
|
|
|
q=q,
|
|
|
c=c,
|
|
|
atom_enc_bias=atom_enc_bias,
|
|
|
to_keys=to_keys,
|
|
|
)
|
|
|
|
|
|
s = (
|
|
|
a
|
|
|
+ self.res_type_encoding(res_type)
|
|
|
+ self.msa_profile_encoding(torch.cat([profile, deletion_mean], dim=-1))
|
|
|
)
|
|
|
|
|
|
if self.add_method_conditioning:
|
|
|
s = s + self.method_conditioning_init(feats["method_feature"])
|
|
|
if self.add_modified_flag:
|
|
|
s = s + self.modified_conditioning_init(feats["modified"])
|
|
|
if self.add_cyclic_flag:
|
|
|
cyclic = feats["cyclic_period"].clamp(max=1.0).unsqueeze(-1)
|
|
|
s = s + self.cyclic_conditioning_init(cyclic)
|
|
|
if self.add_mol_type_feat:
|
|
|
s = s + self.mol_type_conditioning_init(feats["mol_type"])
|
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
class TemplateModule(nn.Module):
|
|
|
"""Template module."""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
token_z: int,
|
|
|
template_dim: int,
|
|
|
template_blocks: int,
|
|
|
dropout: float = 0.25,
|
|
|
pairwise_head_width: int = 32,
|
|
|
pairwise_num_heads: int = 4,
|
|
|
post_layer_norm: bool = False,
|
|
|
activation_checkpointing: bool = False,
|
|
|
min_dist: float = 3.25,
|
|
|
max_dist: float = 50.75,
|
|
|
num_bins: int = 38,
|
|
|
**kwargs,
|
|
|
) -> None:
|
|
|
"""Initialize the template module.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
token_z : int
|
|
|
The token pairwise embedding size.
|
|
|
|
|
|
"""
|
|
|
super().__init__()
|
|
|
self.min_dist = min_dist
|
|
|
self.max_dist = max_dist
|
|
|
self.num_bins = num_bins
|
|
|
self.relu = nn.ReLU()
|
|
|
self.z_norm = nn.LayerNorm(token_z)
|
|
|
self.v_norm = nn.LayerNorm(template_dim)
|
|
|
self.z_proj = nn.Linear(token_z, template_dim, bias=False)
|
|
|
self.a_proj = nn.Linear(
|
|
|
const.num_tokens * 2 + num_bins + 5,
|
|
|
template_dim,
|
|
|
bias=False,
|
|
|
)
|
|
|
self.u_proj = nn.Linear(template_dim, token_z, bias=False)
|
|
|
self.pairformer = PairformerNoSeqModule(
|
|
|
template_dim,
|
|
|
num_blocks=template_blocks,
|
|
|
dropout=dropout,
|
|
|
pairwise_head_width=pairwise_head_width,
|
|
|
pairwise_num_heads=pairwise_num_heads,
|
|
|
post_layer_norm=post_layer_norm,
|
|
|
activation_checkpointing=activation_checkpointing,
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
z: Tensor,
|
|
|
feats: dict[str, Tensor],
|
|
|
pair_mask: Tensor,
|
|
|
use_kernels: bool = False,
|
|
|
) -> Tensor:
|
|
|
"""Perform the forward pass.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
z : Tensor
|
|
|
The pairwise embeddings
|
|
|
feats : dict[str, Tensor]
|
|
|
Input features
|
|
|
pair_mask : Tensor
|
|
|
The pair mask
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
Tensor
|
|
|
The updated pairwise embeddings.
|
|
|
|
|
|
"""
|
|
|
|
|
|
asym_id = feats["asym_id"]
|
|
|
res_type = feats["template_restype"]
|
|
|
frame_rot = feats["template_frame_rot"]
|
|
|
frame_t = feats["template_frame_t"]
|
|
|
frame_mask = feats["template_mask_frame"]
|
|
|
cb_coords = feats["template_cb"]
|
|
|
ca_coords = feats["template_ca"]
|
|
|
cb_mask = feats["template_mask_cb"]
|
|
|
template_mask = feats["template_mask"].any(dim=2).float()
|
|
|
num_templates = template_mask.sum(dim=1)
|
|
|
num_templates = num_templates.clamp(min=1)
|
|
|
|
|
|
|
|
|
b_cb_mask = cb_mask[:, :, :, None] * cb_mask[:, :, None, :]
|
|
|
b_frame_mask = frame_mask[:, :, :, None] * frame_mask[:, :, None, :]
|
|
|
|
|
|
b_cb_mask = b_cb_mask[..., None]
|
|
|
b_frame_mask = b_frame_mask[..., None]
|
|
|
|
|
|
|
|
|
B, T = res_type.shape[:2]
|
|
|
asym_mask = (asym_id[:, :, None] == asym_id[:, None, :]).float()
|
|
|
asym_mask = asym_mask[:, None].expand(-1, T, -1, -1)
|
|
|
|
|
|
|
|
|
with torch.autocast(device_type="cuda", enabled=False):
|
|
|
|
|
|
cb_dists = torch.cdist(cb_coords, cb_coords)
|
|
|
boundaries = torch.linspace(self.min_dist, self.max_dist, self.num_bins - 1)
|
|
|
boundaries = boundaries.to(cb_dists.device)
|
|
|
distogram = (cb_dists[..., None] > boundaries).sum(dim=-1).long()
|
|
|
distogram = one_hot(distogram, num_classes=self.num_bins)
|
|
|
|
|
|
|
|
|
frame_rot = frame_rot.unsqueeze(2).transpose(-1, -2)
|
|
|
frame_t = frame_t.unsqueeze(2).unsqueeze(-1)
|
|
|
ca_coords = ca_coords.unsqueeze(3).unsqueeze(-1)
|
|
|
vector = torch.matmul(frame_rot, (ca_coords - frame_t))
|
|
|
norm = torch.norm(vector, dim=-1, keepdim=True)
|
|
|
unit_vector = torch.where(norm > 0, vector / norm, torch.zeros_like(vector))
|
|
|
unit_vector = unit_vector.squeeze(-1)
|
|
|
|
|
|
|
|
|
a_tij = [distogram, b_cb_mask, unit_vector, b_frame_mask]
|
|
|
a_tij = torch.cat(a_tij, dim=-1)
|
|
|
a_tij = a_tij * asym_mask.unsqueeze(-1)
|
|
|
|
|
|
res_type_i = res_type[:, :, :, None]
|
|
|
res_type_j = res_type[:, :, None, :]
|
|
|
res_type_i = res_type_i.expand(-1, -1, -1, res_type.size(2), -1)
|
|
|
res_type_j = res_type_j.expand(-1, -1, res_type.size(2), -1, -1)
|
|
|
a_tij = torch.cat([a_tij, res_type_i, res_type_j], dim=-1)
|
|
|
a_tij = self.a_proj(a_tij)
|
|
|
|
|
|
|
|
|
pair_mask = pair_mask[:, None].expand(-1, T, -1, -1)
|
|
|
pair_mask = pair_mask.reshape(B * T, *pair_mask.shape[2:])
|
|
|
|
|
|
|
|
|
v = self.z_proj(self.z_norm(z[:, None])) + a_tij
|
|
|
v = v.view(B * T, *v.shape[2:])
|
|
|
v = v + self.pairformer(v, pair_mask, use_kernels=use_kernels)
|
|
|
v = self.v_norm(v)
|
|
|
v = v.view(B, T, *v.shape[1:])
|
|
|
|
|
|
|
|
|
template_mask = template_mask[:, :, None, None, None]
|
|
|
num_templates = num_templates[:, None, None, None]
|
|
|
u = (v * template_mask).sum(dim=1) / num_templates.to(v)
|
|
|
|
|
|
|
|
|
u = self.u_proj(self.relu(u))
|
|
|
return u
|
|
|
|
|
|
|
|
|
class TemplateV2Module(nn.Module):
|
|
|
"""Template module."""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
token_z: int,
|
|
|
template_dim: int,
|
|
|
template_blocks: int,
|
|
|
dropout: float = 0.25,
|
|
|
pairwise_head_width: int = 32,
|
|
|
pairwise_num_heads: int = 4,
|
|
|
post_layer_norm: bool = False,
|
|
|
activation_checkpointing: bool = False,
|
|
|
min_dist: float = 3.25,
|
|
|
max_dist: float = 50.75,
|
|
|
num_bins: int = 38,
|
|
|
**kwargs,
|
|
|
) -> None:
|
|
|
"""Initialize the template module.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
token_z : int
|
|
|
The token pairwise embedding size.
|
|
|
|
|
|
"""
|
|
|
super().__init__()
|
|
|
self.min_dist = min_dist
|
|
|
self.max_dist = max_dist
|
|
|
self.num_bins = num_bins
|
|
|
self.relu = nn.ReLU()
|
|
|
self.z_norm = nn.LayerNorm(token_z)
|
|
|
self.v_norm = nn.LayerNorm(template_dim)
|
|
|
self.z_proj = nn.Linear(token_z, template_dim, bias=False)
|
|
|
self.a_proj = nn.Linear(
|
|
|
const.num_tokens * 2 + num_bins + 5,
|
|
|
template_dim,
|
|
|
bias=False,
|
|
|
)
|
|
|
self.u_proj = nn.Linear(template_dim, token_z, bias=False)
|
|
|
self.pairformer = PairformerNoSeqModule(
|
|
|
template_dim,
|
|
|
num_blocks=template_blocks,
|
|
|
dropout=dropout,
|
|
|
pairwise_head_width=pairwise_head_width,
|
|
|
pairwise_num_heads=pairwise_num_heads,
|
|
|
post_layer_norm=post_layer_norm,
|
|
|
activation_checkpointing=activation_checkpointing,
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
z: Tensor,
|
|
|
feats: dict[str, Tensor],
|
|
|
pair_mask: Tensor,
|
|
|
use_kernels: bool = False,
|
|
|
) -> Tensor:
|
|
|
"""Perform the forward pass.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
z : Tensor
|
|
|
The pairwise embeddings
|
|
|
feats : dict[str, Tensor]
|
|
|
Input features
|
|
|
pair_mask : Tensor
|
|
|
The pair mask
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
Tensor
|
|
|
The updated pairwise embeddings.
|
|
|
|
|
|
"""
|
|
|
|
|
|
res_type = feats["template_restype"]
|
|
|
frame_rot = feats["template_frame_rot"]
|
|
|
frame_t = feats["template_frame_t"]
|
|
|
frame_mask = feats["template_mask_frame"]
|
|
|
cb_coords = feats["template_cb"]
|
|
|
ca_coords = feats["template_ca"]
|
|
|
cb_mask = feats["template_mask_cb"]
|
|
|
visibility_ids = feats["visibility_ids"]
|
|
|
template_mask = feats["template_mask"].any(dim=2).float()
|
|
|
num_templates = template_mask.sum(dim=1)
|
|
|
num_templates = num_templates.clamp(min=1)
|
|
|
|
|
|
|
|
|
b_cb_mask = cb_mask[:, :, :, None] * cb_mask[:, :, None, :]
|
|
|
b_frame_mask = frame_mask[:, :, :, None] * frame_mask[:, :, None, :]
|
|
|
|
|
|
b_cb_mask = b_cb_mask[..., None]
|
|
|
b_frame_mask = b_frame_mask[..., None]
|
|
|
|
|
|
|
|
|
B, T = res_type.shape[:2]
|
|
|
tmlp_pair_mask = (
|
|
|
visibility_ids[:, :, :, None] == visibility_ids[:, :, None, :]
|
|
|
).float()
|
|
|
|
|
|
|
|
|
with torch.autocast(device_type="cuda", enabled=False):
|
|
|
|
|
|
cb_dists = torch.cdist(cb_coords, cb_coords)
|
|
|
boundaries = torch.linspace(self.min_dist, self.max_dist, self.num_bins - 1)
|
|
|
boundaries = boundaries.to(cb_dists.device)
|
|
|
distogram = (cb_dists[..., None] > boundaries).sum(dim=-1).long()
|
|
|
distogram = one_hot(distogram, num_classes=self.num_bins)
|
|
|
|
|
|
|
|
|
frame_rot = frame_rot.unsqueeze(2).transpose(-1, -2)
|
|
|
frame_t = frame_t.unsqueeze(2).unsqueeze(-1)
|
|
|
ca_coords = ca_coords.unsqueeze(3).unsqueeze(-1)
|
|
|
vector = torch.matmul(frame_rot, (ca_coords - frame_t))
|
|
|
norm = torch.norm(vector, dim=-1, keepdim=True)
|
|
|
unit_vector = torch.where(norm > 0, vector / norm, torch.zeros_like(vector))
|
|
|
unit_vector = unit_vector.squeeze(-1)
|
|
|
|
|
|
|
|
|
a_tij = [distogram, b_cb_mask, unit_vector, b_frame_mask]
|
|
|
a_tij = torch.cat(a_tij, dim=-1)
|
|
|
a_tij = a_tij * tmlp_pair_mask.unsqueeze(-1)
|
|
|
|
|
|
res_type_i = res_type[:, :, :, None]
|
|
|
res_type_j = res_type[:, :, None, :]
|
|
|
res_type_i = res_type_i.expand(-1, -1, -1, res_type.size(2), -1)
|
|
|
res_type_j = res_type_j.expand(-1, -1, res_type.size(2), -1, -1)
|
|
|
a_tij = torch.cat([a_tij, res_type_i, res_type_j], dim=-1)
|
|
|
a_tij = self.a_proj(a_tij)
|
|
|
|
|
|
|
|
|
pair_mask = pair_mask[:, None].expand(-1, T, -1, -1)
|
|
|
pair_mask = pair_mask.reshape(B * T, *pair_mask.shape[2:])
|
|
|
|
|
|
|
|
|
v = self.z_proj(self.z_norm(z[:, None])) + a_tij
|
|
|
v = v.view(B * T, *v.shape[2:])
|
|
|
v = v + self.pairformer(v, pair_mask, use_kernels=use_kernels)
|
|
|
v = self.v_norm(v)
|
|
|
v = v.view(B, T, *v.shape[1:])
|
|
|
|
|
|
|
|
|
template_mask = template_mask[:, :, None, None, None]
|
|
|
num_templates = num_templates[:, None, None, None]
|
|
|
u = (v * template_mask).sum(dim=1) / num_templates.to(v)
|
|
|
|
|
|
|
|
|
u = self.u_proj(self.relu(u))
|
|
|
return u
|
|
|
|
|
|
|
|
|
class MSAModule(nn.Module):
|
|
|
"""MSA module."""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
msa_s: int,
|
|
|
token_z: int,
|
|
|
token_s: int,
|
|
|
msa_blocks: int,
|
|
|
msa_dropout: float,
|
|
|
z_dropout: float,
|
|
|
pairwise_head_width: int = 32,
|
|
|
pairwise_num_heads: int = 4,
|
|
|
activation_checkpointing: bool = False,
|
|
|
use_paired_feature: bool = True,
|
|
|
subsample_msa: bool = False,
|
|
|
num_subsampled_msa: int = 1024,
|
|
|
**kwargs,
|
|
|
) -> None:
|
|
|
"""Initialize the MSA module.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
token_z : int
|
|
|
The token pairwise embedding size.
|
|
|
|
|
|
"""
|
|
|
super().__init__()
|
|
|
self.msa_blocks = msa_blocks
|
|
|
self.msa_dropout = msa_dropout
|
|
|
self.z_dropout = z_dropout
|
|
|
self.use_paired_feature = use_paired_feature
|
|
|
self.activation_checkpointing = activation_checkpointing
|
|
|
self.subsample_msa = subsample_msa
|
|
|
self.num_subsampled_msa = num_subsampled_msa
|
|
|
|
|
|
self.s_proj = nn.Linear(token_s, msa_s, bias=False)
|
|
|
self.msa_proj = nn.Linear(
|
|
|
const.num_tokens + 2 + int(use_paired_feature),
|
|
|
msa_s,
|
|
|
bias=False,
|
|
|
)
|
|
|
self.layers = nn.ModuleList()
|
|
|
for i in range(msa_blocks):
|
|
|
self.layers.append(
|
|
|
MSALayer(
|
|
|
msa_s,
|
|
|
token_z,
|
|
|
msa_dropout,
|
|
|
z_dropout,
|
|
|
pairwise_head_width,
|
|
|
pairwise_num_heads,
|
|
|
)
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
z: Tensor,
|
|
|
emb: Tensor,
|
|
|
feats: dict[str, Tensor],
|
|
|
use_kernels: bool = False,
|
|
|
) -> Tensor:
|
|
|
"""Perform the forward pass.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
z : Tensor
|
|
|
The pairwise embeddings
|
|
|
emb : Tensor
|
|
|
The input embeddings
|
|
|
feats : dict[str, Tensor]
|
|
|
Input features
|
|
|
use_kernels: bool
|
|
|
Whether to use kernels for triangular updates
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
Tensor
|
|
|
The output pairwise embeddings.
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not self.training:
|
|
|
if z.shape[1] > const.chunk_size_threshold:
|
|
|
chunk_heads_pwa = True
|
|
|
chunk_size_transition_z = 64
|
|
|
chunk_size_transition_msa = 32
|
|
|
chunk_size_outer_product = 4
|
|
|
chunk_size_tri_attn = 128
|
|
|
else:
|
|
|
chunk_heads_pwa = False
|
|
|
chunk_size_transition_z = None
|
|
|
chunk_size_transition_msa = None
|
|
|
chunk_size_outer_product = None
|
|
|
chunk_size_tri_attn = 512
|
|
|
else:
|
|
|
chunk_heads_pwa = False
|
|
|
chunk_size_transition_z = None
|
|
|
chunk_size_transition_msa = None
|
|
|
chunk_size_outer_product = None
|
|
|
chunk_size_tri_attn = None
|
|
|
|
|
|
|
|
|
msa = feats["msa"]
|
|
|
msa = torch.nn.functional.one_hot(msa, num_classes=const.num_tokens)
|
|
|
has_deletion = feats["has_deletion"].unsqueeze(-1)
|
|
|
deletion_value = feats["deletion_value"].unsqueeze(-1)
|
|
|
is_paired = feats["msa_paired"].unsqueeze(-1)
|
|
|
msa_mask = feats["msa_mask"]
|
|
|
token_mask = feats["token_pad_mask"].float()
|
|
|
token_mask = token_mask[:, :, None] * token_mask[:, None, :]
|
|
|
|
|
|
|
|
|
if self.use_paired_feature:
|
|
|
m = torch.cat([msa, has_deletion, deletion_value, is_paired], dim=-1)
|
|
|
else:
|
|
|
m = torch.cat([msa, has_deletion, deletion_value], dim=-1)
|
|
|
|
|
|
|
|
|
if self.subsample_msa:
|
|
|
msa_indices = torch.randperm(msa.shape[1])[: self.num_subsampled_msa]
|
|
|
m = m[:, msa_indices]
|
|
|
msa_mask = msa_mask[:, msa_indices]
|
|
|
|
|
|
|
|
|
m = self.msa_proj(m)
|
|
|
m = m + self.s_proj(emb).unsqueeze(1)
|
|
|
|
|
|
|
|
|
for i in range(self.msa_blocks):
|
|
|
if self.activation_checkpointing and self.training:
|
|
|
z, m = torch.utils.checkpoint.checkpoint(
|
|
|
self.layers[i],
|
|
|
z,
|
|
|
m,
|
|
|
token_mask,
|
|
|
msa_mask,
|
|
|
chunk_heads_pwa,
|
|
|
chunk_size_transition_z,
|
|
|
chunk_size_transition_msa,
|
|
|
chunk_size_outer_product,
|
|
|
chunk_size_tri_attn,
|
|
|
use_kernels,
|
|
|
)
|
|
|
else:
|
|
|
z, m = self.layers[i](
|
|
|
z,
|
|
|
m,
|
|
|
token_mask,
|
|
|
msa_mask,
|
|
|
chunk_heads_pwa,
|
|
|
chunk_size_transition_z,
|
|
|
chunk_size_transition_msa,
|
|
|
chunk_size_outer_product,
|
|
|
chunk_size_tri_attn,
|
|
|
use_kernels,
|
|
|
)
|
|
|
return z
|
|
|
|
|
|
|
|
|
class MSALayer(nn.Module):
|
|
|
"""MSA module."""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
msa_s: int,
|
|
|
token_z: int,
|
|
|
msa_dropout: float,
|
|
|
z_dropout: float,
|
|
|
pairwise_head_width: int = 32,
|
|
|
pairwise_num_heads: int = 4,
|
|
|
) -> None:
|
|
|
"""Initialize the MSA module.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
token_z : int
|
|
|
The token pairwise embedding size.
|
|
|
|
|
|
"""
|
|
|
super().__init__()
|
|
|
self.msa_dropout = msa_dropout
|
|
|
self.msa_transition = Transition(dim=msa_s, hidden=msa_s * 4)
|
|
|
self.pair_weighted_averaging = PairWeightedAveraging(
|
|
|
c_m=msa_s,
|
|
|
c_z=token_z,
|
|
|
c_h=32,
|
|
|
num_heads=8,
|
|
|
)
|
|
|
|
|
|
self.pairformer_layer = PairformerNoSeqLayer(
|
|
|
token_z=token_z,
|
|
|
dropout=z_dropout,
|
|
|
pairwise_head_width=pairwise_head_width,
|
|
|
pairwise_num_heads=pairwise_num_heads,
|
|
|
)
|
|
|
self.outer_product_mean = OuterProductMean(
|
|
|
c_in=msa_s,
|
|
|
c_hidden=32,
|
|
|
c_out=token_z,
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
z: Tensor,
|
|
|
m: Tensor,
|
|
|
token_mask: Tensor,
|
|
|
msa_mask: Tensor,
|
|
|
chunk_heads_pwa: bool = False,
|
|
|
chunk_size_transition_z: int = None,
|
|
|
chunk_size_transition_msa: int = None,
|
|
|
chunk_size_outer_product: int = None,
|
|
|
chunk_size_tri_attn: int = None,
|
|
|
use_kernels: bool = False,
|
|
|
) -> tuple[Tensor, Tensor]:
|
|
|
"""Perform the forward pass.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
z : Tensor
|
|
|
The pairwise embeddings
|
|
|
emb : Tensor
|
|
|
The input embeddings
|
|
|
feats : dict[str, Tensor]
|
|
|
Input features
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
Tensor
|
|
|
The output pairwise embeddings.
|
|
|
|
|
|
"""
|
|
|
|
|
|
msa_dropout = get_dropout_mask(self.msa_dropout, m, self.training)
|
|
|
m = m + msa_dropout * self.pair_weighted_averaging(
|
|
|
m, z, token_mask, chunk_heads_pwa
|
|
|
)
|
|
|
m = m + self.msa_transition(m, chunk_size_transition_msa)
|
|
|
|
|
|
z = z + self.outer_product_mean(m, msa_mask, chunk_size_outer_product)
|
|
|
|
|
|
|
|
|
z = self.pairformer_layer(
|
|
|
z, token_mask, chunk_size_tri_attn, use_kernels=use_kernels
|
|
|
)
|
|
|
|
|
|
return z, m
|
|
|
|
|
|
|
|
|
class BFactorModule(nn.Module):
|
|
|
"""BFactor Module."""
|
|
|
|
|
|
def __init__(self, token_s: int, num_bins: int) -> None:
|
|
|
"""Initialize the bfactor module.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
token_s : int
|
|
|
The token embedding size.
|
|
|
|
|
|
"""
|
|
|
super().__init__()
|
|
|
self.bfactor = nn.Linear(token_s, num_bins)
|
|
|
self.num_bins = num_bins
|
|
|
|
|
|
def forward(self, s: Tensor) -> Tensor:
|
|
|
"""Perform the forward pass.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
s : Tensor
|
|
|
The sequence embeddings
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
Tensor
|
|
|
The predicted bfactor histogram.
|
|
|
|
|
|
"""
|
|
|
return self.bfactor(s)
|
|
|
|
|
|
|
|
|
class DistogramModule(nn.Module):
|
|
|
"""Distogram Module."""
|
|
|
|
|
|
def __init__(self, token_z: int, num_bins: int, num_distograms: int = 1) -> None:
|
|
|
"""Initialize the distogram module.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
token_z : int
|
|
|
The token pairwise embedding size.
|
|
|
|
|
|
"""
|
|
|
super().__init__()
|
|
|
self.distogram = nn.Linear(token_z, num_distograms * num_bins)
|
|
|
self.num_distograms = num_distograms
|
|
|
self.num_bins = num_bins
|
|
|
|
|
|
def forward(self, z: Tensor) -> Tensor:
|
|
|
"""Perform the forward pass.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
z : Tensor
|
|
|
The pairwise embeddings
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
Tensor
|
|
|
The predicted distogram.
|
|
|
|
|
|
"""
|
|
|
z = z + z.transpose(1, 2)
|
|
|
return self.distogram(z).reshape(
|
|
|
z.shape[0], z.shape[1], z.shape[2], self.num_distograms, self.num_bins
|
|
|
)
|
|
|
|