diffsingerkr / Modules /Modules.py
codejin's picture
initial commit
67d041f
raw
history blame
No virus
8.66 kB
from argparse import Namespace
import torch
import math
from typing import Union
from .Layer import Conv1d, LayerNorm, LinearAttention
from .Diffusion import Diffusion
class DiffSinger(torch.nn.Module):
def __init__(self, hyper_parameters: Namespace):
super().__init__()
self.hp = hyper_parameters
self.encoder = Encoder(self.hp)
self.diffusion = Diffusion(self.hp)
def forward(
self,
tokens: torch.LongTensor,
notes: torch.LongTensor,
durations: torch.LongTensor,
lengths: torch.LongTensor,
genres: torch.LongTensor,
singers: torch.LongTensor,
features: Union[torch.FloatTensor, None]= None,
ddim_steps: Union[int, None]= None
):
encodings, linear_predictions = self.encoder(
tokens= tokens,
notes= notes,
durations= durations,
lengths= lengths,
genres= genres,
singers= singers
) # [Batch, Enc_d, Feature_t]
encodings = torch.cat([encodings, linear_predictions], dim= 1) # [Batch, Enc_d + Feature_d, Feature_t]
if not features is None or ddim_steps is None or ddim_steps == self.hp.Diffusion.Max_Step:
diffusion_predictions, noises, epsilons = self.diffusion(
encodings= encodings,
features= features,
)
else:
noises, epsilons = None, None
diffusion_predictions = self.diffusion.DDIM(
encodings= encodings,
ddim_steps= ddim_steps
)
return linear_predictions, diffusion_predictions, noises, epsilons
class Encoder(torch.nn.Module):
def __init__(
self,
hyper_parameters: Namespace
):
super().__init__()
self.hp = hyper_parameters
if self.hp.Feature_Type == 'Mel':
self.feature_size = self.hp.Sound.Mel_Dim
elif self.hp.Feature_Type == 'Spectrogram':
self.feature_size = self.hp.Sound.N_FFT // 2 + 1
self.token_embedding = torch.nn.Embedding(
num_embeddings= self.hp.Tokens,
embedding_dim= self.hp.Encoder.Size
)
self.note_embedding = torch.nn.Embedding(
num_embeddings= self.hp.Notes,
embedding_dim= self.hp.Encoder.Size
)
self.duration_embedding = Duration_Positional_Encoding(
num_embeddings= self.hp.Durations,
embedding_dim= self.hp.Encoder.Size
)
self.genre_embedding = torch.nn.Embedding(
num_embeddings= self.hp.Genres,
embedding_dim= self.hp.Encoder.Size,
)
self.singer_embedding = torch.nn.Embedding(
num_embeddings= self.hp.Singers,
embedding_dim= self.hp.Encoder.Size,
)
torch.nn.init.xavier_uniform_(self.token_embedding.weight)
torch.nn.init.xavier_uniform_(self.note_embedding.weight)
torch.nn.init.xavier_uniform_(self.genre_embedding.weight)
torch.nn.init.xavier_uniform_(self.singer_embedding.weight)
self.fft_blocks = torch.nn.ModuleList([
FFT_Block(
channels= self.hp.Encoder.Size,
num_head= self.hp.Encoder.ConvFFT.Head,
ffn_kernel_size= self.hp.Encoder.ConvFFT.FFN.Kernel_Size,
dropout_rate= self.hp.Encoder.ConvFFT.Dropout_Rate
)
for _ in range(self.hp.Encoder.ConvFFT.Stack)
])
self.linear_projection = Conv1d(
in_channels= self.hp.Encoder.Size,
out_channels= self.feature_size,
kernel_size= 1,
bias= True,
w_init_gain= 'linear'
)
def forward(
self,
tokens: torch.Tensor,
notes: torch.Tensor,
durations: torch.Tensor,
lengths: torch.Tensor,
genres: torch.Tensor,
singers: torch.Tensor
):
x = \
self.token_embedding(tokens) + \
self.note_embedding(notes) + \
self.duration_embedding(durations) + \
self.genre_embedding(genres).unsqueeze(1) + \
self.singer_embedding(singers).unsqueeze(1)
x = x.permute(0, 2, 1) # [Batch, Enc_d, Enc_t]
for block in self.fft_blocks:
x = block(x, lengths) # [Batch, Enc_d, Enc_t]
linear_predictions = self.linear_projection(x) # [Batch, Feature_d, Enc_t]
return x, linear_predictions
class FFT_Block(torch.nn.Module):
def __init__(
self,
channels: int,
num_head: int,
ffn_kernel_size: int,
dropout_rate: float= 0.1,
) -> None:
super().__init__()
self.attention = LinearAttention(
channels= channels,
calc_channels= channels,
num_heads= num_head,
dropout_rate= dropout_rate
)
self.ffn = FFN(
channels= channels,
kernel_size= ffn_kernel_size,
dropout_rate= dropout_rate
)
def forward(
self,
x: torch.Tensor,
lengths: torch.Tensor
) -> torch.Tensor:
'''
x: [Batch, Dim, Time]
'''
masks = (~Mask_Generate(lengths= lengths, max_length= torch.ones_like(x[0, 0]).sum())).unsqueeze(1).float() # float mask
# Attention + Dropout + LayerNorm
x = self.attention(x)
# FFN + Dropout + LayerNorm
x = self.ffn(x, masks)
return x * masks
class FFN(torch.nn.Module):
def __init__(
self,
channels: int,
kernel_size: int,
dropout_rate: float= 0.1,
) -> None:
super().__init__()
self.conv_0 = Conv1d(
in_channels= channels,
out_channels= channels,
kernel_size= kernel_size,
padding= (kernel_size - 1) // 2,
w_init_gain= 'relu'
)
self.relu = torch.nn.ReLU()
self.dropout = torch.nn.Dropout(p= dropout_rate)
self.conv_1 = Conv1d(
in_channels= channels,
out_channels= channels,
kernel_size= kernel_size,
padding= (kernel_size - 1) // 2,
w_init_gain= 'linear'
)
self.norm = LayerNorm(
num_features= channels,
)
def forward(
self,
x: torch.Tensor,
masks: torch.Tensor
) -> torch.Tensor:
'''
x: [Batch, Dim, Time]
'''
residuals = x
x = self.conv_0(x * masks)
x = self.relu(x)
x = self.dropout(x)
x = self.conv_1(x * masks)
x = self.dropout(x)
x = self.norm(x + residuals)
return x * masks
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
# https://github.com/soobinseo/Transformer-TTS/blob/master/network.py
class Duration_Positional_Encoding(torch.nn.Embedding):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
):
positional_embedding = torch.zeros(num_embeddings, embedding_dim)
position = torch.arange(0, num_embeddings, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
positional_embedding[:, 0::2] = torch.sin(position * div_term)
positional_embedding[:, 1::2] = torch.cos(position * div_term)
super().__init__(
num_embeddings= num_embeddings,
embedding_dim= embedding_dim,
_weight= positional_embedding
)
self.weight.requires_grad = False
self.alpha = torch.nn.Parameter(
data= torch.ones(1) * 0.01,
requires_grad= True
)
def forward(self, durations):
'''
durations: [Batch, Length]
'''
return self.alpha * super().forward(durations) # [Batch, Dim, Length]
@torch.jit.script
def get_pe(x: torch.Tensor, pe: torch.Tensor):
pe = pe.repeat(1, 1, math.ceil(x.size(2) / pe.size(2)))
return pe[:, :, :x.size(2)]
def Mask_Generate(lengths: torch.Tensor, max_length: Union[torch.Tensor, int, None]= None):
'''
lengths: [Batch]
max_lengths: an int value. If None, max_lengths == max(lengths)
'''
max_length = max_length or torch.max(lengths)
sequence = torch.arange(max_length)[None, :].to(lengths.device)
return sequence >= lengths[:, None] # [Batch, Time]