"""Adapted from https://github.com/lucidrains/denoising-diffusion-pytorch""" from argparse import Namespace import math from typing import List, Tuple, Optional import torch import torch.nn.functional as F from einops import reduce, rearrange from torch import nn, Tensor from models.unet_model import Unet from trainers.utils import default, get_index_from_list, normalize_to_neg_one_to_one def linear_beta_schedule( timesteps: int, start: float = 0.0001, end: float = 0.02 ) -> Tensor: """ :param timesteps: Number of time steps :return schedule: betas at every timestep, (timesteps,) """ scale = 1000 / timesteps beta_start = scale * start beta_end = scale * end return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float32) def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> Tensor: """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ :param timesteps: Number of time steps :param s: scaling factor :return schedule: betas at every timestep, (timesteps,) """ steps = timesteps + 1 x = torch.linspace(0, timesteps, steps, dtype=torch.float32) alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) class DiffusionModel(nn.Module): def __init__(self, config: Namespace): super().__init__() # Default parameters self.config = config dim: int = self.default('dim', 64) dim_mults: List[int] = self.default('dim_mults', [1, 2, 4, 8]) channels: int = self.default('channels', 1) timesteps: int = self.default('timesteps', 1000) beta_schedule: str = self.default('beta_schedule', 'cosine') objective: str = self.default('objective', 'pred_noise') # 'pred_noise' or 'pred_x_0' p2_loss_weight_gamma: float = self.default('p2_loss_weight_gamma', 0.) # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended p2_loss_weight_k: float = self.default('p2_loss_weight_k', 1.) dynamic_threshold_percentile: float = self.default('dynamic_threshold_percentile', 0.995) self.timesteps = timesteps self.objective = objective self.dynamic_threshold_percentile = dynamic_threshold_percentile self.model = Unet( dim, dim_mults=dim_mults, channels=channels ) if beta_schedule == 'linear': betas = linear_beta_schedule(timesteps) elif beta_schedule == 'cosine': betas = cosine_beta_schedule(timesteps) else: raise ValueError(f'unknown beta schedule {beta_schedule}') alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) # Calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) # Calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) self.register_buffer('posterior_variance', posterior_variance) self.register_buffer( 'posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)) ) self.register_buffer( 'posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod) ) self.register_buffer( 'posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod) ) # p2 reweighting p2_loss_weight = ((p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** (-p2_loss_weight_gamma)) self.register_buffer('p2_loss_weight', p2_loss_weight) def default(self, val, d): return vars(self.config)[val] if val in self.config else d def train_step(self, x_0: Tensor, cond: Optional[Tensor] = None, t:Optional[Tensor] = None) -> Tensor: N, device = x_0.shape[0], x_0.device # If t is not none, use it, otherwise sample from uniform if t is not None: t = t.long().to(device) else: t = torch.randint(0, self.timesteps, (N,), device=device).long() # (N) model_out, noise = self(x_0, t, cond=cond) if self.objective == 'pred_noise': target = noise # (N, C, H, W) elif self.objective == 'pred_x_0': target = x_0 # (N, C, H, W) else: raise ValueError(f'unknown objective {self.objective}') loss = F.l1_loss(model_out, target, reduction='none') # (N, C, H, W) loss = reduce(loss, 'b ... -> b (...)', 'mean') # (N, (C x H x W)) # p2 reweighting loss = loss * get_index_from_list(self.p2_loss_weight, t, loss.shape) return loss.mean() def val_step(self, x_0: Tensor, cond: Optional[Tensor] = None, t_steps:Optional[int] = None) -> Tensor: if not t_steps: t_steps = self.timesteps step_size = self.timesteps // t_steps N, device = x_0.shape[0], x_0.device losses = [] for t in range(0, self.timesteps, step_size): t = torch.ones((N,)) * t t = t.long().to(device) losses.append(self.train_step(x_0, cond, t)) return torch.stack(losses).mean() def forward(self, x_0: Tensor, t: Tensor, cond: Optional[Tensor] = None) -> Tensor: """ Noise x_0 for t timestep and get the model prediction. :param x_0: Clean image, (N, C, H, W) :param t: Timestep, (N,) :param cond: element to condition the reconstruction on - eg image when x_0 is a segmentation (N, C', H, W) :return pred: Model output, predicted noise or image, (N, C, H, W) :return noise: Added noise, (N, C, H, W) """ if self.config.normalize: x_0 = normalize_to_neg_one_to_one(x_0) if cond is not None and self.config.normalize: cond = normalize_to_neg_one_to_one(cond) x_t, noise = self.forward_diffusion_model(x_0, t) return self.model(x_t, t, cond), noise def forward_diffusion_model( self, x_0: Tensor, t: Tensor, noise: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """ Takes an image and a timestep as input and returns the noisy version of it. :param x_0: Image at timestep 0, (N, C, H, W) :param t: Timestep, (N) :param cond: element to condition the reconstruction on - eg image when x_0 is a segmentation (N, C', H, W) :return x_t: Noisy image at timestep t, (N, C, H, W) :return noise: Noise added to the image, (N, C, H, W) """ noise = default(noise, lambda: torch.randn_like(x_0)) sqrt_alphas_cumprod_t = get_index_from_list( self.sqrt_alphas_cumprod, t, x_0.shape) sqrt_one_minus_alphas_cumprod_t = get_index_from_list( self.sqrt_one_minus_alphas_cumprod, t, x_0.shape) # mean + variance x_t = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise return x_t, noise @torch.no_grad() def sample_timestep(self, x_t: Tensor, t: int, cond=Optional[Tensor]) -> Tensor: """ Sample from the model. :param x_t: Image noised t times, (N, C, H, W) :param t: Timestep :return: Sampled image, (N, C, H, W) """ N = x_t.shape[0] device = x_t.device batched_t = torch.full((N,), t, device=device, dtype=torch.long) # (N) model_mean, model_log_variance, _ = self.p_mean_variance(x_t, batched_t, cond=cond) noise = torch.randn_like(x_t) if t > 0 else 0. pred_img = model_mean + (0.5 * model_log_variance).exp() * noise return pred_img def p_mean_variance(self, x_t: Tensor, t: Tensor, clip_denoised: bool = True, cond:Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]: _, pred_x_0 = self.model_predictions(x_t, t, cond=cond) if clip_denoised: # pred_x_0.clamp_(-1., 1.) # Dynamic thrsholding s = torch.quantile(rearrange(pred_x_0, 'b ... -> b (...)').abs(), self.dynamic_threshold_percentile, dim=1) s = torch.max(s, torch.tensor(1.0))[:, None, None, None] pred_x_0 = torch.clip(pred_x_0, -s, s) / s (model_mean, posterior_log_variance) = self.q_posterior(pred_x_0, x_t, t) return model_mean, posterior_log_variance, pred_x_0 def model_predictions(self, x_t: Tensor, t: Tensor, cond:Optional[Tensor] = None) \ -> Tuple[Tensor, Tensor]: """ Return the predicted noise and x_0 for a given x_t and t. :param x_t: Noised image at timestep t, (N, C, H, W) :param t: Timestep, (N,) :return pred_noise: Predicted noise, (N, C, H, W) :return pred_x_0: Predicted x_0, (N, C, H, W) """ model_output = self.model(x_t, t, cond) if self.objective == 'pred_noise': pred_noise = model_output pred_x_0 = self.predict_x_0_from_noise(x_t, t, model_output) elif self.objective == 'pred_x_start': pred_noise = self.predict_noise_from_x_0(x_t, t, model_output) pred_x_0 = model_output return pred_noise, pred_x_0 def q_posterior(self, x_start: Tensor, x_t: Tensor, t: Tensor) \ -> Tuple[Tensor, Tensor]: posterior_mean = ( get_index_from_list(self.posterior_mean_coef1, t, x_t.shape) * x_start + get_index_from_list(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_log_variance_clipped = get_index_from_list( self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_log_variance_clipped def predict_x_0_from_noise(self, x_t: Tensor, t: Tensor, noise: Tensor) \ -> Tensor: """ Get x_0 given x_t, t, and the known or predicted noise. :param x_t: Noised image at timestep t, (N, C, H, W) :param t: Timestep, (N,) :param noise: Noise, (N, C, H, W) :return: Predicted x_0, (N, C, H, W) """ return ( get_index_from_list( self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - get_index_from_list( self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) def predict_noise_from_x_0(self, x_t: Tensor, t: Tensor, x_0: Tensor) \ -> Tensor: """ Get noise given the known or predicted x_0, x_t, and t :param x_t: Noised image at timestep t, (N, C, H, W) :param t: Timestep, (N,) :param noise: Noise, (N, C, H, W) :return: Predicted noise, (N, C, H, W) """ return ( (get_index_from_list(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x_0) / get_index_from_list(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) )