# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math from torch import nn from torch.nn import functional as F from .modules import Conv1d1x1, ResidualConv1dGLU from .upsample import ConvInUpsampleNetwork def receptive_field_size( total_layers, num_cycles, kernel_size, dilation=lambda x: 2**x ): """Compute receptive field size Args: total_layers (int): total layers num_cycles (int): cycles kernel_size (int): kernel size dilation (lambda): lambda to compute dilation factor. ``lambda x : 1`` to disable dilated convolution. Returns: int: receptive field size in sample """ assert total_layers % num_cycles == 0 layers_per_cycle = total_layers // num_cycles dilations = [dilation(i % layers_per_cycle) for i in range(total_layers)] return (kernel_size - 1) * sum(dilations) + 1 class WaveNet(nn.Module): """The WaveNet model that supports local and global conditioning. Args: out_channels (int): Output channels. If input_type is mu-law quantized one-hot vecror. this must equal to the quantize channels. Other wise num_mixtures x 3 (pi, mu, log_scale). layers (int): Number of total layers stacks (int): Number of dilation cycles residual_channels (int): Residual input / output channels gate_channels (int): Gated activation channels. skip_out_channels (int): Skip connection channels. kernel_size (int): Kernel size of convolution layers. dropout (float): Dropout probability. input_dim (int): Number of mel-spec dimension. upsample_scales (list): List of upsample scale. ``np.prod(upsample_scales)`` must equal to hop size. Used only if upsample_conditional_features is enabled. freq_axis_kernel_size (int): Freq-axis kernel_size for transposed convolution layers for upsampling. If you only care about time-axis upsampling, set this to 1. scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise quantized one-hot vector is expected.. """ def __init__(self, cfg): super(WaveNet, self).__init__() self.cfg = cfg self.scalar_input = self.cfg.VOCODER.SCALAR_INPUT self.out_channels = self.cfg.VOCODER.OUT_CHANNELS self.cin_channels = self.cfg.VOCODER.INPUT_DIM self.residual_channels = self.cfg.VOCODER.RESIDUAL_CHANNELS self.layers = self.cfg.VOCODER.LAYERS self.stacks = self.cfg.VOCODER.STACKS self.gate_channels = self.cfg.VOCODER.GATE_CHANNELS self.kernel_size = self.cfg.VOCODER.KERNEL_SIZE self.skip_out_channels = self.cfg.VOCODER.SKIP_OUT_CHANNELS self.dropout = self.cfg.VOCODER.DROPOUT self.upsample_scales = self.cfg.VOCODER.UPSAMPLE_SCALES self.mel_frame_pad = self.cfg.VOCODER.MEL_FRAME_PAD assert self.layers % self.stacks == 0 layers_per_stack = self.layers // self.stacks if self.scalar_input: self.first_conv = Conv1d1x1(1, self.residual_channels) else: self.first_conv = Conv1d1x1(self.out_channels, self.residual_channels) self.conv_layers = nn.ModuleList() for layer in range(self.layers): dilation = 2 ** (layer % layers_per_stack) conv = ResidualConv1dGLU( self.residual_channels, self.gate_channels, kernel_size=self.kernel_size, skip_out_channels=self.skip_out_channels, bias=True, dilation=dilation, dropout=self.dropout, cin_channels=self.cin_channels, ) self.conv_layers.append(conv) self.last_conv_layers = nn.ModuleList( [ nn.ReLU(inplace=True), Conv1d1x1(self.skip_out_channels, self.skip_out_channels), nn.ReLU(inplace=True), Conv1d1x1(self.skip_out_channels, self.out_channels), ] ) self.upsample_net = ConvInUpsampleNetwork( upsample_scales=self.upsample_scales, cin_pad=self.mel_frame_pad, cin_channels=self.cin_channels, ) self.receptive_field = receptive_field_size( self.layers, self.stacks, self.kernel_size ) def forward(self, x, mel, softmax=False): """Forward step Args: x (Tensor): One-hot encoded audio signal, shape (B x C x T) mel (Tensor): Local conditioning features, shape (B x cin_channels x T) softmax (bool): Whether applies softmax or not. Returns: Tensor: output, shape B x out_channels x T """ B, _, T = x.size() mel = self.upsample_net(mel) assert mel.shape[-1] == x.shape[-1] x = self.first_conv(x) skips = 0 for f in self.conv_layers: x, h = f(x, mel) skips += h skips *= math.sqrt(1.0 / len(self.conv_layers)) x = skips for f in self.last_conv_layers: x = f(x) x = F.softmax(x, dim=1) if softmax else x return x def clear_buffer(self): self.first_conv.clear_buffer() for f in self.conv_layers: f.clear_buffer() for f in self.last_conv_layers: try: f.clear_buffer() except AttributeError: pass def make_generation_fast_(self): def remove_weight_norm(m): try: nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return self.apply(remove_weight_norm)