Delanoe Pirard
Deploy to HuggingFace Spaces
18b382b
"""
Model caching utilities for Depth Anything 3.
Provides model caching functionality to avoid reloading model weights on every instantiation.
This significantly reduces latency for repeated model creation (2-5s gain).
"""
from __future__ import annotations
import threading
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
from depth_anything_3.utils.logger import logger
class ModelCache:
"""
Thread-safe singleton cache for Depth Anything 3 models.
Caches loaded model weights to avoid reloading from disk on every instantiation.
Each unique combination of (model_name, device) is cached separately.
Usage:
cache = ModelCache()
model = cache.get(model_name, device, loader_fn)
# loader_fn is only called if cache miss
Thread Safety:
Uses threading.Lock to ensure thread-safe access to cache.
Memory Management:
- Models are kept in cache until explicitly cleared
- Use clear() to free memory when needed
- Use clear_device() to clear specific device models
"""
_instance: Optional["ModelCache"] = None
_lock = threading.Lock()
def __new__(cls):
"""Singleton pattern to ensure single cache instance."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
"""Initialize cache storage."""
if self._initialized:
return
self._cache: Dict[Tuple[str, str], nn.Module] = {}
self._cache_lock = threading.Lock()
self._initialized = True
logger.info("ModelCache initialized")
def get(
self,
model_name: str,
device: torch.device | str,
loader_fn: callable,
) -> nn.Module:
"""
Get cached model or load if not in cache.
Args:
model_name: Name of the model (e.g., "da3-large")
device: Target device (cuda, mps, cpu)
loader_fn: Function to load model if cache miss
Should return nn.Module
Returns:
Cached or freshly loaded model on specified device
Example:
>>> cache = ModelCache()
>>> model = cache.get(
... "da3-large",
... "cuda",
... lambda: create_model()
... )
"""
device_str = str(device)
cache_key = (model_name, device_str)
with self._cache_lock:
if cache_key in self._cache:
logger.debug(f"Model cache HIT: {model_name} on {device_str}")
return self._cache[cache_key]
logger.info(f"Model cache MISS: {model_name} on {device_str}. Loading...")
model = loader_fn()
self._cache[cache_key] = model
logger.info(f"Model cached: {model_name} on {device_str}")
return model
def clear(self) -> None:
"""
Clear entire cache and free memory.
Removes all cached models and forces garbage collection.
Useful when switching between many different models.
"""
with self._cache_lock:
num_cached = len(self._cache)
self._cache.clear()
# Force garbage collection to free GPU memory
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if hasattr(torch, "mps") and torch.backends.mps.is_available():
torch.mps.empty_cache()
logger.info(f"Model cache cleared ({num_cached} models removed)")
def clear_device(self, device: torch.device | str) -> None:
"""
Clear all models on specific device.
Args:
device: Device to clear (e.g., "cuda", "mps", "cpu")
Example:
>>> cache = ModelCache()
>>> cache.clear_device("cuda") # Clear all CUDA models
"""
device_str = str(device)
with self._cache_lock:
keys_to_remove = [key for key in self._cache if key[1] == device_str]
for key in keys_to_remove:
del self._cache[key]
# Free device memory
if "cuda" in device_str and torch.cuda.is_available():
torch.cuda.empty_cache()
elif "mps" in device_str and hasattr(torch, "mps") and torch.backends.mps.is_available():
torch.mps.empty_cache()
logger.info(f"Model cache cleared for device {device_str} ({len(keys_to_remove)} models removed)")
def get_cache_info(self) -> Dict[str, int]:
"""
Get cache statistics.
Returns:
Dictionary with cache info:
- total: Total number of cached models
- by_device: Number of models per device
"""
with self._cache_lock:
info = {
"total": len(self._cache),
"by_device": {},
}
for model_name, device_str in self._cache.keys():
if device_str not in info["by_device"]:
info["by_device"][device_str] = 0
info["by_device"][device_str] += 1
return info
# Global singleton instance
_global_cache = ModelCache()
def get_model_cache() -> ModelCache:
"""
Get global model cache instance.
Returns:
Singleton ModelCache instance
Example:
>>> from depth_anything_3.cache import get_model_cache
>>> cache = get_model_cache()
>>> cache.clear()
"""
return _global_cache