|
import hashlib |
|
import random |
|
import ray |
|
import torch |
|
import torch.nn.functional as F |
|
import torchaudio |
|
from copy import copy |
|
from datetime import datetime |
|
from fastapi import FastAPI |
|
from fastapi.responses import FileResponse |
|
from pathlib import Path |
|
from pydantic import BaseModel |
|
from ray import serve |
|
from time import time |
|
from typing import Any, Dict, List, Text, Tuple |
|
|
|
from constants import ( |
|
AUTO_REGRESSIVE_BATCH_SIZE, |
|
DIFFUSION, |
|
DIFFUSION_TEMPERATURE, |
|
GPT, |
|
LENGTH_PENALTY, |
|
MAX_MEL_TOKENS, |
|
NUM_AUTOREGRESSIVE_SAMPLES, |
|
REPETITION_PENALTY, |
|
TEMPERATURE, |
|
TOP_P, |
|
CLVP_const, |
|
) |
|
from ruth_tts_transformer.models.autoregressive import UnifiedVoice |
|
from ruth_tts_transformer.models.clvp import CLVP |
|
from ruth_tts_transformer.models.diffusion_decoder import DiffusionTts |
|
from ruth_tts_transformer.models.vocoder import UnivNetGenerator |
|
from ruth_tts_transformer.utils.audio import load_voice |
|
from ruth_tts_transformer.utils.tokenizer import VoiceBpeTokenizer |
|
from ruth_tts_transformer.utils.wav2vec_alignment import Wav2VecAlignment |
|
from utils import ( |
|
MODELS_DIR, |
|
get_config_file, |
|
get_model_path, |
|
load_discrete_vocoder_diffuser, |
|
) |
|
|
|
app = FastAPI() |
|
|
|
|
|
class Item(BaseModel): |
|
text: str |
|
voice: str |
|
seed: int = 3 |
|
|
|
|
|
class Gpt: |
|
def __init__( |
|
self, |
|
num_autoregressive_samples: int, |
|
top_p: float, |
|
temperature: float, |
|
length_penalty: int, |
|
repetition_penalty: float, |
|
max_mel_tokens: int, |
|
autoregressive_batch_size: int, |
|
): |
|
self.num_autoregressive_samples = num_autoregressive_samples |
|
self.top_p = top_p |
|
self.temperature = temperature |
|
self.length_penalty = length_penalty |
|
self.repetition_penalty = repetition_penalty |
|
self.max_mel_tokens = max_mel_tokens |
|
self.autoregressive_batch_size = autoregressive_batch_size |
|
self.gpt = ( |
|
UnifiedVoice( |
|
max_mel_tokens=604, |
|
max_text_tokens=402, |
|
max_conditioning_inputs=2, |
|
layers=30, |
|
model_dim=1024, |
|
heads=16, |
|
number_text_tokens=255, |
|
start_text_token=255, |
|
checkpointing=False, |
|
train_solo_embeddings=False, |
|
) |
|
.cpu() |
|
.eval() |
|
) |
|
self.gpt.load_state_dict( |
|
torch.load(get_model_path("autoregressive.pth", MODELS_DIR)) |
|
) |
|
self.gpt = self.gpt.to("cuda") |
|
|
|
def __num_batches(self): |
|
return self.num_autoregressive_samples // self.autoregressive_batch_size |
|
|
|
@staticmethod |
|
def deterministic_state(seed=None): |
|
seed = int(time()) if seed is None else seed |
|
torch.manual_seed(seed) |
|
random.seed(seed) |
|
return seed |
|
|
|
def parse(self, auto_conditioning, text_tokens, best_results, seed, k=1): |
|
self.deterministic_state(seed=seed) |
|
auto_conditioning = copy(auto_conditioning).to("cuda") |
|
text_tokens = copy(text_tokens).to("cuda") |
|
best_results = copy(best_results).to("cuda") |
|
best_latents = self.gpt( |
|
auto_conditioning.repeat(k, 1), |
|
text_tokens.repeat(k, 1), |
|
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), |
|
best_results, |
|
torch.tensor( |
|
[best_results.shape[-1] * self.gpt.mel_length_compression], |
|
device=text_tokens.device, |
|
), |
|
return_latent=True, |
|
clip_inputs=False, |
|
) |
|
|
|
return best_latents |
|
|
|
def parse_inference( |
|
self, auto_conditioning: torch.Tensor, text_tokens: torch.Tensor, seed |
|
) -> Tuple[List[torch.Tensor], int]: |
|
self.deterministic_state(seed=seed) |
|
auto_conditioning = copy(auto_conditioning).to("cuda") |
|
text_tokens = copy(text_tokens).to("cuda") |
|
with torch.no_grad(): |
|
samples = [] |
|
num_batches = self.__num_batches() |
|
for b in range(num_batches): |
|
codes = self.gpt.inference_speech( |
|
auto_conditioning, |
|
text_tokens, |
|
do_sample=True, |
|
top_p=self.top_p, |
|
temperature=self.temperature, |
|
num_return_sequences=self.autoregressive_batch_size, |
|
length_penalty=self.length_penalty, |
|
repetition_penalty=self.repetition_penalty, |
|
max_generate_length=self.max_mel_tokens, |
|
) |
|
padding_needed = self.max_mel_tokens - codes.shape[1] |
|
codes = F.pad(codes, (0, padding_needed), value=self.gpt.stop_mel_token) |
|
|
|
samples.append(codes) |
|
|
|
return samples, self.gpt.stop_mel_token |
|
|
|
|
|
class clvp: |
|
def __init__(self, K): |
|
|
|
self.clvp = ( |
|
CLVP( |
|
dim_text=768, |
|
dim_speech=768, |
|
dim_latent=768, |
|
num_text_tokens=256, |
|
text_enc_depth=20, |
|
text_seq_len=350, |
|
text_heads=12, |
|
num_speech_tokens=8192, |
|
speech_enc_depth=20, |
|
speech_heads=12, |
|
speech_seq_len=430, |
|
use_xformers=True, |
|
) |
|
.cpu() |
|
.eval() |
|
) |
|
self.clvp.load_state_dict(torch.load(get_model_path("clvp2.pth", MODELS_DIR))) |
|
self.clvp.to("cuda") |
|
self.K = K |
|
|
|
@staticmethod |
|
def fix_gpt_output(codes, stop_token, complain=True): |
|
stop_token_indices = (codes == stop_token).nonzero() |
|
if len(stop_token_indices) == 0: |
|
if complain: |
|
print( |
|
"No stop tokens found in one of the generated voice clips. This typically means the spoken audio " |
|
"is " |
|
"too long. In some cases, the output will still be good, though. Listen to it and if it is " |
|
"missing words, " |
|
"try breaking up your input text." |
|
) |
|
return codes |
|
else: |
|
codes[stop_token_indices] = 83 |
|
stm = stop_token_indices.min().item() |
|
codes[stm:] = 83 |
|
if stm - 3 < codes.shape[0]: |
|
codes[-3] = 45 |
|
codes[-2] = 45 |
|
codes[-1] = 248 |
|
|
|
return codes |
|
|
|
def parse( |
|
self, |
|
text_tokens: torch.Tensor, |
|
samples: List[torch.Tensor], |
|
stop_mel_token: int, |
|
seed: int, |
|
) -> torch.Tensor: |
|
self.deterministic_state(seed=seed) |
|
clip_results = [] |
|
text_tokens = copy(text_tokens).to("cuda") |
|
samples = [copy(batch).to("cuda") for batch in samples] |
|
for batch in samples: |
|
for i in range(batch.shape[0]): |
|
batch[i] = self.fix_gpt_output(batch[i], stop_mel_token) |
|
|
|
clvp = self.clvp( |
|
text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False |
|
) |
|
clip_results.append(clvp) |
|
|
|
clip_results = torch.cat(clip_results, dim=0) |
|
samples = torch.cat(samples, dim=0) |
|
|
|
return samples[torch.topk(clip_results, self.K).indices] |
|
|
|
@staticmethod |
|
def deterministic_state(seed=None): |
|
seed = int(time()) if seed is None else seed |
|
torch.manual_seed(seed) |
|
random.seed(seed) |
|
return seed |
|
|
|
|
|
class Diffusion: |
|
def __init__( |
|
self, |
|
diffusion_temperature, |
|
diffusion_iterations=30, |
|
cond_free=True, |
|
cond_free_k=2, |
|
): |
|
self.diffusion_temperature = diffusion_temperature |
|
self.diffusion = ( |
|
DiffusionTts( |
|
model_channels=1024, |
|
num_layers=10, |
|
in_channels=100, |
|
out_channels=200, |
|
in_latent_channels=1024, |
|
in_tokens=8193, |
|
dropout=0, |
|
use_fp16=False, |
|
num_heads=16, |
|
layer_drop=0, |
|
unconditioned_percentage=0, |
|
) |
|
.cpu() |
|
.eval() |
|
) |
|
self.diffusion.load_state_dict( |
|
torch.load(get_model_path("diffusion_decoder.pth", MODELS_DIR)) |
|
) |
|
self.diffuser = load_discrete_vocoder_diffuser( |
|
desired_diffusion_steps=diffusion_iterations, |
|
cond_free=cond_free, |
|
cond_free_k=cond_free_k, |
|
) |
|
|
|
self.vocoder = UnivNetGenerator().cpu() |
|
self.vocoder.load_state_dict( |
|
torch.load( |
|
get_model_path("vocoder.pth", MODELS_DIR), |
|
map_location=torch.device("cpu"), |
|
)["model_g"] |
|
) |
|
self.vocoder.eval(inference=True) |
|
self.diffusion.to("cuda") |
|
self.vocoder.to("cuda") |
|
self.aligner = Wav2VecAlignment() |
|
|
|
self.TACOTRON_MEL_MAX = 2.3143386840820312 |
|
self.TACOTRON_MEL_MIN = -11.512925148010254 |
|
|
|
def denormalize_tacotron_mel(self, norm_mel): |
|
return ((norm_mel + 1) / 2) * ( |
|
self.TACOTRON_MEL_MAX - self.TACOTRON_MEL_MIN |
|
) + self.TACOTRON_MEL_MIN |
|
|
|
def potentially_redact(self, clip, text): |
|
return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1) |
|
|
|
@staticmethod |
|
def deterministic_state(seed=None): |
|
seed = int(time()) if seed is None else seed |
|
torch.manual_seed(seed) |
|
random.seed(seed) |
|
return seed |
|
|
|
def do_spectrogram_diffusion( |
|
self, |
|
diffusion_model, |
|
diffuser, |
|
latents, |
|
conditioning_latents, |
|
seed, |
|
temperature=1, |
|
verbose=False, |
|
): |
|
self.deterministic_state(seed=seed) |
|
with torch.no_grad(): |
|
output_seq_len = ( |
|
latents.shape[1] * 4 * 24000 // 22050 |
|
) |
|
output_shape = (latents.shape[0], 100, output_seq_len) |
|
precomputed_embeddings = diffusion_model.timestep_independent( |
|
latents, conditioning_latents, output_seq_len, False |
|
) |
|
|
|
noise = torch.randn(output_shape, device=latents.device) * temperature |
|
mel = diffuser.p_sample_loop( |
|
diffusion_model, |
|
output_shape, |
|
noise=noise, |
|
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings}, |
|
progress=verbose, |
|
) |
|
return self.denormalize_tacotron_mel(mel)[:, :, :output_seq_len] |
|
|
|
def parse( |
|
self, best_results, best_latents, calm_token, diffusion_conditioning, text, seed |
|
): |
|
self.deterministic_state(seed=seed) |
|
best_results = copy(best_results).to("cuda") |
|
best_latents = copy(best_latents).to("cuda") |
|
diffusion_conditioning = copy(diffusion_conditioning).to("cuda") |
|
wav_candidates = [] |
|
for b in range(best_results.shape[0]): |
|
|
|
codes = best_results[b].unsqueeze(0) |
|
latents = best_latents[b].unsqueeze(0) |
|
|
|
ctokens = 0 |
|
for k in range(codes.shape[-1]): |
|
if codes[0, k] == calm_token: |
|
ctokens += 1 |
|
else: |
|
ctokens = 0 |
|
if ctokens > 8: |
|
latents = latents[:, :k] |
|
break |
|
|
|
mel = self.do_spectrogram_diffusion( |
|
self.diffusion, |
|
self.diffuser, |
|
latents, |
|
diffusion_conditioning, |
|
seed, |
|
temperature=self.diffusion_temperature, |
|
verbose=False, |
|
) |
|
wav = self.vocoder.inference(mel) |
|
wav_candidates.append(wav) |
|
|
|
|
|
|
|
wav_candidates = [ |
|
self.potentially_redact(wav_candidate, text) |
|
for wav_candidate in wav_candidates |
|
] |
|
return wav_candidates |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Orchestractor: |
|
def __init__(self, config: Dict[Text, Any]): |
|
self.calm_token = 83 |
|
self.tokenizer = VoiceBpeTokenizer() |
|
_, conditioning_latent_1 = load_voice("gabby_reading", map_location="cpu") |
|
_, conditioning_latent_2 = load_voice("gabby_conversation", map_location="cpu") |
|
|
|
|
|
|
|
self.conditioning_latents1 = (latent for latent in conditioning_latent_1) |
|
self.conditioning_latents2 = (latent for latent in conditioning_latent_2) |
|
( |
|
self.auto_conditioning1, |
|
self.diffusion_conditioning1, |
|
) = self.conditioning_latents1 |
|
( |
|
self.auto_conditioning2, |
|
self.diffusion_conditioning2, |
|
) = self.conditioning_latents2 |
|
|
|
self.auto_conditioning = None |
|
self.diffusion_conditioning = None |
|
self.gpt = Gpt( |
|
config[GPT][NUM_AUTOREGRESSIVE_SAMPLES], |
|
config[GPT][TOP_P], |
|
config[GPT][TEMPERATURE], |
|
config[GPT][LENGTH_PENALTY], |
|
config[GPT][REPETITION_PENALTY], |
|
config[GPT][MAX_MEL_TOKENS], |
|
config[GPT][AUTO_REGRESSIVE_BATCH_SIZE], |
|
) |
|
self.clvp = clvp(config[CLVP_const]["k"]) |
|
self.diffusion = Diffusion(config[DIFFUSION][DIFFUSION_TEMPERATURE]) |
|
self.calm_token = 83 |
|
print("orchestrator setup completed") |
|
|
|
@staticmethod |
|
def __check_for_long_sentence(text_tokens): |
|
assert ( |
|
text_tokens.shape[-1] < 400 |
|
), "Too much text provided. Break the text up into separate segments and re-try inference." |
|
|
|
|
|
@staticmethod |
|
def deterministic_state(seed=None): |
|
seed = int(time()) if seed is None else seed |
|
torch.manual_seed(seed) |
|
random.seed(seed) |
|
return seed |
|
|
|
def preprocess_text(self, text: Text): |
|
torch_tensor = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0) |
|
return torch_tensor |
|
|
|
def parse(self, res): |
|
print("parsing") |
|
file_name = hashlib.sha1(str(datetime.now()).encode("UTF-8")) |
|
res = [torch.Tensor(copy(split)).squeeze(0).cpu() for split in res] |
|
res = [torch.flatten(split) for split in res] |
|
merged_audio_tensor = torch.cat(res).reshape(1, -1) |
|
torchaudio.save(f"./{file_name.hexdigest()}.wav", merged_audio_tensor, 24000) |
|
|
|
return file_name.hexdigest() |
|
|
|
def generate(self, text, voice, seed): |
|
if voice == "gabby_reading": |
|
self.auto_conditioning = self.auto_conditioning1 |
|
self.diffusion_conditioning = self.diffusion_conditioning1 |
|
elif voice == "gabby_conversation": |
|
self.auto_conditioning = self.auto_conditioning2 |
|
self.diffusion_conditioning = self.diffusion_conditioning2 |
|
|
|
self.deterministic_state(seed=seed) |
|
text_tokens = self.preprocess_text( |
|
text |
|
) |
|
self.__check_for_long_sentence(text_tokens) |
|
|
|
samples, stop_mel_token = self.gpt.parse_inference( |
|
self.auto_conditioning, text_tokens, seed |
|
) |
|
best_sample = self.clvp.parse(text_tokens, samples, stop_mel_token, seed) |
|
best_latent = self.gpt.parse( |
|
self.auto_conditioning, text_tokens, best_sample, seed |
|
) |
|
wav_candidates = self.diffusion.parse( |
|
best_sample, |
|
best_latent, |
|
self.calm_token, |
|
self.diffusion_conditioning, |
|
text, |
|
seed, |
|
) |
|
if len(wav_candidates) > 1: |
|
res = wav_candidates |
|
else: |
|
res = wav_candidates[0] |
|
|
|
return res.cpu() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def model1_deployment(voice="gabby_reading", text="hello how are you!", seed=3): |
|
serve.start(detached=True) |
|
config = get_config_file(Path("config-model.yaml")) |
|
Orchestractor.deploy(config) |
|
orchestrator = serve.get_deployment("orchestrator") |
|
orchestrator = orchestrator.get_handle() |
|
app.deploy = orchestrator |
|
if voice == "gabby_reading" or voice == "gabby_convo": |
|
sentences = text.split(". ") |
|
if len(sentences) > 1: |
|
values = ray.get( |
|
[ |
|
app.deploy.generate.remote( |
|
text=sentence, voice=voice, seed=seed |
|
) |
|
for sentence in sentences |
|
] |
|
) |
|
else: |
|
values = [ |
|
ray.get(app.deploy.generate.remote(text, voice, seed)) |
|
] |
|
file_name = ray.get(app.deploy.parse.remote(values)) |
|
return FileResponse(f"./{file_name}.wav") |
|
else: |
|
return f"{voice} not available!" |
|
|
|
|
|
if __name__ == "__main__": |
|
config = get_config_file(Path("config-model.yaml")) |
|
orches = Orchestractor(config) |
|
orches.generate(text="hello how are you doing from prakash!", voice="gabby_reading", seed=3) |
|
|