diffsingerkr / Modules /Diffusion.py
codejin's picture
initial commit
67d041f
raw
history blame
No virus
14 kB
import torch
import math
from argparse import Namespace
from typing import Optional, List, Dict, Union
from tqdm import tqdm
from .Layer import Conv1d, Lambda
class Diffusion(torch.nn.Module):
def __init__(
self,
hyper_parameters: Namespace
):
super().__init__()
self.hp = hyper_parameters
if self.hp.Feature_Type == 'Mel':
self.feature_size = self.hp.Sound.Mel_Dim
elif self.hp.Feature_Type == 'Spectrogram':
self.feature_size = self.hp.Sound.N_FFT // 2 + 1
self.denoiser = Denoiser(
hyper_parameters= self.hp
)
self.timesteps = self.hp.Diffusion.Max_Step
betas = torch.linspace(1e-4, 0.06, self.timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis= 0)
alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('alphas_cumprod', alphas_cumprod) # [Diffusion_t]
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) # [Diffusion_t]
self.register_buffer('sqrt_alphas_cumprod', alphas_cumprod.sqrt())
self.register_buffer('sqrt_one_minus_alphas_cumprod', (1.0 - alphas_cumprod).sqrt())
self.register_buffer('sqrt_recip_alphas_cumprod', (1.0 / alphas_cumprod).sqrt())
self.register_buffer('sqrt_recipm1_alphas_cumprod', (1.0 / alphas_cumprod - 1.0).sqrt())
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance', torch.maximum(posterior_variance, torch.tensor([1e-20])).log())
self.register_buffer('posterior_mean_coef1', betas * alphas_cumprod_prev.sqrt() / (1.0 - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev) * alphas.sqrt() / (1.0 - alphas_cumprod))
def forward(
self,
encodings: torch.Tensor,
features: torch.Tensor= None
):
'''
encodings: [Batch, Enc_d, Enc_t]
features: [Batch, Feature_d, Feature_t]
feature_lengths: [Batch]
'''
if not features is None: # train
diffusion_steps = torch.randint(
low= 0,
high= self.timesteps,
size= (encodings.size(0),),
dtype= torch.long,
device= encodings.device
) # random single step
noises, epsilons = self.Get_Noise_Epsilon_for_Train(
features= features,
encodings= encodings,
diffusion_steps= diffusion_steps,
)
return None, noises, epsilons
else: # inference
features = self.Sampling(
encodings= encodings,
)
return features, None, None
def Sampling(
self,
encodings: torch.Tensor,
):
features = torch.randn(
size= (encodings.size(0), self.feature_size, encodings.size(2)),
device= encodings.device
)
for diffusion_step in reversed(range(self.timesteps)):
features = self.P_Sampling(
features= features,
encodings= encodings,
diffusion_steps= torch.full(
size= (encodings.size(0), ),
fill_value= diffusion_step,
dtype= torch.long,
device= encodings.device
),
)
return features
def P_Sampling(
self,
features: torch.Tensor,
encodings: torch.Tensor,
diffusion_steps: torch.Tensor,
):
posterior_means, posterior_log_variances = self.Get_Posterior(
features= features,
encodings= encodings,
diffusion_steps= diffusion_steps,
)
noises = torch.randn_like(features) # [Batch, Feature_d, Feature_d]
masks = (diffusion_steps > 0).float().unsqueeze(1).unsqueeze(1) #[Batch, 1, 1]
return posterior_means + masks * (0.5 * posterior_log_variances).exp() * noises
def Get_Posterior(
self,
features: torch.Tensor,
encodings: torch.Tensor,
diffusion_steps: torch.Tensor
):
noised_predictions = self.denoiser(
features= features,
encodings= encodings,
diffusion_steps= diffusion_steps
)
epsilons = \
features * self.sqrt_recip_alphas_cumprod[diffusion_steps][:, None, None] - \
noised_predictions * self.sqrt_recipm1_alphas_cumprod[diffusion_steps][:, None, None]
epsilons.clamp_(-1.0, 1.0) # clipped
posterior_means = \
epsilons * self.posterior_mean_coef1[diffusion_steps][:, None, None] + \
features * self.posterior_mean_coef2[diffusion_steps][:, None, None]
posterior_log_variances = \
self.posterior_log_variance[diffusion_steps][:, None, None]
return posterior_means, posterior_log_variances
def Get_Noise_Epsilon_for_Train(
self,
features: torch.Tensor,
encodings: torch.Tensor,
diffusion_steps: torch.Tensor,
):
noises = torch.randn_like(features)
noised_features = \
features * self.sqrt_alphas_cumprod[diffusion_steps][:, None, None] + \
noises * self.sqrt_one_minus_alphas_cumprod[diffusion_steps][:, None, None]
epsilons = self.denoiser(
features= noised_features,
encodings= encodings,
diffusion_steps= diffusion_steps
)
return noises, epsilons
def DDIM(
self,
encodings: torch.Tensor,
ddim_steps: int,
eta: float= 0.0,
temperature: float= 1.0,
use_tqdm: bool= False
):
ddim_timesteps = self.Get_DDIM_Steps(
ddim_steps= ddim_steps
)
sigmas, alphas, alphas_prev = self.Get_DDIM_Sampling_Parameters(
ddim_timesteps= ddim_timesteps,
eta= eta
)
sqrt_one_minus_alphas = (1. - alphas).sqrt()
features = torch.randn(
size= (encodings.size(0), self.feature_size, encodings.size(2)),
device= encodings.device
)
setp_range = reversed(range(ddim_steps))
if use_tqdm:
tqdm(
setp_range,
desc= '[Diffusion]',
total= ddim_steps
)
for diffusion_steps in setp_range:
noised_predictions = self.denoiser(
features= features,
encodings= encodings,
diffusion_steps= torch.full(
size= (encodings.size(0), ),
fill_value= diffusion_steps,
dtype= torch.long,
device= encodings.device
)
)
feature_starts = (features - sqrt_one_minus_alphas[diffusion_steps] * noised_predictions) / alphas[diffusion_steps].sqrt()
direction_pointings = (1.0 - alphas_prev[diffusion_steps] - sigmas[diffusion_steps].pow(2.0)) * noised_predictions
noises = sigmas[diffusion_steps] * torch.randn_like(features) * temperature
features = alphas_prev[diffusion_steps].sqrt() * feature_starts + direction_pointings + noises
return features
# https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py
def Get_DDIM_Steps(
self,
ddim_steps: int,
ddim_discr_method: str= 'uniform'
):
if ddim_discr_method == 'uniform':
ddim_timesteps = torch.arange(0, self.timesteps, self.timesteps // ddim_steps).long()
elif ddim_discr_method == 'quad':
ddim_timesteps = torch.linspace(0, (torch.tensor(self.timesteps) * 0.8).sqrt(), ddim_steps).pow(2.0).long()
else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
ddim_timesteps[-1] = self.timesteps - 1
return ddim_timesteps
def Get_DDIM_Sampling_Parameters(self, ddim_timesteps, eta):
alphas = self.alphas_cumprod[ddim_timesteps]
alphas_prev = self.alphas_cumprod_prev[ddim_timesteps]
sigmas = eta * ((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)).sqrt()
return sigmas, alphas, alphas_prev
class Denoiser(torch.nn.Module):
def __init__(
self,
hyper_parameters: Namespace
):
super().__init__()
self.hp = hyper_parameters
if self.hp.Feature_Type == 'Mel':
feature_size = self.hp.Sound.Mel_Dim
elif self.hp.Feature_Type == 'Spectrogram':
feature_size = self.hp.Sound.N_FFT // 2 + 1
self.prenet = torch.nn.Sequential(
Conv1d(
in_channels= feature_size,
out_channels= self.hp.Diffusion.Size,
kernel_size= 1,
w_init_gain= 'relu'
),
torch.nn.Mish()
)
self.step_ffn = torch.nn.Sequential(
Diffusion_Embedding(
channels= self.hp.Diffusion.Size
),
Lambda(lambda x: x.unsqueeze(2)),
Conv1d(
in_channels= self.hp.Diffusion.Size,
out_channels= self.hp.Diffusion.Size * 4,
kernel_size= 1,
w_init_gain= 'relu'
),
torch.nn.Mish(),
Conv1d(
in_channels= self.hp.Diffusion.Size * 4,
out_channels= self.hp.Diffusion.Size,
kernel_size= 1,
w_init_gain= 'linear'
)
)
self.residual_blocks = torch.nn.ModuleList([
Residual_Block(
in_channels= self.hp.Diffusion.Size,
kernel_size= self.hp.Diffusion.Kernel_Size,
condition_channels= self.hp.Encoder.Size + feature_size
)
for _ in range(self.hp.Diffusion.Stack)
])
self.projection = torch.nn.Sequential(
Conv1d(
in_channels= self.hp.Diffusion.Size,
out_channels= self.hp.Diffusion.Size,
kernel_size= 1,
w_init_gain= 'relu'
),
torch.nn.ReLU(),
Conv1d(
in_channels= self.hp.Diffusion.Size,
out_channels= feature_size,
kernel_size= 1
),
)
torch.nn.init.zeros_(self.projection[-1].weight) # This is key factor....
def forward(
self,
features: torch.Tensor,
encodings: torch.Tensor,
diffusion_steps: torch.Tensor
):
'''
features: [Batch, Feature_d, Feature_t]
encodings: [Batch, Enc_d, Feature_t]
diffusion_steps: [Batch]
'''
x = self.prenet(features)
diffusion_steps = self.step_ffn(diffusion_steps) # [Batch, Res_d, 1]
skips_list = []
for residual_block in self.residual_blocks:
x, skips = residual_block(
x= x,
conditions= encodings,
diffusion_steps= diffusion_steps
)
skips_list.append(skips)
x = torch.stack(skips_list, dim= 0).sum(dim= 0) / math.sqrt(self.hp.Diffusion.Stack)
x = self.projection(x)
return x
class Diffusion_Embedding(torch.nn.Module):
def __init__(
self,
channels: int
):
super().__init__()
self.channels = channels
def forward(self, x: torch.Tensor):
half_channels = self.channels // 2 # sine and cosine
embeddings = math.log(10000.0) / (half_channels - 1)
embeddings = torch.exp(torch.arange(half_channels, device= x.device) * -embeddings)
embeddings = x.unsqueeze(1) * embeddings.unsqueeze(0)
embeddings = torch.cat([embeddings.sin(), embeddings.cos()], dim= -1)
return embeddings
class Residual_Block(torch.nn.Module):
def __init__(
self,
in_channels: int,
kernel_size: int,
condition_channels: int
):
super().__init__()
self.in_channels = in_channels
self.condition = Conv1d(
in_channels= condition_channels,
out_channels= in_channels * 2,
kernel_size= 1
)
self.diffusion_step = Conv1d(
in_channels= in_channels,
out_channels= in_channels,
kernel_size= 1
)
self.conv = Conv1d(
in_channels= in_channels,
out_channels= in_channels * 2,
kernel_size= kernel_size,
padding= kernel_size // 2
)
self.projection = Conv1d(
in_channels= in_channels,
out_channels= in_channels * 2,
kernel_size= 1
)
def forward(
self,
x: torch.Tensor,
conditions: torch.Tensor,
diffusion_steps: torch.Tensor
):
residuals = x
conditions = self.condition(conditions)
diffusion_steps = self.diffusion_step(diffusion_steps)
x = self.conv(x + diffusion_steps) + conditions
x_a, x_b = x.chunk(chunks= 2, dim= 1)
x = x_a.sigmoid() * x_b.tanh()
x = self.projection(x)
x, skips = x.chunk(chunks= 2, dim= 1)
return (x + residuals) / math.sqrt(2.0), skips