|
import scipy |
|
import skimage |
|
from skimage import feature |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from skimage.color import label2rgb |
|
from skimage.segmentation import mark_boundaries |
|
|
|
import os |
|
import sys |
|
import platform |
|
from pathlib import Path |
|
FILE = Path(__file__).resolve() |
|
ROOT = FILE.parents[0] |
|
if str(ROOT) not in sys.path: |
|
sys.path.append(str(ROOT)) |
|
if platform.system() != 'Windows': |
|
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) |
|
from segmentation_functions import generate_mask, normalize |
|
|
|
|
|
|
|
|
|
def cytof_nuclei_segmentation(im_nuclei, show_process=False, size_hole=50, size_obj=7, |
|
start_coords=(0, 0), side=100, colors=[], min_distance=2, |
|
fg_marker_dilate=2, bg_marker_dilate=2 |
|
): |
|
""" Segment nuclei based on the input nuclei image |
|
|
|
Inputs: |
|
im_nuclei = raw cytof image correspond to nuclei, size=(h, w) |
|
show_process = flag of whether show the process (default=False) |
|
size_hole = size of the hole to be removed (default=50) |
|
size_obj = size of the small objects to be removed (default=7) |
|
start_coords = the starting (x,y) coordinates of visualizing process (default=(0,0)) |
|
side = the side length of visualizing process (default=100) |
|
colors = a list of colors used to visualize segmentation results (default=[]) |
|
Returns: |
|
labels = nuclei segmentation result, where background is represented by 1, size=(h, w) |
|
colors = the list of colors used to visualize segmentation results |
|
|
|
:param im_nuclei: numpy.ndarray |
|
:param show_process: bool |
|
:param size_hole: int |
|
:param size_obj: int |
|
:param start_coords: int |
|
:return labels: numpy.ndarray |
|
:return colors: list |
|
""" |
|
|
|
if len(colors) == 0: |
|
cmap_set3 = plt.get_cmap("Set3") |
|
cmap_tab20c = plt.get_cmap("tab20c") |
|
colors = [cmap_tab20c.colors[_] for _ in range(len(cmap_tab20c.colors))] + \ |
|
[cmap_set3.colors[_] for _ in range(len(cmap_set3.colors))] |
|
|
|
x0, y0 = start_coords |
|
mask = generate_mask(np.clip(im_nuclei, 0, np.quantile(im_nuclei, 0.95)), fill_hole=False, use_watershed=False) |
|
mask = skimage.morphology.remove_small_holes(mask.astype(bool), size_hole) |
|
mask = skimage.morphology.remove_small_objects(mask.astype(bool), size_obj) |
|
if show_process: |
|
plt.figure(figsize=(4, 4)) |
|
plt.imshow(mask[x0:x0 + side, y0:y0 + side], cmap='gray') |
|
plt.show() |
|
|
|
|
|
distance = scipy.ndimage.distance_transform_edt(mask) |
|
distance = scipy.ndimage.gaussian_filter(distance, 1) |
|
local_maxi_idx = skimage.feature.peak_local_max(distance, exclude_border=False, min_distance=min_distance, |
|
labels=None) |
|
local_maxi = np.zeros_like(distance, dtype=bool) |
|
local_maxi[tuple(local_maxi_idx.T)] = True |
|
markers = scipy.ndimage.label(local_maxi)[0] |
|
markers = markers > 0 |
|
markers = skimage.morphology.dilation(markers, skimage.morphology.disk(fg_marker_dilate)) |
|
markers = skimage.morphology.label(markers) |
|
markers[markers > 0] = markers[markers > 0] + 1 |
|
markers = markers + skimage.morphology.erosion(1 - mask, skimage.morphology.disk(bg_marker_dilate)) |
|
|
|
|
|
temp_im = skimage.util.img_as_ubyte(normalize(np.clip(im_nuclei, 0, np.quantile(im_nuclei, 0.95)))) |
|
gradient = skimage.filters.rank.gradient(temp_im, skimage.morphology.disk(3)) |
|
|
|
|
|
labels = skimage.segmentation.watershed(gradient, markers) |
|
labels = skimage.morphology.closing(labels) |
|
labels_rgb = label2rgb(labels, bg_label=1, colors=colors) |
|
labels_rgb[labels == 1, ...] = (0, 0, 0) |
|
|
|
if show_process: |
|
fig, axes = plt.subplots(3, 2, figsize=(8, 12), sharex=False, sharey=False) |
|
ax = axes.ravel() |
|
ax[0].set_title("original grayscale") |
|
ax[0].imshow(np.clip(im_nuclei[x0:x0 + side, y0:y0 + side], 0, np.quantile(im_nuclei, 0.95)), |
|
interpolation='nearest') |
|
ax[1].set_title("markers") |
|
ax[1].imshow(label2rgb(markers[x0:x0 + side, y0:y0 + side], bg_label=1, colors=colors), |
|
interpolation='nearest') |
|
ax[2].set_title("distance") |
|
ax[2].imshow(-distance[x0:x0 + side, y0:y0 + side], cmap=plt.cm.nipy_spectral, interpolation='nearest') |
|
ax[3].set_title("gradient") |
|
ax[3].imshow(gradient[x0:x0 + side, y0:y0 + side], interpolation='nearest') |
|
ax[4].set_title("Watershed Labels") |
|
ax[4].imshow(labels_rgb[x0:x0 + side, y0:y0 + side, :], interpolation='nearest') |
|
ax[5].set_title("Watershed Labels") |
|
ax[5].imshow(labels_rgb, interpolation='nearest') |
|
plt.show() |
|
|
|
return labels, colors |
|
|
|
|
|
def cytof_cell_segmentation(nuclei_seg, radius=5, membrane_channel=None, show_process=False, |
|
start_coords=(0, 0), side=100, colors=[]): |
|
""" Cell segmentation based on nuclei segmentation; membrane-guided cell segmentation if membrane_channel provided. |
|
Inputs: |
|
nuclei_seg = an index image containing nuclei instance segmentation information, where the background is |
|
represented by 1, size=(h,w). Typically, the output of calling the cytof_nuclei_segmentation |
|
function. |
|
radius = assumed radius of cells (default=5) |
|
membrane_channel = membrane image channel of original cytof image (default=None) |
|
show_process = a flag indicating whether or not showing the segmentation process (default=False) |
|
start_coords = the starting (x,y) coordinates of visualizing process (default=(0,0)) |
|
side = the side length of visualizing process (default=100) |
|
colors = a list of colors used to visualize segmentation results (default=[]) |
|
Returns: |
|
labels = an index image containing cell instance segmentation information, where the background is |
|
represented by 1 |
|
colors = the list of colors used to visualize segmentation results |
|
|
|
:param nuclei_seg: numpy.ndarray |
|
:param radius: int |
|
:param membrane_channel: numpy.ndarray |
|
:param show_process: bool |
|
:param start_coords: tuple |
|
:param side: int |
|
:return labels: numpy.ndarray |
|
:return colors: list |
|
""" |
|
|
|
if len(colors) == 0: |
|
cmap_set3 = plt.get_cmap("Set3") |
|
cmap_tab20c = plt.get_cmap("tab20c") |
|
colors = [cmap_tab20c.colors[_] for _ in range(len(cmap_tab20c.colors))] + \ |
|
[cmap_set3.colors[_] for _ in range(len(cmap_set3.colors))] |
|
|
|
x0, y0 = start_coords |
|
|
|
|
|
nuclei_mask = nuclei_seg > 1 |
|
if show_process: |
|
nuclei_bg = nuclei_seg.min() |
|
fig, ax = plt.subplots(1, 2, figsize=(8, 4)) |
|
nuclei_seg_vis = label2rgb(nuclei_seg[x0:x0 + side, y0:y0 + side], bg_label=nuclei_bg, colors=colors) |
|
nuclei_seg_vis[nuclei_seg[x0:x0 + side, y0:y0 + side] == nuclei_bg, ...] = (0, 0, 0) |
|
|
|
ax[0].imshow(nuclei_seg_vis), ax[0].set_title('nuclei segmentation') |
|
ax[1].imshow(nuclei_mask[x0:x0 + side, y0:y0 + side], cmap='gray'), ax[1].set_title('nuclei mask') |
|
|
|
if membrane_channel is not None: |
|
membrane_mask = generate_mask(np.clip(membrane_channel, 0, np.quantile(membrane_channel, 0.95)), |
|
fill_hole=False, use_watershed=False) |
|
if show_process: |
|
|
|
nuclei_membrane = np.zeros((membrane_mask.shape[0], membrane_mask.shape[1], 3), dtype=np.uint8) |
|
nuclei_membrane[..., 0] = nuclei_mask * 255 |
|
nuclei_membrane[..., 1] = membrane_mask |
|
|
|
fig, ax = plt.subplots(1, 2, figsize=(8, 4)) |
|
ax[0].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray'), ax[0].set_title('membrane mask') |
|
ax[1].imshow(nuclei_membrane[x0:x0 + side, y0:y0 + side]), ax[1].set_title('nuclei - membrane') |
|
|
|
|
|
membrane_mask_close = skimage.morphology.closing(membrane_mask, skimage.morphology.disk(1)) |
|
membrane_mask_open = skimage.morphology.opening(membrane_mask_close, skimage.morphology.disk(1)) |
|
membrane_mask_erode = skimage.morphology.erosion(membrane_mask_open, skimage.morphology.disk(3)) |
|
|
|
|
|
membrane_for_skeleton = (membrane_mask_open > 0) & (nuclei_mask == False) |
|
membrane_skeleton = skimage.morphology.skeletonize(membrane_for_skeleton) |
|
'''print(membrane_skeleton) |
|
print(membrane_mask_erode)''' |
|
membrane_mask = membrane_mask_erode |
|
membrane_mask_2 = (membrane_mask_erode > 0) | membrane_skeleton |
|
|
|
if show_process: |
|
fig, axs = plt.subplots(1, 4, figsize=(16, 4)) |
|
axs[0].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray') |
|
axs[0].set_title('raw membrane mask') |
|
axs[1].imshow(membrane_mask_close[x0:x0 + side, y0:y0 + side], cmap='gray') |
|
axs[1].set_title('membrane mask - closed') |
|
axs[2].imshow(membrane_mask_open[x0:x0 + side, y0:y0 + side], cmap='gray') |
|
axs[2].set_title('membrane mask - opened') |
|
axs[3].imshow(membrane_mask_erode[x0:x0 + side, y0:y0 + side], cmap='gray') |
|
axs[3].set_title('membrane mask - erosion') |
|
plt.show() |
|
|
|
fig, axs = plt.subplots(1, 3, figsize=(12, 4)) |
|
axs[0].imshow(membrane_skeleton[x0:x0 + side, y0:y0 + side], cmap='gray') |
|
axs[0].set_title('skeleton') |
|
axs[1].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray') |
|
axs[1].set_title('membrane mask (final)') |
|
axs[2].imshow(membrane_mask_2[x0:x0 + side, y0:y0 + side], cmap='gray') |
|
axs[2].set_title('membrane mask 2') |
|
plt.show() |
|
|
|
|
|
nuclei_membrane = np.zeros((membrane_mask.shape[0], membrane_mask.shape[1], 3), dtype=np.uint8) |
|
nuclei_membrane[..., 0] = nuclei_mask * 255 |
|
nuclei_membrane[..., 1] = membrane_mask |
|
fig, ax = plt.subplots(1, 2, figsize=(8, 4)) |
|
ax[0].imshow(membrane_mask[x0:x0 + side, y0:y0 + side], cmap='gray'), ax[0].set_title('membrane mask') |
|
ax[1].imshow(nuclei_membrane[x0:x0 + side, y0:y0 + side]), ax[1].set_title('nuclei - membrane') |
|
|
|
|
|
dilate_nuclei_mask = skimage.morphology.dilation(nuclei_mask, skimage.morphology.disk(radius)) |
|
if show_process: |
|
fig, axs = plt.subplots(1, 3, figsize=(12, 4)) |
|
axs[0].imshow(nuclei_mask[x0:x0 + side, y0:y0 + side], cmap='gray') |
|
axs[0].set_title('nuclei mask') |
|
axs[1].imshow(dilate_nuclei_mask[x0:x0 + side, y0:y0 + side], cmap='gray') |
|
axs[1].set_title('dilated nuclei mask') |
|
if membrane_channel is not None: |
|
axs[2].imshow(membrane_mask[x0:x0 + side, y0:y0 + side] > 0, cmap='gray') |
|
axs[2].set_title('membrane mask') |
|
|
|
|
|
sure_fg = nuclei_mask.copy() |
|
|
|
|
|
|
|
if membrane_channel is not None: |
|
sure_bg = ((membrane_mask > 0) | (dilate_nuclei_mask == False)) & (sure_fg == False) |
|
sure_bg2 = ((membrane_mask_2 > 0) | (dilate_nuclei_mask == False)) & (sure_fg == False) |
|
else: |
|
sure_bg = (dilate_nuclei_mask == False) & (sure_fg == False) |
|
|
|
unknown = np.logical_not(np.logical_or(sure_fg, sure_bg)) |
|
|
|
if show_process: |
|
fig, axs = plt.subplots(1, 4, figsize=(16, 4)) |
|
axs[0].imshow(sure_fg[x0:x0 + side, y0:y0 + side], cmap='gray') |
|
axs[0].set_title('sure fg') |
|
axs[1].imshow(sure_bg[x0:x0 + side, y0:y0 + side], cmap='gray') |
|
if membrane_channel is not None: |
|
axs[1].set_title('sure bg: membrane | not (dilated nuclei)') |
|
else: |
|
axs[1].set_title('sure bg: not (dilated nuclei)') |
|
axs[2].imshow(unknown[x0:x0 + side, y0:y0 + side], cmap='gray') |
|
axs[2].set_title('unknown') |
|
|
|
|
|
fg_bg_un = np.zeros((unknown.shape[0], unknown.shape[1], 3), dtype=np.uint8) |
|
fg_bg_un[..., 0] = sure_fg * 255 |
|
fg_bg_un[..., 1] = sure_bg * 255 |
|
fg_bg_un[..., 2] = unknown * 255 |
|
axs[3].imshow(fg_bg_un[x0:x0 + side, y0:y0 + side]) |
|
plt.show() |
|
|
|
|
|
if membrane_channel is not None: |
|
distance_bg = -scipy.ndimage.distance_transform_edt(1 - sure_bg2) |
|
distance_fg = scipy.ndimage.distance_transform_edt(1 - sure_fg) |
|
distance = distance_bg+distance_fg |
|
else: |
|
distance = scipy.ndimage.distance_transform_edt(1 - sure_fg) |
|
distance = scipy.ndimage.gaussian_filter(distance, 1) |
|
|
|
|
|
markers = nuclei_seg.copy() |
|
markers[unknown] = 0 |
|
if show_process: |
|
fig, axs = plt.subplots(1, 2, figsize=(8, 4)) |
|
axs[0].set_title("markers") |
|
axs[0].imshow(label2rgb(markers[x0:x0 + side, y0:y0 + side], bg_label=1, colors=colors), |
|
interpolation='nearest') |
|
axs[1].set_title("distance") |
|
im = axs[1].imshow(distance[x0:x0 + side, y0:y0 + side], cmap=plt.cm.nipy_spectral, interpolation='nearest') |
|
plt.colorbar(im, ax=axs[1]) |
|
labels = skimage.segmentation.watershed(distance, markers) |
|
if show_process: |
|
fig, axs = plt.subplots(1, 4, figsize=(16, 4)) |
|
axs[0].imshow(unknown[x0:x0 + side, y0:y0 + side]) |
|
axs[0].set_title('cytoplasm') |
|
|
|
nuclei_lb = label2rgb(nuclei_seg, bg_label=1, colors=colors) |
|
nuclei_lb[nuclei_seg == 1, ...] = (0, 0, 0) |
|
axs[1].imshow(nuclei_lb) |
|
axs[1].set_xlim(x0, x0 + side - 1), axs[1].set_ylim(y0 + side - 1, y0) |
|
axs[1].set_title('nuclei') |
|
|
|
cell_lb = label2rgb(labels, bg_label=1, colors=colors) |
|
cell_lb[labels == 1, ...] = (0, 0, 0) |
|
axs[2].imshow(cell_lb) |
|
axs[2].set_title('cells') |
|
axs[2].set_xlim(x0, x0 + side - 1), axs[2].set_ylim(y0 + side - 1, y0) |
|
|
|
merge_lb = cell_lb.copy() |
|
merge_lb = cell_lb ** 2 |
|
merge_lb[nuclei_mask == 1, ...] = np.clip(nuclei_lb[nuclei_mask == 1, ...].astype(float) * 1.2, 0, 1) |
|
axs[3].imshow(merge_lb) |
|
axs[3].set_title('nuclei-cells') |
|
axs[3].set_xlim(x0, x0 + side - 1), axs[3].set_ylim(y0 + side - 1, y0) |
|
plt.show() |
|
return labels, colors |
|
|
|
|
|
def visualize_segmentation(raw_image, channels, seg, channel_ids, bound_color=(1, 1, 1), bound_mode='inner', show=True, bg_label=0): |
|
|
|
""" Visualize segmentation results with boundaries |
|
Inputs: |
|
raw_image = raw cytof image |
|
channels = a list of channels correspond to each channel in raw_image |
|
seg = instance segmentation result (index image) |
|
channel_ids = indices of desired channels to visualize results |
|
bound_color = desired color in RGB to show boundaries (default=(1,1,1), white color) |
|
bound_mode = the mode for finding boundaries, string in {‘thick’, ‘inner’, ‘outer’, ‘subpixel’}. |
|
(default="inner"). For more details, see |
|
[skimage.segmentation.mark_boundaries](https://scikit-image.org/docs/stable/api/skimage.segmentation.html) |
|
show = a flag indicating whether or not print result image on screen |
|
Returns: |
|
marked_image |
|
:param raw_image: numpy.ndarray |
|
:param seg: numpy.ndarray |
|
:param channel_ids: int |
|
:param bound_color: tuple |
|
:param bound_mode: string |
|
:param show: bool |
|
:return marked_image |
|
""" |
|
from cytof.hyperion_preprocess import cytof_merge_channels |
|
|
|
|
|
|
|
marked_image = mark_boundaries(cytof_merge_channels(raw_image, channels, channel_ids)[0], |
|
seg, mode=bound_mode, color=bound_color, background_label=bg_label) |
|
if show: |
|
plt.figure(figsize=(8,8)) |
|
plt.imshow(marked_image) |
|
plt.show() |
|
return marked_image |
|
|
|
|