#!/usr/bin/env python3 """ Unified WebSocket/HTTP Whisper Transcription Server Handles real-time audio streaming, transcription using Whisper, and HTTP serving """ import asyncio import websockets import json import numpy as np import torch import logging import traceback import os from typing import Dict, Any from aiohttp import web, WSMsgType from aiohttp.web_ws import WebSocketResponse # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) try: from whisper_stream import load_streaming_model_correct from whisper_stream.streaming_decoding import DecodingOptions except ImportError: logger.error("whisper_stream not found. Please install it or use regular whisper") # Fallback to regular whisper if whisper_stream is not available import whisper class UnifiedTranscriptionServer: def __init__(self, host: str = "0.0.0.0", port: int = 8000): self.host = host self.port = port self.clients: Dict[str, Dict[str, Any]] = {} self.app = web.Application() self.setup_routes() def setup_routes(self): """Setup HTTP routes and WebSocket endpoint""" # HTTP routes self.app.router.add_get('/', self.serve_index) self.app.router.add_get('/health', self.health_check) # WebSocket endpoint self.app.router.add_get('/ws', self.websocket_handler) # Static file serving (if needed) if os.path.exists('static'): self.app.router.add_static('/static/', 'static') async def serve_index(self, request): """Serve the main HTML page""" try: with open("./static/client.html", "r", encoding='utf-8') as f: html_content = f.read() return web.Response(text=html_content, content_type='text/html') except FileNotFoundError: return web.Response(text="client.html not found!", status=404) except Exception as e: logger.error(f"Error serving client.html! {e}") return web.Response(text="Error loading page...", status=500) async def health_check(self, request): """Health check endpoint""" return web.json_response({"status": "healthy", "cuda": torch.cuda.is_available()}) async def websocket_handler(self, request): """Handle WebSocket connections""" ws = WebSocketResponse() await ws.prepare(request) # Generate client ID client_id = f"{request.remote}:{id(ws)}" logger.info(f"New WebSocket client connected: {client_id}") # Initialize client state self.clients[client_id] = { 'websocket': ws, 'model': None, 'config': None, 'buffer': bytearray(), 'total_samples': 0, 'is_first_chunk': True } try: await self.process_websocket_messages(client_id) except Exception as e: logger.error(f"Error handling WebSocket client {client_id}: {e}") logger.error(traceback.format_exc()) finally: # Cleanup if client_id in self.clients: del self.clients[client_id] if not ws.closed: await ws.close() return ws async def process_websocket_messages(self, client_id: str): """Process messages from a WebSocket client""" client = self.clients[client_id] ws = client['websocket'] async for msg in ws: if msg.type == WSMsgType.TEXT: # Handle configuration message await self.handle_config_message(client_id, msg.data) elif msg.type == WSMsgType.BINARY: # Handle audio data await self.handle_audio_data(client_id, msg.data) elif msg.type == WSMsgType.ERROR: logger.error(f'WebSocket error for client {client_id}: {ws.exception()}') break async def handle_config_message(self, client_id: str, message: str): """Handle configuration message from client""" client = self.clients[client_id] ws = client['websocket'] try: config = json.loads(message) logger.info(f"Received config from {client_id}: {config}") # Validate config required_fields = ['model_size', 'chunk_size', 'beam_size', 'language'] for field in required_fields: if field not in config: await ws.send_str(json.dumps({"error": f"Missing required field: {field}"})) return # Load model model_size = config['model_size'] chunk_size = config['chunk_size'] logger.info(f"Loading model {model_size} for client {client_id}") # Check - if language is other than english, throw an error. # Only large-v2 300msec is available. if multilingual := config['language'] != "en": if model_size != "large-v2" or chunk_size != 300: await ws.send_str(json.dumps({"error": f"Running multilingual transcription is available for now only on large-v2 model using chunk size of 300ms."})) return # Try to use whisper_stream, fallback to regular whisper try: model = load_streaming_model_correct(model_size, chunk_size, multilingual) client['first_chunk'] = True if torch.cuda.is_available(): model = model.to("cuda") logger.info(f"Model loaded on GPU for client {client_id}") else: logger.info(f"Model loaded on CPU for client {client_id}") model.reset(use_stream=True) model.eval() client['model'] = model client['config'] = config await ws.send_str(json.dumps({"status": "CONFIG_RECEIVED", "gpu": torch.cuda.is_available()})) except Exception as e: logger.error(f"Error loading streaming model: {e}") # Fallback to regular whisper try: model = whisper.load_model(model_size) if torch.cuda.is_available(): model = model.to("cuda") client['model'] = model client['config'] = config client['use_streaming'] = False await ws.send_str(json.dumps({"status": "CONFIG_RECEIVED", "gpu": torch.cuda.is_available(), "fallback": True})) except Exception as e2: logger.error(f"Error loading fallback model: {e2}") await ws.send_str(json.dumps({"error": f"Failed to load model: {e2}"})) except json.JSONDecodeError as e: await ws.send_str(json.dumps({"error": f"Invalid JSON: {e}"})) except Exception as e: logger.error(f"Error handling config for client {client_id}: {e}") await ws.send_str(json.dumps({"error": str(e)})) async def handle_audio_data(self, client_id: str, audio_data: bytes): """Handle audio data from client""" client = self.clients[client_id] ws = client['websocket'] if client['config'] is None: await ws.send_str(json.dumps({"error": "Config not set"})) return if client['model'] is None: await ws.send_str(json.dumps({"error": "Model not loaded"})) return # Add audio data to buffer client['buffer'].extend(audio_data) # Calculate chunk size in bytes chunk_size_ms = client['config']['chunk_size'] sample_rate = 16000 chunk_samples = int(sample_rate * (chunk_size_ms / 1000)) chunk_bytes = chunk_samples * 2 # 16-bit audio = 2 bytes per sample if client.get('first_chunk', True): chunk_bytes += 720 # Process complete chunks while len(client['buffer']) >= chunk_bytes: chunk = client['buffer'][:chunk_bytes] client['buffer'] = client['buffer'][chunk_bytes:] try: if client.get('first_chunk', True): client['first_chunk'] = False await self.transcribe_chunk(client_id, chunk) except Exception as e: logger.error(f"Error transcribing chunk for client {client_id}: {e}") await ws.send_str(json.dumps({"error": f"Transcription error: {str(e)}"})) async def transcribe_chunk(self, client_id: str, chunk: bytes): """Transcribe audio chunk""" client = self.clients[client_id] ws = client['websocket'] model = client['model'] config = client['config'] try: # Convert bytes to numpy array pcm = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32768.0 # Convert to torch tensor audio = torch.tensor(pcm) if torch.cuda.is_available() and next(model.parameters()).is_cuda: audio = audio.to("cuda") # Transcribe based on model type if hasattr(model, 'decode') and 'use_streaming' not in client: # Using whisper_stream decoding_options = DecodingOptions( language=config['language'], gran=(config['chunk_size'] // 20), single_frame_mel=True, without_timestamps=True, beam_size=config['beam_size'], stream_decode=True, use_ca_kv_cache=True, look_ahead_blocks=model.extra_gran_blocks ) result = model.decode(audio, decoding_options, use_frames=True) text = result.text else: # Using regular whisper # Pad audio to minimum length if needed min_length = 16000 # 1 second at 16kHz if len(audio) < min_length: audio = torch.nn.functional.pad(audio, (0, min_length - len(audio))) result = model.transcribe(audio.cpu().numpy(), language="en", beam_size=config['beam_size'], temperature=config['temperature']) text = result['text'] # Send transcription result if text.strip(): client['total_samples'] += len(pcm) duration = client['total_samples'] / 16000 # seconds await ws.send_str(json.dumps({ "text": text.strip(), "timestamp": duration, "chunk_duration": len(pcm) / 16000 })) except Exception as e: logger.error(f"Error in transcription for client {client_id}: {e}") logger.exception("Exception occurred") raise async def start_server(self): """Start the unified server""" logger.info(f"Starting unified server on {self.host}:{self.port}") logger.info(f"CUDA available: {torch.cuda.is_available()}") runner = web.AppRunner(self.app) await runner.setup() site = web.TCPSite(runner, self.host, self.port) await site.start() logger.info(f"Server running on http://{self.host}:{self.port}") logger.info(f"WebSocket endpoint: ws://{self.host}:{self.port}/ws") # Keep the server running try: await asyncio.Future() # Run forever except KeyboardInterrupt: logger.info("Server stopped by user") finally: await runner.cleanup() def main(): import argparse parser = argparse.ArgumentParser(description='Unified WebSocket/HTTP Whisper Transcription Server') parser.add_argument('--host', default='0.0.0.0', help='Host to bind to') parser.add_argument('--port', type=int, default=8000, help='Port to bind to') parser.add_argument('--log-level', default='INFO', help='Log level') args = parser.parse_args() # Set log level logging.getLogger().setLevel(getattr(logging, args.log_level.upper())) server = UnifiedTranscriptionServer(args.host, args.port) try: asyncio.run(server.start_server()) except KeyboardInterrupt: logger.info("Server stopped by user") except Exception as e: logger.error(f"Server error: {e}") logger.error(traceback.format_exc()) if __name__ == '__main__': main()