Shadhil's picture
voice-clone with single audio sample input
9b2107c
raw
history blame
No virus
4.65 kB
import torch.nn as nn # pylint: disable=consider-using-from-import
from torch.nn.utils import parametrize
class KernelPredictor(nn.Module):
"""Kernel predictor for the location-variable convolutions
Args:
cond_channels (int): number of channel for the conditioning sequence,
conv_in_channels (int): number of channel for the input sequence,
conv_out_channels (int): number of channel for the output sequence,
conv_layers (int): number of layers
"""
def __init__( # pylint: disable=dangerous-default-value
self,
cond_channels,
conv_in_channels,
conv_out_channels,
conv_layers,
conv_kernel_size=3,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
kpnet_nonlinear_activation="LeakyReLU",
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
):
super().__init__()
self.conv_in_channels = conv_in_channels
self.conv_out_channels = conv_out_channels
self.conv_kernel_size = conv_kernel_size
self.conv_layers = conv_layers
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
self.input_conv = nn.Sequential(
nn.utils.parametrizations.weight_norm(
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
self.residual_convs = nn.ModuleList()
padding = (kpnet_conv_size - 1) // 2
for _ in range(3):
self.residual_convs.append(
nn.Sequential(
nn.Dropout(kpnet_dropout),
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
)
self.kernel_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_kernel_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
)
self.bias_conv = nn.utils.parametrizations.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_bias_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
)
def forward(self, c):
"""
Args:
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
"""
batch, _, cond_length = c.shape
c = self.input_conv(c)
for residual_conv in self.residual_convs:
residual_conv.to(c.device)
c = c + residual_conv(c)
k = self.kernel_conv(c)
b = self.bias_conv(c)
kernels = k.contiguous().view(
batch,
self.conv_layers,
self.conv_in_channels,
self.conv_out_channels,
self.conv_kernel_size,
cond_length,
)
bias = b.contiguous().view(
batch,
self.conv_layers,
self.conv_out_channels,
cond_length,
)
return kernels, bias
def remove_weight_norm(self):
parametrize.remove_parametrizations(self.input_conv[0], "weight")
parametrize.remove_parametrizations(self.kernel_conv, "weight")
parametrize.remove_parametrizations(self.bias_conv, "weight")
for block in self.residual_convs:
parametrize.remove_parametrizations(block[1], "weight")
parametrize.remove_parametrizations(block[3], "weight")