| |
| import torch |
| import pickle |
| import numpy as np |
| import os |
| import logging |
|
|
| from scipy.spatial.transform import Rotation as scipy_R |
|
|
| from rfdiffusion.util import rigid_from_3_points |
|
|
| from rfdiffusion.util_module import ComputeAllAtomCoords |
|
|
| from rfdiffusion import igso3 |
| import time |
|
|
| torch.set_printoptions(sci_mode=False) |
|
|
|
|
| def get_beta_schedule(T, b0, bT, schedule_type, schedule_params={}, inference=False): |
| """ |
| Given a noise schedule type, create the beta schedule |
| """ |
| assert schedule_type in ["linear"] |
|
|
| |
| |
| assert T >= 15, "With discrete time and T < 15, the schedule is badly approximated" |
| b0 *= 200 / T |
| bT *= 200 / T |
|
|
| |
| if schedule_type == "linear": |
| schedule = torch.linspace(b0, bT, T) |
|
|
| else: |
| raise NotImplementedError(f"Schedule of type {schedule_type} not implemented.") |
|
|
| |
| alpha_schedule = 1 - schedule |
| alphabar_t_schedule = torch.cumprod(alpha_schedule, dim=0) |
|
|
| if inference: |
| print( |
| f"With this beta schedule ({schedule_type} schedule, beta_0 = {round(b0, 3)}, beta_T = {round(bT,3)}), alpha_bar_T = {alphabar_t_schedule[-1]}" |
| ) |
|
|
| return schedule, alpha_schedule, alphabar_t_schedule |
|
|
|
|
| class EuclideanDiffuser: |
| |
|
|
| def __init__( |
| self, |
| T, |
| b_0, |
| b_T, |
| schedule_type="linear", |
| schedule_kwargs={}, |
| ): |
| self.T = T |
|
|
| |
| ( |
| self.beta_schedule, |
| self.alpha_schedule, |
| self.alphabar_schedule, |
| ) = get_beta_schedule(T, b_0, b_T, schedule_type, **schedule_kwargs) |
|
|
| def diffuse_translations(self, xyz, diffusion_mask=None, var_scale=1): |
| return self.apply_kernel_recursive(xyz, diffusion_mask, var_scale) |
|
|
| def apply_kernel(self, x, t, diffusion_mask=None, var_scale=1): |
| """ |
| Applies a noising kernel to the points in x |
| |
| Parameters: |
| x (torch.tensor, required): (N,3,3) set of backbone coordinates |
| |
| t (int, required): Which timestep |
| |
| noise_scale (float, required): scale for noise |
| """ |
| t_idx = t - 1 |
|
|
| assert len(x.shape) == 3 |
| L, _, _ = x.shape |
|
|
| |
| ca_xyz = x[:, 1, :] |
|
|
| b_t = self.beta_schedule[t_idx] |
|
|
| |
| mean = torch.sqrt(1 - b_t) * ca_xyz |
| var = torch.ones(L, 3) * (b_t) * var_scale |
|
|
| sampled_crds = torch.normal(mean, torch.sqrt(var)) |
| delta = sampled_crds - ca_xyz |
|
|
| if not diffusion_mask is None: |
| delta[diffusion_mask, ...] = 0 |
|
|
| out_crds = x + delta[:, None, :] |
|
|
| return out_crds, delta |
|
|
| def apply_kernel_recursive(self, xyz, diffusion_mask=None, var_scale=1): |
| """ |
| Repeatedly apply self.apply_kernel T times and return all crds |
| """ |
| bb_stack = [] |
| T_stack = [] |
|
|
| cur_xyz = torch.clone(xyz) |
|
|
| for t in range(1, self.T + 1): |
| cur_xyz, cur_T = self.apply_kernel( |
| cur_xyz, t, var_scale=var_scale, diffusion_mask=diffusion_mask |
| ) |
| bb_stack.append(cur_xyz) |
| T_stack.append(cur_T) |
|
|
| return torch.stack(bb_stack).transpose(0, 1), torch.stack(T_stack).transpose( |
| 0, 1 |
| ) |
|
|
|
|
| def write_pkl(save_path: str, pkl_data): |
| """Serialize data into a pickle file.""" |
| with open(save_path, "wb") as handle: |
| pickle.dump(pkl_data, handle, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
|
| def read_pkl(read_path: str, verbose=False): |
| """Read data from a pickle file.""" |
| with open(read_path, "rb") as handle: |
| try: |
| return pickle.load(handle) |
| except Exception as e: |
| if verbose: |
| print(f"Failed to read {read_path}") |
| raise (e) |
|
|
|
|
| class IGSO3: |
| """ |
| Class for taking in a set of backbone crds and performing IGSO3 diffusion |
| on all of them. |
| |
| Unlike the diffusion on translations, much of this class is written for a |
| scaling between an initial time t=0 and final time t=1. |
| """ |
|
|
| def __init__( |
| self, |
| *, |
| T, |
| min_sigma, |
| max_sigma, |
| min_b, |
| max_b, |
| cache_dir, |
| num_omega=1000, |
| schedule="linear", |
| L=2000, |
| ): |
| """ |
| |
| Args: |
| T: total number of time steps |
| min_sigma: smallest allowed scale parameter, should be at least 0.01 to maintain numerical stability. Recommended value is 0.05. |
| max_sigma: for exponential schedule, the largest scale parameter. Ignored for recommeded linear schedule |
| min_b: lower value of beta in Ho schedule analogue |
| max_b: upper value of beta in Ho schedule analogue |
| num_omega: discretization level in the angles across [0, pi] |
| schedule: currently only linear and exponential are supported. The exponential schedule may be noising too slowly. |
| L: truncation level |
| """ |
| self._log = logging.getLogger(__name__) |
|
|
| self.T = T |
|
|
| self.schedule = schedule |
| self.cache_dir = cache_dir |
| self.min_sigma = min_sigma |
| self.max_sigma = max_sigma |
|
|
| if self.schedule == "linear": |
| self.min_b = min_b |
| self.max_b = max_b |
| self.max_sigma = self.sigma(1.0) |
| self.num_omega = num_omega |
| self.num_sigma = 500 |
| |
| self.L = L |
| self.igso3_vals = self._calc_igso3_vals(L=L) |
| self.step_size = 1 / self.T |
|
|
| def _calc_igso3_vals(self, L=2000): |
| """_calc_igso3_vals computes numerical approximations to the |
| relevant analytically intractable functionals of the igso3 |
| distribution. |
| |
| The calculated values are cached, or loaded from cache if they already |
| exist. |
| |
| Args: |
| L: truncation level for power series expansion of the pdf. |
| """ |
| replace_period = lambda x: str(x).replace(".", "_") |
| if self.schedule == "linear": |
| cache_fname = os.path.join( |
| self.cache_dir, |
| f"T_{self.T}_omega_{self.num_omega}_min_sigma_{replace_period(self.min_sigma)}" |
| + f"_min_b_{replace_period(self.min_b)}_max_b_{replace_period(self.max_b)}_schedule_{self.schedule}.pkl", |
| ) |
| elif self.schedule == "exponential": |
| cache_fname = os.path.join( |
| self.cache_dir, |
| f"T_{self.T}_omega_{self.num_omega}_min_sigma_{replace_period(self.min_sigma)}" |
| f"_max_sigma_{replace_period(self.max_sigma)}_schedule_{self.schedule}", |
| ) |
| else: |
| raise ValueError(f"Unrecognize schedule {self.schedule}") |
|
|
| if not os.path.isdir(self.cache_dir): |
| os.makedirs(self.cache_dir) |
|
|
| if os.path.exists(cache_fname): |
| self._log.info("Using cached IGSO3.") |
| igso3_vals = read_pkl(cache_fname) |
| else: |
| self._log.info("Calculating IGSO3.") |
| igso3_vals = igso3.calculate_igso3( |
| num_sigma=self.num_sigma, |
| min_sigma=self.min_sigma, |
| max_sigma=self.max_sigma, |
| num_omega=self.num_omega |
| ) |
| write_pkl(cache_fname, igso3_vals) |
|
|
| return igso3_vals |
|
|
| @property |
| def discrete_sigma(self): |
| return self.igso3_vals["discrete_sigma"] |
|
|
| def sigma_idx(self, sigma: np.ndarray): |
| """ |
| Calculates the index for discretized sigma during IGSO(3) initialization.""" |
| return np.digitize(sigma, self.discrete_sigma) - 1 |
|
|
| def t_to_idx(self, t: np.ndarray): |
| """ |
| Helper function to go from discrete time index t to corresponding sigma_idx. |
| |
| Args: |
| t: time index (integer between 1 and 200) |
| """ |
| continuous_t = t / self.T |
| return self.sigma_idx(self.sigma(continuous_t)) |
|
|
| def sigma(self, t: torch.tensor): |
| """ |
| Extract \sigma(t) corresponding to chosen sigma schedule. |
| |
| Args: |
| t: torch tensor with time between 0 and 1 |
| """ |
| if not type(t) == torch.Tensor: |
| t = torch.tensor(t) |
| if torch.any(t < 0) or torch.any(t > 1): |
| raise ValueError(f"Invalid t={t}") |
| if self.schedule == "exponential": |
| sigma = t * np.log10(self.max_sigma) + (1 - t) * np.log10(self.min_sigma) |
| return 10**sigma |
| elif self.schedule == "linear": |
| |
| return ( |
| self.min_sigma |
| + t * self.min_b |
| + (1 / 2) * (t**2) * (self.max_b - self.min_b) |
| ) |
| else: |
| raise ValueError(f"Unrecognize schedule {self.schedule}") |
|
|
| def g(self, t): |
| """ |
| g returns the drift coefficient at time t |
| |
| since |
| sigma(t)^2 := \int_0^t g(s)^2 ds, |
| for arbitrary sigma(t) we invert this relationship to compute |
| g(t) = sqrt(d/dt sigma(t)^2). |
| |
| Args: |
| t: scalar time between 0 and 1 |
| |
| Returns: |
| drift cooeficient as a scalar. |
| """ |
| t = torch.tensor(t, requires_grad=True) |
| sigma_sqr = self.sigma(t) ** 2 |
| grads = torch.autograd.grad(sigma_sqr.sum(), t)[0] |
| return torch.sqrt(grads) |
|
|
| def sample(self, ts, n_samples=1): |
| """ |
| sample uses the inverse cdf to sample an angle of rotation from |
| IGSO(3) |
| |
| Args: |
| ts: array of integer time steps to sample from. |
| n_samples: number of samples to draw. |
| Returns: |
| sampled angles of rotation. [len(ts), N] |
| """ |
| assert sum(ts == 0) == 0, "assumes one-indexed, not zero indexed" |
| all_samples = [] |
| for t in ts: |
| sigma_idx = self.t_to_idx(t) |
| sample_i = np.interp( |
| np.random.rand(n_samples), |
| self.igso3_vals["cdf"][sigma_idx], |
| self.igso3_vals["discrete_omega"], |
| ) |
| all_samples.append(sample_i) |
| return np.stack(all_samples, axis=0) |
|
|
| def sample_vec(self, ts, n_samples=1): |
| """sample_vec generates a rotation vector(s) from IGSO(3) at time steps |
| ts. |
| |
| Return: |
| Sampled vector of shape [len(ts), N, 3] |
| """ |
| x = np.random.randn(len(ts), n_samples, 3) |
| x /= np.linalg.norm(x, axis=-1, keepdims=True) |
| return x * self.sample(ts, n_samples=n_samples)[..., None] |
|
|
| def score_norm(self, t, omega): |
| """ |
| score_norm computes the score norm based on the time step and angle |
| Args: |
| t: integer time step |
| omega: angles (scalar or shape [N]) |
| Return: |
| score_norm with same shape as omega |
| """ |
| sigma_idx = self.t_to_idx(t) |
| score_norm_t = np.interp( |
| omega, |
| self.igso3_vals["discrete_omega"], |
| self.igso3_vals["score_norm"][sigma_idx], |
| ) |
| return score_norm_t |
|
|
| def score_vec(self, ts, vec): |
| """score_vec computes the score of the IGSO(3) density as a rotation |
| vector. This score vector is in the direction of the sampled vector, |
| and has magnitude given by score_norms. |
| |
| In particular, Rt @ hat(score_vec(ts, vec)) is what is referred to as |
| the score approximation in Algorithm 1 |
| |
| |
| Args: |
| ts: times of shape [T] |
| vec: where to compute the score of shape [T, N, 3] |
| Returns: |
| score vectors of shape [T, N, 3] |
| """ |
| omega = np.linalg.norm(vec, axis=-1) |
| all_score_norm = [] |
| for i, t in enumerate(ts): |
| omega_t = omega[i] |
| t_idx = t - 1 |
| sigma_idx = self.t_to_idx(t) |
| score_norm_t = np.interp( |
| omega_t, |
| self.igso3_vals["discrete_omega"], |
| self.igso3_vals["score_norm"][sigma_idx], |
| )[:, None] |
| all_score_norm.append(score_norm_t) |
| score_norm = np.stack(all_score_norm, axis=0) |
| return score_norm * vec / omega[..., None] |
|
|
| def exp_score_norm(self, ts): |
| """exp_score_norm returns the expected value of norm of the score for |
| IGSO(3) with time parameter ts of shape [T]. |
| """ |
| sigma_idcs = [self.t_to_idx(t) for t in ts] |
| return self.igso3_vals["exp_score_norms"][sigma_idcs] |
|
|
| def diffuse_frames(self, xyz, t_list, diffusion_mask=None): |
| """diffuse_frames samples from the IGSO(3) distribution to noise frames |
| |
| Parameters: |
| xyz (np.array or torch.tensor, required): (L,3,3) set of backbone coordinates |
| mask (np.array or torch.tensor, required): (L,) set of bools. True/1 is NOT diffused, False/0 IS diffused |
| Returns: |
| np.array : N/CA/C coordinates for each residue |
| (T,L,3,3), where T is num timesteps |
| """ |
|
|
| if torch.is_tensor(xyz): |
| xyz = xyz.numpy() |
|
|
| t = np.arange(self.T) + 1 |
| num_res = len(xyz) |
|
|
| N = torch.from_numpy(xyz[None, :, 0, :]) |
| Ca = torch.from_numpy(xyz[None, :, 1, :]) |
| C = torch.from_numpy(xyz[None, :, 2, :]) |
|
|
| |
| R_true, Ca = rigid_from_3_points(N, Ca, C) |
| R_true = R_true[0] |
| Ca = Ca[0] |
|
|
| |
| sampled_rots = self.sample_vec(t, n_samples=num_res) |
|
|
| if diffusion_mask is not None: |
| non_diffusion_mask = 1 - diffusion_mask[None, :, None] |
| sampled_rots = sampled_rots * non_diffusion_mask |
|
|
| |
| R_sampled = ( |
| scipy_R.from_rotvec(sampled_rots.reshape(-1, 3)) |
| .as_matrix() |
| .reshape(self.T, num_res, 3, 3) |
| ) |
| R_perturbed = np.einsum("tnij,njk->tnik", R_sampled, R_true) |
| perturbed_crds = ( |
| np.einsum( |
| "tnij,naj->tnai", R_sampled, xyz[:, :3, :] - Ca[:, None, ...].numpy() |
| ) |
| + Ca[None, :, None].numpy() |
| ) |
|
|
| if t_list != None: |
| idx = [i - 1 for i in t_list] |
| perturbed_crds = perturbed_crds[idx] |
| R_perturbed = R_perturbed[idx] |
|
|
| return ( |
| perturbed_crds.transpose(1, 0, 2, 3), |
| R_perturbed.transpose(1, 0, 2, 3), |
| ) |
|
|
| def reverse_sample_vectorized( |
| self, R_t, R_0, t, noise_level, mask=None, return_perturb=False |
| ): |
| """reverse_sample uses an approximation to the IGSO3 score to sample |
| a rotation at the previous time step. |
| |
| Roughly - this update follows the reverse time SDE for Reimannian |
| manifolds proposed by de Bortoli et al. Theorem 1 [1]. But with an |
| approximation to the score based on the prediction of R0. |
| Unlike in reference [1], this diffusion on SO(3) relies on geometric |
| variance schedule. Specifically we follow [2] (appendix C) and assume |
| sigma_t = sigma_min * (sigma_max / sigma_min)^{t/T}, |
| for time step t. When we view this as a discretization of the SDE |
| from time 0 to 1 with step size (1/T). Following Eq. 5 and Eq. 6, |
| this maps on to the forward time SDEs |
| dx = g(t) dBt [FORWARD] |
| and |
| dx = g(t)^2 score(xt, t)dt + g(t) B't, [REVERSE] |
| where g(t) = sigma_t * sqrt(2 * log(sigma_max/ sigma_min)), and Bt and |
| B't are Brownian motions. The formula for g(t) obtains from equation 9 |
| of [2], from which this sampling function may be generalized to |
| alternative noising schedules. |
| Args: |
| R_t: noisy rotation of shape [N, 3, 3] |
| R_0: prediction of un-noised rotation |
| t: integer time step |
| noise_level: scaling on the noise added when obtaining sample |
| (preliminary performance seems empirically better with noise |
| level=0.5) |
| mask: whether the residue is to be updated. A value of 1 means the |
| rotation is not updated from r_t. A value of 0 means the |
| rotation is updated. |
| Return: |
| sampled rotation matrix for time t-1 of shape [3, 3] |
| Reference: |
| [1] De Bortoli, V., Mathieu, E., Hutchinson, M., Thornton, J., Teh, Y. |
| W., & Doucet, A. (2022). Riemannian score-based generative modeling. |
| arXiv preprint arXiv:2202.02763. |
| [2] Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., |
| & Poole, B. (2020). Score-based generative modeling through stochastic |
| differential equations. arXiv preprint arXiv:2011.13456. |
| """ |
| |
| R_0, R_t = torch.tensor(R_0), torch.tensor(R_t) |
| R_0t = torch.einsum("...ij,...kj->...ik", R_t, R_0) |
| R_0t_rotvec = torch.tensor( |
| scipy_R.from_matrix(R_0t.cpu().numpy()).as_rotvec() |
| ).to(R_0.device) |
|
|
| |
| |
| |
| Omega = torch.linalg.norm(R_0t_rotvec, axis=-1).numpy() |
| Score_approx = R_0t_rotvec * (self.score_norm(t, Omega) / Omega)[:, None] |
|
|
| |
| continuous_t = t / self.T |
| rot_g = self.g(continuous_t).to(Score_approx.device) |
|
|
| |
| |
| |
| |
| |
| Z = np.random.normal(size=(R_0.shape[0], 3)) |
| Z = torch.from_numpy(Z).to(Score_approx.device) |
| Z *= noise_level |
|
|
| Delta_r = (rot_g**2) * self.step_size * Score_approx |
|
|
| |
| |
| |
| Perturb_tangent = Delta_r + rot_g * np.sqrt(self.step_size) * Z |
| if mask is not None: |
| Perturb_tangent *= (1 - mask.long())[:, None, None] |
| Perturb = igso3.Exp(Perturb_tangent) |
|
|
| if return_perturb: |
| return Perturb |
|
|
| Interp_rot = torch.einsum("...ij,...jk->...ik", Perturb, R_t) |
|
|
| return Interp_rot |
|
|
|
|
| class Diffuser: |
| |
|
|
| def __init__( |
| self, |
| T, |
| b_0, |
| b_T, |
| min_sigma, |
| max_sigma, |
| min_b, |
| max_b, |
| schedule_type, |
| so3_schedule_type, |
| so3_type, |
| crd_scale, |
| schedule_kwargs={}, |
| var_scale=1.0, |
| cache_dir=".", |
| partial_T=None, |
| truncation_level=2000, |
| ): |
| """ |
| Parameters: |
| |
| T (int, required): Number of steps in the schedule |
| |
| b_0 (float, required): Starting variance for Euclidean schedule |
| |
| b_T (float, required): Ending variance for Euclidean schedule |
| |
| """ |
| self.T = T |
| self.b_0 = b_0 |
| self.b_T = b_T |
| self.min_sigma = min_sigma |
| self.max_sigma = max_sigma |
| self.crd_scale = crd_scale |
| self.var_scale = var_scale |
| self.cache_dir = cache_dir |
|
|
| |
| self.so3_diffuser = IGSO3( |
| T=self.T, |
| min_sigma=self.min_sigma, |
| max_sigma=self.max_sigma, |
| schedule=so3_schedule_type, |
| min_b=min_b, |
| max_b=max_b, |
| cache_dir=self.cache_dir, |
| L=truncation_level, |
| ) |
|
|
| |
| self.eucl_diffuser = EuclideanDiffuser( |
| self.T, b_0, b_T, schedule_type=schedule_type, **schedule_kwargs |
| ) |
|
|
| print("Successful diffuser __init__") |
|
|
| def diffuse_pose( |
| self, |
| xyz, |
| seq, |
| atom_mask, |
| include_motif_sidechains=True, |
| diffusion_mask=None, |
| t_list=None, |
| ): |
| """ |
| Given full atom xyz, sequence and atom mask, diffuse the protein frame |
| translations and rotations |
| |
| Parameters: |
| |
| xyz (L,14/27,3) set of coordinates |
| |
| seq (L,) integer sequence |
| |
| atom_mask: mask describing presence/absence of an atom in pdb |
| |
| diffusion_mask (torch.tensor, optional): Tensor of bools, True means NOT diffused at this residue, False means diffused |
| |
| t_list (list, optional): If present, only return the diffused coordinates at timesteps t within the list |
| |
| |
| """ |
|
|
| if diffusion_mask is None: |
| diffusion_mask = torch.zeros(len(xyz.squeeze())).to(dtype=bool) |
|
|
| get_allatom = ComputeAllAtomCoords().to(device=xyz.device) |
| L = len(xyz) |
|
|
| |
| |
| nan_mask = ~torch.isnan(xyz.squeeze()[:, :3]).any(dim=-1).any(dim=-1) |
| assert torch.sum(~nan_mask) == 0 |
|
|
| |
| if torch.sum(diffusion_mask) != 0: |
| self.motif_com = xyz[diffusion_mask, 1, :].mean( |
| dim=0 |
| ) |
| xyz = xyz - self.motif_com |
| elif torch.sum(diffusion_mask) == 0: |
| xyz = xyz - xyz[:, 1, :].mean(dim=0) |
|
|
| xyz_true = torch.clone(xyz) |
| xyz = xyz * self.crd_scale |
|
|
| |
| tick = time.time() |
| diffused_T, deltas = self.eucl_diffuser.diffuse_translations( |
| xyz[:, :3, :].clone(), diffusion_mask=diffusion_mask |
| ) |
| |
| diffused_T /= self.crd_scale |
| deltas /= self.crd_scale |
|
|
| |
| tick = time.time() |
| diffused_frame_crds, diffused_frames = self.so3_diffuser.diffuse_frames( |
| xyz[:, :3, :].clone(), diffusion_mask=diffusion_mask.numpy(), t_list=None |
| ) |
| diffused_frame_crds /= self.crd_scale |
| |
|
|
| |
| tick = time.time() |
| cum_delta = deltas.cumsum(dim=1) |
| |
| diffused_BB = ( |
| torch.from_numpy(diffused_frame_crds) + cum_delta[:, :, None, :] |
| ).transpose( |
| 0, 1 |
| ) |
| |
|
|
| |
| t_steps, L = diffused_BB.shape[:2] |
|
|
| diffused_fa = torch.zeros(t_steps, L, 27, 3) |
| diffused_fa[:, :, :3, :] = diffused_BB |
|
|
| |
| if include_motif_sidechains: |
| diffused_fa[:, diffusion_mask, :14, :] = xyz_true[None, diffusion_mask, :14] |
|
|
| if t_list is None: |
| fa_stack = diffused_fa |
| else: |
| t_idx_list = [t - 1 for t in t_list] |
| fa_stack = diffused_fa[t_idx_list] |
|
|
| return fa_stack, xyz_true |
|
|