Spaces:
Running on Zero
Running on Zero
| """ | |
| wild mixture of | |
| https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py | |
| https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py | |
| https://github.com/CompVis/taming-transformers | |
| -- merci | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from contextlib import contextmanager | |
| from functools import partial | |
| from refnet.util import default, count_params, instantiate_from_config, exists | |
| from refnet.ldm.util import make_beta_schedule, extract_into_tensor | |
| def disabled_train(self, mode=True): | |
| """Overwrite model.train with this function to make sure train/eval mode | |
| does not change anymore.""" | |
| return self | |
| def uniform_on_device(r1, r2, shape, device): | |
| return (r1 - r2) * torch.rand(*shape, device=device) + r2 | |
| def rescale_zero_terminal_snr(betas): | |
| """ | |
| Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) | |
| Args: | |
| betas (`torch.FloatTensor`): | |
| the betas that the scheduler is being initialized with. | |
| Returns: | |
| `torch.FloatTensor`: rescaled betas with zero terminal SNR | |
| """ | |
| # Convert betas to alphas_bar_sqrt | |
| alphas = 1.0 - betas | |
| alphas_cumprod = torch.cumprod(alphas, dim=0) | |
| alphas_bar_sqrt = alphas_cumprod.sqrt() | |
| # Store old values. | |
| alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() | |
| alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() | |
| # Shift so the last timestep is zero. | |
| alphas_bar_sqrt -= alphas_bar_sqrt_T | |
| # Scale so the first timestep is back to the old value. | |
| alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) | |
| # Convert alphas_bar_sqrt to betas | |
| alphas_bar = alphas_bar_sqrt**2 # Revert sqrt | |
| alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod | |
| alphas = torch.cat([alphas_bar[0:1], alphas]) | |
| betas = 1 - alphas | |
| return betas | |
| class DDPM(nn.Module): | |
| # classic DDPM with Gaussian diffusion, in image space | |
| def __init__( | |
| self, | |
| unet_config, | |
| timesteps = 1000, | |
| beta_schedule = "scaled_linear", | |
| image_size = 256, | |
| channels = 3, | |
| linear_start = 1e-4, | |
| linear_end = 2e-2, | |
| cosine_s = 8e-3, | |
| v_posterior = 0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta | |
| parameterization = "eps", # all assuming fixed variance schedules | |
| zero_snr = False, | |
| half_precision_dtype = "float16", | |
| version = "sdv1", | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| assert parameterization in ["eps", "v"], "currently only supporting 'eps' and 'v'" | |
| assert half_precision_dtype in ["float16", "bfloat16"], "K-diffusion samplers do not support bfloat16, use float16 by default" | |
| if zero_snr: | |
| assert parameterization == "v", 'Zero SNR is only available for "v-prediction" model.' | |
| self.is_sdxl = (version == "sdxl") | |
| self.parameterization = parameterization | |
| print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") | |
| self.cond_stage_model = None | |
| self.img_embedder = None | |
| self.image_size = image_size # try conv? | |
| self.channels = channels | |
| self.model = DiffusionWrapper(unet_config) | |
| count_params(self.model, verbose=True) | |
| self.v_posterior = v_posterior | |
| self.half_precision_dtype = torch.bfloat16 if half_precision_dtype == "bfloat16" else torch.float16 | |
| self.register_schedule(beta_schedule=beta_schedule, timesteps=timesteps, | |
| linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, zero_snr=zero_snr) | |
| def register_schedule(self, beta_schedule="scaled_linear", timesteps=1000, | |
| linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zero_snr=False): | |
| betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, | |
| cosine_s=cosine_s, zero_snr=zero_snr) | |
| alphas = 1. - betas | |
| alphas_cumprod = np.cumprod(alphas, axis=0) | |
| alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) | |
| timesteps, = betas.shape | |
| self.num_timesteps = int(timesteps) | |
| self.linear_start = linear_start | |
| self.linear_end = linear_end | |
| assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' | |
| to_torch = partial(torch.tensor, dtype=torch.float32) | |
| self.register_buffer('betas', to_torch(betas)) | |
| self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) | |
| self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) | |
| # calculations for diffusion q(x_t | x_{t-1}) and others | |
| self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) | |
| self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) | |
| self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) | |
| self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) | |
| self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) | |
| # calculations for posterior q(x_{t-1} | x_t, x_0) | |
| posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( | |
| 1. - alphas_cumprod) + self.v_posterior * betas | |
| # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) | |
| self.register_buffer('posterior_variance', to_torch(posterior_variance)) | |
| # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain | |
| self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) | |
| self.register_buffer('posterior_mean_coef1', to_torch( | |
| betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) | |
| self.register_buffer('posterior_mean_coef2', to_torch( | |
| (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) | |
| def ema_scope(self, context=None): | |
| if self.use_ema: | |
| self.model_ema.store(self.model.parameters()) | |
| self.model_ema.copy_to(self.model) | |
| if context is not None: | |
| print(f"{context}: Switched to EMA weights") | |
| try: | |
| yield None | |
| finally: | |
| if self.use_ema: | |
| self.model_ema.restore(self.model.parameters()) | |
| if context is not None: | |
| print(f"{context}: Restored training weights") | |
| def predict_start_from_z_and_v(self, x_t, t, v): | |
| # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) | |
| # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) | |
| return ( | |
| extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - | |
| extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v | |
| ) | |
| def add_noise(self, x_start, t, noise=None): | |
| noise = default(noise, lambda: torch.randn_like(x_start)) | |
| return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + | |
| extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise).to(x_start.dtype) | |
| def get_v(self, x, noise, t): | |
| return ( | |
| extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - | |
| extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x | |
| ) | |
| def normalize_timesteps(self, timesteps): | |
| return timesteps | |
| class LatentDiffusion(DDPM): | |
| """main class""" | |
| def __init__( | |
| self, | |
| first_stage_config, | |
| cond_stage_config, | |
| scale_factor = 1.0, | |
| *args, | |
| **kwargs | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.scale_factor = scale_factor | |
| self.first_stage_model, self.cond_stage_model = map( | |
| lambda t: instantiate_from_config(t).eval().requires_grad_(False), | |
| (first_stage_config, cond_stage_config) | |
| ) | |
| def get_first_stage_encoding(self, x): | |
| encoder_posterior = self.first_stage_model.encode(x) | |
| z = encoder_posterior.sample() * self.scale_factor | |
| return z.to(self.dtype).detach() | |
| def decode_first_stage(self, z): | |
| z = 1. / self.scale_factor * z | |
| return self.first_stage_model.decode(z.to(self.first_stage_model.dtype)).detach() | |
| def apply_model(self, x_noisy, t, cond): | |
| return self.model(x_noisy, t, **cond) | |
| def get_learned_embedding(self, c, *args, **kwargs): | |
| wd_emb, wd_logits = map(lambda t: t.detach() if exists(t) else None, self.img_embedder.encode(c, **kwargs)) | |
| clip_emb = self.cond_stage_model.encode(c, **kwargs).detach() | |
| return wd_emb, wd_logits, clip_emb | |
| class DiffusionWrapper(nn.Module): | |
| def __init__(self, diff_model_config): | |
| super().__init__() | |
| self.diffusion_model = instantiate_from_config(diff_model_config) | |
| def forward(self, x, t, **cond): | |
| for k in cond: | |
| if k in ["context", "y", "concat"]: | |
| cond[k] = torch.cat(cond[k], 1) | |
| out = self.diffusion_model(x, t, **cond) | |
| return out |