|
import os |
|
from typing import List, Optional, Tuple |
|
import pandas as pd |
|
from skimage import io |
|
|
|
import numpy as np |
|
|
|
import torch |
|
from pytorch_lightning import LightningDataModule |
|
from torch.utils.data import DataLoader, Dataset, random_split |
|
from torchvision.transforms import transforms |
|
|
|
|
|
class FocusDataSet(Dataset): |
|
"""Dataset for z-stacked images of neglected tropical diseaeses.""" |
|
|
|
def __init__( |
|
self, csv_file, root_dir, transform=None, in_memory=True, additional_col_list=[] |
|
): |
|
"""Initialize focus satck dataset. |
|
|
|
Args: |
|
csv_file (string): Path to the csv file with annotations. |
|
root_dir (string): Directory with all the images. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
self.metadata = pd.read_csv(csv_file) |
|
self.in_memory = in_memory |
|
|
|
self.additional_col_index = {} |
|
|
|
_col_list = list(additional_col_list) |
|
for attribute in _col_list: |
|
self.additional_col_index[attribute] = self.metadata.columns.get_loc( |
|
attribute |
|
) |
|
|
|
self.col_index_path = self.metadata.columns.get_loc("image_path") |
|
self.col_index_focus = self.metadata.columns.get_loc("focus_height") |
|
self.root_dir = root_dir |
|
self.transform = transform |
|
|
|
self.images = [] |
|
if self.in_memory: |
|
self.images = np.array( |
|
list(map(self._load_img, self.metadata["image_path"].tolist())) |
|
) |
|
|
|
def _load_img(self, img_path): |
|
path = os.path.join(self.root_dir, img_path) |
|
img = io.imread(path) |
|
return img |
|
|
|
def __len__(self) -> int: |
|
"""Get the length of the dataset. |
|
|
|
Returns: |
|
int: the length |
|
""" |
|
return len(self.metadata) |
|
|
|
def __getitem__(self, idx): |
|
"""Get one items from the dataset. |
|
|
|
Args: |
|
idx (int) The index of the sample that is to be retrieved |
|
|
|
Returns: |
|
Item/Items which is a dictionary containing "image" and "focus_height" |
|
""" |
|
if torch.is_tensor(idx): |
|
idx = idx.tolist() |
|
|
|
if self.in_memory: |
|
image = self.images[idx] |
|
else: |
|
image = self._load_img(self.metadata.iloc[idx, self.col_index_path]) |
|
|
|
if self.transform: |
|
image = self.transform(image) |
|
|
|
focus_height = torch.from_numpy( |
|
np.asarray(self.metadata.iloc[idx, self.col_index_focus]) |
|
).float() |
|
|
|
sample = {"image": image, "focus_height": focus_height} |
|
|
|
for attr, col_idx in self.additional_col_index.items(): |
|
sample[attr] = self.metadata.iloc[idx, col_idx] |
|
|
|
return sample |
|
|
|
|
|
class FocusDataModule(LightningDataModule): |
|
""" |
|
LightningDataModule for FocusStack dataset. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
data_dir: str = "data/", |
|
csv_train_file: str = "data/train_metadata.csv", |
|
csv_val_file: str = "data/validation_metadata.csv", |
|
csv_test_file: str = "data/test_metadata.csv", |
|
batch_size: int = 64, |
|
num_workers: int = 0, |
|
pin_memory: bool = False, |
|
in_memory: bool = True, |
|
augmentation: bool = False, |
|
additional_col_list: List[str] = [], |
|
): |
|
super().__init__() |
|
|
|
|
|
self.save_hyperparameters(logger=False) |
|
|
|
transform_list = [ |
|
transforms.ToTensor(), |
|
transforms.ConvertImageDtype(torch.float), |
|
] |
|
|
|
self.base_transforms = [] |
|
self.base_transforms.extend(transform_list) |
|
self.base_transforms = transforms.Compose(self.base_transforms) |
|
|
|
if augmentation: |
|
transform_list.extend( |
|
[ |
|
transforms.RandomHorizontalFlip(p=0.5), |
|
transforms.RandomVerticalFlip(p=0.5), |
|
transforms.RandomChoice( |
|
[ |
|
transforms.RandomApply( |
|
[transforms.RandomRotation((90, 90))], p=0.5 |
|
), |
|
transforms.RandomApply( |
|
[transforms.RandomRotation((180, 180))], p=0.5 |
|
), |
|
transforms.RandomApply( |
|
[transforms.RandomRotation((270, 270))], p=0.5 |
|
), |
|
] |
|
), |
|
] |
|
) |
|
|
|
|
|
self.transforms = transforms.Compose(transform_list) |
|
|
|
self.data_train: Optional[Dataset] = None |
|
self.data_val: Optional[Dataset] = None |
|
self.data_test: Optional[Dataset] = None |
|
self.in_memory = in_memory |
|
self.additional_col_list = additional_col_list |
|
|
|
def prepare_data(self): |
|
"""This method is not implemented as of yet. |
|
|
|
Download data if needed. This method is called only from a single GPU. |
|
Do not use it to assign state (self.x = y). |
|
""" |
|
pass |
|
|
|
def setup(self, stage: Optional[str] = None): |
|
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. |
|
This method is called by lightning twice for `trainer.fit()` and `trainer.test()`, so be careful if you do a random split! |
|
The `stage` can be used to differentiate whether it's called before trainer.fit()` or `trainer.test()`.""" |
|
|
|
|
|
if not self.data_train and not self.data_val and not self.data_test: |
|
self.data_train = FocusDataSet( |
|
self.hparams.csv_train_file, |
|
self.hparams.data_dir, |
|
transform=self.transforms, |
|
in_memory=self.in_memory, |
|
additional_col_list=self.additional_col_list, |
|
) |
|
|
|
self.data_val = FocusDataSet( |
|
self.hparams.csv_val_file, |
|
self.hparams.data_dir, |
|
transform=self.base_transforms, |
|
in_memory=self.in_memory, |
|
additional_col_list=self.additional_col_list, |
|
) |
|
|
|
self.data_test = FocusDataSet( |
|
self.hparams.csv_test_file, |
|
self.hparams.data_dir, |
|
transform=self.base_transforms, |
|
in_memory=self.in_memory, |
|
additional_col_list=self.additional_col_list, |
|
) |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
dataset=self.data_train, |
|
batch_size=self.hparams.batch_size, |
|
num_workers=self.hparams.num_workers, |
|
pin_memory=self.hparams.pin_memory, |
|
shuffle=True, |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
dataset=self.data_val, |
|
batch_size=self.hparams.batch_size, |
|
num_workers=self.hparams.num_workers, |
|
pin_memory=self.hparams.pin_memory, |
|
shuffle=False, |
|
) |
|
|
|
def test_dataloader(self): |
|
return DataLoader( |
|
dataset=self.data_test, |
|
batch_size=self.hparams.batch_size, |
|
num_workers=self.hparams.num_workers, |
|
pin_memory=self.hparams.pin_memory, |
|
shuffle=False, |
|
) |
|
|