""" Based on: https://github.com/crowsonkb/k-diffusion Copyright (c) 2022 Katherine Crowson Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import numpy as np import torch as th from .gaussian_diffusion import GaussianDiffusion, mean_flat class KarrasDenoiser: def __init__(self, sigma_data: float = 0.5): self.sigma_data = sigma_data def get_snr(self, sigmas): return sigmas**-2 def get_sigmas(self, sigmas): return sigmas def get_scalings(self, sigma): c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 return c_skip, c_out, c_in def training_losses(self, model, x_start, sigmas, model_kwargs=None, noise=None): if model_kwargs is None: model_kwargs = {} if noise is None: noise = th.randn_like(x_start) terms = {} dims = x_start.ndim x_t = x_start + noise * append_dims(sigmas, dims) c_skip, c_out, _ = [append_dims(x, dims) for x in self.get_scalings(sigmas)] model_output, denoised = self.denoise(model, x_t, sigmas, **model_kwargs) target = (x_start - c_skip * x_t) / c_out terms["mse"] = mean_flat((model_output - target) ** 2) terms["xs_mse"] = mean_flat((denoised - x_start) ** 2) if "vb" in terms: terms["loss"] = terms["mse"] + terms["vb"] else: terms["loss"] = terms["mse"] return terms def denoise(self, model, x_t, sigmas, **model_kwargs): c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)] rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44) model_output = model(c_in * x_t, rescaled_t, **model_kwargs) denoised = c_out * model_output + c_skip * x_t return model_output, denoised class GaussianToKarrasDenoiser: def __init__(self, model, diffusion): from scipy import interpolate self.model = model self.diffusion = diffusion self.alpha_cumprod_to_t = interpolate.interp1d( diffusion.alphas_cumprod, np.arange(0, diffusion.num_timesteps) ) def sigma_to_t(self, sigma): alpha_cumprod = 1.0 / (sigma**2 + 1) if alpha_cumprod > self.diffusion.alphas_cumprod[0]: return 0 elif alpha_cumprod <= self.diffusion.alphas_cumprod[-1]: return self.diffusion.num_timesteps - 1 else: return float(self.alpha_cumprod_to_t(alpha_cumprod)) def denoise(self, x_t, sigmas, clip_denoised=True, model_kwargs=None, condition_latents=None): t = th.tensor( [self.sigma_to_t(sigma) for sigma in sigmas.cpu().numpy()], dtype=th.long, device=sigmas.device, ) c_in = append_dims(1.0 / (sigmas**2 + 1) ** 0.5, x_t.ndim) out = self.diffusion.p_mean_variance( self.model, x_t * c_in, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs, condition_latents=condition_latents ) return None, out["pred_xstart"] def karras_sample(*args, **kwargs): last = None x_sequence = [] # print("kraras_sample_model_kwargs", kwargs["model_kwargs"]['embeddings'].shape) for x in karras_sample_progressive(*args, **kwargs): last = x["x"] x_sequence.append(last) return last, x_sequence def karras_sample_progressive( diffusion, model, shape, steps, clip_denoised=True, progress=False, model_kwargs=None, device=None, sigma_min=0.002, sigma_max=80, # higher for highres? rho=7.0, sampler="heun", s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, guidance_scale=0.0, condition_latent=None, initial_noise=None, ): sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device) # print("sigmas", sigmas.shape, sigmas) if initial_noise is None: x_T = th.randn(*shape, device=device) * sigma_max else: x_T = initial_noise.clone() * sigma_max sample_fn = {"heun": sample_heun, "dpm": sample_dpm, "ancestral": sample_euler_ancestral}[ sampler ] if sampler != "ancestral": sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise) else: sampler_args = {} if isinstance(diffusion, KarrasDenoiser): def denoiser(x_t, sigma): _, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs) if clip_denoised: denoised = denoised.clamp(-1, 1) return denoised elif isinstance(diffusion, GaussianDiffusion): model = GaussianToKarrasDenoiser(model, diffusion) def denoiser(x_t, sigma): _, denoised = model.denoise( x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs, condition_latents=condition_latent ) return denoised else: raise NotImplementedError if guidance_scale != 0 and guidance_scale != 1: def guided_denoiser(x_t, sigma): x_t = th.cat([x_t, x_t], dim=0) sigma = th.cat([sigma, sigma], dim=0) x_0 = denoiser(x_t, sigma) cond_x_0, uncond_x_0 = th.split(x_0, len(x_0) // 2, dim=0) x_0 = uncond_x_0 + guidance_scale * (cond_x_0 - uncond_x_0) return x_0 else: guided_denoiser = denoiser for obj in sample_fn( guided_denoiser, x_T, sigmas, progress=progress, condition_latent=condition_latent, **sampler_args, ): if isinstance(diffusion, GaussianDiffusion): # print("is gaussian diffusion", obj) yield diffusion.unscale_out_dict(obj) else: yield obj def karras_sample_progressive_condition( diffusion, model, shape, steps, clip_denoised=True, progress=False, model_kwargs=None, device=None, sigma_min=0.002, sigma_max=80, # higher for highres? rho=7.0, sampler="heun", s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, text_guidance_scale=0.0, image_guidance_scale=0.0, condition_latent=None, ): sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device) x_T = th.randn(*shape, device=device) * sigma_max sample_fn = {"heun": sample_heun, "dpm": sample_dpm, "ancestral": sample_euler_ancestral}[ sampler ] if sampler != "ancestral": sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise) else: sampler_args = {} if isinstance(diffusion, KarrasDenoiser): def denoiser(x_t, sigma): _, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs) if clip_denoised: denoised = denoised.clamp(-1, 1) return denoised elif isinstance(diffusion, GaussianDiffusion): model = GaussianToKarrasDenoiser(model, diffusion) def denoiser(x_t, sigma): _, denoised = model.denoise( x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs, condition_latents=condition_latent ) return denoised else: raise NotImplementedError if (text_guidance_scale != 1.0 and text_guidance_scale != 0.0) or (image_guidance_scale != 1.0 and image_guidance_scale != 0.0): def guided_denoiser(x_t, sigma): x_t = th.cat([x_t, x_t, x_t], dim=0) sigma = th.cat([sigma, sigma, sigma], dim=0) x_0 = denoiser(x_t, sigma) # import pdb; pdb.set_trace() cond_x_0_text, cond_x_0_image, uncond_x_0 = th.chunk(x_0, 3, dim=0) x_0 = uncond_x_0 + text_guidance_scale * (cond_x_0_text - cond_x_0_image) + image_guidance_scale * (cond_x_0_image - uncond_x_0) return x_0 else: guided_denoiser = denoiser for obj in sample_fn( guided_denoiser, x_T, sigmas, progress=progress, condition_latent=condition_latent, **sampler_args, ): if isinstance(diffusion, GaussianDiffusion): yield diffusion.unscale_out_dict(obj) else: yield obj def karras_sample_addition_condition(*args, **kwargs): last = None x_sequence = [] for x in karras_sample_progressive_condition(*args, **kwargs): last = x["x"] x_sequence.append(x["pred_xstart"]) return last, x_sequence def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"): """Constructs the noise schedule of Karras et al. (2022).""" ramp = th.linspace(0, 1, n) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return append_zero(sigmas).to(device) def to_d(x, sigma, denoised): """Converts a denoiser output to a Karras ODE derivative.""" return (x - denoised) / append_dims(sigma, x.ndim) def get_ancestral_step(sigma_from, sigma_to): """Calculates the noise level (sigma_down) to step down to and the amount of noise to add (sigma_up) when doing an ancestral sampling step.""" sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 return sigma_down, sigma_up @th.no_grad() def sample_euler_ancestral(model, x, sigmas, progress=False): """Ancestral sampling with Euler method steps.""" s_in = x.new_ones([x.shape[0]]) indices = range(len(sigmas) - 1) if progress: from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: denoised = model(x, sigmas[i] * s_in) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "pred_xstart": denoised} d = to_d(x, sigmas[i], denoised) # Euler method dt = sigma_down - sigmas[i] x = x + d * dt x = x + th.randn_like(x) * sigma_up yield {"x": x, "pred_xstart": x} @th.no_grad() def sample_heun( denoiser, x, sigmas, progress=False, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, condition_latent=None, ): """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" s_in = x.new_ones([x.shape[0]]) indices = range(len(sigmas) - 1) if progress: from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: gamma = ( min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 ) eps = th.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 denoised = denoiser(x, sigma_hat * s_in) d = to_d(x, sigma_hat, denoised) yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "pred_xstart": denoised} dt = sigmas[i + 1] - sigma_hat if sigmas[i + 1] == 0: # Euler method x = x + d * dt else: # Heun's method x_2 = x + d * dt denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in) d_2 = to_d(x_2, sigmas[i + 1], denoised_2) d_prime = (d + d_2) / 2 x = x + d_prime * dt yield {"x": x, "pred_xstart": denoised} @th.no_grad() def sample_dpm( denoiser, x, sigmas, progress=False, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, ): """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" s_in = x.new_ones([x.shape[0]]) indices = range(len(sigmas) - 1) if progress: from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: gamma = ( min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 ) eps = th.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 denoised = denoiser(x, sigma_hat * s_in) d = to_d(x, sigma_hat, denoised) yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised} # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3 dt_1 = sigma_mid - sigma_hat dt_2 = sigmas[i + 1] - sigma_hat x_2 = x + d * dt_1 denoised_2 = denoiser(x_2, sigma_mid * s_in) d_2 = to_d(x_2, sigma_mid, denoised_2) x = x + d_2 * dt_2 yield {"x": x, "pred_xstart": denoised} def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") return x[(...,) + (None,) * dims_to_append] def append_zero(x): return th.cat([x, x.new_zeros([1])])