Spaces:
Sleeping
Sleeping
File size: 3,366 Bytes
1df7042 e6f4cd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image, ImageOps
from random import random, randint
import warnings
warnings.filterwarnings("ignore")
def make_dataset(root, mode):
assert mode in ['train','val', 'test']
items = []
if mode == 'train':
train_img_path = os.path.join(root, 'train', 'Img')
train_mask_path = os.path.join(root, 'train', 'GT')
images = os.listdir(train_img_path)
labels = os.listdir(train_mask_path)
images.sort()
labels.sort()
for it_im, it_gt in zip(images, labels):
item = (os.path.join(train_img_path, it_im), os.path.join(train_mask_path, it_gt))
items.append(item)
elif mode == 'val':
val_img_path = os.path.join(root, 'val', 'Img')
val_mask_path = os.path.join(root, 'val', 'GT')
images = os.listdir(val_img_path)
labels = os.listdir(val_mask_path)
images.sort()
labels.sort()
for it_im, it_gt in zip(images, labels):
item = (os.path.join(val_img_path, it_im), os.path.join(val_mask_path, it_gt))
items.append(item)
else:
test_img_path = os.path.join(root, 'test', 'Img')
test_mask_path = os.path.join(root, 'test', 'GT')
images = os.listdir(test_img_path)
labels = os.listdir(test_mask_path)
images.sort()
labels.sort()
for it_im, it_gt in zip(images, labels):
item = (os.path.join(test_img_path, it_im), os.path.join(test_mask_path, it_gt))
items.append(item)
return items
class MedicalImageDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, mode, root_dir, transform=None, mask_transform=None, augment=False, equalize=False):
"""
Args:
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.root_dir = root_dir
self.transform = transform
self.mask_transform = mask_transform
self.imgs = make_dataset(root_dir, mode)
self.augmentation = augment
self.equalize = equalize
self.mode = mode
def __len__(self):
return len(self.imgs)
def augment(self, img, mask):
if random() > 0.5:
img = ImageOps.flip(img)
mask = ImageOps.flip(mask)
if random() > 0.5:
img = ImageOps.mirror(img)
mask = ImageOps.mirror(mask)
if random() > 0.5:
angle = random() * 60 - 30
img = img.rotate(angle)
mask = mask.rotate(angle)
return img, mask
def __getitem__(self, index):
img_path, mask_path = self.imgs[index]
img = Image.open(img_path)
mask = Image.open(mask_path).convert('L')
if self.equalize:
img = ImageOps.equalize(img)
if self.augmentation:
img, mask = self.augment(img, mask)
if self.transform:
img = self.transform(img)
mask = self.mask_transform(mask)
return [img, mask, img_path]
|