BOPBTL / Face_Enhancement /data /image_folder.py
manhkhanhUIT's picture
Add code
7fab858
raw
history blame contribute delete
No virus
2.72 kB
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data as data
from PIL import Image
import os
IMG_EXTENSIONS = [
".jpg",
".JPG",
".jpeg",
".JPEG",
".png",
".PNG",
".ppm",
".PPM",
".bmp",
".BMP",
".tiff",
".webp",
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset_rec(dir, images):
assert os.path.isdir(dir), "%s is not a valid directory" % dir
for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
def make_dataset(dir, recursive=False, read_cache=False, write_cache=False):
images = []
if read_cache:
possible_filelist = os.path.join(dir, "files.list")
if os.path.isfile(possible_filelist):
with open(possible_filelist, "r") as f:
images = f.read().splitlines()
return images
if recursive:
make_dataset_rec(dir, images)
else:
assert os.path.isdir(dir) or os.path.islink(dir), "%s is not a valid directory" % dir
for root, dnames, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
if write_cache:
filelist_cache = os.path.join(dir, "files.list")
with open(filelist_cache, "w") as f:
for path in images:
f.write("%s\n" % path)
print("wrote filelist cache at %s" % filelist_cache)
return images
def default_loader(path):
return Image.open(path).convert("RGB")
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False, loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise (
RuntimeError(
"Found 0 images in: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)
)
)
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)