Spaces:
Sleeping
Sleeping
import os | |
import json | |
import asyncio | |
import aiofiles | |
from time import time | |
import json | |
from pprint import pprint | |
from smallestai.waves import WavesClient, AsyncWavesClient | |
class SmallestAITTS: | |
def __init__( | |
self, | |
model_name: str, | |
api_key: str, | |
provider: str, | |
endpoint_url: str, | |
voice_id: str = None, | |
sample_rate: int = 24000, | |
speed: float = 1.0, | |
is_async: bool = False, | |
): | |
if is_async: | |
self.client = AsyncWavesClient(api_key=api_key) | |
else: | |
self.client = WavesClient(api_key=api_key) | |
self.model_name = model_name | |
self.api_key = api_key | |
self.provider = provider | |
self.endpoint_url = endpoint_url | |
self.voice_id = voice_id # if passed as None, initialized later using `load_voice()` function | |
self.sample_rate = sample_rate | |
self.speed = speed | |
self.tts = self._async_tts if is_async else self._tts | |
self.is_async = is_async | |
def load_voice(self, voice_id: str): | |
""" | |
Used for loading voices (Optional) | |
""" | |
self.voice_id = voice_id | |
# Create a common interface method | |
def synthesize(self, text: str, output_filepath: str): | |
""" | |
Unified interface for text-to-speech synthesis. | |
Will automatically use async or sync implementation based on initialization. | |
Args: | |
text: The text to synthesize | |
output_filepath: Path to save the audio file | |
""" | |
if self.is_async: | |
# For async usage, wrap in asyncio.run() if not in an async context | |
try: | |
return asyncio.get_event_loop().run_until_complete( | |
self._async_tts(text, output_filepath) | |
) | |
except RuntimeError: | |
# If there's no event loop running | |
return asyncio.run(self._async_tts(text, output_filepath)) | |
else: | |
return self._tts(text, output_filepath) | |
def _tts(self, text: str, output_filepath: str): | |
# If voice style is not set before TTS | |
assert self.voice_id is not None, "Please set a voice style." | |
self.client.synthesize( | |
text, | |
save_as=output_filepath, | |
model=self.model_name, | |
voice_id=self.voice_id, | |
speed=self.speed, | |
sample_rate=self.sample_rate, | |
) | |
async def _async_tts(self, text: str, output_filepath: str): | |
# If voice style is not set before TTS | |
assert self.voice_id is not None, "Please set a voice style." | |
async with self.client: | |
audio_bytes = await self.client.synthesize( | |
text, | |
model=self.model_name, | |
voice_id=self.voice_id, | |
speed=self.speed, | |
sample_rate=self.sample_rate, | |
) | |
async with aiofiles.open(output_filepath, "wb") as f: | |
await f.write(audio_bytes) | |
# Wrapper for SmallestAI client's default functions | |
def get_languages(self): | |
return self.client.get_languages() | |
def get_voices(self, model="lightning", voiceId=None, **kwargs) -> list: | |
voices = json.loads(self.client.get_voices(model))["voices"] | |
# recursively filter the voices based on the kwargs | |
if voiceId is not None: | |
voices = [voice for voice in voices if voice["voiceId"] == voiceId] | |
else: | |
for key in kwargs: | |
voices = [ | |
voice for voice in voices if voice["tags"][key] == kwargs[key] | |
] | |
return voices | |
def get_models(self): | |
return self.client.get_models() |