|
import torch
|
|
import math
|
|
from tqdm import trange, tqdm
|
|
|
|
import k_diffusion as K
|
|
|
|
|
|
def get_alphas_sigmas(t):
|
|
"""Returns the scaling factors for the clean image (alpha) and for the
|
|
noise (sigma), given a timestep."""
|
|
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
|
|
|
|
def alpha_sigma_to_t(alpha, sigma):
|
|
"""Returns a timestep, given the scaling factors for the clean image and for
|
|
the noise."""
|
|
return torch.atan2(sigma, alpha) / math.pi * 2
|
|
|
|
def t_to_alpha_sigma(t):
|
|
"""Returns the scaling factors for the clean image and for the noise, given
|
|
a timestep."""
|
|
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
|
|
|
|
|
|
@torch.no_grad()
|
|
def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args):
|
|
"""Draws samples from a model given starting noise. Euler method"""
|
|
|
|
|
|
ts = x.new_ones([x.shape[0]])
|
|
|
|
|
|
t = torch.linspace(sigma_max, 0, steps + 1)
|
|
|
|
|
|
|
|
for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])):
|
|
|
|
t_curr_tensor = t_curr * torch.ones(
|
|
(x.shape[0],), dtype=x.dtype, device=x.device
|
|
)
|
|
dt = t_prev - t_curr
|
|
x = x + dt * model(x, t_curr_tensor, **extra_args)
|
|
|
|
|
|
return x
|
|
|
|
@torch.no_grad()
|
|
def sample(model, x, steps, eta, **extra_args):
|
|
"""Draws samples from a model given starting noise. v-diffusion"""
|
|
ts = x.new_ones([x.shape[0]])
|
|
|
|
|
|
t = torch.linspace(1, 0, steps + 1)[:-1]
|
|
|
|
alphas, sigmas = get_alphas_sigmas(t)
|
|
|
|
|
|
for i in trange(steps):
|
|
|
|
|
|
with torch.cuda.amp.autocast():
|
|
v = model(x, ts * t[i], **extra_args).float()
|
|
|
|
|
|
pred = x * alphas[i] - v * sigmas[i]
|
|
eps = x * sigmas[i] + v * alphas[i]
|
|
|
|
|
|
|
|
if i < steps - 1:
|
|
|
|
|
|
ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
|
|
(1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
|
|
adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
|
|
|
|
|
|
|
|
x = pred * alphas[i + 1] + eps * adjusted_sigma
|
|
|
|
|
|
if eta:
|
|
x += torch.randn_like(x) * ddim_sigma
|
|
|
|
|
|
return pred
|
|
|
|
|
|
|
|
def get_bmask(i, steps, mask):
|
|
strength = (i+1)/(steps)
|
|
|
|
bmask = torch.where(mask<=strength,1,0)
|
|
return bmask
|
|
|
|
def make_cond_model_fn(model, cond_fn):
|
|
def cond_model_fn(x, sigma, **kwargs):
|
|
with torch.enable_grad():
|
|
x = x.detach().requires_grad_()
|
|
denoised = model(x, sigma, **kwargs)
|
|
cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
|
|
cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
|
|
return cond_denoised
|
|
return cond_model_fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_k(
|
|
model_fn,
|
|
noise,
|
|
init_data=None,
|
|
mask=None,
|
|
steps=100,
|
|
sampler_type="dpmpp-2m-sde",
|
|
sigma_min=0.5,
|
|
sigma_max=50,
|
|
rho=1.0, device="cuda",
|
|
callback=None,
|
|
cond_fn=None,
|
|
**extra_args
|
|
):
|
|
|
|
denoiser = K.external.VDenoiser(model_fn)
|
|
|
|
if cond_fn is not None:
|
|
denoiser = make_cond_model_fn(denoiser, cond_fn)
|
|
|
|
|
|
sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
|
|
|
|
noise = noise * sigmas[0]
|
|
|
|
wrapped_callback = callback
|
|
|
|
|
|
if mask is None and init_data is not None:
|
|
|
|
|
|
|
|
x = init_data + noise
|
|
|
|
elif mask is not None and init_data is not None:
|
|
|
|
bmask = get_bmask(0, steps, mask)
|
|
|
|
input_noised = init_data + noise
|
|
|
|
x = input_noised * bmask + noise * (1-bmask)
|
|
|
|
|
|
|
|
|
|
def inpainting_callback(args):
|
|
i = args["i"]
|
|
x = args["x"]
|
|
sigma = args["sigma"]
|
|
|
|
|
|
input_noised = init_data + torch.randn_like(init_data) * sigma
|
|
|
|
bmask = get_bmask(i, steps, mask)
|
|
|
|
new_x = input_noised * bmask + x * (1-bmask)
|
|
|
|
x[:,:,:] = new_x[:,:,:]
|
|
|
|
if callback is None:
|
|
wrapped_callback = inpainting_callback
|
|
else:
|
|
wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
|
|
else:
|
|
|
|
|
|
x = noise
|
|
|
|
|
|
with torch.cuda.amp.autocast():
|
|
if sampler_type == "k-heun":
|
|
return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
|
elif sampler_type == "k-lms":
|
|
return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
|
elif sampler_type == "k-dpmpp-2s-ancestral":
|
|
return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
|
elif sampler_type == "k-dpm-2":
|
|
return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
|
elif sampler_type == "k-dpm-fast":
|
|
return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
|
elif sampler_type == "k-dpm-adaptive":
|
|
return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
|
elif sampler_type == "dpmpp-2m-sde":
|
|
return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
|
elif sampler_type == "dpmpp-3m-sde":
|
|
return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_rf(
|
|
model_fn,
|
|
noise,
|
|
init_data=None,
|
|
steps=100,
|
|
sigma_max=1,
|
|
device="cuda",
|
|
callback=None,
|
|
cond_fn=None,
|
|
**extra_args
|
|
):
|
|
|
|
if sigma_max > 1:
|
|
sigma_max = 1
|
|
|
|
if cond_fn is not None:
|
|
denoiser = make_cond_model_fn(denoiser, cond_fn)
|
|
|
|
wrapped_callback = callback
|
|
|
|
if init_data is not None:
|
|
|
|
|
|
x = init_data * (1 - sigma_max) + noise * sigma_max
|
|
else:
|
|
|
|
|
|
x = noise
|
|
|
|
with torch.cuda.amp.autocast():
|
|
|
|
|
|
return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args) |