import torch.nn as nn # pylint: disable=consider-using-from-import from torch.nn.utils import parametrize class KernelPredictor(nn.Module): """Kernel predictor for the location-variable convolutions Args: cond_channels (int): number of channel for the conditioning sequence, conv_in_channels (int): number of channel for the input sequence, conv_out_channels (int): number of channel for the output sequence, conv_layers (int): number of layers """ def __init__( # pylint: disable=dangerous-default-value self, cond_channels, conv_in_channels, conv_out_channels, conv_layers, conv_kernel_size=3, kpnet_hidden_channels=64, kpnet_conv_size=3, kpnet_dropout=0.0, kpnet_nonlinear_activation="LeakyReLU", kpnet_nonlinear_activation_params={"negative_slope": 0.1}, ): super().__init__() self.conv_in_channels = conv_in_channels self.conv_out_channels = conv_out_channels self.conv_kernel_size = conv_kernel_size self.conv_layers = conv_layers kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w kpnet_bias_channels = conv_out_channels * conv_layers # l_b self.input_conv = nn.Sequential( nn.utils.parametrizations.weight_norm( nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True) ), getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), ) self.residual_convs = nn.ModuleList() padding = (kpnet_conv_size - 1) // 2 for _ in range(3): self.residual_convs.append( nn.Sequential( nn.Dropout(kpnet_dropout), nn.utils.parametrizations.weight_norm( nn.Conv1d( kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True, ) ), getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), nn.utils.parametrizations.weight_norm( nn.Conv1d( kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True, ) ), getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), ) ) self.kernel_conv = nn.utils.parametrizations.weight_norm( nn.Conv1d( kpnet_hidden_channels, kpnet_kernel_channels, kpnet_conv_size, padding=padding, bias=True, ) ) self.bias_conv = nn.utils.parametrizations.weight_norm( nn.Conv1d( kpnet_hidden_channels, kpnet_bias_channels, kpnet_conv_size, padding=padding, bias=True, ) ) def forward(self, c): """ Args: c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) """ batch, _, cond_length = c.shape c = self.input_conv(c) for residual_conv in self.residual_convs: residual_conv.to(c.device) c = c + residual_conv(c) k = self.kernel_conv(c) b = self.bias_conv(c) kernels = k.contiguous().view( batch, self.conv_layers, self.conv_in_channels, self.conv_out_channels, self.conv_kernel_size, cond_length, ) bias = b.contiguous().view( batch, self.conv_layers, self.conv_out_channels, cond_length, ) return kernels, bias def remove_weight_norm(self): parametrize.remove_parametrizations(self.input_conv[0], "weight") parametrize.remove_parametrizations(self.kernel_conv, "weight") parametrize.remove_parametrizations(self.bias_conv, "weight") for block in self.residual_convs: parametrize.remove_parametrizations(block[1], "weight") parametrize.remove_parametrizations(block[3], "weight")