SmokeScan / models /loader.py
KinetoLabs's picture
Replace dual 8B with single 30B-A3B FP8 vision model
706520f
"""Model loading with mock/real switching based on environment.
Supports two loading modes:
- MOCK_MODELS=true: Loads mock models (fast, for local dev on RTX 4090)
- MOCK_MODELS=false: Loads all real models at startup (~38-43GB total)
Memory Strategy (Simultaneous Loading for 4xL4 GPUs with 88GB total):
- Vision 30B-A3B FP8 via vLLM: ~30-35GB
- Embedding 2B: ~4GB
- Reranker 2B: ~4GB
- Total: ~38-43GB, leaving ~45GB+ headroom
"""
import logging
import time
from typing import Union
from config.settings import settings
logger = logging.getLogger(__name__)
# Type alias for model stack
ModelStack = Union["MockModelStack", "RealModelStack"] # noqa: F821
# Lazy singleton
_model_stack: ModelStack | None = None
def get_model_stack() -> ModelStack:
"""Get model stack based on environment configuration.
For mock models: Loads mock models immediately (fast, for local dev).
For real models: Loads all 3 models at startup (~38-43GB total).
"""
start_time = time.time()
if settings.mock_models:
logger.info("Loading MOCK model stack (development mode)")
from models.mock import MockModelStack
stack = MockModelStack().load_all()
elapsed = time.time() - start_time
logger.info(f"Mock model stack loaded in {elapsed:.2f}s")
return stack
else:
logger.info("Loading REAL model stack (production mode)")
logger.info(f"Vision model: {settings.vision_model} (FP8 via vLLM)")
logger.info(f"Embedding model: {settings.embedding_model}")
logger.info(f"Reranker model: {settings.reranker_model}")
from models.real import RealModelStack
# Load all models at startup (simultaneous loading)
stack = RealModelStack().load_all()
elapsed = time.time() - start_time
logger.info(f"Real model stack loaded in {elapsed:.2f}s")
return stack
def get_models() -> ModelStack:
"""Get or create the singleton model stack.
Returns fully loaded model stack (all models ready for inference).
"""
global _model_stack
if _model_stack is None:
logger.debug("Model stack not initialized, creating new stack")
_model_stack = get_model_stack()
else:
logger.debug("Returning cached model stack")
return _model_stack
def reset_models() -> None:
"""Reset the model stack (useful for testing)."""
global _model_stack
_model_stack = None