Spaces:
Runtime error
Runtime error
"""Adapted from https://github.com/lucidrains/denoising-diffusion-pytorch""" | |
from argparse import Namespace | |
import math | |
from typing import List, Tuple, Optional | |
import torch | |
import torch.nn.functional as F | |
from einops import reduce, rearrange | |
from torch import nn, Tensor | |
from models.unet_model import Unet | |
from trainers.utils import default, get_index_from_list, normalize_to_neg_one_to_one | |
def linear_beta_schedule( | |
timesteps: int, | |
start: float = 0.0001, | |
end: float = 0.02 | |
) -> Tensor: | |
""" | |
:param timesteps: Number of time steps | |
:return schedule: betas at every timestep, (timesteps,) | |
""" | |
scale = 1000 / timesteps | |
beta_start = scale * start | |
beta_end = scale * end | |
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32) | |
def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> Tensor: | |
""" | |
cosine schedule | |
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ | |
:param timesteps: Number of time steps | |
:param s: scaling factor | |
:return schedule: betas at every timestep, (timesteps,) | |
""" | |
steps = timesteps + 1 | |
x = torch.linspace(0, timesteps, steps, dtype=torch.float32) | |
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 | |
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] | |
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) | |
return torch.clip(betas, 0, 0.999) | |
class DiffusionModel(nn.Module): | |
def __init__(self, config: Namespace): | |
super().__init__() | |
# Default parameters | |
self.config = config | |
dim: int = self.default('dim', 64) | |
dim_mults: List[int] = self.default('dim_mults', [1, 2, 4, 8]) | |
channels: int = self.default('channels', 1) | |
timesteps: int = self.default('timesteps', 1000) | |
beta_schedule: str = self.default('beta_schedule', 'cosine') | |
objective: str = self.default('objective', 'pred_noise') # 'pred_noise' or 'pred_x_0' | |
p2_loss_weight_gamma: float = self.default('p2_loss_weight_gamma', 0.) # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended | |
p2_loss_weight_k: float = self.default('p2_loss_weight_k', 1.) | |
dynamic_threshold_percentile: float = self.default('dynamic_threshold_percentile', 0.995) | |
self.timesteps = timesteps | |
self.objective = objective | |
self.dynamic_threshold_percentile = dynamic_threshold_percentile | |
self.model = Unet( | |
dim, | |
dim_mults=dim_mults, | |
channels=channels | |
) | |
if beta_schedule == 'linear': | |
betas = linear_beta_schedule(timesteps) | |
elif beta_schedule == 'cosine': | |
betas = cosine_beta_schedule(timesteps) | |
else: | |
raise ValueError(f'unknown beta schedule {beta_schedule}') | |
alphas = 1. - betas | |
alphas_cumprod = torch.cumprod(alphas, axis=0) | |
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) | |
# Calculations for diffusion q(x_t | x_{t-1}) and others | |
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) | |
self.register_buffer('sqrt_recip_alphas_cumprod', | |
torch.sqrt(1. / alphas_cumprod)) | |
self.register_buffer('sqrt_recipm1_alphas_cumprod', | |
torch.sqrt(1. / alphas_cumprod - 1)) | |
self.register_buffer('sqrt_one_minus_alphas_cumprod', | |
torch.sqrt(1. - alphas_cumprod)) | |
# Calculations for posterior q(x_{t-1} | x_t, x_0) | |
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) | |
self.register_buffer('posterior_variance', posterior_variance) | |
self.register_buffer( | |
'posterior_log_variance_clipped', | |
torch.log(posterior_variance.clamp(min=1e-20)) | |
) | |
self.register_buffer( | |
'posterior_mean_coef1', | |
betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) | |
) | |
self.register_buffer( | |
'posterior_mean_coef2', | |
(1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) | |
) | |
# p2 reweighting | |
p2_loss_weight = ((p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) | |
** (-p2_loss_weight_gamma)) | |
self.register_buffer('p2_loss_weight', p2_loss_weight) | |
def default(self, val, d): | |
return vars(self.config)[val] if val in self.config else d | |
def train_step(self, x_0: Tensor, cond: Optional[Tensor] = None, t:Optional[Tensor] = None) -> Tensor: | |
N, device = x_0.shape[0], x_0.device | |
# If t is not none, use it, otherwise sample from uniform | |
if t is not None: | |
t = t.long().to(device) | |
else: | |
t = torch.randint(0, self.timesteps, (N,), device=device).long() # (N) | |
model_out, noise = self(x_0, t, cond=cond) | |
if self.objective == 'pred_noise': | |
target = noise # (N, C, H, W) | |
elif self.objective == 'pred_x_0': | |
target = x_0 # (N, C, H, W) | |
else: | |
raise ValueError(f'unknown objective {self.objective}') | |
loss = F.l1_loss(model_out, target, reduction='none') # (N, C, H, W) | |
loss = reduce(loss, 'b ... -> b (...)', 'mean') # (N, (C x H x W)) | |
# p2 reweighting | |
loss = loss * get_index_from_list(self.p2_loss_weight, t, loss.shape) | |
return loss.mean() | |
def val_step(self, x_0: Tensor, cond: Optional[Tensor] = None, t_steps:Optional[int] = None) -> Tensor: | |
if not t_steps: | |
t_steps = self.timesteps | |
step_size = self.timesteps // t_steps | |
N, device = x_0.shape[0], x_0.device | |
losses = [] | |
for t in range(0, self.timesteps, step_size): | |
t = torch.ones((N,)) * t | |
t = t.long().to(device) | |
losses.append(self.train_step(x_0, cond, t)) | |
return torch.stack(losses).mean() | |
def forward(self, x_0: Tensor, t: Tensor, cond: Optional[Tensor] = None) -> Tensor: | |
""" | |
Noise x_0 for t timestep and get the model prediction. | |
:param x_0: Clean image, (N, C, H, W) | |
:param t: Timestep, (N,) | |
:param cond: element to condition the reconstruction on - eg image when x_0 is a segmentation (N, C', H, W) | |
:return pred: Model output, predicted noise or image, (N, C, H, W) | |
:return noise: Added noise, (N, C, H, W) | |
""" | |
if self.config.normalize: | |
x_0 = normalize_to_neg_one_to_one(x_0) | |
if cond is not None and self.config.normalize: | |
cond = normalize_to_neg_one_to_one(cond) | |
x_t, noise = self.forward_diffusion_model(x_0, t) | |
return self.model(x_t, t, cond), noise | |
def forward_diffusion_model( | |
self, | |
x_0: Tensor, | |
t: Tensor, | |
noise: Optional[Tensor] = None, | |
) -> Tuple[Tensor, Tensor]: | |
""" | |
Takes an image and a timestep as input and returns the noisy version | |
of it. | |
:param x_0: Image at timestep 0, (N, C, H, W) | |
:param t: Timestep, (N) | |
:param cond: element to condition the reconstruction on - eg image when x_0 is a segmentation (N, C', H, W) | |
:return x_t: Noisy image at timestep t, (N, C, H, W) | |
:return noise: Noise added to the image, (N, C, H, W) | |
""" | |
noise = default(noise, lambda: torch.randn_like(x_0)) | |
sqrt_alphas_cumprod_t = get_index_from_list( | |
self.sqrt_alphas_cumprod, t, x_0.shape) | |
sqrt_one_minus_alphas_cumprod_t = get_index_from_list( | |
self.sqrt_one_minus_alphas_cumprod, t, x_0.shape) | |
# mean + variance | |
x_t = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise | |
return x_t, noise | |
def sample_timestep(self, x_t: Tensor, t: int, cond=Optional[Tensor]) -> Tensor: | |
""" | |
Sample from the model. | |
:param x_t: Image noised t times, (N, C, H, W) | |
:param t: Timestep | |
:return: Sampled image, (N, C, H, W) | |
""" | |
N = x_t.shape[0] | |
device = x_t.device | |
batched_t = torch.full((N,), t, device=device, dtype=torch.long) # (N) | |
model_mean, model_log_variance, _ = self.p_mean_variance(x_t, batched_t, cond=cond) | |
noise = torch.randn_like(x_t) if t > 0 else 0. | |
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise | |
return pred_img | |
def p_mean_variance(self, x_t: Tensor, t: Tensor, clip_denoised: bool = True, cond:Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]: | |
_, pred_x_0 = self.model_predictions(x_t, t, cond=cond) | |
if clip_denoised: | |
# pred_x_0.clamp_(-1., 1.) | |
# Dynamic thrsholding | |
s = torch.quantile(rearrange(pred_x_0, 'b ... -> b (...)').abs(), | |
self.dynamic_threshold_percentile, | |
dim=1) | |
s = torch.max(s, torch.tensor(1.0))[:, None, None, None] | |
pred_x_0 = torch.clip(pred_x_0, -s, s) / s | |
(model_mean, | |
posterior_log_variance) = self.q_posterior(pred_x_0, x_t, t) | |
return model_mean, posterior_log_variance, pred_x_0 | |
def model_predictions(self, x_t: Tensor, t: Tensor, cond:Optional[Tensor] = None) \ | |
-> Tuple[Tensor, Tensor]: | |
""" | |
Return the predicted noise and x_0 for a given x_t and t. | |
:param x_t: Noised image at timestep t, (N, C, H, W) | |
:param t: Timestep, (N,) | |
:return pred_noise: Predicted noise, (N, C, H, W) | |
:return pred_x_0: Predicted x_0, (N, C, H, W) | |
""" | |
model_output = self.model(x_t, t, cond) | |
if self.objective == 'pred_noise': | |
pred_noise = model_output | |
pred_x_0 = self.predict_x_0_from_noise(x_t, t, model_output) | |
elif self.objective == 'pred_x_start': | |
pred_noise = self.predict_noise_from_x_0(x_t, t, model_output) | |
pred_x_0 = model_output | |
return pred_noise, pred_x_0 | |
def q_posterior(self, x_start: Tensor, x_t: Tensor, t: Tensor) \ | |
-> Tuple[Tensor, Tensor]: | |
posterior_mean = ( | |
get_index_from_list(self.posterior_mean_coef1, t, x_t.shape) * x_start | |
+ get_index_from_list(self.posterior_mean_coef2, t, x_t.shape) * x_t | |
) | |
posterior_log_variance_clipped = get_index_from_list( | |
self.posterior_log_variance_clipped, t, x_t.shape) | |
return posterior_mean, posterior_log_variance_clipped | |
def predict_x_0_from_noise(self, x_t: Tensor, t: Tensor, noise: Tensor) \ | |
-> Tensor: | |
""" | |
Get x_0 given x_t, t, and the known or predicted noise. | |
:param x_t: Noised image at timestep t, (N, C, H, W) | |
:param t: Timestep, (N,) | |
:param noise: Noise, (N, C, H, W) | |
:return: Predicted x_0, (N, C, H, W) | |
""" | |
return ( | |
get_index_from_list( | |
self.sqrt_recip_alphas_cumprod, t, x_t.shape) | |
* x_t | |
- get_index_from_list( | |
self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) | |
* noise | |
) | |
def predict_noise_from_x_0(self, x_t: Tensor, t: Tensor, x_0: Tensor) \ | |
-> Tensor: | |
""" | |
Get noise given the known or predicted x_0, x_t, and t | |
:param x_t: Noised image at timestep t, (N, C, H, W) | |
:param t: Timestep, (N,) | |
:param noise: Noise, (N, C, H, W) | |
:return: Predicted noise, (N, C, H, W) | |
""" | |
return ( | |
(get_index_from_list(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x_0) | |
/ get_index_from_list(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) | |
) | |