|
|
|
|
|
|
|
|
|
import typing as T |
|
from contextlib import ExitStack |
|
from dataclasses import dataclass |
|
|
|
import torch |
|
import torch.nn as nn |
|
from openfold.model.structure_module import StructureModule |
|
|
|
from esm.esmfold.v1.tri_self_attn_block import TriangularSelfAttentionBlock |
|
|
|
|
|
@dataclass |
|
class StructureModuleConfig: |
|
c_s: int = 384 |
|
c_z: int = 128 |
|
c_ipa: int = 16 |
|
c_resnet: int = 128 |
|
no_heads_ipa: int = 12 |
|
no_qk_points: int = 4 |
|
no_v_points: int = 8 |
|
dropout_rate: float = 0.1 |
|
no_blocks: int = 8 |
|
no_transition_layers: int = 1 |
|
no_resnet_blocks: int = 2 |
|
no_angles: int = 7 |
|
trans_scale_factor: int = 10 |
|
epsilon: float = 1e-8 |
|
inf: float = 1e5 |
|
|
|
|
|
@dataclass |
|
class FoldingTrunkConfig: |
|
_name: str = "FoldingTrunkConfig" |
|
num_blocks: int = 48 |
|
sequence_state_dim: int = 1024 |
|
pairwise_state_dim: int = 128 |
|
sequence_head_width: int = 32 |
|
pairwise_head_width: int = 32 |
|
position_bins: int = 32 |
|
dropout: float = 0 |
|
layer_drop: float = 0 |
|
cpu_grad_checkpoint: bool = False |
|
|
|
max_recycles: int = 4 |
|
chunk_size: T.Optional[int] = None |
|
|
|
structure_module: StructureModuleConfig = StructureModuleConfig() |
|
|
|
|
|
def get_axial_mask(mask): |
|
""" |
|
Helper to convert B x L mask of valid positions to axial mask used |
|
in row column attentions. |
|
|
|
Input: |
|
mask: B x L tensor of booleans |
|
|
|
Output: |
|
mask: B x L x L tensor of booleans |
|
""" |
|
|
|
if mask is None: |
|
return None |
|
assert len(mask.shape) == 2 |
|
batch_dim, seq_dim = mask.shape |
|
m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim) |
|
m = m.reshape(batch_dim * seq_dim, seq_dim) |
|
return m |
|
|
|
|
|
class RelativePosition(nn.Module): |
|
def __init__(self, bins, pairwise_state_dim): |
|
super().__init__() |
|
self.bins = bins |
|
|
|
|
|
|
|
self.embedding = torch.nn.Embedding(2 * bins + 2, pairwise_state_dim) |
|
|
|
def forward(self, residue_index, mask=None): |
|
""" |
|
Input: |
|
residue_index: B x L tensor of indices (dytpe=torch.long) |
|
mask: B x L tensor of booleans |
|
|
|
Output: |
|
pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings |
|
""" |
|
|
|
assert residue_index.dtype == torch.long |
|
if mask is not None: |
|
assert residue_index.shape == mask.shape |
|
|
|
diff = residue_index[:, None, :] - residue_index[:, :, None] |
|
diff = diff.clamp(-self.bins, self.bins) |
|
diff = diff + self.bins + 1 |
|
|
|
if mask is not None: |
|
mask = mask[:, None, :] * mask[:, :, None] |
|
diff[mask == False] = 0 |
|
|
|
output = self.embedding(diff) |
|
return output |
|
|
|
|
|
class FoldingTrunk(nn.Module): |
|
def __init__(self, **kwargs): |
|
super().__init__() |
|
self.cfg = FoldingTrunkConfig(**kwargs) |
|
assert self.cfg.max_recycles > 0 |
|
|
|
c_s = self.cfg.sequence_state_dim |
|
c_z = self.cfg.pairwise_state_dim |
|
|
|
assert c_s % self.cfg.sequence_head_width == 0 |
|
assert c_z % self.cfg.pairwise_head_width == 0 |
|
block = TriangularSelfAttentionBlock |
|
|
|
self.pairwise_positional_embedding = RelativePosition(self.cfg.position_bins, c_z) |
|
|
|
self.blocks = nn.ModuleList( |
|
[ |
|
block( |
|
sequence_state_dim=c_s, |
|
pairwise_state_dim=c_z, |
|
sequence_head_width=self.cfg.sequence_head_width, |
|
pairwise_head_width=self.cfg.pairwise_head_width, |
|
dropout=self.cfg.dropout, |
|
) |
|
for i in range(self.cfg.num_blocks) |
|
] |
|
) |
|
|
|
self.recycle_bins = 15 |
|
self.recycle_s_norm = nn.LayerNorm(c_s) |
|
self.recycle_z_norm = nn.LayerNorm(c_z) |
|
self.recycle_disto = nn.Embedding(self.recycle_bins, c_z) |
|
self.recycle_disto.weight[0].detach().zero_() |
|
|
|
self.structure_module = StructureModule(**self.cfg.structure_module) |
|
self.trunk2sm_s = nn.Linear(c_s, self.structure_module.c_s) |
|
self.trunk2sm_z = nn.Linear(c_z, self.structure_module.c_z) |
|
|
|
self.chunk_size = self.cfg.chunk_size |
|
|
|
def set_chunk_size(self, chunk_size): |
|
|
|
|
|
|
|
|
|
self.chunk_size = chunk_size |
|
|
|
def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles: T.Optional[int] = None): |
|
""" |
|
Inputs: |
|
seq_feats: B x L x C tensor of sequence features |
|
pair_feats: B x L x L x C tensor of pair features |
|
residx: B x L long tensor giving the position in the sequence |
|
mask: B x L boolean tensor indicating valid residues |
|
|
|
Output: |
|
predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object |
|
""" |
|
|
|
device = seq_feats.device |
|
s_s_0 = seq_feats |
|
s_z_0 = pair_feats |
|
|
|
if no_recycles is None: |
|
no_recycles = self.cfg.max_recycles |
|
else: |
|
assert no_recycles >= 0, "Number of recycles must not be negative." |
|
no_recycles += 1 |
|
|
|
def trunk_iter(s, z, residx, mask): |
|
z = z + self.pairwise_positional_embedding(residx, mask=mask) |
|
|
|
for block in self.blocks: |
|
s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size) |
|
return s, z |
|
|
|
s_s = s_s_0 |
|
s_z = s_z_0 |
|
recycle_s = torch.zeros_like(s_s) |
|
recycle_z = torch.zeros_like(s_z) |
|
recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64) |
|
|
|
assert no_recycles > 0 |
|
for recycle_idx in range(no_recycles): |
|
with ExitStack() if recycle_idx == no_recycles - 1 else torch.no_grad(): |
|
|
|
recycle_s = self.recycle_s_norm(recycle_s.detach()) |
|
recycle_z = self.recycle_z_norm(recycle_z.detach()) |
|
recycle_z += self.recycle_disto(recycle_bins.detach()) |
|
|
|
s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask) |
|
|
|
|
|
structure = self.structure_module( |
|
{"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)}, |
|
true_aa, |
|
mask.float(), |
|
) |
|
|
|
recycle_s = s_s |
|
recycle_z = s_z |
|
|
|
recycle_bins = FoldingTrunk.distogram( |
|
structure["positions"][-1][:, :, :3], |
|
3.375, |
|
21.375, |
|
self.recycle_bins, |
|
) |
|
|
|
assert isinstance(structure, dict) |
|
structure["s_s"] = s_s |
|
structure["s_z"] = s_z |
|
|
|
return structure |
|
|
|
@staticmethod |
|
def distogram(coords, min_bin, max_bin, num_bins): |
|
|
|
boundaries = torch.linspace( |
|
min_bin, |
|
max_bin, |
|
num_bins - 1, |
|
device=coords.device, |
|
) |
|
boundaries = boundaries**2 |
|
N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)] |
|
|
|
b = CA - N |
|
c = C - CA |
|
a = b.cross(c, dim=-1) |
|
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA |
|
dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True) |
|
bins = torch.sum(dists > boundaries, dim=-1) |
|
return bins |
|
|