# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # This code is modified from https://github.com/jaywalnut310/vits/blob/main/models.pyimport torch from torch import nn from torch.nn import functional as F import math from modules.flow.modules import * class StochasticDurationPredictor(nn.Module): def __init__( self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0, ): super().__init__() filter_channels = in_channels self.in_channels = in_channels self.filter_channels = filter_channels self.kernel_size = kernel_size self.p_dropout = p_dropout self.n_flows = n_flows self.gin_channels = gin_channels self.log_flow = Log() self.flows = nn.ModuleList() self.flows.append(ElementwiseAffine(2)) for i in range(n_flows): self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) self.flows.append(Flip()) self.post_pre = nn.Conv1d(1, filter_channels, 1) self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) self.post_convs = DDSConv( filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout ) self.post_flows = nn.ModuleList() self.post_flows.append(ElementwiseAffine(2)) for i in range(4): self.post_flows.append( ConvFlow(2, filter_channels, kernel_size, n_layers=3) ) self.post_flows.append(Flip()) self.pre = nn.Conv1d(in_channels, filter_channels, 1) self.proj = nn.Conv1d(filter_channels, filter_channels, 1) self.convs = DDSConv( filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout ) if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, filter_channels, 1) def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): x = torch.detach(x) x = self.pre(x) if g is not None: g = torch.detach(g) x = x + self.cond(g) x = self.convs(x, x_mask) x = self.proj(x) * x_mask if not reverse: flows = self.flows assert w is not None logdet_tot_q = 0 h_w = self.post_pre(w) h_w = self.post_convs(h_w, x_mask) h_w = self.post_proj(h_w) * x_mask e_q = ( torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask ) z_q = e_q for flow in self.post_flows: z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) logdet_tot_q += logdet_q z_u, z1 = torch.split(z_q, [1, 1], 1) u = torch.sigmoid(z_u) * x_mask z0 = (w - u) * x_mask logdet_tot_q += torch.sum( (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2] ) logq = ( torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q ) logdet_tot = 0 z0, logdet = self.log_flow(z0, x_mask) logdet_tot += logdet z = torch.cat([z0, z1], 1) for flow in flows: z, logdet = flow(z, x_mask, g=x, reverse=reverse) logdet_tot = logdet_tot + logdet nll = ( torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot ) return nll + logq else: flows = list(reversed(self.flows)) flows = flows[:-2] + [flows[-1]] z = ( torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale ) for flow in flows: z = flow(z, x_mask, g=x, reverse=reverse) z0, z1 = torch.split(z, [1, 1], 1) logw = z0 return logw