Spaces:
Runtime error
Runtime error
from typing import List | |
from functools import partial | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from .modules.diffusionmodules.util import ( | |
make_beta_schedule, | |
extract_into_tensor, | |
enforce_zero_terminal_snr, | |
noise_like, | |
) | |
from .util import exists, default, instantiate_from_config | |
from .modules.distributions.distributions import DiagonalGaussianDistribution | |
class DiffusionWrapper(nn.Module): | |
def __init__(self, diffusion_model): | |
super().__init__() | |
self.diffusion_model = diffusion_model | |
def forward(self, *args, **kwargs): | |
return self.diffusion_model(*args, **kwargs) | |
class LatentDiffusionInterface(nn.Module): | |
"""a simple interface class for LDM inference""" | |
def __init__( | |
self, | |
unet_config, | |
clip_config, | |
vae_config, | |
parameterization="eps", | |
scale_factor=0.18215, | |
beta_schedule="linear", | |
timesteps=1000, | |
linear_start=0.00085, | |
linear_end=0.0120, | |
cosine_s=8e-3, | |
given_betas=None, | |
zero_snr=False, | |
*args, | |
**kwargs, | |
): | |
super().__init__() | |
unet = instantiate_from_config(unet_config) | |
self.model = DiffusionWrapper(unet) | |
self.clip_model = instantiate_from_config(clip_config) | |
self.vae_model = instantiate_from_config(vae_config) | |
self.parameterization = parameterization | |
self.scale_factor = scale_factor | |
self.register_schedule( | |
given_betas=given_betas, | |
beta_schedule=beta_schedule, | |
timesteps=timesteps, | |
linear_start=linear_start, | |
linear_end=linear_end, | |
cosine_s=cosine_s, | |
zero_snr=zero_snr | |
) | |
def register_schedule( | |
self, | |
given_betas=None, | |
beta_schedule="linear", | |
timesteps=1000, | |
linear_start=1e-4, | |
linear_end=2e-2, | |
cosine_s=8e-3, | |
zero_snr=False | |
): | |
if exists(given_betas): | |
betas = given_betas | |
else: | |
betas = make_beta_schedule( | |
beta_schedule, | |
timesteps, | |
linear_start=linear_start, | |
linear_end=linear_end, | |
cosine_s=cosine_s, | |
) | |
if zero_snr: | |
print("--- using zero snr---") | |
betas = enforce_zero_terminal_snr(betas).numpy() | |
alphas = 1.0 - betas | |
alphas_cumprod = np.cumprod(alphas, axis=0) | |
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) | |
(timesteps,) = betas.shape | |
self.num_timesteps = int(timesteps) | |
self.linear_start = linear_start | |
self.linear_end = linear_end | |
assert ( | |
alphas_cumprod.shape[0] == self.num_timesteps | |
), "alphas have to be defined for each timestep" | |
to_torch = partial(torch.tensor, dtype=torch.float32) | |
self.register_buffer("betas", to_torch(betas)) | |
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) | |
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) | |
# calculations for diffusion q(x_t | x_{t-1}) and others | |
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) | |
self.register_buffer( | |
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) | |
) | |
self.register_buffer( | |
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) | |
) | |
eps = 1e-8 # adding small epsilon value to avoid devide by zero error | |
self.register_buffer( | |
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps))) | |
) | |
self.register_buffer( | |
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps) - 1)) | |
) | |
# calculations for posterior q(x_{t-1} | x_t, x_0) | |
self.v_posterior = 0 | |
posterior_variance = (1 - self.v_posterior) * betas * ( | |
1.0 - alphas_cumprod_prev | |
) / (1.0 - alphas_cumprod) + self.v_posterior * betas | |
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) | |
self.register_buffer("posterior_variance", to_torch(posterior_variance)) | |
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain | |
self.register_buffer( | |
"posterior_log_variance_clipped", | |
to_torch(np.log(np.maximum(posterior_variance, 1e-20))), | |
) | |
self.register_buffer( | |
"posterior_mean_coef1", | |
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), | |
) | |
self.register_buffer( | |
"posterior_mean_coef2", | |
to_torch( | |
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) | |
), | |
) | |
def q_sample(self, x_start, t, noise=None): | |
noise = default(noise, lambda: torch.randn_like(x_start)) | |
return ( | |
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start | |
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) | |
* noise | |
) | |
def get_v(self, x, noise, t): | |
return ( | |
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise | |
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x | |
) | |
def predict_start_from_noise(self, x_t, t, noise): | |
return ( | |
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t | |
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) | |
* noise | |
) | |
def predict_start_from_z_and_v(self, x_t, t, v): | |
return ( | |
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t | |
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v | |
) | |
def predict_eps_from_z_and_v(self, x_t, t, v): | |
return ( | |
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v | |
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) | |
* x_t | |
) | |
def apply_model(self, x_noisy, t, cond, **kwargs): | |
assert isinstance(cond, dict), "cond has to be a dictionary" | |
return self.model(x_noisy, t, **cond, **kwargs) | |
def get_learned_conditioning(self, prompts: List[str]): | |
return self.clip_model(prompts) | |
def get_learned_image_conditioning(self, images): | |
return self.clip_model.forward_image(images) | |
def get_first_stage_encoding(self, encoder_posterior): | |
if isinstance(encoder_posterior, DiagonalGaussianDistribution): | |
z = encoder_posterior.sample() | |
elif isinstance(encoder_posterior, torch.Tensor): | |
z = encoder_posterior | |
else: | |
raise NotImplementedError( | |
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" | |
) | |
return self.scale_factor * z | |
def encode_first_stage(self, x): | |
return self.vae_model.encode(x) | |
def decode_first_stage(self, z): | |
z = 1.0 / self.scale_factor * z | |
return self.vae_model.decode(z) | |