|
"""Utility functions for experiments.""" |
|
import logging |
|
import torch |
|
import os |
|
import re |
|
import random |
|
import esm |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import random |
|
|
|
from analysis import utils as au |
|
from pytorch_lightning.utilities.rank_zero import rank_zero_only |
|
from data.residue_constants import restype_order |
|
from data.repr import get_pre_repr |
|
from data import utils as du |
|
from data.residue_constants import restype_atom37_mask |
|
from openfold.data import data_transforms |
|
from openfold.utils import rigid_utils |
|
from data.cal_trans_rotmats import cal_trans_rotmats |
|
|
|
|
|
class LengthDataset(torch.utils.data.Dataset): |
|
def __init__(self, samples_cfg): |
|
self._samples_cfg = samples_cfg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
validcsv = pd.read_csv(self._samples_cfg.validset_path) |
|
|
|
self._all_sample_seqs = [] |
|
self._all_filename = [] |
|
|
|
prob_num = 500 |
|
exp_prob = np.exp([-prob/prob_num*2 for prob in range(prob_num)]).cumsum() |
|
exp_prob = exp_prob/np.max(exp_prob) |
|
|
|
for idx in range(len(validcsv['seq'])): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._all_filename += [validcsv['file'][idx]] * self._samples_cfg.sample_num |
|
|
|
for batch_idx in range(self._samples_cfg.sample_num): |
|
|
|
rand = random.random() |
|
for prob in range(prob_num): |
|
if rand < exp_prob[prob]: |
|
energy = torch.tensor(prob/prob_num) |
|
break |
|
|
|
self._all_sample_seqs += [(validcsv['seq'][idx], energy)] |
|
|
|
|
|
self._all_sample_ids = self._all_sample_seqs |
|
|
|
|
|
self.device_esm=f'cuda:{torch.cuda.current_device()}' |
|
self.model_esm2, self.alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
|
self.batch_converter = self.alphabet.get_batch_converter() |
|
self.model_esm2.eval().cuda(self.device_esm) |
|
self.model_esm2.requires_grad_(False) |
|
|
|
self._folding_model = esm.pretrained.esmfold_v1().eval() |
|
self._folding_model = self._folding_model.to(self.device_esm) |
|
|
|
self.esm_savepath = self._samples_cfg.esm_savepath |
|
|
|
|
|
self.device_esm=f'cuda:{torch.cuda.current_device()}' |
|
self._folding_model = esm.pretrained.esmfold_v1().eval() |
|
self._folding_model.requires_grad_(False) |
|
self._folding_model.to(self.device_esm) |
|
|
|
def run_folding(self, sequence, save_path): |
|
"""Run ESMFold on sequence.""" |
|
with torch.no_grad(): |
|
output = self._folding_model.infer_pdb(sequence) |
|
self._folding_model.to("cpu") |
|
|
|
with open(save_path, "w") as f: |
|
f.write(output) |
|
return output |
|
|
|
def __len__(self): |
|
return len(self._all_sample_ids) |
|
|
|
def __getitem__(self, idx): |
|
seq, energy = self._all_sample_ids[idx] |
|
aatype = torch.tensor([restype_order[s] for s in seq]) |
|
num_res = len(aatype) |
|
|
|
node_repr_pre, pair_repr_pre = get_pre_repr(aatype, self.model_esm2, |
|
self.alphabet, self.batch_converter, device = self.device_esm) |
|
node_repr_pre = node_repr_pre[0].cpu() |
|
pair_repr_pre = pair_repr_pre[0].cpu() |
|
|
|
motif_mask = torch.ones(aatype.shape) |
|
|
|
|
|
save_path = os.path.join(self.esm_savepath, "esm_" + self._all_filename[idx] + ".pdb") |
|
if not os.path.exists(save_path): |
|
seq_string = seq |
|
with torch.no_grad(): |
|
output = self._folding_model.infer_pdb(seq_string) |
|
with open(save_path, "w") as f: |
|
f.write(output) |
|
|
|
|
|
trans_esmfold, rotmats_esmfold = cal_trans_rotmats(save_path) |
|
|
|
batch = { |
|
'filename':self._all_filename[idx], |
|
'trans_esmfold': trans_esmfold, |
|
'rotmats_esmfold': rotmats_esmfold, |
|
'motif_mask': motif_mask, |
|
'res_mask': torch.ones(num_res).int(), |
|
'num_res': num_res, |
|
'energy': energy, |
|
'aatype': aatype, |
|
'seq': seq, |
|
'node_repr_pre': node_repr_pre, |
|
'pair_repr_pre': pair_repr_pre, |
|
} |
|
return batch |
|
|
|
|
|
|
|
def save_traj( |
|
sample: np.ndarray, |
|
bb_prot_traj: np.ndarray, |
|
x0_traj: np.ndarray, |
|
diffuse_mask: np.ndarray, |
|
output_dir: str, |
|
aatype = None, |
|
index=0, |
|
): |
|
"""Writes final sample and reverse diffusion trajectory. |
|
|
|
Args: |
|
bb_prot_traj: [T, N, 37, 3] atom37 sampled diffusion states. |
|
T is number of time steps. First time step is t=eps, |
|
i.e. bb_prot_traj[0] is the final sample after reverse diffusion. |
|
N is number of residues. |
|
x0_traj: [T, N, 3] x_0 predictions of C-alpha at each time step. |
|
aatype: [T, N, 21] amino acid probability vector trajectory. |
|
res_mask: [N] residue mask. |
|
diffuse_mask: [N] which residues are diffused. |
|
output_dir: where to save samples. |
|
|
|
Returns: |
|
Dictionary with paths to saved samples. |
|
'sample_path': PDB file of final state of reverse trajectory. |
|
'traj_path': PDB file os all intermediate diffused states. |
|
'x0_traj_path': PDB file of C-alpha x_0 predictions at each state. |
|
b_factors are set to 100 for diffused residues and 0 for motif |
|
residues if there are any. |
|
""" |
|
|
|
|
|
diffuse_mask = diffuse_mask.astype(bool) |
|
sample_path = os.path.join(output_dir, 'sample_'+str(index)+'.pdb') |
|
prot_traj_path = os.path.join(output_dir, 'bb_traj_'+str(index)+'.pdb') |
|
x0_traj_path = os.path.join(output_dir, 'x0_traj_'+str(index)+'.pdb') |
|
|
|
|
|
b_factors = np.tile((diffuse_mask * 100)[:, None], (1, 37)) |
|
|
|
sample_path = au.write_prot_to_pdb( |
|
sample, |
|
sample_path, |
|
b_factors=b_factors, |
|
no_indexing=True, |
|
aatype=aatype, |
|
) |
|
prot_traj_path = au.write_prot_to_pdb( |
|
bb_prot_traj, |
|
prot_traj_path, |
|
b_factors=b_factors, |
|
no_indexing=True, |
|
aatype=aatype, |
|
) |
|
x0_traj_path = au.write_prot_to_pdb( |
|
x0_traj, |
|
x0_traj_path, |
|
b_factors=b_factors, |
|
no_indexing=True, |
|
aatype=aatype |
|
) |
|
return { |
|
'sample_path': sample_path, |
|
'traj_path': prot_traj_path, |
|
'x0_traj_path': x0_traj_path, |
|
} |
|
|
|
|
|
def get_pylogger(name=__name__) -> logging.Logger: |
|
"""Initializes multi-GPU-friendly python command line logger.""" |
|
|
|
logger = logging.getLogger(name) |
|
|
|
|
|
|
|
logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") |
|
for level in logging_levels: |
|
setattr(logger, level, rank_zero_only(getattr(logger, level))) |
|
|
|
return logger |
|
|
|
|
|
def flatten_dict(raw_dict): |
|
"""Flattens a nested dict.""" |
|
flattened = [] |
|
for k, v in raw_dict.items(): |
|
if isinstance(v, dict): |
|
flattened.extend([ |
|
(f'{k}:{i}', j) for i, j in flatten_dict(v) |
|
]) |
|
else: |
|
flattened.append((k, v)) |
|
return flattened |
|
|