narrator / handler.py
sim04ful
added import os
cfe276e
raw
history blame
No virus
1.99 kB
import os
os.environ["COQUI_TOS_AGREED"] = "1"
import torch
from TTS.api import TTS
from TTS.utils.manage import ModelManager
from TTS.utils.generic_utils import get_user_data_dir
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
import time
import torchaudio
import io
import base64
class EndpointHandler:
def __init__(self, path=""):
device = "cuda" if torch.cuda.is_available() else "cpu"
config = XttsConfig()
config.load_json("/repository/model/config.json")
model = Xtts.init_from_config(config)
model.load_checkpoint(
config,
checkpoint_path="/repository/model/model.pth",
vocab_path="/repository/model/vocab.json",
speaker_file_path="/repository/model/speakers_xtts.pth",
eval=True,
use_deepspeed=device == "cuda",
)
model.to(device)
self.model = model
def __call__(self, model_input):
(
gpt_cond_latent,
speaker_embedding,
) = self.model.get_conditioning_latents(
audio_path="/repository/attenborough.mp3",
gpt_cond_len=30,
gpt_cond_chunk_len=4,
max_ref_length=60,
)
print("Generating audio")
t0 = time.time()
out = self.model.inference(
text=model_input["text"],
speaker_embedding=speaker_embedding,
gpt_cond_latent=gpt_cond_latent,
temperature=0.75,
repetition_penalty=2.5,
language="en",
enable_text_splitting=True,
)
print(f"I: Time to generate audio: {inference_time} seconds")
audio_file = io.BytesIO()
torchaudio.save(audio_file, torch.tensor(out["wav"]).unsqueeze(0), 24000)
inference_time = time.time() - t0
audio_str = base64.b64encode(audio_file.getvalue()).decode("utf-8")
return {"data": audio_str, "format": "wav"}