kaveh commited on
Commit
7b07ae4
·
1 Parent(s): 8a0f092

added cell boundary selection

Browse files
Files changed (1) hide show
  1. utils/segmentation.py +65 -0
utils/segmentation.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cell segmentation from force map for background exclusion."""
2
+ import numpy as np
3
+ from scipy.ndimage import gaussian_filter
4
+ from skimage.filters import threshold_otsu
5
+ from skimage.morphology import binary_closing, binary_opening, binary_dilation, remove_small_objects, disk
6
+ from skimage.measure import label, regionprops
7
+
8
+
9
+ def estimate_cell_mask(heatmap, sigma=2, min_size=200, exclude_full_image=True,
10
+ threshold_relax=0.85, dilate_radius=4):
11
+ """
12
+ Estimate cell region from force map using Otsu thresholding and morphological cleanup.
13
+
14
+ Args:
15
+ heatmap: 2D float array [0, 1] - predicted force map
16
+ sigma: Gaussian smoothing sigma to reduce noise. Default 2.
17
+ min_size: Minimum object size in pixels; smaller objects removed. Default 200.
18
+ exclude_full_image: If True, exclude the largest connected component when it
19
+ covers most of the image (>70%) and use the second largest. Default True.
20
+ threshold_relax: Multiply Otsu threshold by this (<1 = looser, include more pixels).
21
+ Default 0.85.
22
+ dilate_radius: Radius to dilate mask outward to include surrounding pixels.
23
+ Default 4.
24
+
25
+ Returns:
26
+ mask: Binary uint8 array, 1 = estimated cell, 0 = background
27
+ """
28
+ heatmap = np.clip(heatmap, 0, 1).astype(np.float64)
29
+ if np.max(heatmap) <= 0:
30
+ return np.zeros_like(heatmap, dtype=np.uint8)
31
+
32
+ # Smooth to reduce noise
33
+ smoothed = gaussian_filter(heatmap, sigma=sigma)
34
+
35
+ # Otsu automatic threshold, relaxed to include more pixels
36
+ thresh = threshold_otsu(smoothed) * threshold_relax
37
+ mask = (smoothed > thresh).astype(np.uint8)
38
+
39
+ # Morphological cleanup
40
+ mask = binary_closing(mask, disk(5)).astype(np.uint8)
41
+ mask = binary_opening(mask, disk(3)).astype(np.uint8)
42
+ mask = remove_small_objects(mask.astype(bool), min_size=min_size).astype(np.uint8)
43
+
44
+ # Select component: second largest if largest is whole image
45
+ labeled = label(mask)
46
+ props = list(regionprops(labeled))
47
+
48
+ if len(props) == 0:
49
+ return np.zeros_like(heatmap, dtype=np.uint8)
50
+
51
+ props_sorted = sorted(props, key=lambda x: x.area, reverse=True)
52
+ total_px = heatmap.shape[0] * heatmap.shape[1]
53
+
54
+ if exclude_full_image and len(props_sorted) >= 2 and props_sorted[0].area > 0.7 * total_px:
55
+ region = props_sorted[1]
56
+ else:
57
+ region = props_sorted[0]
58
+
59
+ mask = (labeled == region.label).astype(np.uint8)
60
+
61
+ # Dilate to include surrounding pixels
62
+ if dilate_radius > 0:
63
+ mask = binary_dilation(mask, disk(dilate_radius)).astype(np.uint8)
64
+
65
+ return mask