import torch from torch import nn from TTS.tts.layers.generic.normalization import ActNorm from TTS.tts.layers.glow_tts.glow import CouplingBlock, InvConvNear def squeeze(x, x_mask=None, num_sqz=2): """GlowTTS squeeze operation Increase number of channels and reduce number of time steps by the same factor. Note: each 's' is a n-dimensional vector. ``[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]`` """ b, c, t = x.size() t = (t // num_sqz) * num_sqz x = x[:, :, :t] x_sqz = x.view(b, c, t // num_sqz, num_sqz) x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * num_sqz, t // num_sqz) if x_mask is not None: x_mask = x_mask[:, :, num_sqz - 1 :: num_sqz] else: x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device, dtype=x.dtype) return x_sqz * x_mask, x_mask def unsqueeze(x, x_mask=None, num_sqz=2): """GlowTTS unsqueeze operation (revert the squeeze) Note: each 's' is a n-dimensional vector. ``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5, s2, s4, s6]]`` """ b, c, t = x.size() x_unsqz = x.view(b, num_sqz, c // num_sqz, t) x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // num_sqz, t * num_sqz) if x_mask is not None: x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, num_sqz).view(b, 1, t * num_sqz) else: x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device, dtype=x.dtype) return x_unsqz * x_mask, x_mask class Decoder(nn.Module): """Stack of Glow Decoder Modules. :: Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze Args: in_channels (int): channels of input tensor. hidden_channels (int): hidden decoder channels. kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.) dilation_rate (int): rate to increase dilation by each layer in a decoder block. num_flow_blocks (int): number of decoder blocks. num_coupling_layers (int): number coupling layers. (number of wavenet layers.) dropout_p (float): wavenet dropout rate. sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer. """ def __init__( self, in_channels, hidden_channels, kernel_size, dilation_rate, num_flow_blocks, num_coupling_layers, dropout_p=0.0, num_splits=4, num_squeeze=2, sigmoid_scale=False, c_in_channels=0, ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.kernel_size = kernel_size self.dilation_rate = dilation_rate self.num_flow_blocks = num_flow_blocks self.num_coupling_layers = num_coupling_layers self.dropout_p = dropout_p self.num_splits = num_splits self.num_squeeze = num_squeeze self.sigmoid_scale = sigmoid_scale self.c_in_channels = c_in_channels self.flows = nn.ModuleList() for _ in range(num_flow_blocks): self.flows.append(ActNorm(channels=in_channels * num_squeeze)) self.flows.append(InvConvNear(channels=in_channels * num_squeeze, num_splits=num_splits)) self.flows.append( CouplingBlock( in_channels * num_squeeze, hidden_channels, kernel_size=kernel_size, dilation_rate=dilation_rate, num_layers=num_coupling_layers, c_in_channels=c_in_channels, dropout_p=dropout_p, sigmoid_scale=sigmoid_scale, ) ) def forward(self, x, x_mask, g=None, reverse=False): """ Shapes: - x: :math:`[B, C, T]` - x_mask: :math:`[B, 1 ,T]` - g: :math:`[B, C]` """ if not reverse: flows = self.flows logdet_tot = 0 else: flows = reversed(self.flows) logdet_tot = None if self.num_squeeze > 1: x, x_mask = squeeze(x, x_mask, self.num_squeeze) for f in flows: if not reverse: x, logdet = f(x, x_mask, g=g, reverse=reverse) logdet_tot += logdet else: x, logdet = f(x, x_mask, g=g, reverse=reverse) if self.num_squeeze > 1: x, x_mask = unsqueeze(x, x_mask, self.num_squeeze) return x, logdet_tot def store_inverse(self): for f in self.flows: f.store_inverse()