import math import torch from torch import nn from TTS.tts.layers.generic.gated_conv import GatedConvBlock from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.glow import ResidualConv1dLayerNormBlock from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer from TTS.tts.utils.helpers import sequence_mask class Encoder(nn.Module): """Glow-TTS encoder module. :: embedding -> -> encoder_module -> --> proj_mean | |-> proj_var | |-> concat -> duration_predictor ↑ speaker_embed Args: num_chars (int): number of characters. out_channels (int): number of output channels. hidden_channels (int): encoder's embedding size. hidden_channels_ffn (int): transformer's feed-forward channels. kernel_size (int): kernel size for conv layers and duration predictor. dropout_p (float): dropout rate for any dropout layer. mean_only (bool): if True, output only mean values and use constant std. use_prenet (bool): if True, use pre-convolutional layers before transformer layers. c_in_channels (int): number of channels in conditional input. Shapes: - input: (B, T, C) :: suggested encoder params... for encoder_type == 'rel_pos_transformer' encoder_params={ 'kernel_size':3, 'dropout_p': 0.1, 'num_layers': 6, 'num_heads': 2, 'hidden_channels_ffn': 768, # 4 times the hidden_channels 'input_length': None } for encoder_type == 'gated_conv' encoder_params={ 'kernel_size':5, 'dropout_p': 0.1, 'num_layers': 9, } for encoder_type == 'residual_conv_bn' encoder_params={ "kernel_size": 4, "dilations": [1, 2, 4, 1, 2, 4, 1, 2, 4, 1, 2, 4, 1], "num_conv_blocks": 2, "num_res_blocks": 13 } for encoder_type == 'time_depth_separable' encoder_params={ "kernel_size": 5, 'num_layers': 9, } """ def __init__( self, num_chars, out_channels, hidden_channels, hidden_channels_dp, encoder_type, encoder_params, dropout_p_dp=0.1, mean_only=False, use_prenet=True, c_in_channels=0, ): super().__init__() # class arguments self.num_chars = num_chars self.out_channels = out_channels self.hidden_channels = hidden_channels self.hidden_channels_dp = hidden_channels_dp self.dropout_p_dp = dropout_p_dp self.mean_only = mean_only self.use_prenet = use_prenet self.c_in_channels = c_in_channels self.encoder_type = encoder_type # embedding layer self.emb = nn.Embedding(num_chars, hidden_channels) nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) # init encoder module if encoder_type.lower() == "rel_pos_transformer": if use_prenet: self.prenet = ResidualConv1dLayerNormBlock( hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5 ) self.encoder = RelativePositionTransformer( hidden_channels, hidden_channels, hidden_channels, **encoder_params ) elif encoder_type.lower() == "gated_conv": self.encoder = GatedConvBlock(hidden_channels, **encoder_params) elif encoder_type.lower() == "residual_conv_bn": if use_prenet: self.prenet = nn.Sequential(nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU()) self.encoder = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **encoder_params) self.postnet = nn.Sequential( nn.Conv1d(self.hidden_channels, self.hidden_channels, 1), nn.BatchNorm1d(self.hidden_channels) ) elif encoder_type.lower() == "time_depth_separable": if use_prenet: self.prenet = ResidualConv1dLayerNormBlock( hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5 ) self.encoder = TimeDepthSeparableConvBlock( hidden_channels, hidden_channels, hidden_channels, **encoder_params ) else: raise ValueError(" [!] Unkown encoder type.") # final projection layers self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1) if not mean_only: self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1) # duration predictor self.duration_predictor = DurationPredictor( hidden_channels + c_in_channels, hidden_channels_dp, 3, dropout_p_dp ) def forward(self, x, x_lengths, g=None): """ Shapes: - x: :math:`[B, C, T]` - x_lengths: :math:`[B]` - g (optional): :math:`[B, 1, T]` """ # embedding layer # [B ,T, D] x = self.emb(x) * math.sqrt(self.hidden_channels) # [B, D, T] x = torch.transpose(x, 1, -1) # compute input sequence mask x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # prenet if hasattr(self, "prenet") and self.use_prenet: x = self.prenet(x, x_mask) # encoder x = self.encoder(x, x_mask) # postnet if hasattr(self, "postnet"): x = self.postnet(x) * x_mask # set duration predictor input if g is not None: g_exp = g.expand(-1, -1, x.size(-1)) x_dp = torch.cat([x.detach(), g_exp], 1) else: x_dp = x.detach() # final projection layer x_m = self.proj_m(x) * x_mask if not self.mean_only: x_logs = self.proj_s(x) * x_mask else: x_logs = torch.zeros_like(x_m) # duration predictor logw = self.duration_predictor(x_dp, x_mask) return x_m, x_logs, logw, x_mask