|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
MAX_WAV_VALUE = 32768.0 |
|
|
|
class KernelPredictor(torch.nn.Module): |
|
''' Kernel predictor for the location-variable convolutions''' |
|
|
|
def __init__( |
|
self, |
|
cond_channels, |
|
conv_in_channels, |
|
conv_out_channels, |
|
conv_layers, |
|
conv_kernel_size=3, |
|
kpnet_hidden_channels=64, |
|
kpnet_conv_size=3, |
|
kpnet_dropout=0.0, |
|
kpnet_nonlinear_activation="LeakyReLU", |
|
kpnet_nonlinear_activation_params={"negative_slope": 0.1}, |
|
): |
|
''' |
|
Args: |
|
cond_channels (int): number of channel for the conditioning sequence, |
|
conv_in_channels (int): number of channel for the input sequence, |
|
conv_out_channels (int): number of channel for the output sequence, |
|
conv_layers (int): number of layers |
|
''' |
|
super().__init__() |
|
|
|
self.conv_in_channels = conv_in_channels |
|
self.conv_out_channels = conv_out_channels |
|
self.conv_kernel_size = conv_kernel_size |
|
self.conv_layers = conv_layers |
|
|
|
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers |
|
kpnet_bias_channels = conv_out_channels * conv_layers |
|
|
|
self.input_conv = nn.Sequential( |
|
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)), |
|
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), |
|
) |
|
|
|
self.residual_convs = nn.ModuleList() |
|
padding = (kpnet_conv_size - 1) // 2 |
|
for _ in range(3): |
|
self.residual_convs.append( |
|
nn.Sequential( |
|
nn.Dropout(kpnet_dropout), |
|
nn.utils.weight_norm( |
|
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, |
|
bias=True)), |
|
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), |
|
nn.utils.weight_norm( |
|
nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, |
|
bias=True)), |
|
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), |
|
) |
|
) |
|
self.kernel_conv = nn.utils.weight_norm( |
|
nn.Conv1d(kpnet_hidden_channels, kpnet_kernel_channels, kpnet_conv_size, padding=padding, bias=True)) |
|
self.bias_conv = nn.utils.weight_norm( |
|
nn.Conv1d(kpnet_hidden_channels, kpnet_bias_channels, kpnet_conv_size, padding=padding, bias=True)) |
|
|
|
def forward(self, c): |
|
''' |
|
Args: |
|
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) |
|
''' |
|
batch, _, cond_length = c.shape |
|
c = self.input_conv(c) |
|
for residual_conv in self.residual_convs: |
|
residual_conv.to(c.device) |
|
c = c + residual_conv(c) |
|
k = self.kernel_conv(c) |
|
b = self.bias_conv(c) |
|
kernels = k.contiguous().view( |
|
batch, |
|
self.conv_layers, |
|
self.conv_in_channels, |
|
self.conv_out_channels, |
|
self.conv_kernel_size, |
|
cond_length, |
|
) |
|
bias = b.contiguous().view( |
|
batch, |
|
self.conv_layers, |
|
self.conv_out_channels, |
|
cond_length, |
|
) |
|
|
|
return kernels, bias |
|
|
|
def remove_weight_norm(self): |
|
nn.utils.remove_weight_norm(self.input_conv[0]) |
|
nn.utils.remove_weight_norm(self.kernel_conv) |
|
nn.utils.remove_weight_norm(self.bias_conv) |
|
for block in self.residual_convs: |
|
nn.utils.remove_weight_norm(block[1]) |
|
nn.utils.remove_weight_norm(block[3]) |
|
|
|
|
|
class LVCBlock(torch.nn.Module): |
|
'''the location-variable convolutions''' |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
cond_channels, |
|
stride, |
|
dilations=[1, 3, 9, 27], |
|
lReLU_slope=0.2, |
|
conv_kernel_size=3, |
|
cond_hop_length=256, |
|
kpnet_hidden_channels=64, |
|
kpnet_conv_size=3, |
|
kpnet_dropout=0.0, |
|
): |
|
super().__init__() |
|
|
|
self.cond_hop_length = cond_hop_length |
|
self.conv_layers = len(dilations) |
|
self.conv_kernel_size = conv_kernel_size |
|
|
|
self.kernel_predictor = KernelPredictor( |
|
cond_channels=cond_channels, |
|
conv_in_channels=in_channels, |
|
conv_out_channels=2 * in_channels, |
|
conv_layers=len(dilations), |
|
conv_kernel_size=conv_kernel_size, |
|
kpnet_hidden_channels=kpnet_hidden_channels, |
|
kpnet_conv_size=kpnet_conv_size, |
|
kpnet_dropout=kpnet_dropout, |
|
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope} |
|
) |
|
|
|
self.convt_pre = nn.Sequential( |
|
nn.LeakyReLU(lReLU_slope), |
|
nn.utils.weight_norm(nn.ConvTranspose1d(in_channels, in_channels, 2 * stride, stride=stride, |
|
padding=stride // 2 + stride % 2, output_padding=stride % 2)), |
|
) |
|
|
|
self.conv_blocks = nn.ModuleList() |
|
for dilation in dilations: |
|
self.conv_blocks.append( |
|
nn.Sequential( |
|
nn.LeakyReLU(lReLU_slope), |
|
nn.utils.weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size, |
|
padding=dilation * (conv_kernel_size - 1) // 2, dilation=dilation)), |
|
nn.LeakyReLU(lReLU_slope), |
|
) |
|
) |
|
|
|
def forward(self, x, c): |
|
''' forward propagation of the location-variable convolutions. |
|
Args: |
|
x (Tensor): the input sequence (batch, in_channels, in_length) |
|
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) |
|
|
|
Returns: |
|
Tensor: the output sequence (batch, in_channels, in_length) |
|
''' |
|
_, in_channels, _ = x.shape |
|
|
|
x = self.convt_pre(x) |
|
kernels, bias = self.kernel_predictor(c) |
|
|
|
for i, conv in enumerate(self.conv_blocks): |
|
output = conv(x) |
|
|
|
k = kernels[:, i, :, :, :, :] |
|
b = bias[:, i, :, :] |
|
|
|
output = self.location_variable_convolution(output, k, b, |
|
hop_size=self.cond_hop_length) |
|
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( |
|
output[:, in_channels:, :]) |
|
|
|
return x |
|
|
|
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): |
|
''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. |
|
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. |
|
Args: |
|
x (Tensor): the input sequence (batch, in_channels, in_length). |
|
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) |
|
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) |
|
dilation (int): the dilation of convolution. |
|
hop_size (int): the hop_size of the conditioning sequence. |
|
Returns: |
|
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). |
|
''' |
|
batch, _, in_length = x.shape |
|
batch, _, out_channels, kernel_size, kernel_length = kernel.shape |
|
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" |
|
|
|
padding = dilation * int((kernel_size - 1) / 2) |
|
x = F.pad(x, (padding, padding), 'constant', 0) |
|
x = x.unfold(2, hop_size + 2 * padding, hop_size) |
|
|
|
if hop_size < dilation: |
|
x = F.pad(x, (0, dilation), 'constant', 0) |
|
x = x.unfold(3, dilation, |
|
dilation) |
|
x = x[:, :, :, :, :hop_size] |
|
x = x.transpose(3, 4) |
|
x = x.unfold(4, kernel_size, 1) |
|
|
|
o = torch.einsum('bildsk,biokl->bolsd', x, kernel) |
|
o = o.to(memory_format=torch.channels_last_3d) |
|
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) |
|
o = o + bias |
|
o = o.contiguous().view(batch, out_channels, -1) |
|
|
|
return o |
|
|
|
def remove_weight_norm(self): |
|
self.kernel_predictor.remove_weight_norm() |
|
nn.utils.remove_weight_norm(self.convt_pre[1]) |
|
for block in self.conv_blocks: |
|
nn.utils.remove_weight_norm(block[1]) |
|
|
|
|
|
class UnivNetGenerator(nn.Module): |
|
"""UnivNet Generator""" |
|
|
|
def __init__(self, noise_dim=64, channel_size=32, dilations=[1,3,9,27], strides=[8,8,4], lReLU_slope=.2, kpnet_conv_size=3, |
|
|
|
hop_length=256, n_mel_channels=100): |
|
super(UnivNetGenerator, self).__init__() |
|
self.mel_channel = n_mel_channels |
|
self.noise_dim = noise_dim |
|
self.hop_length = hop_length |
|
channel_size = channel_size |
|
kpnet_conv_size = kpnet_conv_size |
|
|
|
self.res_stack = nn.ModuleList() |
|
hop_length = 1 |
|
for stride in strides: |
|
hop_length = stride * hop_length |
|
self.res_stack.append( |
|
LVCBlock( |
|
channel_size, |
|
n_mel_channels, |
|
stride=stride, |
|
dilations=dilations, |
|
lReLU_slope=lReLU_slope, |
|
cond_hop_length=hop_length, |
|
kpnet_conv_size=kpnet_conv_size |
|
) |
|
) |
|
|
|
self.conv_pre = \ |
|
nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode='reflect')) |
|
|
|
self.conv_post = nn.Sequential( |
|
nn.LeakyReLU(lReLU_slope), |
|
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode='reflect')), |
|
nn.Tanh(), |
|
) |
|
|
|
def forward(self, c, z): |
|
''' |
|
Args: |
|
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length) |
|
z (Tensor): the noise sequence (batch, noise_dim, in_length) |
|
|
|
''' |
|
z = self.conv_pre(z) |
|
|
|
for res_block in self.res_stack: |
|
res_block.to(z.device) |
|
z = res_block(z, c) |
|
|
|
z = self.conv_post(z) |
|
|
|
return z |
|
|
|
def eval(self, inference=False): |
|
super(UnivNetGenerator, self).eval() |
|
|
|
if inference: |
|
self.remove_weight_norm() |
|
|
|
def remove_weight_norm(self): |
|
print('Removing weight norm...') |
|
|
|
nn.utils.remove_weight_norm(self.conv_pre) |
|
|
|
for layer in self.conv_post: |
|
if len(layer.state_dict()) != 0: |
|
nn.utils.remove_weight_norm(layer) |
|
|
|
for res_block in self.res_stack: |
|
res_block.remove_weight_norm() |
|
|
|
def inference(self, c, z=None): |
|
|
|
|
|
zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device) |
|
mel = torch.cat((c, zero), dim=2) |
|
|
|
if z is None: |
|
z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device) |
|
|
|
audio = self.forward(mel, z) |
|
audio = audio[:, :, :-(self.hop_length * 10)] |
|
audio = audio.clamp(min=-1, max=1) |
|
return audio |
|
|
|
|
|
if __name__ == '__main__': |
|
model = UnivNetGenerator() |
|
|
|
c = torch.randn(3, 100, 10) |
|
z = torch.randn(3, 64, 10) |
|
print(c.shape) |
|
|
|
y = model(c, z) |
|
print(y.shape) |
|
assert y.shape == torch.Size([3, 1, 2560]) |
|
|
|
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
print(pytorch_total_params) |
|
|