#!/usr/bin/env python3 """ Unified Model Loader Coordinates separate SAM2 and MatAnyone loaders for cleaner architecture Notes: - SAM2: exposes set_image(...) and predict(...) - MatAnyone: our loader returns a stateful callable adapter: - callable(adapter) -> frame0: adapter(image_rgb01, mask01), frames>0: adapter(image_rgb01) - optional: adapter.reset() to clear per-video memory We therefore validate MatAnyone by checking "callable(...)" and/or presence of "reset", not only ".step/.process". """ from __future__ import annotations import os import gc import time import logging from typing import Optional, Dict, Any, Tuple, Callable import torch from core.exceptions import ModelLoadingError from utils.hardware.device_manager import DeviceManager from utils.system.memory_manager import MemoryManager # Import the specialized loaders from models.loaders.sam2_loader import SAM2Loader from models.loaders.matanyone_loader import MatAnyoneLoader logger = logging.getLogger(__name__) class LoadedModel: """Container for loaded model information""" def __init__(self, model=None, model_id: str = "", load_time: float = 0.0, device: str = "", framework: str = ""): self.model = model self.model_id = model_id self.load_time = load_time self.device = device self.framework = framework def to_dict(self) -> Dict[str, Any]: return { "model_id": self.model_id, "framework": self.framework, "device": self.device, "load_time": self.load_time, "loaded": self.model is not None, "model_type": type(self.model).__name__ if self.model is not None else None, } class ModelLoader: """Main model loader that coordinates SAM2 and MatAnyone loaders""" def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager): self.device_manager = device_mgr self.memory_manager = memory_mgr self.device = self.device_manager.get_optimal_device() # Initialize specialized loaders self.sam2_loader = SAM2Loader(device=str(self.device)) self.matanyone_loader = MatAnyoneLoader(device=str(self.device)) # Model storage self.sam2_predictor: Optional[LoadedModel] = None self.matanyone_model: Optional[LoadedModel] = None # Statistics self.loading_stats = { "sam2_load_time": 0.0, "matanyone_load_time": 0.0, "total_load_time": 0.0, "models_loaded": False, "loading_attempts": 0, } logger.info(f"ModelLoader initialized for device: {self.device}") def load_all_models( self, progress_callback: Optional[Callable[[float, str], None]] = None, cancel_event=None ) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]: """ Load all models using specialized loaders Returns: Tuple of (sam2_model, matanyone_model) """ start_time = time.time() self.loading_stats["loading_attempts"] += 1 try: logger.info("Starting model loading process...") if progress_callback: progress_callback(0.0, "Initializing model loading...") # Clean up any existing models self._cleanup_models() # -------------------- Load SAM2 -------------------- # if progress_callback: progress_callback(0.1, "Loading SAM2 model...") sam2_start = time.time() # CHANGED: Force tiny model instead of auto-detection sam2_model = self.sam2_loader.load("tiny") # Force tiny model for faster loading and less memory usage sam2_time = time.time() - sam2_start if sam2_model: self.sam2_predictor = LoadedModel( model=sam2_model, model_id=self.sam2_loader.model_id, load_time=sam2_time, device=str(self.device), framework="sam2" ) self.loading_stats["sam2_load_time"] = sam2_time logger.info(f"SAM2 loaded in {sam2_time:.2f}s") else: logger.warning("SAM2 loading failed") # Cancellation check if cancel_event and cancel_event.is_set(): if progress_callback: progress_callback(1.0, "Model loading cancelled") return self.sam2_predictor, None # ----------------- Load MatAnyone ------------------ # if progress_callback: progress_callback(0.6, "Loading MatAnyone model...") matanyone_start = time.time() matanyone_model = self.matanyone_loader.load() # returns stateful callable adapter or None matanyone_time = time.time() - matanyone_start if matanyone_model: self.matanyone_model = LoadedModel( model=matanyone_model, model_id=self.matanyone_loader.model_id, load_time=matanyone_time, device=str(self.device), framework="matanyone" ) self.loading_stats["matanyone_load_time"] = matanyone_time logger.info(f"MatAnyone loaded in {matanyone_time:.2f}s") else: logger.warning("MatAnyone loading failed") # ----------------- Finalize stats ------------------ # total_time = time.time() - start_time self.loading_stats["total_load_time"] = total_time self.loading_stats["models_loaded"] = bool(self.sam2_predictor or self.matanyone_model) if progress_callback: if self.loading_stats["models_loaded"]: progress_callback(1.0, "Models loaded successfully") else: progress_callback(1.0, "Model loading completed with failures") logger.info(f"Model loading completed in {total_time:.2f}s") return self.sam2_predictor, self.matanyone_model except Exception as e: error_msg = f"Model loading failed: {str(e)}" logger.error(error_msg) self._cleanup_models() self.loading_stats["models_loaded"] = False if progress_callback: progress_callback(1.0, f"Error: {error_msg}") return None, None def reload_models( self, progress_callback: Optional[Callable[[float, str], None]] = None ) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]: """Reload all models from scratch""" logger.info("Reloading models...") self._cleanup_models() self.loading_stats["models_loaded"] = False return self.load_all_models(progress_callback) @property def models_ready(self) -> bool: """Check if any models are loaded and ready""" return self.sam2_predictor is not None or self.matanyone_model is not None def get_sam2(self): """Get SAM2 predictor model""" return self.sam2_predictor.model if self.sam2_predictor else None def get_matanyone(self): """ Get MatAnyone processor model. IMPORTANT: This returns the stateful callable adapter from MatAnyoneLoader: - callable(image_rgb01[, mask01]) -> 2D alpha - optional .reset() to clear memory per video """ return self.matanyone_model.model if self.matanyone_model else None def validate_models(self) -> bool: """Validate that loaded models have expected interfaces""" try: valid = False # Validate SAM2 if self.sam2_predictor: model = self.sam2_predictor.model if hasattr(model, "set_image") and hasattr(model, "predict"): valid = True logger.info("SAM2 model validated") # Validate MatAnyone (stateful adapter OR raw core) if self.matanyone_model: model = self.matanyone_model.model if callable(model): valid = True logger.info("MatAnyone adapter validated (callable)") elif hasattr(model, "step") or hasattr(model, "process"): valid = True logger.info("MatAnyone core validated (.step/.process)") elif hasattr(model, "reset"): # still accept an adapter exposing reset but not callable (unlikely) valid = True logger.info("MatAnyone object validated via reset()") else: logger.warning("MatAnyone present but interface not recognized") return valid except Exception as e: logger.error(f"Model validation failed: {e}") return False def get_model_info(self) -> Dict[str, Any]: """Get detailed information about loaded models""" info: Dict[str, Any] = { "models_loaded": self.loading_stats["models_loaded"], "device": str(self.device), "loading_stats": self.loading_stats.copy(), } # Add SAM2 info info["sam2"] = self.sam2_loader.get_info() if self.sam2_loader else {} # Add MatAnyone info (augment with interface hints) mat_info = self.matanyone_loader.get_info() if self.matanyone_loader else {} try: m = self.get_matanyone() mat_info["callable"] = bool(callable(m)) mat_info["has_reset"] = bool(hasattr(m, "reset")) mat_info["has_step"] = bool(hasattr(m, "step")) mat_info["has_process"] = bool(hasattr(m, "process")) except Exception: pass info["matanyone"] = mat_info return info def get_load_summary(self) -> str: """Get human-readable loading summary""" if not self.loading_stats["models_loaded"]: return "No models loaded" lines = [] lines.append(f"Models loaded in {self.loading_stats['total_load_time']:.1f}s") if self.sam2_predictor: lines.append(f"✓ SAM2: {self.loading_stats['sam2_load_time']:.1f}s") lines.append(f" Model: {self.sam2_predictor.model_id}") else: lines.append("✗ SAM2: Failed to load") if self.matanyone_model: # Describe adapter/callable for clarity iface = [] m = self.matanyone_model.model if callable(m): iface.append("callable") if hasattr(m, "reset"): iface.append("reset") if hasattr(m, "step"): iface.append("step") if hasattr(m, "process"): iface.append("process") iface_str = f" ({', '.join(iface)})" if iface else "" lines.append(f"✓ MatAnyone: {self.loading_stats['matanyone_load_time']:.1f}s{iface_str}") lines.append(f" Model: {self.matanyone_model.model_id}") else: lines.append("✗ MatAnyone: Failed to load") lines.append(f"Device: {self.device}") return "\n".join(lines) def cleanup(self): """Clean up all resources""" self._cleanup_models() logger.info("ModelLoader cleanup completed") def _cleanup_models(self): """Internal cleanup of loaded models""" # Clean up SAM2 if self.sam2_loader: try: if hasattr(self.sam2_loader, 'cleanup'): self.sam2_loader.cleanup() else: logger.debug("SAM2 loader has no cleanup method") except Exception as e: logger.debug(f"SAM2 loader cleanup error: {e}") if self.sam2_predictor: try: del self.sam2_predictor except Exception: pass self.sam2_predictor = None # Clean up MatAnyone if self.matanyone_loader: try: if hasattr(self.matanyone_loader, 'cleanup'): self.matanyone_loader.cleanup() else: # MatAnyone doesn't have cleanup, but we can clean the wrapper if hasattr(self.matanyone_loader, '_wrapper') and self.matanyone_loader._wrapper: if hasattr(self.matanyone_loader._wrapper, 'reset'): self.matanyone_loader._wrapper.reset() self.matanyone_loader._processor = None self.matanyone_loader._wrapper = None logger.debug("MatAnyone loader cleaned up manually") except Exception as e: logger.debug(f"MatAnyone loader cleanup error: {e}") if self.matanyone_model: try: del self.matanyone_model except Exception: pass self.matanyone_model = None # Clear CUDA cache if torch.cuda.is_available(): torch.cuda.empty_cache() # Garbage collection gc.collect() logger.debug("Model cleanup completed")