Spaces:
Starting
on
T4
Starting
on
T4
File size: 2,057 Bytes
8c639ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
"""
https://github.com/ProteinDesignLab/protpardelle
License: MIT
Author: Alex Chu
Noise and diffusion utils.
"""
from scipy.stats import norm
import torch
from torchtyping import TensorType
from core import utils
def noise_schedule(
time: TensorType[float],
function: str = "uniform",
sigma_data: float = 10.0,
psigma_mean: float = -1.2,
psigma_std: float = 1.2,
s_min: float = 0.001,
s_max: float = 60,
rho: float = 7.0,
time_power: float = 4.0,
constant_val: float = 0.0,
):
def sampling_noise(time):
# high noise = 1; low noise = 0. opposite of Karras et al. schedule
term1 = s_max ** (1 / rho)
term2 = (1 - time) * (s_min ** (1 / rho) - s_max ** (1 / rho))
noise_level = sigma_data * ((term1 + term2) ** rho)
return noise_level
if function == "lognormal":
normal_sample = torch.Tensor(norm.ppf(time.cpu())).to(time)
noise_level = sigma_data * torch.exp(psigma_mean + psigma_std * normal_sample)
elif function == "uniform":
noise_level = sampling_noise(time)
elif function == "mpnn":
time = time**time_power
noise_level = sampling_noise(time)
elif function == "constant":
noise_level = torch.ones_like(time) * constant_val
return noise_level
def noise_coords(
coords: TensorType["b n a x", float],
noise_level: TensorType["b", float],
dummy_fill_masked_atoms: bool = False,
atom_mask: TensorType["b n a"] = None,
):
# Does not apply atom mask after adding noise
if dummy_fill_masked_atoms:
assert atom_mask is not None
dummy_fill_mask = 1 - atom_mask
dummy_fill_value = coords[..., 1:2, :] # CA
# dummy_fill_value = utils.fill_in_cbeta_for_atom37(coords)[..., 3:4, :] # CB
coords = (
coords * atom_mask[..., None]
+ dummy_fill_value * dummy_fill_mask[..., None]
)
noise = torch.randn_like(coords) * utils.expand(noise_level, coords)
noisy_coords = coords + noise
return noisy_coords
|