""" https://github.com/ProteinDesignLab/protpardelle License: MIT Author: Alex Chu Noise and diffusion utils. """ from scipy.stats import norm import torch from torchtyping import TensorType from core import utils def noise_schedule( time: TensorType[float], function: str = "uniform", sigma_data: float = 10.0, psigma_mean: float = -1.2, psigma_std: float = 1.2, s_min: float = 0.001, s_max: float = 60, rho: float = 7.0, time_power: float = 4.0, constant_val: float = 0.0, ): def sampling_noise(time): # high noise = 1; low noise = 0. opposite of Karras et al. schedule term1 = s_max ** (1 / rho) term2 = (1 - time) * (s_min ** (1 / rho) - s_max ** (1 / rho)) noise_level = sigma_data * ((term1 + term2) ** rho) return noise_level if function == "lognormal": normal_sample = torch.Tensor(norm.ppf(time.cpu())).to(time) noise_level = sigma_data * torch.exp(psigma_mean + psigma_std * normal_sample) elif function == "uniform": noise_level = sampling_noise(time) elif function == "mpnn": time = time**time_power noise_level = sampling_noise(time) elif function == "constant": noise_level = torch.ones_like(time) * constant_val return noise_level def noise_coords( coords: TensorType["b n a x", float], noise_level: TensorType["b", float], dummy_fill_masked_atoms: bool = False, atom_mask: TensorType["b n a"] = None, ): # Does not apply atom mask after adding noise if dummy_fill_masked_atoms: assert atom_mask is not None dummy_fill_mask = 1 - atom_mask dummy_fill_value = coords[..., 1:2, :] # CA # dummy_fill_value = utils.fill_in_cbeta_for_atom37(coords)[..., 3:4, :] # CB coords = ( coords * atom_mask[..., None] + dummy_fill_value * dummy_fill_mask[..., None] ) noise = torch.randn_like(coords) * utils.expand(noise_level, coords) noisy_coords = coords + noise return noisy_coords