narrator / handler.py
simdi's picture
Create handler.py
cbf8a35 verified
raw
history blame
No virus
1.9 kB
import time
from TTS.api import TTS
from TTS.utils.manage import ModelManager
from TTS.utils.generic_utils import get_user_data_dir
import torch
import os
from TTS.tts.configs.xtts_config import XttsConfig
import torchaudio
from TTS.tts.models.xtts import Xtts
import io
import base64
class EndpointHandler:
def __init__(self, path=""):
device = "cuda" if torch.cuda.is_available() else "cpu"
config = XttsConfig()
config.load_json("./model/config.json")
model = Xtts.init_from_config(config)
model.load_checkpoint(
config,
checkpoint_path="./model/model.pth",
vocab_path="./model/vocab.json",
speaker_file_path="./model/speakers_xtts.pth",
eval=True,
use_deepspeed=device == "cuda",
)
model.to(device)
self.model = model
def __call__(self, model_input):
(
gpt_cond_latent,
speaker_embedding,
) = self.model.get_conditioning_latents(
audio_path="attenborough.mp3",
gpt_cond_len=30,
gpt_cond_chunk_len=4,
max_ref_length=60,
)
print("Generating audio")
t0 = time.time()
out = self.model.inference(
text=model_input["text"],
speaker_embedding=speaker_embedding,
gpt_cond_latent=gpt_cond_latent,
temperature=0.75,
repetition_penalty=2.5,
language="en",
enable_text_splitting=True,
)
print(f"I: Time to generate audio: {inference_time} seconds")
audio_file = io.BytesIO()
torchaudio.save(audio_file, torch.tensor(out["wav"]).unsqueeze(0), 24000)
inference_time = time.time() - t0
audio_str = base64.b64encode(audio_file.getvalue()).decode("utf-8")
return {"data": audio_str, "format": "wav"}