# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import numpy as np from torch import nn from torch.nn import functional as F class Stretch2d(nn.Module): def __init__(self, x_scale, y_scale, mode="nearest"): super(Stretch2d, self).__init__() self.x_scale = x_scale self.y_scale = y_scale self.mode = mode def forward(self, x): return F.interpolate( x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode ) def _get_activation(upsample_activation): nonlinear = getattr(nn, upsample_activation) return nonlinear class UpsampleNetwork(nn.Module): def __init__( self, upsample_scales, upsample_activation="none", upsample_activation_params={}, mode="nearest", freq_axis_kernel_size=1, cin_pad=0, cin_channels=128, ): super(UpsampleNetwork, self).__init__() self.up_layers = nn.ModuleList() total_scale = np.prod(upsample_scales) self.indent = cin_pad * total_scale for scale in upsample_scales: freq_axis_padding = (freq_axis_kernel_size - 1) // 2 k_size = (freq_axis_kernel_size, scale * 2 + 1) padding = (freq_axis_padding, scale) stretch = Stretch2d(scale, 1, mode) conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) conv.weight.data.fill_(1.0 / np.prod(k_size)) conv = nn.utils.weight_norm(conv) self.up_layers.append(stretch) self.up_layers.append(conv) if upsample_activation != "none": nonlinear = _get_activation(upsample_activation) self.up_layers.append(nonlinear(**upsample_activation_params)) def forward(self, c): """ Args: c : B x C x T """ # B x 1 x C x T c = c.unsqueeze(1) for f in self.up_layers: c = f(c) # B x C x T c = c.squeeze(1) if self.indent > 0: c = c[:, :, self.indent : -self.indent] return c class ConvInUpsampleNetwork(nn.Module): def __init__( self, upsample_scales, upsample_activation="none", upsample_activation_params={}, mode="nearest", freq_axis_kernel_size=1, cin_pad=0, cin_channels=128, ): super(ConvInUpsampleNetwork, self).__init__() # To capture wide-context information in conditional features # meaningless if cin_pad == 0 ks = 2 * cin_pad + 1 self.conv_in = nn.Conv1d( cin_channels, cin_channels, kernel_size=ks, padding=cin_pad, bias=False ) self.upsample = UpsampleNetwork( upsample_scales, upsample_activation, upsample_activation_params, mode, freq_axis_kernel_size, cin_pad=cin_pad, cin_channels=cin_channels, ) def forward(self, c): c_up = self.upsample(self.conv_in(c)) return c_up