dillonlaird commited on
Commit
d9697ef
1 Parent(s): 199c85f

added postprocessing

Browse files
Files changed (1) hide show
  1. app/sam/postprocess.py +21 -0
app/sam/postprocess.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ import numpy.typing as npt
5
+
6
+ from torch import Tensor
7
+ from kornia.morphology import erosion, dilation
8
+
9
+
10
+ def clean_mask_torch(mask: Tensor) -> Tensor:
11
+ kernel = torch.ones(2, 2).to(mask.device)
12
+ if len(mask.shape) == 2:
13
+ mask = mask[None, None, :, :]
14
+ if mask.dtype == torch.bool:
15
+ mask = mask.int()
16
+ return dilation(erosion(mask, kernel), kernel)
17
+
18
+
19
+ def clean_mask_np(mask: npt.NDArray) -> npt.NDArray:
20
+ kernel = np.ones((2, 2), np.uint8)
21
+ return cv2.dilate(cv2.erode(mask, kernel, iterations=1), kernel, iterations=1)