from typing import List from functools import partial import numpy as np import torch import torch.nn as nn from .modules.diffusionmodules.util import ( make_beta_schedule, extract_into_tensor, enforce_zero_terminal_snr, noise_like, ) from .util import exists, default, instantiate_from_config from .modules.distributions.distributions import DiagonalGaussianDistribution class DiffusionWrapper(nn.Module): def __init__(self, diffusion_model): super().__init__() self.diffusion_model = diffusion_model def forward(self, *args, **kwargs): return self.diffusion_model(*args, **kwargs) class LatentDiffusionInterface(nn.Module): """a simple interface class for LDM inference""" def __init__( self, unet_config, clip_config, vae_config, parameterization="eps", scale_factor=0.18215, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.0120, cosine_s=8e-3, given_betas=None, zero_snr=False, *args, **kwargs, ): super().__init__() unet = instantiate_from_config(unet_config) self.model = DiffusionWrapper(unet) self.clip_model = instantiate_from_config(clip_config) self.vae_model = instantiate_from_config(vae_config) self.parameterization = parameterization self.scale_factor = scale_factor self.register_schedule( given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, zero_snr=zero_snr ) def register_schedule( self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zero_snr=False ): if exists(given_betas): betas = given_betas else: betas = make_beta_schedule( beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, ) if zero_snr: print("--- using zero snr---") betas = enforce_zero_terminal_snr(betas).numpy() alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end assert ( alphas_cumprod.shape[0] == self.num_timesteps ), "alphas have to be defined for each timestep" to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer("betas", to_torch(betas)) self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) self.register_buffer( "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) ) self.register_buffer( "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) ) eps = 1e-8 # adding small epsilon value to avoid devide by zero error self.register_buffer( "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps))) ) self.register_buffer( "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps) - 1)) ) # calculations for posterior q(x_{t-1} | x_t, x_0) self.v_posterior = 0 posterior_variance = (1 - self.v_posterior) * betas * ( 1.0 - alphas_cumprod_prev ) / (1.0 - alphas_cumprod) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.register_buffer("posterior_variance", to_torch(posterior_variance)) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.register_buffer( "posterior_log_variance_clipped", to_torch(np.log(np.maximum(posterior_variance, 1e-20))), ) self.register_buffer( "posterior_mean_coef1", to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), ) self.register_buffer( "posterior_mean_coef2", to_torch( (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) ), ) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def get_v(self, x, noise, t): return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x ) def predict_start_from_noise(self, x_t, t, noise): return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) def predict_start_from_z_and_v(self, x_t, t, v): return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v ) def predict_eps_from_z_and_v(self, x_t, t, v): return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t ) def apply_model(self, x_noisy, t, cond, **kwargs): assert isinstance(cond, dict), "cond has to be a dictionary" return self.model(x_noisy, t, **cond, **kwargs) def get_learned_conditioning(self, prompts: List[str]): return self.clip_model(prompts) def get_learned_image_conditioning(self, images): return self.clip_model.forward_image(images) def get_first_stage_encoding(self, encoder_posterior): if isinstance(encoder_posterior, DiagonalGaussianDistribution): z = encoder_posterior.sample() elif isinstance(encoder_posterior, torch.Tensor): z = encoder_posterior else: raise NotImplementedError( f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" ) return self.scale_factor * z def encode_first_stage(self, x): return self.vae_model.encode(x) def decode_first_stage(self, z): z = 1.0 / self.scale_factor * z return self.vae_model.decode(z)