""" MatAnyone Loader - Stable Callable Wrapper for InferenceCore (extra-dim stripping) ================================================================================= - Always call InferenceCore UNBATCHED: image -> CHW float32 [0,1] mask -> 1HW float32 [0,1] - Aggressively strip extra dims: e.g. [B,T,C,H,W] -> [C,H,W] (use first slice when B/T > 1 with a warning) e.g. [B,C,H,W] -> [C,H,W] e.g. [H,W,C,1] -> [H,W,C] - Robust alpha extraction -> (H,W) float32 [0,1] """ from __future__ import annotations import logging from typing import Optional, Dict, Any, Tuple, Union import numpy as np import torch logger = logging.getLogger(__name__) try: # Official import path from matanyone.inference.inference_core import InferenceCore except Exception: # keep import error defered until load() InferenceCore = None # type: ignore # ------------------------------ Helpers ------------------------------ def _to_float01_np(arr: np.ndarray) -> np.ndarray: """Ensure numpy array is float32 in [0,1].""" if arr.dtype == np.uint8: arr = arr.astype(np.float32) / 255.0 else: arr = arr.astype(np.float32, copy=False) np.clip(arr, 0.0, 1.0, out=arr) return arr def _strip_leading_extras_to_ndim(x: Union[np.ndarray, torch.Tensor], target_ndim: int) -> Union[np.ndarray, torch.Tensor]: """ Reduce x to at most target_ndim by removing leading dims. - If a leading dim == 1, squeeze it. - If a leading dim > 1, take the first slice and log a warning. Repeat until ndim <= target_ndim. """ is_tensor = torch.is_tensor(x) get_shape = (lambda t: tuple(t.shape)) if is_tensor else (lambda a: a.shape) index_first = (lambda t: t[0]) if is_tensor else (lambda a: a[0]) squeeze_first = (lambda t: t.squeeze(0)) if is_tensor else (lambda a: np.squeeze(a, axis=0)) while len(get_shape(x)) > target_ndim: dim0 = get_shape(x)[0] if dim0 == 1: x = squeeze_first(x) else: logger.warning(f"Input has extra leading dim >1 ({dim0}); taking the first slice.") x = index_first(x) return x def _ensure_chw_float01(image: Union[np.ndarray, torch.Tensor], *, name: str = "image") -> torch.Tensor: """ Convert image to torch.FloatTensor CHW in [0,1], stripping extras. Accepts shapes up to 5D (e.g. B,T,C,H,W / B,C,H,W / H,W,C / CHW / HW / ...). If ambiguous multi-channel, picks first channel with a warning. """ orig_shape = tuple(image.shape) if not torch.is_tensor(image) else tuple(image.shape) # Reduce to <= 3 dims image = _strip_leading_extras_to_ndim(image, 3) if torch.is_tensor(image): t = image if t.ndim == 4: t = _strip_leading_extras_to_ndim(t, 3) if t.ndim == 3: c0, c1, c2 = t.shape if c0 in (1, 3, 4): pass # CHW elif c2 in (1, 3, 4): t = t.permute(2, 0, 1) # HWC -> CHW else: logger.warning(f"{name}: ambiguous 3D shape {tuple(t.shape)}; attempting HWC->CHW then selecting first channel.") t = t.permute(2, 0, 1) if t.shape[0] > 1: t = t[0] t = t.unsqueeze(0) elif t.ndim == 2: t = t.unsqueeze(0) # 1HW else: raise ValueError(f"{name}: unsupported tensor dims {tuple(t.shape)} after stripping.") t = t.to(dtype=torch.float32) if torch.max(t) > 1.5: t = t / 255.0 t = torch.clamp(t, 0.0, 1.0) logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (CHW)") return t arr = np.asarray(image) if arr.ndim == 4: arr = _strip_leading_extras_to_ndim(arr, 3) if arr.ndim == 3: if arr.shape[0] in (1, 3, 4): pass # CHW elif arr.shape[-1] in (1, 3, 4): arr = arr.transpose(2, 0, 1) # HWC -> CHW else: logger.warning(f"{name}: ambiguous 3D shape {arr.shape}; trying HWC->CHW and selecting first channel.") arr = arr.transpose(2, 0, 1) if arr.shape[0] > 1: arr = arr[0:1, ...] elif arr.ndim == 2: arr = arr[None, ...] # 1HW else: raise ValueError(f"{name}: unsupported numpy dims {arr.shape} after stripping.") arr = _to_float01_np(arr) t = torch.from_numpy(arr) logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (CHW)") return t def _ensure_1hw_float01(mask: Union[np.ndarray, torch.Tensor], *, name: str = "mask") -> torch.Tensor: """ Convert mask to torch.FloatTensor 1HW in [0,1], stripping extras. Accepts up to 4D inputs; collapses leading dims; picks first slice/channel if needed. """ orig_shape = tuple(mask.shape) if not torch.is_tensor(mask) else tuple(mask.shape) mask = _strip_leading_extras_to_ndim(mask, 3) if torch.is_tensor(mask): m = mask if m.ndim == 3: if m.shape[0] == 1: pass # 1HW elif m.shape[-1] == 1: m = m.permute(2, 0, 1) # HW1 -> 1HW else: logger.warning(f"{name}: multi-channel {tuple(m.shape)}; using first channel.") if m.shape[0] in (3, 4): m = m[0:1, ...] elif m.shape[-1] in (3, 4): m = m.permute(2, 0, 1)[0:1, ...] else: m = m[0:1, ...] elif m.ndim == 2: m = m.unsqueeze(0) else: raise ValueError(f"{name}: unsupported tensor dims {tuple(m.shape)} after stripping.") m = m.to(dtype=torch.float32) if torch.max(m) > 1.5: m = m / 255.0 m = torch.clamp(m, 0.0, 1.0) logger.debug(f"{name}: {orig_shape} -> {tuple(m.shape)} (1HW)") return m arr = np.asarray(mask) if arr.ndim == 3: if arr.shape[0] == 1: pass # 1HW elif arr.shape[-1] == 1: arr = arr.transpose(2, 0, 1) # HW1 -> 1HW else: logger.warning(f"{name}: multi-channel {arr.shape}; using first channel.") if arr.shape[0] in (3, 4): arr = arr[0:1, ...] elif arr.shape[-1] in (3, 4): arr = arr.transpose(2, 0, 1)[0:1, ...] else: arr = arr[0:1, ...] elif arr.ndim == 2: arr = arr[None, ...] else: raise ValueError(f"{name}: unsupported numpy dims {arr.shape} after stripping.") arr = _to_float01_np(arr) t = torch.from_numpy(arr) logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (1HW)") return t def _alpha_from_result(result: Union[np.ndarray, torch.Tensor]) -> np.ndarray: """Extract a 2D alpha (H,W) float32 [0,1] from various outputs.""" if result is None: return np.full((512, 512), 0.5, dtype=np.float32) if torch.is_tensor(result): result = result.detach().float().cpu() arr = np.asarray(result) while arr.ndim > 3: if arr.shape[0] > 1: logger.warning(f"Result has leading dim {arr.shape[0]}; taking first slice.") arr = arr[0] if arr.ndim == 2: alpha = arr elif arr.ndim == 3: if arr.shape[0] in (1, 3, 4): alpha = arr[0] elif arr.shape[-1] in (1, 3, 4): alpha = arr[..., 0] else: alpha = arr[0] else: alpha = np.full((512, 512), 0.5, dtype=np.float32) alpha = alpha.astype(np.float32, copy=False) np.clip(alpha, 0.0, 1.0, out=alpha) return alpha def _hw_from_image_like(x: Union[np.ndarray, torch.Tensor]) -> Tuple[int, int]: """Best-effort infer (H, W) for fallback mask sizing.""" shape = tuple(x.shape) if torch.is_tensor(x) else np.asarray(x).shape if len(shape) == 2: return shape[0], shape[1] if len(shape) == 3: if shape[0] in (1, 3, 4): return shape[1], shape[2] if shape[-1] in (1, 3, 4): return shape[0], shape[1] return shape[1], shape[2] if len(shape) >= 4: if len(shape) >= 4 and (shape[1] in (1, 3, 4)): return shape[2], shape[3] return shape[-3], shape[-2] return 512, 512 # --------------------------- Callable Wrapper --------------------------- class MatAnyoneCallableWrapper: """ Callable session-like wrapper around an InferenceCore instance. Contract: - First call SHOULD include a mask (1HW). If not, returns neutral 0.5 alpha. - Subsequent calls do not require mask. - Returns 2D alpha (H,W) float32 in [0,1]. - Strips any extra dims from inputs before calling core. """ def __init__(self, inference_core, device: str = None): self.core = inference_core self.initialized = False # Best-effort device selection if available if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device def __call__(self, image, mask=None, **kwargs) -> np.ndarray: try: img_chw = _ensure_chw_float01(image, name="image").to(self.device, non_blocking=True) if not self.initialized: if mask is None: h, w = _hw_from_image_like(image) logger.warning("MatAnyone first frame called without mask; returning neutral alpha.") return np.full((h, w), 0.5, dtype=np.float32) m_1hw = _ensure_1hw_float01(mask, name="mask").to(self.device, non_blocking=True) with torch.inference_mode(): if hasattr(self.core, "step"): result = self.core.step(image=img_chw, mask=m_1hw, **kwargs) elif hasattr(self.core, "process_frame"): result = self.core.process_frame(img_chw, m_1hw, **kwargs) else: logger.warning("InferenceCore has no recognized frame API; echoing input mask.") return _alpha_from_result(mask) self.initialized = True return _alpha_from_result(result) # Subsequent frames (no mask) with torch.inference_mode(): if hasattr(self.core, "step"): result = self.core.step(image=img_chw, **kwargs) elif hasattr(self.core, "process_frame"): result = self.core.process_frame(img_chw, **kwargs) else: h, w = _hw_from_image_like(image) logger.warning("InferenceCore has no recognized frame API on subsequent call; returning neutral alpha.") return np.full((h, w), 0.5, dtype=np.float32) return _alpha_from_result(result) except Exception as e: logger.error(f"MatAnyone wrapper call failed: {e}") # Fallbacks if mask is not None: try: return _alpha_from_result(mask) except Exception: pass h, w = _hw_from_image_like(image) return np.full((h, w), 0.5, dtype=np.float32) def reset(self): """Reset state between videos.""" self.initialized = False if hasattr(self.core, "reset"): try: self.core.reset() except Exception as e: logger.debug(f"Core reset() failed: {e}") elif hasattr(self.core, "clear_memory"): try: self.core.clear_memory() except Exception as e: logger.debug(f"Core clear_memory() failed: {e}") # --------------------------- Main Loader Class --------------------------- class MatAnyoneLoader: """ Loader for MatAnyone InferenceCore with cleanup support. Provides a consistent interface with other model loaders, including proper resource cleanup. """ def __init__(self, device: str = "auto", model_id: str = "PeiqingYang/MatAnyone"): self.device = device self.model_id = model_id self._processor: Optional[InferenceCore] = None # type: ignore self._wrapper: Optional[MatAnyoneCallableWrapper] = None def load(self) -> Optional[Any]: """ Initialize and return a callable wrapper around InferenceCore. Returns MatAnyoneCallableWrapper if successful, else None. """ global InferenceCore try: if InferenceCore is None: from matanyone.inference.inference_core import InferenceCore as _IC # type: ignore InferenceCore = _IC # type: ignore logger.info("Loading MatAnyone InferenceCore ...") self._processor = InferenceCore(self.model_id) # type: ignore logger.info("MatAnyone InferenceCore loaded successfully") # Choose device dev = ( "cuda" if (str(self.device).startswith("cuda") and torch.cuda.is_available()) else ("cpu" if str(self.device) == "cpu" else ("cuda" if torch.cuda.is_available() else "cpu")) ) self._wrapper = MatAnyoneCallableWrapper(self._processor, device=dev) logger.info("MatAnyone wrapped with dimension-safe callable") return self._wrapper except Exception as e: logger.error(f"Failed to load MatAnyone InferenceCore: {e}") self._processor = None self._wrapper = None return None def get(self) -> Optional[Any]: """Return the cached callable if loaded.""" return self._wrapper or self._processor def get_info(self) -> Dict[str, Any]: """Metadata for diagnostics.""" return { "model_id": self.model_id, "loaded": self._wrapper is not None or self._processor is not None, "wrapped": self._wrapper is not None, } def cleanup(self): """ Clean up all resources associated with MatAnyone. This method ensures proper cleanup of: - The wrapper's state and memory - The InferenceCore processor - Any CUDA tensors in memory """ logger.debug("Starting MatAnyone cleanup...") # Clean up wrapper first if self._wrapper: try: self._wrapper.reset() logger.debug("MatAnyone wrapper reset completed") except Exception as e: logger.debug(f"Wrapper reset failed (non-critical): {e}") self._wrapper = None # Clean up processor if self._processor: try: # Try various cleanup methods that might exist if hasattr(self._processor, 'cleanup'): self._processor.cleanup() elif hasattr(self._processor, 'clear'): self._processor.clear() elif hasattr(self._processor, 'reset'): self._processor.reset() logger.debug("MatAnyone processor cleanup attempted") except Exception as e: logger.debug(f"Processor cleanup failed (non-critical): {e}") self._processor = None # Clear any CUDA cache if using GPU if self.device != "cpu" and torch.cuda.is_available(): try: torch.cuda.empty_cache() logger.debug("CUDA cache cleared for MatAnyone") except Exception as e: logger.debug(f"CUDA cache clear failed: {e}") logger.info("MatAnyone resources cleaned up")