Spaces:
Runtime error
Runtime error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
import torch.utils.data as data | |
from PIL import Image | |
import os | |
IMG_EXTENSIONS = [ | |
".jpg", | |
".JPG", | |
".jpeg", | |
".JPEG", | |
".png", | |
".PNG", | |
".ppm", | |
".PPM", | |
".bmp", | |
".BMP", | |
".tiff", | |
".webp", | |
] | |
def is_image_file(filename): | |
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) | |
def make_dataset_rec(dir, images): | |
assert os.path.isdir(dir), "%s is not a valid directory" % dir | |
for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)): | |
for fname in fnames: | |
if is_image_file(fname): | |
path = os.path.join(root, fname) | |
images.append(path) | |
def make_dataset(dir, recursive=False, read_cache=False, write_cache=False): | |
images = [] | |
if read_cache: | |
possible_filelist = os.path.join(dir, "files.list") | |
if os.path.isfile(possible_filelist): | |
with open(possible_filelist, "r") as f: | |
images = f.read().splitlines() | |
return images | |
if recursive: | |
make_dataset_rec(dir, images) | |
else: | |
assert os.path.isdir(dir) or os.path.islink(dir), "%s is not a valid directory" % dir | |
for root, dnames, fnames in sorted(os.walk(dir)): | |
for fname in fnames: | |
if is_image_file(fname): | |
path = os.path.join(root, fname) | |
images.append(path) | |
if write_cache: | |
filelist_cache = os.path.join(dir, "files.list") | |
with open(filelist_cache, "w") as f: | |
for path in images: | |
f.write("%s\n" % path) | |
print("wrote filelist cache at %s" % filelist_cache) | |
return images | |
def default_loader(path): | |
return Image.open(path).convert("RGB") | |
class ImageFolder(data.Dataset): | |
def __init__(self, root, transform=None, return_paths=False, loader=default_loader): | |
imgs = make_dataset(root) | |
if len(imgs) == 0: | |
raise ( | |
RuntimeError( | |
"Found 0 images in: " + root + "\n" | |
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS) | |
) | |
) | |
self.root = root | |
self.imgs = imgs | |
self.transform = transform | |
self.return_paths = return_paths | |
self.loader = loader | |
def __getitem__(self, index): | |
path = self.imgs[index] | |
img = self.loader(path) | |
if self.transform is not None: | |
img = self.transform(img) | |
if self.return_paths: | |
return img, path | |
else: | |
return img | |
def __len__(self): | |
return len(self.imgs) | |