# -*- coding: utf-8 -*- # Coordinate the datasets, used to select the right dataset with corresponding setting # # @ Fabian Hörst, fabian.hoerst@uk-essen.de # Institute for Artifical Intelligence in Medicine, # University Medicine Essen from typing import Callable from torch.utils.data import Dataset from cell_segmentation.datasets.conic import CoNicDataset from cell_segmentation.datasets.pannuke import PanNukeDataset def select_dataset( dataset_name: str, split: str, dataset_config: dict, transforms: Callable = None ) -> Dataset: """Select a cell segmentation dataset from the provided ones, currently just PanNuke is implemented here Args: dataset_name (str): Name of dataset to use. Must be one of: [pannuke, lizzard] split (str): Split to use. Must be one of: ["train", "val", "validation", "test"] dataset_config (dict): Dictionary with dataset configuration settings transforms (Callable, optional): PyTorch Image and Mask transformations. Defaults to None. Raises: NotImplementedError: Unknown dataset Returns: Dataset: Cell segmentation dataset """ assert split.lower() in [ "train", "val", "validation", "test", ], "Unknown split type!" if dataset_name.lower() == "pannuke": if split == "train": folds = dataset_config["train_folds"] if split == "val" or split == "validation": folds = dataset_config["val_folds"] if split == "test": folds = dataset_config["test_folds"] dataset = PanNukeDataset( dataset_path=dataset_config["dataset_path"], folds=folds, transforms=transforms, stardist=dataset_config.get("stardist", False), regression=dataset_config.get("regression_loss", False), ) elif dataset_name.lower() == "conic": if split == "train": folds = dataset_config["train_folds"] if split == "val" or split == "validation": folds = dataset_config["val_folds"] if split == "test": folds = dataset_config["test_folds"] dataset = CoNicDataset( dataset_path=dataset_config["dataset_path"], folds=folds, transforms=transforms, stardist=dataset_config.get("stardist", False), regression=dataset_config.get("regression_loss", False), # TODO: Stardist and regression loss ) else: raise NotImplementedError(f"Unknown dataset: {dataset_name}") return dataset