|
|
|
|
|
|
|
|
|
import torch |
|
import numpy as np |
|
|
|
from tqdm import tqdm |
|
from functools import partial |
|
from diffusers_vdm.basics import extract_into_tensor |
|
|
|
|
|
to_torch = partial(torch.tensor, dtype=torch.float32) |
|
|
|
|
|
def rescale_zero_terminal_snr(betas): |
|
|
|
alphas = 1.0 - betas |
|
alphas_cumprod = np.cumprod(alphas, axis=0) |
|
alphas_bar_sqrt = np.sqrt(alphas_cumprod) |
|
|
|
|
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy() |
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy() |
|
|
|
|
|
alphas_bar_sqrt -= alphas_bar_sqrt_T |
|
|
|
|
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) |
|
|
|
|
|
alphas_bar = alphas_bar_sqrt**2 |
|
alphas = alphas_bar[1:] / alphas_bar[:-1] |
|
alphas = np.concatenate([alphas_bar[0:1], alphas]) |
|
betas = 1 - alphas |
|
|
|
return betas |
|
|
|
|
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
|
|
|
|
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
|
|
|
|
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
|
|
|
return noise_cfg |
|
|
|
|
|
class SamplerDynamicTSNR(torch.nn.Module): |
|
@torch.no_grad() |
|
def __init__(self, unet, terminal_scale=0.7): |
|
super().__init__() |
|
self.unet = unet |
|
|
|
self.is_v = True |
|
self.n_timestep = 1000 |
|
self.guidance_rescale = 0.7 |
|
|
|
linear_start = 0.00085 |
|
linear_end = 0.012 |
|
|
|
betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, self.n_timestep, dtype=np.float64) ** 2 |
|
betas = rescale_zero_terminal_snr(betas) |
|
alphas = 1. - betas |
|
|
|
alphas_cumprod = np.cumprod(alphas, axis=0) |
|
|
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod).to(unet.device)) |
|
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)).to(unet.device)) |
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)).to(unet.device)) |
|
|
|
|
|
turning_step = 400 |
|
scale_arr = np.concatenate([ |
|
np.linspace(1.0, terminal_scale, turning_step), |
|
np.full(self.n_timestep - turning_step, terminal_scale) |
|
]) |
|
self.register_buffer('scale_arr', to_torch(scale_arr).to(unet.device)) |
|
|
|
def predict_eps_from_z_and_v(self, x_t, t, v): |
|
return self.sqrt_alphas_cumprod[t] * v + self.sqrt_one_minus_alphas_cumprod[t] * x_t |
|
|
|
def predict_start_from_z_and_v(self, x_t, t, v): |
|
return self.sqrt_alphas_cumprod[t] * x_t - self.sqrt_one_minus_alphas_cumprod[t] * v |
|
|
|
def q_sample(self, x0, t, noise): |
|
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * x0 + |
|
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) |
|
|
|
def get_v(self, x0, t, noise): |
|
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * noise - |
|
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * x0) |
|
|
|
def dynamic_x0_rescale(self, x0, t): |
|
return x0 * extract_into_tensor(self.scale_arr, t, x0.shape) |
|
|
|
@torch.no_grad() |
|
def get_ground_truth(self, x0, noise, t): |
|
x0 = self.dynamic_x0_rescale(x0, t) |
|
xt = self.q_sample(x0, t, noise) |
|
target = self.get_v(x0, t, noise) if self.is_v else noise |
|
return xt, target |
|
|
|
def get_uniform_trailing_steps(self, steps): |
|
c = self.n_timestep / steps |
|
ddim_timesteps = np.flip(np.round(np.arange(self.n_timestep, 0, -c))).astype(np.int64) |
|
steps_out = ddim_timesteps - 1 |
|
return torch.tensor(steps_out, device=self.unet.device, dtype=torch.long) |
|
|
|
@torch.no_grad() |
|
def forward(self, latent_shape, steps, extra_args, progress_tqdm=None): |
|
bar = tqdm if progress_tqdm is None else progress_tqdm |
|
|
|
eta = 1.0 |
|
|
|
timesteps = self.get_uniform_trailing_steps(steps) |
|
timesteps_prev = torch.nn.functional.pad(timesteps[:-1], pad=(1, 0)) |
|
|
|
x = torch.randn(latent_shape, device=self.unet.device, dtype=self.unet.dtype) |
|
|
|
alphas = self.alphas_cumprod[timesteps] |
|
alphas_prev = self.alphas_cumprod[timesteps_prev] |
|
scale_arr = self.scale_arr[timesteps] |
|
scale_arr_prev = self.scale_arr[timesteps_prev] |
|
|
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas) |
|
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) |
|
|
|
s_in = x.new_ones((x.shape[0])) |
|
s_x = x.new_ones((x.shape[0], ) + (1, ) * (x.ndim - 1)) |
|
for i in bar(range(len(timesteps))): |
|
index = len(timesteps) - 1 - i |
|
t = timesteps[index].item() |
|
|
|
model_output = self.model_apply(x, t * s_in, **extra_args) |
|
|
|
if self.is_v: |
|
e_t = self.predict_eps_from_z_and_v(x, t, model_output) |
|
else: |
|
e_t = model_output |
|
|
|
a_prev = alphas_prev[index].item() * s_x |
|
sigma_t = sigmas[index].item() * s_x |
|
|
|
if self.is_v: |
|
pred_x0 = self.predict_start_from_z_and_v(x, t, model_output) |
|
else: |
|
a_t = alphas[index].item() * s_x |
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x |
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() |
|
|
|
|
|
scale_t = scale_arr[index].item() * s_x |
|
prev_scale_t = scale_arr_prev[index].item() * s_x |
|
rescale = (prev_scale_t / scale_t) |
|
pred_x0 = pred_x0 * rescale |
|
|
|
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t |
|
noise = sigma_t * torch.randn_like(x) |
|
x = a_prev.sqrt() * pred_x0 + dir_xt + noise |
|
|
|
return x |
|
|
|
@torch.no_grad() |
|
def model_apply(self, x, t, **extra_args): |
|
x = x.to(device=self.unet.device, dtype=self.unet.dtype) |
|
cfg_scale = extra_args['cfg_scale'] |
|
p = self.unet(x, t, **extra_args['positive']) |
|
n = self.unet(x, t, **extra_args['negative']) |
|
o = n + cfg_scale * (p - n) |
|
o_better = rescale_noise_cfg(o, p, guidance_rescale=self.guidance_rescale) |
|
return o_better |
|
|