from __future__ import print_function, division import os import torch import pandas as pd from skimage import io, transform import numpy as np from torch.utils.data import Dataset, DataLoader from torchvision import transforms, utils from PIL import Image, ImageOps from random import random, randint import warnings warnings.filterwarnings("ignore") def make_dataset(root, mode): assert mode in ['train','val', 'test'] items = [] if mode == 'train': train_img_path = os.path.join(root, 'train', 'Img') train_mask_path = os.path.join(root, 'train', 'GT') images = os.listdir(train_img_path) labels = os.listdir(train_mask_path) images.sort() labels.sort() for it_im, it_gt in zip(images, labels): item = (os.path.join(train_img_path, it_im), os.path.join(train_mask_path, it_gt)) items.append(item) elif mode == 'val': val_img_path = os.path.join(root, 'val', 'Img') val_mask_path = os.path.join(root, 'val', 'GT') images = os.listdir(val_img_path) labels = os.listdir(val_mask_path) images.sort() labels.sort() for it_im, it_gt in zip(images, labels): item = (os.path.join(val_img_path, it_im), os.path.join(val_mask_path, it_gt)) items.append(item) else: test_img_path = os.path.join(root, 'test', 'Img') test_mask_path = os.path.join(root, 'test', 'GT') images = os.listdir(test_img_path) labels = os.listdir(test_mask_path) images.sort() labels.sort() for it_im, it_gt in zip(images, labels): item = (os.path.join(test_img_path, it_im), os.path.join(test_mask_path, it_gt)) items.append(item) return items class MedicalImageDataset(Dataset): """Face Landmarks dataset.""" def __init__(self, mode, root_dir, transform=None, mask_transform=None, augment=False, equalize=False): """ Args: root_dir (string): Directory with all the images. transform (callable, optional): Optional transform to be applied on a sample. """ self.root_dir = root_dir self.transform = transform self.mask_transform = mask_transform self.imgs = make_dataset(root_dir, mode) self.augmentation = augment self.equalize = equalize self.mode = mode def __len__(self): return len(self.imgs) def augment(self, img, mask): if random() > 0.5: img = ImageOps.flip(img) mask = ImageOps.flip(mask) if random() > 0.5: img = ImageOps.mirror(img) mask = ImageOps.mirror(mask) if random() > 0.5: angle = random() * 60 - 30 img = img.rotate(angle) mask = mask.rotate(angle) return img, mask def __getitem__(self, index): img_path, mask_path = self.imgs[index] img = Image.open(img_path) mask = Image.open(mask_path).convert('L') if self.equalize: img = ImageOps.equalize(img) if self.augmentation: img, mask = self.augment(img, mask) if self.transform: img = self.transform(img) mask = self.mask_transform(mask) return [img, mask, img_path]