Spaces:
Runtime error
Runtime error
File size: 8,758 Bytes
c83dd81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
from abc import ABC, abstractmethod
from diffusers import DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler
import torch
from diffusers.utils.torch_utils import randn_tensor
from IPython import embed
import numpy as np
class MySchedulers(ABC):
def __init__(self) -> None:
pass
@abstractmethod
def pred_prev(self,
noisy_images, noisy_residual, timestep, timestep_prev
):
raise NotImplementedError("not implemented")
def get_alpha(alphas_cumprod, timestep):
timestep_lt_zero_mask = torch.lt(timestep, 0).to(alphas_cumprod.dtype)
timestep_gt_999_mask = torch.gt(timestep, 999).to(alphas_cumprod.dtype)
normal_alpha = alphas_cumprod[torch.clip(timestep, 0, 999)]
one_alpha = torch.ones_like(normal_alpha).to(normal_alpha.dtype).to(normal_alpha.dtype)
zero_alpha = torch.zeros_like(normal_alpha).to(normal_alpha.dtype).to(normal_alpha.dtype)
return (normal_alpha * (1 - timestep_lt_zero_mask) + one_alpha * timestep_lt_zero_mask) * (1 - timestep_gt_999_mask) + zero_alpha * timestep_gt_999_mask
class MyDDIM(MySchedulers):
def __init__(self, ddpm_or_ddim_scheduler, normnoise=False) -> None:
super(MyDDIM, self).__init__()
assert isinstance(ddpm_or_ddim_scheduler, DDPMScheduler) or isinstance(ddpm_or_ddim_scheduler, DDIMScheduler)
self.alphas_cumprod = ddpm_or_ddim_scheduler.alphas_cumprod
self.normnoise = normnoise
self.prediction_type = ddpm_or_ddim_scheduler.config.prediction_type
def pred_prev(self,
noisy_images, model_output, timestep, timestep_prev
):
torch_dtype = model_output.dtype
#noisy_images = noisy_images.to(torch.float32)
#noisy_residual = noisy_residual.to(torch.float32)
#print(noisy_residual.std())
#print(noisy_residual.std())
alphas_cumprod = self.alphas_cumprod.to(noisy_images.device) #.to(noisy_images.dtype)
alpha_prod_t = get_alpha(alphas_cumprod, timestep).view(-1, 1, 1, 1, 1)
beta_prod_t = 1 - alpha_prod_t
if self.prediction_type == "epsilon":
if self.normnoise:
model_output = model_output / (torch.std(model_output, dim=(1,2,3), keepdim=True) + 0.0001)
pred_original_sample = (noisy_images - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (noisy_images - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * noisy_images - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * noisy_images
alpha_prod_t_prev = get_alpha(alphas_cumprod, timestep_prev).view(-1, 1, 1, 1, 1)
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = (alpha_prod_t_prev ** (0.5)) * pred_original_sample + beta_prod_t_prev ** (0.5) * pred_epsilon
return prev_sample.to(torch_dtype), pred_original_sample.to(torch_dtype)
def pred_prev_train(self,
noisy_images, noisy_residual, timestep, timestep_prev
):
torch_dtype = noisy_residual.dtype
#noisy_images = noisy_images.to(torch.float32)
#noisy_residual = noisy_residual.to(torch.float32)
print(noisy_residual.std())
if self.normnoise:
noisy_residual = noisy_residual / (torch.std(noisy_residual, dim=(1,2,3), keepdim=True) + 0.0001)
print(noisy_residual.std())
alphas_cumprod = self.alphas_cumprod.to(noisy_images.device) #.to(noisy_images.dtype)
alpha_prod_t = get_alpha(alphas_cumprod, timestep).view(-1, 1, 1, 1, 1).to(torch_dtype).detach()
beta_prod_t = 1 - alpha_prod_t
pred_original_sample = (noisy_images - beta_prod_t ** (0.5) * noisy_residual) / alpha_prod_t ** (0.5)
pred_epsilon = noisy_residual
alpha_prod_t_prev = get_alpha(alphas_cumprod, timestep_prev).view(-1, 1, 1, 1, 1).to(torch_dtype).detach()
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = beta_prod_t_prev ** (0.5) * pred_epsilon
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = (alpha_prod_t_prev ** (0.5)) * pred_original_sample + pred_sample_direction
#return prev_sample.to(torch_dtype), pred_original_sample.to(torch_dtype)
#embed()
assert prev_sample.dtype == torch_dtype
assert pred_original_sample.dtype == torch_dtype
return prev_sample, pred_original_sample
def add_more_noise(self, noisy_latents, noise, timestep, timestep_more):
alphas_cumprod = self.alphas_cumprod.to(noisy_latents.device)
alpha_prod_t = get_alpha(alphas_cumprod, timestep).view(-1, 1, 1, 1, 1)
alpha_prod_t_more = get_alpha(alphas_cumprod, timestep_more).view(-1, 1, 1, 1, 1)
sqrt_alpha_prod = alpha_prod_t ** (0.5)
sqrt_one_minus_alpha_prod = (1 - alpha_prod_t) ** (0.5)
sqrt_alpha_prod_more = alpha_prod_t_more ** (0.5)
sqrt_one_minus_alpha_prod_more = (1 - alpha_prod_t_more) ** (0.5)
#noisy_latents = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise_p
#noisy_latents * (sqrt_alpha_prod_more / sqrt_alpha_prod)
# = sqrt_alpha_prod_more * original_samples + sqrt_one_minus_alpha_prod * noise_p * (sqrt_alpha_prod_more / sqrt_alpha_prod)
noise_coe = ((sqrt_one_minus_alpha_prod_more ** 2) - (sqrt_one_minus_alpha_prod * sqrt_alpha_prod_more / sqrt_alpha_prod) ** 2) ** (0.5)
return noisy_latents * (sqrt_alpha_prod_more / sqrt_alpha_prod) + noise_coe * noise
def get_sigma(sigmas, timestep):
timestep_lt_zero_mask = torch.lt(timestep, 0).to(sigmas.dtype)
normal_sigma = sigmas[torch.clip(timestep, 0)]
zero_sigma = torch.zeros_like(normal_sigma).to(normal_sigma.dtype).to(normal_sigma.device)
return normal_sigma * (1 - timestep_lt_zero_mask) + zero_sigma * timestep_lt_zero_mask
class MyEulerA(MySchedulers):
def __init__(self, eulera_scheduler, normnoise=False) -> None:
super(MyEulerA, self).__init__()
assert isinstance(eulera_scheduler, EulerAncestralDiscreteScheduler)
self.sigmas = ((1 - eulera_scheduler.alphas_cumprod) / eulera_scheduler.alphas_cumprod) ** 0.5
assert len(self.sigmas) == 1000
self.generator = None
self.normnoise = normnoise
def pred_prev(self,
noisy_images, noisy_residual, timestep, timestep_prev
):
torch_dtype = noisy_residual.dtype
noisy_images = noisy_images.to(torch.float32)
noisy_residual = noisy_residual.to(torch.float32)
if self.normnoise:
noisy_residual = noisy_residual / (torch.std(noisy_residual, dim=(1,2,3), keepdim=True) + 0.0001)
sigmas = self.sigmas.to(noisy_images.device)
sigma_from = get_sigma(sigmas, timestep).view(-1, 1, 1, 1, 1)
sigma_to = get_sigma(sigmas, timestep_prev).view(-1, 1, 1, 1, 1)
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
sample = noisy_images * ((sigma_from**2 + 1) ** 0.5)
pred_original_sample = sample - sigma_from * noisy_residual
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_from
dt = sigma_down - sigma_from
prev_sample = sample + derivative * dt
device = noisy_residual.device
noise = randn_tensor(noisy_residual.shape, dtype=torch_dtype, device=device, generator=self.generator).to(noisy_residual.dtype)
#embed()
prev_sample = prev_sample + noise * sigma_up
#print(sigma_up, ((sigma_to**2 + 1) ** 0.5))
prev_sample = prev_sample / ((sigma_to**2 + 1) ** 0.5)
#embed()
return prev_sample.to(torch_dtype), pred_original_sample.to(torch_dtype)
if __name__ == "__main__":
a = MyDDIM()
|