Peter Michael Gits Claude commited on
Commit
dd577b3
Β·
1 Parent(s): 393f5a7

Switch to smaller Moshiko model for T4 GPU utilization

Browse files

v1.4.0 - MAJOR: Switch from full Moshi to smaller Moshiko model
1. Changed model repo from DEFAULT_REPO to kyutai/moshiko-pytorch-bf16
2. Moshiko requires ~16GB VRAM vs full Moshi ~24GB (should fit T4 Small)
3. Updated all logging and UI text to reflect Moshiko model
4. Maintains GPU utilization instead of CPU fallback
5. Smaller model optimized for consumer GPUs while maintaining quality

πŸ€– Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -16,7 +16,7 @@ from fastapi.responses import JSONResponse, HTMLResponse
16
  import uvicorn
17
 
18
  # Version tracking
19
- VERSION = "1.3.14"
20
  COMMIT_SHA = "TBD"
21
 
22
  # Configure logging
@@ -53,42 +53,44 @@ async def load_moshi_models():
53
  from huggingface_hub import hf_hub_download
54
  from moshi.models import loaders, LMGen
55
 
56
- # Load Mimi (audio codec)
57
- logger.info("Loading Mimi audio codec...")
58
- mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME, cache_dir='/app/hf_cache')
 
 
59
  mimi = loaders.get_mimi(mimi_weight, device=device)
60
  mimi.set_num_codebooks(8) # Limited to 8 for Moshi
61
- logger.info("βœ… Mimi loaded successfully")
62
 
63
  # Clear cache after Mimi loading
64
  if device == "cuda":
65
  torch.cuda.empty_cache()
66
  logger.info(f"GPU memory after Mimi: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
67
 
68
- # Load Moshi (language model)
69
- logger.info("Loading Moshi language model...")
70
- moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME, cache_dir='/app/hf_cache')
71
 
72
  # Try loading with memory-efficient settings
73
  try:
74
  moshi = loaders.get_moshi_lm(moshi_weight, device=device)
75
  lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)
76
- logger.info("βœ… Moshi loaded successfully")
77
  except RuntimeError as cuda_error:
78
  if "CUDA out of memory" in str(cuda_error):
79
- logger.warning(f"CUDA out of memory, trying CPU fallback: {cuda_error}")
80
  # Move Mimi to CPU as well for consistency
81
  mimi = loaders.get_mimi(mimi_weight, device="cpu")
82
  mimi.set_num_codebooks(8)
83
  device = "cpu"
84
  moshi = loaders.get_moshi_lm(moshi_weight, device="cpu")
85
  lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)
86
- logger.info("βœ… Moshi loaded successfully on CPU (fallback)")
87
  logger.info("βœ… Mimi also moved to CPU for device consistency")
88
  else:
89
  raise
90
 
91
- logger.info("πŸŽ‰ All Moshi models loaded successfully!")
92
  return True
93
 
94
  except ImportError as import_error:
@@ -118,7 +120,7 @@ def transcribe_audio_moshi(audio_data: np.ndarray, sample_rate: int = 24000) ->
118
  try:
119
  if mimi == "mock":
120
  duration = len(audio_data) / sample_rate
121
- return f"Mock Moshi STT: {duration:.2f}s audio at {sample_rate}Hz"
122
 
123
  # Ensure 24kHz audio for Moshi
124
  if sample_rate != 24000:
@@ -176,8 +178,8 @@ async def lifespan(app: FastAPI):
176
 
177
  # FastAPI app with lifespan
178
  app = FastAPI(
179
- title="STT GPU Service Python v4 - Cache Fixed",
180
- description="Real-time WebSocket STT streaming with Moshi PyTorch implementation (Cache Fixed)",
181
  version=VERSION,
182
  lifespan=lifespan
183
  )
@@ -190,7 +192,7 @@ async def health_check():
190
  "timestamp": time.time(),
191
  "version": VERSION,
192
  "commit_sha": COMMIT_SHA,
