# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import math from torch import nn from torch.nn import functional as F from .conv import Conv1d as conv_Conv1d def Conv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): m = conv_Conv1d(in_channels, out_channels, kernel_size, **kwargs) nn.init.kaiming_normal_(m.weight, nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) return nn.utils.weight_norm(m) def Conv1d1x1(in_channels, out_channels, bias=True): return Conv1d( in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias ) def _conv1x1_forward(conv, x, is_incremental): if is_incremental: x = conv.incremental_forward(x) else: x = conv(x) return x class ResidualConv1dGLU(nn.Module): """Residual dilated conv1d + Gated linear unit Args: residual_channels (int): Residual input / output channels gate_channels (int): Gated activation channels. kernel_size (int): Kernel size of convolution layers. skip_out_channels (int): Skip connection channels. If None, set to same as ``residual_channels``. cin_channels (int): Local conditioning channels. If negative value is set, local conditioning is disabled. dropout (float): Dropout probability. padding (int): Padding for convolution layers. If None, proper padding is computed depends on dilation and kernel_size. dilation (int): Dilation factor. """ def __init__( self, residual_channels, gate_channels, kernel_size, skip_out_channels=None, cin_channels=-1, dropout=1 - 0.95, padding=None, dilation=1, causal=True, bias=True, *args, **kwargs, ): super(ResidualConv1dGLU, self).__init__() self.dropout = dropout if skip_out_channels is None: skip_out_channels = residual_channels if padding is None: # no future time stamps available if causal: padding = (kernel_size - 1) * dilation else: padding = (kernel_size - 1) // 2 * dilation self.causal = causal self.conv = Conv1d( residual_channels, gate_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, *args, **kwargs, ) # mel conditioning self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, bias=False) gate_out_channels = gate_channels // 2 self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias) self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, bias=bias) def forward(self, x, c=None): return self._forward(x, c, False) def incremental_forward(self, x, c=None): return self._forward(x, c, True) def clear_buffer(self): for c in [ self.conv, self.conv1x1_out, self.conv1x1_skip, self.conv1x1c, ]: if c is not None: c.clear_buffer() def _forward(self, x, c, is_incremental): """Forward Args: x (Tensor): B x C x T c (Tensor): B x C x T, Mel conditioning features Returns: Tensor: output """ residual = x x = F.dropout(x, p=self.dropout, training=self.training) if is_incremental: splitdim = -1 x = self.conv.incremental_forward(x) else: splitdim = 1 x = self.conv(x) # remove future time steps x = x[:, :, : residual.size(-1)] if self.causal else x a, b = x.split(x.size(splitdim) // 2, dim=splitdim) assert self.conv1x1c is not None c = _conv1x1_forward(self.conv1x1c, c, is_incremental) ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) a, b = a + ca, b + cb x = torch.tanh(a) * torch.sigmoid(b) # For skip connection s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental) # For residual connection x = _conv1x1_forward(self.conv1x1_out, x, is_incremental) x = (x + residual) * math.sqrt(0.5) return x, s