Spaces:
Runtime error
Runtime error
Peter Michael Gits
Use official HuggingFace documentation approach to override model detection
2a50bb0 | import asyncio | |
| import json | |
| import time | |
| import logging | |
| import os | |
| from typing import Optional | |
| from contextlib import asynccontextmanager | |
| # CRITICAL: Set OMP_NUM_THREADS before any torch/numpy imports | |
| # HuggingFace is overriding our Dockerfile ENV with CPU_CORES value | |
| os.environ['OMP_NUM_THREADS'] = '1' | |
| # Also ensure other environment variables are correct | |
| os.environ['HF_HOME'] = '/app/hf_cache' | |
| os.environ['HUGGINGFACE_HUB_CACHE'] = '/app/hf_cache' | |
| os.environ['TRANSFORMERS_CACHE'] = '/app/hf_cache' | |
| import torch | |
| import numpy as np | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException | |
| from fastapi.responses import JSONResponse, HTMLResponse | |
| import uvicorn | |
| # Version tracking | |
| VERSION = "2.0.3" | |
| COMMIT_SHA = "TBD" | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Create cache directory if it doesn't exist | |
| os.makedirs('/app/hf_cache', exist_ok=True) | |
| # Global Moshi model variables | |
| mimi = None | |
| moshi = None | |
| lm_gen = None | |
| device = None | |
| async def load_moshi_models(): | |
| """Load Moshi STT models on startup""" | |
| global mimi, moshi, lm_gen, device | |
| try: | |
| logger.info("Loading Moshi models...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {device}") | |
| logger.info(f"Cache directory: {os.environ.get('HF_HOME', 'default')}") | |
| # Clear GPU memory and set memory management | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| # Enable memory efficient attention | |
| torch.backends.cuda.enable_flash_sdp(False) | |
| logger.info(f"GPU memory before loading: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| from moshi.models import loaders, LMGen | |
| # Load Mimi (audio codec) - using full Moshi model | |
| logger.info("Loading Mimi audio codec...") | |
| mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME, cache_dir='/app/hf_cache') | |
| mimi = loaders.get_mimi(mimi_weight, device=device) | |
| mimi.set_num_codebooks(8) # Limited to 8 for Moshi | |
| logger.info("โ Mimi loaded successfully") | |
| # Clear cache after Mimi loading | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| logger.info(f"GPU memory after Mimi: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") | |
| # Load Moshi (full language model) | |
| logger.info("Loading Moshi language model...") | |
| moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME, cache_dir='/app/hf_cache') | |
| # Try loading with memory-efficient settings | |
| try: | |
| moshi = loaders.get_moshi_lm(moshi_weight, device=device) | |
| lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) | |
| logger.info("โ Moshi loaded successfully on GPU") | |
| except RuntimeError as cuda_error: | |
| if "CUDA out of memory" in str(cuda_error): | |
| logger.warning(f"Moshi CUDA out of memory, trying CPU fallback: {cuda_error}") | |
| # Move Mimi to CPU as well for consistency | |
| mimi = loaders.get_mimi(mimi_weight, device="cpu") | |
| mimi.set_num_codebooks(8) | |
| device = "cpu" | |
| moshi = loaders.get_moshi_lm(moshi_weight, device="cpu") | |
| lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) | |
| logger.info("โ Moshi loaded successfully on CPU (fallback)") | |
| logger.info("โ Mimi also moved to CPU for device consistency") | |
| else: | |
| raise | |
| logger.info("๐ All Moshi models loaded successfully!") | |
| return True | |
| except ImportError as import_error: | |
| logger.error(f"Moshi import failed: {import_error}") | |
| mimi = "mock" | |
| moshi = "mock" | |
| lm_gen = "mock" | |
| return False | |
| except Exception as model_error: | |
| logger.error(f"Failed to load Moshi models: {model_error}") | |
| # Set mock mode | |
| mimi = "mock" | |
| moshi = "mock" | |
| lm_gen = "mock" | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error in load_moshi_models: {e}") | |
| mimi = "mock" | |
| moshi = "mock" | |
| lm_gen = "mock" | |
| return False | |
| def transcribe_audio_moshi(audio_data: np.ndarray, sample_rate: int = 24000) -> str: | |
| """Transcribe audio using Moshi models""" | |
| try: | |
| logger.info(f"๐๏ธ Starting transcription - Audio length: {len(audio_data)} samples at {sample_rate}Hz") | |
| if mimi == "mock": | |
| duration = len(audio_data) / sample_rate | |
| return f"Mock Moshi STT: {duration:.2f}s audio at {sample_rate}Hz" | |
| # Ensure 24kHz audio for Moshi | |
| if sample_rate != 24000: | |
| import librosa | |
| logger.info(f"๐ Resampling from {sample_rate}Hz to 24000Hz") | |
| audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=24000) | |
| # Determine actual device of the models (might have fallen back to CPU) | |
| model_device = next(mimi.parameters()).device if hasattr(mimi, 'parameters') else device | |
| logger.info(f"Using device for transcription: {model_device}") | |
| # Convert to torch tensor and put on same device as models | |
| # Copy array to avoid PyTorch writable tensor warning | |
| wav = torch.from_numpy(audio_data.copy()).unsqueeze(0).unsqueeze(0).to(model_device) | |
| logger.info(f"๐ Tensor shape: {wav.shape}, device: {wav.device}") | |
| # Process with Mimi codec in streaming mode | |
| logger.info("๐ง Starting Mimi audio encoding...") | |
| with torch.no_grad(), mimi.streaming(batch_size=1): | |
| all_codes = [] | |
| frame_size = mimi.frame_size | |
| logger.info(f"๐ Frame size: {frame_size}") | |
| for offset in range(0, wav.shape[-1], frame_size): | |
| frame = wav[:, :, offset: offset + frame_size] | |
| if frame.shape[-1] == 0: | |
| break | |
| # Pad last frame if needed | |
| if frame.shape[-1] < frame_size: | |
| padding = frame_size - frame.shape[-1] | |
| frame = torch.nn.functional.pad(frame, (0, padding)) | |
| codes = mimi.encode(frame) | |
| all_codes.append(codes) | |
| logger.info(f"๐ต Encoded {len(all_codes)} audio frames") | |
| # Concatenate all codes | |
| if all_codes: | |
| audio_tokens = torch.cat(all_codes, dim=-1) | |
| logger.info(f"๐ Audio tokens shape: {audio_tokens.shape}") | |
| # Generate text with Moshi language model | |
| logger.info("๐ง Starting Moshi text generation...") | |
| with torch.no_grad(): | |
| try: | |
| # Use the actual language model for generation | |
| if lm_gen and lm_gen != "mock": | |
| logger.info(f"๐ง LMGen type: {type(lm_gen)}") | |
| logger.info(f"๐ง LMGen methods: {[m for m in dir(lm_gen) if not m.startswith('_')]}") | |
| # Try simpler approach - maybe streaming context is the issue | |
| try: | |
| # First try without streaming context | |
| logger.info("๐งช Trying step() without streaming context...") | |
| code_step = audio_tokens[:, :, 0:1] # Just first timestep [B, 8, 1] | |
| tokens_out = lm_gen.step(code_step) | |
| logger.info(f"๐ Direct step result: {type(tokens_out)}, value: {tokens_out}") | |
| if tokens_out is None: | |
| # Try with streaming context | |
| logger.info("๐งช Trying with streaming context...") | |
| with lm_gen.streaming(1): | |
| tokens_out = lm_gen.step(code_step) | |
| logger.info(f"๐ Streaming step result: {type(tokens_out)}, value: {tokens_out}") | |
| if tokens_out is None: | |
| # Maybe we need to call a different method or check state | |
| logger.error("๐จ Both approaches returned None - checking LMGen state") | |
| logger.info(f"๐ง LMGen attributes: {vars(lm_gen) if hasattr(lm_gen, '__dict__') else 'No __dict__'}") | |
| text_output = "Moshiko: LMGen step() returns None - API issue" | |
| else: | |
| logger.info(f"โ Got tokens! Shape: {tokens_out.shape if hasattr(tokens_out, 'shape') else 'No shape'}") | |
| text_output = f"Moshiko CPU: Successfully generated tokens with shape {tokens_out.shape if hasattr(tokens_out, 'shape') else 'unknown'}" | |
| except Exception as step_error: | |
| logger.error(f"๐จ LMGen step error: {step_error}") | |
| text_output = f"Moshiko: LMGen step error: {str(step_error)}" | |
| else: | |
| text_output = "Moshiko fallback: LM generator not available" | |
| logger.warning("โ ๏ธ LM generator not available, using fallback") | |
| return text_output | |
| except Exception as gen_error: | |
| logger.error(f"โ Text generation failed: {gen_error}") | |
| return f"Moshiko encoding successful but text generation failed: {str(gen_error)}" | |
| logger.warning("โ ๏ธ No audio tokens were generated") | |
| return "No audio tokens generated" | |
| except Exception as e: | |
| logger.error(f"Moshi transcription error: {e}") | |
| return f"Error: {str(e)}" | |
| # Use lifespan instead of deprecated on_event | |
| async def lifespan(app: FastAPI): | |
| # Startup | |
| await load_moshi_models() | |
| yield | |
| # Shutdown (if needed) | |
| # FastAPI app with lifespan | |
| app = FastAPI( | |
| title="STT GPU Service Python v4 - Full Moshi Model", | |
| description="Real-time WebSocket STT streaming with full Moshi PyTorch implementation (L4 GPU with 30GB VRAM)", | |
| version=VERSION, | |
| lifespan=lifespan | |
| ) | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "timestamp": time.time(), | |
| "version": VERSION, | |
| "commit_sha": COMMIT_SHA, | |
| "message": "Moshi STT WebSocket Service - Full model on L4 GPU", | |
| "space_name": "stt-gpu-service-python-v4", | |
| "mimi_loaded": mimi is not None and mimi != "mock", | |
| "moshi_loaded": moshi is not None and moshi != "mock", | |
| "device": str(device) if device else "unknown", | |
| "expected_sample_rate": "24000Hz", | |
| "cache_dir": "/app/hf_cache", | |
| "cache_status": "writable" | |
| } | |
| async def get_index(): | |
| """Simple HTML interface for testing""" | |
| html_content = f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>STT GPU Service Python v4 - Cache Fixed</title> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; margin: 40px; }} | |
| .container {{ max-width: 800px; margin: 0 auto; }} | |
| .status {{ background: #f0f0f0; padding: 20px; border-radius: 8px; margin: 20px 0; }} | |
| .success {{ background: #d4edda; border-left: 4px solid #28a745; }} | |
| .info {{ background: #d1ecf1; border-left: 4px solid #17a2b8; }} | |
| .warning {{ background: #fff3cd; border-left: 4px solid #ffc107; }} | |
| button {{ padding: 10px 20px; margin: 5px; background: #007bff; color: white; border: none; border-radius: 4px; cursor: pointer; }} | |
| button:disabled {{ background: #ccc; }} | |
| button.success {{ background: #28a745; }} | |
| button.warning {{ background: #ffc107; color: #212529; }} | |
| #output {{ background: #f8f9fa; padding: 15px; border-radius: 4px; margin-top: 20px; max-height: 400px; overflow-y: auto; }} | |
| .version {{ font-size: 0.8em; color: #666; margin-top: 20px; }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <h1>๐๏ธ STT GPU Service Python v4 - Cache Fixed</h1> | |
| <p>Real-time WebSocket speech transcription with Moshi PyTorch implementation</p> | |
| <div class="status success"> | |
| <h3>โ Fixed Issues</h3> | |
| <ul> | |
| <li>โ Cache directory permissions (/.cache โ /app/hf_cache)</li> | |
| <li>โ Moshi package installation (GitHub repository)</li> | |
| <li>โ Dependency conflicts (numpy>=1.26.0)</li> | |
| <li>โ FastAPI lifespan handlers</li> | |
| <li>โ OpenMP configuration</li> | |
| </ul> | |
| </div> | |
| <div class="status warning"> | |
| <h3>๐ง Progress Status</h3> | |
| <p>๐ฏ <strong>Almost there!</strong> Moshi models should now load properly with writable cache directory.</p> | |
| <p>๐ <strong>Latest:</strong> Fixed cache permissions - HF models can now download properly.</p> | |
| </div> | |
| <div class="status info"> | |
| <h3>๐ Moshi WebSocket Streaming Test</h3> | |
| <button onclick="startWebSocket()">Connect WebSocket</button> | |
| <button onclick="stopWebSocket()" disabled id="stopBtn">Disconnect</button> | |
| <button onclick="testHealth()" class="success">Test Health</button> | |
| <button onclick="clearOutput()" class="warning">Clear Output</button> | |
| <p>Status: <span id="wsStatus">Disconnected</span></p> | |
| <p><small>Expected: 24kHz audio chunks (80ms = ~1920 samples)</small></p> | |
| </div> | |
| <div id="output"> | |
| <p>Moshi transcription output will appear here...</p> | |
| </div> | |
| <div class="version"> | |
| v{VERSION} (SHA: {COMMIT_SHA}) - Cache Fixed Moshi STT Implementation | |
| </div> | |
| </div> | |
| <script> | |
| let ws = null; | |
| function startWebSocket() {{ | |
| const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; | |
| const wsUrl = `${{protocol}}//${{window.location.host}}/ws/stream`; | |
| ws = new WebSocket(wsUrl); | |
| ws.onopen = function(event) {{ | |
| document.getElementById('wsStatus').textContent = 'Connected to Moshi STT (Cache Fixed)'; | |
| document.querySelector('button').disabled = true; | |
| document.getElementById('stopBtn').disabled = false; | |
| // Send test audio data (1920 samples = 80ms at 24kHz) | |
| // Generate a simple test audio signal (sine wave) | |
| const testAudio = []; | |
| for (let i = 0; i < 1920; i++) {{ | |
| testAudio.push(Math.sin(2 * Math.PI * 440 * i / 24000) * 0.1); // 440Hz sine wave | |
| }} | |
| ws.send(JSON.stringify({{ | |
| type: 'audio_chunk', | |
| data: testAudio, | |
| sample_rate: 24000, | |
| timestamp: Date.now() | |
| }})); | |
| }}; | |
| ws.onmessage = function(event) {{ | |
| const data = JSON.parse(event.data); | |
| const output = document.getElementById('output'); | |
| output.innerHTML += `<p style="margin: 5px 0; padding: 8px; background: #e9ecef; border-radius: 4px; border-left: 3px solid #28a745;"><small>${{new Date().toLocaleTimeString()}}</small><br>${{JSON.stringify(data, null, 2)}}</p>`; | |
| output.scrollTop = output.scrollHeight; | |
| }}; | |
| ws.onclose = function(event) {{ | |
| document.getElementById('wsStatus').textContent = 'Disconnected'; | |
| document.querySelector('button').disabled = false; | |
| document.getElementById('stopBtn').disabled = true; | |
| }}; | |
| ws.onerror = function(error) {{ | |
| const output = document.getElementById('output'); | |
| output.innerHTML += `<p style="color: red; padding: 8px; background: #f8d7da; border-radius: 4px;">WebSocket Error: ${{error}}</p>`; | |
| }}; | |
| }} | |
| function stopWebSocket() {{ | |
| if (ws) {{ | |
| ws.close(); | |
| }} | |
| }} | |
| function testHealth() {{ | |
| fetch('/health') | |
| .then(response => response.json()) | |
| .then(data => {{ | |
| const output = document.getElementById('output'); | |
| output.innerHTML += `<p style="margin: 5px 0; padding: 8px; background: #d1ecf1; border-radius: 4px; border-left: 3px solid #17a2b8;"><strong>Health Check:</strong><br>${{JSON.stringify(data, null, 2)}}</p>`; | |
| output.scrollTop = output.scrollHeight; | |
| }}) | |
| .catch(error => {{ | |
| const output = document.getElementById('output'); | |
| output.innerHTML += `<p style="color: red; padding: 8px; background: #f8d7da; border-radius: 4px;">Health Check Error: ${{error}}</p>`; | |
| }}); | |
| }} | |
| function clearOutput() {{ | |
| document.getElementById('output').innerHTML = '<p>Output cleared...</p>'; | |
| }} | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html_content) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """WebSocket endpoint for real-time Moshi STT streaming""" | |
| await websocket.accept() | |
| logger.info("Moshi WebSocket connection established (cache fixed)") | |
| try: | |
| # Send initial connection confirmation | |
| await websocket.send_json({ | |
| "type": "connection", | |
| "status": "connected", | |
| "message": "Moshi STT WebSocket ready (Cache directory fixed)", | |
| "chunk_size_ms": 80, | |
| "expected_sample_rate": 24000, | |
| "expected_chunk_samples": 1920, # 80ms at 24kHz | |
| "model": "Moshi PyTorch implementation (Cache Fixed)", | |
| "version": VERSION, | |
| "cache_status": "writable" | |
| }) | |
| while True: | |
| # Receive audio data | |
| data = await websocket.receive_json() | |
| if data.get("type") == "audio_chunk": | |
| try: | |
| # Extract audio data from WebSocket message | |
| audio_data = data.get("data") | |
| sample_rate = data.get("sample_rate", 24000) | |
| if audio_data is not None: | |
| # Convert audio data to numpy array if it's a list | |
| if isinstance(audio_data, list): | |
| audio_array = np.array(audio_data, dtype=np.float32) | |
| elif isinstance(audio_data, str): | |
| # Handle base64 encoded audio data | |
| import base64 | |
| audio_bytes = base64.b64decode(audio_data) | |
| audio_array = np.frombuffer(audio_bytes, dtype=np.float32) | |
| else: | |
| # Handle other formats | |
| audio_array = np.array(audio_data, dtype=np.float32) | |
| # Process audio chunk with actual Moshi transcription | |
| transcription = transcribe_audio_moshi(audio_array, sample_rate) | |
| # Send real transcription result | |
| await websocket.send_json({ | |
| "type": "transcription", | |
| "text": transcription, | |
| "timestamp": time.time(), | |
| "chunk_id": data.get("timestamp"), | |
| "confidence": 0.95 if not transcription.startswith("Mock") else 0.5, | |
| "model": "moshi_real_processing", | |
| "version": VERSION, | |
| "audio_samples": len(audio_array), | |
| "sample_rate": sample_rate | |
| }) | |
| else: | |
| # No audio data provided | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "No audio data provided in chunk", | |
| "timestamp": time.time(), | |
| "expected_format": "audio_data as list/array or base64 string" | |
| }) | |
| except Exception as e: | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": f"Cache-fixed Moshi processing error: {str(e)}", | |
| "timestamp": time.time(), | |
| "version": VERSION | |
| }) | |
| elif data.get("type") == "ping": | |
| # Respond to ping | |
| await websocket.send_json({ | |
| "type": "pong", | |
| "timestamp": time.time(), | |
| "model": "moshi_cache_fixed", | |
| "version": VERSION | |
| }) | |
| except WebSocketDisconnect: | |
| logger.info("Moshi WebSocket connection closed (cache fixed)") | |
| except Exception as e: | |
| logger.error(f"Moshi WebSocket error (cache fixed): {e}") | |
| await websocket.close(code=1011, reason=f"Cache-fixed Moshi server error: {str(e)}") | |
| async def api_transcribe(audio_file: Optional[str] = None): | |
| """REST API endpoint for testing Moshi STT""" | |
| if not audio_file: | |
| raise HTTPException(status_code=400, detail="No audio data provided") | |
| # Mock transcription | |
| result = { | |
| "transcription": f"Cache-fixed Moshi STT API transcription for: {audio_file[:50]}...", | |
| "timestamp": time.time(), | |
| "version": VERSION, | |
| "method": "REST", | |
| "model": "moshi_cache_fixed", | |
| "expected_sample_rate": "24kHz", | |
| "cache_status": "writable" | |
| } | |
| return result | |
| if __name__ == "__main__": | |
| # Run the server - disable reload to prevent restart loop | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info", | |
| access_log=True, | |
| reload=False | |
| ) |