Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
def edm_sampler( | |
net, | |
noise, | |
labels=None, | |
gnet=None, | |
num_steps=32, | |
sigma_min=0.002, | |
sigma_max=80, | |
rho=7, | |
guidance=1, | |
S_churn=0, | |
S_min=0, | |
S_max=float("inf"), | |
S_noise=1, | |
dtype=torch.float32, | |
randn_like=torch.randn_like, | |
): | |
# Guided denoiser. | |
def denoise(x, t): | |
Dx = net(x, t, labels).to(dtype) | |
if guidance == 1: | |
return Dx | |
ref_Dx = gnet(x, t).to(dtype) | |
return ref_Dx.lerp(Dx, guidance) | |
# Time step discretization. | |
step_indices = torch.arange(num_steps, dtype=dtype, device=noise.device) | |
t_steps = ( | |
sigma_max ** (1 / rho) | |
+ step_indices | |
/ (num_steps - 1) | |
* (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) | |
) ** rho | |
t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 | |
# Main sampling loop. | |
x_next = noise.to(dtype) * t_steps[0] | |
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 | |
x_cur = x_next | |
# Increase noise temporarily. | |
if S_churn > 0 and S_min <= t_cur <= S_max: | |
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) | |
t_hat = t_cur + gamma * t_cur | |
x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) | |
else: | |
t_hat = t_cur | |
x_hat = x_cur | |
# Euler step. | |
d_cur = (x_hat - denoise(x_hat, t_hat)) / t_hat | |
x_next = x_hat + (t_next - t_hat) * d_cur | |
# Apply 2nd order correction. | |
if i < num_steps - 1: | |
d_prime = (x_next - denoise(x_next, t_next)) / t_next | |
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) | |
return x_next | |