Spaces:
Sleeping
Sleeping
File size: 8,710 Bytes
f5ec497 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
"""
Data models and types for the TTSFM package.
This module defines the core data structures used throughout the package,
including request/response models, enums, and error types.
"""
from enum import Enum
from typing import Optional, Dict, Any, Union
from dataclasses import dataclass
from datetime import datetime
class Voice(str, Enum):
"""Available voice options for TTS generation."""
ALLOY = "alloy"
ASH = "ash"
BALLAD = "ballad"
CORAL = "coral"
ECHO = "echo"
FABLE = "fable"
NOVA = "nova"
ONYX = "onyx"
SAGE = "sage"
SHIMMER = "shimmer"
VERSE = "verse"
class AudioFormat(str, Enum):
"""Supported audio output formats."""
MP3 = "mp3"
WAV = "wav"
OPUS = "opus"
AAC = "aac"
FLAC = "flac"
PCM = "pcm"
@dataclass
class TTSRequest:
"""
Request model for TTS generation.
Attributes:
input: Text to convert to speech
voice: Voice to use for generation
response_format: Audio format for output
instructions: Optional instructions for voice modulation
model: Model to use (for OpenAI compatibility, usually ignored)
speed: Speech speed (for OpenAI compatibility, usually ignored)
max_length: Maximum allowed text length (default: 4096 characters)
validate_length: Whether to validate text length (default: True)
"""
input: str
voice: Union[Voice, str] = Voice.ALLOY
response_format: Union[AudioFormat, str] = AudioFormat.MP3
instructions: Optional[str] = None
model: Optional[str] = None
speed: Optional[float] = None
max_length: int = 4096
validate_length: bool = True
def __post_init__(self):
"""Validate and normalize fields after initialization."""
# Ensure voice is a valid Voice enum
if isinstance(self.voice, str):
try:
self.voice = Voice(self.voice.lower())
except ValueError:
raise ValueError(f"Invalid voice: {self.voice}. Must be one of {list(Voice)}")
# Ensure response_format is a valid AudioFormat enum
if isinstance(self.response_format, str):
try:
self.response_format = AudioFormat(self.response_format.lower())
except ValueError:
raise ValueError(f"Invalid format: {self.response_format}. Must be one of {list(AudioFormat)}")
# Validate input text
if not self.input or not self.input.strip():
raise ValueError("Input text cannot be empty")
# Validate text length if enabled
if self.validate_length:
text_length = len(self.input)
if text_length > self.max_length:
raise ValueError(
f"Input text is too long ({text_length} characters). "
f"Maximum allowed length is {self.max_length} characters. "
f"Consider splitting your text into smaller chunks or disable "
f"length validation with validate_length=False."
)
# Validate max_length parameter
if self.max_length <= 0:
raise ValueError("max_length must be a positive integer")
# Validate speed if provided
if self.speed is not None and (self.speed < 0.25 or self.speed > 4.0):
raise ValueError("Speed must be between 0.25 and 4.0")
def to_dict(self) -> Dict[str, Any]:
"""Convert request to dictionary for API calls."""
data = {
"input": self.input,
"voice": self.voice.value if isinstance(self.voice, Voice) else self.voice,
"response_format": self.response_format.value if isinstance(self.response_format, AudioFormat) else self.response_format
}
if self.instructions:
data["instructions"] = self.instructions
if self.model:
data["model"] = self.model
if self.speed is not None:
data["speed"] = self.speed
return data
@dataclass
class TTSResponse:
"""
Response model for TTS generation.
Attributes:
audio_data: Generated audio as bytes
content_type: MIME type of the audio data
format: Audio format used
size: Size of audio data in bytes
duration: Estimated duration in seconds (if available)
metadata: Additional response metadata
"""
audio_data: bytes
content_type: str
format: AudioFormat
size: int
duration: Optional[float] = None
metadata: Optional[Dict[str, Any]] = None
def __post_init__(self):
"""Calculate derived fields after initialization."""
if self.size is None:
self.size = len(self.audio_data)
def save_to_file(self, filename: str) -> str:
"""
Save audio data to a file.
Args:
filename: Target filename (extension will be added if missing)
Returns:
str: Final filename used
"""
import os
# Use the actual returned format for the extension, not any requested format
expected_extension = f".{self.format.value}"
# Check if filename already has the correct extension
if filename.endswith(expected_extension):
final_filename = filename
else:
# Remove any existing extension and add the correct one
base_name = filename
# Remove common audio extensions if present
for ext in ['.mp3', '.wav', '.opus', '.aac', '.flac', '.pcm']:
if base_name.endswith(ext):
base_name = base_name[:-len(ext)]
break
final_filename = f"{base_name}{expected_extension}"
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(final_filename) if os.path.dirname(final_filename) else ".", exist_ok=True)
# Write audio data
with open(final_filename, "wb") as f:
f.write(self.audio_data)
return final_filename
@dataclass
class TTSError:
"""
Error information from TTS API.
Attributes:
code: Error code
message: Human-readable error message
type: Error type/category
details: Additional error details
timestamp: When the error occurred
"""
code: str
message: str
type: Optional[str] = None
details: Optional[Dict[str, Any]] = None
timestamp: Optional[datetime] = None
def __post_init__(self):
"""Set timestamp if not provided."""
if self.timestamp is None:
self.timestamp = datetime.now()
@dataclass
class APIError(TTSError):
"""API-specific error information."""
status_code: int = 500
headers: Optional[Dict[str, str]] = None
@dataclass
class NetworkError(TTSError):
"""Network-related error information."""
timeout: Optional[float] = None
retry_count: int = 0
@dataclass
class ValidationError(TTSError):
"""Validation error information."""
field: Optional[str] = None
value: Optional[Any] = None
# Content type mappings for audio formats
CONTENT_TYPE_MAP = {
AudioFormat.MP3: "audio/mpeg",
AudioFormat.OPUS: "audio/opus",
AudioFormat.AAC: "audio/aac",
AudioFormat.FLAC: "audio/flac",
AudioFormat.WAV: "audio/wav",
AudioFormat.PCM: "audio/pcm"
}
# Reverse mapping for content type to format
FORMAT_FROM_CONTENT_TYPE = {v: k for k, v in CONTENT_TYPE_MAP.items()}
def get_content_type(format: Union[AudioFormat, str]) -> str:
"""Get MIME content type for audio format."""
if isinstance(format, str):
format = AudioFormat(format.lower())
return CONTENT_TYPE_MAP.get(format, "audio/mpeg")
def get_format_from_content_type(content_type: str) -> AudioFormat:
"""Get audio format from MIME content type."""
return FORMAT_FROM_CONTENT_TYPE.get(content_type, AudioFormat.MP3)
def get_supported_format(requested_format: AudioFormat) -> AudioFormat:
"""
Map requested format to supported format.
Args:
requested_format: The requested audio format
Returns:
AudioFormat: MP3 or WAV (the supported formats)
"""
if requested_format == AudioFormat.MP3:
return AudioFormat.MP3
else:
# All other formats (WAV, OPUS, AAC, FLAC, PCM) return WAV
return AudioFormat.WAV
def maps_to_wav(format_value: str) -> bool:
"""
Check if a format maps to WAV.
Args:
format_value: Format string to check
Returns:
bool: True if the format maps to WAV
"""
return format_value.lower() in ['wav', 'opus', 'aac', 'flac', 'pcm']
|