File size: 1,557 Bytes
dbe3751 2586f05 dbe3751 2586f05 dbe3751 2586f05 dbe3751 2586f05 dbe3751 2586f05 dbe3751 2586f05 dbe3751 2586f05 dbe3751 2586f05 dbe3751 2586f05 dbe3751 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
# utils/mask_bridge.py
from __future__ import annotations
import torch
def log_shape(tag: str, t: torch.Tensor) -> None:
try:
mn = float(t.min()) if t.numel() else float("nan")
mx = float(t.max()) if t.numel() else float("nan")
print(f"[mask_bridge] {tag}: shape={tuple(t.shape)} dtype={t.dtype} device={t.device} "
f"range=[{mn:.4f},{mx:.4f}]")
except Exception:
pass
def sam2_to_matanyone_mask(
sam2_masks: torch.Tensor, # (B,M,H,W) after post_process
iou_scores: torch.Tensor | None, # (B,M) or None
threshold: float = 0.5,
return_mode: str = "single", # "single"→(1,H,W) or "multi"→(C,H,W)
keep_soft: bool = False,
) -> torch.Tensor:
assert sam2_masks.ndim == 4, f"Expect (B,M,H,W). Got {tuple(sam2_masks.shape)}"
B, M, H, W = sam2_masks.shape
assert B == 1, "Bridge expects B=1 for first-frame bootstrapping"
candidates = sam2_masks[0] # (M,H,W)
if iou_scores is not None and iou_scores.ndim == 2 and iou_scores.shape[0] == 1:
best_idx = int(torch.argmax(iou_scores[0]).item())
else:
areas = candidates.sum(dim=(-2,-1))
best_idx = int(torch.argmax(areas).item())
if return_mode == "multi":
out = candidates # (M,H,W) treat as (C,H,W)
else:
out = candidates[best_idx:best_idx+1] # (1,H,W)
out = out.to(torch.float32)
if not keep_soft:
out = (out >= threshold).float()
out = out.clamp_(0.0, 1.0).contiguous()
log_shape("sam2→mat.mask", out)
return out
|