DiffLinker / src /noise.py
igashov
DiffLinker code
95ba5bc
import torch
import torch.nn.functional as F
import math
import numpy as np
def clip_noise_schedule(alphas2, clip_value=0.001):
"""
For a noise schedule given by alpha^2, this clips alpha_t / alpha_t-1. This may help improve stability during
sampling.
"""
alphas2 = np.concatenate([np.ones(1), alphas2], axis=0)
alphas_step = (alphas2[1:] / alphas2[:-1])
alphas_step = np.clip(alphas_step, a_min=clip_value, a_max=1.)
alphas2 = np.cumprod(alphas_step, axis=0)
return alphas2
def polynomial_schedule(timesteps: int, s=1e-4, power=3.):
"""
A noise schedule based on a simple polynomial equation: 1 - x^power.
"""
steps = timesteps + 1
x = np.linspace(0, steps, steps)
alphas2 = (1 - np.power(x / steps, power)) ** 2
alphas2 = clip_noise_schedule(alphas2, clip_value=0.001)
precision = 1 - 2 * s
alphas2 = precision * alphas2 + s
return alphas2
def cosine_beta_schedule(timesteps, s=0.008, raise_to_power: float = 1):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 2
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
betas = np.clip(betas, a_min=0, a_max=0.999)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
if raise_to_power != 1:
alphas_cumprod = np.power(alphas_cumprod, raise_to_power)
return alphas_cumprod
class PositiveLinear(torch.nn.Module):
"""Linear layer with weights forced to be positive."""
def __init__(self, in_features: int, out_features: int, bias: bool = True,
weight_init_offset: int = -2):
super(PositiveLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(
torch.empty((out_features, in_features)))
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features))
else:
self.register_parameter('bias', None)
self.weight_init_offset = weight_init_offset
self.reset_parameters()
def reset_parameters(self) -> None:
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
with torch.no_grad():
self.weight.add_(self.weight_init_offset)
if self.bias is not None:
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
torch.nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x):
positive_weight = F.softplus(self.weight)
return F.linear(x, positive_weight, self.bias)
class PredefinedNoiseSchedule(torch.nn.Module):
"""
Predefined noise schedule. Essentially creates a lookup array for predefined (non-learned) noise schedules.
"""
def __init__(self, noise_schedule, timesteps, precision):
super(PredefinedNoiseSchedule, self).__init__()
self.timesteps = timesteps
if noise_schedule == 'cosine':
alphas2 = cosine_beta_schedule(timesteps)
elif 'polynomial' in noise_schedule:
splits = noise_schedule.split('_')
assert len(splits) == 2
power = float(splits[1])
alphas2 = polynomial_schedule(timesteps, s=precision, power=power)
else:
raise ValueError(noise_schedule)
# print('alphas2', alphas2)
sigmas2 = 1 - alphas2
log_alphas2 = np.log(alphas2)
log_sigmas2 = np.log(sigmas2)
log_alphas2_to_sigmas2 = log_alphas2 - log_sigmas2
# print('gamma', -log_alphas2_to_sigmas2)
self.gamma = torch.nn.Parameter(
torch.from_numpy(-log_alphas2_to_sigmas2).float(),
requires_grad=False)
def forward(self, t):
t_int = torch.round(t * self.timesteps).long()
return self.gamma[t_int]
class GammaNetwork(torch.nn.Module):
"""The gamma network models a monotonic increasing function. Construction as in the VDM paper."""
def __init__(self):
super().__init__()
self.l1 = PositiveLinear(1, 1)
self.l2 = PositiveLinear(1, 1024)
self.l3 = PositiveLinear(1024, 1)
self.gamma_0 = torch.nn.Parameter(torch.tensor([-5.]))
self.gamma_1 = torch.nn.Parameter(torch.tensor([10.]))
self.show_schedule()
def show_schedule(self, num_steps=50):
t = torch.linspace(0, 1, num_steps).view(num_steps, 1)
gamma = self.forward(t)
print('Gamma schedule:')
print(gamma.detach().cpu().numpy().reshape(num_steps))
def gamma_tilde(self, t):
l1_t = self.l1(t)
return l1_t + self.l3(torch.sigmoid(self.l2(l1_t)))
def forward(self, t):
zeros, ones = torch.zeros_like(t), torch.ones_like(t)
# Not super efficient.
gamma_tilde_0 = self.gamma_tilde(zeros)
gamma_tilde_1 = self.gamma_tilde(ones)
gamma_tilde_t = self.gamma_tilde(t)
# Normalize to [0, 1]
normalized_gamma = (gamma_tilde_t - gamma_tilde_0) / (
gamma_tilde_1 - gamma_tilde_0)
# Rescale to [gamma_0, gamma_1]
gamma = self.gamma_0 + (self.gamma_1 - self.gamma_0) * normalized_gamma
return gamma