Spaces:
Runtime error
Runtime error
""" | |
Main TTS Pipeline | |
================= | |
Orchestrates the complete TTS pipeline with optimization and error handling. | |
""" | |
import logging | |
import time | |
from typing import Tuple, List, Optional, Dict, Any | |
import numpy as np | |
from .preprocessing import TextProcessor | |
from .model import OptimizedTTSModel | |
from .audio_processing import AudioProcessor | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
class TTSPipeline: | |
""" | |
High-performance TTS pipeline with advanced optimization features. | |
This pipeline combines: | |
- Intelligent text preprocessing and chunking | |
- Optimized model inference with caching | |
- Advanced audio post-processing | |
- Comprehensive error handling and logging | |
""" | |
def __init__(self, | |
model_checkpoint: str = "Edmon02/TTS_NB_2", | |
max_chunk_length: int = 200, | |
crossfade_duration: float = 0.1, | |
use_mixed_precision: bool = True, | |
device: Optional[str] = None): | |
""" | |
Initialize the TTS pipeline. | |
Args: | |
model_checkpoint: Path to the TTS model checkpoint | |
max_chunk_length: Maximum characters per text chunk | |
crossfade_duration: Crossfade duration between audio chunks | |
use_mixed_precision: Whether to use mixed precision inference | |
device: Device to use for computation | |
""" | |
self.model_checkpoint = model_checkpoint | |
self.max_chunk_length = max_chunk_length | |
self.crossfade_duration = crossfade_duration | |
logger.info("Initializing TTS Pipeline...") | |
# Initialize components | |
self.text_processor = TextProcessor(max_chunk_length=max_chunk_length) | |
self.model = OptimizedTTSModel( | |
checkpoint=model_checkpoint, | |
use_mixed_precision=use_mixed_precision, | |
device=device | |
) | |
self.audio_processor = AudioProcessor(crossfade_duration=crossfade_duration) | |
# Performance tracking | |
self.total_inferences = 0 | |
self.total_processing_time = 0.0 | |
# Warm up the model | |
self._warmup() | |
logger.info("TTS Pipeline initialized successfully") | |
def _warmup(self): | |
"""Warm up the pipeline with a test inference.""" | |
try: | |
logger.info("Warming up TTS pipeline...") | |
test_text = "Բարև ձեզ" | |
_ = self.synthesize(test_text, log_performance=False) | |
logger.info("Pipeline warmup completed") | |
except Exception as e: | |
logger.warning(f"Pipeline warmup failed: {e}") | |
def synthesize(self, | |
text: str, | |
speaker: str = "BDL", | |
enable_chunking: bool = True, | |
apply_audio_processing: bool = True, | |
log_performance: bool = True) -> Tuple[int, np.ndarray]: | |
""" | |
Main synthesis function with automatic optimization. | |
Args: | |
text: Input text to synthesize | |
speaker: Speaker identifier | |
enable_chunking: Whether to use intelligent chunking for long texts | |
apply_audio_processing: Whether to apply audio post-processing | |
log_performance: Whether to log performance metrics | |
Returns: | |
Tuple of (sample_rate, audio_array) | |
""" | |
start_time = time.time() | |
try: | |
# Validate input | |
if not text or not text.strip(): | |
logger.warning("Empty or invalid text provided") | |
return 16000, np.zeros(0, dtype=np.int16) | |
# Determine if chunking is needed | |
should_chunk = enable_chunking and len(text) > self.max_chunk_length | |
if should_chunk: | |
logger.info(f"Processing long text ({len(text)} chars) with chunking") | |
sample_rate, audio = self._synthesize_with_chunking( | |
text, speaker, apply_audio_processing | |
) | |
else: | |
logger.debug(f"Processing short text ({len(text)} chars) directly") | |
sample_rate, audio = self._synthesize_direct( | |
text, speaker, apply_audio_processing | |
) | |
# Track performance | |
total_time = time.time() - start_time | |
self.total_inferences += 1 | |
self.total_processing_time += total_time | |
if log_performance: | |
audio_duration = len(audio) / sample_rate if len(audio) > 0 else 0 | |
rtf = total_time / audio_duration if audio_duration > 0 else float('inf') | |
logger.info( | |
f"Synthesis completed: {len(text)} chars → " | |
f"{audio_duration:.2f}s audio in {total_time:.3f}s " | |
f"(RTF: {rtf:.2f})" | |
) | |
return sample_rate, audio | |
except Exception as e: | |
logger.error(f"Synthesis failed: {e}") | |
return 16000, np.zeros(0, dtype=np.int16) | |
def _synthesize_direct(self, | |
text: str, | |
speaker: str, | |
apply_audio_processing: bool) -> Tuple[int, np.ndarray]: | |
""" | |
Direct synthesis for short texts. | |
Args: | |
text: Input text | |
speaker: Speaker identifier | |
apply_audio_processing: Whether to apply post-processing | |
Returns: | |
Tuple of (sample_rate, audio_array) | |
""" | |
# Process text | |
processed_text = self.text_processor.process_text(text) | |
# Generate speech | |
sample_rate, audio = self.model.generate_speech(processed_text, speaker) | |
# Apply audio processing if requested | |
if apply_audio_processing and len(audio) > 0: | |
audio = self.audio_processor.process_audio(audio) | |
audio = self.audio_processor.add_silence(audio) | |
return sample_rate, audio | |
def _synthesize_with_chunking(self, | |
text: str, | |
speaker: str, | |
apply_audio_processing: bool) -> Tuple[int, np.ndarray]: | |
""" | |
Synthesis with intelligent chunking for long texts. | |
Args: | |
text: Input text | |
speaker: Speaker identifier | |
apply_audio_processing: Whether to apply post-processing | |
Returns: | |
Tuple of (sample_rate, audio_array) | |
""" | |
# Process and chunk text | |
chunks = self.text_processor.process_chunks(text) | |
if not chunks: | |
logger.warning("No valid chunks generated") | |
return 16000, np.zeros(0, dtype=np.int16) | |
# Generate speech for all chunks | |
sample_rate, audio = self.model.generate_speech_chunks(chunks, speaker) | |
# Apply audio processing if requested | |
if apply_audio_processing and len(audio) > 0: | |
audio = self.audio_processor.process_audio(audio) | |
audio = self.audio_processor.add_silence(audio) | |
return sample_rate, audio | |
def batch_synthesize(self, | |
texts: List[str], | |
speaker: str = "BDL", | |
enable_chunking: bool = True) -> List[Tuple[int, np.ndarray]]: | |
""" | |
Batch synthesis for multiple texts. | |
Args: | |
texts: List of input texts | |
speaker: Speaker identifier | |
enable_chunking: Whether to use chunking | |
Returns: | |
List of (sample_rate, audio_array) tuples | |
""" | |
logger.info(f"Starting batch synthesis for {len(texts)} texts") | |
results = [] | |
for i, text in enumerate(texts): | |
logger.debug(f"Processing batch item {i+1}/{len(texts)}") | |
result = self.synthesize( | |
text, | |
speaker, | |
enable_chunking=enable_chunking, | |
log_performance=False | |
) | |
results.append(result) | |
logger.info(f"Batch synthesis completed: {len(results)} items processed") | |
return results | |
def get_performance_stats(self) -> Dict[str, Any]: | |
"""Get comprehensive performance statistics.""" | |
stats = { | |
"pipeline_stats": { | |
"total_inferences": self.total_inferences, | |
"total_processing_time": self.total_processing_time, | |
"avg_processing_time": ( | |
self.total_processing_time / self.total_inferences | |
if self.total_inferences > 0 else 0 | |
) | |
}, | |
"text_processor_stats": self.text_processor.get_cache_stats(), | |
"model_stats": self.model.get_performance_stats(), | |
} | |
return stats | |
def clear_caches(self): | |
"""Clear all caches to free memory.""" | |
self.text_processor.clear_cache() | |
self.model.clear_performance_cache() | |
logger.info("All caches cleared") | |
def get_available_speakers(self) -> List[str]: | |
"""Get list of available speakers.""" | |
return self.model.get_available_speakers() | |
def optimize_for_production(self): | |
"""Apply production-level optimizations.""" | |
logger.info("Applying production optimizations...") | |
try: | |
# Optimize model | |
self.model.optimize_for_inference() | |
# Clear any unnecessary caches | |
self.clear_caches() | |
logger.info("Production optimizations applied") | |
except Exception as e: | |
logger.warning(f"Some optimizations failed: {e}") | |
def health_check(self) -> Dict[str, Any]: | |
""" | |
Perform a health check of the pipeline. | |
Returns: | |
Health status information | |
""" | |
health_status = { | |
"status": "healthy", | |
"components": {}, | |
"timestamp": time.time() | |
} | |
try: | |
# Test text processor | |
test_text = "Թեստ տեքստ" | |
processed = self.text_processor.process_text(test_text) | |
health_status["components"]["text_processor"] = { | |
"status": "ok" if processed else "error", | |
"test_result": bool(processed) | |
} | |
# Test model | |
try: | |
_, audio = self.model.generate_speech("Բարև") | |
health_status["components"]["model"] = { | |
"status": "ok" if len(audio) > 0 else "error", | |
"test_audio_samples": len(audio) | |
} | |
except Exception as e: | |
health_status["components"]["model"] = { | |
"status": "error", | |
"error": str(e) | |
} | |
# Check if any component failed | |
if any(comp.get("status") == "error" | |
for comp in health_status["components"].values()): | |
health_status["status"] = "degraded" | |
except Exception as e: | |
health_status["status"] = "error" | |
health_status["error"] = str(e) | |
return health_status | |