|
|
|
|
|
""" |
|
|
Unified Model Loader |
|
|
Coordinates separate SAM2 and MatAnyone loaders for cleaner architecture |
|
|
|
|
|
Notes: |
|
|
- SAM2: exposes set_image(...) and predict(...) |
|
|
- MatAnyone: our loader returns a stateful callable adapter: |
|
|
- callable(adapter) -> frame0: adapter(image_rgb01, mask01), frames>0: adapter(image_rgb01) |
|
|
- optional: adapter.reset() to clear per-video memory |
|
|
We therefore validate MatAnyone by checking "callable(...)" and/or presence of "reset", |
|
|
not only ".step/.process". |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import gc |
|
|
import time |
|
|
import logging |
|
|
from typing import Optional, Dict, Any, Tuple, Callable |
|
|
|
|
|
import torch |
|
|
|
|
|
from core.exceptions import ModelLoadingError |
|
|
from utils.hardware.device_manager import DeviceManager |
|
|
from utils.system.memory_manager import MemoryManager |
|
|
|
|
|
|
|
|
from models.loaders.sam2_loader import SAM2Loader |
|
|
from models.loaders.matanyone_loader import MatAnyoneLoader |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class LoadedModel: |
|
|
"""Container for loaded model information""" |
|
|
def __init__(self, model=None, model_id: str = "", load_time: float = 0.0, |
|
|
device: str = "", framework: str = ""): |
|
|
self.model = model |
|
|
self.model_id = model_id |
|
|
self.load_time = load_time |
|
|
self.device = device |
|
|
self.framework = framework |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
return { |
|
|
"model_id": self.model_id, |
|
|
"framework": self.framework, |
|
|
"device": self.device, |
|
|
"load_time": self.load_time, |
|
|
"loaded": self.model is not None, |
|
|
"model_type": type(self.model).__name__ if self.model is not None else None, |
|
|
} |
|
|
|
|
|
|
|
|
class ModelLoader: |
|
|
"""Main model loader that coordinates SAM2 and MatAnyone loaders""" |
|
|
|
|
|
def __init__(self, device_mgr: DeviceManager, memory_mgr: MemoryManager): |
|
|
self.device_manager = device_mgr |
|
|
self.memory_manager = memory_mgr |
|
|
self.device = self.device_manager.get_optimal_device() |
|
|
|
|
|
|
|
|
self.sam2_loader = SAM2Loader(device=str(self.device)) |
|
|
self.matanyone_loader = MatAnyoneLoader(device=str(self.device)) |
|
|
|
|
|
|
|
|
self.sam2_predictor: Optional[LoadedModel] = None |
|
|
self.matanyone_model: Optional[LoadedModel] = None |
|
|
|
|
|
|
|
|
self.loading_stats = { |
|
|
"sam2_load_time": 0.0, |
|
|
"matanyone_load_time": 0.0, |
|
|
"total_load_time": 0.0, |
|
|
"models_loaded": False, |
|
|
"loading_attempts": 0, |
|
|
} |
|
|
|
|
|
logger.info(f"ModelLoader initialized for device: {self.device}") |
|
|
|
|
|
def load_all_models( |
|
|
self, |
|
|
progress_callback: Optional[Callable[[float, str], None]] = None, |
|
|
cancel_event=None |
|
|
) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]: |
|
|
""" |
|
|
Load all models using specialized loaders |
|
|
|
|
|
Returns: |
|
|
Tuple of (sam2_model, matanyone_model) |
|
|
""" |
|
|
start_time = time.time() |
|
|
self.loading_stats["loading_attempts"] += 1 |
|
|
|
|
|
try: |
|
|
logger.info("Starting model loading process...") |
|
|
if progress_callback: |
|
|
progress_callback(0.0, "Initializing model loading...") |
|
|
|
|
|
|
|
|
self._cleanup_models() |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(0.1, "Loading SAM2 model...") |
|
|
|
|
|
sam2_start = time.time() |
|
|
|
|
|
sam2_model = self.sam2_loader.load("tiny") |
|
|
sam2_time = time.time() - sam2_start |
|
|
|
|
|
if sam2_model: |
|
|
self.sam2_predictor = LoadedModel( |
|
|
model=sam2_model, |
|
|
model_id=self.sam2_loader.model_id, |
|
|
load_time=sam2_time, |
|
|
device=str(self.device), |
|
|
framework="sam2" |
|
|
) |
|
|
self.loading_stats["sam2_load_time"] = sam2_time |
|
|
logger.info(f"SAM2 loaded in {sam2_time:.2f}s") |
|
|
else: |
|
|
logger.warning("SAM2 loading failed") |
|
|
|
|
|
|
|
|
if cancel_event and cancel_event.is_set(): |
|
|
if progress_callback: |
|
|
progress_callback(1.0, "Model loading cancelled") |
|
|
return self.sam2_predictor, None |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(0.6, "Loading MatAnyone model...") |
|
|
|
|
|
matanyone_start = time.time() |
|
|
matanyone_model = self.matanyone_loader.load() |
|
|
matanyone_time = time.time() - matanyone_start |
|
|
|
|
|
if matanyone_model: |
|
|
self.matanyone_model = LoadedModel( |
|
|
model=matanyone_model, |
|
|
model_id=self.matanyone_loader.model_id, |
|
|
load_time=matanyone_time, |
|
|
device=str(self.device), |
|
|
framework="matanyone" |
|
|
) |
|
|
self.loading_stats["matanyone_load_time"] = matanyone_time |
|
|
logger.info(f"MatAnyone loaded in {matanyone_time:.2f}s") |
|
|
else: |
|
|
logger.warning("MatAnyone loading failed") |
|
|
|
|
|
|
|
|
total_time = time.time() - start_time |
|
|
self.loading_stats["total_load_time"] = total_time |
|
|
self.loading_stats["models_loaded"] = bool(self.sam2_predictor or self.matanyone_model) |
|
|
|
|
|
if progress_callback: |
|
|
if self.loading_stats["models_loaded"]: |
|
|
progress_callback(1.0, "Models loaded successfully") |
|
|
else: |
|
|
progress_callback(1.0, "Model loading completed with failures") |
|
|
|
|
|
logger.info(f"Model loading completed in {total_time:.2f}s") |
|
|
return self.sam2_predictor, self.matanyone_model |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Model loading failed: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
self._cleanup_models() |
|
|
self.loading_stats["models_loaded"] = False |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(1.0, f"Error: {error_msg}") |
|
|
|
|
|
return None, None |
|
|
|
|
|
def reload_models( |
|
|
self, |
|
|
progress_callback: Optional[Callable[[float, str], None]] = None |
|
|
) -> Tuple[Optional[LoadedModel], Optional[LoadedModel]]: |
|
|
"""Reload all models from scratch""" |
|
|
logger.info("Reloading models...") |
|
|
self._cleanup_models() |
|
|
self.loading_stats["models_loaded"] = False |
|
|
return self.load_all_models(progress_callback) |
|
|
|
|
|
@property |
|
|
def models_ready(self) -> bool: |
|
|
"""Check if any models are loaded and ready""" |
|
|
return self.sam2_predictor is not None or self.matanyone_model is not None |
|
|
|
|
|
def get_sam2(self): |
|
|
"""Get SAM2 predictor model""" |
|
|
return self.sam2_predictor.model if self.sam2_predictor else None |
|
|
|
|
|
def get_matanyone(self): |
|
|
""" |
|
|
Get MatAnyone processor model. |
|
|
|
|
|
IMPORTANT: This returns the stateful callable adapter from MatAnyoneLoader: |
|
|
- callable(image_rgb01[, mask01]) -> 2D alpha |
|
|
- optional .reset() to clear memory per video |
|
|
""" |
|
|
return self.matanyone_model.model if self.matanyone_model else None |
|
|
|
|
|
def validate_models(self) -> bool: |
|
|
"""Validate that loaded models have expected interfaces""" |
|
|
try: |
|
|
valid = False |
|
|
|
|
|
|
|
|
if self.sam2_predictor: |
|
|
model = self.sam2_predictor.model |
|
|
if hasattr(model, "set_image") and hasattr(model, "predict"): |
|
|
valid = True |
|
|
logger.info("SAM2 model validated") |
|
|
|
|
|
|
|
|
if self.matanyone_model: |
|
|
model = self.matanyone_model.model |
|
|
if callable(model): |
|
|
valid = True |
|
|
logger.info("MatAnyone adapter validated (callable)") |
|
|
elif hasattr(model, "step") or hasattr(model, "process"): |
|
|
valid = True |
|
|
logger.info("MatAnyone core validated (.step/.process)") |
|
|
elif hasattr(model, "reset"): |
|
|
|
|
|
valid = True |
|
|
logger.info("MatAnyone object validated via reset()") |
|
|
else: |
|
|
logger.warning("MatAnyone present but interface not recognized") |
|
|
|
|
|
return valid |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Model validation failed: {e}") |
|
|
return False |
|
|
|
|
|
def get_model_info(self) -> Dict[str, Any]: |
|
|
"""Get detailed information about loaded models""" |
|
|
info: Dict[str, Any] = { |
|
|
"models_loaded": self.loading_stats["models_loaded"], |
|
|
"device": str(self.device), |
|
|
"loading_stats": self.loading_stats.copy(), |
|
|
} |
|
|
|
|
|
|
|
|
info["sam2"] = self.sam2_loader.get_info() if self.sam2_loader else {} |
|
|
|
|
|
|
|
|
mat_info = self.matanyone_loader.get_info() if self.matanyone_loader else {} |
|
|
try: |
|
|
m = self.get_matanyone() |
|
|
mat_info["callable"] = bool(callable(m)) |
|
|
mat_info["has_reset"] = bool(hasattr(m, "reset")) |
|
|
mat_info["has_step"] = bool(hasattr(m, "step")) |
|
|
mat_info["has_process"] = bool(hasattr(m, "process")) |
|
|
except Exception: |
|
|
pass |
|
|
info["matanyone"] = mat_info |
|
|
|
|
|
return info |
|
|
|
|
|
def get_load_summary(self) -> str: |
|
|
"""Get human-readable loading summary""" |
|
|
if not self.loading_stats["models_loaded"]: |
|
|
return "No models loaded" |
|
|
|
|
|
lines = [] |
|
|
lines.append(f"Models loaded in {self.loading_stats['total_load_time']:.1f}s") |
|
|
|
|
|
if self.sam2_predictor: |
|
|
lines.append(f"β SAM2: {self.loading_stats['sam2_load_time']:.1f}s") |
|
|
lines.append(f" Model: {self.sam2_predictor.model_id}") |
|
|
else: |
|
|
lines.append("β SAM2: Failed to load") |
|
|
|
|
|
if self.matanyone_model: |
|
|
|
|
|
iface = [] |
|
|
m = self.matanyone_model.model |
|
|
if callable(m): iface.append("callable") |
|
|
if hasattr(m, "reset"): iface.append("reset") |
|
|
if hasattr(m, "step"): iface.append("step") |
|
|
if hasattr(m, "process"): iface.append("process") |
|
|
iface_str = f" ({', '.join(iface)})" if iface else "" |
|
|
|
|
|
lines.append(f"β MatAnyone: {self.loading_stats['matanyone_load_time']:.1f}s{iface_str}") |
|
|
lines.append(f" Model: {self.matanyone_model.model_id}") |
|
|
else: |
|
|
lines.append("β MatAnyone: Failed to load") |
|
|
|
|
|
lines.append(f"Device: {self.device}") |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
def cleanup(self): |
|
|
"""Clean up all resources""" |
|
|
self._cleanup_models() |
|
|
logger.info("ModelLoader cleanup completed") |
|
|
|
|
|
def _cleanup_models(self): |
|
|
"""Internal cleanup of loaded models""" |
|
|
|
|
|
if self.sam2_loader: |
|
|
try: |
|
|
if hasattr(self.sam2_loader, 'cleanup'): |
|
|
self.sam2_loader.cleanup() |
|
|
else: |
|
|
logger.debug("SAM2 loader has no cleanup method") |
|
|
except Exception as e: |
|
|
logger.debug(f"SAM2 loader cleanup error: {e}") |
|
|
|
|
|
if self.sam2_predictor: |
|
|
try: |
|
|
del self.sam2_predictor |
|
|
except Exception: |
|
|
pass |
|
|
self.sam2_predictor = None |
|
|
|
|
|
|
|
|
if self.matanyone_loader: |
|
|
try: |
|
|
if hasattr(self.matanyone_loader, 'cleanup'): |
|
|
self.matanyone_loader.cleanup() |
|
|
else: |
|
|
|
|
|
if hasattr(self.matanyone_loader, '_wrapper') and self.matanyone_loader._wrapper: |
|
|
if hasattr(self.matanyone_loader._wrapper, 'reset'): |
|
|
self.matanyone_loader._wrapper.reset() |
|
|
self.matanyone_loader._processor = None |
|
|
self.matanyone_loader._wrapper = None |
|
|
logger.debug("MatAnyone loader cleaned up manually") |
|
|
except Exception as e: |
|
|
logger.debug(f"MatAnyone loader cleanup error: {e}") |
|
|
|
|
|
if self.matanyone_model: |
|
|
try: |
|
|
del self.matanyone_model |
|
|
except Exception: |
|
|
pass |
|
|
self.matanyone_model = None |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
gc.collect() |
|
|
|
|
|
logger.debug("Model cleanup completed") |