import torch import torch.nn as nn import numpy as np import torch.nn.functional as F from models.tts.naturalspeech2.diffusion import Diffusion from models.tts.naturalspeech2.diffusion_flow import DiffusionFlow from models.tts.naturalspeech2.wavenet import WaveNet from models.tts.naturalspeech2.prior_encoder import PriorEncoder from modules.naturalpseech2.transformers import TransformerEncoder from encodec import EncodecModel from einops import rearrange, repeat import os import json class NaturalSpeech2(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg self.latent_dim = cfg.latent_dim self.query_emb_num = cfg.query_emb.query_token_num self.prior_encoder = PriorEncoder(cfg.prior_encoder) if cfg.diffusion.diffusion_type == "diffusion": self.diffusion = Diffusion(cfg.diffusion) elif cfg.diffusion.diffusion_type == "flow": self.diffusion = DiffusionFlow(cfg.diffusion) self.prompt_encoder = TransformerEncoder(cfg=cfg.prompt_encoder) if self.latent_dim != cfg.prompt_encoder.encoder_hidden: self.prompt_lin = nn.Linear( self.latent_dim, cfg.prompt_encoder.encoder_hidden ) self.prompt_lin.weight.data.normal_(0.0, 0.02) else: self.prompt_lin = None self.query_emb = nn.Embedding(self.query_emb_num, cfg.query_emb.hidden_size) self.query_attn = nn.MultiheadAttention( cfg.query_emb.hidden_size, cfg.query_emb.head_num, batch_first=True ) codec_model = EncodecModel.encodec_model_24khz() codec_model.set_target_bandwidth(12.0) codec_model.requires_grad_(False) self.quantizer = codec_model.quantizer @torch.no_grad() def code_to_latent(self, code): latent = self.quantizer.decode(code.transpose(0, 1)) return latent def latent_to_code(self, latent, nq=16): residual = latent all_indices = [] all_dist = [] for i in range(nq): layer = self.quantizer.vq.layers[i] x = rearrange(residual, "b d n -> b n d") x = layer.project_in(x) shape = x.shape x = layer._codebook.preprocess(x) embed = layer._codebook.embed.t() dist = -( x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True) ) indices = dist.max(dim=-1).indices indices = layer._codebook.postprocess_emb(indices, shape) dist = dist.reshape(*shape[:-1], dist.shape[-1]) quantized = layer.decode(indices) residual = residual - quantized all_indices.append(indices) all_dist.append(dist) out_indices = torch.stack(all_indices) out_dist = torch.stack(all_dist) return out_indices, out_dist # (nq, B, T); (nq, B, T, 1024) @torch.no_grad() def latent_to_latent(self, latent, nq=16): codes, _ = self.latent_to_code(latent, nq) latent = self.quantizer.vq.decode(codes) return latent def forward( self, code=None, pitch=None, duration=None, phone_id=None, phone_id_frame=None, frame_nums=None, ref_code=None, ref_frame_nums=None, phone_mask=None, mask=None, ref_mask=None, ): ref_latent = self.code_to_latent(ref_code) latent = self.code_to_latent(code) if self.latent_dim is not None: ref_latent = self.prompt_lin(ref_latent.transpose(1, 2)) ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None) spk_emb = ref_latent.transpose(1, 2) # (B, d, T') spk_query_emb = self.query_emb( torch.arange(self.query_emb_num).to(latent.device) ).repeat( latent.shape[0], 1, 1 ) # (B, query_emb_num, d) spk_query_emb, _ = self.query_attn( spk_query_emb, spk_emb.transpose(1, 2), spk_emb.transpose(1, 2), key_padding_mask=~(ref_mask.bool()), ) # (B, query_emb_num, d) prior_out = self.prior_encoder( phone_id=phone_id, duration=duration, pitch=pitch, phone_mask=phone_mask, mask=mask, ref_emb=spk_emb, ref_mask=ref_mask, is_inference=False, ) prior_condition = prior_out["prior_out"] # (B, T, d) diff_out = self.diffusion(latent, mask, prior_condition, spk_query_emb) return diff_out, prior_out @torch.no_grad() def inference( self, ref_code=None, phone_id=None, ref_mask=None, inference_steps=1000 ): ref_latent = self.code_to_latent(ref_code) if self.latent_dim is not None: ref_latent = self.prompt_lin(ref_latent.transpose(1, 2)) ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None) spk_emb = ref_latent.transpose(1, 2) # (B, d, T') spk_query_emb = self.query_emb( torch.arange(self.query_emb_num).to(ref_latent.device) ).repeat( ref_latent.shape[0], 1, 1 ) # (B, query_emb_num, d) spk_query_emb, _ = self.query_attn( spk_query_emb, spk_emb.transpose(1, 2), spk_emb.transpose(1, 2), key_padding_mask=~(ref_mask.bool()), ) # (B, query_emb_num, d) prior_out = self.prior_encoder( phone_id=phone_id, duration=None, pitch=None, phone_mask=None, mask=None, ref_emb=spk_emb, ref_mask=ref_mask, is_inference=True, ) prior_condition = prior_out["prior_out"] # (B, T, d) z = torch.randn( prior_condition.shape[0], self.latent_dim, prior_condition.shape[1] ).to(ref_latent.device) / (1.20) x0 = self.diffusion.reverse_diffusion( z, None, prior_condition, inference_steps, spk_query_emb ) return x0, prior_out @torch.no_grad() def reverse_diffusion_from_t( self, code=None, pitch=None, duration=None, phone_id=None, ref_code=None, phone_mask=None, mask=None, ref_mask=None, n_timesteps=None, t=None, ): # o Only for debug ref_latent = self.code_to_latent(ref_code) latent = self.code_to_latent(code) if self.latent_dim is not None: ref_latent = self.prompt_lin(ref_latent.transpose(1, 2)) ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None) spk_emb = ref_latent.transpose(1, 2) # (B, d, T') spk_query_emb = self.query_emb( torch.arange(self.query_emb_num).to(latent.device) ).repeat( latent.shape[0], 1, 1 ) # (B, query_emb_num, d) spk_query_emb, _ = self.query_attn( spk_query_emb, spk_emb.transpose(1, 2), spk_emb.transpose(1, 2), key_padding_mask=~(ref_mask.bool()), ) # (B, query_emb_num, d) prior_out = self.prior_encoder( phone_id=phone_id, duration=duration, pitch=pitch, phone_mask=phone_mask, mask=mask, ref_emb=spk_emb, ref_mask=ref_mask, is_inference=False, ) prior_condition = prior_out["prior_out"] # (B, T, d) diffusion_step = ( torch.ones( latent.shape[0], dtype=latent.dtype, device=latent.device, requires_grad=False, ) * t ) diffusion_step = torch.clamp(diffusion_step, 1e-5, 1.0 - 1e-5) xt, _ = self.diffusion.forward_diffusion( x0=latent, diffusion_step=diffusion_step ) # print(torch.abs(xt-latent).max(), torch.abs(xt-latent).mean(), torch.abs(xt-latent).std()) x0 = self.diffusion.reverse_diffusion_from_t( xt, mask, prior_condition, n_timesteps, spk_query_emb, t_start=t ) return x0, prior_out, xt