MogensR's picture
Update models/loaders/matanyone_loader.py
9337085
"""
MatAnyone Loader - Stable Callable Wrapper for InferenceCore (extra-dim stripping)
=================================================================================
- Always call InferenceCore UNBATCHED:
image -> CHW float32 [0,1]
mask -> 1HW float32 [0,1]
- Aggressively strip extra dims:
e.g. [B,T,C,H,W] -> [C,H,W] (use first slice when B/T > 1 with a warning)
e.g. [B,C,H,W] -> [C,H,W]
e.g. [H,W,C,1] -> [H,W,C]
- Robust alpha extraction -> (H,W) float32 [0,1]
"""
from __future__ import annotations
import logging
from typing import Optional, Dict, Any, Tuple, Union
import numpy as np
import torch
logger = logging.getLogger(__name__)
try:
# Official import path
from matanyone.inference.inference_core import InferenceCore
except Exception: # keep import error defered until load()
InferenceCore = None # type: ignore
# ------------------------------ Helpers ------------------------------
def _to_float01_np(arr: np.ndarray) -> np.ndarray:
"""Ensure numpy array is float32 in [0,1]."""
if arr.dtype == np.uint8:
arr = arr.astype(np.float32) / 255.0
else:
arr = arr.astype(np.float32, copy=False)
np.clip(arr, 0.0, 1.0, out=arr)
return arr
def _strip_leading_extras_to_ndim(x: Union[np.ndarray, torch.Tensor], target_ndim: int) -> Union[np.ndarray, torch.Tensor]:
"""
Reduce x to at most target_ndim by removing leading dims.
- If a leading dim == 1, squeeze it.
- If a leading dim > 1, take the first slice and log a warning.
Repeat until ndim <= target_ndim.
"""
is_tensor = torch.is_tensor(x)
get_shape = (lambda t: tuple(t.shape)) if is_tensor else (lambda a: a.shape)
index_first = (lambda t: t[0]) if is_tensor else (lambda a: a[0])
squeeze_first = (lambda t: t.squeeze(0)) if is_tensor else (lambda a: np.squeeze(a, axis=0))
while len(get_shape(x)) > target_ndim:
dim0 = get_shape(x)[0]
if dim0 == 1:
x = squeeze_first(x)
else:
logger.warning(f"Input has extra leading dim >1 ({dim0}); taking the first slice.")
x = index_first(x)
return x
def _ensure_chw_float01(image: Union[np.ndarray, torch.Tensor], *, name: str = "image") -> torch.Tensor:
"""
Convert image to torch.FloatTensor CHW in [0,1], stripping extras.
Accepts shapes up to 5D (e.g. B,T,C,H,W / B,C,H,W / H,W,C / CHW / HW / ...).
If ambiguous multi-channel, picks first channel with a warning.
"""
orig_shape = tuple(image.shape) if not torch.is_tensor(image) else tuple(image.shape)
# Reduce to <= 3 dims
image = _strip_leading_extras_to_ndim(image, 3)
if torch.is_tensor(image):
t = image
if t.ndim == 4:
t = _strip_leading_extras_to_ndim(t, 3)
if t.ndim == 3:
c0, c1, c2 = t.shape
if c0 in (1, 3, 4):
pass # CHW
elif c2 in (1, 3, 4):
t = t.permute(2, 0, 1) # HWC -> CHW
else:
logger.warning(f"{name}: ambiguous 3D shape {tuple(t.shape)}; attempting HWC->CHW then selecting first channel.")
t = t.permute(2, 0, 1)
if t.shape[0] > 1:
t = t[0]
t = t.unsqueeze(0)
elif t.ndim == 2:
t = t.unsqueeze(0) # 1HW
else:
raise ValueError(f"{name}: unsupported tensor dims {tuple(t.shape)} after stripping.")
t = t.to(dtype=torch.float32)
if torch.max(t) > 1.5:
t = t / 255.0
t = torch.clamp(t, 0.0, 1.0)
logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (CHW)")
return t
arr = np.asarray(image)
if arr.ndim == 4:
arr = _strip_leading_extras_to_ndim(arr, 3)
if arr.ndim == 3:
if arr.shape[0] in (1, 3, 4):
pass # CHW
elif arr.shape[-1] in (1, 3, 4):
arr = arr.transpose(2, 0, 1) # HWC -> CHW
else:
logger.warning(f"{name}: ambiguous 3D shape {arr.shape}; trying HWC->CHW and selecting first channel.")
arr = arr.transpose(2, 0, 1)
if arr.shape[0] > 1:
arr = arr[0:1, ...]
elif arr.ndim == 2:
arr = arr[None, ...] # 1HW
else:
raise ValueError(f"{name}: unsupported numpy dims {arr.shape} after stripping.")
arr = _to_float01_np(arr)
t = torch.from_numpy(arr)
logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (CHW)")
return t
def _ensure_1hw_float01(mask: Union[np.ndarray, torch.Tensor], *, name: str = "mask") -> torch.Tensor:
"""
Convert mask to torch.FloatTensor 1HW in [0,1], stripping extras.
Accepts up to 4D inputs; collapses leading dims; picks first slice/channel if needed.
"""
orig_shape = tuple(mask.shape) if not torch.is_tensor(mask) else tuple(mask.shape)
mask = _strip_leading_extras_to_ndim(mask, 3)
if torch.is_tensor(mask):
m = mask
if m.ndim == 3:
if m.shape[0] == 1:
pass # 1HW
elif m.shape[-1] == 1:
m = m.permute(2, 0, 1) # HW1 -> 1HW
else:
logger.warning(f"{name}: multi-channel {tuple(m.shape)}; using first channel.")
if m.shape[0] in (3, 4):
m = m[0:1, ...]
elif m.shape[-1] in (3, 4):
m = m.permute(2, 0, 1)[0:1, ...]
else:
m = m[0:1, ...]
elif m.ndim == 2:
m = m.unsqueeze(0)
else:
raise ValueError(f"{name}: unsupported tensor dims {tuple(m.shape)} after stripping.")
m = m.to(dtype=torch.float32)
if torch.max(m) > 1.5:
m = m / 255.0
m = torch.clamp(m, 0.0, 1.0)
logger.debug(f"{name}: {orig_shape} -> {tuple(m.shape)} (1HW)")
return m
arr = np.asarray(mask)
if arr.ndim == 3:
if arr.shape[0] == 1:
pass # 1HW
elif arr.shape[-1] == 1:
arr = arr.transpose(2, 0, 1) # HW1 -> 1HW
else:
logger.warning(f"{name}: multi-channel {arr.shape}; using first channel.")
if arr.shape[0] in (3, 4):
arr = arr[0:1, ...]
elif arr.shape[-1] in (3, 4):
arr = arr.transpose(2, 0, 1)[0:1, ...]
else:
arr = arr[0:1, ...]
elif arr.ndim == 2:
arr = arr[None, ...]
else:
raise ValueError(f"{name}: unsupported numpy dims {arr.shape} after stripping.")
arr = _to_float01_np(arr)
t = torch.from_numpy(arr)
logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (1HW)")
return t
def _alpha_from_result(result: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
"""Extract a 2D alpha (H,W) float32 [0,1] from various outputs."""
if result is None:
return np.full((512, 512), 0.5, dtype=np.float32)
if torch.is_tensor(result):
result = result.detach().float().cpu()
arr = np.asarray(result)
while arr.ndim > 3:
if arr.shape[0] > 1:
logger.warning(f"Result has leading dim {arr.shape[0]}; taking first slice.")
arr = arr[0]
if arr.ndim == 2:
alpha = arr
elif arr.ndim == 3:
if arr.shape[0] in (1, 3, 4):
alpha = arr[0]
elif arr.shape[-1] in (1, 3, 4):
alpha = arr[..., 0]
else:
alpha = arr[0]
else:
alpha = np.full((512, 512), 0.5, dtype=np.float32)
alpha = alpha.astype(np.float32, copy=False)
np.clip(alpha, 0.0, 1.0, out=alpha)
return alpha
def _hw_from_image_like(x: Union[np.ndarray, torch.Tensor]) -> Tuple[int, int]:
"""Best-effort infer (H, W) for fallback mask sizing."""
shape = tuple(x.shape) if torch.is_tensor(x) else np.asarray(x).shape
if len(shape) == 2:
return shape[0], shape[1]
if len(shape) == 3:
if shape[0] in (1, 3, 4):
return shape[1], shape[2]
if shape[-1] in (1, 3, 4):
return shape[0], shape[1]
return shape[1], shape[2]
if len(shape) >= 4:
if len(shape) >= 4 and (shape[1] in (1, 3, 4)):
return shape[2], shape[3]
return shape[-3], shape[-2]
return 512, 512
# --------------------------- Callable Wrapper ---------------------------
class MatAnyoneCallableWrapper:
"""
Callable session-like wrapper around an InferenceCore instance.
Contract:
- First call SHOULD include a mask (1HW). If not, returns neutral 0.5 alpha.
- Subsequent calls do not require mask.
- Returns 2D alpha (H,W) float32 in [0,1].
- Strips any extra dims from inputs before calling core.
"""
def __init__(self, inference_core, device: str = None):
self.core = inference_core
self.initialized = False
# Best-effort device selection if available
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
def __call__(self, image, mask=None, **kwargs) -> np.ndarray:
try:
img_chw = _ensure_chw_float01(image, name="image").to(self.device, non_blocking=True)
if not self.initialized:
if mask is None:
h, w = _hw_from_image_like(image)
logger.warning("MatAnyone first frame called without mask; returning neutral alpha.")
return np.full((h, w), 0.5, dtype=np.float32)
m_1hw = _ensure_1hw_float01(mask, name="mask").to(self.device, non_blocking=True)
with torch.inference_mode():
if hasattr(self.core, "step"):
result = self.core.step(image=img_chw, mask=m_1hw, **kwargs)
elif hasattr(self.core, "process_frame"):
result = self.core.process_frame(img_chw, m_1hw, **kwargs)
else:
logger.warning("InferenceCore has no recognized frame API; echoing input mask.")
return _alpha_from_result(mask)
self.initialized = True
return _alpha_from_result(result)
# Subsequent frames (no mask)
with torch.inference_mode():
if hasattr(self.core, "step"):
result = self.core.step(image=img_chw, **kwargs)
elif hasattr(self.core, "process_frame"):
result = self.core.process_frame(img_chw, **kwargs)
else:
h, w = _hw_from_image_like(image)
logger.warning("InferenceCore has no recognized frame API on subsequent call; returning neutral alpha.")
return np.full((h, w), 0.5, dtype=np.float32)
return _alpha_from_result(result)
except Exception as e:
logger.error(f"MatAnyone wrapper call failed: {e}")
# Fallbacks
if mask is not None:
try:
return _alpha_from_result(mask)
except Exception:
pass
h, w = _hw_from_image_like(image)
return np.full((h, w), 0.5, dtype=np.float32)
def reset(self):
"""Reset state between videos."""
self.initialized = False
if hasattr(self.core, "reset"):
try:
self.core.reset()
except Exception as e:
logger.debug(f"Core reset() failed: {e}")
elif hasattr(self.core, "clear_memory"):
try:
self.core.clear_memory()
except Exception as e:
logger.debug(f"Core clear_memory() failed: {e}")
# --------------------------- Main Loader Class ---------------------------
class MatAnyoneLoader:
"""
Loader for MatAnyone InferenceCore with cleanup support.
Provides a consistent interface with other model loaders,
including proper resource cleanup.
"""
def __init__(self, device: str = "auto", model_id: str = "PeiqingYang/MatAnyone"):
self.device = device
self.model_id = model_id
self._processor: Optional[InferenceCore] = None # type: ignore
self._wrapper: Optional[MatAnyoneCallableWrapper] = None
def load(self) -> Optional[Any]:
"""
Initialize and return a callable wrapper around InferenceCore.
Returns MatAnyoneCallableWrapper if successful, else None.
"""
global InferenceCore
try:
if InferenceCore is None:
from matanyone.inference.inference_core import InferenceCore as _IC # type: ignore
InferenceCore = _IC # type: ignore
logger.info("Loading MatAnyone InferenceCore ...")
self._processor = InferenceCore(self.model_id) # type: ignore
logger.info("MatAnyone InferenceCore loaded successfully")
# Choose device
dev = (
"cuda" if (str(self.device).startswith("cuda") and torch.cuda.is_available()) else
("cpu" if str(self.device) == "cpu" else ("cuda" if torch.cuda.is_available() else "cpu"))
)
self._wrapper = MatAnyoneCallableWrapper(self._processor, device=dev)
logger.info("MatAnyone wrapped with dimension-safe callable")
return self._wrapper
except Exception as e:
logger.error(f"Failed to load MatAnyone InferenceCore: {e}")
self._processor = None
self._wrapper = None
return None
def get(self) -> Optional[Any]:
"""Return the cached callable if loaded."""
return self._wrapper or self._processor
def get_info(self) -> Dict[str, Any]:
"""Metadata for diagnostics."""
return {
"model_id": self.model_id,
"loaded": self._wrapper is not None or self._processor is not None,
"wrapped": self._wrapper is not None,
}
def cleanup(self):
"""
Clean up all resources associated with MatAnyone.
This method ensures proper cleanup of:
- The wrapper's state and memory
- The InferenceCore processor
- Any CUDA tensors in memory
"""
logger.debug("Starting MatAnyone cleanup...")
# Clean up wrapper first
if self._wrapper:
try:
self._wrapper.reset()
logger.debug("MatAnyone wrapper reset completed")
except Exception as e:
logger.debug(f"Wrapper reset failed (non-critical): {e}")
self._wrapper = None
# Clean up processor
if self._processor:
try:
# Try various cleanup methods that might exist
if hasattr(self._processor, 'cleanup'):
self._processor.cleanup()
elif hasattr(self._processor, 'clear'):
self._processor.clear()
elif hasattr(self._processor, 'reset'):
self._processor.reset()
logger.debug("MatAnyone processor cleanup attempted")
except Exception as e:
logger.debug(f"Processor cleanup failed (non-critical): {e}")
self._processor = None
# Clear any CUDA cache if using GPU
if self.device != "cpu" and torch.cuda.is_available():
try:
torch.cuda.empty_cache()
logger.debug("CUDA cache cleared for MatAnyone")
except Exception as e:
logger.debug(f"CUDA cache clear failed: {e}")
logger.info("MatAnyone resources cleaned up")