import torch from torch import nn from TTS.tts.layers.glow_tts.decoder import Decoder as GlowDecoder from TTS.tts.utils.helpers import sequence_mask class Decoder(nn.Module): """Uses glow decoder with some modifications. :: Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze Args: in_channels (int): channels of input tensor. hidden_channels (int): hidden decoder channels. kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.) dilation_rate (int): rate to increase dilation by each layer in a decoder block. num_flow_blocks (int): number of decoder blocks. num_coupling_layers (int): number coupling layers. (number of wavenet layers.) dropout_p (float): wavenet dropout rate. sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer. """ def __init__( self, in_channels, hidden_channels, kernel_size, dilation_rate, num_flow_blocks, num_coupling_layers, dropout_p=0.0, num_splits=4, num_squeeze=2, sigmoid_scale=False, c_in_channels=0, ): super().__init__() self.glow_decoder = GlowDecoder( in_channels, hidden_channels, kernel_size, dilation_rate, num_flow_blocks, num_coupling_layers, dropout_p, num_splits, num_squeeze, sigmoid_scale, c_in_channels, ) self.n_sqz = num_squeeze def forward(self, x, x_len, g=None, reverse=False): """ Input shapes: - x: :math:`[B, C, T]` - x_len :math:`[B]` - g: :math:`[B, C]` Output shapes: - x: :math:`[B, C, T]` - x_len :math:`[B]` - logget_tot :math:`[B]` """ x, x_len, x_max_len = self.preprocess(x, x_len, x_len.max()) x_mask = torch.unsqueeze(sequence_mask(x_len, x_max_len), 1).to(x.dtype) x, logdet_tot = self.glow_decoder(x, x_mask, g, reverse) return x, x_len, logdet_tot def preprocess(self, y, y_lengths, y_max_length): if y_max_length is not None: y_max_length = torch.div(y_max_length, self.n_sqz, rounding_mode="floor") * self.n_sqz y = y[:, :, :y_max_length] y_lengths = torch.div(y_lengths, self.n_sqz, rounding_mode="floor") * self.n_sqz return y, y_lengths, y_max_length def store_inverse(self): self.glow_decoder.store_inverse()