ruth-tts / serve.py
prakashr7d's picture
written handler
b0bf39f
import hashlib
import random
import ray
import torch
import torch.nn.functional as F
import torchaudio
from copy import copy
from datetime import datetime
from fastapi import FastAPI
from fastapi.responses import FileResponse
from pathlib import Path
from pydantic import BaseModel
from ray import serve
from time import time
from typing import Any, Dict, List, Text, Tuple
from constants import (
AUTO_REGRESSIVE_BATCH_SIZE,
DIFFUSION,
DIFFUSION_TEMPERATURE,
GPT,
LENGTH_PENALTY,
MAX_MEL_TOKENS,
NUM_AUTOREGRESSIVE_SAMPLES,
REPETITION_PENALTY,
TEMPERATURE,
TOP_P,
CLVP_const,
)
from ruth_tts_transformer.models.autoregressive import UnifiedVoice
from ruth_tts_transformer.models.clvp import CLVP
from ruth_tts_transformer.models.diffusion_decoder import DiffusionTts
from ruth_tts_transformer.models.vocoder import UnivNetGenerator
from ruth_tts_transformer.utils.audio import load_voice
from ruth_tts_transformer.utils.tokenizer import VoiceBpeTokenizer
from ruth_tts_transformer.utils.wav2vec_alignment import Wav2VecAlignment
from utils import (
MODELS_DIR,
get_config_file,
get_model_path,
load_discrete_vocoder_diffuser,
)
app = FastAPI()
class Item(BaseModel):
text: str
voice: str
seed: int = 3
class Gpt:
def __init__(
self,
num_autoregressive_samples: int,
top_p: float,
temperature: float,
length_penalty: int,
repetition_penalty: float,
max_mel_tokens: int,
autoregressive_batch_size: int,
):
self.num_autoregressive_samples = num_autoregressive_samples
self.top_p = top_p
self.temperature = temperature
self.length_penalty = length_penalty
self.repetition_penalty = repetition_penalty
self.max_mel_tokens = max_mel_tokens
self.autoregressive_batch_size = autoregressive_batch_size
self.gpt = (
UnifiedVoice(
max_mel_tokens=604,
max_text_tokens=402,
max_conditioning_inputs=2,
layers=30,
model_dim=1024,
heads=16,
number_text_tokens=255,
start_text_token=255,
checkpointing=False,
train_solo_embeddings=False,
)
.cpu()
.eval()
)
self.gpt.load_state_dict(
torch.load(get_model_path("autoregressive.pth", MODELS_DIR))
)
self.gpt = self.gpt.to("cuda")
def __num_batches(self):
return self.num_autoregressive_samples // self.autoregressive_batch_size
@staticmethod
def deterministic_state(seed=None):
seed = int(time()) if seed is None else seed
torch.manual_seed(seed)
random.seed(seed)
return seed
def parse(self, auto_conditioning, text_tokens, best_results, seed, k=1):
self.deterministic_state(seed=seed)
auto_conditioning = copy(auto_conditioning).to("cuda")
text_tokens = copy(text_tokens).to("cuda")
best_results = copy(best_results).to("cuda")
best_latents = self.gpt(
auto_conditioning.repeat(k, 1),
text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
best_results,
torch.tensor(
[best_results.shape[-1] * self.gpt.mel_length_compression],
device=text_tokens.device,
),
return_latent=True,
clip_inputs=False,
)
# return best_latents.cpu().detach().numpy()
return best_latents
def parse_inference(
self, auto_conditioning: torch.Tensor, text_tokens: torch.Tensor, seed
) -> Tuple[List[torch.Tensor], int]:
self.deterministic_state(seed=seed)
auto_conditioning = copy(auto_conditioning).to("cuda")
text_tokens = copy(text_tokens).to("cuda")
with torch.no_grad():
samples = []
num_batches = self.__num_batches()
for b in range(num_batches):
codes = self.gpt.inference_speech(
auto_conditioning,
text_tokens,
do_sample=True,
top_p=self.top_p,
temperature=self.temperature,
num_return_sequences=self.autoregressive_batch_size,
length_penalty=self.length_penalty,
repetition_penalty=self.repetition_penalty,
max_generate_length=self.max_mel_tokens,
)
padding_needed = self.max_mel_tokens - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=self.gpt.stop_mel_token)
# samples.append(codes.cpu().detach().numpy())
samples.append(codes)
return samples, self.gpt.stop_mel_token
class clvp:
def __init__(self, K):
self.clvp = (
CLVP(
dim_text=768,
dim_speech=768,
dim_latent=768,
num_text_tokens=256,
text_enc_depth=20,
text_seq_len=350,
text_heads=12,
num_speech_tokens=8192,
speech_enc_depth=20,
speech_heads=12,
speech_seq_len=430,
use_xformers=True,
)
.cpu()
.eval()
)
self.clvp.load_state_dict(torch.load(get_model_path("clvp2.pth", MODELS_DIR)))
self.clvp.to("cuda")
self.K = K
@staticmethod
def fix_gpt_output(codes, stop_token, complain=True):
stop_token_indices = (codes == stop_token).nonzero()
if len(stop_token_indices) == 0:
if complain:
print(
"No stop tokens found in one of the generated voice clips. This typically means the spoken audio "
"is "
"too long. In some cases, the output will still be good, though. Listen to it and if it is "
"missing words, "
"try breaking up your input text."
)
return codes
else:
codes[stop_token_indices] = 83
stm = stop_token_indices.min().item()
codes[stm:] = 83
if stm - 3 < codes.shape[0]:
codes[-3] = 45
codes[-2] = 45
codes[-1] = 248
return codes
def parse(
self,
text_tokens: torch.Tensor,
samples: List[torch.Tensor],
stop_mel_token: int,
seed: int,
) -> torch.Tensor:
self.deterministic_state(seed=seed)
clip_results = []
text_tokens = copy(text_tokens).to("cuda")
samples = [copy(batch).to("cuda") for batch in samples]
for batch in samples:
for i in range(batch.shape[0]):
batch[i] = self.fix_gpt_output(batch[i], stop_mel_token)
clvp = self.clvp(
text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False
)
clip_results.append(clvp)
clip_results = torch.cat(clip_results, dim=0)
samples = torch.cat(samples, dim=0)
# return samples[torch.topk(clip_results, self.K).indices].cpu().detach().numpy()
return samples[torch.topk(clip_results, self.K).indices]
@staticmethod
def deterministic_state(seed=None):
seed = int(time()) if seed is None else seed
torch.manual_seed(seed)
random.seed(seed)
return seed
class Diffusion:
def __init__(
self,
diffusion_temperature,
diffusion_iterations=30,
cond_free=True,
cond_free_k=2,
):
self.diffusion_temperature = diffusion_temperature
self.diffusion = (
DiffusionTts(
model_channels=1024,
num_layers=10,
in_channels=100,
out_channels=200,
in_latent_channels=1024,
in_tokens=8193,
dropout=0,
use_fp16=False,
num_heads=16,
layer_drop=0,
unconditioned_percentage=0,
)
.cpu()
.eval()
)
self.diffusion.load_state_dict(
torch.load(get_model_path("diffusion_decoder.pth", MODELS_DIR))
)
self.diffuser = load_discrete_vocoder_diffuser(
desired_diffusion_steps=diffusion_iterations,
cond_free=cond_free,
cond_free_k=cond_free_k,
)
self.vocoder = UnivNetGenerator().cpu()
self.vocoder.load_state_dict(
torch.load(
get_model_path("vocoder.pth", MODELS_DIR),
map_location=torch.device("cpu"),
)["model_g"]
)
self.vocoder.eval(inference=True)
self.diffusion.to("cuda")
self.vocoder.to("cuda")
self.aligner = Wav2VecAlignment()
# state = self.deterministic_state(seed=0) #Remove after testing
self.TACOTRON_MEL_MAX = 2.3143386840820312
self.TACOTRON_MEL_MIN = -11.512925148010254
def denormalize_tacotron_mel(self, norm_mel):
return ((norm_mel + 1) / 2) * (
self.TACOTRON_MEL_MAX - self.TACOTRON_MEL_MIN
) + self.TACOTRON_MEL_MIN
def potentially_redact(self, clip, text):
return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
@staticmethod
def deterministic_state(seed=None):
seed = int(time()) if seed is None else seed
torch.manual_seed(seed)
random.seed(seed)
return seed
def do_spectrogram_diffusion(
self,
diffusion_model,
diffuser,
latents,
conditioning_latents,
seed,
temperature=1,
verbose=False,
):
self.deterministic_state(seed=seed)
with torch.no_grad():
output_seq_len = (
latents.shape[1] * 4 * 24000 // 22050
) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (latents.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion_model.timestep_independent(
latents, conditioning_latents, output_seq_len, False
)
noise = torch.randn(output_shape, device=latents.device) * temperature
mel = diffuser.p_sample_loop(
diffusion_model,
output_shape,
noise=noise,
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
progress=verbose,
)
return self.denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
def parse(
self, best_results, best_latents, calm_token, diffusion_conditioning, text, seed
):
self.deterministic_state(seed=seed)
best_results = copy(best_results).to("cuda")
best_latents = copy(best_latents).to("cuda")
diffusion_conditioning = copy(diffusion_conditioning).to("cuda")
wav_candidates = []
for b in range(best_results.shape[0]):
codes = best_results[b].unsqueeze(0)
latents = best_latents[b].unsqueeze(0)
ctokens = 0
for k in range(codes.shape[-1]):
if codes[0, k] == calm_token:
ctokens += 1
else:
ctokens = 0
if ctokens > 8:
latents = latents[:, :k]
break
mel = self.do_spectrogram_diffusion(
self.diffusion,
self.diffuser,
latents,
diffusion_conditioning,
seed,
temperature=self.diffusion_temperature,
verbose=False,
)
wav = self.vocoder.inference(mel)
wav_candidates.append(wav)
# wav_candidates = [self.potentially_redact(wav_candidate, text).cpu().detach().numpy() for wav_candidate in
# wav_candidates]
# TODO: Check whether wav candidates should be in numpy
wav_candidates = [
self.potentially_redact(wav_candidate, text)
for wav_candidate in wav_candidates
]
return wav_candidates
# @serve.deployment(
# name="orchestrator",
# num_replicas=4,
# ray_actor_options={"num_cpus": 8, "num_gpus": 0.5},
# )
class Orchestractor:
def __init__(self, config: Dict[Text, Any]):
self.calm_token = 83
self.tokenizer = VoiceBpeTokenizer()
_, conditioning_latent_1 = load_voice("gabby_reading", map_location="cpu")
_, conditioning_latent_2 = load_voice("gabby_conversation", map_location="cpu")
# self.conditioning_latents1 = (latent.cpu().detach().numpy() for latent in conditioning_latent_1)
# self.conditioning_latents2 = (latent.cpu().detach().numpy() for latent in conditioning_latent_2)
self.conditioning_latents1 = (latent for latent in conditioning_latent_1)
self.conditioning_latents2 = (latent for latent in conditioning_latent_2)
(
self.auto_conditioning1,
self.diffusion_conditioning1,
) = self.conditioning_latents1
(
self.auto_conditioning2,
self.diffusion_conditioning2,
) = self.conditioning_latents2
self.auto_conditioning = None
self.diffusion_conditioning = None
self.gpt = Gpt(
config[GPT][NUM_AUTOREGRESSIVE_SAMPLES],
config[GPT][TOP_P],
config[GPT][TEMPERATURE],
config[GPT][LENGTH_PENALTY],
config[GPT][REPETITION_PENALTY],
config[GPT][MAX_MEL_TOKENS],
config[GPT][AUTO_REGRESSIVE_BATCH_SIZE],
)
self.clvp = clvp(config[CLVP_const]["k"])
self.diffusion = Diffusion(config[DIFFUSION][DIFFUSION_TEMPERATURE])
self.calm_token = 83
print("orchestrator setup completed")
@staticmethod
def __check_for_long_sentence(text_tokens):
assert (
text_tokens.shape[-1] < 400
), "Too much text provided. Break the text up into separate segments and re-try inference."
# TODO: split the text into several pieces and do the generation and combine them last
@staticmethod
def deterministic_state(seed=None):
seed = int(time()) if seed is None else seed
torch.manual_seed(seed)
random.seed(seed)
return seed
def preprocess_text(self, text: Text):
torch_tensor = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0)
return torch_tensor
def parse(self, res):
print("parsing")
file_name = hashlib.sha1(str(datetime.now()).encode("UTF-8"))
res = [torch.Tensor(copy(split)).squeeze(0).cpu() for split in res]
res = [torch.flatten(split) for split in res]
merged_audio_tensor = torch.cat(res).reshape(1, -1)
torchaudio.save(f"./{file_name.hexdigest()}.wav", merged_audio_tensor, 24000)
# torchaudio.save(f"./{file_name.hexdigest()}.wav", torch.Tensor(copy(res)).squeeze(0).cpu(), 24000)
return file_name.hexdigest()
def generate(self, text, voice, seed):
if voice == "gabby_reading":
self.auto_conditioning = self.auto_conditioning1
self.diffusion_conditioning = self.diffusion_conditioning1
elif voice == "gabby_conversation":
self.auto_conditioning = self.auto_conditioning2
self.diffusion_conditioning = self.diffusion_conditioning2
self.deterministic_state(seed=seed)
text_tokens = self.preprocess_text(
text
) # preprocess the in-coming text into tokens
self.__check_for_long_sentence(text_tokens)
# text_tokens = text_tokens.cpu().detach().numpy()
samples, stop_mel_token = self.gpt.parse_inference(
self.auto_conditioning, text_tokens, seed
)
best_sample = self.clvp.parse(text_tokens, samples, stop_mel_token, seed)
best_latent = self.gpt.parse(
self.auto_conditioning, text_tokens, best_sample, seed
)
wav_candidates = self.diffusion.parse(
best_sample,
best_latent,
self.calm_token,
self.diffusion_conditioning,
text,
seed,
)
if len(wav_candidates) > 1:
res = wav_candidates
else:
res = wav_candidates[0]
return res.cpu()
# @app.on_event("startup")
# def startup_event():
# ray.init(address="auto")
# serve.start()
# config = get_config_file(Path("config-model.yaml"))
# Orchestractor.deploy(config)
# orchestrator = serve.get_deployment("orchestrator")
# orchestrator = orchestrator.get_handle()
# app.deploy = orchestrator
# @app.on_event("shutdown")
# def shutdown_event():
# ray.shutdown()
#
#
# @app.post("/convert")
def model1_deployment(voice="gabby_reading", text="hello how are you!", seed=3):
serve.start(detached=True)
config = get_config_file(Path("config-model.yaml"))
Orchestractor.deploy(config)
orchestrator = serve.get_deployment("orchestrator")
orchestrator = orchestrator.get_handle()
app.deploy = orchestrator
if voice == "gabby_reading" or voice == "gabby_convo":
sentences = text.split(". ")
if len(sentences) > 1:
values = ray.get(
[
app.deploy.generate.remote(
text=sentence, voice=voice, seed=seed
)
for sentence in sentences
]
)
else:
values = [
ray.get(app.deploy.generate.remote(text, voice, seed))
]
file_name = ray.get(app.deploy.parse.remote(values))
return FileResponse(f"./{file_name}.wav")
else:
return f"{voice} not available!"
if __name__ == "__main__":
config = get_config_file(Path("config-model.yaml"))
orches = Orchestractor(config)
orches.generate(text="hello how are you doing from prakash!", voice="gabby_reading", seed=3)