import torch from torch import nn from torch.nn import functional as F from vits import attentions from vits import commons from vits import modules from vits.utils import f0_to_coarse from vits_decoder.generator import Generator from vits.modules_grl import SpeakerClassifier class TextEncoder(nn.Module): def __init__(self, in_channels, vec_channels, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout): super().__init__() self.out_channels = out_channels self.pre = nn.Conv1d(in_channels, hidden_channels, kernel_size=5, padding=2) self.hub = nn.Conv1d(vec_channels, hidden_channels, kernel_size=5, padding=2) self.pit = nn.Embedding(256, hidden_channels) self.enc = attentions.Encoder( hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, x, x_lengths, v, f0): x = torch.transpose(x, 1, -1) # [b, h, t] x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( x.dtype ) x = self.pre(x) * x_mask v = torch.transpose(v, 1, -1) # [b, h, t] v = self.hub(v) * x_mask x = x + v + self.pit(f0).transpose(1, 2) x = self.enc(x * x_mask, x_mask) stats = self.proj(x) * x_mask m, logs = torch.split(stats, self.out_channels, dim=1) z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask return z, m, logs, x_mask, x class ResidualCouplingBlock(nn.Module): def __init__( self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0, ): super().__init__() self.flows = nn.ModuleList() for i in range(n_flows): self.flows.append( modules.ResidualCouplingLayer( channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True, ) ) self.flows.append(modules.Flip()) def forward(self, x, x_mask, g=None, reverse=False): if not reverse: total_logdet = 0 for flow in self.flows: x, log_det = flow(x, x_mask, g=g, reverse=reverse) total_logdet += log_det return x, total_logdet else: total_logdet = 0 for flow in reversed(self.flows): x, log_det = flow(x, x_mask, g=g, reverse=reverse) total_logdet += log_det return x, total_logdet def remove_weight_norm(self): for i in range(self.n_flows): self.flows[i * 2].remove_weight_norm() class PosteriorEncoder(nn.Module): def __init__( self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, ): super().__init__() self.out_channels = out_channels self.pre = nn.Conv1d(in_channels, hidden_channels, 1) self.enc = modules.WN( hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) def forward(self, x, x_lengths, g=None): x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( x.dtype ) x = self.pre(x) * x_mask x = self.enc(x, x_mask, g=g) stats = self.proj(x) * x_mask m, logs = torch.split(stats, self.out_channels, dim=1) z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask return z, m, logs, x_mask def remove_weight_norm(self): self.enc.remove_weight_norm() class SynthesizerTrn(nn.Module): def __init__( self, spec_channels, segment_size, hp ): super().__init__() self.segment_size = segment_size self.emb_g = nn.Linear(hp.vits.spk_dim, hp.vits.gin_channels) self.enc_p = TextEncoder( hp.vits.ppg_dim, hp.vits.vec_dim, hp.vits.inter_channels, hp.vits.hidden_channels, hp.vits.filter_channels, 2, 6, 3, 0.1, ) self.speaker_classifier = SpeakerClassifier( hp.vits.hidden_channels, hp.vits.spk_dim, ) self.enc_q = PosteriorEncoder( spec_channels, hp.vits.inter_channels, hp.vits.hidden_channels, 5, 1, 16, gin_channels=hp.vits.gin_channels, ) self.flow = ResidualCouplingBlock( hp.vits.inter_channels, hp.vits.hidden_channels, 5, 1, 4, gin_channels=hp.vits.spk_dim ) self.dec = Generator(hp=hp) def forward(self, ppg, vec, pit, spec, spk, ppg_l, spec_l): ppg = ppg + torch.randn_like(ppg) * 1 # Perturbation vec = vec + torch.randn_like(vec) * 2 # Perturbation g = self.emb_g(F.normalize(spk)).unsqueeze(-1) z_p, m_p, logs_p, ppg_mask, x = self.enc_p( ppg, ppg_l, vec, f0=f0_to_coarse(pit)) z_q, m_q, logs_q, spec_mask = self.enc_q(spec, spec_l, g=g) z_slice, pit_slice, ids_slice = commons.rand_slice_segments_with_pitch( z_q, pit, spec_l, self.segment_size) audio = self.dec(spk, z_slice, pit_slice) # SNAC to flow z_f, logdet_f = self.flow(z_q, spec_mask, g=spk) z_r, logdet_r = self.flow(z_p, spec_mask, g=spk, reverse=True) # speaker spk_preds = self.speaker_classifier(x) return audio, ids_slice, spec_mask, (z_f, z_r, z_p, m_p, logs_p, z_q, m_q, logs_q, logdet_f, logdet_r), spk_preds def infer(self, ppg, vec, pit, spk, ppg_l): ppg = ppg + torch.randn_like(ppg) * 0.0001 # Perturbation z_p, m_p, logs_p, ppg_mask, x = self.enc_p( ppg, ppg_l, vec, f0=f0_to_coarse(pit)) z, _ = self.flow(z_p, ppg_mask, g=spk, reverse=True) o = self.dec(spk, z * ppg_mask, f0=pit) return o class SynthesizerInfer(nn.Module): def __init__( self, spec_channels, segment_size, hp ): super().__init__() self.segment_size = segment_size self.enc_p = TextEncoder( hp.vits.ppg_dim, hp.vits.vec_dim, hp.vits.inter_channels, hp.vits.hidden_channels, hp.vits.filter_channels, 2, 6, 3, 0.1, ) self.flow = ResidualCouplingBlock( hp.vits.inter_channels, hp.vits.hidden_channels, 5, 1, 4, gin_channels=hp.vits.spk_dim ) self.dec = Generator(hp=hp) def remove_weight_norm(self): self.flow.remove_weight_norm() self.dec.remove_weight_norm() def pitch2source(self, f0): return self.dec.pitch2source(f0) def source2wav(self, source): return self.dec.source2wav(source) def inference(self, ppg, vec, pit, spk, ppg_l, source): z_p, m_p, logs_p, ppg_mask, x = self.enc_p( ppg, ppg_l, vec, f0=f0_to_coarse(pit)) z, _ = self.flow(z_p, ppg_mask, g=spk, reverse=True) o = self.dec.inference(spk, z * ppg_mask, source) return o