wyysf's picture
Upload 107 files
d758270 verified
raw
history blame
No virus
7.51 kB
from typing import List
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from .modules.diffusionmodules.util import (
make_beta_schedule,
extract_into_tensor,
enforce_zero_terminal_snr,
noise_like,
)
from .util import exists, default, instantiate_from_config
from .modules.distributions.distributions import DiagonalGaussianDistribution
class DiffusionWrapper(nn.Module):
def __init__(self, diffusion_model):
super().__init__()
self.diffusion_model = diffusion_model
def forward(self, *args, **kwargs):
return self.diffusion_model(*args, **kwargs)
class LatentDiffusionInterface(nn.Module):
"""a simple interface class for LDM inference"""
def __init__(
self,
unet_config,
clip_config,
vae_config,
parameterization="eps",
scale_factor=0.18215,
beta_schedule="linear",
timesteps=1000,
linear_start=0.00085,
linear_end=0.0120,
cosine_s=8e-3,
given_betas=None,
zero_snr=False,
*args,
**kwargs,
):
super().__init__()
unet = instantiate_from_config(unet_config)
self.model = DiffusionWrapper(unet)
self.clip_model = instantiate_from_config(clip_config)
self.vae_model = instantiate_from_config(vae_config)
self.parameterization = parameterization
self.scale_factor = scale_factor
self.register_schedule(
given_betas=given_betas,
beta_schedule=beta_schedule,
timesteps=timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
zero_snr=zero_snr
)
def register_schedule(
self,
given_betas=None,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
zero_snr=False
):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
if zero_snr:
print("--- using zero snr---")
betas = enforce_zero_terminal_snr(betas).numpy()
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert (
alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep"
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
eps = 1e-8 # adding small epsilon value to avoid devide by zero error
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps)))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps) - 1))
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.v_posterior = 0
posterior_variance = (1 - self.v_posterior) * betas * (
1.0 - alphas_cumprod_prev
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer("posterior_variance", to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer(
"posterior_log_variance_clipped",
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
)
self.register_buffer(
"posterior_mean_coef1",
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
)
self.register_buffer(
"posterior_mean_coef2",
to_torch(
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
),
)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def get_v(self, x, noise, t):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
* noise
)
def predict_start_from_z_and_v(self, x_t, t, v):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def predict_eps_from_z_and_v(self, x_t, t, v):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
* x_t
)
def apply_model(self, x_noisy, t, cond, **kwargs):
assert isinstance(cond, dict), "cond has to be a dictionary"
return self.model(x_noisy, t, **cond, **kwargs)
def get_learned_conditioning(self, prompts: List[str]):
return self.clip_model(prompts)
def get_learned_image_conditioning(self, images):
return self.clip_model.forward_image(images)
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
)
return self.scale_factor * z
def encode_first_stage(self, x):
return self.vae_model.encode(x)
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
return self.vae_model.decode(z)