File size: 1,557 Bytes
dbe3751
 
2586f05
 
 
dbe3751
 
 
 
 
 
 
2586f05
 
dbe3751
 
 
 
 
2586f05
 
 
dbe3751
2586f05
dbe3751
 
2586f05
 
dbe3751
2586f05
 
 
dbe3751
2586f05
dbe3751
2586f05
dbe3751
2586f05
 
dbe3751
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# utils/mask_bridge.py
from __future__ import annotations
import torch

def log_shape(tag: str, t: torch.Tensor) -> None:
    try:
        mn = float(t.min()) if t.numel() else float("nan")
        mx = float(t.max()) if t.numel() else float("nan")
        print(f"[mask_bridge] {tag}: shape={tuple(t.shape)} dtype={t.dtype} device={t.device} "
              f"range=[{mn:.4f},{mx:.4f}]")
    except Exception:
        pass

def sam2_to_matanyone_mask(
    sam2_masks: torch.Tensor,         # (B,M,H,W) after post_process
    iou_scores: torch.Tensor | None,  # (B,M) or None
    threshold: float = 0.5,
    return_mode: str = "single",      # "single"→(1,H,W) or "multi"→(C,H,W)
    keep_soft: bool = False,
) -> torch.Tensor:
    assert sam2_masks.ndim == 4, f"Expect (B,M,H,W). Got {tuple(sam2_masks.shape)}"
    B, M, H, W = sam2_masks.shape
    assert B == 1, "Bridge expects B=1 for first-frame bootstrapping"

    candidates = sam2_masks[0]  # (M,H,W)
    if iou_scores is not None and iou_scores.ndim == 2 and iou_scores.shape[0] == 1:
        best_idx = int(torch.argmax(iou_scores[0]).item())
    else:
        areas = candidates.sum(dim=(-2,-1))
        best_idx = int(torch.argmax(areas).item())

    if return_mode == "multi":
        out = candidates  # (M,H,W) treat as (C,H,W)
    else:
        out = candidates[best_idx:best_idx+1]  # (1,H,W)

    out = out.to(torch.float32)
    if not keep_soft:
        out = (out >= threshold).float()
    out = out.clamp_(0.0, 1.0).contiguous()
    log_shape("sam2→mat.mask", out)
    return out