# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math import torch.nn as nn from modules.general.utils import Conv1d, zero_module from .residual_block import ResidualBlock class BiDilConv(nn.Module): r"""Dilated CNN architecture with residual connections, default diffusion decoder. Args: input_channel: The number of input channels. base_channel: The number of base channels. n_res_block: The number of residual blocks. conv_kernel_size: The kernel size of convolutional layers. dilation_cycle_length: The cycle length of dilation. conditioner_size: The size of conditioner. """ def __init__( self, input_channel, base_channel, n_res_block, conv_kernel_size, dilation_cycle_length, conditioner_size, output_channel: int = -1, ): super().__init__() self.input_channel = input_channel self.base_channel = base_channel self.n_res_block = n_res_block self.conv_kernel_size = conv_kernel_size self.dilation_cycle_length = dilation_cycle_length self.conditioner_size = conditioner_size self.output_channel = output_channel if output_channel > 0 else input_channel self.input = nn.Sequential( Conv1d( input_channel, base_channel, 1, ), nn.ReLU(), ) self.residual_blocks = nn.ModuleList( [ ResidualBlock( channels=base_channel, kernel_size=conv_kernel_size, dilation=2 ** (i % dilation_cycle_length), d_context=conditioner_size, ) for i in range(n_res_block) ] ) self.out_proj = nn.Sequential( Conv1d( base_channel, base_channel, 1, ), nn.ReLU(), zero_module( Conv1d( base_channel, self.output_channel, 1, ), ), ) def forward(self, x, y, context=None): """ Args: x: Noisy mel-spectrogram [B x ``n_mel`` x L] y: FILM embeddings with the shape of (B, ``base_channel``) context: Context with the shape of [B x ``d_context`` x L], default to None. """ h = self.input(x) skip = None for i in range(self.n_res_block): h, skip_connection = self.residual_blocks[i](h, y, context) skip = skip_connection if skip is None else skip_connection + skip out = skip / math.sqrt(self.n_res_block) out = self.out_proj(out) return out