import torch import torch.nn.functional as F import math import numpy as np def clip_noise_schedule(alphas2, clip_value=0.001): """ For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during sampling. """ alphas2 = np.concatenate([np.ones(1), alphas2], axis=0) alphas_step = (alphas2[1:] / alphas2[:-1]) alphas_step = np.clip(alphas_step, a_min=clip_value, a_max=1.) alphas2 = np.cumprod(alphas_step, axis=0) return alphas2 def polynomial_schedule(timesteps: int, s=1e-4, power=3.): """ A noise schedule based on a simple polynomial equation: 1 - x^power. """ steps = timesteps + 1 x = np.linspace(0, steps, steps) alphas2 = (1 - np.power(x / steps, power)) ** 2 alphas2 = clip_noise_schedule(alphas2, clip_value=0.001) precision = 1 - 2 * s alphas2 = precision * alphas2 + s return alphas2 def cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 2 x = np.linspace(0, steps, steps) alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) betas = np.clip(betas, a_min=0, a_max=0.999) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) if raise_to_power != 1: alphas_cumprod = np.power(alphas_cumprod, raise_to_power) return alphas_cumprod class PositiveLinear(torch.nn.Module): """Linear layer with weights forced to be positive.""" def __init__(self, in_features: int, out_features: int, bias: bool = True, weight_init_offset: int = -2): super(PositiveLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = torch.nn.Parameter( torch.empty((out_features, in_features))) if bias: self.bias = torch.nn.Parameter(torch.empty(out_features)) else: self.register_parameter('bias', None) self.weight_init_offset = weight_init_offset self.reset_parameters() def reset_parameters(self) -> None: torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) with torch.no_grad(): self.weight.add_(self.weight_init_offset) if self.bias is not None: fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 torch.nn.init.uniform_(self.bias, -bound, bound) def forward(self, x): positive_weight = F.softplus(self.weight) return F.linear(x, positive_weight, self.bias) class PredefinedNoiseSchedule(torch.nn.Module): """ Predefined noise schedule. Essentially creates a lookup array for predefined (non-learned) noise schedules. """ def __init__(self, noise_schedule, timesteps, precision): super(PredefinedNoiseSchedule, self).__init__() self.timesteps = timesteps if noise_schedule == 'cosine': alphas2 = cosine_beta_schedule(timesteps) elif 'polynomial' in noise_schedule: splits = noise_schedule.split('_') assert len(splits) == 2 power = float(splits[1]) alphas2 = polynomial_schedule(timesteps, s=precision, power=power) else: raise ValueError(noise_schedule) # print('alphas2', alphas2) sigmas2 = 1 - alphas2 log_alphas2 = np.log(alphas2) log_sigmas2 = np.log(sigmas2) log_alphas2_to_sigmas2 = log_alphas2 - log_sigmas2 # print('gamma', -log_alphas2_to_sigmas2) self.gamma = torch.nn.Parameter( torch.from_numpy(-log_alphas2_to_sigmas2).float(), requires_grad=False) def forward(self, t): t_int = torch.round(t * self.timesteps).long() return self.gamma[t_int] class GammaNetwork(torch.nn.Module): """The gamma network models a monotonic increasing function. Construction as in the VDM paper.""" def __init__(self): super().__init__() self.l1 = PositiveLinear(1, 1) self.l2 = PositiveLinear(1, 1024) self.l3 = PositiveLinear(1024, 1) self.gamma_0 = torch.nn.Parameter(torch.tensor([-5.])) self.gamma_1 = torch.nn.Parameter(torch.tensor([10.])) self.show_schedule() def show_schedule(self, num_steps=50): t = torch.linspace(0, 1, num_steps).view(num_steps, 1) gamma = self.forward(t) print('Gamma schedule:') print(gamma.detach().cpu().numpy().reshape(num_steps)) def gamma_tilde(self, t): l1_t = self.l1(t) return l1_t + self.l3(torch.sigmoid(self.l2(l1_t))) def forward(self, t): zeros, ones = torch.zeros_like(t), torch.ones_like(t) # Not super efficient. gamma_tilde_0 = self.gamma_tilde(zeros) gamma_tilde_1 = self.gamma_tilde(ones) gamma_tilde_t = self.gamma_tilde(t) # Normalize to [0, 1] normalized_gamma = (gamma_tilde_t - gamma_tilde_0) / ( gamma_tilde_1 - gamma_tilde_0) # Rescale to [gamma_0, gamma_1] gamma = self.gamma_0 + (self.gamma_1 - self.gamma_0) * normalized_gamma return gamma