grad-svc / grad /solver.py
maxmax20160403's picture
Upload 39 files
3aa4060
raw
history blame
No virus
7.4 kB
import torch
class NoiseScheduleVP:
def __init__(self, beta_min=0.05, beta_max=20):
self.beta_min = beta_min
self.beta_max = beta_max
self.T = 1.
def get_noise(self, t, beta_init, beta_term, cumulative=False):
if cumulative:
noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)
else:
noise = beta_init + (beta_term - beta_init)*t
return noise
def marginal_log_mean_coeff(self, t):
return -0.25 * t**2 * (self.beta_max -
self.beta_min) - 0.5 * t * self.beta_min
def marginal_std(self, t):
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
def marginal_lambda(self, t):
log_mean_coeff = self.marginal_log_mean_coeff(t)
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
return log_mean_coeff - log_std
def inverse_lambda(self, lamb):
tmp = 2. * (self.beta_max - self.beta_min) * torch.logaddexp(
-2. * lamb,
torch.zeros((1, )).to(lamb))
Delta = self.beta_min**2 + tmp
return tmp / (torch.sqrt(Delta) + self.beta_min) / (self.beta_max -
self.beta_min)
def get_time_steps(self, t_T, t_0, N):
lambda_T = self.marginal_lambda(torch.tensor(t_T))
lambda_0 = self.marginal_lambda(torch.tensor(t_0))
logSNR_steps = torch.linspace(lambda_T, lambda_0, N + 1)
return self.inverse_lambda(logSNR_steps)
@torch.no_grad()
def reverse_diffusion(self, estimator, spk, z, mask, mu, n_timesteps, stoc):
print("use dpm-solver reverse")
xt = z * mask
yt = xt - mu
T = 1
eps = 1e-3
time = self.get_time_steps(T, eps, n_timesteps)
for i in range(n_timesteps):
s = torch.ones((xt.shape[0], )).to(xt.device) * time[i]
t = torch.ones((xt.shape[0], )).to(xt.device) * time[i + 1]
lambda_s = self.marginal_lambda(s)
lambda_t = self.marginal_lambda(t)
h = lambda_t - lambda_s
log_alpha_s = self.marginal_log_mean_coeff(s)
log_alpha_t = self.marginal_log_mean_coeff(t)
sigma_t = self.marginal_std(t)
phi_1 = torch.expm1(h)
noise_s = estimator(spk, yt + mu, mask, mu, s)
lt = 1 - torch.exp(-self.get_noise(s, self.beta_min, self.beta_max, cumulative=True))
a = torch.exp(log_alpha_t - log_alpha_s)
b = sigma_t * phi_1 * torch.sqrt(lt)
yt = a * yt + (b * noise_s)
xt = yt + mu
return xt
class MaxLikelihood:
def __init__(self, beta_min=0.05, beta_max=20):
self.beta_min = beta_min
self.beta_max = beta_max
def get_noise(self, t, beta_init, beta_term, cumulative=False):
if cumulative:
noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)
else:
noise = beta_init + (beta_term - beta_init)*t
return noise
def get_gamma(self, s, t, beta_init, beta_term):
gamma = beta_init*(t-s) + 0.5*(beta_term-beta_init)*(t**2-s**2)
gamma = torch.exp(-0.5*gamma)
return gamma
def get_mu(self, s, t):
gamma_0_s = self.get_gamma(0, s, self.beta_min, self.beta_max)
gamma_0_t = self.get_gamma(0, t, self.beta_min, self.beta_max)
gamma_s_t = self.get_gamma(s, t, self.beta_min, self.beta_max)
mu = gamma_s_t * ((1-gamma_0_s**2) / (1-gamma_0_t**2))
return mu
def get_nu(self, s, t):
gamma_0_s = self.get_gamma(0, s, self.beta_min, self.beta_max)
gamma_0_t = self.get_gamma(0, t, self.beta_min, self.beta_max)
gamma_s_t = self.get_gamma(s, t, self.beta_min, self.beta_max)
nu = gamma_0_s * ((1-gamma_s_t**2) / (1-gamma_0_t**2))
return nu
def get_sigma(self, s, t):
gamma_0_s = self.get_gamma(0, s, self.beta_min, self.beta_max)
gamma_0_t = self.get_gamma(0, t, self.beta_min, self.beta_max)
gamma_s_t = self.get_gamma(s, t, self.beta_min, self.beta_max)
sigma = torch.sqrt(((1 - gamma_0_s**2) * (1 - gamma_s_t**2)) / (1 - gamma_0_t**2))
return sigma
def get_kappa(self, t, h, noise):
nu = self.get_nu(t-h, t)
gamma_0_t = self.get_gamma(0, t, self.beta_min, self.beta_max)
kappa = (nu*(1-gamma_0_t**2)/(gamma_0_t*noise*h) - 1)
return kappa
def get_omega(self, t, h, noise):
mu = self.get_mu(t-h, t)
kappa = self.get_kappa(t, h, noise)
gamma_0_t = self.get_gamma(0, t, self.beta_min, self.beta_max)
omega = (mu-1)/(noise*h) + (1+kappa)/(1-gamma_0_t**2) - 0.5
return omega
@torch.no_grad()
def reverse_diffusion(self, estimator, spk, z, mask, mu, n_timesteps, stoc=False):
print("use MaxLikelihood reverse")
h = 1.0 / n_timesteps
xt = z * mask
for i in range(n_timesteps):
t = (1.0 - i*h) * torch.ones(z.shape[0], dtype=z.dtype,
device=z.device)
time = t.unsqueeze(-1).unsqueeze(-1)
noise_t = self.get_noise(time, self.beta_min, self.beta_max,
cumulative=False)
kappa_t_h = self.get_kappa(t, h, noise_t)
omega_t_h = self.get_omega(t, h, noise_t)
sigma_t_h = self.get_sigma(t-h, t)
es = estimator(spk, xt, mask, mu, t)
dxt = ((0.5+omega_t_h)*(xt - mu) + (1+kappa_t_h) * es)
dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
requires_grad=False)
dxt_stoc = dxt_stoc * sigma_t_h
dxt = dxt * noise_t * h + dxt_stoc
xt = (xt + dxt) * mask
return xt
class GradRaw:
def __init__(self, beta_min=0.05, beta_max=20):
self.beta_min = beta_min
self.beta_max = beta_max
def get_noise(self, t, beta_init, beta_term, cumulative=False):
if cumulative:
noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)
else:
noise = beta_init + (beta_term - beta_init)*t
return noise
@torch.no_grad()
def reverse_diffusion(self, estimator, spk, z, mask, mu, n_timesteps, stoc=False):
print("use grad-raw reverse")
h = 1.0 / n_timesteps
xt = z * mask
for i in range(n_timesteps):
t = (1.0 - (i + 0.5)*h) * \
torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
time = t.unsqueeze(-1).unsqueeze(-1)
noise_t = self.get_noise(time, self.beta_min, self.beta_max,
cumulative=False)
if stoc: # adds stochastic term
dxt_det = 0.5 * (mu - xt) - estimator(spk, xt, mask, mu, t)
dxt_det = dxt_det * noise_t * h
dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
requires_grad=False)
dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
dxt = dxt_det + dxt_stoc
else:
dxt = 0.5 * (mu - xt - estimator(spk, xt, mask, mu, t))
dxt = dxt * noise_t * h
xt = (xt - dxt) * mask
return xt