medicalSegmentation / src /medicalDataLoader.py
thov's picture
add training
e6f4cd4
raw
history blame contribute delete
No virus
3.37 kB
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image, ImageOps
from random import random, randint
import warnings
warnings.filterwarnings("ignore")
def make_dataset(root, mode):
assert mode in ['train','val', 'test']
items = []
if mode == 'train':
train_img_path = os.path.join(root, 'train', 'Img')
train_mask_path = os.path.join(root, 'train', 'GT')
images = os.listdir(train_img_path)
labels = os.listdir(train_mask_path)
images.sort()
labels.sort()
for it_im, it_gt in zip(images, labels):
item = (os.path.join(train_img_path, it_im), os.path.join(train_mask_path, it_gt))
items.append(item)
elif mode == 'val':
val_img_path = os.path.join(root, 'val', 'Img')
val_mask_path = os.path.join(root, 'val', 'GT')
images = os.listdir(val_img_path)
labels = os.listdir(val_mask_path)
images.sort()
labels.sort()
for it_im, it_gt in zip(images, labels):
item = (os.path.join(val_img_path, it_im), os.path.join(val_mask_path, it_gt))
items.append(item)
else:
test_img_path = os.path.join(root, 'test', 'Img')
test_mask_path = os.path.join(root, 'test', 'GT')
images = os.listdir(test_img_path)
labels = os.listdir(test_mask_path)
images.sort()
labels.sort()
for it_im, it_gt in zip(images, labels):
item = (os.path.join(test_img_path, it_im), os.path.join(test_mask_path, it_gt))
items.append(item)
return items
class MedicalImageDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, mode, root_dir, transform=None, mask_transform=None, augment=False, equalize=False):
"""
Args:
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.root_dir = root_dir
self.transform = transform
self.mask_transform = mask_transform
self.imgs = make_dataset(root_dir, mode)
self.augmentation = augment
self.equalize = equalize
self.mode = mode
def __len__(self):
return len(self.imgs)
def augment(self, img, mask):
if random() > 0.5:
img = ImageOps.flip(img)
mask = ImageOps.flip(mask)
if random() > 0.5:
img = ImageOps.mirror(img)
mask = ImageOps.mirror(mask)
if random() > 0.5:
angle = random() * 60 - 30
img = img.rotate(angle)
mask = mask.rotate(angle)
return img, mask
def __getitem__(self, index):
img_path, mask_path = self.imgs[index]
img = Image.open(img_path)
mask = Image.open(mask_path).convert('L')
if self.equalize:
img = ImageOps.equalize(img)
if self.augmentation:
img, mask = self.augment(img, mask)
if self.transform:
img = self.transform(img)
mask = self.mask_transform(mask)
return [img, mask, img_path]