Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# PostProcessing Pipeline | |
# | |
# Adapted from HoverNet | |
# HoverNet Network (https://doi.org/10.1016/j.media.2019.101563) | |
# Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net) | |
# | |
# @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
# Institute for Artifical Intelligence in Medicine, | |
# University Medicine Essen | |
import warnings | |
from typing import Tuple, Literal,List | |
import cv2 | |
import numpy as np | |
from scipy.ndimage import measurements | |
from scipy.ndimage.morphology import binary_fill_holes | |
from skimage.segmentation import watershed | |
import torch | |
from .tools import get_bounding_box, remove_small_objects | |
def noop(*args, **kargs): | |
pass | |
warnings.warn = noop | |
class DetectionCellPostProcessor: | |
def __init__( | |
self, | |
nr_types: int = None, | |
magnification: Literal[20, 40] = 40, | |
gt: bool = False, | |
) -> None: | |
"""DetectionCellPostProcessor for postprocessing prediction maps and get detected cells | |
Args: | |
nr_types (int, optional): Number of cell types, including background (background = 0). Defaults to None. | |
magnification (Literal[20, 40], optional): Which magnification the data has. Defaults to 40. | |
gt (bool, optional): If this is gt data (used that we do not suppress tiny cells that may be noise in a prediction map). | |
Defaults to False. | |
Raises: | |
NotImplementedError: Unknown magnification | |
""" | |
self.nr_types = nr_types | |
self.magnification = magnification | |
self.gt = gt | |
if magnification == 40: | |
self.object_size = 10 | |
self.k_size = 21 | |
elif magnification == 20: | |
self.object_size = 3 # 3 or 40, we used 5 | |
self.k_size = 11 # 11 or 41, we used 13 | |
else: | |
raise NotImplementedError("Unknown magnification") | |
if gt: # to not supress something in gt! | |
self.object_size = 100 | |
self.k_size = 21 | |
def post_process_cell_segmentation( | |
self, | |
pred_map: np.ndarray, | |
) -> Tuple[np.ndarray, dict]: | |
"""Post processing of one image tile | |
Args: | |
pred_map (np.ndarray): Combined output of tp, np and hv branches, in the same order. Shape: (H, W, 4) | |
Returns: | |
Tuple[np.ndarray, dict]: | |
np.ndarray: Instance map for one image. Each nuclei has own integer. Shape: (H, W) | |
dict: Instance dictionary. Main Key is the nuclei instance number (int), with a dict as value. | |
For each instance, the dictionary contains the keys: bbox (bounding box), centroid (centroid coordinates), | |
contour, type_prob (probability), type (nuclei type) | |
""" | |
if self.nr_types is not None: | |
pred_type = pred_map[..., :1] | |
pred_inst = pred_map[..., 1:] | |
pred_type = pred_type.astype(np.int32) | |
else: | |
pred_inst = pred_map | |
pred_inst = np.squeeze(pred_inst) | |
pred_inst = self.__proc_np_hv( | |
pred_inst, object_size=self.object_size, ksize=self.k_size | |
) | |
inst_id_list = np.unique(pred_inst)[1:] # exlcude background | |
inst_info_dict = {} | |
for inst_id in inst_id_list: | |
inst_map = pred_inst == inst_id | |
rmin, rmax, cmin, cmax = get_bounding_box(inst_map) | |
inst_bbox = np.array([[rmin, cmin], [rmax, cmax]]) | |
inst_map = inst_map[ | |
inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1] | |
] | |
inst_map = inst_map.astype(np.uint8) | |
inst_moment = cv2.moments(inst_map) | |
inst_contour = cv2.findContours( | |
inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE | |
) | |
# * opencv protocol format may break | |
inst_contour = np.squeeze(inst_contour[0][0].astype("int32")) | |
# < 3 points dont make a contour, so skip, likely artifact too | |
# as the contours obtained via approximation => too small or sthg | |
if inst_contour.shape[0] < 3: | |
continue | |
if len(inst_contour.shape) != 2: | |
continue # ! check for trickery shape | |
inst_centroid = [ | |
(inst_moment["m10"] / inst_moment["m00"]), | |
(inst_moment["m01"] / inst_moment["m00"]), | |
] | |
inst_centroid = np.array(inst_centroid) | |
inst_contour[:, 0] += inst_bbox[0][1] # X | |
inst_contour[:, 1] += inst_bbox[0][0] # Y | |
inst_centroid[0] += inst_bbox[0][1] # X | |
inst_centroid[1] += inst_bbox[0][0] # Y | |
inst_info_dict[inst_id] = { # inst_id should start at 1 | |
"bbox": inst_bbox, | |
"centroid": inst_centroid, | |
"contour": inst_contour, | |
"type_prob": None, | |
"type": None, | |
} | |
#### * Get class of each instance id, stored at index id-1 (inst_id = number of deteced nucleus) | |
for inst_id in list(inst_info_dict.keys()): | |
rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten() | |
inst_map_crop = pred_inst[rmin:rmax, cmin:cmax] | |
inst_type_crop = pred_type[rmin:rmax, cmin:cmax] | |
inst_map_crop = inst_map_crop == inst_id | |
inst_type = inst_type_crop[inst_map_crop] | |
type_list, type_pixels = np.unique(inst_type, return_counts=True) | |
type_list = list(zip(type_list, type_pixels)) | |
type_list = sorted(type_list, key=lambda x: x[1], reverse=True) | |
inst_type = type_list[0][0] | |
if inst_type == 0: # ! pick the 2nd most dominant if exist | |
if len(type_list) > 1: | |
inst_type = type_list[1][0] | |
type_dict = {v[0]: v[1] for v in type_list} | |
type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6) | |
inst_info_dict[inst_id]["type"] = int(inst_type) | |
inst_info_dict[inst_id]["type_prob"] = float(type_prob) | |
return pred_inst, inst_info_dict | |
def __proc_np_hv( | |
self, pred: np.ndarray, object_size: int = 10, ksize: int = 21 | |
) -> np.ndarray: | |
"""Process Nuclei Prediction with XY Coordinate Map and generate instance map (each instance has unique integer) | |
Separate Instances (also overlapping ones) from binary nuclei map and hv map by using morphological operations and watershed | |
Args: | |
pred (np.ndarray): Prediction output, assuming. Shape: (H, W, 3) | |
* channel 0 contain probability map of nuclei | |
* channel 1 containing the regressed X-map | |
* channel 2 containing the regressed Y-map | |
object_size (int, optional): Smallest oject size for filtering. Defaults to 10 | |
k_size (int, optional): Sobel Kernel size. Defaults to 21 | |
Returns: | |
np.ndarray: Instance map for one image. Each nuclei has own integer. Shape: (H, W) | |
""" | |
pred = np.array(pred, dtype=np.float32) | |
blb_raw = pred[..., 0] | |
h_dir_raw = pred[..., 1] | |
v_dir_raw = pred[..., 2] | |
# processing | |
blb = np.array(blb_raw >= 0.5, dtype=np.int32) | |
blb = measurements.label(blb)[0] # ndimage.label(blb)[0] | |
blb = remove_small_objects(blb, min_size=10) # 10 | |
blb[blb > 0] = 1 # background is 0 already | |
h_dir = cv2.normalize( | |
h_dir_raw, | |
None, | |
alpha=0, | |
beta=1, | |
norm_type=cv2.NORM_MINMAX, | |
dtype=cv2.CV_32F, | |
) | |
v_dir = cv2.normalize( | |
v_dir_raw, | |
None, | |
alpha=0, | |
beta=1, | |
norm_type=cv2.NORM_MINMAX, | |
dtype=cv2.CV_32F, | |
) | |
# ksize = int((20 * scale_factor) + 1) # 21 vs 41 | |
# obj_size = math.ceil(10 * (scale_factor**2)) #10 vs 40 | |
sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize) | |
sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize) | |
sobelh = 1 - ( | |
cv2.normalize( | |
sobelh, | |
None, | |
alpha=0, | |
beta=1, | |
norm_type=cv2.NORM_MINMAX, | |
dtype=cv2.CV_32F, | |
) | |
) | |
sobelv = 1 - ( | |
cv2.normalize( | |
sobelv, | |
None, | |
alpha=0, | |
beta=1, | |
norm_type=cv2.NORM_MINMAX, | |
dtype=cv2.CV_32F, | |
) | |
) | |
overall = np.maximum(sobelh, sobelv) | |
overall = overall - (1 - blb) | |
overall[overall < 0] = 0 | |
dist = (1.0 - overall) * blb | |
## nuclei values form mountains so inverse to get basins | |
dist = -cv2.GaussianBlur(dist, (3, 3), 0) | |
overall = np.array(overall >= 0.4, dtype=np.int32) | |
marker = blb - overall | |
marker[marker < 0] = 0 | |
marker = binary_fill_holes(marker).astype("uint8") | |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel) | |
marker = measurements.label(marker)[0] | |
marker = remove_small_objects(marker, min_size=object_size) | |
proced_pred = watershed(dist, markers=marker, mask=blb) | |
return proced_pred | |
def calculate_instances( | |
pred_types: torch.Tensor, pred_insts: torch.Tensor | |
) -> List[dict]: | |
"""Best used for GT | |
Args: | |
pred_types (torch.Tensor): Binary or type map ground-truth. | |
Shape must be (B, C, H, W) with C=1 for binary or num_nuclei_types for multi-class. | |
pred_insts (torch.Tensor): Ground-Truth instance map with shape (B, H, W) | |
Returns: | |
list[dict]: Dictionary with nuclei informations, output similar to post_process_cell_segmentation | |
""" | |
type_preds = [] | |
pred_types = pred_types.permute(0, 2, 3, 1) | |
for i in range(pred_types.shape[0]): | |
pred_type = torch.argmax(pred_types, dim=-1)[i].detach().cpu().numpy() | |
pred_inst = pred_insts[i].detach().cpu().numpy() | |
inst_id_list = np.unique(pred_inst)[1:] # exlcude background | |
inst_info_dict = {} | |
for inst_id in inst_id_list: | |
inst_map = pred_inst == inst_id | |
rmin, rmax, cmin, cmax = get_bounding_box(inst_map) | |
inst_bbox = np.array([[rmin, cmin], [rmax, cmax]]) | |
inst_map = inst_map[ | |
inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1] | |
] | |
inst_map = inst_map.astype(np.uint8) | |
inst_moment = cv2.moments(inst_map) | |
inst_contour = cv2.findContours( | |
inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE | |
) | |
# * opencv protocol format may break | |
inst_contour = np.squeeze(inst_contour[0][0].astype("int32")) | |
# < 3 points dont make a contour, so skip, likely artifact too | |
# as the contours obtained via approximation => too small or sthg | |
if inst_contour.shape[0] < 3: | |
continue | |
if len(inst_contour.shape) != 2: | |
continue # ! check for trickery shape | |
inst_centroid = [ | |
(inst_moment["m10"] / inst_moment["m00"]), | |
(inst_moment["m01"] / inst_moment["m00"]), | |
] | |
inst_centroid = np.array(inst_centroid) | |
inst_contour[:, 0] += inst_bbox[0][1] # X | |
inst_contour[:, 1] += inst_bbox[0][0] # Y | |
inst_centroid[0] += inst_bbox[0][1] # X | |
inst_centroid[1] += inst_bbox[0][0] # Y | |
inst_info_dict[inst_id] = { # inst_id should start at 1 | |
"bbox": inst_bbox, | |
"centroid": inst_centroid, | |
"contour": inst_contour, | |
"type_prob": None, | |
"type": None, | |
} | |
#### * Get class of each instance id, stored at index id-1 (inst_id = number of deteced nucleus) | |
for inst_id in list(inst_info_dict.keys()): | |
rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten() | |
inst_map_crop = pred_inst[rmin:rmax, cmin:cmax] | |
inst_type_crop = pred_type[rmin:rmax, cmin:cmax] | |
inst_map_crop = inst_map_crop == inst_id | |
inst_type = inst_type_crop[inst_map_crop] | |
type_list, type_pixels = np.unique(inst_type, return_counts=True) | |
type_list = list(zip(type_list, type_pixels)) | |
type_list = sorted(type_list, key=lambda x: x[1], reverse=True) | |
inst_type = type_list[0][0] | |
if inst_type == 0: # ! pick the 2nd most dominant if exist | |
if len(type_list) > 1: | |
inst_type = type_list[1][0] | |
type_dict = {v[0]: v[1] for v in type_list} | |
type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6) | |
inst_info_dict[inst_id]["type"] = int(inst_type) | |
inst_info_dict[inst_id]["type_prob"] = float(type_prob) | |
type_preds.append(inst_info_dict) | |
return type_preds | |