Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
import math | |
import numpy as np | |
def clip_noise_schedule(alphas2, clip_value=0.001): | |
""" | |
For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during | |
sampling. | |
""" | |
alphas2 = np.concatenate([np.ones(1), alphas2], axis=0) | |
alphas_step = (alphas2[1:] / alphas2[:-1]) | |
alphas_step = np.clip(alphas_step, a_min=clip_value, a_max=1.) | |
alphas2 = np.cumprod(alphas_step, axis=0) | |
return alphas2 | |
def polynomial_schedule(timesteps: int, s=1e-4, power=3.): | |
""" | |
A noise schedule based on a simple polynomial equation: 1 - x^power. | |
""" | |
steps = timesteps + 1 | |
x = np.linspace(0, steps, steps) | |
alphas2 = (1 - np.power(x / steps, power)) ** 2 | |
alphas2 = clip_noise_schedule(alphas2, clip_value=0.001) | |
precision = 1 - 2 * s | |
alphas2 = precision * alphas2 + s | |
return alphas2 | |
def cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1): | |
""" | |
cosine schedule | |
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ | |
""" | |
steps = timesteps + 2 | |
x = np.linspace(0, steps, steps) | |
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 | |
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] | |
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) | |
betas = np.clip(betas, a_min=0, a_max=0.999) | |
alphas = 1. - betas | |
alphas_cumprod = np.cumprod(alphas, axis=0) | |
if raise_to_power != 1: | |
alphas_cumprod = np.power(alphas_cumprod, raise_to_power) | |
return alphas_cumprod | |
class PositiveLinear(torch.nn.Module): | |
"""Linear layer with weights forced to be positive.""" | |
def __init__(self, in_features: int, out_features: int, bias: bool = True, | |
weight_init_offset: int = -2): | |
super(PositiveLinear, self).__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.weight = torch.nn.Parameter( | |
torch.empty((out_features, in_features))) | |
if bias: | |
self.bias = torch.nn.Parameter(torch.empty(out_features)) | |
else: | |
self.register_parameter('bias', None) | |
self.weight_init_offset = weight_init_offset | |
self.reset_parameters() | |
def reset_parameters(self) -> None: | |
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | |
with torch.no_grad(): | |
self.weight.add_(self.weight_init_offset) | |
if self.bias is not None: | |
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) | |
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 | |
torch.nn.init.uniform_(self.bias, -bound, bound) | |
def forward(self, x): | |
positive_weight = F.softplus(self.weight) | |
return F.linear(x, positive_weight, self.bias) | |
class PredefinedNoiseSchedule(torch.nn.Module): | |
""" | |
Predefined noise schedule. Essentially creates a lookup array for predefined (non-learned) noise schedules. | |
""" | |
def __init__(self, noise_schedule, timesteps, precision): | |
super(PredefinedNoiseSchedule, self).__init__() | |
self.timesteps = timesteps | |
if noise_schedule == 'cosine': | |
alphas2 = cosine_beta_schedule(timesteps) | |
elif 'polynomial' in noise_schedule: | |
splits = noise_schedule.split('_') | |
assert len(splits) == 2 | |
power = float(splits[1]) | |
alphas2 = polynomial_schedule(timesteps, s=precision, power=power) | |
else: | |
raise ValueError(noise_schedule) | |
# print('alphas2', alphas2) | |
sigmas2 = 1 - alphas2 | |
log_alphas2 = np.log(alphas2) | |
log_sigmas2 = np.log(sigmas2) | |
log_alphas2_to_sigmas2 = log_alphas2 - log_sigmas2 | |
# print('gamma', -log_alphas2_to_sigmas2) | |
self.gamma = torch.nn.Parameter( | |
torch.from_numpy(-log_alphas2_to_sigmas2).float(), | |
requires_grad=False) | |
def forward(self, t): | |
t_int = torch.round(t * self.timesteps).long() | |
return self.gamma[t_int] | |
class GammaNetwork(torch.nn.Module): | |
"""The gamma network models a monotonic increasing function. Construction as in the VDM paper.""" | |
def __init__(self): | |
super().__init__() | |
self.l1 = PositiveLinear(1, 1) | |
self.l2 = PositiveLinear(1, 1024) | |
self.l3 = PositiveLinear(1024, 1) | |
self.gamma_0 = torch.nn.Parameter(torch.tensor([-5.])) | |
self.gamma_1 = torch.nn.Parameter(torch.tensor([10.])) | |
self.show_schedule() | |
def show_schedule(self, num_steps=50): | |
t = torch.linspace(0, 1, num_steps).view(num_steps, 1) | |
gamma = self.forward(t) | |
print('Gamma schedule:') | |
print(gamma.detach().cpu().numpy().reshape(num_steps)) | |
def gamma_tilde(self, t): | |
l1_t = self.l1(t) | |
return l1_t + self.l3(torch.sigmoid(self.l2(l1_t))) | |
def forward(self, t): | |
zeros, ones = torch.zeros_like(t), torch.ones_like(t) | |
# Not super efficient. | |
gamma_tilde_0 = self.gamma_tilde(zeros) | |
gamma_tilde_1 = self.gamma_tilde(ones) | |
gamma_tilde_t = self.gamma_tilde(t) | |
# Normalize to [0, 1] | |
normalized_gamma = (gamma_tilde_t - gamma_tilde_0) / ( | |
gamma_tilde_1 - gamma_tilde_0) | |
# Rescale to [gamma_0, gamma_1] | |
gamma = self.gamma_0 + (self.gamma_1 - self.gamma_0) * normalized_gamma | |
return gamma | |