from typing import List, Tuple, Union import torch import torch.nn as nn # pylint: disable=consider-using-from-import import torch.nn.functional as F from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d from TTS.tts.layers.delightful_tts.networks import STL def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor: batch_size = lengths.shape[0] max_len = torch.max(lengths).item() ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1) mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) return mask def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: return torch.ceil(lens / stride).int() class ReferenceEncoder(nn.Module): """ Referance encoder for utterance and phoneme prosody encoders. Reference encoder made up of convolution and RNN layers. Args: num_mels (int): Number of mel frames to produce. ref_enc_filters (list[int]): List of channel sizes for encoder layers. ref_enc_size (int): Size of the kernel for the conv layers. ref_enc_strides (List[int]): List of strides to use for conv layers. ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit. Inputs: inputs, mask - **inputs** (batch, dim, time): Tensor containing mel vector - **lengths** (batch): Tensor containing the mel lengths. Returns: - **outputs** (batch, time, dim): Tensor produced by Reference Encoder. """ def __init__( self, num_mels: int, ref_enc_filters: List[Union[int, int, int, int, int, int]], ref_enc_size: int, ref_enc_strides: List[Union[int, int, int, int, int]], ref_enc_gru_size: int, ): super().__init__() n_mel_channels = num_mels self.n_mel_channels = n_mel_channels K = len(ref_enc_filters) filters = [self.n_mel_channels] + ref_enc_filters strides = [1] + ref_enc_strides # Use CoordConv at the first layer to better preserve positional information: https://arxiv.org/pdf/1811.02122.pdf convs = [ CoordConv1d( in_channels=filters[0], out_channels=filters[0 + 1], kernel_size=ref_enc_size, stride=strides[0], padding=ref_enc_size // 2, with_r=True, ) ] convs2 = [ nn.Conv1d( in_channels=filters[i], out_channels=filters[i + 1], kernel_size=ref_enc_size, stride=strides[i], padding=ref_enc_size // 2, ) for i in range(1, K) ] convs.extend(convs2) self.convs = nn.ModuleList(convs) self.norms = nn.ModuleList([nn.InstanceNorm1d(num_features=ref_enc_filters[i], affine=True) for i in range(K)]) self.gru = nn.GRU( input_size=ref_enc_filters[-1], hidden_size=ref_enc_gru_size, batch_first=True, ) def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ inputs --- [N, n_mels, timesteps] outputs --- [N, E//2] """ mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1) x = x.masked_fill(mel_masks, 0) for conv, norm in zip(self.convs, self.norms): x = conv(x) x = F.leaky_relu(x, 0.3) # [N, 128, Ty//2^K, n_mels//2^K] x = norm(x) for _ in range(2): mel_lens = stride_lens(mel_lens) mel_masks = get_mask_from_lengths(mel_lens) x = x.masked_fill(mel_masks.unsqueeze(1), 0) x = x.permute((0, 2, 1)) x = torch.nn.utils.rnn.pack_padded_sequence(x, mel_lens.cpu().int(), batch_first=True, enforce_sorted=False) self.gru.flatten_parameters() x, memory = self.gru(x) # memory --- [N, Ty, E//2], out --- [1, N, E//2] x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True) return x, memory, mel_masks def calculate_channels( # pylint: disable=no-self-use self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int ) -> int: for _ in range(n_convs): L = (L - kernel_size + 2 * pad) // stride + 1 return L class UtteranceLevelProsodyEncoder(nn.Module): def __init__( self, num_mels: int, ref_enc_filters: List[Union[int, int, int, int, int, int]], ref_enc_size: int, ref_enc_strides: List[Union[int, int, int, int, int]], ref_enc_gru_size: int, dropout: float, n_hidden: int, bottleneck_size_u: int, token_num: int, ): """ Encoder to extract prosody from utterance. it is made up of a reference encoder with a couple of linear layers and style token layer with dropout. Args: num_mels (int): Number of mel frames to produce. ref_enc_filters (list[int]): List of channel sizes for ref encoder layers. ref_enc_size (int): Size of the kernel for the ref encoder conv layers. ref_enc_strides (List[int]): List of strides to use for teh ref encoder conv layers. ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit. dropout (float): Probability of dropout. n_hidden (int): Size of hidden layers. bottleneck_size_u (int): Size of the bottle neck layer. Inputs: inputs, mask - **inputs** (batch, dim, time): Tensor containing mel vector - **lengths** (batch): Tensor containing the mel lengths. Returns: - **outputs** (batch, 1, dim): Tensor produced by Utterance Level Prosody Encoder. """ super().__init__() self.E = n_hidden self.d_q = self.d_k = n_hidden bottleneck_size = bottleneck_size_u self.encoder = ReferenceEncoder( ref_enc_filters=ref_enc_filters, ref_enc_gru_size=ref_enc_gru_size, ref_enc_size=ref_enc_size, ref_enc_strides=ref_enc_strides, num_mels=num_mels, ) self.encoder_prj = nn.Linear(ref_enc_gru_size, self.E // 2) self.stl = STL(n_hidden=n_hidden, token_num=token_num) self.encoder_bottleneck = nn.Linear(self.E, bottleneck_size) self.dropout = nn.Dropout(dropout) def forward(self, mels: torch.Tensor, mel_lens: torch.Tensor) -> torch.Tensor: """ Shapes: mels: :math: `[B, C, T]` mel_lens: :math: `[B]` out --- [N, seq_len, E] """ _, embedded_prosody, _ = self.encoder(mels, mel_lens) # Bottleneck embedded_prosody = self.encoder_prj(embedded_prosody) # Style Token out = self.encoder_bottleneck(self.stl(embedded_prosody)) out = self.dropout(out) out = out.view((-1, 1, out.shape[3])) return out class PhonemeLevelProsodyEncoder(nn.Module): def __init__( self, num_mels: int, ref_enc_filters: List[Union[int, int, int, int, int, int]], ref_enc_size: int, ref_enc_strides: List[Union[int, int, int, int, int]], ref_enc_gru_size: int, dropout: float, n_hidden: int, n_heads: int, bottleneck_size_p: int, ): super().__init__() self.E = n_hidden self.d_q = self.d_k = n_hidden bottleneck_size = bottleneck_size_p self.encoder = ReferenceEncoder( ref_enc_filters=ref_enc_filters, ref_enc_gru_size=ref_enc_gru_size, ref_enc_size=ref_enc_size, ref_enc_strides=ref_enc_strides, num_mels=num_mels, ) self.encoder_prj = nn.Linear(ref_enc_gru_size, n_hidden) self.attention = ConformerMultiHeadedSelfAttention( d_model=n_hidden, num_heads=n_heads, dropout_p=dropout, ) self.encoder_bottleneck = nn.Linear(n_hidden, bottleneck_size) def forward( self, x: torch.Tensor, src_mask: torch.Tensor, mels: torch.Tensor, mel_lens: torch.Tensor, encoding: torch.Tensor, ) -> torch.Tensor: """ x --- [N, seq_len, encoder_embedding_dim] mels --- [N, Ty/r, n_mels*r], r=1 out --- [N, seq_len, bottleneck_size] attn --- [N, seq_len, ref_len], Ty/r = ref_len """ embedded_prosody, _, mel_masks = self.encoder(mels, mel_lens) # Bottleneck embedded_prosody = self.encoder_prj(embedded_prosody) attn_mask = mel_masks.view((mel_masks.shape[0], 1, 1, -1)) x, _ = self.attention( query=x, key=embedded_prosody, value=embedded_prosody, mask=attn_mask, encoding=encoding, ) x = self.encoder_bottleneck(x) x = x.masked_fill(src_mask.unsqueeze(-1), 0.0) return x