wondervictor commited on
Commit
059012f
·
verified ·
1 Parent(s): 9095b80

Update model/segment_anything_2/sam2/utils/misc.py

Browse files
model/segment_anything_2/sam2/utils/misc.py CHANGED
@@ -43,25 +43,25 @@ def get_sdpa_settings():
43
 
44
  return old_gpu, use_flash_attn, math_kernel_on
45
 
46
- from sam2.utils.misc import get_connected_components
47
 
48
- # def get_connected_components(mask):
49
- # """
50
- # Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W).
51
-
52
- # Inputs:
53
- # - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is
54
- # background.
55
 
56
- # Outputs:
57
- # - labels: A tensor of shape (N, 1, H, W) containing the connected component labels
58
- # for foreground pixels and 0 for background pixels.
59
- # - counts: A tensor of shape (N, 1, H, W) containing the area of the connected
60
- # components for foreground pixels and 0 for background pixels.
61
- # """
62
- # from model.segment_anything_2.sam2 import _C
 
 
 
 
63
 
64
- # return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
65
 
66
 
67
  def mask_to_box(masks: torch.Tensor):
 
43
 
44
  return old_gpu, use_flash_attn, math_kernel_on
45
 
46
+ # from sam2.utils.misc import get_connected_components
47
 
48
+ def get_connected_components(mask):
49
+ """
50
+ Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W).
 
 
 
 
51
 
52
+ Inputs:
53
+ - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is
54
+ background.
55
+
56
+ Outputs:
57
+ - labels: A tensor of shape (N, 1, H, W) containing the connected component labels
58
+ for foreground pixels and 0 for background pixels.
59
+ - counts: A tensor of shape (N, 1, H, W) containing the area of the connected
60
+ components for foreground pixels and 0 for background pixels.
61
+ """
62
+ from model.segment_anything_2.sam2 import _C
63
 
64
+ return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
65
 
66
 
67
  def mask_to_box(masks: torch.Tensor):