File size: 3,720 Bytes
ce2ed27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import json
import asyncio
import aiofiles
from time import time
import json
from pprint import pprint
from smallestai.waves import WavesClient, AsyncWavesClient


class SmallestAITTS:
    def __init__(
        self,
        model_name: str,
        api_key: str,
        provider: str,
        endpoint_url: str,
        voice_id: str = None,
        sample_rate: int = 24000,
        speed: float = 1.0,
        is_async: bool = False,
    ):
        if is_async:
            self.client = AsyncWavesClient(api_key=api_key)
        else:
            self.client = WavesClient(api_key=api_key)

        self.model_name = model_name
        self.api_key = api_key
        self.provider = provider
        self.endpoint_url = endpoint_url
        self.voice_id = voice_id  # if passed as None, initialized later using `load_voice()` function
        self.sample_rate = sample_rate
        self.speed = speed
        self.tts = self._async_tts if is_async else self._tts
        self.is_async = is_async

    def load_voice(self, voice_id: str):
        """
        Used for loading voices (Optional)
        """
        self.voice_id = voice_id

    # Create a common interface method
    def synthesize(self, text: str, output_filepath: str):
        """
        Unified interface for text-to-speech synthesis.
        Will automatically use async or sync implementation based on initialization.

        Args:
            text: The text to synthesize
            output_filepath: Path to save the audio file
        """
        if self.is_async:
            # For async usage, wrap in asyncio.run() if not in an async context
            try:
                return asyncio.get_event_loop().run_until_complete(
                    self._async_tts(text, output_filepath)
                )
            except RuntimeError:
                # If there's no event loop running
                return asyncio.run(self._async_tts(text, output_filepath))
        else:
            return self._tts(text, output_filepath)

    def _tts(self, text: str, output_filepath: str):
        # If voice style is not set before TTS
        assert self.voice_id is not None, "Please set a voice style."
        self.client.synthesize(
            text,
            save_as=output_filepath,
            model=self.model_name,
            voice_id=self.voice_id,
            speed=self.speed,
            sample_rate=self.sample_rate,
        )

    async def _async_tts(self, text: str, output_filepath: str):
        # If voice style is not set before TTS
        assert self.voice_id is not None, "Please set a voice style."
        async with self.client:
            audio_bytes = await self.client.synthesize(
                text,
                model=self.model_name,
                voice_id=self.voice_id,
                speed=self.speed,
                sample_rate=self.sample_rate,
            )
            async with aiofiles.open(output_filepath, "wb") as f:
                await f.write(audio_bytes)

    # Wrapper for SmallestAI client's default functions
    def get_languages(self):
        return self.client.get_languages()

    def get_voices(self, model="lightning", voiceId=None, **kwargs) -> list:
        voices = json.loads(self.client.get_voices(model))["voices"]
        # recursively filter the voices based on the kwargs
        if voiceId is not None:
            voices = [voice for voice in voices if voice["voiceId"] == voiceId]
        else:
            for key in kwargs:
                voices = [
                    voice for voice in voices if voice["tags"][key] == kwargs[key]
                ]

        return voices

    def get_models(self):
        return self.client.get_models()