Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
import torch | |
import base64 | |
import io | |
from typing import Dict | |
from pydantic import BaseModel | |
import numpy as np | |
import re | |
import logging | |
from pathlib import Path | |
import time | |
from functools import lru_cache | |
import multiprocessing | |
from concurrent.futures import ThreadPoolExecutor | |
import os | |
from TTS.utils.manage import ModelManager | |
from TTS.api import TTS | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class TTSRequest(BaseModel): | |
text: str | |
class OptimizedTTSService: | |
def __init__(self): | |
logger.info("Initializing Optimized TTS Service...") | |
try: | |
# Set TTS home directory and accept license | |
os.environ["HOME"] = "/tmp/home" | |
os.environ["TTS_HOME"] = "/tmp/tts_home" | |
os.environ["COQUI_TOS_AGREED"] = "1" # Accept TTS license | |
# Set number of threads for PyTorch | |
n_threads = max(2, multiprocessing.cpu_count() - 1) | |
torch.set_num_threads(n_threads) | |
logger.info(f"Using {n_threads} CPU threads") | |
# Initialize TTS with error handling | |
try: | |
model_name = "tts_models/multilingual/multi-dataset/xtts_v2" | |
logger.info(f"Loading TTS model: {model_name}") | |
self.tts = TTS(model_name) | |
logger.info("TTS model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load TTS model: {str(e)}") | |
raise | |
# Load latents | |
try: | |
logger.info("Loading voice latents...") | |
latents_path = "models/goggins_latents.pt" | |
if not os.path.exists(latents_path): | |
raise FileNotFoundError(f"Latents file not found at {latents_path}") | |
self.latents = torch.load(latents_path, map_location="cpu") | |
logger.info("Latents loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load latents: {str(e)}") | |
raise | |
# Initialize thread pool for parallel processing | |
self.executor = ThreadPoolExecutor(max_workers=n_threads) | |
# Configure model for inference | |
self.model = self.tts.synthesizer.tts_model | |
self.model.eval() | |
# Initialize device | |
self.device = torch.device("cpu") | |
logger.info(f"Using device: {self.device}") | |
# Initialize cache | |
self._setup_cache() | |
logger.info("Service initialization complete!") | |
except Exception as e: | |
logger.error(f"Failed to initialize TTS service: {str(e)}") | |
raise | |
def _setup_cache(self): | |
"""Setup caching mechanisms with error handling""" | |
try: | |
self.chunk_cache = {} | |
self.max_cache_size = 1000 | |
self.cache_dir = Path("tts_cache") | |
self.cache_dir.mkdir(exist_ok=True) | |
logger.info("Cache setup complete") | |
except Exception as e: | |
logger.error(f"Failed to setup cache: {str(e)}") | |
raise | |
def _process_chunk(self, chunk: str) -> np.ndarray: | |
"""Process a single chunk of text with improved error handling""" | |
try: | |
# Convert latents to tensors | |
speaker_embedding = torch.tensor( | |
self.latents["speaker_embedding"], | |
dtype=torch.float32, | |
device=self.device, | |
) | |
gpt_cond_latent = torch.tensor( | |
self.latents["gpt_cond_latent"], dtype=torch.float32, device=self.device | |
) | |
# Get optimized parameters based on chunk length | |
params = self._get_params_for_length(len(chunk)) | |
# Generate speech | |
with torch.no_grad(): | |
wav = self.model.inference( | |
text=chunk, | |
language="en", | |
gpt_cond_latent=gpt_cond_latent, | |
speaker_embedding=speaker_embedding, | |
**params, | |
) | |
return wav["wav"] | |
except Exception as e: | |
logger.error(f"Error processing chunk '{chunk[:50]}...': {str(e)}") | |
raise | |
def _get_params_for_length(self, chunk_length: int) -> Dict: | |
"""Get optimized parameters based on text length""" | |
if chunk_length <= 80: | |
return { | |
"temperature": 0.75, | |
"length_penalty": 0.8, | |
"repetition_penalty": 1.8, | |
"top_k": 40, | |
"top_p": 0.80, | |
} | |
elif chunk_length <= 150: | |
return { | |
"temperature": 0.85, | |
"length_penalty": 1.0, | |
"repetition_penalty": 2.0, | |
"top_k": 50, | |
"top_p": 0.85, | |
} | |
else: | |
return { | |
"temperature": 0.95, | |
"length_penalty": 1.2, | |
"repetition_penalty": 2.2, | |
"top_k": 60, | |
"top_p": 0.90, | |
} | |
def generate_speech(self, text: str) -> np.ndarray: | |
"""Generate speech with improved error handling""" | |
try: | |
# Clean and validate input | |
if not text or not text.strip(): | |
raise ValueError("Empty text input") | |
text = text.strip() | |
if len(text) > 1000: # Add reasonable limit | |
raise ValueError("Text too long (max 1000 characters)") | |
# Process single chunk for short text | |
if len(text) <= 150: | |
return self._process_chunk(text) | |
# Split longer text into chunks | |
chunks = text.split(". ") | |
chunks = [chunk.strip() + "." for chunk in chunks if chunk.strip()] | |
# Process chunks | |
wavs = [] | |
for i, chunk in enumerate(chunks, 1): | |
logger.info(f"Processing chunk {i}/{len(chunks)}: {chunk[:50]}...") | |
wav = self._process_chunk(chunk) | |
wavs.append(wav) | |
# Concatenate results | |
final_wav = np.concatenate(wavs) | |
return final_wav | |
except Exception as e: | |
logger.error(f"Error in generate_speech: {str(e)}") | |
raise | |
# Initialize FastAPI app | |
app = FastAPI(title="Goggins TTS API") | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize service | |
service = None | |
async def startup_event(): | |
global service | |
try: | |
service = OptimizedTTSService() | |
except Exception as e: | |
logger.error(f"Failed to initialize service: {str(e)}") | |
raise | |
async def generate_speech(request: TTSRequest): | |
"""Generate speech from text with detailed timing""" | |
try: | |
total_start = time.time() | |
logger.info(f"\nReceived request for text: {request.text[:50]}...") | |
# Model processing time | |
model_start = time.time() | |
wav = service.generate_speech(request.text) | |
model_time = time.time() - model_start | |
# Audio conversion time | |
conversion_start = time.time() | |
buffer = io.BytesIO() | |
np.save(buffer, wav.astype(np.float32)) | |
audio_base64 = base64.b64encode(buffer.getvalue()).decode() | |
conversion_time = time.time() - conversion_start | |
# Total processing time | |
total_time = time.time() - total_start | |
timing_info = { | |
"total_processing_time": round(total_time, 2), | |
"model_processing_time": round(model_time, 2), | |
"audio_conversion_time": round(conversion_time, 2), | |
} | |
logger.info(f"Timing breakdown: {timing_info}") | |
# Add the missing return statement | |
return {"status": "success", "audio": audio_base64, "timing": timing_info} | |
except Exception as e: | |
logger.error(f"Error in generate_speech endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
"""Health check endpoint""" | |
if not service: | |
raise HTTPException(status_code=503, detail="Service not initialized") | |
return {"status": "healthy"} |