from math import atan, cos, pi, sin, sqrt from typing import Any, Callable, List, Optional, Tuple, Type import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, reduce from torch import Tensor from .utils import * """ Diffusion Training """ """ Distributions """ class Distribution: def __call__(self, num_samples: int, device: torch.device): raise NotImplementedError() class LogNormalDistribution(Distribution): def __init__(self, mean: float, std: float): self.mean = mean self.std = std def __call__( self, num_samples: int, device: torch.device = torch.device("cpu") ) -> Tensor: normal = self.mean + self.std * torch.randn((num_samples,), device=device) return normal.exp() class UniformDistribution(Distribution): def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")): return torch.rand(num_samples, device=device) class VKDistribution(Distribution): def __init__( self, min_value: float = 0.0, max_value: float = float("inf"), sigma_data: float = 1.0, ): self.min_value = min_value self.max_value = max_value self.sigma_data = sigma_data def __call__( self, num_samples: int, device: torch.device = torch.device("cpu") ) -> Tensor: sigma_data = self.sigma_data min_cdf = atan(self.min_value / sigma_data) * 2 / pi max_cdf = atan(self.max_value / sigma_data) * 2 / pi u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf return torch.tan(u * pi / 2) * sigma_data """ Diffusion Classes """ def pad_dims(x: Tensor, ndim: int) -> Tensor: # Pads additional ndims to the right of the tensor return x.view(*x.shape, *((1,) * ndim)) def clip(x: Tensor, dynamic_threshold: float = 0.0): if dynamic_threshold == 0.0: return x.clamp(-1.0, 1.0) else: # Dynamic thresholding # Find dynamic threshold quantile for each batch x_flat = rearrange(x, "b ... -> b (...)") scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1) # Clamp to a min of 1.0 scale.clamp_(min=1.0) # Clamp all values and scale scale = pad_dims(scale, ndim=x.ndim - scale.ndim) x = x.clamp(-scale, scale) / scale return x def to_batch( batch_size: int, device: torch.device, x: Optional[float] = None, xs: Optional[Tensor] = None, ) -> Tensor: assert exists(x) ^ exists(xs), "Either x or xs must be provided" # If x provided use the same for all batch items if exists(x): xs = torch.full(size=(batch_size,), fill_value=x).to(device) assert exists(xs) return xs class Diffusion(nn.Module): alias: str = "" """Base diffusion class""" def denoise_fn( self, x_noisy: Tensor, sigmas: Optional[Tensor] = None, sigma: Optional[float] = None, **kwargs, ) -> Tensor: raise NotImplementedError("Diffusion class missing denoise_fn") def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: raise NotImplementedError("Diffusion class missing forward function") class VDiffusion(Diffusion): alias = "v" def __init__(self, net: nn.Module, *, sigma_distribution: Distribution): super().__init__() self.net = net self.sigma_distribution = sigma_distribution def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: angle = sigmas * pi / 2 alpha = torch.cos(angle) beta = torch.sin(angle) return alpha, beta def denoise_fn( self, x_noisy: Tensor, sigmas: Optional[Tensor] = None, sigma: Optional[float] = None, **kwargs, ) -> Tensor: batch_size, device = x_noisy.shape[0], x_noisy.device sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) return self.net(x_noisy, sigmas, **kwargs) def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: batch_size, device = x.shape[0], x.device # Sample amount of noise to add for each batch element sigmas = self.sigma_distribution(num_samples=batch_size, device=device) sigmas_padded = rearrange(sigmas, "b -> b 1 1") # Get noise noise = default(noise, lambda: torch.randn_like(x)) # Combine input and noise weighted by half-circle alpha, beta = self.get_alpha_beta(sigmas_padded) x_noisy = x * alpha + noise * beta x_target = noise * alpha - x * beta # Denoise and return loss x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs) return F.mse_loss(x_denoised, x_target) class KDiffusion(Diffusion): """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364""" alias = "k" def __init__( self, net: nn.Module, *, sigma_distribution: Distribution, sigma_data: float, # data distribution standard deviation dynamic_threshold: float = 0.0, ): super().__init__() self.net = net self.sigma_data = sigma_data self.sigma_distribution = sigma_distribution self.dynamic_threshold = dynamic_threshold def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]: sigma_data = self.sigma_data c_noise = torch.log(sigmas) * 0.25 sigmas = rearrange(sigmas, "b -> b 1 1") c_skip = (sigma_data**2) / (sigmas**2 + sigma_data**2) c_out = sigmas * sigma_data * (sigma_data**2 + sigmas**2) ** -0.5 c_in = (sigmas**2 + sigma_data**2) ** -0.5 return c_skip, c_out, c_in, c_noise def denoise_fn( self, x_noisy: Tensor, sigmas: Optional[Tensor] = None, sigma: Optional[float] = None, **kwargs, ) -> Tensor: batch_size, device = x_noisy.shape[0], x_noisy.device sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) # Predict network output and add skip connection c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas) x_pred = self.net(c_in * x_noisy, c_noise, **kwargs) x_denoised = c_skip * x_noisy + c_out * x_pred return x_denoised def loss_weight(self, sigmas: Tensor) -> Tensor: # Computes weight depending on data distribution return (sigmas**2 + self.sigma_data**2) * (sigmas * self.sigma_data) ** -2 def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: batch_size, device = x.shape[0], x.device from einops import rearrange, reduce # Sample amount of noise to add for each batch element sigmas = self.sigma_distribution(num_samples=batch_size, device=device) sigmas_padded = rearrange(sigmas, "b -> b 1 1") # Add noise to input noise = default(noise, lambda: torch.randn_like(x)) x_noisy = x + sigmas_padded * noise # Compute denoised values x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs) # Compute weighted loss losses = F.mse_loss(x_denoised, x, reduction="none") losses = reduce(losses, "b ... -> b", "mean") losses = losses * self.loss_weight(sigmas) loss = losses.mean() return loss class VKDiffusion(Diffusion): alias = "vk" def __init__(self, net: nn.Module, *, sigma_distribution: Distribution): super().__init__() self.net = net self.sigma_distribution = sigma_distribution def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]: sigma_data = 1.0 sigmas = rearrange(sigmas, "b -> b 1 1") c_skip = (sigma_data**2) / (sigmas**2 + sigma_data**2) c_out = -sigmas * sigma_data * (sigma_data**2 + sigmas**2) ** -0.5 c_in = (sigmas**2 + sigma_data**2) ** -0.5 return c_skip, c_out, c_in def sigma_to_t(self, sigmas: Tensor) -> Tensor: return sigmas.atan() / pi * 2 def t_to_sigma(self, t: Tensor) -> Tensor: return (t * pi / 2).tan() def denoise_fn( self, x_noisy: Tensor, sigmas: Optional[Tensor] = None, sigma: Optional[float] = None, **kwargs, ) -> Tensor: batch_size, device = x_noisy.shape[0], x_noisy.device sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) # Predict network output and add skip connection c_skip, c_out, c_in = self.get_scale_weights(sigmas) x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs) x_denoised = c_skip * x_noisy + c_out * x_pred return x_denoised def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: batch_size, device = x.shape[0], x.device # Sample amount of noise to add for each batch element sigmas = self.sigma_distribution(num_samples=batch_size, device=device) sigmas_padded = rearrange(sigmas, "b -> b 1 1") # Add noise to input noise = default(noise, lambda: torch.randn_like(x)) x_noisy = x + sigmas_padded * noise # Compute model output c_skip, c_out, c_in = self.get_scale_weights(sigmas) x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs) # Compute v-objective target v_target = (x - c_skip * x_noisy) / (c_out + 1e-7) # Compute loss loss = F.mse_loss(x_pred, v_target) return loss """ Diffusion Sampling """ """ Schedules """ class Schedule(nn.Module): """Interface used by different sampling schedules""" def forward(self, num_steps: int, device: torch.device) -> Tensor: raise NotImplementedError() class LinearSchedule(Schedule): def forward(self, num_steps: int, device: Any) -> Tensor: sigmas = torch.linspace(1, 0, num_steps + 1)[:-1] return sigmas class KarrasSchedule(Schedule): """https://arxiv.org/abs/2206.00364 equation 5""" def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0): super().__init__() self.sigma_min = sigma_min self.sigma_max = sigma_max self.rho = rho def forward(self, num_steps: int, device: Any) -> Tensor: rho_inv = 1.0 / self.rho steps = torch.arange(num_steps, device=device, dtype=torch.float32) sigmas = ( self.sigma_max**rho_inv + (steps / (num_steps - 1)) * (self.sigma_min**rho_inv - self.sigma_max**rho_inv) ) ** self.rho sigmas = F.pad(sigmas, pad=(0, 1), value=0.0) return sigmas """ Samplers """ class Sampler(nn.Module): diffusion_types: List[Type[Diffusion]] = [] def forward( self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int ) -> Tensor: raise NotImplementedError() def inpaint( self, source: Tensor, mask: Tensor, fn: Callable, sigmas: Tensor, num_steps: int, num_resamples: int, ) -> Tensor: raise NotImplementedError("Inpainting not available with current sampler") class VSampler(Sampler): diffusion_types = [VDiffusion] def get_alpha_beta(self, sigma: float) -> Tuple[float, float]: angle = sigma * pi / 2 alpha = cos(angle) beta = sin(angle) return alpha, beta def forward( self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int ) -> Tensor: x = sigmas[0] * noise alpha, beta = self.get_alpha_beta(sigmas[0].item()) for i in range(num_steps - 1): is_last = i == num_steps - 1 x_denoised = fn(x, sigma=sigmas[i]) x_pred = x * alpha - x_denoised * beta x_eps = x * beta + x_denoised * alpha if not is_last: alpha, beta = self.get_alpha_beta(sigmas[i + 1].item()) x = x_pred * alpha + x_eps * beta return x_pred class KarrasSampler(Sampler): """https://arxiv.org/abs/2206.00364 algorithm 1""" diffusion_types = [KDiffusion, VKDiffusion] def __init__( self, s_tmin: float = 0, s_tmax: float = float("inf"), s_churn: float = 0.0, s_noise: float = 1.0, ): super().__init__() self.s_tmin = s_tmin self.s_tmax = s_tmax self.s_noise = s_noise self.s_churn = s_churn def step( self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float ) -> Tensor: """Algorithm 2 (step)""" # Select temporarily increased noise level sigma_hat = sigma + gamma * sigma # Add noise to move from sigma to sigma_hat epsilon = self.s_noise * torch.randn_like(x) x_hat = x + sqrt(sigma_hat**2 - sigma**2) * epsilon # Evaluate ∂x/∂sigma at sigma_hat d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat # Take euler step from sigma_hat to sigma_next x_next = x_hat + (sigma_next - sigma_hat) * d # Second order correction if sigma_next != 0: model_out_next = fn(x_next, sigma=sigma_next) d_prime = (x_next - model_out_next) / sigma_next x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime) return x_next def forward( self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int ) -> Tensor: x = sigmas[0] * noise # Compute gammas gammas = torch.where( (sigmas >= self.s_tmin) & (sigmas <= self.s_tmax), min(self.s_churn / num_steps, sqrt(2) - 1), 0.0, ) # Denoise to sample for i in range(num_steps - 1): x = self.step( x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa ) return x class AEulerSampler(Sampler): diffusion_types = [KDiffusion, VKDiffusion] def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]: sigma_up = sqrt(sigma_next**2 * (sigma**2 - sigma_next**2) / sigma**2) sigma_down = sqrt(sigma_next**2 - sigma_up**2) return sigma_up, sigma_down def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor: # Sigma steps sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next) # Derivative at sigma (∂x/∂sigma) d = (x - fn(x, sigma=sigma)) / sigma # Euler method x_next = x + d * (sigma_down - sigma) # Add randomness x_next = x_next + torch.randn_like(x) * sigma_up return x_next def forward( self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int ) -> Tensor: x = sigmas[0] * noise # Denoise to sample for i in range(num_steps - 1): x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa return x class ADPM2Sampler(Sampler): """https://www.desmos.com/calculator/jbxjlqd9mb""" diffusion_types = [KDiffusion, VKDiffusion] def __init__(self, rho: float = 1.0): super().__init__() self.rho = rho def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]: r = self.rho sigma_up = sqrt(sigma_next**2 * (sigma**2 - sigma_next**2) / sigma**2) sigma_down = sqrt(sigma_next**2 - sigma_up**2) sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r return sigma_up, sigma_down, sigma_mid def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor: # Sigma steps sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next) # Derivative at sigma (∂x/∂sigma) d = (x - fn(x, sigma=sigma)) / sigma # Denoise to midpoint x_mid = x + d * (sigma_mid - sigma) # Derivative at sigma_mid (∂x_mid/∂sigma_mid) d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid # Denoise to next x = x + d_mid * (sigma_down - sigma) # Add randomness x_next = x + torch.randn_like(x) * sigma_up return x_next def forward( self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int ) -> Tensor: x = sigmas[0] * noise # Denoise to sample for i in range(num_steps - 1): x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa return x def inpaint( self, source: Tensor, mask: Tensor, fn: Callable, sigmas: Tensor, num_steps: int, num_resamples: int, ) -> Tensor: x = sigmas[0] * torch.randn_like(source) for i in range(num_steps - 1): # Noise source to current noise level source_noisy = source + sigmas[i] * torch.randn_like(source) for r in range(num_resamples): # Merge noisy source and current then denoise x = source_noisy * mask + x * ~mask x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa # Renoise if not last resample step if r < num_resamples - 1: sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2) x = x + sigma * torch.randn_like(x) return source * mask + x * ~mask """ Main Classes """ class DiffusionSampler(nn.Module): def __init__( self, diffusion: Diffusion, *, sampler: Sampler, sigma_schedule: Schedule, num_steps: Optional[int] = None, clamp: bool = True, ): super().__init__() self.denoise_fn = diffusion.denoise_fn self.sampler = sampler self.sigma_schedule = sigma_schedule self.num_steps = num_steps self.clamp = clamp # Check sampler is compatible with diffusion type sampler_class = sampler.__class__.__name__ diffusion_class = diffusion.__class__.__name__ message = f"{sampler_class} incompatible with {diffusion_class}" assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message def forward( self, noise: Tensor, num_steps: Optional[int] = None, **kwargs ) -> Tensor: device = noise.device num_steps = default(num_steps, self.num_steps) # type: ignore assert exists(num_steps), "Parameter `num_steps` must be provided" # Compute sigmas using schedule sigmas = self.sigma_schedule(num_steps, device) # Append additional kwargs to denoise function (used e.g. for conditional unet) fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa # Sample using sampler x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps) x = x.clamp(-1.0, 1.0) if self.clamp else x return x class DiffusionInpainter(nn.Module): def __init__( self, diffusion: Diffusion, *, num_steps: int, num_resamples: int, sampler: Sampler, sigma_schedule: Schedule, ): super().__init__() self.denoise_fn = diffusion.denoise_fn self.num_steps = num_steps self.num_resamples = num_resamples self.inpaint_fn = sampler.inpaint self.sigma_schedule = sigma_schedule @torch.no_grad() def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor: x = self.inpaint_fn( source=inpaint, mask=inpaint_mask, fn=self.denoise_fn, sigmas=self.sigma_schedule(self.num_steps, inpaint.device), num_steps=self.num_steps, num_resamples=self.num_resamples, ) return x def sequential_mask(like: Tensor, start: int) -> Tensor: length, device = like.shape[2], like.device mask = torch.ones_like(like, dtype=torch.bool) mask[:, :, start:] = torch.zeros((length - start,), device=device) return mask class SpanBySpanComposer(nn.Module): def __init__( self, inpainter: DiffusionInpainter, *, num_spans: int, ): super().__init__() self.inpainter = inpainter self.num_spans = num_spans def forward(self, start: Tensor, keep_start: bool = False) -> Tensor: half_length = start.shape[2] // 2 spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else [] # Inpaint second half from first half inpaint = torch.zeros_like(start) inpaint[:, :, :half_length] = start[:, :, half_length:] inpaint_mask = sequential_mask(like=start, start=half_length) for i in range(self.num_spans): # Inpaint second half span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask) # Replace first half with generated second half second_half = span[:, :, half_length:] inpaint[:, :, :half_length] = second_half # Save generated span spans.append(second_half) return torch.cat(spans, dim=2) class XDiffusion(nn.Module): def __init__(self, type: str, net: nn.Module, **kwargs): super().__init__() diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion] aliases = [t.alias for t in diffusion_classes] # type: ignore message = f"type='{type}' must be one of {*aliases,}" assert type in aliases, message self.net = net for XDiffusion in diffusion_classes: if XDiffusion.alias == type: # type: ignore self.diffusion = XDiffusion(net=net, **kwargs) def forward(self, *args, **kwargs) -> Tensor: return self.diffusion(*args, **kwargs) def sample( self, noise: Tensor, num_steps: int, sigma_schedule: Schedule, sampler: Sampler, clamp: bool, **kwargs, ) -> Tensor: diffusion_sampler = DiffusionSampler( diffusion=self.diffusion, sampler=sampler, sigma_schedule=sigma_schedule, num_steps=num_steps, clamp=clamp, ) return diffusion_sampler(noise, **kwargs)