Spaces:
Build error
Build error
| from typing import Tuple | |
| import torch | |
| import torch.nn as nn # pylint: disable=consider-using-from-import | |
| import torch.nn.functional as F | |
| from torch.nn.utils import parametrize | |
| from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor | |
| def calc_same_padding(kernel_size: int) -> Tuple[int, int]: | |
| pad = kernel_size // 2 | |
| return (pad, pad - (kernel_size + 1) % 2) | |
| class ConvNorm(nn.Module): | |
| """A 1-dimensional convolutional layer with optional weight normalization. | |
| This layer wraps a 1D convolutional layer from PyTorch and applies | |
| optional weight normalization. The layer can be used in a similar way to | |
| the convolutional layers in PyTorch's `torch.nn` module. | |
| Args: | |
| in_channels (int): The number of channels in the input signal. | |
| out_channels (int): The number of channels in the output signal. | |
| kernel_size (int, optional): The size of the convolving kernel. | |
| Defaults to 1. | |
| stride (int, optional): The stride of the convolution. Defaults to 1. | |
| padding (int, optional): Zero-padding added to both sides of the input. | |
| If `None`, the padding will be calculated so that the output has | |
| the same length as the input. Defaults to `None`. | |
| dilation (int, optional): Spacing between kernel elements. Defaults to 1. | |
| bias (bool, optional): If `True`, add bias after convolution. Defaults to `True`. | |
| w_init_gain (str, optional): The weight initialization function to use. | |
| Can be either 'linear' or 'relu'. Defaults to 'linear'. | |
| use_weight_norm (bool, optional): If `True`, apply weight normalization | |
| to the convolutional weights. Defaults to `False`. | |
| Shapes: | |
| - Input: :math:`[N, D, T]` | |
| - Output: :math:`[N, out_dim, T]` where `out_dim` is the number of output dimensions. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=None, | |
| dilation=1, | |
| bias=True, | |
| w_init_gain="linear", | |
| use_weight_norm=False, | |
| ): | |
| super(ConvNorm, self).__init__() # pylint: disable=super-with-arguments | |
| if padding is None: | |
| assert kernel_size % 2 == 1 | |
| padding = int(dilation * (kernel_size - 1) / 2) | |
| self.kernel_size = kernel_size | |
| self.dilation = dilation | |
| self.use_weight_norm = use_weight_norm | |
| conv_fn = nn.Conv1d | |
| self.conv = conv_fn( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| bias=bias, | |
| ) | |
| nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain)) | |
| if self.use_weight_norm: | |
| self.conv = nn.utils.parametrizations.weight_norm(self.conv) | |
| def forward(self, signal, mask=None): | |
| conv_signal = self.conv(signal) | |
| if mask is not None: | |
| # always re-zero output if mask is | |
| # available to match zero-padding | |
| conv_signal = conv_signal * mask | |
| return conv_signal | |
| class ConvLSTMLinear(nn.Module): | |
| def __init__( | |
| self, | |
| in_dim, | |
| out_dim, | |
| n_layers=2, | |
| n_channels=256, | |
| kernel_size=3, | |
| p_dropout=0.1, | |
| lstm_type="bilstm", | |
| use_linear=True, | |
| ): | |
| super(ConvLSTMLinear, self).__init__() # pylint: disable=super-with-arguments | |
| self.out_dim = out_dim | |
| self.lstm_type = lstm_type | |
| self.use_linear = use_linear | |
| self.dropout = nn.Dropout(p=p_dropout) | |
| convolutions = [] | |
| for i in range(n_layers): | |
| conv_layer = ConvNorm( | |
| in_dim if i == 0 else n_channels, | |
| n_channels, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=int((kernel_size - 1) / 2), | |
| dilation=1, | |
| w_init_gain="relu", | |
| ) | |
| conv_layer = nn.utils.parametrizations.weight_norm(conv_layer.conv, name="weight") | |
| convolutions.append(conv_layer) | |
| self.convolutions = nn.ModuleList(convolutions) | |
| if not self.use_linear: | |
| n_channels = out_dim | |
| if self.lstm_type != "": | |
| use_bilstm = False | |
| lstm_channels = n_channels | |
| if self.lstm_type == "bilstm": | |
| use_bilstm = True | |
| lstm_channels = int(n_channels // 2) | |
| self.bilstm = nn.LSTM(n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm) | |
| lstm_norm_fn_pntr = nn.utils.spectral_norm | |
| self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0") | |
| if self.lstm_type == "bilstm": | |
| self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse") | |
| if self.use_linear: | |
| self.dense = nn.Linear(n_channels, out_dim) | |
| def run_padded_sequence(self, context, lens): | |
| context_embedded = [] | |
| for b_ind in range(context.size()[0]): # TODO: speed up | |
| curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone() | |
| for conv in self.convolutions: | |
| curr_context = self.dropout(F.relu(conv(curr_context))) | |
| context_embedded.append(curr_context[0].transpose(0, 1)) | |
| context = nn.utils.rnn.pad_sequence(context_embedded, batch_first=True) | |
| return context | |
| def run_unsorted_inputs(self, fn, context, lens): # pylint: disable=no-self-use | |
| lens_sorted, ids_sorted = torch.sort(lens, descending=True) | |
| unsort_ids = [0] * lens.size(0) | |
| for i in range(len(ids_sorted)): # pylint: disable=consider-using-enumerate | |
| unsort_ids[ids_sorted[i]] = i | |
| lens_sorted = lens_sorted.long().cpu() | |
| context = context[ids_sorted] | |
| context = nn.utils.rnn.pack_padded_sequence(context, lens_sorted, batch_first=True) | |
| context = fn(context)[0] | |
| context = nn.utils.rnn.pad_packed_sequence(context, batch_first=True)[0] | |
| # map back to original indices | |
| context = context[unsort_ids] | |
| return context | |
| def forward(self, context, lens): | |
| if context.size()[0] > 1: | |
| context = self.run_padded_sequence(context, lens) | |
| # to B, D, T | |
| context = context.transpose(1, 2) | |
| else: | |
| for conv in self.convolutions: | |
| context = self.dropout(F.relu(conv(context))) | |
| if self.lstm_type != "": | |
| context = context.transpose(1, 2) | |
| self.bilstm.flatten_parameters() | |
| if lens is not None: | |
| context = self.run_unsorted_inputs(self.bilstm, context, lens) | |
| else: | |
| context = self.bilstm(context)[0] | |
| context = context.transpose(1, 2) | |
| x_hat = context | |
| if self.use_linear: | |
| x_hat = self.dense(context.transpose(1, 2)).transpose(1, 2) | |
| return x_hat | |
| class DepthWiseConv1d(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int): | |
| super().__init__() | |
| self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=in_channels) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.conv(x) | |
| class PointwiseConv1d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| stride: int = 1, | |
| padding: int = 0, | |
| bias: bool = True, | |
| ): | |
| super().__init__() | |
| self.conv = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| stride=stride, | |
| padding=padding, | |
| bias=bias, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.conv(x) | |
| class BSConv1d(nn.Module): | |
| """https://arxiv.org/pdf/2003.13549.pdf""" | |
| def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int): | |
| super().__init__() | |
| self.pointwise = nn.Conv1d(channels_in, channels_out, kernel_size=1) | |
| self.depthwise = nn.Conv1d( | |
| channels_out, | |
| channels_out, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| groups=channels_out, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x1 = self.pointwise(x) | |
| x2 = self.depthwise(x1) | |
| return x2 | |
| class BSConv2d(nn.Module): | |
| """https://arxiv.org/pdf/2003.13549.pdf""" | |
| def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int): | |
| super().__init__() | |
| self.pointwise = nn.Conv2d(channels_in, channels_out, kernel_size=1) | |
| self.depthwise = nn.Conv2d( | |
| channels_out, | |
| channels_out, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| groups=channels_out, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x1 = self.pointwise(x) | |
| x2 = self.depthwise(x1) | |
| return x2 | |
| class Conv1dGLU(nn.Module): | |
| """From DeepVoice 3""" | |
| def __init__(self, d_model: int, kernel_size: int, padding: int, embedding_dim: int): | |
| super().__init__() | |
| self.conv = BSConv1d(d_model, 2 * d_model, kernel_size=kernel_size, padding=padding) | |
| self.embedding_proj = nn.Linear(embedding_dim, d_model) | |
| self.register_buffer("sqrt", torch.sqrt(torch.FloatTensor([0.5])).squeeze(0)) | |
| self.softsign = torch.nn.Softsign() | |
| def forward(self, x: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor: | |
| x = x.permute((0, 2, 1)) | |
| residual = x | |
| x = self.conv(x) | |
| splitdim = 1 | |
| a, b = x.split(x.size(splitdim) // 2, dim=splitdim) | |
| embeddings = self.embedding_proj(embeddings).unsqueeze(2) | |
| softsign = self.softsign(embeddings) | |
| softsign = softsign.expand_as(a) | |
| a = a + softsign | |
| x = a * torch.sigmoid(b) | |
| x = x + residual | |
| x = x * self.sqrt | |
| x = x.permute((0, 2, 1)) | |
| return x | |
| class ConvTransposed(nn.Module): | |
| """ | |
| A 1D convolutional transposed layer for PyTorch. | |
| This layer applies a 1D convolutional transpose operation to its input tensor, | |
| where the number of channels of the input tensor is the same as the number of channels of the output tensor. | |
| Attributes: | |
| in_channels (int): The number of channels in the input tensor. | |
| out_channels (int): The number of channels in the output tensor. | |
| kernel_size (int): The size of the convolutional kernel. Default: 1. | |
| padding (int): The number of padding elements to add to the input tensor. Default: 0. | |
| conv (BSConv1d): The 1D convolutional transpose layer. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int = 1, | |
| padding: int = 0, | |
| ): | |
| super().__init__() | |
| self.conv = BSConv1d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x.contiguous().transpose(1, 2) | |
| x = self.conv(x) | |
| x = x.contiguous().transpose(1, 2) | |
| return x | |
| class DepthwiseConvModule(nn.Module): | |
| def __init__(self, dim: int, kernel_size: int = 7, expansion: int = 4, lrelu_slope: float = 0.3): | |
| super().__init__() | |
| padding = calc_same_padding(kernel_size) | |
| self.depthwise = nn.Conv1d( | |
| dim, | |
| dim * expansion, | |
| kernel_size=kernel_size, | |
| padding=padding[0], | |
| groups=dim, | |
| ) | |
| self.act = nn.LeakyReLU(lrelu_slope) | |
| self.out = nn.Conv1d(dim * expansion, dim, 1, 1, 0) | |
| self.ln = nn.LayerNorm(dim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.ln(x) | |
| x = x.permute((0, 2, 1)) | |
| x = self.depthwise(x) | |
| x = self.act(x) | |
| x = self.out(x) | |
| x = x.permute((0, 2, 1)) | |
| return x | |
| class AddCoords(nn.Module): | |
| def __init__(self, rank: int, with_r: bool = False): | |
| super().__init__() | |
| self.rank = rank | |
| self.with_r = with_r | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.rank == 1: | |
| batch_size_shape, channel_in_shape, dim_x = x.shape # pylint: disable=unused-variable | |
| xx_range = torch.arange(dim_x, dtype=torch.int32) | |
| xx_channel = xx_range[None, None, :] | |
| xx_channel = xx_channel.float() / (dim_x - 1) | |
| xx_channel = xx_channel * 2 - 1 | |
| xx_channel = xx_channel.repeat(batch_size_shape, 1, 1) | |
| xx_channel = xx_channel.to(x.device) | |
| out = torch.cat([x, xx_channel], dim=1) | |
| if self.with_r: | |
| rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2)) | |
| out = torch.cat([out, rr], dim=1) | |
| elif self.rank == 2: | |
| batch_size_shape, channel_in_shape, dim_y, dim_x = x.shape | |
| xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32) | |
| yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32) | |
| xx_range = torch.arange(dim_y, dtype=torch.int32) | |
| yy_range = torch.arange(dim_x, dtype=torch.int32) | |
| xx_range = xx_range[None, None, :, None] | |
| yy_range = yy_range[None, None, :, None] | |
| xx_channel = torch.matmul(xx_range, xx_ones) | |
| yy_channel = torch.matmul(yy_range, yy_ones) | |
| # transpose y | |
| yy_channel = yy_channel.permute(0, 1, 3, 2) | |
| xx_channel = xx_channel.float() / (dim_y - 1) | |
| yy_channel = yy_channel.float() / (dim_x - 1) | |
| xx_channel = xx_channel * 2 - 1 | |
| yy_channel = yy_channel * 2 - 1 | |
| xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1) | |
| yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1) | |
| xx_channel = xx_channel.to(x.device) | |
| yy_channel = yy_channel.to(x.device) | |
| out = torch.cat([x, xx_channel, yy_channel], dim=1) | |
| if self.with_r: | |
| rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) | |
| out = torch.cat([out, rr], dim=1) | |
| elif self.rank == 3: | |
| batch_size_shape, channel_in_shape, dim_z, dim_y, dim_x = x.shape | |
| xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32) | |
| yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32) | |
| zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32) | |
| xy_range = torch.arange(dim_y, dtype=torch.int32) | |
| xy_range = xy_range[None, None, None, :, None] | |
| yz_range = torch.arange(dim_z, dtype=torch.int32) | |
| yz_range = yz_range[None, None, None, :, None] | |
| zx_range = torch.arange(dim_x, dtype=torch.int32) | |
| zx_range = zx_range[None, None, None, :, None] | |
| xy_channel = torch.matmul(xy_range, xx_ones) | |
| xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2) | |
| yz_channel = torch.matmul(yz_range, yy_ones) | |
| yz_channel = yz_channel.permute(0, 1, 3, 4, 2) | |
| yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4) | |
| zx_channel = torch.matmul(zx_range, zz_ones) | |
| zx_channel = zx_channel.permute(0, 1, 4, 2, 3) | |
| zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3) | |
| xx_channel = xx_channel.to(x.device) | |
| yy_channel = yy_channel.to(x.device) | |
| zz_channel = zz_channel.to(x.device) | |
| out = torch.cat([x, xx_channel, yy_channel, zz_channel], dim=1) | |
| if self.with_r: | |
| rr = torch.sqrt( | |
| torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2) + torch.pow(zz_channel - 0.5, 2) | |
| ) | |
| out = torch.cat([out, rr], dim=1) | |
| else: | |
| raise NotImplementedError | |
| return out | |
| class CoordConv1d(nn.modules.conv.Conv1d): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| padding: int = 0, | |
| dilation: int = 1, | |
| groups: int = 1, | |
| bias: bool = True, | |
| with_r: bool = False, | |
| ): | |
| super().__init__( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| padding, | |
| dilation, | |
| groups, | |
| bias, | |
| ) | |
| self.rank = 1 | |
| self.addcoords = AddCoords(self.rank, with_r) | |
| self.conv = nn.Conv1d( | |
| in_channels + self.rank + int(with_r), | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| padding, | |
| dilation, | |
| groups, | |
| bias, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.addcoords(x) | |
| x = self.conv(x) | |
| return x | |
| class CoordConv2d(nn.modules.conv.Conv2d): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| padding: int = 0, | |
| dilation: int = 1, | |
| groups: int = 1, | |
| bias: bool = True, | |
| with_r: bool = False, | |
| ): | |
| super().__init__( | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| padding, | |
| dilation, | |
| groups, | |
| bias, | |
| ) | |
| self.rank = 2 | |
| self.addcoords = AddCoords(self.rank, with_r) | |
| self.conv = nn.Conv2d( | |
| in_channels + self.rank + int(with_r), | |
| out_channels, | |
| kernel_size, | |
| stride, | |
| padding, | |
| dilation, | |
| groups, | |
| bias, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.addcoords(x) | |
| x = self.conv(x) | |
| return x | |
| class LVCBlock(torch.nn.Module): | |
| """the location-variable convolutions""" | |
| def __init__( # pylint: disable=dangerous-default-value | |
| self, | |
| in_channels, | |
| cond_channels, | |
| stride, | |
| dilations=[1, 3, 9, 27], | |
| lReLU_slope=0.2, | |
| conv_kernel_size=3, | |
| cond_hop_length=256, | |
| kpnet_hidden_channels=64, | |
| kpnet_conv_size=3, | |
| kpnet_dropout=0.0, | |
| ): | |
| super().__init__() | |
| self.cond_hop_length = cond_hop_length | |
| self.conv_layers = len(dilations) | |
| self.conv_kernel_size = conv_kernel_size | |
| self.kernel_predictor = KernelPredictor( | |
| cond_channels=cond_channels, | |
| conv_in_channels=in_channels, | |
| conv_out_channels=2 * in_channels, | |
| conv_layers=len(dilations), | |
| conv_kernel_size=conv_kernel_size, | |
| kpnet_hidden_channels=kpnet_hidden_channels, | |
| kpnet_conv_size=kpnet_conv_size, | |
| kpnet_dropout=kpnet_dropout, | |
| kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, | |
| ) | |
| self.convt_pre = nn.Sequential( | |
| nn.LeakyReLU(lReLU_slope), | |
| nn.utils.parametrizations.weight_norm( | |
| nn.ConvTranspose1d( | |
| in_channels, | |
| in_channels, | |
| 2 * stride, | |
| stride=stride, | |
| padding=stride // 2 + stride % 2, | |
| output_padding=stride % 2, | |
| ) | |
| ), | |
| ) | |
| self.conv_blocks = nn.ModuleList() | |
| for dilation in dilations: | |
| self.conv_blocks.append( | |
| nn.Sequential( | |
| nn.LeakyReLU(lReLU_slope), | |
| nn.utils.parametrizations.weight_norm( | |
| nn.Conv1d( | |
| in_channels, | |
| in_channels, | |
| conv_kernel_size, | |
| padding=dilation * (conv_kernel_size - 1) // 2, | |
| dilation=dilation, | |
| ) | |
| ), | |
| nn.LeakyReLU(lReLU_slope), | |
| ) | |
| ) | |
| def forward(self, x, c): | |
| """forward propagation of the location-variable convolutions. | |
| Args: | |
| x (Tensor): the input sequence (batch, in_channels, in_length) | |
| c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) | |
| Returns: | |
| Tensor: the output sequence (batch, in_channels, in_length) | |
| """ | |
| _, in_channels, _ = x.shape # (B, c_g, L') | |
| x = self.convt_pre(x) # (B, c_g, stride * L') | |
| kernels, bias = self.kernel_predictor(c) | |
| for i, conv in enumerate(self.conv_blocks): | |
| output = conv(x) # (B, c_g, stride * L') | |
| k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) | |
| b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) | |
| output = self.location_variable_convolution( | |
| output, k, b, hop_size=self.cond_hop_length | |
| ) # (B, 2 * c_g, stride * L'): LVC | |
| x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( | |
| output[:, in_channels:, :] | |
| ) # (B, c_g, stride * L'): GAU | |
| return x | |
| def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): # pylint: disable=no-self-use | |
| """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. | |
| Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. | |
| Args: | |
| x (Tensor): the input sequence (batch, in_channels, in_length). | |
| kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) | |
| bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) | |
| dilation (int): the dilation of convolution. | |
| hop_size (int): the hop_size of the conditioning sequence. | |
| Returns: | |
| (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). | |
| """ | |
| batch, _, in_length = x.shape | |
| batch, _, out_channels, kernel_size, kernel_length = kernel.shape | |
| assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" | |
| padding = dilation * int((kernel_size - 1) / 2) | |
| x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) | |
| x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) | |
| if hop_size < dilation: | |
| x = F.pad(x, (0, dilation), "constant", 0) | |
| x = x.unfold( | |
| 3, dilation, dilation | |
| ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) | |
| x = x[:, :, :, :, :hop_size] | |
| x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) | |
| x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) | |
| o = torch.einsum("bildsk,biokl->bolsd", x, kernel) | |
| o = o.to(memory_format=torch.channels_last_3d) | |
| bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) | |
| o = o + bias | |
| o = o.contiguous().view(batch, out_channels, -1) | |
| return o | |
| def remove_weight_norm(self): | |
| self.kernel_predictor.remove_weight_norm() | |
| parametrize.remove_parametrizations(self.convt_pre[1], "weight") | |
| for block in self.conv_blocks: | |
| parametrize.remove_parametrizations(block[1], "weight") | |