Spaces:
Runtime error
Runtime error
| from typing import List, Tuple, Union | |
| import torch | |
| import torch.nn as nn # pylint: disable=consider-using-from-import | |
| import torch.nn.functional as F | |
| from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention | |
| from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d | |
| from TTS.tts.layers.delightful_tts.networks import STL | |
| def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor: | |
| batch_size = lengths.shape[0] | |
| max_len = torch.max(lengths).item() | |
| ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1) | |
| mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) | |
| return mask | |
| def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: | |
| return torch.ceil(lens / stride).int() | |
| class ReferenceEncoder(nn.Module): | |
| """ | |
| Referance encoder for utterance and phoneme prosody encoders. Reference encoder | |
| made up of convolution and RNN layers. | |
| Args: | |
| num_mels (int): Number of mel frames to produce. | |
| ref_enc_filters (list[int]): List of channel sizes for encoder layers. | |
| ref_enc_size (int): Size of the kernel for the conv layers. | |
| ref_enc_strides (List[int]): List of strides to use for conv layers. | |
| ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit. | |
| Inputs: inputs, mask | |
| - **inputs** (batch, dim, time): Tensor containing mel vector | |
| - **lengths** (batch): Tensor containing the mel lengths. | |
| Returns: | |
| - **outputs** (batch, time, dim): Tensor produced by Reference Encoder. | |
| """ | |
| def __init__( | |
| self, | |
| num_mels: int, | |
| ref_enc_filters: List[Union[int, int, int, int, int, int]], | |
| ref_enc_size: int, | |
| ref_enc_strides: List[Union[int, int, int, int, int]], | |
| ref_enc_gru_size: int, | |
| ): | |
| super().__init__() | |
| n_mel_channels = num_mels | |
| self.n_mel_channels = n_mel_channels | |
| K = len(ref_enc_filters) | |
| filters = [self.n_mel_channels] + ref_enc_filters | |
| strides = [1] + ref_enc_strides | |
| # Use CoordConv at the first layer to better preserve positional information: https://arxiv.org/pdf/1811.02122.pdf | |
| convs = [ | |
| CoordConv1d( | |
| in_channels=filters[0], | |
| out_channels=filters[0 + 1], | |
| kernel_size=ref_enc_size, | |
| stride=strides[0], | |
| padding=ref_enc_size // 2, | |
| with_r=True, | |
| ) | |
| ] | |
| convs2 = [ | |
| nn.Conv1d( | |
| in_channels=filters[i], | |
| out_channels=filters[i + 1], | |
| kernel_size=ref_enc_size, | |
| stride=strides[i], | |
| padding=ref_enc_size // 2, | |
| ) | |
| for i in range(1, K) | |
| ] | |
| convs.extend(convs2) | |
| self.convs = nn.ModuleList(convs) | |
| self.norms = nn.ModuleList([nn.InstanceNorm1d(num_features=ref_enc_filters[i], affine=True) for i in range(K)]) | |
| self.gru = nn.GRU( | |
| input_size=ref_enc_filters[-1], | |
| hidden_size=ref_enc_gru_size, | |
| batch_first=True, | |
| ) | |
| def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| inputs --- [N, n_mels, timesteps] | |
| outputs --- [N, E//2] | |
| """ | |
| mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1) | |
| x = x.masked_fill(mel_masks, 0) | |
| for conv, norm in zip(self.convs, self.norms): | |
| x = conv(x) | |
| x = F.leaky_relu(x, 0.3) # [N, 128, Ty//2^K, n_mels//2^K] | |
| x = norm(x) | |
| for _ in range(2): | |
| mel_lens = stride_lens(mel_lens) | |
| mel_masks = get_mask_from_lengths(mel_lens) | |
| x = x.masked_fill(mel_masks.unsqueeze(1), 0) | |
| x = x.permute((0, 2, 1)) | |
| x = torch.nn.utils.rnn.pack_padded_sequence(x, mel_lens.cpu().int(), batch_first=True, enforce_sorted=False) | |
| self.gru.flatten_parameters() | |
| x, memory = self.gru(x) # memory --- [N, Ty, E//2], out --- [1, N, E//2] | |
| x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True) | |
| return x, memory, mel_masks | |
| def calculate_channels( # pylint: disable=no-self-use | |
| self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int | |
| ) -> int: | |
| for _ in range(n_convs): | |
| L = (L - kernel_size + 2 * pad) // stride + 1 | |
| return L | |
| class UtteranceLevelProsodyEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| num_mels: int, | |
| ref_enc_filters: List[Union[int, int, int, int, int, int]], | |
| ref_enc_size: int, | |
| ref_enc_strides: List[Union[int, int, int, int, int]], | |
| ref_enc_gru_size: int, | |
| dropout: float, | |
| n_hidden: int, | |
| bottleneck_size_u: int, | |
| token_num: int, | |
| ): | |
| """ | |
| Encoder to extract prosody from utterance. it is made up of a reference encoder | |
| with a couple of linear layers and style token layer with dropout. | |
| Args: | |
| num_mels (int): Number of mel frames to produce. | |
| ref_enc_filters (list[int]): List of channel sizes for ref encoder layers. | |
| ref_enc_size (int): Size of the kernel for the ref encoder conv layers. | |
| ref_enc_strides (List[int]): List of strides to use for teh ref encoder conv layers. | |
| ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit. | |
| dropout (float): Probability of dropout. | |
| n_hidden (int): Size of hidden layers. | |
| bottleneck_size_u (int): Size of the bottle neck layer. | |
| Inputs: inputs, mask | |
| - **inputs** (batch, dim, time): Tensor containing mel vector | |
| - **lengths** (batch): Tensor containing the mel lengths. | |
| Returns: | |
| - **outputs** (batch, 1, dim): Tensor produced by Utterance Level Prosody Encoder. | |
| """ | |
| super().__init__() | |
| self.E = n_hidden | |
| self.d_q = self.d_k = n_hidden | |
| bottleneck_size = bottleneck_size_u | |
| self.encoder = ReferenceEncoder( | |
| ref_enc_filters=ref_enc_filters, | |
| ref_enc_gru_size=ref_enc_gru_size, | |
| ref_enc_size=ref_enc_size, | |
| ref_enc_strides=ref_enc_strides, | |
| num_mels=num_mels, | |
| ) | |
| self.encoder_prj = nn.Linear(ref_enc_gru_size, self.E // 2) | |
| self.stl = STL(n_hidden=n_hidden, token_num=token_num) | |
| self.encoder_bottleneck = nn.Linear(self.E, bottleneck_size) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, mels: torch.Tensor, mel_lens: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Shapes: | |
| mels: :math: `[B, C, T]` | |
| mel_lens: :math: `[B]` | |
| out --- [N, seq_len, E] | |
| """ | |
| _, embedded_prosody, _ = self.encoder(mels, mel_lens) | |
| # Bottleneck | |
| embedded_prosody = self.encoder_prj(embedded_prosody) | |
| # Style Token | |
| out = self.encoder_bottleneck(self.stl(embedded_prosody)) | |
| out = self.dropout(out) | |
| out = out.view((-1, 1, out.shape[3])) | |
| return out | |
| class PhonemeLevelProsodyEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| num_mels: int, | |
| ref_enc_filters: List[Union[int, int, int, int, int, int]], | |
| ref_enc_size: int, | |
| ref_enc_strides: List[Union[int, int, int, int, int]], | |
| ref_enc_gru_size: int, | |
| dropout: float, | |
| n_hidden: int, | |
| n_heads: int, | |
| bottleneck_size_p: int, | |
| ): | |
| super().__init__() | |
| self.E = n_hidden | |
| self.d_q = self.d_k = n_hidden | |
| bottleneck_size = bottleneck_size_p | |
| self.encoder = ReferenceEncoder( | |
| ref_enc_filters=ref_enc_filters, | |
| ref_enc_gru_size=ref_enc_gru_size, | |
| ref_enc_size=ref_enc_size, | |
| ref_enc_strides=ref_enc_strides, | |
| num_mels=num_mels, | |
| ) | |
| self.encoder_prj = nn.Linear(ref_enc_gru_size, n_hidden) | |
| self.attention = ConformerMultiHeadedSelfAttention( | |
| d_model=n_hidden, | |
| num_heads=n_heads, | |
| dropout_p=dropout, | |
| ) | |
| self.encoder_bottleneck = nn.Linear(n_hidden, bottleneck_size) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| src_mask: torch.Tensor, | |
| mels: torch.Tensor, | |
| mel_lens: torch.Tensor, | |
| encoding: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| x --- [N, seq_len, encoder_embedding_dim] | |
| mels --- [N, Ty/r, n_mels*r], r=1 | |
| out --- [N, seq_len, bottleneck_size] | |
| attn --- [N, seq_len, ref_len], Ty/r = ref_len | |
| """ | |
| embedded_prosody, _, mel_masks = self.encoder(mels, mel_lens) | |
| # Bottleneck | |
| embedded_prosody = self.encoder_prj(embedded_prosody) | |
| attn_mask = mel_masks.view((mel_masks.shape[0], 1, 1, -1)) | |
| x, _ = self.attention( | |
| query=x, | |
| key=embedded_prosody, | |
| value=embedded_prosody, | |
| mask=attn_mask, | |
| encoding=encoding, | |
| ) | |
| x = self.encoder_bottleneck(x) | |
| x = x.masked_fill(src_mask.unsqueeze(-1), 0.0) | |
| return x | |