|
import torch |
|
from torch.utils.data import Dataset |
|
import os |
|
from natsort import natsorted |
|
import cv2 |
|
import glob |
|
import numpy as np |
|
from PIL import Image |
|
from skimage import io as img |
|
|
|
class ImageAndMaskData(Dataset): |
|
|
|
def __init__(self, img_dir, mask_dir, transform=None): |
|
|
|
|
|
self.images = natsorted(glob.glob(img_dir + "/*")) |
|
self.masks = natsorted(glob.glob(mask_dir + "/*")) |
|
|
|
self.imgs_and_masks = list(zip(self.images, self.masks)) |
|
|
|
self.transform = transform |
|
|
|
def __len__(self): |
|
|
|
return len(self.imgs_and_masks) |
|
|
|
def __getitem__(self, idx): |
|
|
|
data = self.imgs_and_masks[idx] |
|
|
|
img_path = data[0] |
|
mask_path = data[1] |
|
|
|
|
|
img = np.array(Image.open(img_path)) |
|
mask = np.array(Image.open(mask_path))[:,:,0:1] |
|
|
|
|
|
|
|
sample = np.concatenate((img, mask), axis=2) |
|
|
|
|
|
|
|
|
|
sample = Image.fromarray(sample) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.transform: |
|
sample = self.transform(sample) |
|
|
|
|
|
|
|
return sample |
|
|
|
|
|
|
|
|
|
def make_4_chs_img(image_path, mask_path): |
|
im = img.imread(image_path) |
|
mask = img.imread(mask_path) |
|
|
|
|
|
mask = (mask > 127)*255 |
|
|
|
|
|
|
|
return np.concatenate((im, mask[:,:,0:1]), axis=2) |
|
|
|
def norm(x): |
|
out = (x -0.5) *2 |
|
return out.clamp(-1, 1) |
|
|
|
def denorm(x): |
|
out = (x + 1) / 2 |
|
return out.clamp(0, 1) |
|
|
|
def np2torch(x): |
|
|
|
x = x[:,:,:] |
|
x = x.transpose((2, 0, 1))/255 |
|
|
|
x = torch.from_numpy(x) |
|
|
|
|
|
|
|
x = x.type(torch.FloatTensor) |
|
|
|
x = norm(x) |
|
return x |
|
|
|
|
|
|
|
class ImageAndMaskDataFromSinGAN(Dataset): |
|
|
|
def __init__(self, img_dir, mask_dir, transform=None): |
|
|
|
|
|
self.images = natsorted(glob.glob(img_dir + "/*")) |
|
self.masks = natsorted(glob.glob(mask_dir + "/*")) |
|
|
|
self.imgs_and_masks = list(zip(self.images, self.masks)) |
|
|
|
self.transform = transform |
|
|
|
def __len__(self): |
|
|
|
return len(self.imgs_and_masks) |
|
|
|
def __getitem__(self, idx): |
|
|
|
data = self.imgs_and_masks[idx] |
|
|
|
image_path = data[0] |
|
mask_path = data[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sample = make_4_chs_img(image_path, mask_path) |
|
|
|
sample = np2torch(sample) |
|
|
|
sample = sample[0:4,:,:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.transform: |
|
sample = self.transform(sample) |
|
|
|
|
|
|
|
return sample |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
dataset = ImageAndMaskDataFromSinGAN("/work/vajira/DATA/kvasir_seg/real_images_root/real_images", |
|
"/work/vajira/DATA/kvasir_seg/real_masks_root/real_masks") |
|
|
|
print(dataset[1].shape) |
|
|
|
|
|
|
|
|
|
|