goggins-chat / app.py
shriniket73's picture
Update app.py
c6bc3fd verified
raw
history blame
8.46 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import torch
import base64
import io
from typing import Dict
from pydantic import BaseModel
import numpy as np
import re
import logging
from pathlib import Path
import time
from functools import lru_cache
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import os
from TTS.utils.manage import ModelManager
from TTS.api import TTS
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TTSRequest(BaseModel):
text: str
class OptimizedTTSService:
def __init__(self):
logger.info("Initializing Optimized TTS Service...")
try:
# Set TTS home directory and accept license
os.environ["HOME"] = "/tmp/home"
os.environ["TTS_HOME"] = "/tmp/tts_home"
os.environ["COQUI_TOS_AGREED"] = "1" # Accept TTS license
# Set number of threads for PyTorch
n_threads = max(2, multiprocessing.cpu_count() - 1)
torch.set_num_threads(n_threads)
logger.info(f"Using {n_threads} CPU threads")
# Initialize TTS with error handling
try:
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
logger.info(f"Loading TTS model: {model_name}")
self.tts = TTS(model_name)
logger.info("TTS model loaded successfully")
except Exception as e:
logger.error(f"Failed to load TTS model: {str(e)}")
raise
# Load latents
try:
logger.info("Loading voice latents...")
latents_path = "models/goggins_latents.pt"
if not os.path.exists(latents_path):
raise FileNotFoundError(f"Latents file not found at {latents_path}")
self.latents = torch.load(latents_path, map_location="cpu")
logger.info("Latents loaded successfully")
except Exception as e:
logger.error(f"Failed to load latents: {str(e)}")
raise
# Initialize thread pool for parallel processing
self.executor = ThreadPoolExecutor(max_workers=n_threads)
# Configure model for inference
self.model = self.tts.synthesizer.tts_model
self.model.eval()
# Initialize device
self.device = torch.device("cpu")
logger.info(f"Using device: {self.device}")
# Initialize cache
self._setup_cache()
logger.info("Service initialization complete!")
except Exception as e:
logger.error(f"Failed to initialize TTS service: {str(e)}")
raise
def _setup_cache(self):
"""Setup caching mechanisms with error handling"""
try:
self.chunk_cache = {}
self.max_cache_size = 1000
self.cache_dir = Path("tts_cache")
self.cache_dir.mkdir(exist_ok=True)
logger.info("Cache setup complete")
except Exception as e:
logger.error(f"Failed to setup cache: {str(e)}")
raise
def _process_chunk(self, chunk: str) -> np.ndarray:
"""Process a single chunk of text with improved error handling"""
try:
# Convert latents to tensors
speaker_embedding = torch.tensor(
self.latents["speaker_embedding"],
dtype=torch.float32,
device=self.device,
)
gpt_cond_latent = torch.tensor(
self.latents["gpt_cond_latent"], dtype=torch.float32, device=self.device
)
# Get optimized parameters based on chunk length
params = self._get_params_for_length(len(chunk))
# Generate speech
with torch.no_grad():
wav = self.model.inference(
text=chunk,
language="en",
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
**params,
)
return wav["wav"]
except Exception as e:
logger.error(f"Error processing chunk '{chunk[:50]}...': {str(e)}")
raise
def _get_params_for_length(self, chunk_length: int) -> Dict:
"""Get optimized parameters based on text length"""
if chunk_length <= 80:
return {
"temperature": 0.75,
"length_penalty": 0.8,
"repetition_penalty": 1.8,
"top_k": 40,
"top_p": 0.80,
}
elif chunk_length <= 150:
return {
"temperature": 0.85,
"length_penalty": 1.0,
"repetition_penalty": 2.0,
"top_k": 50,
"top_p": 0.85,
}
else:
return {
"temperature": 0.95,
"length_penalty": 1.2,
"repetition_penalty": 2.2,
"top_k": 60,
"top_p": 0.90,
}
def generate_speech(self, text: str) -> np.ndarray:
"""Generate speech with improved error handling"""
try:
# Clean and validate input
if not text or not text.strip():
raise ValueError("Empty text input")
text = text.strip()
if len(text) > 1000: # Add reasonable limit
raise ValueError("Text too long (max 1000 characters)")
# Process single chunk for short text
if len(text) <= 150:
return self._process_chunk(text)
# Split longer text into chunks
chunks = text.split(". ")
chunks = [chunk.strip() + "." for chunk in chunks if chunk.strip()]
# Process chunks
wavs = []
for i, chunk in enumerate(chunks, 1):
logger.info(f"Processing chunk {i}/{len(chunks)}: {chunk[:50]}...")
wav = self._process_chunk(chunk)
wavs.append(wav)
# Concatenate results
final_wav = np.concatenate(wavs)
return final_wav
except Exception as e:
logger.error(f"Error in generate_speech: {str(e)}")
raise
# Initialize FastAPI app
app = FastAPI(title="Goggins TTS API")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize service
service = None
@app.on_event("startup")
async def startup_event():
global service
try:
service = OptimizedTTSService()
except Exception as e:
logger.error(f"Failed to initialize service: {str(e)}")
raise
@app.post("/generate")
async def generate_speech(request: TTSRequest):
"""Generate speech from text with detailed timing"""
try:
total_start = time.time()
logger.info(f"\nReceived request for text: {request.text[:50]}...")
# Model processing time
model_start = time.time()
wav = service.generate_speech(request.text)
model_time = time.time() - model_start
# Audio conversion time
conversion_start = time.time()
buffer = io.BytesIO()
np.save(buffer, wav.astype(np.float32))
audio_base64 = base64.b64encode(buffer.getvalue()).decode()
conversion_time = time.time() - conversion_start
# Total processing time
total_time = time.time() - total_start
timing_info = {
"total_processing_time": round(total_time, 2),
"model_processing_time": round(model_time, 2),
"audio_conversion_time": round(conversion_time, 2),
}
logger.info(f"Timing breakdown: {timing_info}")
# Add the missing return statement
return {"status": "success", "audio": audio_base64, "timing": timing_info}
except Exception as e:
logger.error(f"Error in generate_speech endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
if not service:
raise HTTPException(status_code=503, detail="Service not initialized")
return {"status": "healthy"}