|
import sys, os |
|
import torch |
|
from icecream import ic |
|
import random |
|
import numpy as np |
|
from kinematics import get_init_xyz |
|
sys.path.append('../') |
|
from utils.calc_dssp import annotate_sse |
|
|
|
ic.configureOutput(includeContext=True) |
|
|
|
def mask_inputs(seq, |
|
msa_masked, |
|
msa_full, |
|
xyz_t, |
|
t1d, |
|
mask_msa, |
|
input_seq_mask=None, |
|
input_str_mask=None, |
|
input_floating_mask=None, |
|
input_t1dconf_mask=None, |
|
loss_seq_mask=None, |
|
loss_str_mask=None, |
|
loss_str_mask_2d=None, |
|
dssp=False, |
|
hotspots=False, |
|
diffuser=None, |
|
t=None, |
|
freeze_seq_emb=False, |
|
mutate_seq=False, |
|
no_clamp_seq=False, |
|
norm_input=False, |
|
contacts=None, |
|
frac_provide_dssp=0.5, |
|
dssp_mask_percentage=[0,100], |
|
frac_provide_contacts=0.5, |
|
struc_cond=False): |
|
""" |
|
Parameters: |
|
seq (torch.tensor, required): (I,L) integer sequence |
|
|
|
msa_masked (torch.tensor, required): (I,N_short,L,48) |
|
|
|
msa_full (torch,.tensor, required): (I,N_long,L,25) |
|
|
|
xyz_t (torch,tensor): (T,L,27,3) template crds BEFORE they go into get_init_xyz |
|
|
|
t1d (torch.tensor, required): (I,L,22) this is the t1d before tacking on the chi angles |
|
|
|
str_mask_1D (torch.tensor, required): Shape (L) rank 1 tensor where structure is masked at False positions |
|
|
|
seq_mask_1D (torch.tensor, required): Shape (L) rank 1 tensor where seq is masked at False positions |
|
t1d_24: is there an extra dimension to input structure confidence? |
|
|
|
diffuser: diffuser class |
|
|
|
t: time step |
|
|
|
NOTE: in the MSA, the order is 20aa, 1x unknown, 1x mask token. We set the masked region to 22 (masked). |
|
For the t1d, this has 20aa, 1x unkown, and 1x template conf. Here, we set the masked region to 21 (unknown). |
|
This, we think, makes sense, as the template in normal RF training does not perfectly correspond to the MSA. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seq_mask = input_seq_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
str_mask = input_str_mask |
|
|
|
x_0 = torch.nn.functional.one_hot(seq[0,...],num_classes=22).float()*2-1 |
|
seq_diffused = diffuser.q_sample(x_0,t,mask=seq_mask) |
|
|
|
seq_tmp=torch.argmax(seq_diffused,axis=-1).to(device=seq.device) |
|
seq=seq_tmp.repeat(seq.shape[0], 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
B,N,L,_=msa_masked.shape |
|
msa_masked[:,0,:,:22] = seq_diffused |
|
|
|
x_0_msa = msa_masked[0,1:,:,:22].float()*2-1 |
|
msa_seq_mask = seq_mask.unsqueeze(0).repeat(N-1, 1) |
|
msa_diffused = diffuser.q_sample(x_0_msa,torch.tensor([t]),mask=msa_seq_mask) |
|
|
|
msa_masked[:,1:,:,:22] = torch.clone(msa_diffused) |
|
|
|
|
|
|
|
|
|
msa_masked[:,0,:,22:44] = seq_diffused |
|
msa_masked[:,1:,:,22:44] = msa_diffused |
|
|
|
|
|
msa_masked[:,0,~seq_mask,44:46] = 0 |
|
|
|
|
|
|
|
|
|
|
|
msa_full = msa_full[:,:msa_masked.shape[1],:,:] |
|
msa_full[:,0,:,:22] = seq_diffused |
|
msa_full[:,1:,:,:22] = msa_diffused |
|
|
|
|
|
|
|
|
|
t1d = torch.cat((t1d, torch.zeros((t1d.shape[0],t1d.shape[1],1)).float()), -1).to(seq.device) |
|
t1d[:,:,:21] = seq_diffused[...,:21] |
|
|
|
|
|
|
|
t1d[:,~seq_mask,21] = 0.0 |
|
t1d[:,seq_mask,21] = 1.0 |
|
|
|
t1d[:1,:,22] = 1-t/diffuser.num_timesteps |
|
|
|
|
|
|
|
t1d = torch.cat((t1d, torch.zeros((t1d.shape[0],t1d.shape[1],1)).float()), -1).to(seq.device) |
|
t1d[:,~str_mask,23] = 0.0 |
|
t1d[:,str_mask,23] = 1.0 |
|
|
|
if dssp: |
|
print(f'adding dssp {frac_provide_dssp} of time') |
|
t1d = torch.cat((t1d, torch.zeros((t1d.shape[0],t1d.shape[1],4)).float()), -1).to(seq.device) |
|
|
|
|
|
percentage_mask=random.randint(dssp_mask_percentage[0], dssp_mask_percentage[1]) |
|
dssp=annotate_sse(np.array(xyz_t[0,:,1,:].squeeze()), percentage_mask=percentage_mask) |
|
|
|
if np.random.rand()>frac_provide_dssp: |
|
print('masking dssp') |
|
dssp[...]=0 |
|
dssp[:,-1]=1 |
|
t1d[...,24:]=dssp |
|
|
|
if hotspots: |
|
print(f"adding hotspots {frac_provide_contacts} of time") |
|
t1d = torch.cat((t1d, torch.zeros((t1d.shape[0],t1d.shape[1],1)).float()), -1).to(seq.device) |
|
|
|
if np.random.rand()>frac_provide_contacts: |
|
print('masking contacts') |
|
contacts = torch.zeros(L) |
|
t1d[...,-1] = contacts |
|
|
|
|
|
|
|
xyz_t = get_init_xyz(xyz_t[None]) |
|
xyz_t = xyz_t[0] |
|
|
|
xyz_t[:,:,3:,:] = float('nan') |
|
|
|
if struc_cond: |
|
print("non-autoregressive structure conditioning") |
|
r = diffuser.alphas_cumprod[t] |
|
xyz_mask = (torch.rand(xyz_t.shape[1]) > r).to(torch.bool).to(seq.device) |
|
xyz_mask = torch.logical_and(xyz_mask,~str_mask) |
|
xyz_t[:,xyz_mask,:,:] = float('nan') |
|
else: |
|
xyz_t[:,~str_mask,:,:] = float('nan') |
|
|
|
|
|
|
|
|
|
mask_msa[:,:,~loss_seq_mask] = False |
|
|
|
out=dict( |
|
seq= seq, |
|
msa_masked= msa_masked, |
|
msa_full= msa_full, |
|
xyz_t= xyz_t, |
|
t1d= t1d, |
|
mask_msa= mask_msa, |
|
seq_diffused= seq_diffused |
|
) |
|
|
|
return out |
|
|