import os from typing import List, Optional, Tuple import pandas as pd from skimage import io import numpy as np import torch from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split from torchvision.transforms import transforms class FocusDataSet(Dataset): """Dataset for z-stacked images of neglected tropical diseaeses.""" def __init__( self, csv_file, root_dir, transform=None, in_memory=True, additional_col_list=[] ): """Initialize focus satck dataset. Args: csv_file (string): Path to the csv file with annotations. root_dir (string): Directory with all the images. transform (callable, optional): Optional transform to be applied on a sample. """ self.metadata = pd.read_csv(csv_file) self.in_memory = in_memory self.additional_col_index = {} _col_list = list(additional_col_list) # clone list to avoid modifying default for attribute in _col_list: self.additional_col_index[attribute] = self.metadata.columns.get_loc( attribute ) self.col_index_path = self.metadata.columns.get_loc("image_path") self.col_index_focus = self.metadata.columns.get_loc("focus_height") self.root_dir = root_dir self.transform = transform self.images = [] if self.in_memory: self.images = np.array( list(map(self._load_img, self.metadata["image_path"].tolist())) ) def _load_img(self, img_path): path = os.path.join(self.root_dir, img_path) img = io.imread(path) return img def __len__(self) -> int: """Get the length of the dataset. Returns: int: the length """ return len(self.metadata) def __getitem__(self, idx): """Get one items from the dataset. Args: idx (int) The index of the sample that is to be retrieved Returns: Item/Items which is a dictionary containing "image" and "focus_height" """ if torch.is_tensor(idx): idx = idx.tolist() if self.in_memory: image = self.images[idx] else: image = self._load_img(self.metadata.iloc[idx, self.col_index_path]) if self.transform: image = self.transform(image) focus_height = torch.from_numpy( np.asarray(self.metadata.iloc[idx, self.col_index_focus]) ).float() sample = {"image": image, "focus_height": focus_height} for attr, col_idx in self.additional_col_index.items(): sample[attr] = self.metadata.iloc[idx, col_idx] return sample class FocusDataModule(LightningDataModule): """ LightningDataModule for FocusStack dataset. """ def __init__( self, data_dir: str = "data/", csv_train_file: str = "data/train_metadata.csv", csv_val_file: str = "data/validation_metadata.csv", csv_test_file: str = "data/test_metadata.csv", batch_size: int = 64, num_workers: int = 0, pin_memory: bool = False, in_memory: bool = True, augmentation: bool = False, additional_col_list: List[str] = [], ): super().__init__() # this line allows to access init params with 'self.hparams' attribute self.save_hyperparameters(logger=False) transform_list = [ transforms.ToTensor(), transforms.ConvertImageDtype(torch.float), ] self.base_transforms = [] self.base_transforms.extend(transform_list) self.base_transforms = transforms.Compose(self.base_transforms) if augmentation: transform_list.extend( [ transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.5), transforms.RandomChoice( [ transforms.RandomApply( [transforms.RandomRotation((90, 90))], p=0.5 ), transforms.RandomApply( [transforms.RandomRotation((180, 180))], p=0.5 ), transforms.RandomApply( [transforms.RandomRotation((270, 270))], p=0.5 ), ] ), ] ) # data transformations self.transforms = transforms.Compose(transform_list) self.data_train: Optional[Dataset] = None self.data_val: Optional[Dataset] = None self.data_test: Optional[Dataset] = None self.in_memory = in_memory self.additional_col_list = additional_col_list def prepare_data(self): """This method is not implemented as of yet. Download data if needed. This method is called only from a single GPU. Do not use it to assign state (self.x = y). """ pass def setup(self, stage: Optional[str] = None): """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. This method is called by lightning twice for `trainer.fit()` and `trainer.test()`, so be careful if you do a random split! The `stage` can be used to differentiate whether it's called before trainer.fit()` or `trainer.test()`.""" # load datasets only if they're not loaded already if not self.data_train and not self.data_val and not self.data_test: self.data_train = FocusDataSet( self.hparams.csv_train_file, self.hparams.data_dir, transform=self.transforms, in_memory=self.in_memory, additional_col_list=self.additional_col_list, ) self.data_val = FocusDataSet( self.hparams.csv_val_file, self.hparams.data_dir, transform=self.base_transforms, in_memory=self.in_memory, additional_col_list=self.additional_col_list, ) self.data_test = FocusDataSet( self.hparams.csv_test_file, self.hparams.data_dir, transform=self.base_transforms, in_memory=self.in_memory, additional_col_list=self.additional_col_list, ) def train_dataloader(self): return DataLoader( dataset=self.data_train, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=True, ) def val_dataloader(self): return DataLoader( dataset=self.data_val, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=False, ) def test_dataloader(self): return DataLoader( dataset=self.data_test, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=False, )