193
- "message": "Moshi STT WebSocket Service - Cache directory fixed",
194
  "space_name": "stt-gpu-service-python-v4",
195
  "mimi_loaded": mimi is not None and mimi != "mock",
196
  "moshi_loaded": moshi is not None and moshi != "mock",
 
16
  import uvicorn
17
 
18
  # Version tracking
19
+ VERSION = "1.4.0"
20
  COMMIT_SHA = "TBD"
21
 
22
  # Configure logging
 
53
  from huggingface_hub import hf_hub_download
54
  from moshi.models import loaders, LMGen
55
 
56
+ # Load Mimi (audio codec) - using smaller Moshiko model
57
+ logger.info("Loading Mimi audio codec for Moshiko...")
58
+ # Use Moshiko model repo instead of default
59
+ MOSHIKO_REPO = "kyutai/moshiko-pytorch-bf16"
60
+ mimi_weight = hf_hub_download(MOSHIKO_REPO, loaders.MIMI_NAME, cache_dir='/app/hf_cache')
61
  mimi = loaders.get_mimi(mimi_weight, device=device)
62
  mimi.set_num_codebooks(8) # Limited to 8 for Moshi
63
+ logger.info("βœ… Mimi loaded successfully (Moshiko variant)")
64
 
65
  # Clear cache after Mimi loading
66
  if device == "cuda":
67
  torch.cuda.empty_cache()
68
  logger.info(f"GPU memory after Mimi: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
69
 
70
+ # Load Moshiko (smaller language model)
71
+ logger.info("Loading Moshiko language model...")
72
+ moshi_weight = hf_hub_download(MOSHIKO_REPO, loaders.MOSHI_NAME, cache_dir='/app/hf_cache')
73
 
74
  # Try loading with memory-efficient settings
75
  try:
76
  moshi = loaders.get_moshi_lm(moshi_weight, device=device)
77
  lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)
78
+ logger.info("βœ… Moshiko loaded successfully on GPU")
79
  except RuntimeError as cuda_error:
80
  if "CUDA out of memory" in str(cuda_error):
81
+ logger.warning(f"Moshiko CUDA out of memory, trying CPU fallback: {cuda_error}")
82
  # Move Mimi to CPU as well for consistency
83
  mimi = loaders.get_mimi(mimi_weight, device="cpu")
84
  mimi.set_num_codebooks(8)
85
  device = "cpu"
86
  moshi = loaders.get_moshi_lm(moshi_weight, device="cpu")
87
  lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)
88
+ logger.info("βœ… Moshiko loaded successfully on CPU (fallback)")
89
  logger.info("βœ… Mimi also moved to CPU for device consistency")
90
  else:
91
  raise
92
 
93
+ logger.info("πŸŽ‰ All Moshiko models loaded successfully!")
94
  return True
95
 
96
  except ImportError as import_error:
 
120
  try:
121
  if mimi == "mock":
122
  duration = len(audio_data) / sample_rate
123
+ return f"Mock Moshiko STT: {duration:.2f}s audio at {sample_rate}Hz"
124
 
125
  # Ensure 24kHz audio for Moshi
126
  if sample_rate != 24000:
 
178
 
179
  # FastAPI app with lifespan
180
  app = FastAPI(
181
+ title="STT GPU Service Python v4 - Moshiko Model",
182
+ description="Real-time WebSocket STT streaming with Moshiko PyTorch implementation (Smaller model for T4 GPU)",
183
  version=VERSION,
184
  lifespan=lifespan
185
  )
 
192
  "timestamp": time.time(),
193
  "version": VERSION,
194
  "commit_sha": COMMIT_SHA,
195
+ "message": "Moshiko STT WebSocket Service - Smaller model for T4 GPU",
196
  "space_name": "stt-gpu-service-python-v4",
197
  "mimi_loaded": mimi is not None and mimi != "mock",
198
  "moshi_loaded": moshi is not None and moshi != "mock",