Spaces:
Runtime error
Runtime error
from abc import ABC, abstractmethod | |
from diffusers import DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler | |
import torch | |
from diffusers.utils.torch_utils import randn_tensor | |
from IPython import embed | |
import numpy as np | |
class MySchedulers(ABC): | |
def __init__(self) -> None: | |
pass | |
def pred_prev(self, | |
noisy_images, noisy_residual, timestep, timestep_prev | |
): | |
raise NotImplementedError("not implemented") | |
def get_alpha(alphas_cumprod, timestep): | |
timestep_lt_zero_mask = torch.lt(timestep, 0).to(alphas_cumprod.dtype) | |
timestep_gt_999_mask = torch.gt(timestep, 999).to(alphas_cumprod.dtype) | |
normal_alpha = alphas_cumprod[torch.clip(timestep, 0, 999)] | |
one_alpha = torch.ones_like(normal_alpha).to(normal_alpha.dtype).to(normal_alpha.dtype) | |
zero_alpha = torch.zeros_like(normal_alpha).to(normal_alpha.dtype).to(normal_alpha.dtype) | |
return (normal_alpha * (1 - timestep_lt_zero_mask) + one_alpha * timestep_lt_zero_mask) * (1 - timestep_gt_999_mask) + zero_alpha * timestep_gt_999_mask | |
class MyDDIM(MySchedulers): | |
def __init__(self, ddpm_or_ddim_scheduler, normnoise=False) -> None: | |
super(MyDDIM, self).__init__() | |
assert isinstance(ddpm_or_ddim_scheduler, DDPMScheduler) or isinstance(ddpm_or_ddim_scheduler, DDIMScheduler) | |
self.alphas_cumprod = ddpm_or_ddim_scheduler.alphas_cumprod | |
self.normnoise = normnoise | |
self.prediction_type = ddpm_or_ddim_scheduler.config.prediction_type | |
def pred_prev(self, | |
noisy_images, model_output, timestep, timestep_prev | |
): | |
torch_dtype = model_output.dtype | |
#noisy_images = noisy_images.to(torch.float32) | |
#noisy_residual = noisy_residual.to(torch.float32) | |
#print(noisy_residual.std()) | |
#print(noisy_residual.std()) | |
alphas_cumprod = self.alphas_cumprod.to(noisy_images.device) #.to(noisy_images.dtype) | |
alpha_prod_t = get_alpha(alphas_cumprod, timestep).view(-1, 1, 1, 1, 1) | |
beta_prod_t = 1 - alpha_prod_t | |
if self.prediction_type == "epsilon": | |
if self.normnoise: | |
model_output = model_output / (torch.std(model_output, dim=(1,2,3), keepdim=True) + 0.0001) | |
pred_original_sample = (noisy_images - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | |
pred_epsilon = model_output | |
elif self.prediction_type == "sample": | |
pred_original_sample = model_output | |
pred_epsilon = (noisy_images - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) | |
elif self.prediction_type == "v_prediction": | |
pred_original_sample = (alpha_prod_t**0.5) * noisy_images - (beta_prod_t**0.5) * model_output | |
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * noisy_images | |
alpha_prod_t_prev = get_alpha(alphas_cumprod, timestep_prev).view(-1, 1, 1, 1, 1) | |
beta_prod_t_prev = 1 - alpha_prod_t_prev | |
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
prev_sample = (alpha_prod_t_prev ** (0.5)) * pred_original_sample + beta_prod_t_prev ** (0.5) * pred_epsilon | |
return prev_sample.to(torch_dtype), pred_original_sample.to(torch_dtype) | |
def pred_prev_train(self, | |
noisy_images, noisy_residual, timestep, timestep_prev | |
): | |
torch_dtype = noisy_residual.dtype | |
#noisy_images = noisy_images.to(torch.float32) | |
#noisy_residual = noisy_residual.to(torch.float32) | |
print(noisy_residual.std()) | |
if self.normnoise: | |
noisy_residual = noisy_residual / (torch.std(noisy_residual, dim=(1,2,3), keepdim=True) + 0.0001) | |
print(noisy_residual.std()) | |
alphas_cumprod = self.alphas_cumprod.to(noisy_images.device) #.to(noisy_images.dtype) | |
alpha_prod_t = get_alpha(alphas_cumprod, timestep).view(-1, 1, 1, 1, 1).to(torch_dtype).detach() | |
beta_prod_t = 1 - alpha_prod_t | |
pred_original_sample = (noisy_images - beta_prod_t ** (0.5) * noisy_residual) / alpha_prod_t ** (0.5) | |
pred_epsilon = noisy_residual | |
alpha_prod_t_prev = get_alpha(alphas_cumprod, timestep_prev).view(-1, 1, 1, 1, 1).to(torch_dtype).detach() | |
beta_prod_t_prev = 1 - alpha_prod_t_prev | |
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
pred_sample_direction = beta_prod_t_prev ** (0.5) * pred_epsilon | |
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
prev_sample = (alpha_prod_t_prev ** (0.5)) * pred_original_sample + pred_sample_direction | |
#return prev_sample.to(torch_dtype), pred_original_sample.to(torch_dtype) | |
#embed() | |
assert prev_sample.dtype == torch_dtype | |
assert pred_original_sample.dtype == torch_dtype | |
return prev_sample, pred_original_sample | |
def add_more_noise(self, noisy_latents, noise, timestep, timestep_more): | |
alphas_cumprod = self.alphas_cumprod.to(noisy_latents.device) | |
alpha_prod_t = get_alpha(alphas_cumprod, timestep).view(-1, 1, 1, 1, 1) | |
alpha_prod_t_more = get_alpha(alphas_cumprod, timestep_more).view(-1, 1, 1, 1, 1) | |
sqrt_alpha_prod = alpha_prod_t ** (0.5) | |
sqrt_one_minus_alpha_prod = (1 - alpha_prod_t) ** (0.5) | |
sqrt_alpha_prod_more = alpha_prod_t_more ** (0.5) | |
sqrt_one_minus_alpha_prod_more = (1 - alpha_prod_t_more) ** (0.5) | |
#noisy_latents = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise_p | |
#noisy_latents * (sqrt_alpha_prod_more / sqrt_alpha_prod) | |
# = sqrt_alpha_prod_more * original_samples + sqrt_one_minus_alpha_prod * noise_p * (sqrt_alpha_prod_more / sqrt_alpha_prod) | |
noise_coe = ((sqrt_one_minus_alpha_prod_more ** 2) - (sqrt_one_minus_alpha_prod * sqrt_alpha_prod_more / sqrt_alpha_prod) ** 2) ** (0.5) | |
return noisy_latents * (sqrt_alpha_prod_more / sqrt_alpha_prod) + noise_coe * noise | |
def get_sigma(sigmas, timestep): | |
timestep_lt_zero_mask = torch.lt(timestep, 0).to(sigmas.dtype) | |
normal_sigma = sigmas[torch.clip(timestep, 0)] | |
zero_sigma = torch.zeros_like(normal_sigma).to(normal_sigma.dtype).to(normal_sigma.device) | |
return normal_sigma * (1 - timestep_lt_zero_mask) + zero_sigma * timestep_lt_zero_mask | |
class MyEulerA(MySchedulers): | |
def __init__(self, eulera_scheduler, normnoise=False) -> None: | |
super(MyEulerA, self).__init__() | |
assert isinstance(eulera_scheduler, EulerAncestralDiscreteScheduler) | |
self.sigmas = ((1 - eulera_scheduler.alphas_cumprod) / eulera_scheduler.alphas_cumprod) ** 0.5 | |
assert len(self.sigmas) == 1000 | |
self.generator = None | |
self.normnoise = normnoise | |
def pred_prev(self, | |
noisy_images, noisy_residual, timestep, timestep_prev | |
): | |
torch_dtype = noisy_residual.dtype | |
noisy_images = noisy_images.to(torch.float32) | |
noisy_residual = noisy_residual.to(torch.float32) | |
if self.normnoise: | |
noisy_residual = noisy_residual / (torch.std(noisy_residual, dim=(1,2,3), keepdim=True) + 0.0001) | |
sigmas = self.sigmas.to(noisy_images.device) | |
sigma_from = get_sigma(sigmas, timestep).view(-1, 1, 1, 1, 1) | |
sigma_to = get_sigma(sigmas, timestep_prev).view(-1, 1, 1, 1, 1) | |
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 | |
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 | |
sample = noisy_images * ((sigma_from**2 + 1) ** 0.5) | |
pred_original_sample = sample - sigma_from * noisy_residual | |
# 2. Convert to an ODE derivative | |
derivative = (sample - pred_original_sample) / sigma_from | |
dt = sigma_down - sigma_from | |
prev_sample = sample + derivative * dt | |
device = noisy_residual.device | |
noise = randn_tensor(noisy_residual.shape, dtype=torch_dtype, device=device, generator=self.generator).to(noisy_residual.dtype) | |
#embed() | |
prev_sample = prev_sample + noise * sigma_up | |
#print(sigma_up, ((sigma_to**2 + 1) ** 0.5)) | |
prev_sample = prev_sample / ((sigma_to**2 + 1) ** 0.5) | |
#embed() | |
return prev_sample.to(torch_dtype), pred_original_sample.to(torch_dtype) | |
if __name__ == "__main__": | |
a = MyDDIM() | |