Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# PanNuke Dataset | |
# | |
# Dataset information: https://arxiv.org/abs/2108.11195 | |
# Please Prepare Dataset as described here: docs/readmes/pannuke.md # TODO: write own documentation | |
# | |
# @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
# Institute for Artifical Intelligence in Medicine, | |
# University Medicine Essen | |
import logging | |
from pathlib import Path | |
from typing import Callable, Tuple, Union, List | |
import numpy as np | |
import pandas as pd | |
import torch | |
from PIL import Image | |
from cell_segmentation.datasets.base_cell import CellDataset | |
from cell_segmentation.datasets.pannuke import PanNukeDataset | |
logger = logging.getLogger() | |
logger.addHandler(logging.NullHandler()) | |
class CoNicDataset(CellDataset): | |
"""Lizzard dataset | |
This dataset is always cached | |
Args: | |
dataset_path (Union[Path, str]): Path to Lizzard dataset. Structure is described under ./docs/readmes/cell_segmentation.md | |
folds (Union[int, list[int]]): Folds to use for this dataset | |
transforms (Callable, optional): PyTorch transformations. Defaults to None. | |
stardist (bool, optional): Return StarDist labels. Defaults to False | |
regression (bool, optional): Return Regression of cells in x and y direction. Defaults to False | |
**kwargs are irgnored | |
""" | |
def __init__( | |
self, | |
dataset_path: Union[Path, str], | |
folds: Union[int, List[int]], | |
transforms: Callable = None, | |
stardist: bool = False, | |
regression: bool = False, | |
**kwargs, | |
) -> None: | |
if isinstance(folds, int): | |
folds = [folds] | |
self.dataset = Path(dataset_path).resolve() | |
self.transforms = transforms | |
self.images = [] | |
self.masks = [] | |
self.img_names = [] | |
self.folds = folds | |
self.stardist = stardist | |
self.regression = regression | |
for fold in folds: | |
image_path = self.dataset / f"fold{fold}" / "images" | |
fold_images = [f for f in sorted(image_path.glob("*.png")) if f.is_file()] | |
# sanity_check: mask must exist for image | |
for fold_image in fold_images: | |
mask_path = ( | |
self.dataset / f"fold{fold}" / "labels" / f"{fold_image.stem}.npy" | |
) | |
if mask_path.is_file(): | |
self.images.append(fold_image) | |
self.masks.append(mask_path) | |
self.img_names.append(fold_image.name) | |
else: | |
logger.debug( | |
"Found image {fold_image}, but no corresponding annotation file!" | |
) | |
# load everything in advance to speedup, as the dataset is rather small | |
self.loaded_imgs = [] | |
self.loaded_masks = [] | |
for idx in range(len(self.images)): | |
img_path = self.images[idx] | |
img = np.array(Image.open(img_path)).astype(np.uint8) | |
mask_path = self.masks[idx] | |
mask = np.load(mask_path, allow_pickle=True) | |
inst_map = mask[()]["inst_map"].astype(np.int32) | |
type_map = mask[()]["type_map"].astype(np.int32) | |
mask = np.stack([inst_map, type_map], axis=-1) | |
self.loaded_imgs.append(img) | |
self.loaded_masks.append(mask) | |
logger.info(f"Created Pannuke Dataset by using fold(s) {self.folds}") | |
logger.info(f"Resulting dataset length: {self.__len__()}") | |
def __getitem__(self, index: int) -> Tuple[torch.Tensor, dict, str, str]: | |
"""Get one dataset item consisting of transformed image, | |
masks (instance_map, nuclei_type_map, nuclei_binary_map, hv_map) and tissue type as string | |
Args: | |
index (int): Index of element to retrieve | |
Returns: | |
Tuple[torch.Tensor, dict, str, str]: | |
torch.Tensor: Image, with shape (3, H, W), shape is arbitrary for Lizzard (H and W approx. between 500 and 2000) | |
dict: | |
"instance_map": Instance-Map, each instance is has one integer starting by 1 (zero is background), Shape (256, 256) | |
"nuclei_type_map": Nuclei-Type-Map, for each nucleus (instance) the class is indicated by an integer. Shape (256, 256) | |
"nuclei_binary_map": Binary Nuclei-Mask, Shape (256, 256) | |
"hv_map": Horizontal and vertical instance map. | |
Shape: (H, W, 2). First dimension is horizontal (horizontal gradient (-1 to 1)), | |
last is vertical (vertical gradient (-1 to 1)) Shape (256, 256, 2) | |
"dist_map": Probability distance map. Shape (256, 256) | |
"stardist_map": Stardist vector map. Shape (n_rays, 256, 256) | |
[Optional if regression] | |
"regression_map": Regression map. Shape (2, 256, 256). First is vertical, second horizontal. | |
str: Tissue type | |
str: Image Name | |
""" | |
img_path = self.images[index] | |
img = self.loaded_imgs[index] | |
mask = self.loaded_masks[index] | |
if self.transforms is not None: | |
transformed = self.transforms(image=img, mask=mask) | |
img = transformed["image"] | |
mask = transformed["mask"] | |
inst_map = mask[:, :, 0].copy() | |
type_map = mask[:, :, 1].copy() | |
np_map = mask[:, :, 0].copy() | |
np_map[np_map > 0] = 1 | |
hv_map = PanNukeDataset.gen_instance_hv_map(inst_map) | |
# torch convert | |
img = torch.Tensor(img).type(torch.float32) | |
img = img.permute(2, 0, 1) | |
if torch.max(img) >= 5: | |
img = img / 255 | |
masks = { | |
"instance_map": torch.Tensor(inst_map).type(torch.int64), | |
"nuclei_type_map": torch.Tensor(type_map).type(torch.int64), | |
"nuclei_binary_map": torch.Tensor(np_map).type(torch.int64), | |
"hv_map": torch.Tensor(hv_map).type(torch.float32), | |
} | |
if self.stardist: | |
dist_map = PanNukeDataset.gen_distance_prob_maps(inst_map) | |
stardist_map = PanNukeDataset.gen_stardist_maps(inst_map) | |
masks["dist_map"] = torch.Tensor(dist_map).type(torch.float32) | |
masks["stardist_map"] = torch.Tensor(stardist_map).type(torch.float32) | |
if self.regression: | |
masks["regression_map"] = PanNukeDataset.gen_regression_map(inst_map) | |
return img, masks, "Colon", Path(img_path).name | |
def __len__(self) -> int: | |
"""Length of Dataset | |
Returns: | |
int: Length of Dataset | |
""" | |
return len(self.images) | |
def set_transforms(self, transforms: Callable) -> None: | |
"""Set the transformations, can be used tp exchange transformations | |
Args: | |
transforms (Callable): PyTorch transformations | |
""" | |
self.transforms = transforms | |
def load_cell_count(self): | |
"""Load Cell count from cell_count.csv file. File must be located inside the fold folder | |
and named "cell_count.csv" | |
Example file beginning: | |
Image,Neutrophil,Epithelial,Lymphocyte,Plasma,Eosinophil,Connective | |
consep_1_0000.png,0,117,0,0,0,0 | |
consep_1_0001.png,0,95,1,0,0,8 | |
consep_1_0002.png,0,172,3,0,0,2 | |
... | |
""" | |
df_placeholder = [] | |
for fold in self.folds: | |
csv_path = self.dataset / f"fold{fold}" / "cell_count.csv" | |
cell_count = pd.read_csv(csv_path, index_col=0) | |
df_placeholder.append(cell_count) | |
self.cell_count = pd.concat(df_placeholder) | |
self.cell_count = self.cell_count.reindex(self.img_names) | |
def get_sampling_weights_cell(self, gamma: float = 1) -> torch.Tensor: | |
"""Get sampling weights calculated by cell type statistics | |
Args: | |
gamma (float, optional): Gamma scaling factor, between 0 and 1. | |
1 means total balancing, 0 means original weights. Defaults to 1. | |
Returns: | |
torch.Tensor: Weights for each sample | |
""" | |
assert 0 <= gamma <= 1, "Gamma must be between 0 and 1" | |
assert hasattr(self, "cell_count"), "Please run .load_cell_count() in advance!" | |
binary_weight_factors = np.array([1069, 4189, 4356, 3103, 1025, 4527]) | |
k = np.sum(binary_weight_factors) | |
cell_counts_imgs = np.clip(self.cell_count.to_numpy(), 0, 1) | |
weight_vector = k / (gamma * binary_weight_factors + (1 - gamma) * k) | |
img_weight = (1 - gamma) * np.max(cell_counts_imgs, axis=-1) + gamma * np.sum( | |
cell_counts_imgs * weight_vector, axis=-1 | |
) | |
img_weight[np.where(img_weight == 0)] = np.min( | |
img_weight[np.nonzero(img_weight)] | |
) | |
return torch.Tensor(img_weight) | |
# def get_sampling_weights_cell(self, gamma: float = 1) -> torch.Tensor: | |
# """Get sampling weights calculated by cell type statistics | |
# Args: | |
# gamma (float, optional): Gamma scaling factor, between 0 and 1. | |
# 1 means total balancing, 0 means original weights. Defaults to 1. | |
# Returns: | |
# torch.Tensor: Weights for each sample | |
# """ | |
# assert 0 <= gamma <= 1, "Gamma must be between 0 and 1" | |
# assert hasattr(self, "cell_count"), "Please run .load_cell_count() in advance!" | |
# binary_weight_factors = np.array([4012, 222017, 93612, 24793, 2999, 98783]) | |
# k = np.sum(binary_weight_factors) | |
# cell_counts_imgs = self.cell_count.to_numpy() | |
# weight_vector = k / (gamma * binary_weight_factors + (1 - gamma) * k) | |
# img_weight = (1 - gamma) * np.max(cell_counts_imgs, axis=-1) + gamma * np.sum( | |
# cell_counts_imgs * weight_vector, axis=-1 | |
# ) | |
# img_weight[np.where(img_weight == 0)] = np.min( | |
# img_weight[np.nonzero(img_weight)] | |
# ) | |
# return torch.Tensor(img_weight) | |