# -*- coding: utf-8 -*- # MoNuSeg Dataset # # Dataset information: https://monuseg.grand-challenge.org/Home/ # Please Prepare Dataset as described here: docs/readmes/monuseg.md # # @ Fabian Hörst, fabian.hoerst@uk-essen.de # Institute for Artifical Intelligence in Medicine, # University Medicine Essen import logging from pathlib import Path from typing import Callable, Union, Tuple import numpy as np import torch from PIL import Image from torch.utils.data import Dataset from cell_segmentation.datasets.pannuke import PanNukeDataset from einops import rearrange logger = logging.getLogger() logger.addHandler(logging.NullHandler()) class MoNuSegDataset(Dataset): def __init__( self, dataset_path: Union[Path, str], transforms: Callable = None, patching: bool = False, overlap: int = 0, ) -> None: """MoNuSeg Dataset Args: dataset_path (Union[Path, str]): Path to dataset transforms (Callable, optional): Transformations to apply on images. Defaults to None. patching (bool, optional): If patches with size 256px should be used Otherwise, the entire MoNuSeg images are loaded. Defaults to False. overlap: (bool, optional): If overlap should be used for patch sampling. Overlap in pixels. Recommended value other than 0 is 64. Defaults to 0. Raises: FileNotFoundError: If no ground-truth annotation file was found in path """ self.dataset = Path(dataset_path).resolve() self.transforms = transforms self.masks = [] self.img_names = [] self.patching = patching self.overlap = overlap image_path = self.dataset / "images" label_path = self.dataset / "labels" self.images = [f for f in sorted(image_path.glob("*.png")) if f.is_file()] self.masks = [f for f in sorted(label_path.glob("*.npy")) if f.is_file()] # sanity_check for idx, image in enumerate(self.images): image_name = image.stem mask_name = self.masks[idx].stem if image_name != mask_name: raise FileNotFoundError(f"Annotation for file {image_name} is missing") def __getitem__(self, index: int) -> Tuple[torch.Tensor, dict, str]: """Get one item from dataset Args: index (int): Item to get Returns: Tuple[torch.Tensor, dict, str]: Trainings-Batch * torch.Tensor: Image * dict: Ground-Truth values: keys are "instance map", "nuclei_binary_map" and "hv_map" * str: filename """ img_path = self.images[index] img = np.array(Image.open(img_path)).astype(np.uint8) mask_path = self.masks[index] mask = np.load(mask_path, allow_pickle=True) mask = mask.astype(np.int64) if self.transforms is not None: transformed = self.transforms(image=img, mask=mask) img = transformed["image"] mask = transformed["mask"] hv_map = PanNukeDataset.gen_instance_hv_map(mask) np_map = mask.copy() np_map[np_map > 0] = 1 # torch convert img = torch.Tensor(img).type(torch.float32) img = img.permute(2, 0, 1) if torch.max(img) >= 5: img = img / 255 if self.patching and self.overlap == 0: img = rearrange(img, "c (h i) (w j) -> c h w i j", i=256, j=256) if self.patching and self.overlap != 0: img = img.unfold(1, 256, 256 - self.overlap).unfold( 2, 256, 256 - self.overlap ) masks = { "instance_map": torch.Tensor(mask).type(torch.int64), "nuclei_binary_map": torch.Tensor(np_map).type(torch.int64), "hv_map": torch.Tensor(hv_map).type(torch.float32), } return img, masks, Path(img_path).name def __len__(self) -> int: """Length of Dataset Returns: int: Length of Dataset """ return len(self.images) def set_transforms(self, transforms: Callable) -> None: """Set the transformations, can be used tp exchange transformations Args: transforms (Callable): PyTorch transformations """ self.transforms = transforms