gustproof's picture
First
fb3e84a
raw
history blame
1.77 kB
import torch
import numpy as np
def edm_sampler(
net,
noise,
labels=None,
gnet=None,
num_steps=32,
sigma_min=0.002,
sigma_max=80,
rho=7,
guidance=1,
S_churn=0,
S_min=0,
S_max=float("inf"),
S_noise=1,
dtype=torch.float32,
randn_like=torch.randn_like,
):
# Guided denoiser.
def denoise(x, t):
Dx = net(x, t, labels).to(dtype)
if guidance == 1:
return Dx
ref_Dx = gnet(x, t).to(dtype)
return ref_Dx.lerp(Dx, guidance)
# Time step discretization.
step_indices = torch.arange(num_steps, dtype=dtype, device=noise.device)
t_steps = (
sigma_max ** (1 / rho)
+ step_indices
/ (num_steps - 1)
* (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
) ** rho
t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
# Main sampling loop.
x_next = noise.to(dtype) * t_steps[0]
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
x_cur = x_next
# Increase noise temporarily.
if S_churn > 0 and S_min <= t_cur <= S_max:
gamma = min(S_churn / num_steps, np.sqrt(2) - 1)
t_hat = t_cur + gamma * t_cur
x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur)
else:
t_hat = t_cur
x_hat = x_cur
# Euler step.
d_cur = (x_hat - denoise(x_hat, t_hat)) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur
# Apply 2nd order correction.
if i < num_steps - 1:
d_prime = (x_next - denoise(x_next, t_next)) / t_next
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
return x_next