|
|
""" |
|
|
MatAnyone Loader - Stable Callable Wrapper for InferenceCore (extra-dim stripping) |
|
|
================================================================================= |
|
|
|
|
|
- Always call InferenceCore UNBATCHED: |
|
|
image -> CHW float32 [0,1] |
|
|
mask -> 1HW float32 [0,1] |
|
|
- Aggressively strip extra dims: |
|
|
e.g. [B,T,C,H,W] -> [C,H,W] (use first slice when B/T > 1 with a warning) |
|
|
e.g. [B,C,H,W] -> [C,H,W] |
|
|
e.g. [H,W,C,1] -> [H,W,C] |
|
|
- Robust alpha extraction -> (H,W) float32 [0,1] |
|
|
""" |
|
|
from __future__ import annotations |
|
|
|
|
|
import logging |
|
|
from typing import Optional, Dict, Any, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
try: |
|
|
|
|
|
from matanyone.inference.inference_core import InferenceCore |
|
|
except Exception: |
|
|
InferenceCore = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _to_float01_np(arr: np.ndarray) -> np.ndarray: |
|
|
"""Ensure numpy array is float32 in [0,1].""" |
|
|
if arr.dtype == np.uint8: |
|
|
arr = arr.astype(np.float32) / 255.0 |
|
|
else: |
|
|
arr = arr.astype(np.float32, copy=False) |
|
|
np.clip(arr, 0.0, 1.0, out=arr) |
|
|
return arr |
|
|
|
|
|
|
|
|
def _strip_leading_extras_to_ndim(x: Union[np.ndarray, torch.Tensor], target_ndim: int) -> Union[np.ndarray, torch.Tensor]: |
|
|
""" |
|
|
Reduce x to at most target_ndim by removing leading dims. |
|
|
- If a leading dim == 1, squeeze it. |
|
|
- If a leading dim > 1, take the first slice and log a warning. |
|
|
Repeat until ndim <= target_ndim. |
|
|
""" |
|
|
is_tensor = torch.is_tensor(x) |
|
|
get_shape = (lambda t: tuple(t.shape)) if is_tensor else (lambda a: a.shape) |
|
|
index_first = (lambda t: t[0]) if is_tensor else (lambda a: a[0]) |
|
|
squeeze_first = (lambda t: t.squeeze(0)) if is_tensor else (lambda a: np.squeeze(a, axis=0)) |
|
|
|
|
|
while len(get_shape(x)) > target_ndim: |
|
|
dim0 = get_shape(x)[0] |
|
|
if dim0 == 1: |
|
|
x = squeeze_first(x) |
|
|
else: |
|
|
logger.warning(f"Input has extra leading dim >1 ({dim0}); taking the first slice.") |
|
|
x = index_first(x) |
|
|
return x |
|
|
|
|
|
|
|
|
def _ensure_chw_float01(image: Union[np.ndarray, torch.Tensor], *, name: str = "image") -> torch.Tensor: |
|
|
""" |
|
|
Convert image to torch.FloatTensor CHW in [0,1], stripping extras. |
|
|
Accepts shapes up to 5D (e.g. B,T,C,H,W / B,C,H,W / H,W,C / CHW / HW / ...). |
|
|
If ambiguous multi-channel, picks first channel with a warning. |
|
|
""" |
|
|
orig_shape = tuple(image.shape) if not torch.is_tensor(image) else tuple(image.shape) |
|
|
|
|
|
image = _strip_leading_extras_to_ndim(image, 3) |
|
|
|
|
|
if torch.is_tensor(image): |
|
|
t = image |
|
|
if t.ndim == 4: |
|
|
t = _strip_leading_extras_to_ndim(t, 3) |
|
|
|
|
|
if t.ndim == 3: |
|
|
c0, c1, c2 = t.shape |
|
|
if c0 in (1, 3, 4): |
|
|
pass |
|
|
elif c2 in (1, 3, 4): |
|
|
t = t.permute(2, 0, 1) |
|
|
else: |
|
|
logger.warning(f"{name}: ambiguous 3D shape {tuple(t.shape)}; attempting HWC->CHW then selecting first channel.") |
|
|
t = t.permute(2, 0, 1) |
|
|
if t.shape[0] > 1: |
|
|
t = t[0] |
|
|
t = t.unsqueeze(0) |
|
|
elif t.ndim == 2: |
|
|
t = t.unsqueeze(0) |
|
|
else: |
|
|
raise ValueError(f"{name}: unsupported tensor dims {tuple(t.shape)} after stripping.") |
|
|
|
|
|
t = t.to(dtype=torch.float32) |
|
|
if torch.max(t) > 1.5: |
|
|
t = t / 255.0 |
|
|
t = torch.clamp(t, 0.0, 1.0) |
|
|
logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (CHW)") |
|
|
return t |
|
|
|
|
|
arr = np.asarray(image) |
|
|
if arr.ndim == 4: |
|
|
arr = _strip_leading_extras_to_ndim(arr, 3) |
|
|
|
|
|
if arr.ndim == 3: |
|
|
if arr.shape[0] in (1, 3, 4): |
|
|
pass |
|
|
elif arr.shape[-1] in (1, 3, 4): |
|
|
arr = arr.transpose(2, 0, 1) |
|
|
else: |
|
|
logger.warning(f"{name}: ambiguous 3D shape {arr.shape}; trying HWC->CHW and selecting first channel.") |
|
|
arr = arr.transpose(2, 0, 1) |
|
|
if arr.shape[0] > 1: |
|
|
arr = arr[0:1, ...] |
|
|
elif arr.ndim == 2: |
|
|
arr = arr[None, ...] |
|
|
else: |
|
|
raise ValueError(f"{name}: unsupported numpy dims {arr.shape} after stripping.") |
|
|
|
|
|
arr = _to_float01_np(arr) |
|
|
t = torch.from_numpy(arr) |
|
|
logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (CHW)") |
|
|
return t |
|
|
|
|
|
|
|
|
def _ensure_1hw_float01(mask: Union[np.ndarray, torch.Tensor], *, name: str = "mask") -> torch.Tensor: |
|
|
""" |
|
|
Convert mask to torch.FloatTensor 1HW in [0,1], stripping extras. |
|
|
Accepts up to 4D inputs; collapses leading dims; picks first slice/channel if needed. |
|
|
""" |
|
|
orig_shape = tuple(mask.shape) if not torch.is_tensor(mask) else tuple(mask.shape) |
|
|
mask = _strip_leading_extras_to_ndim(mask, 3) |
|
|
|
|
|
if torch.is_tensor(mask): |
|
|
m = mask |
|
|
if m.ndim == 3: |
|
|
if m.shape[0] == 1: |
|
|
pass |
|
|
elif m.shape[-1] == 1: |
|
|
m = m.permute(2, 0, 1) |
|
|
else: |
|
|
logger.warning(f"{name}: multi-channel {tuple(m.shape)}; using first channel.") |
|
|
if m.shape[0] in (3, 4): |
|
|
m = m[0:1, ...] |
|
|
elif m.shape[-1] in (3, 4): |
|
|
m = m.permute(2, 0, 1)[0:1, ...] |
|
|
else: |
|
|
m = m[0:1, ...] |
|
|
elif m.ndim == 2: |
|
|
m = m.unsqueeze(0) |
|
|
else: |
|
|
raise ValueError(f"{name}: unsupported tensor dims {tuple(m.shape)} after stripping.") |
|
|
|
|
|
m = m.to(dtype=torch.float32) |
|
|
if torch.max(m) > 1.5: |
|
|
m = m / 255.0 |
|
|
m = torch.clamp(m, 0.0, 1.0) |
|
|
logger.debug(f"{name}: {orig_shape} -> {tuple(m.shape)} (1HW)") |
|
|
return m |
|
|
|
|
|
arr = np.asarray(mask) |
|
|
if arr.ndim == 3: |
|
|
if arr.shape[0] == 1: |
|
|
pass |
|
|
elif arr.shape[-1] == 1: |
|
|
arr = arr.transpose(2, 0, 1) |
|
|
else: |
|
|
logger.warning(f"{name}: multi-channel {arr.shape}; using first channel.") |
|
|
if arr.shape[0] in (3, 4): |
|
|
arr = arr[0:1, ...] |
|
|
elif arr.shape[-1] in (3, 4): |
|
|
arr = arr.transpose(2, 0, 1)[0:1, ...] |
|
|
else: |
|
|
arr = arr[0:1, ...] |
|
|
elif arr.ndim == 2: |
|
|
arr = arr[None, ...] |
|
|
else: |
|
|
raise ValueError(f"{name}: unsupported numpy dims {arr.shape} after stripping.") |
|
|
|
|
|
arr = _to_float01_np(arr) |
|
|
t = torch.from_numpy(arr) |
|
|
logger.debug(f"{name}: {orig_shape} -> {tuple(t.shape)} (1HW)") |
|
|
return t |
|
|
|
|
|
|
|
|
def _alpha_from_result(result: Union[np.ndarray, torch.Tensor]) -> np.ndarray: |
|
|
"""Extract a 2D alpha (H,W) float32 [0,1] from various outputs.""" |
|
|
if result is None: |
|
|
return np.full((512, 512), 0.5, dtype=np.float32) |
|
|
|
|
|
if torch.is_tensor(result): |
|
|
result = result.detach().float().cpu() |
|
|
|
|
|
arr = np.asarray(result) |
|
|
while arr.ndim > 3: |
|
|
if arr.shape[0] > 1: |
|
|
logger.warning(f"Result has leading dim {arr.shape[0]}; taking first slice.") |
|
|
arr = arr[0] |
|
|
|
|
|
if arr.ndim == 2: |
|
|
alpha = arr |
|
|
elif arr.ndim == 3: |
|
|
if arr.shape[0] in (1, 3, 4): |
|
|
alpha = arr[0] |
|
|
elif arr.shape[-1] in (1, 3, 4): |
|
|
alpha = arr[..., 0] |
|
|
else: |
|
|
alpha = arr[0] |
|
|
else: |
|
|
alpha = np.full((512, 512), 0.5, dtype=np.float32) |
|
|
|
|
|
alpha = alpha.astype(np.float32, copy=False) |
|
|
np.clip(alpha, 0.0, 1.0, out=alpha) |
|
|
return alpha |
|
|
|
|
|
|
|
|
def _hw_from_image_like(x: Union[np.ndarray, torch.Tensor]) -> Tuple[int, int]: |
|
|
"""Best-effort infer (H, W) for fallback mask sizing.""" |
|
|
shape = tuple(x.shape) if torch.is_tensor(x) else np.asarray(x).shape |
|
|
if len(shape) == 2: |
|
|
return shape[0], shape[1] |
|
|
if len(shape) == 3: |
|
|
if shape[0] in (1, 3, 4): |
|
|
return shape[1], shape[2] |
|
|
if shape[-1] in (1, 3, 4): |
|
|
return shape[0], shape[1] |
|
|
return shape[1], shape[2] |
|
|
if len(shape) >= 4: |
|
|
if len(shape) >= 4 and (shape[1] in (1, 3, 4)): |
|
|
return shape[2], shape[3] |
|
|
return shape[-3], shape[-2] |
|
|
return 512, 512 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MatAnyoneCallableWrapper: |
|
|
""" |
|
|
Callable session-like wrapper around an InferenceCore instance. |
|
|
|
|
|
Contract: |
|
|
- First call SHOULD include a mask (1HW). If not, returns neutral 0.5 alpha. |
|
|
- Subsequent calls do not require mask. |
|
|
- Returns 2D alpha (H,W) float32 in [0,1]. |
|
|
- Strips any extra dims from inputs before calling core. |
|
|
""" |
|
|
|
|
|
def __init__(self, inference_core, device: str = None): |
|
|
self.core = inference_core |
|
|
self.initialized = False |
|
|
|
|
|
if device is None: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.device = device |
|
|
|
|
|
def __call__(self, image, mask=None, **kwargs) -> np.ndarray: |
|
|
try: |
|
|
img_chw = _ensure_chw_float01(image, name="image").to(self.device, non_blocking=True) |
|
|
|
|
|
if not self.initialized: |
|
|
if mask is None: |
|
|
h, w = _hw_from_image_like(image) |
|
|
logger.warning("MatAnyone first frame called without mask; returning neutral alpha.") |
|
|
return np.full((h, w), 0.5, dtype=np.float32) |
|
|
|
|
|
m_1hw = _ensure_1hw_float01(mask, name="mask").to(self.device, non_blocking=True) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
if hasattr(self.core, "step"): |
|
|
result = self.core.step(image=img_chw, mask=m_1hw, **kwargs) |
|
|
elif hasattr(self.core, "process_frame"): |
|
|
result = self.core.process_frame(img_chw, m_1hw, **kwargs) |
|
|
else: |
|
|
logger.warning("InferenceCore has no recognized frame API; echoing input mask.") |
|
|
return _alpha_from_result(mask) |
|
|
|
|
|
self.initialized = True |
|
|
return _alpha_from_result(result) |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
if hasattr(self.core, "step"): |
|
|
result = self.core.step(image=img_chw, **kwargs) |
|
|
elif hasattr(self.core, "process_frame"): |
|
|
result = self.core.process_frame(img_chw, **kwargs) |
|
|
else: |
|
|
h, w = _hw_from_image_like(image) |
|
|
logger.warning("InferenceCore has no recognized frame API on subsequent call; returning neutral alpha.") |
|
|
return np.full((h, w), 0.5, dtype=np.float32) |
|
|
|
|
|
return _alpha_from_result(result) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"MatAnyone wrapper call failed: {e}") |
|
|
|
|
|
if mask is not None: |
|
|
try: |
|
|
return _alpha_from_result(mask) |
|
|
except Exception: |
|
|
pass |
|
|
h, w = _hw_from_image_like(image) |
|
|
return np.full((h, w), 0.5, dtype=np.float32) |
|
|
|
|
|
def reset(self): |
|
|
"""Reset state between videos.""" |
|
|
self.initialized = False |
|
|
if hasattr(self.core, "reset"): |
|
|
try: |
|
|
self.core.reset() |
|
|
except Exception as e: |
|
|
logger.debug(f"Core reset() failed: {e}") |
|
|
elif hasattr(self.core, "clear_memory"): |
|
|
try: |
|
|
self.core.clear_memory() |
|
|
except Exception as e: |
|
|
logger.debug(f"Core clear_memory() failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MatAnyoneLoader: |
|
|
""" |
|
|
Loader for MatAnyone InferenceCore with cleanup support. |
|
|
|
|
|
Provides a consistent interface with other model loaders, |
|
|
including proper resource cleanup. |
|
|
""" |
|
|
|
|
|
def __init__(self, device: str = "auto", model_id: str = "PeiqingYang/MatAnyone"): |
|
|
self.device = device |
|
|
self.model_id = model_id |
|
|
self._processor: Optional[InferenceCore] = None |
|
|
self._wrapper: Optional[MatAnyoneCallableWrapper] = None |
|
|
|
|
|
def load(self) -> Optional[Any]: |
|
|
""" |
|
|
Initialize and return a callable wrapper around InferenceCore. |
|
|
Returns MatAnyoneCallableWrapper if successful, else None. |
|
|
""" |
|
|
global InferenceCore |
|
|
try: |
|
|
if InferenceCore is None: |
|
|
from matanyone.inference.inference_core import InferenceCore as _IC |
|
|
InferenceCore = _IC |
|
|
|
|
|
logger.info("Loading MatAnyone InferenceCore ...") |
|
|
self._processor = InferenceCore(self.model_id) |
|
|
logger.info("MatAnyone InferenceCore loaded successfully") |
|
|
|
|
|
|
|
|
dev = ( |
|
|
"cuda" if (str(self.device).startswith("cuda") and torch.cuda.is_available()) else |
|
|
("cpu" if str(self.device) == "cpu" else ("cuda" if torch.cuda.is_available() else "cpu")) |
|
|
) |
|
|
|
|
|
self._wrapper = MatAnyoneCallableWrapper(self._processor, device=dev) |
|
|
logger.info("MatAnyone wrapped with dimension-safe callable") |
|
|
return self._wrapper |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load MatAnyone InferenceCore: {e}") |
|
|
self._processor = None |
|
|
self._wrapper = None |
|
|
return None |
|
|
|
|
|
def get(self) -> Optional[Any]: |
|
|
"""Return the cached callable if loaded.""" |
|
|
return self._wrapper or self._processor |
|
|
|
|
|
def get_info(self) -> Dict[str, Any]: |
|
|
"""Metadata for diagnostics.""" |
|
|
return { |
|
|
"model_id": self.model_id, |
|
|
"loaded": self._wrapper is not None or self._processor is not None, |
|
|
"wrapped": self._wrapper is not None, |
|
|
} |
|
|
|
|
|
def cleanup(self): |
|
|
""" |
|
|
Clean up all resources associated with MatAnyone. |
|
|
|
|
|
This method ensures proper cleanup of: |
|
|
- The wrapper's state and memory |
|
|
- The InferenceCore processor |
|
|
- Any CUDA tensors in memory |
|
|
""" |
|
|
logger.debug("Starting MatAnyone cleanup...") |
|
|
|
|
|
|
|
|
if self._wrapper: |
|
|
try: |
|
|
self._wrapper.reset() |
|
|
logger.debug("MatAnyone wrapper reset completed") |
|
|
except Exception as e: |
|
|
logger.debug(f"Wrapper reset failed (non-critical): {e}") |
|
|
self._wrapper = None |
|
|
|
|
|
|
|
|
if self._processor: |
|
|
try: |
|
|
|
|
|
if hasattr(self._processor, 'cleanup'): |
|
|
self._processor.cleanup() |
|
|
elif hasattr(self._processor, 'clear'): |
|
|
self._processor.clear() |
|
|
elif hasattr(self._processor, 'reset'): |
|
|
self._processor.reset() |
|
|
logger.debug("MatAnyone processor cleanup attempted") |
|
|
except Exception as e: |
|
|
logger.debug(f"Processor cleanup failed (non-critical): {e}") |
|
|
self._processor = None |
|
|
|
|
|
|
|
|
if self.device != "cpu" and torch.cuda.is_available(): |
|
|
try: |
|
|
torch.cuda.empty_cache() |
|
|
logger.debug("CUDA cache cleared for MatAnyone") |
|
|
except Exception as e: |
|
|
logger.debug(f"CUDA cache clear failed: {e}") |
|
|
|
|
|
logger.info("MatAnyone resources cleaned up") |