# utils/interop.py from __future__ import annotations import torch def log_shape(tag: str, t: torch.Tensor) -> None: try: mn = float(t.min()) if t.numel() else float("nan") mx = float(t.max()) if t.numel() else float("nan") print(f"[interop] {tag}: shape={tuple(t.shape)} dtype={t.dtype} device={t.device} " f"range=[{mn:.4f},{mx:.4f}]") except Exception as e: print(f"[interop] {tag}: ") def _to_float01(x: torch.Tensor) -> torch.Tensor: x = x.to(torch.float32) if x.max() > 1.0: x = x / 255.0 return x.clamp_(0.0, 1.0) def _squeeze_bt(x: torch.Tensor) -> torch.Tensor: # Drop singleton Time and extra Batch: (B,T,C,H,W) → (B,C,H,W) or (C,H,W) if x.ndim == 5: if x.shape[1] == 1: x = x.squeeze(1) # drop T if x.ndim == 5 and x.shape[0] == 1: x = x.squeeze(0) # drop B # Edge case: (1,1,3,H,W) if x.ndim == 4 and x.shape[0] == 1 and x.shape[1] == 1 and x.shape[-3] == 3: x = x.squeeze(1) # → (1,3,H,W) return x def ensure_image_nchw( img: torch.Tensor, device: torch.device | str = "cuda", want_batched: bool = True, ) -> torch.Tensor: img = img.to(device) img = _squeeze_bt(img) if img.ndim == 3: # CHW or HWC if img.shape[0] in (1,3): chw = img else: chw = img.permute(2,0,1) # HWC→CHW chw = _to_float01(chw.contiguous()) return chw.unsqueeze(0) if want_batched else chw if img.ndim == 4: N,A,B,C = img.shape if A == 3: nchw = img elif C == 3: nchw = img.permute(0,3,1,2) # NHWC→NCHW else: raise AssertionError(f"Cannot infer channels in image: {tuple(img.shape)}") return _to_float01(nchw.contiguous()) raise AssertionError(f"Image must be 3D/4D; got {tuple(img.shape)}") def ensure_mask_for_matanyone( mask: torch.Tensor, *, idx_mask: bool = False, threshold: float = 0.5, keep_soft: bool = False, device: torch.device | str = "cuda", ) -> torch.Tensor: mask = mask.to(device) mask = _squeeze_bt(mask) if idx_mask: # Return (H,W) labels {0,1} if mask.ndim == 3: if mask.shape[0] == 1: idx = (mask[0] >= threshold).to(torch.long) else: idx = torch.argmax(mask, dim=0).to(torch.long) idx = (idx > 0).to(torch.long) elif mask.ndim == 2: idx = (mask >= threshold).to(torch.long) else: raise AssertionError(f"idx mask must be 2D/3D; got {tuple(mask.shape)}") return idx # Channel mask path → (1,H,W) float [0,1] if mask.ndim == 2: out = mask.unsqueeze(0) elif mask.ndim == 3: if mask.shape[0] == 1: out = mask else: # choose largest area channel areas = mask.sum(dim=(-2,-1)) out = mask[areas.argmax():areas.argmax()+1] else: raise AssertionError(f"mask must be 2D/3D; got {tuple(mask.shape)}") out = out.to(torch.float32) if not keep_soft: out = (out >= threshold).to(torch.float32) return out.clamp_(0.0, 1.0).contiguous()