from torch.utils.data import Dataset, DataLoader from pathlib import Path from typing import List, Tuple, TypeVar, Optional import pandas as pd import os from PIL import Image from torch import Tensor from torchvision import transforms import torch PathLike = TypeVar("PathLike", str, Path, None) PROJECT_DIR = Path(os.path.realpath(__file__)).parent.parent DATADIR = Path("/JSRT") # can be found at http://db.jsrt.or.jp/eng.php def build_dataloaders( data_dir: str=DATADIR, img_size: int=128, batch_size: int=16, num_workers: int=1, n_labelled_images: Optional[int] = None, **kwargs ) -> Tuple[List, List, List]: """ Build dataloaders for the JSRT dataset. """ train_ds = JSRTDataset(data_dir, PROJECT_DIR / "data", "JSRT_train_split.csv", img_size) if n_labelled_images is not None: train_ds = torch.utils.data.Subset(train_ds, range(n_labelled_images)) print(f"Using {n_labelled_images} labelled images") val_ds = JSRTDataset(data_dir, PROJECT_DIR / "data", "JSRT_val_split.csv", img_size) test_ds = JSRTDataset(data_dir, PROJECT_DIR / "data", "JSRT_test_split.csv", img_size) dataloaders = {} dataloaders['train'] = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) dataloaders['val'] = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) dataloaders['test'] = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) return dataloaders class JSRTDataset(Dataset): def __init__(self, base_path:PathLike, csv_path:PathLike, csv_name:str, img_size:int=128, labels:List[str] =('right lung', 'left lung', ), **kwargs) -> None: self.df = pd.read_csv(os.path.join(csv_path, csv_name)) self.base_path = Path(base_path) self.labels = labels self.img_size = img_size def load_image(self, fname: str) -> Tensor: img = Image.open(self.base_path /fname).convert('L').resize((self.img_size, self.img_size)) img = transforms.ToTensor()(img).float() return img def load_labels(self, fnames: List[str]) -> Tensor: labels = [] for fname in fnames: label = Image.open(self.base_path /fname).convert('L').resize((self.img_size, self.img_size)) # convert to tensor label = transforms.ToTensor()(label).float() # make binary label = (label > .5).float() labels.append(label) # append all labels and merge label = torch.stack(labels).sum(0) # lungs have no overlap (right?) if (label > 1).sum()>0: print("overlapping lungs!", fnames) label = (label > .5) return label def __getitem__(self, index) -> Tuple[Tensor, Tensor]: i = self.df.index[index] img = self.load_image(self.df.loc[i, "path"]) label_paths = ["SCR/masks/" + item + "/" + self.df.loc[i, 'id']+ ".gif" for item in self.labels] labels = self.load_labels(label_paths) return img, labels def __len__(self): return len(self.df)