Peter Michael Gits
Use official HuggingFace documentation approach to override model detection
2a50bb0
import asyncio
import json
import time
import logging
import os
from typing import Optional
from contextlib import asynccontextmanager
# CRITICAL: Set OMP_NUM_THREADS before any torch/numpy imports
# HuggingFace is overriding our Dockerfile ENV with CPU_CORES value
os.environ['OMP_NUM_THREADS'] = '1'
# Also ensure other environment variables are correct
os.environ['HF_HOME'] = '/app/hf_cache'
os.environ['HUGGINGFACE_HUB_CACHE'] = '/app/hf_cache'
os.environ['TRANSFORMERS_CACHE'] = '/app/hf_cache'
import torch
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.responses import JSONResponse, HTMLResponse
import uvicorn
# Version tracking
VERSION = "2.0.3"
COMMIT_SHA = "TBD"
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Create cache directory if it doesn't exist
os.makedirs('/app/hf_cache', exist_ok=True)
# Global Moshi model variables
mimi = None
moshi = None
lm_gen = None
device = None
async def load_moshi_models():
"""Load Moshi STT models on startup"""
global mimi, moshi, lm_gen, device
try:
logger.info("Loading Moshi models...")
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
logger.info(f"Cache directory: {os.environ.get('HF_HOME', 'default')}")
# Clear GPU memory and set memory management
if device == "cuda":
torch.cuda.empty_cache()
# Enable memory efficient attention
torch.backends.cuda.enable_flash_sdp(False)
logger.info(f"GPU memory before loading: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
try:
from huggingface_hub import hf_hub_download
from moshi.models import loaders, LMGen
# Load Mimi (audio codec) - using full Moshi model
logger.info("Loading Mimi audio codec...")
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME, cache_dir='/app/hf_cache')
mimi = loaders.get_mimi(mimi_weight, device=device)
mimi.set_num_codebooks(8) # Limited to 8 for Moshi
logger.info("โœ… Mimi loaded successfully")
# Clear cache after Mimi loading
if device == "cuda":
torch.cuda.empty_cache()
logger.info(f"GPU memory after Mimi: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
# Load Moshi (full language model)
logger.info("Loading Moshi language model...")
moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME, cache_dir='/app/hf_cache')
# Try loading with memory-efficient settings
try:
moshi = loaders.get_moshi_lm(moshi_weight, device=device)
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)
logger.info("โœ… Moshi loaded successfully on GPU")
except RuntimeError as cuda_error:
if "CUDA out of memory" in str(cuda_error):
logger.warning(f"Moshi CUDA out of memory, trying CPU fallback: {cuda_error}")
# Move Mimi to CPU as well for consistency
mimi = loaders.get_mimi(mimi_weight, device="cpu")
mimi.set_num_codebooks(8)
device = "cpu"
moshi = loaders.get_moshi_lm(moshi_weight, device="cpu")
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)
logger.info("โœ… Moshi loaded successfully on CPU (fallback)")
logger.info("โœ… Mimi also moved to CPU for device consistency")
else:
raise
logger.info("๐ŸŽ‰ All Moshi models loaded successfully!")
return True
except ImportError as import_error:
logger.error(f"Moshi import failed: {import_error}")
mimi = "mock"
moshi = "mock"
lm_gen = "mock"
return False
except Exception as model_error:
logger.error(f"Failed to load Moshi models: {model_error}")
# Set mock mode
mimi = "mock"
moshi = "mock"
lm_gen = "mock"
return False
except Exception as e:
logger.error(f"Error in load_moshi_models: {e}")
mimi = "mock"
moshi = "mock"
lm_gen = "mock"
return False
def transcribe_audio_moshi(audio_data: np.ndarray, sample_rate: int = 24000) -> str:
"""Transcribe audio using Moshi models"""
try:
logger.info(f"๐ŸŽ™๏ธ Starting transcription - Audio length: {len(audio_data)} samples at {sample_rate}Hz")
if mimi == "mock":
duration = len(audio_data) / sample_rate
return f"Mock Moshi STT: {duration:.2f}s audio at {sample_rate}Hz"
# Ensure 24kHz audio for Moshi
if sample_rate != 24000:
import librosa
logger.info(f"๐Ÿ”„ Resampling from {sample_rate}Hz to 24000Hz")
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=24000)
# Determine actual device of the models (might have fallen back to CPU)
model_device = next(mimi.parameters()).device if hasattr(mimi, 'parameters') else device
logger.info(f"Using device for transcription: {model_device}")
# Convert to torch tensor and put on same device as models
# Copy array to avoid PyTorch writable tensor warning
wav = torch.from_numpy(audio_data.copy()).unsqueeze(0).unsqueeze(0).to(model_device)
logger.info(f"๐Ÿ“Š Tensor shape: {wav.shape}, device: {wav.device}")
# Process with Mimi codec in streaming mode
logger.info("๐Ÿ”ง Starting Mimi audio encoding...")
with torch.no_grad(), mimi.streaming(batch_size=1):
all_codes = []
frame_size = mimi.frame_size
logger.info(f"๐Ÿ“ Frame size: {frame_size}")
for offset in range(0, wav.shape[-1], frame_size):
frame = wav[:, :, offset: offset + frame_size]
if frame.shape[-1] == 0:
break
# Pad last frame if needed
if frame.shape[-1] < frame_size:
padding = frame_size - frame.shape[-1]
frame = torch.nn.functional.pad(frame, (0, padding))
codes = mimi.encode(frame)
all_codes.append(codes)
logger.info(f"๐ŸŽต Encoded {len(all_codes)} audio frames")
# Concatenate all codes
if all_codes:
audio_tokens = torch.cat(all_codes, dim=-1)
logger.info(f"๐Ÿ”— Audio tokens shape: {audio_tokens.shape}")
# Generate text with Moshi language model
logger.info("๐Ÿง  Starting Moshi text generation...")
with torch.no_grad():
try:
# Use the actual language model for generation
if lm_gen and lm_gen != "mock":
logger.info(f"๐Ÿ”ง LMGen type: {type(lm_gen)}")
logger.info(f"๐Ÿ”ง LMGen methods: {[m for m in dir(lm_gen) if not m.startswith('_')]}")
# Try simpler approach - maybe streaming context is the issue
try:
# First try without streaming context
logger.info("๐Ÿงช Trying step() without streaming context...")
code_step = audio_tokens[:, :, 0:1] # Just first timestep [B, 8, 1]
tokens_out = lm_gen.step(code_step)
logger.info(f"๐Ÿ” Direct step result: {type(tokens_out)}, value: {tokens_out}")
if tokens_out is None:
# Try with streaming context
logger.info("๐Ÿงช Trying with streaming context...")
with lm_gen.streaming(1):
tokens_out = lm_gen.step(code_step)
logger.info(f"๐Ÿ” Streaming step result: {type(tokens_out)}, value: {tokens_out}")
if tokens_out is None:
# Maybe we need to call a different method or check state
logger.error("๐Ÿšจ Both approaches returned None - checking LMGen state")
logger.info(f"๐Ÿ”ง LMGen attributes: {vars(lm_gen) if hasattr(lm_gen, '__dict__') else 'No __dict__'}")
text_output = "Moshiko: LMGen step() returns None - API issue"
else:
logger.info(f"โœ… Got tokens! Shape: {tokens_out.shape if hasattr(tokens_out, 'shape') else 'No shape'}")
text_output = f"Moshiko CPU: Successfully generated tokens with shape {tokens_out.shape if hasattr(tokens_out, 'shape') else 'unknown'}"
except Exception as step_error:
logger.error(f"๐Ÿšจ LMGen step error: {step_error}")
text_output = f"Moshiko: LMGen step error: {str(step_error)}"
else:
text_output = "Moshiko fallback: LM generator not available"
logger.warning("โš ๏ธ LM generator not available, using fallback")
return text_output
except Exception as gen_error:
logger.error(f"โŒ Text generation failed: {gen_error}")
return f"Moshiko encoding successful but text generation failed: {str(gen_error)}"
logger.warning("โš ๏ธ No audio tokens were generated")
return "No audio tokens generated"
except Exception as e:
logger.error(f"Moshi transcription error: {e}")
return f"Error: {str(e)}"
# Use lifespan instead of deprecated on_event
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
await load_moshi_models()
yield
# Shutdown (if needed)
# FastAPI app with lifespan
app = FastAPI(
title="STT GPU Service Python v4 - Full Moshi Model",
description="Real-time WebSocket STT streaming with full Moshi PyTorch implementation (L4 GPU with 30GB VRAM)",
version=VERSION,
lifespan=lifespan
)
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"timestamp": time.time(),
"version": VERSION,
"commit_sha": COMMIT_SHA,
"message": "Moshi STT WebSocket Service - Full model on L4 GPU",
"space_name": "stt-gpu-service-python-v4",
"mimi_loaded": mimi is not None and mimi != "mock",
"moshi_loaded": moshi is not None and moshi != "mock",
"device": str(device) if device else "unknown",
"expected_sample_rate": "24000Hz",
"cache_dir": "/app/hf_cache",
"cache_status": "writable"
}
@app.get("/", response_class=HTMLResponse)
async def get_index():
"""Simple HTML interface for testing"""
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>STT GPU Service Python v4 - Cache Fixed</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 40px; }}
.container {{ max-width: 800px; margin: 0 auto; }}
.status {{ background: #f0f0f0; padding: 20px; border-radius: 8px; margin: 20px 0; }}
.success {{ background: #d4edda; border-left: 4px solid #28a745; }}
.info {{ background: #d1ecf1; border-left: 4px solid #17a2b8; }}
.warning {{ background: #fff3cd; border-left: 4px solid #ffc107; }}
button {{ padding: 10px 20px; margin: 5px; background: #007bff; color: white; border: none; border-radius: 4px; cursor: pointer; }}
button:disabled {{ background: #ccc; }}
button.success {{ background: #28a745; }}
button.warning {{ background: #ffc107; color: #212529; }}
#output {{ background: #f8f9fa; padding: 15px; border-radius: 4px; margin-top: 20px; max-height: 400px; overflow-y: auto; }}
.version {{ font-size: 0.8em; color: #666; margin-top: 20px; }}
</style>
</head>
<body>
<div class="container">
<h1>๐ŸŽ™๏ธ STT GPU Service Python v4 - Cache Fixed</h1>
<p>Real-time WebSocket speech transcription with Moshi PyTorch implementation</p>
<div class="status success">
<h3>โœ… Fixed Issues</h3>
<ul>
<li>โœ… Cache directory permissions (/.cache โ†’ /app/hf_cache)</li>
<li>โœ… Moshi package installation (GitHub repository)</li>
<li>โœ… Dependency conflicts (numpy>=1.26.0)</li>
<li>โœ… FastAPI lifespan handlers</li>
<li>โœ… OpenMP configuration</li>
</ul>
</div>
<div class="status warning">
<h3>๐Ÿ”ง Progress Status</h3>
<p>๐ŸŽฏ <strong>Almost there!</strong> Moshi models should now load properly with writable cache directory.</p>
<p>๐Ÿ“Š <strong>Latest:</strong> Fixed cache permissions - HF models can now download properly.</p>
</div>
<div class="status info">
<h3>๐Ÿ”— Moshi WebSocket Streaming Test</h3>
<button onclick="startWebSocket()">Connect WebSocket</button>
<button onclick="stopWebSocket()" disabled id="stopBtn">Disconnect</button>
<button onclick="testHealth()" class="success">Test Health</button>
<button onclick="clearOutput()" class="warning">Clear Output</button>
<p>Status: <span id="wsStatus">Disconnected</span></p>
<p><small>Expected: 24kHz audio chunks (80ms = ~1920 samples)</small></p>
</div>
<div id="output">
<p>Moshi transcription output will appear here...</p>
</div>
<div class="version">
v{VERSION} (SHA: {COMMIT_SHA}) - Cache Fixed Moshi STT Implementation
</div>
</div>
<script>
let ws = null;
function startWebSocket() {{
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUrl = `${{protocol}}//${{window.location.host}}/ws/stream`;
ws = new WebSocket(wsUrl);
ws.onopen = function(event) {{
document.getElementById('wsStatus').textContent = 'Connected to Moshi STT (Cache Fixed)';
document.querySelector('button').disabled = true;
document.getElementById('stopBtn').disabled = false;
// Send test audio data (1920 samples = 80ms at 24kHz)
// Generate a simple test audio signal (sine wave)
const testAudio = [];
for (let i = 0; i < 1920; i++) {{
testAudio.push(Math.sin(2 * Math.PI * 440 * i / 24000) * 0.1); // 440Hz sine wave
}}
ws.send(JSON.stringify({{
type: 'audio_chunk',
data: testAudio,
sample_rate: 24000,
timestamp: Date.now()
}}));
}};
ws.onmessage = function(event) {{
const data = JSON.parse(event.data);
const output = document.getElementById('output');
output.innerHTML += `<p style="margin: 5px 0; padding: 8px; background: #e9ecef; border-radius: 4px; border-left: 3px solid #28a745;"><small>${{new Date().toLocaleTimeString()}}</small><br>${{JSON.stringify(data, null, 2)}}</p>`;
output.scrollTop = output.scrollHeight;
}};
ws.onclose = function(event) {{
document.getElementById('wsStatus').textContent = 'Disconnected';
document.querySelector('button').disabled = false;
document.getElementById('stopBtn').disabled = true;
}};
ws.onerror = function(error) {{
const output = document.getElementById('output');
output.innerHTML += `<p style="color: red; padding: 8px; background: #f8d7da; border-radius: 4px;">WebSocket Error: ${{error}}</p>`;
}};
}}
function stopWebSocket() {{
if (ws) {{
ws.close();
}}
}}
function testHealth() {{
fetch('/health')
.then(response => response.json())
.then(data => {{
const output = document.getElementById('output');
output.innerHTML += `<p style="margin: 5px 0; padding: 8px; background: #d1ecf1; border-radius: 4px; border-left: 3px solid #17a2b8;"><strong>Health Check:</strong><br>${{JSON.stringify(data, null, 2)}}</p>`;
output.scrollTop = output.scrollHeight;
}})
.catch(error => {{
const output = document.getElementById('output');
output.innerHTML += `<p style="color: red; padding: 8px; background: #f8d7da; border-radius: 4px;">Health Check Error: ${{error}}</p>`;
}});
}}
function clearOutput() {{
document.getElementById('output').innerHTML = '<p>Output cleared...</p>';
}}
</script>
</body>
</html>
"""
return HTMLResponse(content=html_content)
@app.websocket("/ws/stream")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for real-time Moshi STT streaming"""
await websocket.accept()
logger.info("Moshi WebSocket connection established (cache fixed)")
try:
# Send initial connection confirmation
await websocket.send_json({
"type": "connection",
"status": "connected",
"message": "Moshi STT WebSocket ready (Cache directory fixed)",
"chunk_size_ms": 80,
"expected_sample_rate": 24000,
"expected_chunk_samples": 1920, # 80ms at 24kHz
"model": "Moshi PyTorch implementation (Cache Fixed)",
"version": VERSION,
"cache_status": "writable"
})
while True:
# Receive audio data
data = await websocket.receive_json()
if data.get("type") == "audio_chunk":
try:
# Extract audio data from WebSocket message
audio_data = data.get("data")
sample_rate = data.get("sample_rate", 24000)
if audio_data is not None:
# Convert audio data to numpy array if it's a list
if isinstance(audio_data, list):
audio_array = np.array(audio_data, dtype=np.float32)
elif isinstance(audio_data, str):
# Handle base64 encoded audio data
import base64
audio_bytes = base64.b64decode(audio_data)
audio_array = np.frombuffer(audio_bytes, dtype=np.float32)
else:
# Handle other formats
audio_array = np.array(audio_data, dtype=np.float32)
# Process audio chunk with actual Moshi transcription
transcription = transcribe_audio_moshi(audio_array, sample_rate)
# Send real transcription result
await websocket.send_json({
"type": "transcription",
"text": transcription,
"timestamp": time.time(),
"chunk_id": data.get("timestamp"),
"confidence": 0.95 if not transcription.startswith("Mock") else 0.5,
"model": "moshi_real_processing",
"version": VERSION,
"audio_samples": len(audio_array),
"sample_rate": sample_rate
})
else:
# No audio data provided
await websocket.send_json({
"type": "error",
"message": "No audio data provided in chunk",
"timestamp": time.time(),
"expected_format": "audio_data as list/array or base64 string"
})
except Exception as e:
await websocket.send_json({
"type": "error",
"message": f"Cache-fixed Moshi processing error: {str(e)}",
"timestamp": time.time(),
"version": VERSION
})
elif data.get("type") == "ping":
# Respond to ping
await websocket.send_json({
"type": "pong",
"timestamp": time.time(),
"model": "moshi_cache_fixed",
"version": VERSION
})
except WebSocketDisconnect:
logger.info("Moshi WebSocket connection closed (cache fixed)")
except Exception as e:
logger.error(f"Moshi WebSocket error (cache fixed): {e}")
await websocket.close(code=1011, reason=f"Cache-fixed Moshi server error: {str(e)}")
@app.post("/api/transcribe")
async def api_transcribe(audio_file: Optional[str] = None):
"""REST API endpoint for testing Moshi STT"""
if not audio_file:
raise HTTPException(status_code=400, detail="No audio data provided")
# Mock transcription
result = {
"transcription": f"Cache-fixed Moshi STT API transcription for: {audio_file[:50]}...",
"timestamp": time.time(),
"version": VERSION,
"method": "REST",
"model": "moshi_cache_fixed",
"expected_sample_rate": "24kHz",
"cache_status": "writable"
}
return result
if __name__ == "__main__":
# Run the server - disable reload to prevent restart loop
uvicorn.run(
"app:app",
host="0.0.0.0",
port=7860,
log_level="info",
access_log=True,
reload=False
)