Spaces:
Build error
Build error
| import torch.nn as nn # pylint: disable=consider-using-from-import | |
| from torch.nn.utils import parametrize | |
| class KernelPredictor(nn.Module): | |
| """Kernel predictor for the location-variable convolutions | |
| Args: | |
| cond_channels (int): number of channel for the conditioning sequence, | |
| conv_in_channels (int): number of channel for the input sequence, | |
| conv_out_channels (int): number of channel for the output sequence, | |
| conv_layers (int): number of layers | |
| """ | |
| def __init__( # pylint: disable=dangerous-default-value | |
| self, | |
| cond_channels, | |
| conv_in_channels, | |
| conv_out_channels, | |
| conv_layers, | |
| conv_kernel_size=3, | |
| kpnet_hidden_channels=64, | |
| kpnet_conv_size=3, | |
| kpnet_dropout=0.0, | |
| kpnet_nonlinear_activation="LeakyReLU", | |
| kpnet_nonlinear_activation_params={"negative_slope": 0.1}, | |
| ): | |
| super().__init__() | |
| self.conv_in_channels = conv_in_channels | |
| self.conv_out_channels = conv_out_channels | |
| self.conv_kernel_size = conv_kernel_size | |
| self.conv_layers = conv_layers | |
| kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w | |
| kpnet_bias_channels = conv_out_channels * conv_layers # l_b | |
| self.input_conv = nn.Sequential( | |
| nn.utils.parametrizations.weight_norm( | |
| nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True) | |
| ), | |
| getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), | |
| ) | |
| self.residual_convs = nn.ModuleList() | |
| padding = (kpnet_conv_size - 1) // 2 | |
| for _ in range(3): | |
| self.residual_convs.append( | |
| nn.Sequential( | |
| nn.Dropout(kpnet_dropout), | |
| nn.utils.parametrizations.weight_norm( | |
| nn.Conv1d( | |
| kpnet_hidden_channels, | |
| kpnet_hidden_channels, | |
| kpnet_conv_size, | |
| padding=padding, | |
| bias=True, | |
| ) | |
| ), | |
| getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), | |
| nn.utils.parametrizations.weight_norm( | |
| nn.Conv1d( | |
| kpnet_hidden_channels, | |
| kpnet_hidden_channels, | |
| kpnet_conv_size, | |
| padding=padding, | |
| bias=True, | |
| ) | |
| ), | |
| getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), | |
| ) | |
| ) | |
| self.kernel_conv = nn.utils.parametrizations.weight_norm( | |
| nn.Conv1d( | |
| kpnet_hidden_channels, | |
| kpnet_kernel_channels, | |
| kpnet_conv_size, | |
| padding=padding, | |
| bias=True, | |
| ) | |
| ) | |
| self.bias_conv = nn.utils.parametrizations.weight_norm( | |
| nn.Conv1d( | |
| kpnet_hidden_channels, | |
| kpnet_bias_channels, | |
| kpnet_conv_size, | |
| padding=padding, | |
| bias=True, | |
| ) | |
| ) | |
| def forward(self, c): | |
| """ | |
| Args: | |
| c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) | |
| """ | |
| batch, _, cond_length = c.shape | |
| c = self.input_conv(c) | |
| for residual_conv in self.residual_convs: | |
| residual_conv.to(c.device) | |
| c = c + residual_conv(c) | |
| k = self.kernel_conv(c) | |
| b = self.bias_conv(c) | |
| kernels = k.contiguous().view( | |
| batch, | |
| self.conv_layers, | |
| self.conv_in_channels, | |
| self.conv_out_channels, | |
| self.conv_kernel_size, | |
| cond_length, | |
| ) | |
| bias = b.contiguous().view( | |
| batch, | |
| self.conv_layers, | |
| self.conv_out_channels, | |
| cond_length, | |
| ) | |
| return kernels, bias | |
| def remove_weight_norm(self): | |
| parametrize.remove_parametrizations(self.input_conv[0], "weight") | |
| parametrize.remove_parametrizations(self.kernel_conv, "weight") | |
| parametrize.remove_parametrizations(self.bias_conv, "weight") | |
| for block in self.residual_convs: | |
| parametrize.remove_parametrizations(block[1], "weight") | |
| parametrize.remove_parametrizations(block[3], "weight") | |