|
import os |
|
|
|
os.environ["COQUI_TOS_AGREED"] = "1" |
|
from TTS.api import TTS |
|
from TTS.utils.manage import ModelManager |
|
from TTS.utils.generic_utils import get_user_data_dir |
|
from TTS.tts.configs.xtts_config import XttsConfig |
|
from TTS.tts.models.xtts import Xtts |
|
import torch |
|
import time |
|
import torchaudio |
|
import io |
|
import base64 |
|
import requests |
|
import tempfile |
|
|
|
|
|
def convert_audio_urls_to_paths(audio_urls): |
|
temp_files = [] |
|
audio_paths = [] |
|
|
|
for url in audio_urls: |
|
filename = url.split("/")[-1] |
|
file_destination_path, file_object = download_tempfile( |
|
file_url=url, filename=filename |
|
) |
|
temp_files.append(file_object) |
|
audio_paths.append(file_destination_path) |
|
|
|
return audio_paths, temp_files |
|
|
|
|
|
def download_tempfile(file_url, filename): |
|
try: |
|
response = requests.get(file_url) |
|
response.raise_for_status() |
|
filetype = filename.split(".")[-1] |
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f".{filetype}") |
|
temp_file.write(response.content) |
|
return temp_file.name, temp_file |
|
except Exception as e: |
|
print(f"Error downloading file: {e}") |
|
return None, None |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
config = XttsConfig() |
|
config.load_json("/repository/model/config.json") |
|
model = Xtts.init_from_config(config) |
|
model.load_checkpoint( |
|
config, |
|
checkpoint_path="/repository/model/model.pth", |
|
vocab_path="/repository/model/vocab.json", |
|
speaker_file_path="/repository/model/speakers_xtts.pth", |
|
eval=True, |
|
use_deepspeed=device == "cuda", |
|
) |
|
model.to(device) |
|
|
|
self.model = model |
|
|
|
def __call__(self, model_input): |
|
audio_paths, temp_files = convert_audio_urls_to_paths(model_input["audio_urls"]) |
|
|
|
( |
|
gpt_cond_latent, |
|
speaker_embedding, |
|
) = self.model.get_conditioning_latents( |
|
audio_path=audio_paths, |
|
gpt_cond_len=model_input["gpt_cond_len"], |
|
gpt_cond_chunk_len=model_input["gpt_cond_chunk_len"], |
|
max_ref_length=model_input["max_ref_length"], |
|
) |
|
|
|
print("Generating audio") |
|
|
|
t0 = time.time() |
|
out = self.model.inference( |
|
text=model_input["text"], |
|
speaker_embedding=speaker_embedding, |
|
gpt_cond_latent=gpt_cond_latent, |
|
temperature=model_input["temperature"], |
|
repetition_penalty=model_input["repetition_penalty"], |
|
language=model_input["language"], |
|
enable_text_splitting=True, |
|
) |
|
audio_file = io.BytesIO() |
|
torchaudio.save( |
|
audio_file, torch.tensor(out["wav"]).unsqueeze(0), 24000, format="wav" |
|
) |
|
inference_time = time.time() - t0 |
|
print(f"I: Time to generate audio: {inference_time} seconds") |
|
audio_str = base64.b64encode(audio_file.getvalue()).decode("utf-8") |
|
|
|
try: |
|
for temp_file in temp_files: |
|
os.remove(temp_file) |
|
except Exception as e: |
|
print(f"Error removing temp files: {e}") |
|
|
|
return {"data": audio_str, "format": "wav"} |
|
|