import torch from torch import nn from TTS.tts.layers.generic.res_conv_bn import Conv1dBN, Conv1dBNBlock, ResidualConv1dBNBlock from TTS.tts.layers.generic.transformer import FFTransformerBlock from TTS.tts.layers.generic.wavenet import WNBlocks from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer class WaveNetDecoder(nn.Module): """WaveNet based decoder with a prenet and a postnet. prenet: conv1d_1x1 postnet: 3 x [conv1d_1x1 -> relu] -> conv1d_1x1 TODO: Integrate speaker conditioning vector. Note: default wavenet parameters; params = { "num_blocks": 12, "hidden_channels":192, "kernel_size": 5, "dilation_rate": 1, "num_layers": 4, "dropout_p": 0.05 } Args: in_channels (int): number of input channels. out_channels (int): number of output channels. hidden_channels (int): number of hidden channels for prenet and postnet. params (dict): dictionary for residual convolutional blocks. """ def __init__(self, in_channels, out_channels, hidden_channels, c_in_channels, params): super().__init__() # prenet self.prenet = torch.nn.Conv1d(in_channels, params["hidden_channels"], 1) # wavenet layers self.wn = WNBlocks(params["hidden_channels"], c_in_channels=c_in_channels, **params) # postnet self.postnet = [ torch.nn.Conv1d(params["hidden_channels"], hidden_channels, 1), torch.nn.ReLU(), torch.nn.Conv1d(hidden_channels, hidden_channels, 1), torch.nn.ReLU(), torch.nn.Conv1d(hidden_channels, hidden_channels, 1), torch.nn.ReLU(), torch.nn.Conv1d(hidden_channels, out_channels, 1), ] self.postnet = nn.Sequential(*self.postnet) def forward(self, x, x_mask=None, g=None): x = self.prenet(x) * x_mask x = self.wn(x, x_mask, g) o = self.postnet(x) * x_mask return o class RelativePositionTransformerDecoder(nn.Module): """Decoder with Relative Positional Transformer. Note: Default params params={ 'hidden_channels_ffn': 128, 'num_heads': 2, "kernel_size": 3, "dropout_p": 0.1, "num_layers": 8, "rel_attn_window_size": 4, "input_length": None } Args: in_channels (int): number of input channels. out_channels (int): number of output channels. hidden_channels (int): number of hidden channels including Transformer layers. params (dict): dictionary for residual convolutional blocks. """ def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() self.prenet = Conv1dBN(in_channels, hidden_channels, 1, 1) self.rel_pos_transformer = RelativePositionTransformer(in_channels, out_channels, hidden_channels, **params) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument o = self.prenet(x) * x_mask o = self.rel_pos_transformer(o, x_mask) return o class FFTransformerDecoder(nn.Module): """Decoder with FeedForwardTransformer. Default params params={ 'hidden_channels_ffn': 1024, 'num_heads': 2, "dropout_p": 0.1, "num_layers": 6, } Args: in_channels (int): number of input channels. out_channels (int): number of output channels. hidden_channels (int): number of hidden channels including Transformer layers. params (dict): dictionary for residual convolutional blocks. """ def __init__(self, in_channels, out_channels, params): super().__init__() self.transformer_block = FFTransformerBlock(in_channels, **params) self.postnet = nn.Conv1d(in_channels, out_channels, 1) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument # TODO: handle multi-speaker x_mask = 1 if x_mask is None else x_mask o = self.transformer_block(x) * x_mask o = self.postnet(o) * x_mask return o class ResidualConv1dBNDecoder(nn.Module): """Residual Convolutional Decoder as in the original Speedy Speech paper TODO: Integrate speaker conditioning vector. Note: Default params params = { "kernel_size": 4, "dilations": 4 * [1, 2, 4, 8] + [1], "num_conv_blocks": 2, "num_res_blocks": 17 } Args: in_channels (int): number of input channels. out_channels (int): number of output channels. hidden_channels (int): number of hidden channels including ResidualConv1dBNBlock layers. params (dict): dictionary for residual convolutional blocks. """ def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() self.res_conv_block = ResidualConv1dBNBlock(in_channels, hidden_channels, hidden_channels, **params) self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1) self.postnet = nn.Sequential( Conv1dBNBlock( hidden_channels, hidden_channels, hidden_channels, params["kernel_size"], 1, num_conv_blocks=2 ), nn.Conv1d(hidden_channels, out_channels, 1), ) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument o = self.res_conv_block(x, x_mask) o = self.post_conv(o) + x return self.postnet(o) * x_mask class Decoder(nn.Module): """Decodes the expanded phoneme encoding into spectrograms Args: out_channels (int): number of output channels. in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers. decoder_type (str): decoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. decoder_params (dict): model parameters for specified decoder type. c_in_channels (int): number of channels for conditional input. Shapes: - input: (B, C, T) """ # pylint: disable=dangerous-default-value def __init__( self, out_channels, in_hidden_channels, decoder_type="residual_conv_bn", decoder_params={ "kernel_size": 4, "dilations": 4 * [1, 2, 4, 8] + [1], "num_conv_blocks": 2, "num_res_blocks": 17, }, c_in_channels=0, ): super().__init__() if decoder_type.lower() == "relative_position_transformer": self.decoder = RelativePositionTransformerDecoder( in_channels=in_hidden_channels, out_channels=out_channels, hidden_channels=in_hidden_channels, params=decoder_params, ) elif decoder_type.lower() == "residual_conv_bn": self.decoder = ResidualConv1dBNDecoder( in_channels=in_hidden_channels, out_channels=out_channels, hidden_channels=in_hidden_channels, params=decoder_params, ) elif decoder_type.lower() == "wavenet": self.decoder = WaveNetDecoder( in_channels=in_hidden_channels, out_channels=out_channels, hidden_channels=in_hidden_channels, c_in_channels=c_in_channels, params=decoder_params, ) elif decoder_type.lower() == "fftransformer": self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params) else: raise ValueError(f"[!] Unknown decoder type - {decoder_type}") def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument """ Args: x: [B, C, T] x_mask: [B, 1, T] g: [B, C_g, 1] """ # TODO: implement multi-speaker o = self.decoder(x, x_mask, g) return o