Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# PanNuke Dataset | |
# | |
# Dataset information: https://arxiv.org/abs/2003.10778 | |
# Please Prepare Dataset as described here: docs/readmes/pannuke.md | |
# | |
# @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
# Institute for Artifical Intelligence in Medicine, | |
# University Medicine Essen | |
import logging | |
import sys # remove | |
from pathlib import Path | |
from typing import Callable, Tuple, Union, List | |
sys.path.append("/homes/fhoerst/histo-projects/CellViT/") # remove | |
import numpy as np | |
import pandas as pd | |
import torch | |
import yaml | |
from numba import njit | |
from PIL import Image | |
from scipy.ndimage import center_of_mass, distance_transform_edt | |
from cell_segmentation.datasets.base_cell import CellDataset | |
from cell_segmentation.utils.tools import fix_duplicates, get_bounding_box | |
logger = logging.getLogger() | |
logger.addHandler(logging.NullHandler()) | |
from natsort import natsorted | |
class PanNukeDataset(CellDataset): | |
"""PanNuke dataset | |
Args: | |
dataset_path (Union[Path, str]): Path to PanNuke dataset. Structure is described under ./docs/readmes/cell_segmentation.md | |
folds (Union[int, list[int]]): Folds to use for this dataset | |
transforms (Callable, optional): PyTorch transformations. Defaults to None. | |
stardist (bool, optional): Return StarDist labels. Defaults to False | |
regression (bool, optional): Return Regression of cells in x and y direction. Defaults to False | |
cache_dataset: If the dataset should be loaded to host memory in first epoch. | |
Be careful, workers in DataLoader needs to be persistent to have speedup. | |
Recommended to false, just use if you have enough RAM and your I/O operations might be limited. | |
Defaults to False. | |
""" | |
def __init__( | |
self, | |
dataset_path: Union[Path, str], | |
folds: Union[int, List[int]], | |
transforms: Callable = None, | |
stardist: bool = False, | |
regression: bool = False, | |
cache_dataset: bool = False, | |
) -> None: | |
if isinstance(folds, int): | |
folds = [folds] | |
self.dataset = Path(dataset_path).resolve() | |
self.transforms = transforms | |
self.images = [] | |
self.masks = [] | |
self.types = {} | |
self.img_names = [] | |
self.folds = folds | |
self.cache_dataset = cache_dataset | |
self.stardist = stardist | |
self.regression = regression | |
for fold in folds: | |
image_path = self.dataset / f"fold{fold}" / "images" | |
fold_images = [ | |
f for f in natsorted(image_path.glob("*.png")) if f.is_file() | |
] | |
# sanity_check: mask must exist for image | |
for fold_image in fold_images: | |
mask_path = ( | |
self.dataset / f"fold{fold}" / "labels" / f"{fold_image.stem}.npy" | |
) | |
if mask_path.is_file(): | |
self.images.append(fold_image) | |
self.masks.append(mask_path) | |
self.img_names.append(fold_image.name) | |
else: | |
logger.debug( | |
"Found image {fold_image}, but no corresponding annotation file!" | |
) | |
fold_types = pd.read_csv(self.dataset / f"fold{fold}" / "types.csv") | |
fold_type_dict = fold_types.set_index("img")["type"].to_dict() | |
self.types = { | |
**self.types, | |
**fold_type_dict, | |
} # careful - should all be named differently | |
logger.info(f"Created Pannuke Dataset by using fold(s) {self.folds}") | |
logger.info(f"Resulting dataset length: {self.__len__()}") | |
if self.cache_dataset: | |
self.cached_idx = [] # list of idx that should be cached | |
self.cached_imgs = {} # keys: idx, values: numpy array of imgs | |
self.cached_masks = {} # keys: idx, values: numpy array of masks | |
logger.info("Using cached dataset. Cache is built up during first epoch.") | |
def __getitem__(self, index: int) -> Tuple[torch.Tensor, dict, str, str]: | |
"""Get one dataset item consisting of transformed image, | |
masks (instance_map, nuclei_type_map, nuclei_binary_map, hv_map) and tissue type as string | |
Args: | |
index (int): Index of element to retrieve | |
Returns: | |
Tuple[torch.Tensor, dict, str, str]: | |
torch.Tensor: Image, with shape (3, H, W), in this case (3, 256, 256) | |
dict: | |
"instance_map": Instance-Map, each instance is has one integer starting by 1 (zero is background), Shape (256, 256) | |
"nuclei_type_map": Nuclei-Type-Map, for each nucleus (instance) the class is indicated by an integer. Shape (256, 256) | |
"nuclei_binary_map": Binary Nuclei-Mask, Shape (256, 256) | |
"hv_map": Horizontal and vertical instance map. | |
Shape: (2 , H, W). First dimension is horizontal (horizontal gradient (-1 to 1)), | |
last is vertical (vertical gradient (-1 to 1)) Shape (2, 256, 256) | |
[Optional if stardist] | |
"dist_map": Probability distance map. Shape (256, 256) | |
"stardist_map": Stardist vector map. Shape (n_rays, 256, 256) | |
[Optional if regression] | |
"regression_map": Regression map. Shape (2, 256, 256). First is vertical, second horizontal. | |
str: Tissue type | |
str: Image Name | |
""" | |
img_path = self.images[index] | |
if self.cache_dataset: | |
if index in self.cached_idx: | |
img = self.cached_imgs[index] | |
mask = self.cached_masks[index] | |
else: | |
# cache file | |
img = self.load_imgfile(index) | |
mask = self.load_maskfile(index) | |
self.cached_imgs[index] = img | |
self.cached_masks[index] = mask | |
self.cached_idx.append(index) | |
else: | |
img = self.load_imgfile(index) | |
mask = self.load_maskfile(index) | |
if self.transforms is not None: | |
transformed = self.transforms(image=img, mask=mask) | |
img = transformed["image"] | |
mask = transformed["mask"] | |
tissue_type = self.types[img_path.name] | |
inst_map = mask[:, :, 0].copy() | |
type_map = mask[:, :, 1].copy() | |
np_map = mask[:, :, 0].copy() | |
np_map[np_map > 0] = 1 | |
hv_map = PanNukeDataset.gen_instance_hv_map(inst_map) | |
# torch convert | |
img = torch.Tensor(img).type(torch.float32) | |
img = img.permute(2, 0, 1) | |
if torch.max(img) >= 5: | |
img = img / 255 | |
masks = { | |
"instance_map": torch.Tensor(inst_map).type(torch.int64), | |
"nuclei_type_map": torch.Tensor(type_map).type(torch.int64), | |
"nuclei_binary_map": torch.Tensor(np_map).type(torch.int64), | |
"hv_map": torch.Tensor(hv_map).type(torch.float32), | |
} | |
# load stardist transforms if neccessary | |
if self.stardist: | |
dist_map = PanNukeDataset.gen_distance_prob_maps(inst_map) | |
stardist_map = PanNukeDataset.gen_stardist_maps(inst_map) | |
masks["dist_map"] = torch.Tensor(dist_map).type(torch.float32) | |
masks["stardist_map"] = torch.Tensor(stardist_map).type(torch.float32) | |
if self.regression: | |
masks["regression_map"] = PanNukeDataset.gen_regression_map(inst_map) | |
return img, masks, tissue_type, Path(img_path).name | |
def __len__(self) -> int: | |
"""Length of Dataset | |
Returns: | |
int: Length of Dataset | |
""" | |
return len(self.images) | |
def set_transforms(self, transforms: Callable) -> None: | |
"""Set the transformations, can be used tp exchange transformations | |
Args: | |
transforms (Callable): PyTorch transformations | |
""" | |
self.transforms = transforms | |
def load_imgfile(self, index: int) -> np.ndarray: | |
"""Load image from file (disk) | |
Args: | |
index (int): Index of file | |
Returns: | |
np.ndarray: Image as array with shape (H, W, 3) | |
""" | |
img_path = self.images[index] | |
return np.array(Image.open(img_path)).astype(np.uint8) | |
def load_maskfile(self, index: int) -> np.ndarray: | |
"""Load mask from file (disk) | |
Args: | |
index (int): Index of file | |
Returns: | |
np.ndarray: Mask as array with shape (H, W, 2) | |
""" | |
mask_path = self.masks[index] | |
mask = np.load(mask_path, allow_pickle=True) | |
inst_map = mask[()]["inst_map"].astype(np.int32) | |
type_map = mask[()]["type_map"].astype(np.int32) | |
mask = np.stack([inst_map, type_map], axis=-1) | |
return mask | |
def load_cell_count(self): | |
"""Load Cell count from cell_count.csv file. File must be located inside the fold folder | |
and named "cell_count.csv" | |
Example file beginning: | |
Image,Neoplastic,Inflammatory,Connective,Dead,Epithelial | |
0_0.png,4,2,2,0,0 | |
0_1.png,8,1,1,0,0 | |
0_10.png,17,0,1,0,0 | |
0_100.png,10,0,11,0,0 | |
... | |
""" | |
df_placeholder = [] | |
for fold in self.folds: | |
csv_path = self.dataset / f"fold{fold}" / "cell_count.csv" | |
cell_count = pd.read_csv(csv_path, index_col=0) | |
df_placeholder.append(cell_count) | |
self.cell_count = pd.concat(df_placeholder) | |
self.cell_count = self.cell_count.reindex(self.img_names) | |
def get_sampling_weights_tissue(self, gamma: float = 1) -> torch.Tensor: | |
"""Get sampling weights calculated by tissue type statistics | |
For this, a file named "weight_config.yaml" with the content: | |
tissue: | |
tissue_1: xxx | |
tissue_2: xxx (name of tissue: count) | |
... | |
Must exists in the dataset main folder (parent path, not inside the folds) | |
Args: | |
gamma (float, optional): Gamma scaling factor, between 0 and 1. | |
1 means total balancing, 0 means original weights. Defaults to 1. | |
Returns: | |
torch.Tensor: Weights for each sample | |
""" | |
assert 0 <= gamma <= 1, "Gamma must be between 0 and 1" | |
with open( | |
(self.dataset / "weight_config.yaml").resolve(), "r" | |
) as run_config_file: | |
yaml_config = yaml.safe_load(run_config_file) | |
tissue_counts = dict(yaml_config)["tissue"] | |
# calculate weight for each tissue | |
weights_dict = {} | |
k = np.sum(list(tissue_counts.values())) | |
for tissue, count in tissue_counts.items(): | |
w = k / (gamma * count + (1 - gamma) * k) | |
weights_dict[tissue] = w | |
weights = [] | |
for idx in range(self.__len__()): | |
img_idx = self.img_names[idx] | |
type_str = self.types[img_idx] | |
weights.append(weights_dict[type_str]) | |
return torch.Tensor(weights) | |
def get_sampling_weights_cell(self, gamma: float = 1) -> torch.Tensor: | |
"""Get sampling weights calculated by cell type statistics | |
Args: | |
gamma (float, optional): Gamma scaling factor, between 0 and 1. | |
1 means total balancing, 0 means original weights. Defaults to 1. | |
Returns: | |
torch.Tensor: Weights for each sample | |
""" | |
assert 0 <= gamma <= 1, "Gamma must be between 0 and 1" | |
assert hasattr(self, "cell_count"), "Please run .load_cell_count() in advance!" | |
binary_weight_factors = np.array([4191, 4132, 6140, 232, 1528]) | |
k = np.sum(binary_weight_factors) | |
cell_counts_imgs = np.clip(self.cell_count.to_numpy(), 0, 1) | |
weight_vector = k / (gamma * binary_weight_factors + (1 - gamma) * k) | |
img_weight = (1 - gamma) * np.max(cell_counts_imgs, axis=-1) + gamma * np.sum( | |
cell_counts_imgs * weight_vector, axis=-1 | |
) | |
img_weight[np.where(img_weight == 0)] = np.min( | |
img_weight[np.nonzero(img_weight)] | |
) | |
return torch.Tensor(img_weight) | |
def get_sampling_weights_cell_tissue(self, gamma: float = 1) -> torch.Tensor: | |
"""Get combined sampling weights by calculating tissue and cell sampling weights, | |
normalizing them and adding them up to yield one score. | |
Args: | |
gamma (float, optional): Gamma scaling factor, between 0 and 1. | |
1 means total balancing, 0 means original weights. Defaults to 1. | |
Returns: | |
torch.Tensor: Weights for each sample | |
""" | |
assert 0 <= gamma <= 1, "Gamma must be between 0 and 1" | |
tw = self.get_sampling_weights_tissue(gamma) | |
cw = self.get_sampling_weights_cell(gamma) | |
weights = tw / torch.max(tw) + cw / torch.max(cw) | |
return weights | |
def gen_instance_hv_map(inst_map: np.ndarray) -> np.ndarray: | |
"""Obtain the horizontal and vertical distance maps for each | |
nuclear instance. | |
Args: | |
inst_map (np.ndarray): Instance map with each instance labelled as a unique integer | |
Shape: (H, W) | |
Returns: | |
np.ndarray: Horizontal and vertical instance map. | |
Shape: (2, H, W). First dimension is horizontal (horizontal gradient (-1 to 1)), | |
last is vertical (vertical gradient (-1 to 1)) | |
""" | |
orig_inst_map = inst_map.copy() # instance ID map | |
x_map = np.zeros(orig_inst_map.shape[:2], dtype=np.float32) | |
y_map = np.zeros(orig_inst_map.shape[:2], dtype=np.float32) | |
inst_list = list(np.unique(orig_inst_map)) | |
inst_list.remove(0) # 0 is background | |
for inst_id in inst_list: | |
inst_map = np.array(orig_inst_map == inst_id, np.uint8) | |
inst_box = get_bounding_box(inst_map) | |
# expand the box by 2px | |
# Because we first pad the ann at line 207, the bboxes | |
# will remain valid after expansion | |
if inst_box[0] >= 2: | |
inst_box[0] -= 2 | |
if inst_box[2] >= 2: | |
inst_box[2] -= 2 | |
if inst_box[1] <= orig_inst_map.shape[0] - 2: | |
inst_box[1] += 2 | |
if inst_box[3] <= orig_inst_map.shape[0] - 2: | |
inst_box[3] += 2 | |
# improvement | |
inst_map = inst_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] | |
if inst_map.shape[0] < 2 or inst_map.shape[1] < 2: | |
continue | |
# instance center of mass, rounded to nearest pixel | |
inst_com = list(center_of_mass(inst_map)) | |
inst_com[0] = int(inst_com[0] + 0.5) | |
inst_com[1] = int(inst_com[1] + 0.5) | |
inst_x_range = np.arange(1, inst_map.shape[1] + 1) | |
inst_y_range = np.arange(1, inst_map.shape[0] + 1) | |
# shifting center of pixels grid to instance center of mass | |
inst_x_range -= inst_com[1] | |
inst_y_range -= inst_com[0] | |
inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range) | |
# remove coord outside of instance | |
inst_x[inst_map == 0] = 0 | |
inst_y[inst_map == 0] = 0 | |
inst_x = inst_x.astype("float32") | |
inst_y = inst_y.astype("float32") | |
# normalize min into -1 scale | |
if np.min(inst_x) < 0: | |
inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0]) | |
if np.min(inst_y) < 0: | |
inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0]) | |
# normalize max into +1 scale | |
if np.max(inst_x) > 0: | |
inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0]) | |
if np.max(inst_y) > 0: | |
inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0]) | |
#### | |
x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] | |
x_map_box[inst_map > 0] = inst_x[inst_map > 0] | |
y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] | |
y_map_box[inst_map > 0] = inst_y[inst_map > 0] | |
hv_map = np.stack([x_map, y_map]) | |
return hv_map | |
def gen_distance_prob_maps(inst_map: np.ndarray) -> np.ndarray: | |
"""Generate distance probability maps | |
Args: | |
inst_map (np.ndarray): Instance-Map, each instance is has one integer starting by 1 (zero is background), Shape (H, W) | |
Returns: | |
np.ndarray: Distance probability map, shape (H, W) | |
""" | |
inst_map = fix_duplicates(inst_map) | |
dist = np.zeros_like(inst_map, dtype=np.float64) | |
inst_list = list(np.unique(inst_map)) | |
if 0 in inst_list: | |
inst_list.remove(0) | |
for inst_id in inst_list: | |
inst = np.array(inst_map == inst_id, np.uint8) | |
y1, y2, x1, x2 = get_bounding_box(inst) | |
y1 = y1 - 2 if y1 - 2 >= 0 else y1 | |
x1 = x1 - 2 if x1 - 2 >= 0 else x1 | |
x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2 | |
y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2 | |
inst = inst[y1:y2, x1:x2] | |
if inst.shape[0] < 2 or inst.shape[1] < 2: | |
continue | |
# chessboard distance map generation | |
# normalize distance to 0-1 | |
inst_dist = distance_transform_edt(inst) | |
inst_dist = inst_dist.astype("float64") | |
max_value = np.amax(inst_dist) | |
if max_value <= 0: | |
continue | |
inst_dist = inst_dist / (np.max(inst_dist) + 1e-10) | |
dist_map_box = dist[y1:y2, x1:x2] | |
dist_map_box[inst > 0] = inst_dist[inst > 0] | |
return dist | |
def gen_stardist_maps(inst_map: np.ndarray) -> np.ndarray: | |
"""Generate StarDist map with 32 nrays | |
Args: | |
inst_map (np.ndarray): Instance-Map, each instance is has one integer starting by 1 (zero is background), Shape (H, W) | |
Returns: | |
np.ndarray: Stardist vector map, shape (n_rays, H, W) | |
""" | |
n_rays = 32 | |
# inst_map = fix_duplicates(inst_map) | |
dist = np.empty(inst_map.shape + (n_rays,), np.float32) | |
st_rays = np.float32((2 * np.pi) / n_rays) | |
for i in range(inst_map.shape[0]): | |
for j in range(inst_map.shape[1]): | |
value = inst_map[i, j] | |
if value == 0: | |
dist[i, j] = 0 | |
else: | |
for k in range(n_rays): | |
phi = np.float32(k * st_rays) | |
dy = np.cos(phi) | |
dx = np.sin(phi) | |
x, y = np.float32(0), np.float32(0) | |
while True: | |
x += dx | |
y += dy | |
ii = int(round(i + x)) | |
jj = int(round(j + y)) | |
if ( | |
ii < 0 | |
or ii >= inst_map.shape[0] | |
or jj < 0 | |
or jj >= inst_map.shape[1] | |
or value != inst_map[ii, jj] | |
): | |
# small correction as we overshoot the boundary | |
t_corr = 1 - 0.5 / max(np.abs(dx), np.abs(dy)) | |
x -= t_corr * dx | |
y -= t_corr * dy | |
dst = np.sqrt(x**2 + y**2) | |
dist[i, j, k] = dst | |
break | |
return dist.transpose(2, 0, 1) | |
def gen_regression_map(inst_map: np.ndarray): | |
n_directions = 2 | |
dist = np.zeros(inst_map.shape + (n_directions,), np.float32).transpose(2, 0, 1) | |
inst_map = fix_duplicates(inst_map) | |
inst_list = list(np.unique(inst_map)) | |
if 0 in inst_list: | |
inst_list.remove(0) | |
for inst_id in inst_list: | |
inst = np.array(inst_map == inst_id, np.uint8) | |
y1, y2, x1, x2 = get_bounding_box(inst) | |
y1 = y1 - 2 if y1 - 2 >= 0 else y1 | |
x1 = x1 - 2 if x1 - 2 >= 0 else x1 | |
x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2 | |
y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2 | |
inst = inst[y1:y2, x1:x2] | |
y_mass, x_mass = center_of_mass(inst) | |
x_map = np.repeat(np.arange(1, x2 - x1 + 1)[None, :], y2 - y1, axis=0) | |
y_map = np.repeat(np.arange(1, y2 - y1 + 1)[:, None], x2 - x1, axis=1) | |
# we use a transposed coordinate system to align to HV-map, correct would be -1*x_dist_map and -1*y_dist_map | |
x_dist_map = (x_map - x_mass) * np.clip(inst, 0, 1) | |
y_dist_map = (y_map - y_mass) * np.clip(inst, 0, 1) | |
dist[0, y1:y2, x1:x2] = x_dist_map | |
dist[1, y1:y2, x1:x2] = y_dist_map | |
return dist | |