Spaces:
Runtime error
Runtime error
File size: 2,062 Bytes
4a285f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
###############################################################################
import torch.utils.data as data
from PIL import Image
import os
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
f = dir.split('/')[-1].split('_')[-1]
print(dir, f)
dirs = os.listdir(dir)
for img in dirs:
path = os.path.join(dir, img)
# print(path)
images.append(path)
return images
def make_dataset_test(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
f = dir.split('/')[-1].split('_')[-1]
names = os.listdir(dir)
for i in range(len([name for name in os.listdir(dir) if os.path.isfile(os.path.join(dir, name))])):
img = names[i]
path = os.path.join(dir, img)
# print(path)
images.append(path)
return images
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
|