Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Base cell segmentation dataset, based on torch Dataset implementation | |
| # | |
| # @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
| # Institute for Artifical Intelligence in Medicine, | |
| # University Medicine Essen | |
| import logging | |
| from typing import Callable | |
| import torch | |
| from torch.utils.data import Dataset | |
| logger = logging.getLogger() | |
| logger.addHandler(logging.NullHandler()) | |
| from abc import abstractmethod | |
| class CellDataset(Dataset): | |
| def set_transforms(self, transforms: Callable) -> None: | |
| self.transforms = transforms | |
| def load_cell_count(self): | |
| """Load Cell count from cell_count.csv file. File must be located inside the fold folder | |
| Example file beginning: | |
| Image,Neoplastic,Inflammatory,Connective,Dead,Epithelial | |
| 0_0.png,4,2,2,0,0 | |
| 0_1.png,8,1,1,0,0 | |
| 0_10.png,17,0,1,0,0 | |
| 0_100.png,10,0,11,0,0 | |
| ... | |
| """ | |
| pass | |
| def get_sampling_weights_tissue(self, gamma: float = 1) -> torch.Tensor: | |
| """Get sampling weights calculated by tissue type statistics | |
| For this, a file named "weight_config.yaml" with the content: | |
| tissue: | |
| tissue_1: xxx | |
| tissue_2: xxx (name of tissue: count) | |
| ... | |
| Must exists in the dataset main folder (parent path, not inside the folds) | |
| Args: | |
| gamma (float, optional): Gamma scaling factor, between 0 and 1. | |
| 1 means total balancing, 0 means original weights. Defaults to 1. | |
| Returns: | |
| torch.Tensor: Weights for each sample | |
| """ | |
| def get_sampling_weights_cell(self, gamma: float = 1) -> torch.Tensor: | |
| """Get sampling weights calculated by cell type statistics | |
| Args: | |
| gamma (float, optional): Gamma scaling factor, between 0 and 1. | |
| 1 means total balancing, 0 means original weights. Defaults to 1. | |
| Returns: | |
| torch.Tensor: Weights for each sample | |
| """ | |
| def get_sampling_weights_cell_tissue(self, gamma: float = 1) -> torch.Tensor: | |
| """Get combined sampling weights by calculating tissue and cell sampling weights, | |
| normalizing them and adding them up to yield one score. | |
| Args: | |
| gamma (float, optional): Gamma scaling factor, between 0 and 1. | |
| 1 means total balancing, 0 means original weights. Defaults to 1. | |
| Returns: | |
| torch.Tensor: Weights for each sample | |
| """ | |
| assert 0 <= gamma <= 1, "Gamma must be between 0 and 1" | |
| tw = self.get_sampling_weights_tissue(gamma) | |
| cw = self.get_sampling_weights_cell(gamma) | |
| weights = tw / torch.max(tw) + cw / torch.max(cw) | |
| return weights | |