File size: 8,710 Bytes
f5ec497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
"""
Data models and types for the TTSFM package.

This module defines the core data structures used throughout the package,
including request/response models, enums, and error types.
"""

from enum import Enum
from typing import Optional, Dict, Any, Union
from dataclasses import dataclass
from datetime import datetime


class Voice(str, Enum):
    """Available voice options for TTS generation."""
    ALLOY = "alloy"
    ASH = "ash"
    BALLAD = "ballad"
    CORAL = "coral"
    ECHO = "echo"
    FABLE = "fable"
    NOVA = "nova"
    ONYX = "onyx"
    SAGE = "sage"
    SHIMMER = "shimmer"
    VERSE = "verse"


class AudioFormat(str, Enum):
    """Supported audio output formats."""
    MP3 = "mp3"
    WAV = "wav"
    OPUS = "opus"
    AAC = "aac"
    FLAC = "flac"
    PCM = "pcm"


@dataclass
class TTSRequest:
    """
    Request model for TTS generation.

    Attributes:
        input: Text to convert to speech
        voice: Voice to use for generation
        response_format: Audio format for output
        instructions: Optional instructions for voice modulation
        model: Model to use (for OpenAI compatibility, usually ignored)
        speed: Speech speed (for OpenAI compatibility, usually ignored)
        max_length: Maximum allowed text length (default: 4096 characters)
        validate_length: Whether to validate text length (default: True)
    """
    input: str
    voice: Union[Voice, str] = Voice.ALLOY
    response_format: Union[AudioFormat, str] = AudioFormat.MP3
    instructions: Optional[str] = None
    model: Optional[str] = None
    speed: Optional[float] = None
    max_length: int = 4096
    validate_length: bool = True
    
    def __post_init__(self):
        """Validate and normalize fields after initialization."""
        # Ensure voice is a valid Voice enum
        if isinstance(self.voice, str):
            try:
                self.voice = Voice(self.voice.lower())
            except ValueError:
                raise ValueError(f"Invalid voice: {self.voice}. Must be one of {list(Voice)}")

        # Ensure response_format is a valid AudioFormat enum
        if isinstance(self.response_format, str):
            try:
                self.response_format = AudioFormat(self.response_format.lower())
            except ValueError:
                raise ValueError(f"Invalid format: {self.response_format}. Must be one of {list(AudioFormat)}")

        # Validate input text
        if not self.input or not self.input.strip():
            raise ValueError("Input text cannot be empty")

        # Validate text length if enabled
        if self.validate_length:
            text_length = len(self.input)
            if text_length > self.max_length:
                raise ValueError(
                    f"Input text is too long ({text_length} characters). "
                    f"Maximum allowed length is {self.max_length} characters. "
                    f"Consider splitting your text into smaller chunks or disable "
                    f"length validation with validate_length=False."
                )

        # Validate max_length parameter
        if self.max_length <= 0:
            raise ValueError("max_length must be a positive integer")

        # Validate speed if provided
        if self.speed is not None and (self.speed < 0.25 or self.speed > 4.0):
            raise ValueError("Speed must be between 0.25 and 4.0")
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert request to dictionary for API calls."""
        data = {
            "input": self.input,
            "voice": self.voice.value if isinstance(self.voice, Voice) else self.voice,
            "response_format": self.response_format.value if isinstance(self.response_format, AudioFormat) else self.response_format
        }
        
        if self.instructions:
            data["instructions"] = self.instructions
        
        if self.model:
            data["model"] = self.model
            
        if self.speed is not None:
            data["speed"] = self.speed
            
        return data


@dataclass
class TTSResponse:
    """
    Response model for TTS generation.
    
    Attributes:
        audio_data: Generated audio as bytes
        content_type: MIME type of the audio data
        format: Audio format used
        size: Size of audio data in bytes
        duration: Estimated duration in seconds (if available)
        metadata: Additional response metadata
    """
    audio_data: bytes
    content_type: str
    format: AudioFormat
    size: int
    duration: Optional[float] = None
    metadata: Optional[Dict[str, Any]] = None
    
    def __post_init__(self):
        """Calculate derived fields after initialization."""
        if self.size is None:
            self.size = len(self.audio_data)
    
    def save_to_file(self, filename: str) -> str:
        """
        Save audio data to a file.

        Args:
            filename: Target filename (extension will be added if missing)

        Returns:
            str: Final filename used
        """
        import os

        # Use the actual returned format for the extension, not any requested format
        expected_extension = f".{self.format.value}"

        # Check if filename already has the correct extension
        if filename.endswith(expected_extension):
            final_filename = filename
        else:
            # Remove any existing extension and add the correct one
            base_name = filename
            # Remove common audio extensions if present
            for ext in ['.mp3', '.wav', '.opus', '.aac', '.flac', '.pcm']:
                if base_name.endswith(ext):
                    base_name = base_name[:-len(ext)]
                    break
            final_filename = f"{base_name}{expected_extension}"

        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(final_filename) if os.path.dirname(final_filename) else ".", exist_ok=True)

        # Write audio data
        with open(final_filename, "wb") as f:
            f.write(self.audio_data)

        return final_filename


@dataclass
class TTSError:
    """
    Error information from TTS API.
    
    Attributes:
        code: Error code
        message: Human-readable error message
        type: Error type/category
        details: Additional error details
        timestamp: When the error occurred
    """
    code: str
    message: str
    type: Optional[str] = None
    details: Optional[Dict[str, Any]] = None
    timestamp: Optional[datetime] = None
    
    def __post_init__(self):
        """Set timestamp if not provided."""
        if self.timestamp is None:
            self.timestamp = datetime.now()


@dataclass
class APIError(TTSError):
    """API-specific error information."""
    status_code: int = 500
    headers: Optional[Dict[str, str]] = None


@dataclass
class NetworkError(TTSError):
    """Network-related error information."""
    timeout: Optional[float] = None
    retry_count: int = 0


@dataclass
class ValidationError(TTSError):
    """Validation error information."""
    field: Optional[str] = None
    value: Optional[Any] = None


# Content type mappings for audio formats
CONTENT_TYPE_MAP = {
    AudioFormat.MP3: "audio/mpeg",
    AudioFormat.OPUS: "audio/opus",
    AudioFormat.AAC: "audio/aac",
    AudioFormat.FLAC: "audio/flac",
    AudioFormat.WAV: "audio/wav",
    AudioFormat.PCM: "audio/pcm"
}

# Reverse mapping for content type to format
FORMAT_FROM_CONTENT_TYPE = {v: k for k, v in CONTENT_TYPE_MAP.items()}


def get_content_type(format: Union[AudioFormat, str]) -> str:
    """Get MIME content type for audio format."""
    if isinstance(format, str):
        format = AudioFormat(format.lower())
    return CONTENT_TYPE_MAP.get(format, "audio/mpeg")


def get_format_from_content_type(content_type: str) -> AudioFormat:
    """Get audio format from MIME content type."""
    return FORMAT_FROM_CONTENT_TYPE.get(content_type, AudioFormat.MP3)


def get_supported_format(requested_format: AudioFormat) -> AudioFormat:
    """
    Map requested format to supported format.

    Args:
        requested_format: The requested audio format

    Returns:
        AudioFormat: MP3 or WAV (the supported formats)
    """
    if requested_format == AudioFormat.MP3:
        return AudioFormat.MP3
    else:
        # All other formats (WAV, OPUS, AAC, FLAC, PCM) return WAV
        return AudioFormat.WAV


def maps_to_wav(format_value: str) -> bool:
    """
    Check if a format maps to WAV.

    Args:
        format_value: Format string to check

    Returns:
        bool: True if the format maps to WAV
    """
    return format_value.lower() in ['wav', 'opus', 'aac', 'flac', 'pcm']