import math import torch import torch.nn.functional as F def cosine_beta_schedule(timesteps, s=0.008): """ cosine schedule as proposed in https://arxiv.org/abs/2102.09672 """ steps = timesteps + 1 x = torch.linspace(0, timesteps, steps) alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0.0001, 0.9999) def linear_beta_schedule(timesteps): beta_start = 0.0001 beta_end = 0.02 return torch.linspace(beta_start, beta_end, timesteps) def quadratic_beta_schedule(timesteps): beta_start = 0.0001 beta_end = 0.02 return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2 def sigmoid_beta_schedule(timesteps): beta_start = 0.0001 beta_end = 0.02 betas = torch.linspace(-6, 6, timesteps) return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start class NoiseSchedule: def __init__(self, timesteps=200): self.timesteps = timesteps # define beta schedule self.betas = linear_beta_schedule(timesteps=timesteps) # self.betas = cosine_beta_schedule(timesteps=timesteps) # define alphas self.alphas = 1. - self.betas # alphas_cumprod: alpha bar self.alphas_cumprod = torch.cumprod(self.alphas, axis=0) self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0) self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod) def extract(a, t, x_shape): batch_size = t.shape[0] out = a.gather(-1, t.cpu()) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) # forward diffusion (using the nice property) def q_sample(x_start, t, noise_schedule, noise=None): if noise is None: noise = torch.randn_like(x_start) sqrt_alphas_cumprod_t = extract(noise_schedule.sqrt_alphas_cumprod, t, x_start.shape) # print("sqrt_alphas_cumprod_t", sqrt_alphas_cumprod_t) sqrt_one_minus_alphas_cumprod_t = extract( noise_schedule.sqrt_one_minus_alphas_cumprod, t, x_start.shape ) # print("sqrt_one_minus_alphas_cumprod_t", sqrt_one_minus_alphas_cumprod_t) # print("noise", noise) return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise