""" https://github.com/ProteinDesignLab/protpardelle License: MIT Author: Alex Chu Configs and convenience functions for wrapping the model sample() function. """ import argparse import time from typing import Optional, Tuple import torch from torchtyping import TensorType from core import residue_constants from core import utils import diffusion def default_backbone_sampling_config(): config = argparse.Namespace( n_steps=500, s_churn=200, step_scale=1.2, sidechain_mode=False, noise_schedule=lambda t: diffusion.noise_schedule(t, s_max=80, s_min=0.001), ) return config def default_allatom_sampling_config(): noise_schedule = lambda t: diffusion.noise_schedule(t, s_max=80, s_min=0.001) stage2 = argparse.Namespace( apply_cond_proportion=1.0, n_steps=200, s_churn=100, step_scale=1.2, sidechain_mode=True, skip_mpnn_proportion=1.0, noise_schedule=noise_schedule, ) config = argparse.Namespace( n_steps=500, s_churn=200, step_scale=1.2, sidechain_mode=True, skip_mpnn_proportion=0.6, use_fullmpnn=False, use_fullmpnn_for_final=True, anneal_seq_resampling_rate="linear", noise_schedule=noise_schedule, stage_2=stage2, ) return config def draw_backbone_samples( model: torch.nn.Module, seq_mask: TensorType["b n", float] = None, n_samples: int = None, sample_length_range: Tuple[int] = (50, 512), pdb_save_path: Optional[str] = None, return_aux: bool = False, return_sampling_runtime: bool = False, **sampling_kwargs, ): device = model.device if seq_mask is None: assert n_samples is not None seq_mask = model.make_seq_mask_for_sampling( n_samples=n_samples, min_len=sample_length_range[0], max_len=sample_length_range[1], ) start = time.time() aux = model.sample( seq_mask=seq_mask, return_last=False, return_aux=True, **sampling_kwargs ) aux["runtime"] = time.time() - start seq_lens = seq_mask.sum(-1).long() cropped_samp_coords = [ s[: seq_lens[i], model.bb_idxs] for i, s in enumerate(aux["xt_traj"][-1]) ] if pdb_save_path is not None: gly_aatype = (seq_mask * residue_constants.restype_order["G"]).long() trimmed_aatype = [a[: seq_lens[i]] for i, a in enumerate(gly_aatype)] atom_mask = utils.atom37_mask_from_aatype(gly_aatype, seq_mask).cpu() for i in range(len(cropped_samp_coords)): utils.write_coords_to_pdb( cropped_samp_coords[i], f"{pdb_save_path}{i}.pdb", batched=False, aatype=trimmed_aatype[i], atom_mask=atom_mask[i], ) if return_aux: return aux else: if return_sampling_runtime: return cropped_samp_coords, seq_mask, aux["runtime"] else: return cropped_samp_coords, seq_mask def draw_allatom_samples( model: torch.nn.Module, seq_mask: TensorType["b n", float] = None, n_samples: int = None, sample_length_range: Tuple[int] = (50, 512), two_stage_sampling: bool = True, pdb_save_path: Optional[str] = None, return_aux: bool = False, return_sampling_runtime: bool = False, **sampling_kwargs, ): """Implement the default 2-stage all-atom sampling routine.""" def save_allatom_samples(aux, path): seq_lens = aux["seq_mask"].sum(-1).long() cropped_samp_coords = [ c[: seq_lens[i]] for i, c in enumerate(aux["xt_traj"][-1]) ] cropped_samp_aatypes = [ s[: seq_lens[i]] for i, s in enumerate(aux["st_traj"][-1]) ] samp_atom_mask = utils.atom37_mask_from_aatype( aux["st_traj"][-1].to(device), seq_mask ) samp_atom_mask = [m[: seq_lens[i]] for i, m in enumerate(samp_atom_mask)] for i, c in enumerate(cropped_samp_coords): utils.write_coords_to_pdb( c, f"{path}{i}.pdb", batched=False, aatype=cropped_samp_aatypes[i], atom_mask=samp_atom_mask[i], conect=True, ) device = model.device if seq_mask is None: assert n_samples is not None seq_mask = model.make_seq_mask_for_sampling( n_samples=n_samples, min_len=sample_length_range[0], max_len=sample_length_range[1], ) sampling_runtime = 0.0 # Stage 1 sampling start = time.time() if "stage_2" in sampling_kwargs: stage_2_kwargs = vars(sampling_kwargs.pop("stage_2")) aux = model.sample( seq_mask=seq_mask, return_last=False, return_aux=True, **sampling_kwargs, ) sampling_runtime = time.time() - start if pdb_save_path is not None and two_stage_sampling: save_allatom_samples(aux, pdb_save_path + "_init") # Stage 2 sampling (sidechain refinement only) if two_stage_sampling: samp_seq = aux["st_traj"][-1] samp_coords = aux["xt_traj"][-1] cond_atom_mask = utils.atom37_mask_from_aatype((seq_mask * 7).long(), seq_mask) aux = {f"stage1_{k}": v for k, v in aux.items()} start = time.time() stage2_aux = model.sample( gt_cond_atom_mask=cond_atom_mask.to(device), # condition on backbone gt_cond_seq_mask=seq_mask.to(device), gt_coords=samp_coords.to(device), gt_aatype=samp_seq.to(device), seq_mask=seq_mask, return_last=False, return_aux=True, **stage_2_kwargs, ) sampling_runtime += time.time() - start aux = {**aux, **stage2_aux} if pdb_save_path is not None: save_allatom_samples(aux, pdb_save_path + "_samp") aux["runtime"] = sampling_runtime # Process outputs, crop to correct length if return_aux: return aux else: xt_traj = aux["xt_traj"] st_traj = aux["st_traj"] seq_mask = aux["seq_mask"] seq_lens = seq_mask.sum(-1).long() cropped_samp_coords = [c[: seq_lens[i]] for i, c in enumerate(xt_traj[-1])] cropped_samp_aatypes = [s[: seq_lens[i]] for i, s in enumerate(st_traj[-1])] samp_atom_mask = utils.atom37_mask_from_aatype(st_traj[-1].to(device), seq_mask) samp_atom_mask = [m[: seq_lens[i]] for i, m in enumerate(samp_atom_mask)] orig_xt_traj = aux["stage1_xt_traj"] stage1_coords = [c[: seq_lens[i]] for i, c in enumerate(orig_xt_traj[-1])] ret = ( cropped_samp_coords, cropped_samp_aatypes, samp_atom_mask, stage1_coords, seq_mask, ) if return_sampling_runtime: ret = ret + (sampling_runtime,) return ret