Spaces:
Paused
Paused
| """ | |
| STT Lifecycle Manager for Flare | |
| =============================== | |
| Manages STT instances lifecycle per session | |
| """ | |
| import asyncio | |
| from typing import Dict, Optional, Any | |
| from datetime import datetime | |
| import traceback | |
| import base64 | |
| from event_bus import EventBus, Event, EventType, publish_error | |
| from resource_manager import ResourceManager, ResourceType | |
| from stt.stt_factory import STTFactory | |
| from stt.stt_interface import STTInterface, STTConfig, TranscriptionResult | |
| from utils.logger import log_info, log_error, log_debug, log_warning | |
| class STTSession: | |
| """STT session wrapper""" | |
| def __init__(self, session_id: str, stt_instance: STTInterface): | |
| self.session_id = session_id | |
| self.stt_instance = stt_instance | |
| self.is_streaming = False | |
| self.config: Optional[STTConfig] = None | |
| self.created_at = datetime.utcnow() | |
| self.last_activity = datetime.utcnow() | |
| self.total_chunks = 0 | |
| self.total_bytes = 0 | |
| def update_activity(self): | |
| """Update last activity timestamp""" | |
| self.last_activity = datetime.utcnow() | |
| class STTLifecycleManager: | |
| """Manages STT instances lifecycle""" | |
| def __init__(self, event_bus: EventBus, resource_manager: ResourceManager): | |
| self.event_bus = event_bus | |
| self.resource_manager = resource_manager | |
| self.stt_sessions: Dict[str, STTSession] = {} | |
| self._setup_event_handlers() | |
| self._setup_resource_pool() | |
| def _setup_event_handlers(self): | |
| """Subscribe to STT-related events""" | |
| self.event_bus.subscribe(EventType.STT_STARTED, self._handle_stt_start) | |
| self.event_bus.subscribe(EventType.STT_STOPPED, self._handle_stt_stop) | |
| self.event_bus.subscribe(EventType.AUDIO_CHUNK_RECEIVED, self._handle_audio_chunk) | |
| self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended) | |
| def _setup_resource_pool(self): | |
| """Setup STT instance pool""" | |
| self.resource_manager.register_pool( | |
| resource_type=ResourceType.STT_INSTANCE, | |
| factory=self._create_stt_instance, | |
| max_idle=5, | |
| max_age_seconds=300 # 5 minutes | |
| ) | |
| async def _create_stt_instance(self) -> STTInterface: | |
| """Factory for creating STT instances""" | |
| try: | |
| stt_instance = STTFactory.create_provider() | |
| if not stt_instance: | |
| raise ValueError("Failed to create STT instance") | |
| log_debug("🎤 Created new STT instance") | |
| return stt_instance | |
| except Exception as e: | |
| log_error(f"❌ Failed to create STT instance", error=str(e)) | |
| raise | |
| async def _handle_stt_start(self, event: Event): | |
| """Handle STT start request""" | |
| session_id = event.session_id | |
| config_data = event.data | |
| try: | |
| log_info(f"🎤 Starting STT", session_id=session_id) | |
| # Check if already exists | |
| if session_id in self.stt_sessions: | |
| stt_session = self.stt_sessions[session_id] | |
| if stt_session.is_streaming: | |
| log_warning(f"⚠️ STT already streaming", session_id=session_id) | |
| return | |
| else: | |
| # Acquire STT instance from pool | |
| resource_id = f"stt_{session_id}" | |
| stt_instance = await self.resource_manager.acquire( | |
| resource_id=resource_id, | |
| session_id=session_id, | |
| resource_type=ResourceType.STT_INSTANCE, | |
| cleanup_callback=self._cleanup_stt_instance | |
| ) | |
| # Create session wrapper | |
| stt_session = STTSession(session_id, stt_instance) | |
| self.stt_sessions[session_id] = stt_session | |
| # Get session locale from state orchestrator | |
| locale = config_data.get("locale", "tr") | |
| # Build STT config - ✅ CONTINUOUS LISTENING İÇİN AYARLAR | |
| stt_config = STTConfig( | |
| language=self._get_language_code(locale), | |
| sample_rate=config_data.get("sample_rate", 16000), | |
| encoding=config_data.get("encoding", "WEBM_OPUS"), # Try "LINEAR16" if WEBM fails | |
| enable_punctuation=config_data.get("enable_punctuation", True), | |
| enable_word_timestamps=False, | |
| model=config_data.get("model", "latest_long"), | |
| use_enhanced=config_data.get("use_enhanced", True), | |
| single_utterance=False, # ✅ Continuous listening için FALSE olmalı | |
| interim_results=True, # ✅ Interim results'ı AÇ | |
| ) | |
| # Log the exact config being used | |
| log_info(f"📋 STT Config: encoding={stt_config.encoding}, " | |
| f"sample_rate={stt_config.sample_rate}, " | |
| f"single_utterance={stt_config.single_utterance}, " | |
| f"interim_results={stt_config.interim_results}") | |
| stt_session.config = stt_config | |
| # Start streaming | |
| await stt_session.stt_instance.start_streaming(stt_config) | |
| stt_session.is_streaming = True | |
| stt_session.update_activity() | |
| log_info(f"✅ STT started in continuous mode with interim results", session_id=session_id, language=stt_config.language) | |
| # Notify STT is ready | |
| await self.event_bus.publish(Event( | |
| type=EventType.STT_READY, | |
| session_id=session_id, | |
| data={"language": stt_config.language} | |
| )) | |
| except Exception as e: | |
| log_error( | |
| f"❌ Failed to start STT", | |
| session_id=session_id, | |
| error=str(e), | |
| traceback=traceback.format_exc() | |
| ) | |
| # Clean up on error | |
| if session_id in self.stt_sessions: | |
| await self._cleanup_session(session_id) | |
| # Publish error event | |
| await publish_error( | |
| session_id=session_id, | |
| error_type="stt_error", | |
| error_message=f"Failed to start STT: {str(e)}" | |
| ) | |
| async def _handle_stt_stop(self, event: Event): | |
| """Handle STT stop request""" | |
| session_id = event.session_id | |
| reason = event.data.get("reason", "unknown") | |
| log_info(f"🛑 Stopping STT", session_id=session_id, reason=reason) | |
| stt_session = self.stt_sessions.get(session_id) | |
| if not stt_session: | |
| log_warning(f"⚠️ No STT session found", session_id=session_id) | |
| return | |
| try: | |
| if stt_session.is_streaming: | |
| # Stop streaming | |
| final_result = await stt_session.stt_instance.stop_streaming() | |
| stt_session.is_streaming = False | |
| # If we got a final result, publish it | |
| if final_result and final_result.text: | |
| await self.event_bus.publish(Event( | |
| type=EventType.STT_RESULT, | |
| session_id=session_id, | |
| data={ | |
| "text": final_result.text, | |
| "is_final": True, | |
| "confidence": final_result.confidence | |
| } | |
| )) | |
| # Don't remove session immediately - might restart | |
| stt_session.update_activity() | |
| log_info(f"✅ STT stopped", session_id=session_id) | |
| except Exception as e: | |
| log_error( | |
| f"❌ Error stopping STT", | |
| session_id=session_id, | |
| error=str(e) | |
| ) | |
| async def _handle_audio_chunk(self, event: Event): | |
| """Process audio chunk through STT""" | |
| session_id = event.session_id | |
| stt_session = self.stt_sessions.get(session_id) | |
| if not stt_session or not stt_session.is_streaming: | |
| # STT not ready, ignore chunk | |
| return | |
| try: | |
| # Decode audio data | |
| audio_data = base64.b64decode(event.data.get("audio_data", "")) | |
| # Update stats | |
| stt_session.total_chunks += 1 | |
| stt_session.total_bytes += len(audio_data) | |
| stt_session.update_activity() | |
| # Stream to STT | |
| async for result in stt_session.stt_instance.stream_audio(audio_data): | |
| # Publish transcription results | |
| await self.event_bus.publish(Event( | |
| type=EventType.STT_RESULT, | |
| session_id=session_id, | |
| data={ | |
| "text": result.text, | |
| "is_final": result.is_final, | |
| "confidence": result.confidence, | |
| "timestamp": result.timestamp | |
| } | |
| )) | |
| # Log final results | |
| if result.is_final: | |
| log_info( | |
| f"📝 STT final result", | |
| session_id=session_id, | |
| text=result.text[:50] + "..." if len(result.text) > 50 else result.text, | |
| confidence=result.confidence | |
| ) | |
| # Log progress periodically | |
| if stt_session.total_chunks % 100 == 0: | |
| log_debug( | |
| f"📊 STT progress", | |
| session_id=session_id, | |
| chunks=stt_session.total_chunks, | |
| bytes=stt_session.total_bytes | |
| ) | |
| except Exception as e: | |
| log_error( | |
| f"❌ Error processing audio chunk", | |
| session_id=session_id, | |
| error=str(e) | |
| ) | |
| # Check if it's a recoverable error | |
| if "stream duration" in str(e) or "timeout" in str(e).lower(): | |
| # STT timeout, restart needed | |
| await publish_error( | |
| session_id=session_id, | |
| error_type="stt_timeout", | |
| error_message="STT stream timeout, restart needed" | |
| ) | |
| else: | |
| # Other STT error | |
| await publish_error( | |
| session_id=session_id, | |
| error_type="stt_error", | |
| error_message=str(e) | |
| ) | |
| async def _handle_session_ended(self, event: Event): | |
| """Clean up STT resources when session ends""" | |
| session_id = event.session_id | |
| await self._cleanup_session(session_id) | |
| async def _cleanup_session(self, session_id: str): | |
| """Clean up STT session""" | |
| stt_session = self.stt_sessions.pop(session_id, None) | |
| if not stt_session: | |
| return | |
| try: | |
| # Stop streaming if active | |
| if stt_session.is_streaming: | |
| await stt_session.stt_instance.stop_streaming() | |
| # Release resource | |
| resource_id = f"stt_{session_id}" | |
| await self.resource_manager.release(resource_id, delay_seconds=60) | |
| log_info( | |
| f"🧹 STT session cleaned up", | |
| session_id=session_id, | |
| total_chunks=stt_session.total_chunks, | |
| total_bytes=stt_session.total_bytes | |
| ) | |
| except Exception as e: | |
| log_error( | |
| f"❌ Error cleaning up STT session", | |
| session_id=session_id, | |
| error=str(e) | |
| ) | |
| async def _cleanup_stt_instance(self, stt_instance: STTInterface): | |
| """Cleanup callback for STT instance""" | |
| try: | |
| # Ensure streaming is stopped | |
| if hasattr(stt_instance, 'is_streaming') and stt_instance.is_streaming: | |
| await stt_instance.stop_streaming() | |
| log_debug("🧹 STT instance cleaned up") | |
| except Exception as e: | |
| log_error(f"❌ Error cleaning up STT instance", error=str(e)) | |
| def _get_language_code(self, locale: str) -> str: | |
| """Convert locale to STT language code""" | |
| # Map common locales to STT language codes | |
| locale_map = { | |
| "tr": "tr-TR", | |
| "en": "en-US", | |
| "de": "de-DE", | |
| "fr": "fr-FR", | |
| "es": "es-ES", | |
| "it": "it-IT", | |
| "pt": "pt-BR", | |
| "ru": "ru-RU", | |
| "ja": "ja-JP", | |
| "ko": "ko-KR", | |
| "zh": "zh-CN", | |
| "ar": "ar-SA" | |
| } | |
| # Check direct match | |
| if locale in locale_map: | |
| return locale_map[locale] | |
| # Check if it's already a full code | |
| if "-" in locale and len(locale) == 5: | |
| return locale | |
| # Default to locale-LOCALE format | |
| return f"{locale}-{locale.upper()}" | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get STT manager statistics""" | |
| session_stats = {} | |
| for session_id, stt_session in self.stt_sessions.items(): | |
| session_stats[session_id] = { | |
| "is_streaming": stt_session.is_streaming, | |
| "total_chunks": stt_session.total_chunks, | |
| "total_bytes": stt_session.total_bytes, | |
| "uptime_seconds": (datetime.utcnow() - stt_session.created_at).total_seconds(), | |
| "last_activity": stt_session.last_activity.isoformat() | |
| } | |
| return { | |
| "active_sessions": len(self.stt_sessions), | |
| "streaming_sessions": sum(1 for s in self.stt_sessions.values() if s.is_streaming), | |
| "sessions": session_stats | |
| } |