Spaces:
Running
on
T4
Running
on
T4
import sys, os | |
import torch | |
from icecream import ic | |
import random | |
import numpy as np | |
from kinematics import get_init_xyz | |
sys.path.append('../') | |
from utils.calc_dssp import annotate_sse | |
ic.configureOutput(includeContext=True) | |
def mask_inputs(seq, | |
msa_masked, | |
msa_full, | |
xyz_t, | |
t1d, | |
mask_msa, | |
input_seq_mask=None, | |
input_str_mask=None, | |
input_floating_mask=None, | |
input_t1dconf_mask=None, | |
loss_seq_mask=None, | |
loss_str_mask=None, | |
loss_str_mask_2d=None, | |
dssp=False, | |
hotspots=False, | |
diffuser=None, | |
t=None, | |
freeze_seq_emb=False, | |
mutate_seq=False, | |
no_clamp_seq=False, | |
norm_input=False, | |
contacts=None, | |
frac_provide_dssp=0.5, | |
dssp_mask_percentage=[0,100], | |
frac_provide_contacts=0.5, | |
struc_cond=False): | |
""" | |
Parameters: | |
seq (torch.tensor, required): (I,L) integer sequence | |
msa_masked (torch.tensor, required): (I,N_short,L,48) | |
msa_full (torch,.tensor, required): (I,N_long,L,25) | |
xyz_t (torch,tensor): (T,L,27,3) template crds BEFORE they go into get_init_xyz | |
t1d (torch.tensor, required): (I,L,22) this is the t1d before tacking on the chi angles | |
str_mask_1D (torch.tensor, required): Shape (L) rank 1 tensor where structure is masked at False positions | |
seq_mask_1D (torch.tensor, required): Shape (L) rank 1 tensor where seq is masked at False positions | |
t1d_24: is there an extra dimension to input structure confidence? | |
diffuser: diffuser class | |
t: time step | |
NOTE: in the MSA, the order is 20aa, 1x unknown, 1x mask token. We set the masked region to 22 (masked). | |
For the t1d, this has 20aa, 1x unkown, and 1x template conf. Here, we set the masked region to 21 (unknown). | |
This, we think, makes sense, as the template in normal RF training does not perfectly correspond to the MSA. | |
""" | |
#ic(input_seq_mask.shape) | |
#ic(seq.shape) | |
#ic(msa_masked.shape) | |
#ic(msa_full.shape) | |
#ic(t1d.shape) | |
#ic(xyz_t.shape) | |
#ic(input_str_mask.shape) | |
#ic(mask_msa.shape) | |
########### | |
seq_mask = input_seq_mask | |
###################### | |
###sequence diffusion### | |
###################### | |
str_mask = input_str_mask | |
x_0 = torch.nn.functional.one_hot(seq[0,...],num_classes=22).float()*2-1 | |
seq_diffused = diffuser.q_sample(x_0,t,mask=seq_mask) | |
seq_tmp=torch.argmax(seq_diffused,axis=-1).to(device=seq.device) | |
seq=seq_tmp.repeat(seq.shape[0], 1) | |
################### | |
###msa diffusion### | |
################### | |
### msa_masked ### | |
#ic(msa_masked.shape) | |
B,N,L,_=msa_masked.shape | |
msa_masked[:,0,:,:22] = seq_diffused | |
x_0_msa = msa_masked[0,1:,:,:22].float()*2-1 | |
msa_seq_mask = seq_mask.unsqueeze(0).repeat(N-1, 1) | |
msa_diffused = diffuser.q_sample(x_0_msa,torch.tensor([t]),mask=msa_seq_mask) | |
msa_masked[:,1:,:,:22] = torch.clone(msa_diffused) | |
# index 44/45 is insertion/deletion | |
# index 43 is the masked token NOTE check this | |
# index 42 is the unknown token | |
msa_masked[:,0,:,22:44] = seq_diffused | |
msa_masked[:,1:,:,22:44] = msa_diffused | |
# insertion/deletion stuff | |
msa_masked[:,0,~seq_mask,44:46] = 0 | |
### msa_full ### | |
################ | |
#make msa_full same size as msa_masked | |
#ic(msa_full.shape) | |
msa_full = msa_full[:,:msa_masked.shape[1],:,:] | |
msa_full[:,0,:,:22] = seq_diffused | |
msa_full[:,1:,:,:22] = msa_diffused | |
### t1d ### | |
########### | |
# NOTE: adjusting t1d last dim (confidence) from sequence mask | |
t1d = torch.cat((t1d, torch.zeros((t1d.shape[0],t1d.shape[1],1)).float()), -1).to(seq.device) | |
t1d[:,:,:21] = seq_diffused[...,:21] | |
#t1d[:,:,21] *= input_t1dconf_mask | |
#set diffused conf to 0 and everything else to 1 | |
t1d[:,~seq_mask,21] = 0.0 | |
t1d[:,seq_mask,21] = 1.0 | |
t1d[:1,:,22] = 1-t/diffuser.num_timesteps | |
#to do add structure confidence metric; need to expand dimensions of chkpt b4 | |
#if t1d_24: JG - changed to be default | |
t1d = torch.cat((t1d, torch.zeros((t1d.shape[0],t1d.shape[1],1)).float()), -1).to(seq.device) | |
t1d[:,~str_mask,23] = 0.0 | |
t1d[:,str_mask,23] = 1.0 | |
if dssp: | |
print(f'adding dssp {frac_provide_dssp} of time') | |
t1d = torch.cat((t1d, torch.zeros((t1d.shape[0],t1d.shape[1],4)).float()), -1).to(seq.device) | |
#dssp info | |
#mask some percentage of dssp info in range dssp_mask_percentage[0],dssp_mask_percentage[1] | |
percentage_mask=random.randint(dssp_mask_percentage[0], dssp_mask_percentage[1]) | |
dssp=annotate_sse(np.array(xyz_t[0,:,1,:].squeeze()), percentage_mask=percentage_mask) | |
#dssp_unmasked = annotate_sse(np.array(xyz_t[0,:,1,:].squeeze()), percentage_mask=0) | |
if np.random.rand()>frac_provide_dssp: | |
print('masking dssp') | |
dssp[...]=0 #replace with mask token | |
dssp[:,-1]=1 | |
t1d[...,24:]=dssp | |
if hotspots: | |
print(f"adding hotspots {frac_provide_contacts} of time") | |
t1d = torch.cat((t1d, torch.zeros((t1d.shape[0],t1d.shape[1],1)).float()), -1).to(seq.device) | |
#mask all contacts some fraction of the time | |
if np.random.rand()>frac_provide_contacts: | |
print('masking contacts') | |
contacts = torch.zeros(L) | |
t1d[...,-1] = contacts | |
### xyz_t ### | |
############# | |
xyz_t = get_init_xyz(xyz_t[None]) | |
xyz_t = xyz_t[0] | |
#Sequence masking | |
xyz_t[:,:,3:,:] = float('nan') | |
# Structure masking | |
if struc_cond: | |
print("non-autoregressive structure conditioning") | |
r = diffuser.alphas_cumprod[t] | |
xyz_mask = (torch.rand(xyz_t.shape[1]) > r).to(torch.bool).to(seq.device) | |
xyz_mask = torch.logical_and(xyz_mask,~str_mask) | |
xyz_t[:,xyz_mask,:,:] = float('nan') | |
else: | |
xyz_t[:,~str_mask,:,:] = float('nan') | |
### mask_msa ### | |
################ | |
# NOTE: this is for loss scoring | |
mask_msa[:,:,~loss_seq_mask] = False | |
out=dict( | |
seq= seq, | |
msa_masked= msa_masked, | |
msa_full= msa_full, | |
xyz_t= xyz_t, | |
t1d= t1d, | |
mask_msa= mask_msa, | |
seq_diffused= seq_diffused | |
) | |
return out | |