| """ |
| WebRTC WebSocket Handler for Real-time Audio Streaming |
| Integrates with FastAPI for unmute.sh-style voice interaction |
| """ |
|
|
| import asyncio |
| import json |
| import logging |
| from typing import Dict, Optional |
| import websockets |
| from fastapi import WebSocket, WebSocketDisconnect |
| import numpy as np |
| import soundfile as sf |
| import tempfile |
| import os |
| from datetime import datetime |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class WebRTCHandler: |
| """Handles WebRTC WebSocket connections for real-time audio streaming""" |
| |
| def __init__(self): |
| self.active_connections: Dict[str, WebSocket] = {} |
| self.audio_buffers: Dict[str, list] = {} |
| self.stt_service_url = "https://pgits-stt-gpu-service.hf.space" |
| self.stt_websocket_url = "wss://pgits-stt-gpu-service.hf.space/ws/stt" |
| self.stt_connections: Dict[str, websockets.WebSocketClientProtocol] = {} |
| |
| self.tts_service_url = "https://pgits-tts-gpu-service.hf.space" |
| self.tts_websocket_url = "wss://pgits-tts-gpu-service.hf.space/ws/tts" |
| self.tts_connections: Dict[str, websockets.WebSocketClientProtocol] = {} |
| |
| async def connect(self, websocket: WebSocket, client_id: str): |
| """Accept WebSocket connection and initialize audio buffer""" |
| await websocket.accept() |
| self.active_connections[client_id] = websocket |
| self.audio_buffers[client_id] = [] |
| |
| logger.info(f"π WebRTC client {client_id} connected") |
| |
| |
| await self.send_message(client_id, { |
| "type": "connection_confirmed", |
| "client_id": client_id, |
| "timestamp": datetime.now().isoformat(), |
| "services": { |
| "stt": self.stt_service_url, |
| "status": "ready" |
| } |
| }) |
| |
| async def disconnect(self, client_id: str): |
| """Clean up connection and buffers""" |
| if client_id in self.active_connections: |
| del self.active_connections[client_id] |
| if client_id in self.audio_buffers: |
| del self.audio_buffers[client_id] |
| |
| |
| await self.disconnect_from_stt_service(client_id) |
| |
| |
| await self.disconnect_from_tts_service(client_id) |
| |
| logger.info(f"π WebRTC client {client_id} disconnected") |
| |
| async def send_message(self, client_id: str, message: dict): |
| """Send JSON message to client""" |
| if client_id in self.active_connections: |
| websocket = self.active_connections[client_id] |
| try: |
| await websocket.send_text(json.dumps(message)) |
| except Exception as e: |
| logger.error(f"Failed to send message to {client_id}: {e}") |
| await self.disconnect(client_id) |
| |
| async def handle_audio_chunk(self, client_id: str, audio_data: bytes, sample_rate: int = 16000): |
| """Process incoming audio chunk using unmute.sh streaming methodology""" |
| try: |
| logger.info(f"π€ Received {len(audio_data)} bytes from {client_id}") |
| |
| |
| if client_id not in self.audio_buffers: |
| self.audio_buffers[client_id] = [] |
| |
| |
| self.audio_buffers[client_id].append(audio_data) |
| |
| |
| await self.send_message(client_id, { |
| "type": "chunk_buffered", |
| "chunk_size": len(audio_data), |
| "buffer_chunks": len(self.audio_buffers[client_id]), |
| "timestamp": datetime.now().isoformat() |
| }) |
| |
| logger.info(f"π¦ Buffered chunk for {client_id} ({len(self.audio_buffers[client_id])} total chunks)") |
| |
| except Exception as e: |
| logger.error(f"Error buffering audio chunk for {client_id}: {e}") |
| await self.send_message(client_id, { |
| "type": "error", |
| "message": f"Audio buffering error: {str(e)}", |
| "timestamp": datetime.now().isoformat() |
| }) |
| |
| async def process_buffered_audio_with_flush(self, client_id: str): |
| """Process all buffered audio chunks with unmute.sh flush trick""" |
| try: |
| if client_id not in self.audio_buffers or not self.audio_buffers[client_id]: |
| logger.info(f"No audio chunks to process for {client_id}") |
| return |
| |
| |
| all_audio_data = b''.join(self.audio_buffers[client_id]) |
| total_chunks = len(self.audio_buffers[client_id]) |
| |
| logger.info(f"π Processing {total_chunks} buffered chunks ({len(all_audio_data)} bytes total) with flush trick") |
| |
| |
| with tempfile.NamedTemporaryFile(suffix='.webm', delete=False) as tmp_file: |
| tmp_file.write(all_audio_data) |
| tmp_file_path = tmp_file.name |
| |
| try: |
| |
| transcription = await self.process_audio_file_webrtc_with_flush(tmp_file_path) |
| |
| if transcription and transcription.strip() and not transcription.startswith("ERROR"): |
| |
| await self.send_message(client_id, { |
| "type": "transcription", |
| "text": transcription.strip(), |
| "timestamp": datetime.now().isoformat(), |
| "audio_size": len(all_audio_data), |
| "format": "webm/audio", |
| "is_final": True, |
| "chunks_processed": total_chunks |
| }) |
| |
| logger.info(f"π Final transcription sent to {client_id}: {transcription[:50]}...") |
| else: |
| |
| await self.send_message(client_id, { |
| "type": "transcription_error", |
| "message": f"Audio processing failed: {transcription if transcription else 'No result'}", |
| "timestamp": datetime.now().isoformat() |
| }) |
| finally: |
| |
| if os.path.exists(tmp_file_path): |
| os.unlink(tmp_file_path) |
| |
| |
| self.audio_buffers[client_id] = [] |
| logger.info(f"π§Ή Cleared audio buffer for {client_id}") |
| |
| except Exception as e: |
| logger.error(f"Error processing buffered audio for {client_id}: {e}") |
| await self.send_message(client_id, { |
| "type": "transcription_error", |
| "message": f"Buffered audio processing error: {str(e)}", |
| "timestamp": datetime.now().isoformat() |
| }) |
|
|
| async def process_audio_file_webrtc_with_flush(self, audio_file_path: str) -> Optional[str]: |
| """Process audio file using unmute.sh flush trick methodology""" |
| try: |
| |
| from core.mcp_audio_handler import mcp_audio_handler |
| |
| |
| result = await mcp_audio_handler.speech_to_text(audio_file_path) |
| |
| logger.info(f"π FLUSH TRICK: STT service returned: {result[:100] if result else 'None'}...") |
| |
| return result |
| |
| except Exception as e: |
| logger.error(f"Error in flush trick audio processing: {e}") |
| return f"ERROR: Flush trick processing failed - {str(e)}" |
|
|
| async def connect_to_stt_service(self, client_id: str) -> bool: |
| """Connect to the STT WebSocket service""" |
| try: |
| logger.info(f"π Connecting to STT service for client {client_id}: {self.stt_websocket_url}") |
| |
| |
| stt_ws = await asyncio.wait_for( |
| websockets.connect(self.stt_websocket_url), |
| timeout=5.0 |
| ) |
| self.stt_connections[client_id] = stt_ws |
| |
| |
| confirmation = await asyncio.wait_for(stt_ws.recv(), timeout=10.0) |
| confirmation_data = json.loads(confirmation) |
| |
| if confirmation_data.get("type") == "stt_connection_confirmed": |
| logger.info(f"β
STT service connected for client {client_id}") |
| return True |
| else: |
| logger.warning(f"β οΈ Unexpected STT confirmation: {confirmation_data}") |
| return False |
| |
| except asyncio.TimeoutError: |
| logger.error(f"β STT service connection timeout for {client_id} - service may be cold starting or WebSocket endpoints not available") |
| return False |
| except websockets.exceptions.WebSocketException as e: |
| if "503" in str(e): |
| logger.error(f"β STT service unavailable (HTTP 503) for {client_id} - service may be cold starting") |
| logger.info(f"π Try again in a few moments - Hugging Face services need time to start") |
| else: |
| logger.error(f"β STT WebSocket error for {client_id}: {e}") |
| logger.info(f"π Debug: Attempted connection to {self.stt_websocket_url}") |
| return False |
| except Exception as e: |
| logger.error(f"β Failed to connect to STT service for {client_id}: {e}") |
| logger.info(f"π Debug: STT service URL: {self.stt_websocket_url}") |
| return False |
| |
| async def disconnect_from_stt_service(self, client_id: str): |
| """Disconnect from STT WebSocket service""" |
| if client_id in self.stt_connections: |
| try: |
| stt_ws = self.stt_connections[client_id] |
| await stt_ws.close() |
| del self.stt_connections[client_id] |
| logger.info(f"π Disconnected from STT service for client {client_id}") |
| except Exception as e: |
| logger.error(f"Error disconnecting from STT service: {e}") |
| |
| async def send_audio_to_stt_service(self, client_id: str, audio_data: bytes) -> Optional[str]: |
| """Send audio data to STT service and get transcription""" |
| if client_id not in self.stt_connections: |
| |
| success = await self.connect_to_stt_service(client_id) |
| if not success: |
| return None |
| |
| try: |
| stt_ws = self.stt_connections[client_id] |
| |
| |
| import base64 |
| audio_b64 = base64.b64encode(audio_data).decode('utf-8') |
| |
| |
| message = { |
| "type": "stt_audio_chunk", |
| "audio_data": audio_b64, |
| "language": "auto", |
| "model_size": "base" |
| } |
| |
| await stt_ws.send(json.dumps(message)) |
| logger.info(f"π€ Sent {len(audio_data)} bytes to STT service") |
| |
| |
| response = await stt_ws.recv() |
| response_data = json.loads(response) |
| |
| if response_data.get("type") == "stt_transcription": |
| transcription = response_data.get("text", "") |
| logger.info(f"π STT transcription received: {transcription[:50]}...") |
| return transcription |
| elif response_data.get("type") == "stt_error": |
| error_msg = response_data.get("message", "Unknown STT error") |
| logger.error(f"β STT service error: {error_msg}") |
| return None |
| else: |
| logger.warning(f"β οΈ Unexpected STT response: {response_data}") |
| return None |
| |
| except Exception as e: |
| logger.error(f"β Error communicating with STT service: {e}") |
| |
| await self.disconnect_from_stt_service(client_id) |
| return None |
| |
| |
| async def connect_to_tts_service(self, client_id: str) -> bool: |
| """Connect to the TTS WebSocket service""" |
| try: |
| logger.info(f"π Connecting to TTS service for client {client_id}: {self.tts_websocket_url}") |
| |
| |
| tts_ws = await asyncio.wait_for( |
| websockets.connect(self.tts_websocket_url), |
| timeout=10.0 |
| ) |
| self.tts_connections[client_id] = tts_ws |
| |
| |
| confirmation = await asyncio.wait_for(tts_ws.recv(), timeout=15.0) |
| confirmation_data = json.loads(confirmation) |
| |
| if confirmation_data.get("type") == "tts_connection_confirmed": |
| logger.info(f"β
TTS service connected for client {client_id}") |
| return True |
| else: |
| logger.warning(f"β οΈ Unexpected TTS confirmation: {confirmation_data}") |
| return False |
| |
| except asyncio.TimeoutError: |
| logger.error(f"β TTS service connection timeout - service may not be in WebSocket mode") |
| logger.info(f"π‘ TTS service needs TTS_SERVICE_MODE=websocket environment variable") |
| return False |
| except websockets.exceptions.InvalidStatusCode as e: |
| logger.error(f"β TTS WebSocket endpoint not available: {e}") |
| logger.info(f"π‘ TTS service may be running in Gradio-only mode instead of WebSocket mode") |
| return False |
| except Exception as e: |
| logger.error(f"β Failed to connect to TTS service for {client_id}: {e}") |
| logger.info(f"π‘ Check if TTS service is running and configured with TTS_SERVICE_MODE=websocket") |
| return False |
| |
| async def disconnect_from_tts_service(self, client_id: str): |
| """Disconnect from TTS WebSocket service""" |
| if client_id in self.tts_connections: |
| try: |
| tts_ws = self.tts_connections[client_id] |
| await tts_ws.close() |
| del self.tts_connections[client_id] |
| logger.info(f"π Disconnected from TTS service for client {client_id}") |
| except Exception as e: |
| logger.error(f"Error disconnecting from TTS service: {e}") |
| |
| async def send_text_to_tts_service(self, client_id: str, text: str, voice_preset: str = "v2/en_speaker_6") -> Optional[bytes]: |
| """Send text to TTS service and get audio response""" |
| if client_id not in self.tts_connections: |
| |
| success = await self.connect_to_tts_service(client_id) |
| if not success: |
| return None |
| |
| try: |
| tts_ws = self.tts_connections[client_id] |
| |
| |
| message = { |
| "type": "tts_synthesize", |
| "text": text, |
| "voice_preset": voice_preset |
| } |
| |
| await tts_ws.send(json.dumps(message)) |
| logger.info(f"π€ Sent text to TTS service: {text[:50]}...") |
| |
| |
| response = await tts_ws.recv() |
| response_data = json.loads(response) |
| |
| if response_data.get("type") == "tts_audio_response": |
| |
| import base64 |
| audio_b64 = response_data.get("audio_data", "") |
| audio_bytes = base64.b64decode(audio_b64) |
| logger.info(f"π TTS audio received: {len(audio_bytes)} bytes") |
| return audio_bytes |
| elif response_data.get("type") == "tts_error": |
| error_msg = response_data.get("message", "Unknown TTS error") |
| logger.error(f"β TTS service error: {error_msg}") |
| return None |
| else: |
| logger.warning(f"β οΈ Unexpected TTS response: {response_data}") |
| return None |
| |
| except Exception as e: |
| logger.error(f"β Error communicating with TTS service: {e}") |
| |
| await self.disconnect_from_tts_service(client_id) |
| return None |
| |
| async def play_tts_response(self, client_id: str, text: str, voice_preset: str = "v2/en_speaker_6"): |
| """Generate TTS audio and send to client for playback""" |
| try: |
| logger.info(f"π Generating TTS response for client {client_id}: {text[:50]}...") |
| |
| |
| logger.info("π Attempting WebSocket TTS (PRIMARY)") |
| audio_data = await self.send_text_to_tts_service(client_id, text, voice_preset) |
| |
| if not audio_data: |
| logger.info("π WebSocket failed, trying HTTP API fallback") |
| audio_data = await self.try_http_tts_fallback(text, voice_preset) |
| |
| if audio_data: |
| |
| import base64 |
| audio_b64 = base64.b64encode(audio_data).decode('utf-8') |
| |
| |
| await self.send_message(client_id, { |
| "type": "tts_playback", |
| "audio_data": audio_b64, |
| "audio_format": "wav", |
| "text": text, |
| "voice_preset": voice_preset, |
| "timestamp": datetime.now().isoformat(), |
| "audio_size": len(audio_data) |
| }) |
| |
| logger.info(f"π TTS playback sent to {client_id} ({len(audio_data)} bytes)") |
| else: |
| logger.warning(f"β οΈ TTS service failed to generate audio for: {text[:50]}...") |
| |
| |
| await self.send_message(client_id, { |
| "type": "tts_error", |
| "message": "TTS audio generation failed", |
| "text": text, |
| "timestamp": datetime.now().isoformat() |
| }) |
| |
| except Exception as e: |
| logger.error(f"β TTS playback error for {client_id}: {e}") |
| await self.send_message(client_id, { |
| "type": "tts_error", |
| "message": f"TTS playback error: {str(e)}", |
| "timestamp": datetime.now().isoformat() |
| }) |
|
|
| async def process_audio_file_webrtc(self, audio_file_path: str, sample_rate: int) -> Optional[str]: |
| """Process audio file with real STT service via WebSocket""" |
| try: |
| logger.info(f"π€ WebRTC: Processing audio file {audio_file_path} with real STT") |
| |
| |
| with open(audio_file_path, 'rb') as f: |
| audio_data = f.read() |
| |
| file_size = len(audio_data) |
| logger.info(f"π€ Audio file size: {file_size} bytes") |
| |
| |
| temp_client_id = f"temp_{datetime.now().isoformat()}" |
| |
| try: |
| |
| logger.info("π Attempting WebSocket STT (PRIMARY)") |
| transcription = await self.send_audio_to_stt_service(temp_client_id, audio_data) |
| |
| if transcription: |
| logger.info(f"β
WebSocket STT transcription: {transcription}") |
| return transcription |
| |
| |
| logger.info("π WebSocket failed, trying HTTP API fallback") |
| http_transcription = await self.try_http_stt_fallback(audio_file_path) |
| if http_transcription: |
| logger.info(f"β
HTTP STT transcription (fallback): {http_transcription}") |
| return f"[HTTP] {http_transcription}" |
| else: |
| logger.error("β Both WebSocket and HTTP STT failed - using minimal fallback") |
| |
| |
| return "I'm having trouble processing that audio. Could you please try again?" |
| |
| finally: |
| |
| await self.disconnect_from_stt_service(temp_client_id) |
| |
| except Exception as e: |
| logger.error(f"WebRTC audio file processing failed: {e}") |
| return None |
| |
| async def try_http_stt_fallback(self, audio_file_path: str) -> Optional[str]: |
| """Fallback to HTTP API if WebSocket fails""" |
| try: |
| import requests |
| import aiohttp |
| import asyncio |
| |
| |
| def make_request(): |
| api_url = f"{self.stt_service_url}/api/predict" |
| with open(audio_file_path, 'rb') as audio_file: |
| files = {'data': audio_file} |
| data = {'data': '["auto", "base", true]'} |
| |
| response = requests.post(api_url, files=files, data=data, timeout=30) |
| return response |
| |
| |
| loop = asyncio.get_event_loop() |
| response = await loop.run_in_executor(None, make_request) |
| |
| if response.status_code == 200: |
| result = response.json() |
| logger.info(f"π HTTP STT result: {result}") |
| |
| |
| if result and 'data' in result and len(result['data']) > 1: |
| transcription = result['data'][1] |
| if transcription and transcription.strip(): |
| logger.info(f"β
HTTP STT transcription: {transcription}") |
| return transcription |
| |
| except Exception as e: |
| logger.error(f"β HTTP STT fallback failed: {e}") |
| |
| return None |
| |
| async def try_http_tts_fallback(self, text: str, voice_preset: str = "v2/en_speaker_6") -> Optional[bytes]: |
| """Fallback to HTTP API if TTS WebSocket fails""" |
| try: |
| import requests |
| import asyncio |
| |
| |
| def make_request(): |
| api_url = f"{self.tts_service_url}/api/predict" |
| data = {'data': f'["{text}", "{voice_preset}"]'} |
| |
| response = requests.post(api_url, data=data, timeout=60) |
| return response |
| |
| |
| loop = asyncio.get_event_loop() |
| response = await loop.run_in_executor(None, make_request) |
| |
| if response.status_code == 200: |
| result = response.json() |
| logger.info(f"π HTTP TTS result received") |
| |
| |
| if result and 'data' in result and len(result['data']) > 0: |
| audio_file_path = result['data'][0] |
| if audio_file_path and isinstance(audio_file_path, str): |
| |
| if audio_file_path.startswith('http'): |
| audio_response = requests.get(audio_file_path, timeout=30) |
| if audio_response.status_code == 200: |
| logger.info(f"β
HTTP TTS audio downloaded: {len(audio_response.content)} bytes") |
| return audio_response.content |
| |
| except Exception as e: |
| logger.error(f"β HTTP TTS fallback failed: {e}") |
| |
| return None |
| |
| async def process_audio_chunk_real_time(self, audio_array: np.ndarray, sample_rate: int) -> Optional[str]: |
| """Legacy method - kept for compatibility""" |
| try: |
| logger.info(f"π€ WebRTC: Processing {len(audio_array)} samples at {sample_rate}Hz") |
| duration = len(audio_array) / sample_rate |
| transcription = f"WebRTC test: Audio array ({duration:.1f}s, {sample_rate}Hz)" |
| return transcription |
| except Exception as e: |
| logger.error(f"WebRTC audio processing failed: {e}") |
| return None |
| |
| async def handle_message(self, client_id: str, message_data: dict): |
| """Handle different types of WebSocket messages""" |
| message_type = message_data.get("type") |
| |
| if message_type == "audio_chunk": |
| |
| audio_data = message_data.get("audio_data") |
| sample_rate = message_data.get("sample_rate", 16000) |
| |
| if audio_data: |
| |
| import base64 |
| audio_bytes = base64.b64decode(audio_data) |
| await self.handle_audio_chunk(client_id, audio_bytes, sample_rate) |
| |
| elif message_type == "start_recording": |
| |
| await self.send_message(client_id, { |
| "type": "recording_started", |
| "timestamp": datetime.now().isoformat() |
| }) |
| logger.info(f"π€ Recording started for {client_id}") |
| |
| elif message_type == "stop_recording": |
| |
| logger.info(f"π€ Recording stopped for {client_id} - applying unmute.sh flush trick") |
| |
| |
| await self.process_buffered_audio_with_flush(client_id) |
| |
| await self.send_message(client_id, { |
| "type": "recording_stopped", |
| "timestamp": datetime.now().isoformat() |
| }) |
| |
| elif message_type == "tts_request": |
| |
| text = message_data.get("text", "") |
| voice_preset = message_data.get("voice_preset", "v2/en_speaker_6") |
| |
| if text.strip(): |
| await self.play_tts_response(client_id, text, voice_preset) |
| else: |
| await self.send_message(client_id, { |
| "type": "tts_error", |
| "message": "Empty text provided for TTS", |
| "timestamp": datetime.now().isoformat() |
| }) |
| |
| elif message_type == "get_tts_voices": |
| |
| await self.send_message(client_id, { |
| "type": "tts_voices_list", |
| "voices": ["v2/en_speaker_6", "v2/en_speaker_9", "v2/en_speaker_3", "v2/en_speaker_1"], |
| "timestamp": datetime.now().isoformat() |
| }) |
| |
| else: |
| logger.warning(f"Unknown message type from {client_id}: {message_type}") |
|
|
|
|
| |
| webrtc_handler = WebRTCHandler() |