Spaces:
Runtime error
Runtime error
import torch.utils.data as data | |
from PIL import Image | |
import os | |
import os.path | |
from io import BytesIO | |
import lmdb | |
from torch.utils.data import Dataset | |
class MultiResolutionDataset(Dataset): | |
def __init__(self, path, transform, resolution=256): | |
self.env = lmdb.open( | |
path, | |
max_readers=32, | |
readonly=True, | |
lock=False, | |
readahead=False, | |
meminit=False, | |
) | |
if not self.env: | |
raise IOError('Cannot open lmdb dataset', path) | |
with self.env.begin(write=False) as txn: | |
self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) | |
self.resolution = resolution | |
self.transform = transform | |
def __len__(self): | |
return self.length | |
def __getitem__(self, index): | |
with self.env.begin(write=False) as txn: | |
key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') | |
img_bytes = txn.get(key) | |
buffer = BytesIO(img_bytes) | |
img = Image.open(buffer) | |
img = self.transform(img) | |
return img | |
def has_file_allowed_extension(filename, extensions): | |
"""Checks if a file is an allowed extension. | |
Args: | |
filename (string): path to a file | |
Returns: | |
bool: True if the filename ends with a known image extension | |
""" | |
filename_lower = filename.lower() | |
return any(filename_lower.endswith(ext) for ext in extensions) | |
def find_classes(dir): | |
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] | |
classes.sort() | |
class_to_idx = {classes[i]: i for i in range(len(classes))} | |
return classes, class_to_idx | |
def make_dataset(dir, extensions): | |
images = [] | |
for root, _, fnames in sorted(os.walk(dir)): | |
for fname in sorted(fnames): | |
if has_file_allowed_extension(fname, extensions): | |
path = os.path.join(root, fname) | |
item = (path, 0) | |
images.append(item) | |
return images | |
class DatasetFolder(data.Dataset): | |
def __init__(self, root, loader, extensions, transform=None, target_transform=None): | |
# classes, class_to_idx = find_classes(root) | |
samples = make_dataset(root, extensions) | |
if len(samples) == 0: | |
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" | |
"Supported extensions are: " + ",".join(extensions))) | |
self.root = root | |
self.loader = loader | |
self.extensions = extensions | |
self.samples = samples | |
self.transform = transform | |
self.target_transform = target_transform | |
def __getitem__(self, index): | |
""" | |
Args: | |
index (int): Index | |
Returns: | |
tuple: (sample, target) where target is class_index of the target class. | |
""" | |
path, target = self.samples[index] | |
sample = self.loader(path) | |
if self.transform is not None: | |
sample = self.transform(sample) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
return sample | |
def __len__(self): | |
return len(self.samples) | |
def __repr__(self): | |
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' | |
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) | |
fmt_str += ' Root Location: {}\n'.format(self.root) | |
tmp = ' Transforms (if any): ' | |
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
tmp = ' Target Transforms (if any): ' | |
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
return fmt_str | |
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] | |
def pil_loader(path): | |
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) | |
with open(path, 'rb') as f: | |
img = Image.open(f) | |
return img.convert('RGB') | |
def default_loader(path): | |
return pil_loader(path) | |
class ImageFolder(DatasetFolder): | |
def __init__(self, root, transform1=None, transform2=None, target_transform=None, | |
loader=default_loader): | |
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, | |
transform=transform1, | |
target_transform=target_transform) | |
self.imgs = self.samples | |
self.transform2 = transform2 | |
def set_stage(self, stage): | |
if stage == 'last': | |
self.transform = self.transform2 | |
class ListFolder(Dataset): | |
def __init__(self, txt, transform): | |
with open(txt) as f: | |
imgpaths= f.readlines() | |
self.imgpaths = [x.strip() for x in imgpaths] | |
self.transform = transform | |
def __getitem__(self, idx): | |
path = self.imgpaths[idx] | |
image = Image.open(path) | |
return self.transform(image) | |
def __len__(self): | |
return len(self.imgpaths) | |