nikunjkdtechnoland
init commit some more files
89c278d
raw
history blame
2.82 kB
import sys
import torch.utils.data as data
from os import listdir
from utils.tools import default_loader, is_image_file, normalize
import os
import torchvision.transforms as transforms
class Dataset(data.Dataset):
def __init__(self, data_path, image_shape, with_subfolder=False, random_crop=True, return_name=False):
super(Dataset, self).__init__()
if with_subfolder:
self.samples = self._find_samples_in_subfolders(data_path)
else:
self.samples = [x for x in listdir(data_path) if is_image_file(x)]
self.data_path = data_path
self.image_shape = image_shape[:-1]
self.random_crop = random_crop
self.return_name = return_name
def __getitem__(self, index):
path = os.path.join(self.data_path, self.samples[index])
img = default_loader(path)
if self.random_crop:
imgw, imgh = img.size
if imgh < self.image_shape[0] or imgw < self.image_shape[1]:
img = transforms.Resize(min(self.image_shape))(img)
img = transforms.RandomCrop(self.image_shape)(img)
else:
img = transforms.Resize(self.image_shape)(img)
img = transforms.RandomCrop(self.image_shape)(img)
img = transforms.ToTensor()(img) # turn the image to a tensor
img = normalize(img)
if self.return_name:
return self.samples[index], img
else:
return img
def _find_samples_in_subfolders(self, dir):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
if sys.version_info >= (3, 5):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
else:
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
samples = []
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
if is_image_file(fname):
path = os.path.join(root, fname)
# item = (path, class_to_idx[target])
# samples.append(item)
samples.append(path)
return samples
def __len__(self):
return len(self.samples)