TextureScraping / swapae /data /lmdb_dataset.py
sunshineatnoon
Add application file
1b2a9b1
raw
history blame
2.35 kB
import random
import sys
import os.path
from PIL import Image
from swapae.data.base_dataset import BaseDataset, get_transform
import cv2
import numpy as np
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
import torchvision.transforms as transforms
class LMDBDataset(BaseDataset):
def __init__(self, opt):
import lmdb
self.opt = opt
write_cache = True
root = opt.dataroot
self.root = os.path.expanduser(root)
self.env = lmdb.open(root, readonly=True, lock=False)
with self.env.begin(write=False) as txn:
self.length = txn.stat()['entries']
print('lmdb file at %s opened.' % root)
cache_file = os.path.join(root, '_cache_')
if os.path.isfile(cache_file):
self.keys = pickle.load(open(cache_file, "rb"))
elif write_cache:
print('generating keys')
with self.env.begin(write=False) as txn:
self.keys = [key for key, _ in txn.cursor()]
pickle.dump(self.keys, open(cache_file, "wb"))
print('cache file generated at %s' % cache_file)
else:
self.keys = []
random.Random(0).shuffle(self.keys)
self.transform = get_transform(self.opt, grayscale=False)
if "lsun" in self.opt.dataroot.lower():
print("Seems like a LSUN dataset, so we will apply BGR->RGB conversion")
def __getitem__(self, index):
path = self.keys[index]
return self.getitem_by_path(path)
def getitem_by_path(self, path):
env = self.env
with env.begin(write=False) as txn:
imgbuf = txn.get(path)
try:
img = cv2.imdecode(
np.fromstring(imgbuf, dtype=np.uint8), 1)
except cv2.error as e:
print(path, e)
return self.__getitem__(random.randint(0, self.length - 1))
if "lsun" in self.opt.dataroot.lower():
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img)
return {"real_A": self.transform(img), "path_A": path.decode("utf-8")}
def set_phase(self, phase):
super().set_phase(phase)
pass
def __len__(self):
return self.length
def __repr__(self):
return self.__class__.__name__ + ' (' + self.root + ')'