PUMP / datasets /image_set.py
Philippe Weinzaepfel
huggingface demo
3ef85e9
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use
from pdb import set_trace as bb
import os
from os.path import *
from PIL import Image
class ImageSet(object):
""" Base class for an image dataset.
"""
def __init__(self, root, imgs):
self.root = root
self.imgs = imgs
assert imgs, f'Empty image set in {root}'
def init_from_folder(self, *args, **kw):
imset = ImageSet.from_folder(*args, **kw)
ImageSet.__init__(self, imset.root, imset.imgs)
def __len__(self):
return len(self.imgs)
def get_image_path(self, idx):
return os.path.join(self.root, self.imgs[idx])
def get_image(self, idx):
fname = self.get_image_path(idx)
try:
return Image.open(fname).convert('RGB')
except Exception as e:
raise IOError("Could not load image %s (reason: %s)" % (fname, str(e)))
__getitem__ = get_image
@staticmethod
def from_folder(root, exts=('.jpg','.jpeg','.png','.ppm'), recursive=False, listing=False, check_imgs=False):
"""
recursive: bool or func. If a function, it must evaluate True to the directory name.
"""
if listing:
if listing is True: listing = f"list_imgs{'_recursive' if recursive else ''}.txt"
flist = join(root, listing)
try: return ImageSet.from_listing(root,flist)
except IOError: print(f'>> ImageSet.from_folder(listing=True): entering {root}...')
if check_imgs is True: # default verif function
check_imgs = verify_img
for _, dirnames, dirfiles in os.walk(root):
imgs = sorted([f for f in dirfiles if f.lower().endswith(exts)])
if check_imgs: imgs = [img for img in imgs if check_imgs(join(root,img))]
if recursive:
for dirname in sorted(dirnames):
if callable(recursive) and not recursive(join(root,dirname)): continue
imset = ImageSet.from_folder(join(root,dirname), exts=exts, recursive=recursive, listing=listing, check_imgs=check_imgs)
imgs += [join(dirname,f) for f in imset.imgs]
break # recursion is handled internally
if listing:
try: open(flist,'w').write('\n'.join(imgs))
except IOError: pass # write permission denied
return ImageSet(root, imgs)
@staticmethod
def from_listing(root, list_path):
return ImageSet(root, open(list_path).read().splitlines())
def circular_pad(self, min_size):
assert self.imgs, 'cannot pad an empty image set'
while len(self.imgs) < min_size:
self.imgs += self.imgs # artifically augment size
self.imgs = self.imgs[:min_size or None]
return self
def __repr__(self):
prefix = os.path.commonprefix((self.get_image_path(0),self.get_image_path(len(self)-1)))
return f'{self.__class__.__name__}({len(self)} images from {prefix}...)'
def verify_img(path, exts=None):
if exts and not path.lower().endswith(exts): return False
try:
Image.open(path).convert('RGB') # try to open it
return True
except:
return False