Spaces:
Runtime error
Runtime error
File size: 4,793 Bytes
753e275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
import torch.nn as nn
from diffab.modules.common.geometry import construct_3d_basis
from diffab.modules.common.so3 import rotation_to_so3vec
from diffab.modules.encoders.residue import ResidueEmbedding
from diffab.modules.encoders.pair import PairEmbedding
from diffab.modules.diffusion.dpm_full import FullDPM
from diffab.utils.protein.constants import max_num_heavyatoms, BBHeavyAtom
from ._base import register_model
resolution_to_num_atoms = {
'backbone+CB': 5,
'full': max_num_heavyatoms
}
@register_model('diffab')
class DiffusionAntibodyDesign(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
num_atoms = resolution_to_num_atoms[cfg.get('resolution', 'full')]
self.residue_embed = ResidueEmbedding(cfg.res_feat_dim, num_atoms)
self.pair_embed = PairEmbedding(cfg.pair_feat_dim, num_atoms)
self.diffusion = FullDPM(
cfg.res_feat_dim,
cfg.pair_feat_dim,
**cfg.diffusion,
)
def encode(self, batch, remove_structure, remove_sequence):
"""
Returns:
res_feat: (N, L, res_feat_dim)
pair_feat: (N, L, L, pair_feat_dim)
"""
# This is used throughout embedding and encoding layers
# to avoid data leakage.
context_mask = torch.logical_and(
batch['mask_heavyatom'][:, :, BBHeavyAtom.CA],
~batch['generate_flag'] # Context means ``not generated''
)
structure_mask = context_mask if remove_structure else None
sequence_mask = context_mask if remove_sequence else None
res_feat = self.residue_embed(
aa = batch['aa'],
res_nb = batch['res_nb'],
chain_nb = batch['chain_nb'],
pos_atoms = batch['pos_heavyatom'],
mask_atoms = batch['mask_heavyatom'],
fragment_type = batch['fragment_type'],
structure_mask = structure_mask,
sequence_mask = sequence_mask,
)
pair_feat = self.pair_embed(
aa = batch['aa'],
res_nb = batch['res_nb'],
chain_nb = batch['chain_nb'],
pos_atoms = batch['pos_heavyatom'],
mask_atoms = batch['mask_heavyatom'],
structure_mask = structure_mask,
sequence_mask = sequence_mask,
)
R = construct_3d_basis(
batch['pos_heavyatom'][:, :, BBHeavyAtom.CA],
batch['pos_heavyatom'][:, :, BBHeavyAtom.C],
batch['pos_heavyatom'][:, :, BBHeavyAtom.N],
)
p = batch['pos_heavyatom'][:, :, BBHeavyAtom.CA]
return res_feat, pair_feat, R, p
def forward(self, batch):
mask_generate = batch['generate_flag']
mask_res = batch['mask']
res_feat, pair_feat, R_0, p_0 = self.encode(
batch,
remove_structure = self.cfg.get('train_structure', True),
remove_sequence = self.cfg.get('train_sequence', True)
)
v_0 = rotation_to_so3vec(R_0)
s_0 = batch['aa']
loss_dict = self.diffusion(
v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res,
denoise_structure = self.cfg.get('train_structure', True),
denoise_sequence = self.cfg.get('train_sequence', True),
)
return loss_dict
@torch.no_grad()
def sample(
self,
batch,
sample_opt={
'sample_structure': True,
'sample_sequence': True,
}
):
mask_generate = batch['generate_flag']
mask_res = batch['mask']
res_feat, pair_feat, R_0, p_0 = self.encode(
batch,
remove_structure = sample_opt.get('sample_structure', True),
remove_sequence = sample_opt.get('sample_sequence', True)
)
v_0 = rotation_to_so3vec(R_0)
s_0 = batch['aa']
traj = self.diffusion.sample(v_0, p_0, s_0, res_feat, pair_feat, mask_generate, mask_res, **sample_opt)
return traj
@torch.no_grad()
def optimize(
self,
batch,
opt_step,
optimize_opt={
'sample_structure': True,
'sample_sequence': True,
}
):
mask_generate = batch['generate_flag']
mask_res = batch['mask']
res_feat, pair_feat, R_0, p_0 = self.encode(
batch,
remove_structure = optimize_opt.get('sample_structure', True),
remove_sequence = optimize_opt.get('sample_sequence', True)
)
v_0 = rotation_to_so3vec(R_0)
s_0 = batch['aa']
traj = self.diffusion.optimize(v_0, p_0, s_0, opt_step, res_feat, pair_feat, mask_generate, mask_res, **optimize_opt)
return traj
|