from torch.utils import data from PIL import Image import os class Dataset(data.Dataset): 'Characterizes a dataset for PyTorch' def __init__(self, path, transform=None): 'Initialization' self.file_names = self.get_filenames(path) self.transform = transform def __len__(self): 'Denotes the total number of samples' return len(self.file_names) def __getitem__(self, index): 'Generates one sample of data' img = Image.open(self.file_names[index]).convert('RGB') # Convert image and label to torch tensors if self.transform is not None: img = self.transform(img) return img def get_filenames(self, data_path): images = [] for path, subdirs, files in os.walk(data_path): for name in files: if name.rfind('jpg') != -1 or name.rfind('png') != -1: filename = os.path.join(path, name) if os.path.isfile(filename): images.append(filename) return images