3dart / imagedream /ldm /interface.py
gusgeneris's picture
First
c24da45
raw
history blame
7.51 kB
from typing import List
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from .modules.diffusionmodules.util import (
make_beta_schedule,
extract_into_tensor,
enforce_zero_terminal_snr,
noise_like,
)
from .util import exists, default, instantiate_from_config
from .modules.distributions.distributions import DiagonalGaussianDistribution
class DiffusionWrapper(nn.Module):
def __init__(self, diffusion_model):
super().__init__()
self.diffusion_model = diffusion_model
def forward(self, *args, **kwargs):
return self.diffusion_model(*args, **kwargs)
class LatentDiffusionInterface(nn.Module):
"""a simple interface class for LDM inference"""
def __init__(
self,
unet_config,
clip_config,
vae_config,
parameterization="eps",
scale_factor=0.18215,
beta_schedule="linear",
timesteps=1000,
linear_start=0.00085,
linear_end=0.0120,
cosine_s=8e-3,
given_betas=None,
zero_snr=False,
*args,
**kwargs,
):
super().__init__()
unet = instantiate_from_config(unet_config)
self.model = DiffusionWrapper(unet)
self.clip_model = instantiate_from_config(clip_config)
self.vae_model = instantiate_from_config(vae_config)
self.parameterization = parameterization
self.scale_factor = scale_factor
self.register_schedule(
given_betas=given_betas,
beta_schedule=beta_schedule,
timesteps=timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
zero_snr=zero_snr
)
def register_schedule(
self,
given_betas=None,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
zero_snr=False
):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
if zero_snr:
print("--- using zero snr---")
betas = enforce_zero_terminal_snr(betas).numpy()
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert (
alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep"
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
eps = 1e-8 # adding small epsilon value to avoid devide by zero error
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps)))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps) - 1))
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.v_posterior = 0
posterior_variance = (1 - self.v_posterior) * betas * (
1.0 - alphas_cumprod_prev
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer("posterior_variance", to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer(
"posterior_log_variance_clipped",
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
)
self.register_buffer(
"posterior_mean_coef1",
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
)
self.register_buffer(
"posterior_mean_coef2",
to_torch(
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
),
)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def get_v(self, x, noise, t):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
* noise
)
def predict_start_from_z_and_v(self, x_t, t, v):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def predict_eps_from_z_and_v(self, x_t, t, v):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
* x_t
)
def apply_model(self, x_noisy, t, cond, **kwargs):
assert isinstance(cond, dict), "cond has to be a dictionary"
return self.model(x_noisy, t, **cond, **kwargs)
def get_learned_conditioning(self, prompts: List[str]):
return self.clip_model(prompts)
def get_learned_image_conditioning(self, images):
return self.clip_model.forward_image(images)
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
)
return self.scale_factor * z
def encode_first_stage(self, x):
return self.vae_model.encode(x)
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
return self.vae_model.decode(z)