from typing import Dict, Any | |
import torch | |
from TTS.api import TTS | |
class EndpointHandler: | |
def __init__(self, path=""): | |
# Load Coqui XTTS model (weights are downloaded automatically) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.tts = TTS( | |
"tts_models/multilingual/multi-dataset/xtts_v2").to(device) | |
self.device = device | |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
""" | |
Expected input: | |
{ | |
"text": "Hello world", | |
"speaker_wav": "path/to/voice.wav" (optional, for cloning), | |
"language": "en" | |
} | |
""" | |
text = data.get("text", "") | |
speaker_wav = data.get("speaker_wav", None) | |
language = data.get("language", "en") | |
# Generate speech | |
output_path = "output.wav" | |
self.tts.tts_to_file( | |
text=text, | |
speaker_wav=speaker_wav, | |
language=language, | |
file_path=output_path | |
) | |
return {"audio": output_path} | |