Spaces:
Sleeping
Sleeping
""" | |
https://github.com/ProteinDesignLab/protpardelle | |
License: MIT | |
Author: Alex Chu | |
Noise and diffusion utils. | |
""" | |
from scipy.stats import norm | |
import torch | |
from torchtyping import TensorType | |
from core import utils | |
def noise_schedule( | |
time: TensorType[float], | |
function: str = "uniform", | |
sigma_data: float = 10.0, | |
psigma_mean: float = -1.2, | |
psigma_std: float = 1.2, | |
s_min: float = 0.001, | |
s_max: float = 60, | |
rho: float = 7.0, | |
time_power: float = 4.0, | |
constant_val: float = 0.0, | |
): | |
def sampling_noise(time): | |
# high noise = 1; low noise = 0. opposite of Karras et al. schedule | |
term1 = s_max ** (1 / rho) | |
term2 = (1 - time) * (s_min ** (1 / rho) - s_max ** (1 / rho)) | |
noise_level = sigma_data * ((term1 + term2) ** rho) | |
return noise_level | |
if function == "lognormal": | |
normal_sample = torch.Tensor(norm.ppf(time.cpu())).to(time) | |
noise_level = sigma_data * torch.exp(psigma_mean + psigma_std * normal_sample) | |
elif function == "uniform": | |
noise_level = sampling_noise(time) | |
elif function == "mpnn": | |
time = time**time_power | |
noise_level = sampling_noise(time) | |
elif function == "constant": | |
noise_level = torch.ones_like(time) * constant_val | |
return noise_level | |
def noise_coords( | |
coords: TensorType["b n a x", float], | |
noise_level: TensorType["b", float], | |
dummy_fill_masked_atoms: bool = False, | |
atom_mask: TensorType["b n a"] = None, | |
): | |
# Does not apply atom mask after adding noise | |
if dummy_fill_masked_atoms: | |
assert atom_mask is not None | |
dummy_fill_mask = 1 - atom_mask | |
dummy_fill_value = coords[..., 1:2, :] # CA | |
# dummy_fill_value = utils.fill_in_cbeta_for_atom37(coords)[..., 3:4, :] # CB | |
coords = ( | |
coords * atom_mask[..., None] | |
+ dummy_fill_value * dummy_fill_mask[..., None] | |
) | |
noise = torch.randn_like(coords) * utils.expand(noise_level, coords) | |
noisy_coords = coords + noise | |
return noisy_coords | |