import hashlib import random import ray import torch import torch.nn.functional as F import torchaudio from copy import copy from datetime import datetime from fastapi import FastAPI from fastapi.responses import FileResponse from pathlib import Path from pydantic import BaseModel from ray import serve from time import time from typing import Any, Dict, List, Text, Tuple from constants import ( AUTO_REGRESSIVE_BATCH_SIZE, DIFFUSION, DIFFUSION_TEMPERATURE, GPT, LENGTH_PENALTY, MAX_MEL_TOKENS, NUM_AUTOREGRESSIVE_SAMPLES, REPETITION_PENALTY, TEMPERATURE, TOP_P, CLVP_const, ) from ruth_tts_transformer.models.autoregressive import UnifiedVoice from ruth_tts_transformer.models.clvp import CLVP from ruth_tts_transformer.models.diffusion_decoder import DiffusionTts from ruth_tts_transformer.models.vocoder import UnivNetGenerator from ruth_tts_transformer.utils.audio import load_voice from ruth_tts_transformer.utils.tokenizer import VoiceBpeTokenizer from ruth_tts_transformer.utils.wav2vec_alignment import Wav2VecAlignment from utils import ( MODELS_DIR, get_config_file, get_model_path, load_discrete_vocoder_diffuser, ) app = FastAPI() class Item(BaseModel): text: str voice: str seed: int = 3 class Gpt: def __init__( self, num_autoregressive_samples: int, top_p: float, temperature: float, length_penalty: int, repetition_penalty: float, max_mel_tokens: int, autoregressive_batch_size: int, ): self.num_autoregressive_samples = num_autoregressive_samples self.top_p = top_p self.temperature = temperature self.length_penalty = length_penalty self.repetition_penalty = repetition_penalty self.max_mel_tokens = max_mel_tokens self.autoregressive_batch_size = autoregressive_batch_size self.gpt = ( UnifiedVoice( max_mel_tokens=604, max_text_tokens=402, max_conditioning_inputs=2, layers=30, model_dim=1024, heads=16, number_text_tokens=255, start_text_token=255, checkpointing=False, train_solo_embeddings=False, ) .cpu() .eval() ) self.gpt.load_state_dict( torch.load(get_model_path("autoregressive.pth", MODELS_DIR)) ) self.gpt = self.gpt.to("cuda") def __num_batches(self): return self.num_autoregressive_samples // self.autoregressive_batch_size @staticmethod def deterministic_state(seed=None): seed = int(time()) if seed is None else seed torch.manual_seed(seed) random.seed(seed) return seed def parse(self, auto_conditioning, text_tokens, best_results, seed, k=1): self.deterministic_state(seed=seed) auto_conditioning = copy(auto_conditioning).to("cuda") text_tokens = copy(text_tokens).to("cuda") best_results = copy(best_results).to("cuda") best_latents = self.gpt( auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, torch.tensor( [best_results.shape[-1] * self.gpt.mel_length_compression], device=text_tokens.device, ), return_latent=True, clip_inputs=False, ) # return best_latents.cpu().detach().numpy() return best_latents def parse_inference( self, auto_conditioning: torch.Tensor, text_tokens: torch.Tensor, seed ) -> Tuple[List[torch.Tensor], int]: self.deterministic_state(seed=seed) auto_conditioning = copy(auto_conditioning).to("cuda") text_tokens = copy(text_tokens).to("cuda") with torch.no_grad(): samples = [] num_batches = self.__num_batches() for b in range(num_batches): codes = self.gpt.inference_speech( auto_conditioning, text_tokens, do_sample=True, top_p=self.top_p, temperature=self.temperature, num_return_sequences=self.autoregressive_batch_size, length_penalty=self.length_penalty, repetition_penalty=self.repetition_penalty, max_generate_length=self.max_mel_tokens, ) padding_needed = self.max_mel_tokens - codes.shape[1] codes = F.pad(codes, (0, padding_needed), value=self.gpt.stop_mel_token) # samples.append(codes.cpu().detach().numpy()) samples.append(codes) return samples, self.gpt.stop_mel_token class clvp: def __init__(self, K): self.clvp = ( CLVP( dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20, text_seq_len=350, text_heads=12, num_speech_tokens=8192, speech_enc_depth=20, speech_heads=12, speech_seq_len=430, use_xformers=True, ) .cpu() .eval() ) self.clvp.load_state_dict(torch.load(get_model_path("clvp2.pth", MODELS_DIR))) self.clvp.to("cuda") self.K = K @staticmethod def fix_gpt_output(codes, stop_token, complain=True): stop_token_indices = (codes == stop_token).nonzero() if len(stop_token_indices) == 0: if complain: print( "No stop tokens found in one of the generated voice clips. This typically means the spoken audio " "is " "too long. In some cases, the output will still be good, though. Listen to it and if it is " "missing words, " "try breaking up your input text." ) return codes else: codes[stop_token_indices] = 83 stm = stop_token_indices.min().item() codes[stm:] = 83 if stm - 3 < codes.shape[0]: codes[-3] = 45 codes[-2] = 45 codes[-1] = 248 return codes def parse( self, text_tokens: torch.Tensor, samples: List[torch.Tensor], stop_mel_token: int, seed: int, ) -> torch.Tensor: self.deterministic_state(seed=seed) clip_results = [] text_tokens = copy(text_tokens).to("cuda") samples = [copy(batch).to("cuda") for batch in samples] for batch in samples: for i in range(batch.shape[0]): batch[i] = self.fix_gpt_output(batch[i], stop_mel_token) clvp = self.clvp( text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False ) clip_results.append(clvp) clip_results = torch.cat(clip_results, dim=0) samples = torch.cat(samples, dim=0) # return samples[torch.topk(clip_results, self.K).indices].cpu().detach().numpy() return samples[torch.topk(clip_results, self.K).indices] @staticmethod def deterministic_state(seed=None): seed = int(time()) if seed is None else seed torch.manual_seed(seed) random.seed(seed) return seed class Diffusion: def __init__( self, diffusion_temperature, diffusion_iterations=30, cond_free=True, cond_free_k=2, ): self.diffusion_temperature = diffusion_temperature self.diffusion = ( DiffusionTts( model_channels=1024, num_layers=10, in_channels=100, out_channels=200, in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16, layer_drop=0, unconditioned_percentage=0, ) .cpu() .eval() ) self.diffusion.load_state_dict( torch.load(get_model_path("diffusion_decoder.pth", MODELS_DIR)) ) self.diffuser = load_discrete_vocoder_diffuser( desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k, ) self.vocoder = UnivNetGenerator().cpu() self.vocoder.load_state_dict( torch.load( get_model_path("vocoder.pth", MODELS_DIR), map_location=torch.device("cpu"), )["model_g"] ) self.vocoder.eval(inference=True) self.diffusion.to("cuda") self.vocoder.to("cuda") self.aligner = Wav2VecAlignment() # state = self.deterministic_state(seed=0) #Remove after testing self.TACOTRON_MEL_MAX = 2.3143386840820312 self.TACOTRON_MEL_MIN = -11.512925148010254 def denormalize_tacotron_mel(self, norm_mel): return ((norm_mel + 1) / 2) * ( self.TACOTRON_MEL_MAX - self.TACOTRON_MEL_MIN ) + self.TACOTRON_MEL_MIN def potentially_redact(self, clip, text): return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1) @staticmethod def deterministic_state(seed=None): seed = int(time()) if seed is None else seed torch.manual_seed(seed) random.seed(seed) return seed def do_spectrogram_diffusion( self, diffusion_model, diffuser, latents, conditioning_latents, seed, temperature=1, verbose=False, ): self.deterministic_state(seed=seed) with torch.no_grad(): output_seq_len = ( latents.shape[1] * 4 * 24000 // 22050 ) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. output_shape = (latents.shape[0], 100, output_seq_len) precomputed_embeddings = diffusion_model.timestep_independent( latents, conditioning_latents, output_seq_len, False ) noise = torch.randn(output_shape, device=latents.device) * temperature mel = diffuser.p_sample_loop( diffusion_model, output_shape, noise=noise, model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings}, progress=verbose, ) return self.denormalize_tacotron_mel(mel)[:, :, :output_seq_len] def parse( self, best_results, best_latents, calm_token, diffusion_conditioning, text, seed ): self.deterministic_state(seed=seed) best_results = copy(best_results).to("cuda") best_latents = copy(best_latents).to("cuda") diffusion_conditioning = copy(diffusion_conditioning).to("cuda") wav_candidates = [] for b in range(best_results.shape[0]): codes = best_results[b].unsqueeze(0) latents = best_latents[b].unsqueeze(0) ctokens = 0 for k in range(codes.shape[-1]): if codes[0, k] == calm_token: ctokens += 1 else: ctokens = 0 if ctokens > 8: latents = latents[:, :k] break mel = self.do_spectrogram_diffusion( self.diffusion, self.diffuser, latents, diffusion_conditioning, seed, temperature=self.diffusion_temperature, verbose=False, ) wav = self.vocoder.inference(mel) wav_candidates.append(wav) # wav_candidates = [self.potentially_redact(wav_candidate, text).cpu().detach().numpy() for wav_candidate in # wav_candidates] # TODO: Check whether wav candidates should be in numpy wav_candidates = [ self.potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates ] return wav_candidates # @serve.deployment( # name="orchestrator", # num_replicas=4, # ray_actor_options={"num_cpus": 8, "num_gpus": 0.5}, # ) class Orchestractor: def __init__(self, config: Dict[Text, Any]): self.calm_token = 83 self.tokenizer = VoiceBpeTokenizer() _, conditioning_latent_1 = load_voice("gabby_reading", map_location="cpu") _, conditioning_latent_2 = load_voice("gabby_conversation", map_location="cpu") # self.conditioning_latents1 = (latent.cpu().detach().numpy() for latent in conditioning_latent_1) # self.conditioning_latents2 = (latent.cpu().detach().numpy() for latent in conditioning_latent_2) self.conditioning_latents1 = (latent for latent in conditioning_latent_1) self.conditioning_latents2 = (latent for latent in conditioning_latent_2) ( self.auto_conditioning1, self.diffusion_conditioning1, ) = self.conditioning_latents1 ( self.auto_conditioning2, self.diffusion_conditioning2, ) = self.conditioning_latents2 self.auto_conditioning = None self.diffusion_conditioning = None self.gpt = Gpt( config[GPT][NUM_AUTOREGRESSIVE_SAMPLES], config[GPT][TOP_P], config[GPT][TEMPERATURE], config[GPT][LENGTH_PENALTY], config[GPT][REPETITION_PENALTY], config[GPT][MAX_MEL_TOKENS], config[GPT][AUTO_REGRESSIVE_BATCH_SIZE], ) self.clvp = clvp(config[CLVP_const]["k"]) self.diffusion = Diffusion(config[DIFFUSION][DIFFUSION_TEMPERATURE]) self.calm_token = 83 print("orchestrator setup completed") @staticmethod def __check_for_long_sentence(text_tokens): assert ( text_tokens.shape[-1] < 400 ), "Too much text provided. Break the text up into separate segments and re-try inference." # TODO: split the text into several pieces and do the generation and combine them last @staticmethod def deterministic_state(seed=None): seed = int(time()) if seed is None else seed torch.manual_seed(seed) random.seed(seed) return seed def preprocess_text(self, text: Text): torch_tensor = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0) return torch_tensor def parse(self, res): print("parsing") file_name = hashlib.sha1(str(datetime.now()).encode("UTF-8")) res = [torch.Tensor(copy(split)).squeeze(0).cpu() for split in res] res = [torch.flatten(split) for split in res] merged_audio_tensor = torch.cat(res).reshape(1, -1) torchaudio.save(f"./{file_name.hexdigest()}.wav", merged_audio_tensor, 24000) # torchaudio.save(f"./{file_name.hexdigest()}.wav", torch.Tensor(copy(res)).squeeze(0).cpu(), 24000) return file_name.hexdigest() def generate(self, text, voice, seed): if voice == "gabby_reading": self.auto_conditioning = self.auto_conditioning1 self.diffusion_conditioning = self.diffusion_conditioning1 elif voice == "gabby_conversation": self.auto_conditioning = self.auto_conditioning2 self.diffusion_conditioning = self.diffusion_conditioning2 self.deterministic_state(seed=seed) text_tokens = self.preprocess_text( text ) # preprocess the in-coming text into tokens self.__check_for_long_sentence(text_tokens) # text_tokens = text_tokens.cpu().detach().numpy() samples, stop_mel_token = self.gpt.parse_inference( self.auto_conditioning, text_tokens, seed ) best_sample = self.clvp.parse(text_tokens, samples, stop_mel_token, seed) best_latent = self.gpt.parse( self.auto_conditioning, text_tokens, best_sample, seed ) wav_candidates = self.diffusion.parse( best_sample, best_latent, self.calm_token, self.diffusion_conditioning, text, seed, ) if len(wav_candidates) > 1: res = wav_candidates else: res = wav_candidates[0] return res.cpu() # @app.on_event("startup") # def startup_event(): # ray.init(address="auto") # serve.start() # config = get_config_file(Path("config-model.yaml")) # Orchestractor.deploy(config) # orchestrator = serve.get_deployment("orchestrator") # orchestrator = orchestrator.get_handle() # app.deploy = orchestrator # @app.on_event("shutdown") # def shutdown_event(): # ray.shutdown() # # # @app.post("/convert") def model1_deployment(voice="gabby_reading", text="hello how are you!", seed=3): serve.start(detached=True) config = get_config_file(Path("config-model.yaml")) Orchestractor.deploy(config) orchestrator = serve.get_deployment("orchestrator") orchestrator = orchestrator.get_handle() app.deploy = orchestrator if voice == "gabby_reading" or voice == "gabby_convo": sentences = text.split(". ") if len(sentences) > 1: values = ray.get( [ app.deploy.generate.remote( text=sentence, voice=voice, seed=seed ) for sentence in sentences ] ) else: values = [ ray.get(app.deploy.generate.remote(text, voice, seed)) ] file_name = ray.get(app.deploy.parse.remote(values)) return FileResponse(f"./{file_name}.wav") else: return f"{voice} not available!" if __name__ == "__main__": config = get_config_file(Path("config-model.yaml")) orches = Orchestractor(config) orches.generate(text="hello how are you doing from prakash!", voice="gabby_reading", seed=3)