| | |
| | import torch |
| | import pickle |
| | import numpy as np |
| | import os |
| | import logging |
| |
|
| | from scipy.spatial.transform import Rotation as scipy_R |
| |
|
| | from rfdiffusion.util import rigid_from_3_points |
| |
|
| | from rfdiffusion.util_module import ComputeAllAtomCoords |
| |
|
| | from rfdiffusion import igso3 |
| | import time |
| |
|
| | torch.set_printoptions(sci_mode=False) |
| |
|
| |
|
| | def get_beta_schedule(T, b0, bT, schedule_type, schedule_params={}, inference=False): |
| | """ |
| | Given a noise schedule type, create the beta schedule |
| | """ |
| | assert schedule_type in ["linear"] |
| |
|
| | |
| | |
| | assert T >= 15, "With discrete time and T < 15, the schedule is badly approximated" |
| | b0 *= 200 / T |
| | bT *= 200 / T |
| |
|
| | |
| | if schedule_type == "linear": |
| | schedule = torch.linspace(b0, bT, T) |
| |
|
| | else: |
| | raise NotImplementedError(f"Schedule of type {schedule_type} not implemented.") |
| |
|
| | |
| | alpha_schedule = 1 - schedule |
| | alphabar_t_schedule = torch.cumprod(alpha_schedule, dim=0) |
| |
|
| | if inference: |
| | print( |
| | f"With this beta schedule ({schedule_type} schedule, beta_0 = {round(b0, 3)}, beta_T = {round(bT,3)}), alpha_bar_T = {alphabar_t_schedule[-1]}" |
| | ) |
| |
|
| | return schedule, alpha_schedule, alphabar_t_schedule |
| |
|
| |
|
| | class EuclideanDiffuser: |
| | |
| |
|
| | def __init__( |
| | self, |
| | T, |
| | b_0, |
| | b_T, |
| | schedule_type="linear", |
| | schedule_kwargs={}, |
| | ): |
| | self.T = T |
| |
|
| | |
| | ( |
| | self.beta_schedule, |
| | self.alpha_schedule, |
| | self.alphabar_schedule, |
| | ) = get_beta_schedule(T, b_0, b_T, schedule_type, **schedule_kwargs) |
| |
|
| | def diffuse_translations(self, xyz, diffusion_mask=None, var_scale=1): |
| | return self.apply_kernel_recursive(xyz, diffusion_mask, var_scale) |
| |
|
| | def apply_kernel(self, x, t, diffusion_mask=None, var_scale=1): |
| | """ |
| | Applies a noising kernel to the points in x |
| | |
| | Parameters: |
| | x (torch.tensor, required): (N,3,3) set of backbone coordinates |
| | |
| | t (int, required): Which timestep |
| | |
| | noise_scale (float, required): scale for noise |
| | """ |
| | t_idx = t - 1 |
| |
|
| | assert len(x.shape) == 3 |
| | L, _, _ = x.shape |
| |
|
| | |
| | ca_xyz = x[:, 1, :] |
| |
|
| | b_t = self.beta_schedule[t_idx] |
| |
|
| | |
| | mean = torch.sqrt(1 - b_t) * ca_xyz |
| | var = torch.ones(L, 3) * (b_t) * var_scale |
| |
|
| | sampled_crds = torch.normal(mean, torch.sqrt(var)) |
| | delta = sampled_crds - ca_xyz |
| |
|
| | if not diffusion_mask is None: |
| | delta[diffusion_mask, ...] = 0 |
| |
|
| | out_crds = x + delta[:, None, :] |
| |
|
| | return out_crds, delta |
| |
|
| | def apply_kernel_recursive(self, xyz, diffusion_mask=None, var_scale=1): |
| | """ |
| | Repeatedly apply self.apply_kernel T times and return all crds |
| | """ |
| | bb_stack = [] |
| | T_stack = [] |
| |
|
| | cur_xyz = torch.clone(xyz) |
| |
|
| | for t in range(1, self.T + 1): |
| | cur_xyz, cur_T = self.apply_kernel( |
| | cur_xyz, t, var_scale=var_scale, diffusion_mask=diffusion_mask |
| | ) |
| | bb_stack.append(cur_xyz) |
| | T_stack.append(cur_T) |
| |
|
| | return torch.stack(bb_stack).transpose(0, 1), torch.stack(T_stack).transpose( |
| | 0, 1 |
| | ) |
| |
|
| |
|
| | def write_pkl(save_path: str, pkl_data): |
| | """Serialize data into a pickle file.""" |
| | with open(save_path, "wb") as handle: |
| | pickle.dump(pkl_data, handle, protocol=pickle.HIGHEST_PROTOCOL) |
| |
|
| |
|
| | def read_pkl(read_path: str, verbose=False): |
| | """Read data from a pickle file.""" |
| | with open(read_path, "rb") as handle: |
| | try: |
| | return pickle.load(handle) |
| | except Exception as e: |
| | if verbose: |
| | print(f"Failed to read {read_path}") |
| | raise (e) |
| |
|
| |
|
| | class IGSO3: |
| | """ |
| | Class for taking in a set of backbone crds and performing IGSO3 diffusion |
| | on all of them. |
| | |
| | Unlike the diffusion on translations, much of this class is written for a |
| | scaling between an initial time t=0 and final time t=1. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | *, |
| | T, |
| | min_sigma, |
| | max_sigma, |
| | min_b, |
| | max_b, |
| | cache_dir, |
| | num_omega=1000, |
| | schedule="linear", |
| | L=2000, |
| | ): |
| | """ |
| | |
| | Args: |
| | T: total number of time steps |
| | min_sigma: smallest allowed scale parameter, should be at least 0.01 to maintain numerical stability. Recommended value is 0.05. |
| | max_sigma: for exponential schedule, the largest scale parameter. Ignored for recommeded linear schedule |
| | min_b: lower value of beta in Ho schedule analogue |
| | max_b: upper value of beta in Ho schedule analogue |
| | num_omega: discretization level in the angles across [0, pi] |
| | schedule: currently only linear and exponential are supported. The exponential schedule may be noising too slowly. |
| | L: truncation level |
| | """ |
| | self._log = logging.getLogger(__name__) |
| |
|
| | self.T = T |
| |
|
| | self.schedule = schedule |
| | self.cache_dir = cache_dir |
| | self.min_sigma = min_sigma |
| | self.max_sigma = max_sigma |
| |
|
| | if self.schedule == "linear": |
| | self.min_b = min_b |
| | self.max_b = max_b |
| | self.max_sigma = self.sigma(1.0) |
| | self.num_omega = num_omega |
| | self.num_sigma = 500 |
| | |
| | self.L = L |
| | self.igso3_vals = self._calc_igso3_vals(L=L) |
| | self.step_size = 1 / self.T |
| |
|
| | def _calc_igso3_vals(self, L=2000): |
| | """_calc_igso3_vals computes numerical approximations to the |
| | relevant analytically intractable functionals of the igso3 |
| | distribution. |
| | |
| | The calculated values are cached, or loaded from cache if they already |
| | exist. |
| | |
| | Args: |
| | L: truncation level for power series expansion of the pdf. |
| | """ |
| | replace_period = lambda x: str(x).replace(".", "_") |
| | if self.schedule == "linear": |
| | cache_fname = os.path.join( |
| | self.cache_dir, |
| | f"T_{self.T}_omega_{self.num_omega}_min_sigma_{replace_period(self.min_sigma)}" |
| | + f"_min_b_{replace_period(self.min_b)}_max_b_{replace_period(self.max_b)}_schedule_{self.schedule}.pkl", |
| | ) |
| | elif self.schedule == "exponential": |
| | cache_fname = os.path.join( |
| | self.cache_dir, |
| | f"T_{self.T}_omega_{self.num_omega}_min_sigma_{replace_period(self.min_sigma)}" |
| | f"_max_sigma_{replace_period(self.max_sigma)}_schedule_{self.schedule}", |
| | ) |
| | else: |
| | raise ValueError(f"Unrecognize schedule {self.schedule}") |
| |
|
| | if not os.path.isdir(self.cache_dir): |
| | os.makedirs(self.cache_dir) |
| |
|
| | if os.path.exists(cache_fname): |
| | self._log.info("Using cached IGSO3.") |
| | igso3_vals = read_pkl(cache_fname) |
| | else: |
| | self._log.info("Calculating IGSO3.") |
| | igso3_vals = igso3.calculate_igso3( |
| | num_sigma=self.num_sigma, |
| | min_sigma=self.min_sigma, |
| | max_sigma=self.max_sigma, |
| | num_omega=self.num_omega |
| | ) |
| | write_pkl(cache_fname, igso3_vals) |
| |
|
| | return igso3_vals |
| |
|
| | @property |
| | def discrete_sigma(self): |
| | return self.igso3_vals["discrete_sigma"] |
| |
|
| | def sigma_idx(self, sigma: np.ndarray): |
| | """ |
| | Calculates the index for discretized sigma during IGSO(3) initialization.""" |
| | return np.digitize(sigma, self.discrete_sigma) - 1 |
| |
|
| | def t_to_idx(self, t: np.ndarray): |
| | """ |
| | Helper function to go from discrete time index t to corresponding sigma_idx. |
| | |
| | Args: |
| | t: time index (integer between 1 and 200) |
| | """ |
| | continuous_t = t / self.T |
| | return self.sigma_idx(self.sigma(continuous_t)) |
| |
|
| | def sigma(self, t: torch.tensor): |
| | """ |
| | Extract \sigma(t) corresponding to chosen sigma schedule. |
| | |
| | Args: |
| | t: torch tensor with time between 0 and 1 |
| | """ |
| | if not type(t) == torch.Tensor: |
| | t = torch.tensor(t) |
| | if torch.any(t < 0) or torch.any(t > 1): |
| | raise ValueError(f"Invalid t={t}") |
| | if self.schedule == "exponential": |
| | sigma = t * np.log10(self.max_sigma) + (1 - t) * np.log10(self.min_sigma) |
| | return 10**sigma |
| | elif self.schedule == "linear": |
| | |
| | return ( |
| | self.min_sigma |
| | + t * self.min_b |
| | + (1 / 2) * (t**2) * (self.max_b - self.min_b) |
| | ) |
| | else: |
| | raise ValueError(f"Unrecognize schedule {self.schedule}") |
| |
|
| | def g(self, t): |
| | """ |
| | g returns the drift coefficient at time t |
| | |
| | since |
| | sigma(t)^2 := \int_0^t g(s)^2 ds, |
| | for arbitrary sigma(t) we invert this relationship to compute |
| | g(t) = sqrt(d/dt sigma(t)^2). |
| | |
| | Args: |
| | t: scalar time between 0 and 1 |
| | |
| | Returns: |
| | drift cooeficient as a scalar. |
| | """ |
| | t = torch.tensor(t, requires_grad=True) |
| | sigma_sqr = self.sigma(t) ** 2 |
| | grads = torch.autograd.grad(sigma_sqr.sum(), t)[0] |
| | return torch.sqrt(grads) |
| |
|
| | def sample(self, ts, n_samples=1): |
| | """ |
| | sample uses the inverse cdf to sample an angle of rotation from |
| | IGSO(3) |
| | |
| | Args: |
| | ts: array of integer time steps to sample from. |
| | n_samples: number of samples to draw. |
| | Returns: |
| | sampled angles of rotation. [len(ts), N] |
| | """ |
| | assert sum(ts == 0) == 0, "assumes one-indexed, not zero indexed" |
| | all_samples = [] |
| | for t in ts: |
| | sigma_idx = self.t_to_idx(t) |
| | sample_i = np.interp( |
| | np.random.rand(n_samples), |
| | self.igso3_vals["cdf"][sigma_idx], |
| | self.igso3_vals["discrete_omega"], |
| | ) |
| | all_samples.append(sample_i) |
| | return np.stack(all_samples, axis=0) |
| |
|
| | def sample_vec(self, ts, n_samples=1): |
| | """sample_vec generates a rotation vector(s) from IGSO(3) at time steps |
| | ts. |
| | |
| | Return: |
| | Sampled vector of shape [len(ts), N, 3] |
| | """ |
| | x = np.random.randn(len(ts), n_samples, 3) |
| | x /= np.linalg.norm(x, axis=-1, keepdims=True) |
| | return x * self.sample(ts, n_samples=n_samples)[..., None] |
| |
|
| | def score_norm(self, t, omega): |
| | """ |
| | score_norm computes the score norm based on the time step and angle |
| | Args: |
| | t: integer time step |
| | omega: angles (scalar or shape [N]) |
| | Return: |
| | score_norm with same shape as omega |
| | """ |
| | sigma_idx = self.t_to_idx(t) |
| | score_norm_t = np.interp( |
| | omega, |
| | self.igso3_vals["discrete_omega"], |
| | self.igso3_vals["score_norm"][sigma_idx], |
| | ) |
| | return score_norm_t |
| |
|
| | def score_vec(self, ts, vec): |
| | """score_vec computes the score of the IGSO(3) density as a rotation |
| | vector. This score vector is in the direction of the sampled vector, |
| | and has magnitude given by score_norms. |
| | |
| | In particular, Rt @ hat(score_vec(ts, vec)) is what is referred to as |
| | the score approximation in Algorithm 1 |
| | |
| | |
| | Args: |
| | ts: times of shape [T] |
| | vec: where to compute the score of shape [T, N, 3] |
| | Returns: |
| | score vectors of shape [T, N, 3] |
| | """ |
| | omega = np.linalg.norm(vec, axis=-1) |
| | all_score_norm = [] |
| | for i, t in enumerate(ts): |
| | omega_t = omega[i] |
| | t_idx = t - 1 |
| | sigma_idx = self.t_to_idx(t) |
| | score_norm_t = np.interp( |
| | omega_t, |
| | self.igso3_vals["discrete_omega"], |
| | self.igso3_vals["score_norm"][sigma_idx], |
| | )[:, None] |
| | all_score_norm.append(score_norm_t) |
| | score_norm = np.stack(all_score_norm, axis=0) |
| | return score_norm * vec / omega[..., None] |
| |
|
| | def exp_score_norm(self, ts): |
| | """exp_score_norm returns the expected value of norm of the score for |
| | IGSO(3) with time parameter ts of shape [T]. |
| | """ |
| | sigma_idcs = [self.t_to_idx(t) for t in ts] |
| | return self.igso3_vals["exp_score_norms"][sigma_idcs] |
| |
|
| | def diffuse_frames(self, xyz, t_list, diffusion_mask=None): |
| | """diffuse_frames samples from the IGSO(3) distribution to noise frames |
| | |
| | Parameters: |
| | xyz (np.array or torch.tensor, required): (L,3,3) set of backbone coordinates |
| | mask (np.array or torch.tensor, required): (L,) set of bools. True/1 is NOT diffused, False/0 IS diffused |
| | Returns: |
| | np.array : N/CA/C coordinates for each residue |
| | (T,L,3,3), where T is num timesteps |
| | """ |
| |
|
| | if torch.is_tensor(xyz): |
| | xyz = xyz.numpy() |
| |
|
| | t = np.arange(self.T) + 1 |
| | num_res = len(xyz) |
| |
|
| | N = torch.from_numpy(xyz[None, :, 0, :]) |
| | Ca = torch.from_numpy(xyz[None, :, 1, :]) |
| | C = torch.from_numpy(xyz[None, :, 2, :]) |
| |
|
| | |
| | R_true, Ca = rigid_from_3_points(N, Ca, C) |
| | R_true = R_true[0] |
| | Ca = Ca[0] |
| |
|
| | |
| | sampled_rots = self.sample_vec(t, n_samples=num_res) |
| |
|
| | if diffusion_mask is not None: |
| | non_diffusion_mask = 1 - diffusion_mask[None, :, None] |
| | sampled_rots = sampled_rots * non_diffusion_mask |
| |
|
| | |
| | R_sampled = ( |
| | scipy_R.from_rotvec(sampled_rots.reshape(-1, 3)) |
| | .as_matrix() |
| | .reshape(self.T, num_res, 3, 3) |
| | ) |
| | R_perturbed = np.einsum("tnij,njk->tnik", R_sampled, R_true) |
| | perturbed_crds = ( |
| | np.einsum( |
| | "tnij,naj->tnai", R_sampled, xyz[:, :3, :] - Ca[:, None, ...].numpy() |
| | ) |
| | + Ca[None, :, None].numpy() |
| | ) |
| |
|
| | if t_list != None: |
| | idx = [i - 1 for i in t_list] |
| | perturbed_crds = perturbed_crds[idx] |
| | R_perturbed = R_perturbed[idx] |
| |
|
| | return ( |
| | perturbed_crds.transpose(1, 0, 2, 3), |
| | R_perturbed.transpose(1, 0, 2, 3), |
| | ) |
| |
|
| | def reverse_sample_vectorized( |
| | self, R_t, R_0, t, noise_level, mask=None, return_perturb=False |
| | ): |
| | """reverse_sample uses an approximation to the IGSO3 score to sample |
| | a rotation at the previous time step. |
| | |
| | Roughly - this update follows the reverse time SDE for Reimannian |
| | manifolds proposed by de Bortoli et al. Theorem 1 [1]. But with an |
| | approximation to the score based on the prediction of R0. |
| | Unlike in reference [1], this diffusion on SO(3) relies on geometric |
| | variance schedule. Specifically we follow [2] (appendix C) and assume |
| | sigma_t = sigma_min * (sigma_max / sigma_min)^{t/T}, |
| | for time step t. When we view this as a discretization of the SDE |
| | from time 0 to 1 with step size (1/T). Following Eq. 5 and Eq. 6, |
| | this maps on to the forward time SDEs |
| | dx = g(t) dBt [FORWARD] |
| | and |
| | dx = g(t)^2 score(xt, t)dt + g(t) B't, [REVERSE] |
| | where g(t) = sigma_t * sqrt(2 * log(sigma_max/ sigma_min)), and Bt and |
| | B't are Brownian motions. The formula for g(t) obtains from equation 9 |
| | of [2], from which this sampling function may be generalized to |
| | alternative noising schedules. |
| | Args: |
| | R_t: noisy rotation of shape [N, 3, 3] |
| | R_0: prediction of un-noised rotation |
| | t: integer time step |
| | noise_level: scaling on the noise added when obtaining sample |
| | (preliminary performance seems empirically better with noise |
| | level=0.5) |
| | mask: whether the residue is to be updated. A value of 1 means the |
| | rotation is not updated from r_t. A value of 0 means the |
| | rotation is updated. |
| | Return: |
| | sampled rotation matrix for time t-1 of shape [3, 3] |
| | Reference: |
| | [1] De Bortoli, V., Mathieu, E., Hutchinson, M., Thornton, J., Teh, Y. |
| | W., & Doucet, A. (2022). Riemannian score-based generative modeling. |
| | arXiv preprint arXiv:2202.02763. |
| | [2] Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., |
| | & Poole, B. (2020). Score-based generative modeling through stochastic |
| | differential equations. arXiv preprint arXiv:2011.13456. |
| | """ |
| | |
| | R_0, R_t = torch.tensor(R_0), torch.tensor(R_t) |
| | R_0t = torch.einsum("...ij,...kj->...ik", R_t, R_0) |
| | R_0t_rotvec = torch.tensor( |
| | scipy_R.from_matrix(R_0t.cpu().numpy()).as_rotvec() |
| | ).to(R_0.device) |
| |
|
| | |
| | |
| | |
| | Omega = torch.linalg.norm(R_0t_rotvec, axis=-1).numpy() |
| | Score_approx = R_0t_rotvec * (self.score_norm(t, Omega) / Omega)[:, None] |
| |
|
| | |
| | continuous_t = t / self.T |
| | rot_g = self.g(continuous_t).to(Score_approx.device) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | Z = np.random.normal(size=(R_0.shape[0], 3)) |
| | Z = torch.from_numpy(Z).to(Score_approx.device) |
| | Z *= noise_level |
| |
|
| | Delta_r = (rot_g**2) * self.step_size * Score_approx |
| |
|
| | |
| | |
| | |
| | Perturb_tangent = Delta_r + rot_g * np.sqrt(self.step_size) * Z |
| | if mask is not None: |
| | Perturb_tangent *= (1 - mask.long())[:, None, None] |
| | Perturb = igso3.Exp(Perturb_tangent) |
| |
|
| | if return_perturb: |
| | return Perturb |
| |
|
| | Interp_rot = torch.einsum("...ij,...jk->...ik", Perturb, R_t) |
| |
|
| | return Interp_rot |
| |
|
| |
|
| | class Diffuser: |
| | |
| |
|
| | def __init__( |
| | self, |
| | T, |
| | b_0, |
| | b_T, |
| | min_sigma, |
| | max_sigma, |
| | min_b, |
| | max_b, |
| | schedule_type, |
| | so3_schedule_type, |
| | so3_type, |
| | crd_scale, |
| | schedule_kwargs={}, |
| | var_scale=1.0, |
| | cache_dir=".", |
| | partial_T=None, |
| | truncation_level=2000, |
| | ): |
| | """ |
| | Parameters: |
| | |
| | T (int, required): Number of steps in the schedule |
| | |
| | b_0 (float, required): Starting variance for Euclidean schedule |
| | |
| | b_T (float, required): Ending variance for Euclidean schedule |
| | |
| | """ |
| | self.T = T |
| | self.b_0 = b_0 |
| | self.b_T = b_T |
| | self.min_sigma = min_sigma |
| | self.max_sigma = max_sigma |
| | self.crd_scale = crd_scale |
| | self.var_scale = var_scale |
| | self.cache_dir = cache_dir |
| |
|
| | |
| | self.so3_diffuser = IGSO3( |
| | T=self.T, |
| | min_sigma=self.min_sigma, |
| | max_sigma=self.max_sigma, |
| | schedule=so3_schedule_type, |
| | min_b=min_b, |
| | max_b=max_b, |
| | cache_dir=self.cache_dir, |
| | L=truncation_level, |
| | ) |
| |
|
| | |
| | self.eucl_diffuser = EuclideanDiffuser( |
| | self.T, b_0, b_T, schedule_type=schedule_type, **schedule_kwargs |
| | ) |
| |
|
| | print("Successful diffuser __init__") |
| |
|
| | def diffuse_pose( |
| | self, |
| | xyz, |
| | seq, |
| | atom_mask, |
| | include_motif_sidechains=True, |
| | diffusion_mask=None, |
| | t_list=None, |
| | ): |
| | """ |
| | Given full atom xyz, sequence and atom mask, diffuse the protein frame |
| | translations and rotations |
| | |
| | Parameters: |
| | |
| | xyz (L,14/27,3) set of coordinates |
| | |
| | seq (L,) integer sequence |
| | |
| | atom_mask: mask describing presence/absence of an atom in pdb |
| | |
| | diffusion_mask (torch.tensor, optional): Tensor of bools, True means NOT diffused at this residue, False means diffused |
| | |
| | t_list (list, optional): If present, only return the diffused coordinates at timesteps t within the list |
| | |
| | |
| | """ |
| |
|
| | if diffusion_mask is None: |
| | diffusion_mask = torch.zeros(len(xyz.squeeze())).to(dtype=bool) |
| |
|
| | get_allatom = ComputeAllAtomCoords().to(device=xyz.device) |
| | L = len(xyz) |
| |
|
| | |
| | |
| | nan_mask = ~torch.isnan(xyz.squeeze()[:, :3]).any(dim=-1).any(dim=-1) |
| | assert torch.sum(~nan_mask) == 0 |
| |
|
| | |
| | if torch.sum(diffusion_mask) != 0: |
| | self.motif_com = xyz[diffusion_mask, 1, :].mean( |
| | dim=0 |
| | ) |
| | xyz = xyz - self.motif_com |
| | elif torch.sum(diffusion_mask) == 0: |
| | xyz = xyz - xyz[:, 1, :].mean(dim=0) |
| |
|
| | xyz_true = torch.clone(xyz) |
| | xyz = xyz * self.crd_scale |
| |
|
| | |
| | tick = time.time() |
| | diffused_T, deltas = self.eucl_diffuser.diffuse_translations( |
| | xyz[:, :3, :].clone(), diffusion_mask=diffusion_mask |
| | ) |
| | |
| | diffused_T /= self.crd_scale |
| | deltas /= self.crd_scale |
| |
|
| | |
| | tick = time.time() |
| | diffused_frame_crds, diffused_frames = self.so3_diffuser.diffuse_frames( |
| | xyz[:, :3, :].clone(), diffusion_mask=diffusion_mask.numpy(), t_list=None |
| | ) |
| | diffused_frame_crds /= self.crd_scale |
| | |
| |
|
| | |
| | tick = time.time() |
| | cum_delta = deltas.cumsum(dim=1) |
| | |
| | diffused_BB = ( |
| | torch.from_numpy(diffused_frame_crds) + cum_delta[:, :, None, :] |
| | ).transpose( |
| | 0, 1 |
| | ) |
| | |
| |
|
| | |
| | t_steps, L = diffused_BB.shape[:2] |
| |
|
| | diffused_fa = torch.zeros(t_steps, L, 27, 3) |
| | diffused_fa[:, :, :3, :] = diffused_BB |
| |
|
| | |
| | if include_motif_sidechains: |
| | diffused_fa[:, diffusion_mask, :14, :] = xyz_true[None, diffusion_mask, :14] |
| |
|
| | if t_list is None: |
| | fa_stack = diffused_fa |
| | else: |
| | t_idx_list = [t - 1 for t in t_list] |
| | fa_stack = diffused_fa[t_idx_list] |
| |
|
| | return fa_stack, xyz_true |
| |
|