from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware import torch import base64 import io from typing import Dict from pydantic import BaseModel import numpy as np import re import logging from pathlib import Path import time from functools import lru_cache import multiprocessing from concurrent.futures import ThreadPoolExecutor import os from TTS.utils.manage import ModelManager from TTS.api import TTS # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TTSRequest(BaseModel): text: str class OptimizedTTSService: def __init__(self): logger.info("Initializing Optimized TTS Service...") try: # Set TTS home directory and accept license os.environ["HOME"] = "/tmp/home" os.environ["TTS_HOME"] = "/tmp/tts_home" os.environ["COQUI_TOS_AGREED"] = "1" # Accept TTS license # Set number of threads for PyTorch n_threads = max(2, multiprocessing.cpu_count() - 1) torch.set_num_threads(n_threads) logger.info(f"Using {n_threads} CPU threads") # Initialize TTS with error handling try: model_name = "tts_models/multilingual/multi-dataset/xtts_v2" logger.info(f"Loading TTS model: {model_name}") self.tts = TTS(model_name) logger.info("TTS model loaded successfully") except Exception as e: logger.error(f"Failed to load TTS model: {str(e)}") raise # Load latents try: logger.info("Loading voice latents...") latents_path = "models/goggins_latents.pt" if not os.path.exists(latents_path): raise FileNotFoundError(f"Latents file not found at {latents_path}") self.latents = torch.load(latents_path, map_location="cpu") logger.info("Latents loaded successfully") except Exception as e: logger.error(f"Failed to load latents: {str(e)}") raise # Initialize thread pool for parallel processing self.executor = ThreadPoolExecutor(max_workers=n_threads) # Configure model for inference self.model = self.tts.synthesizer.tts_model self.model.eval() # Initialize device self.device = torch.device("cpu") logger.info(f"Using device: {self.device}") # Initialize cache self._setup_cache() logger.info("Service initialization complete!") except Exception as e: logger.error(f"Failed to initialize TTS service: {str(e)}") raise def _setup_cache(self): """Setup caching mechanisms with error handling""" try: self.chunk_cache = {} self.max_cache_size = 1000 self.cache_dir = Path("tts_cache") self.cache_dir.mkdir(exist_ok=True) logger.info("Cache setup complete") except Exception as e: logger.error(f"Failed to setup cache: {str(e)}") raise def _process_chunk(self, chunk: str) -> np.ndarray: """Process a single chunk of text with improved error handling""" try: # Convert latents to tensors speaker_embedding = torch.tensor( self.latents["speaker_embedding"], dtype=torch.float32, device=self.device, ) gpt_cond_latent = torch.tensor( self.latents["gpt_cond_latent"], dtype=torch.float32, device=self.device ) # Get optimized parameters based on chunk length params = self._get_params_for_length(len(chunk)) # Generate speech with torch.no_grad(): wav = self.model.inference( text=chunk, language="en", gpt_cond_latent=gpt_cond_latent, speaker_embedding=speaker_embedding, **params, ) return wav["wav"] except Exception as e: logger.error(f"Error processing chunk '{chunk[:50]}...': {str(e)}") raise def _get_params_for_length(self, chunk_length: int) -> Dict: """Get optimized parameters based on text length""" if chunk_length <= 80: return { "temperature": 0.75, "length_penalty": 0.8, "repetition_penalty": 1.8, "top_k": 40, "top_p": 0.80, } elif chunk_length <= 150: return { "temperature": 0.85, "length_penalty": 1.0, "repetition_penalty": 2.0, "top_k": 50, "top_p": 0.85, } else: return { "temperature": 0.95, "length_penalty": 1.2, "repetition_penalty": 2.2, "top_k": 60, "top_p": 0.90, } def generate_speech(self, text: str) -> np.ndarray: """Generate speech with improved error handling""" try: # Clean and validate input if not text or not text.strip(): raise ValueError("Empty text input") text = text.strip() if len(text) > 1000: # Add reasonable limit raise ValueError("Text too long (max 1000 characters)") # Process single chunk for short text if len(text) <= 150: return self._process_chunk(text) # Split longer text into chunks chunks = text.split(". ") chunks = [chunk.strip() + "." for chunk in chunks if chunk.strip()] # Process chunks wavs = [] for i, chunk in enumerate(chunks, 1): logger.info(f"Processing chunk {i}/{len(chunks)}: {chunk[:50]}...") wav = self._process_chunk(chunk) wavs.append(wav) # Concatenate results final_wav = np.concatenate(wavs) return final_wav except Exception as e: logger.error(f"Error in generate_speech: {str(e)}") raise # Initialize FastAPI app app = FastAPI(title="Goggins TTS API") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize service service = None @app.on_event("startup") async def startup_event(): global service try: service = OptimizedTTSService() except Exception as e: logger.error(f"Failed to initialize service: {str(e)}") raise @app.post("/generate") async def generate_speech(request: TTSRequest): """Generate speech from text with detailed timing""" try: total_start = time.time() logger.info(f"\nReceived request for text: {request.text[:50]}...") # Model processing time model_start = time.time() wav = service.generate_speech(request.text) model_time = time.time() - model_start # Audio conversion time conversion_start = time.time() buffer = io.BytesIO() np.save(buffer, wav.astype(np.float32)) audio_base64 = base64.b64encode(buffer.getvalue()).decode() conversion_time = time.time() - conversion_start # Total processing time total_time = time.time() - total_start timing_info = { "total_processing_time": round(total_time, 2), "model_processing_time": round(model_time, 2), "audio_conversion_time": round(conversion_time, 2), } logger.info(f"Timing breakdown: {timing_info}") # Add the missing return statement return {"status": "success", "audio": audio_base64, "timing": timing_info} except Exception as e: logger.error(f"Error in generate_speech endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): """Health check endpoint""" if not service: raise HTTPException(status_code=503, detail="Service not initialized") return {"status": "healthy"}