Spaces:
Runtime error
Runtime error
Peter Michael Gits
Fix Dockerfile directory permissions - create /app as root before switching users
26096f4 | import asyncio | |
| import json | |
| import time | |
| import logging | |
| from typing import Optional | |
| import torch | |
| import numpy as np | |
| import librosa | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse | |
| import uvicorn | |
| # Version tracking | |
| VERSION = "1.1.1" | |
| COMMIT_SHA = "TBD" | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global model variables | |
| model = None | |
| processor = None | |
| device = None | |
| async def load_model(): | |
| """Load STT model on startup""" | |
| global model, processor, device | |
| try: | |
| logger.info("Loading STT model...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {device}") | |
| # Try to load the actual model - fallback to mock if not available | |
| try: | |
| from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration | |
| model_id = "kyutai/stt-1b-en_fr" | |
| logger.info(f"Loading processor from {model_id}...") | |
| processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id) | |
| logger.info(f"Loading model from {model_id}...") | |
| model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id).to(device) | |
| logger.info(f"Model {model_id} loaded successfully on {device}") | |
| except Exception as model_error: | |
| logger.warning(f"Could not load actual model: {model_error}") | |
| logger.info("Using mock STT for development") | |
| model = "mock" | |
| processor = "mock" | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| model = "mock" | |
| processor = "mock" | |
| def transcribe_audio(audio_data: np.ndarray, sample_rate: int = 24000) -> str: | |
| """Transcribe audio data - expects 24kHz audio for Kyutai STT""" | |
| try: | |
| if model == "mock": | |
| # Mock transcription for development | |
| duration = len(audio_data) / sample_rate | |
| return f"Mock transcription: {duration:.2f}s audio at {sample_rate}Hz ({len(audio_data)} samples)" | |
| # Real transcription - Kyutai STT expects 24kHz | |
| if sample_rate != 24000: | |
| logger.info(f"Resampling from {sample_rate}Hz to 24000Hz") | |
| audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=24000) | |
| inputs = processor(audio_data, sampling_rate=24000, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| generated_ids = model.generate(**inputs) | |
| transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return transcription | |
| except Exception as e: | |
| logger.error(f"Transcription error: {e}") | |
| return f"Error: {str(e)}" | |
| # FastAPI app | |
| app = FastAPI( | |
| title="STT GPU Service Python v4", | |
| description="Real-time WebSocket STT streaming with kyutai/stt-1b-en_fr (24kHz)", | |
| version=VERSION | |
| ) | |
| async def startup_event(): | |
| """Load model on startup""" | |
| await load_model() | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "timestamp": time.time(), | |
| "version": VERSION, | |
| "commit_sha": COMMIT_SHA, | |
| "message": "STT WebSocket Service - Real-time streaming ready", | |
| "space_name": "stt-gpu-service-python-v4", | |
| "model_loaded": model is not None, | |
| "device": str(device) if device else "unknown", | |
| "expected_sample_rate": "24000Hz" | |
| } | |
| async def get_index(): | |
| """Simple HTML interface for testing""" | |
| html_content = f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>STT GPU Service Python v4</title> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; margin: 40px; }} | |
| .container {{ max-width: 800px; margin: 0 auto; }} | |
| .status {{ background: #f0f0f0; padding: 20px; border-radius: 8px; margin: 20px 0; }} | |
| button {{ padding: 10px 20px; margin: 5px; background: #007bff; color: white; border: none; border-radius: 4px; cursor: pointer; }} | |
| button:disabled {{ background: #ccc; }} | |
| #output {{ background: #f8f9fa; padding: 15px; border-radius: 4px; margin-top: 20px; }} | |
| .version {{ font-size: 0.8em; color: #666; margin-top: 20px; }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <h1>🎙️ STT GPU Service Python v4</h1> | |
| <p>Real-time WebSocket speech transcription service (24kHz audio)</p> | |
| <div class="status"> | |
| <h3>WebSocket Streaming Test</h3> | |
| <button onclick="startWebSocket()">Connect WebSocket</button> | |
| <button onclick="stopWebSocket()" disabled id="stopBtn">Disconnect</button> | |
| <p>Status: <span id="wsStatus">Disconnected</span></p> | |
| <p><small>Expected: 24kHz audio chunks (80ms = ~1920 samples)</small></p> | |
| </div> | |
| <div id="output"> | |
| <p>Transcription output will appear here...</p> | |
| </div> | |
| <div class="version"> | |
| v{VERSION} (SHA: {COMMIT_SHA}) | |
| </div> | |
| </div> | |
| <script> | |
| let ws = null; | |
| function startWebSocket() {{ | |
| const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; | |
| const wsUrl = `${{protocol}}//${{window.location.host}}/ws/stream`; | |
| ws = new WebSocket(wsUrl); | |
| ws.onopen = function(event) {{ | |
| document.getElementById('wsStatus').textContent = 'Connected'; | |
| document.querySelector('button').disabled = true; | |
| document.getElementById('stopBtn').disabled = false; | |
| // Send test message | |
| ws.send(JSON.stringify({{ | |
| type: 'audio_chunk', | |
| data: 'test_audio_data_24khz', | |
| timestamp: Date.now() | |
| }})); | |
| }}; | |
| ws.onmessage = function(event) {{ | |
| const data = JSON.parse(event.data); | |
| document.getElementById('output').innerHTML += `<p>${{JSON.stringify(data, null, 2)}}</p>`; | |
| }}; | |
| ws.onclose = function(event) {{ | |
| document.getElementById('wsStatus').textContent = 'Disconnected'; | |
| document.querySelector('button').disabled = false; | |
| document.getElementById('stopBtn').disabled = true; | |
| }}; | |
| ws.onerror = function(error) {{ | |
| document.getElementById('output').innerHTML += `<p style="color: red;">WebSocket Error: ${{error}}</p>`; | |
| }}; | |
| }} | |
| function stopWebSocket() {{ | |
| if (ws) {{ | |
| ws.close(); | |
| }} | |
| }} | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html_content) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """WebSocket endpoint for real-time audio streaming""" | |
| await websocket.accept() | |
| logger.info("WebSocket connection established") | |
| try: | |
| # Send initial connection confirmation | |
| await websocket.send_json({ | |
| "type": "connection", | |
| "status": "connected", | |
| "message": "STT WebSocket ready for audio chunks", | |
| "chunk_size_ms": 80, | |
| "expected_sample_rate": 24000, | |
| "expected_chunk_samples": 1920 # 80ms at 24kHz = 1920 samples | |
| }) | |
| while True: | |
| # Receive audio data | |
| data = await websocket.receive_json() | |
| if data.get("type") == "audio_chunk": | |
| try: | |
| # Process 80ms audio chunk (1920 samples at 24kHz) | |
| # In real implementation, you would: | |
| # 1. Decode base64 audio data | |
| # 2. Convert to numpy array (24kHz) | |
| # 3. Process with STT model | |
| # 4. Return transcription | |
| # For now, mock processing | |
| transcription = f"Mock transcription for 24kHz chunk at {data.get('timestamp', 'unknown')}" | |
| # Send transcription result | |
| await websocket.send_json({ | |
| "type": "transcription", | |
| "text": transcription, | |
| "timestamp": time.time(), | |
| "chunk_id": data.get("timestamp"), | |
| "confidence": 0.95 | |
| }) | |
| except Exception as e: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Processing error: {str(e)}", | |
| "timestamp": time.time() | |
| }) | |
| elif data.get("type") == "ping": | |
| # Respond to ping | |
| await websocket.send_json({ | |
| "type": "pong", | |
| "timestamp": time.time() | |
| }) | |
| except WebSocketDisconnect: | |
| logger.info("WebSocket connection closed") | |
| except Exception as e: | |
| logger.error(f"WebSocket error: {e}") | |
| await websocket.close(code=1011, reason=f"Server error: {str(e)}") | |
| async def api_transcribe(audio_file: Optional[str] = None): | |
| """REST API endpoint for testing""" | |
| if not audio_file: | |
| raise HTTPException(status_code=400, detail="No audio data provided") | |
| # Mock transcription | |
| result = { | |
| "transcription": f"REST API transcription result for: {audio_file[:50]}...", | |
| "timestamp": time.time(), | |
| "version": VERSION, | |
| "method": "REST", | |
| "expected_sample_rate": "24kHz" | |
| } | |
| return result | |
| if __name__ == "__main__": | |
| # Run the server | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info", | |
| access_log=True | |
| ) |