Spaces:
Running
Running
| import uvicorn | |
| import os | |
| import asyncio | |
| import io | |
| import time | |
| import re | |
| import shutil | |
| from contextlib import asynccontextmanager | |
| from typing import Optional, AsyncGenerator, List | |
| import logging | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel, Field | |
| import numpy as np | |
| from pydub import AudioSegment | |
| from kittentts import KittenTTS | |
| LOG_LEVEL = os.getenv("LOG_LEVEL", "WARNING").upper() | |
| logging.basicConfig( | |
| level=LOG_LEVEL, | |
| format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # --- FFmpeg Detection (IMPROVED) --- | |
| def setup_ffmpeg(): | |
| """Detect FFmpeg in system PATH and configure pydub""" | |
| # Check if ffmpeg is available in system PATH | |
| ffmpeg_path = shutil.which("ffmpeg") | |
| if ffmpeg_path: | |
| logger.info(f"β FFmpeg found at: {ffmpeg_path}") | |
| # Test if FFmpeg can actually export MP3 | |
| try: | |
| # Create a simple test audio and try MP3 export | |
| test_audio = AudioSegment.silent(duration=100) # 100ms silence | |
| test_buffer = io.BytesIO() | |
| test_audio.export(test_buffer, format="mp3") | |
| print("β FFmpeg MP3 export test: PASSED") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β FFmpeg MP3 export test failed: {e}") | |
| return False | |
| else: | |
| logger.warning("β FFmpeg not found in PATH") | |
| logger.warning("π‘ Make sure FFmpeg is installed and available in system PATH") | |
| return False | |
| # Check FFmpeg availability | |
| ffmpeg_available = setup_ffmpeg() | |
| # --- Configuration --- | |
| class Config: | |
| MODEL_NAME = os.getenv("MODEL_NAME", "KittenML/kitten-tts-nano-0.2") | |
| MAX_TEXT_LENGTH = int(os.getenv("MAX_TEXT_LENGTH", "2000")) | |
| # Audio Properties | |
| FRAME_RATE = 24000 | |
| CHANNELS = 1 | |
| SAMPLE_WIDTH = 2 | |
| # Available voices | |
| VOICES = [ | |
| "expr-voice-2-f", "expr-voice-2-m", "expr-voice-3-f", "expr-voice-3-m", | |
| "expr-voice-4-f", "expr-voice-4-m", "expr-voice-5-f", "expr-voice-5-m" | |
| ] | |
| # --- Global State --- | |
| class AppState: | |
| model: Optional[KittenTTS] = None | |
| model_ready: bool = False | |
| app_state = AppState() | |
| # --- Lifespan Management --- | |
| async def lifespan(app: FastAPI): | |
| # Startup | |
| print("π Starting Kitten TTS API...") | |
| # Load model | |
| try: | |
| print(f"π¦ Loading model: {Config.MODEL_NAME}") | |
| app_state.model = KittenTTS(Config.MODEL_NAME) | |
| # Quick warm-up | |
| print("π₯ Warming up model...") | |
| test_audio = app_state.model.generate(text="Hello", voice=Config.VOICES[0]) | |
| print(f"β Model warm-up complete. Test audio shape: {test_audio.shape}") | |
| app_state.model_ready = True | |
| print("β Model loaded and ready!") | |
| except Exception as e: | |
| logger.critical(f"β Model loading failed: {e}", exc_info=True) | |
| app_state.model_ready = False | |
| yield | |
| # Shutdown | |
| print("π Shutting down Kitten TTS API...") | |
| app_state.model_ready = False | |
| app_state.model = None | |
| # --- App Initialization --- | |
| app = FastAPI( | |
| title="Kitten TTS API", | |
| version="1.1.0", | |
| description="High-quality Text-to-Speech API with streaming support", | |
| lifespan=lifespan | |
| ) | |
| # --- CORS Middleware --- | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- Pydantic Models --- | |
| class SpeechRequest(BaseModel): | |
| input: str = Field(..., min_length=1, max_length=Config.MAX_TEXT_LENGTH) | |
| model: str = Field(default="kitten-nano-0.2") | |
| voice: str = Field(default=Config.VOICES[0]) | |
| speed: float = Field(default=1.0, ge=0.5, le=2.0) | |
| response_format: str = Field(default="mp3", pattern="^(mp3|wav)$") | |
| class HealthResponse(BaseModel): | |
| class Config: | |
| protected_namespaces = () | |
| status: str | |
| model_ready: bool | |
| voices_available: int | |
| version: str | |
| ffmpeg_available: bool | |
| # --- Text Chunking --- | |
| def split_text_for_streaming(text: str) -> List[str]: | |
| """Split text into natural speaking chunks.""" | |
| if len(text) <= 150: | |
| return [text] | |
| sentences = re.split(r'(?<=[.!?;:])\s+', text) | |
| chunks = [] | |
| current_chunk = "" | |
| for sentence in sentences: | |
| if not sentence.strip(): | |
| continue | |
| if current_chunk and len(current_chunk) + len(sentence) > 200: | |
| chunks.append(current_chunk.strip()) | |
| current_chunk = sentence | |
| else: | |
| current_chunk = f"{current_chunk} {sentence}".strip() if current_chunk else sentence | |
| if current_chunk: | |
| chunks.append(current_chunk) | |
| logger.info(f"π Split text into {len(chunks)} chunks") | |
| return [chunk for chunk in chunks if chunk.strip()] | |
| # --- Audio Generation (IMPROVED) --- | |
| def _generate_audio_chunk(text: str, voice: str, speed: float, format: str) -> Optional[bytes]: | |
| """Generate audio chunk in specified format.""" | |
| try: | |
| if not app_state.model or not app_state.model_ready: | |
| raise RuntimeError("Model not ready") | |
| logger.info(f"π΅ Generating audio for: '{text[:50]}...'") | |
| # Generate audio | |
| numpy_audio_data = app_state.model.generate(text=text, voice=voice) | |
| # Debug audio range | |
| audio_range = np.abs(numpy_audio_data).max() | |
| logger.debug(f"π Audio range: {audio_range:.6f}") | |
| if audio_range < 0.001: | |
| logger.warning(f"β οΈ WARNING: Generated audio appears to be silent!") | |
| # Convert to 16-bit PCM | |
| numpy_audio_int16 = (numpy_audio_data * 32767).astype(np.int16) | |
| raw_pcm_bytes = numpy_audio_int16.tobytes() | |
| # Create AudioSegment | |
| audio_segment = AudioSegment( | |
| data=raw_pcm_bytes, | |
| sample_width=Config.SAMPLE_WIDTH, | |
| frame_rate=Config.FRAME_RATE, | |
| channels=Config.CHANNELS | |
| ) | |
| # Apply speed adjustment | |
| if speed != 1.0: | |
| logger.info(f"β‘ Applying speed: {speed}x") | |
| audio_segment = audio_segment.speedup(playback_speed=speed) | |
| # Export to requested format | |
| buffer = io.BytesIO() | |
| if format == "mp3" and ffmpeg_available: | |
| try: | |
| audio_segment.export(buffer, format="mp3", bitrate="64k") | |
| mp3_data = buffer.getvalue() | |
| logger.debug(f"π¦ Generated MP3 chunk: {len(mp3_data)} bytes") | |
| return mp3_data | |
| except Exception as e: | |
| logger.warning(f"β MP3 export failed, falling back to WAV: {e}") | |
| # Clear buffer and fall back to WAV | |
| buffer = io.BytesIO() | |
| format = "wav" | |
| # WAV format (fallback or requested) | |
| audio_segment.export(buffer, format="wav") | |
| wav_data = buffer.getvalue() | |
| logger.debug(f"π¦ Generated WAV chunk: {len(wav_data)} bytes") | |
| return wav_data | |
| except Exception as e: | |
| logger.exception(f"β Audio generation error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| async def audio_stream_generator(text: str, voice: str, speed: float, format: str) -> AsyncGenerator[bytes, None]: | |
| """Async generator for audio streaming.""" | |
| chunks = split_text_for_streaming(text) | |
| if not chunks: | |
| yield b"" | |
| return | |
| for i, chunk in enumerate(chunks): | |
| if not chunk.strip(): | |
| continue | |
| logger.info(f"π΅ Processing chunk {i+1}/{len(chunks)}") | |
| audio_chunk_bytes = await asyncio.to_thread( | |
| _generate_audio_chunk, | |
| text=chunk, | |
| voice=voice, | |
| speed=speed, | |
| format=format | |
| ) | |
| if audio_chunk_bytes: | |
| yield audio_chunk_bytes | |
| await asyncio.sleep(0.01) | |
| # --- WAV Generation --- | |
| def generate_wav_audio(text: str, voice: str, speed: float) -> bytes: | |
| """Generate WAV audio without streaming.""" | |
| try: | |
| if not app_state.model_ready: | |
| raise RuntimeError("Service unavailable") | |
| # Generate audio | |
| numpy_audio_data = app_state.model.generate(text=text, voice=voice) | |
| numpy_audio_int16 = (numpy_audio_data * 32767).astype(np.int16) | |
| raw_pcm_bytes = numpy_audio_int16.tobytes() | |
| # Create audio segment | |
| audio_segment = AudioSegment( | |
| data=raw_pcm_bytes, | |
| sample_width=Config.SAMPLE_WIDTH, | |
| frame_rate=Config.FRAME_RATE, | |
| channels=Config.CHANNELS | |
| ) | |
| # Apply speed | |
| if speed != 1.0: | |
| audio_segment = audio_segment.speedup(playback_speed=speed) | |
| # Export to WAV | |
| wav_io = io.BytesIO() | |
| audio_segment.export(wav_io, format="wav") | |
| return wav_io.getvalue() | |
| except Exception as e: | |
| logger.exception(f"β WAV generation error: {e}") | |
| raise RuntimeError("Audio generation failed") | |
| # --- API Endpoints --- | |
| async def generate_speech(speech_request: SpeechRequest): | |
| """Generate speech audio with streaming support.""" | |
| if speech_request.voice not in Config.VOICES: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Voice must be one of {Config.VOICES}" | |
| ) | |
| if not app_state.model_ready: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Service temporarily unavailable." | |
| ) | |
| try: | |
| logger.info(f"π― TTS Request: voice={speech_request.voice}, speed={speech_request.speed}, format={speech_request.response_format}") | |
| if speech_request.response_format == "mp3": | |
| return StreamingResponse( | |
| audio_stream_generator( | |
| text=speech_request.input, | |
| voice=speech_request.voice, | |
| speed=speech_request.speed, | |
| format="mp3" | |
| ), | |
| media_type="audio/mpeg", | |
| headers={"Content-Disposition": "attachment; filename=speech.mp3"} | |
| ) | |
| elif speech_request.response_format == "wav": | |
| wav_data = await asyncio.to_thread( | |
| generate_wav_audio, | |
| speech_request.input, | |
| speech_request.voice, | |
| speech_request.speed | |
| ) | |
| return StreamingResponse( | |
| io.BytesIO(wav_data), | |
| media_type="audio/wav", | |
| headers={"Content-Disposition": "attachment; filename=speech.wav"} | |
| ) | |
| except Exception as e: | |
| logger.exception(f"β Endpoint error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}") | |
| async def list_voices(): | |
| """List available voices.""" | |
| return {"voices": Config.VOICES} | |
| async def health_check() -> HealthResponse: | |
| """Health check endpoint.""" | |
| return HealthResponse( | |
| status="healthy" if app_state.model_ready else "unhealthy", | |
| model_ready=app_state.model_ready, | |
| voices_available=len(Config.VOICES), | |
| version="1.1.0", | |
| ffmpeg_available=ffmpeg_available | |
| ) | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| workers=1, | |
| log_level="info" | |
| ) |