MogensR's picture
Update models/loaders/model_loader.py
d51cab4
#!/usr/bin/env python3
"""
Unified Model Loader
Coordinates separate SAM2 and MatAnyone loaders for cleaner architecture
Notes:
- SAM2: exposes set_image(...) and predict(...)
- MatAnyone: our loader returns a stateful callable adapter:
- callable(adapter) -> frame0: adapter(image_rgb01, mask01), frames>0: adapter(image_rgb01)
- optional: adapter.reset() to clear per-video memory
We therefore validate MatAnyone by checking "callable(...)" and/or presence of "reset",
not only ".step/.process".
"""
from __future__ import annotations
import os
import gc
import time
import logging
from typing import Optional, Dict, Any, Tuple, Callable
import torch
from core.exceptions import ModelLoadingError
from utils.hardware.device_manager import DeviceManager
from utils.system.memory_manager import MemoryManager
# Import the specialized loaders
from models.loaders.sam2_loader import SAM2Loader
from models.loaders.matanyone_loader import MatAnyoneLoader
logger = logging.getLogger(__name__)
class LoadedModel:
"""Container for loaded model information"""
def __init__(self, model=None, model_id: str = "", load_time: float = 0.0,
device: str = "", framework: str = ""):
self.model = model
self.model_id = model_id
self.load_time = load_time
self.device = device
self.framework = framework
def to_dict(self) -> Dict[str, Any]:
return {
"model_id": self.model_id,
"framework": self.framework,
"device": self.device,
"load_time": self.load_time,
"loaded": self.model is not None,
"model_type": type(self.model).__name__ if self.model is not None else None,
}
class ModelLoader:
"""Main model loader that coordinates SAM2 and MatAnyone loaders"""
def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager):
self.device_manager = device_mgr
self.memory_manager = memory_mgr
self.device = self.device_manager.get_optimal_device()
# Initialize specialized loaders
self.sam2_loader = SAM2Loader(device=str(self.device))
self.matanyone_loader = MatAnyoneLoader(device=str(self.device))
# Model storage
self.sam2_predictor: Optional[LoadedModel] = None
self.matanyone_model: Optional[LoadedModel] = None
# Statistics
self.loading_stats = {
"sam2_load_time": 0.0,
"matanyone_load_time": 0.0,
"total_load_time": 0.0,
"models_loaded": False,
"loading_attempts": 0,
}
logger.info(f"ModelLoader initialized for device: {self.device}")
def load_all_models(
self,
progress_callback: Optional[Callable[[float, str], None]] = None,
cancel_event=None
) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]:
"""
Load all models using specialized loaders
Returns:
Tuple of (sam2_model, matanyone_model)
"""
start_time = time.time()
self.loading_stats["loading_attempts"] += 1
try:
logger.info("Starting model loading process...")
if progress_callback:
progress_callback(0.0, "Initializing model loading...")
# Clean up any existing models
self._cleanup_models()
# -------------------- Load SAM2 -------------------- #
if progress_callback:
progress_callback(0.1, "Loading SAM2 model...")
sam2_start = time.time()
# CHANGED: Force tiny model instead of auto-detection
sam2_model = self.sam2_loader.load("tiny") # Force tiny model for faster loading and less memory usage
sam2_time = time.time() - sam2_start
if sam2_model:
self.sam2_predictor = LoadedModel(
model=sam2_model,
model_id=self.sam2_loader.model_id,
load_time=sam2_time,
device=str(self.device),
framework="sam2"
)
self.loading_stats["sam2_load_time"] = sam2_time
logger.info(f"SAM2 loaded in {sam2_time:.2f}s")
else:
logger.warning("SAM2 loading failed")
# Cancellation check
if cancel_event and cancel_event.is_set():
if progress_callback:
progress_callback(1.0, "Model loading cancelled")
return self.sam2_predictor, None
# ----------------- Load MatAnyone ------------------ #
if progress_callback:
progress_callback(0.6, "Loading MatAnyone model...")
matanyone_start = time.time()
matanyone_model = self.matanyone_loader.load() # returns stateful callable adapter or None
matanyone_time = time.time() - matanyone_start
if matanyone_model:
self.matanyone_model = LoadedModel(
model=matanyone_model,
model_id=self.matanyone_loader.model_id,
load_time=matanyone_time,
device=str(self.device),
framework="matanyone"
)
self.loading_stats["matanyone_load_time"] = matanyone_time
logger.info(f"MatAnyone loaded in {matanyone_time:.2f}s")
else:
logger.warning("MatAnyone loading failed")
# ----------------- Finalize stats ------------------ #
total_time = time.time() - start_time
self.loading_stats["total_load_time"] = total_time
self.loading_stats["models_loaded"] = bool(self.sam2_predictor or self.matanyone_model)
if progress_callback:
if self.loading_stats["models_loaded"]:
progress_callback(1.0, "Models loaded successfully")
else:
progress_callback(1.0, "Model loading completed with failures")
logger.info(f"Model loading completed in {total_time:.2f}s")
return self.sam2_predictor, self.matanyone_model
except Exception as e:
error_msg = f"Model loading failed: {str(e)}"
logger.error(error_msg)
self._cleanup_models()
self.loading_stats["models_loaded"] = False
if progress_callback:
progress_callback(1.0, f"Error: {error_msg}")
return None, None
def reload_models(
self,
progress_callback: Optional[Callable[[float, str], None]] = None
) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]:
"""Reload all models from scratch"""
logger.info("Reloading models...")
self._cleanup_models()
self.loading_stats["models_loaded"] = False
return self.load_all_models(progress_callback)
@property
def models_ready(self) -> bool:
"""Check if any models are loaded and ready"""
return self.sam2_predictor is not None or self.matanyone_model is not None
def get_sam2(self):
"""Get SAM2 predictor model"""
return self.sam2_predictor.model if self.sam2_predictor else None
def get_matanyone(self):
"""
Get MatAnyone processor model.
IMPORTANT: This returns the stateful callable adapter from MatAnyoneLoader:
- callable(image_rgb01[, mask01]) -> 2D alpha
- optional .reset() to clear memory per video
"""
return self.matanyone_model.model if self.matanyone_model else None
def validate_models(self) -> bool:
"""Validate that loaded models have expected interfaces"""
try:
valid = False
# Validate SAM2
if self.sam2_predictor:
model = self.sam2_predictor.model
if hasattr(model, "set_image") and hasattr(model, "predict"):
valid = True
logger.info("SAM2 model validated")
# Validate MatAnyone (stateful adapter OR raw core)
if self.matanyone_model:
model = self.matanyone_model.model
if callable(model):
valid = True
logger.info("MatAnyone adapter validated (callable)")
elif hasattr(model, "step") or hasattr(model, "process"):
valid = True
logger.info("MatAnyone core validated (.step/.process)")
elif hasattr(model, "reset"):
# still accept an adapter exposing reset but not callable (unlikely)
valid = True
logger.info("MatAnyone object validated via reset()")
else:
logger.warning("MatAnyone present but interface not recognized")
return valid
except Exception as e:
logger.error(f"Model validation failed: {e}")
return False
def get_model_info(self) -> Dict[str, Any]:
"""Get detailed information about loaded models"""
info: Dict[str, Any] = {
"models_loaded": self.loading_stats["models_loaded"],
"device": str(self.device),
"loading_stats": self.loading_stats.copy(),
}
# Add SAM2 info
info["sam2"] = self.sam2_loader.get_info() if self.sam2_loader else {}
# Add MatAnyone info (augment with interface hints)
mat_info = self.matanyone_loader.get_info() if self.matanyone_loader else {}
try:
m = self.get_matanyone()
mat_info["callable"] = bool(callable(m))
mat_info["has_reset"] = bool(hasattr(m, "reset"))
mat_info["has_step"] = bool(hasattr(m, "step"))
mat_info["has_process"] = bool(hasattr(m, "process"))
except Exception:
pass
info["matanyone"] = mat_info
return info
def get_load_summary(self) -> str:
"""Get human-readable loading summary"""
if not self.loading_stats["models_loaded"]:
return "No models loaded"
lines = []
lines.append(f"Models loaded in {self.loading_stats['total_load_time']:.1f}s")
if self.sam2_predictor:
lines.append(f"βœ“ SAM2: {self.loading_stats['sam2_load_time']:.1f}s")
lines.append(f" Model: {self.sam2_predictor.model_id}")
else:
lines.append("βœ— SAM2: Failed to load")
if self.matanyone_model:
# Describe adapter/callable for clarity
iface = []
m = self.matanyone_model.model
if callable(m): iface.append("callable")
if hasattr(m, "reset"): iface.append("reset")
if hasattr(m, "step"): iface.append("step")
if hasattr(m, "process"): iface.append("process")
iface_str = f" ({', '.join(iface)})" if iface else ""
lines.append(f"βœ“ MatAnyone: {self.loading_stats['matanyone_load_time']:.1f}s{iface_str}")
lines.append(f" Model: {self.matanyone_model.model_id}")
else:
lines.append("βœ— MatAnyone: Failed to load")
lines.append(f"Device: {self.device}")
return "\n".join(lines)
def cleanup(self):
"""Clean up all resources"""
self._cleanup_models()
logger.info("ModelLoader cleanup completed")
def _cleanup_models(self):
"""Internal cleanup of loaded models"""
# Clean up SAM2
if self.sam2_loader:
try:
if hasattr(self.sam2_loader, 'cleanup'):
self.sam2_loader.cleanup()
else:
logger.debug("SAM2 loader has no cleanup method")
except Exception as e:
logger.debug(f"SAM2 loader cleanup error: {e}")
if self.sam2_predictor:
try:
del self.sam2_predictor
except Exception:
pass
self.sam2_predictor = None
# Clean up MatAnyone
if self.matanyone_loader:
try:
if hasattr(self.matanyone_loader, 'cleanup'):
self.matanyone_loader.cleanup()
else:
# MatAnyone doesn't have cleanup, but we can clean the wrapper
if hasattr(self.matanyone_loader, '_wrapper') and self.matanyone_loader._wrapper:
if hasattr(self.matanyone_loader._wrapper, 'reset'):
self.matanyone_loader._wrapper.reset()
self.matanyone_loader._processor = None
self.matanyone_loader._wrapper = None
logger.debug("MatAnyone loader cleaned up manually")
except Exception as e:
logger.debug(f"MatAnyone loader cleanup error: {e}")
if self.matanyone_model:
try:
del self.matanyone_model
except Exception:
pass
self.matanyone_model = None
# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Garbage collection
gc.collect()
logger.debug("Model cleanup completed")