|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import os |
|
from paddle.io import Dataset |
|
import lmdb |
|
import cv2 |
|
import string |
|
import six |
|
from PIL import Image |
|
|
|
from .imaug import transform, create_operators |
|
|
|
|
|
class LMDBDataSet(Dataset): |
|
def __init__(self, config, mode, logger, seed=None): |
|
super(LMDBDataSet, self).__init__() |
|
|
|
global_config = config['Global'] |
|
dataset_config = config[mode]['dataset'] |
|
loader_config = config[mode]['loader'] |
|
batch_size = loader_config['batch_size_per_card'] |
|
data_dir = dataset_config['data_dir'] |
|
self.do_shuffle = loader_config['shuffle'] |
|
|
|
self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir) |
|
logger.info("Initialize indexs of datasets:%s" % data_dir) |
|
self.data_idx_order_list = self.dataset_traversal() |
|
if self.do_shuffle: |
|
np.random.shuffle(self.data_idx_order_list) |
|
self.ops = create_operators(dataset_config['transforms'], global_config) |
|
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", |
|
1) |
|
|
|
ratio_list = dataset_config.get("ratio_list", [1.0]) |
|
self.need_reset = True in [x < 1 for x in ratio_list] |
|
|
|
def load_hierarchical_lmdb_dataset(self, data_dir): |
|
lmdb_sets = {} |
|
dataset_idx = 0 |
|
for dirpath, dirnames, filenames in os.walk(data_dir + '/'): |
|
if not dirnames: |
|
env = lmdb.open( |
|
dirpath, |
|
max_readers=32, |
|
readonly=True, |
|
lock=False, |
|
readahead=False, |
|
meminit=False) |
|
txn = env.begin(write=False) |
|
num_samples = int(txn.get('num-samples'.encode())) |
|
lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \ |
|
"txn":txn, "num_samples":num_samples} |
|
dataset_idx += 1 |
|
return lmdb_sets |
|
|
|
def dataset_traversal(self): |
|
lmdb_num = len(self.lmdb_sets) |
|
total_sample_num = 0 |
|
for lno in range(lmdb_num): |
|
total_sample_num += self.lmdb_sets[lno]['num_samples'] |
|
data_idx_order_list = np.zeros((total_sample_num, 2)) |
|
beg_idx = 0 |
|
for lno in range(lmdb_num): |
|
tmp_sample_num = self.lmdb_sets[lno]['num_samples'] |
|
end_idx = beg_idx + tmp_sample_num |
|
data_idx_order_list[beg_idx:end_idx, 0] = lno |
|
data_idx_order_list[beg_idx:end_idx, 1] \ |
|
= list(range(tmp_sample_num)) |
|
data_idx_order_list[beg_idx:end_idx, 1] += 1 |
|
beg_idx = beg_idx + tmp_sample_num |
|
return data_idx_order_list |
|
|
|
def get_img_data(self, value): |
|
"""get_img_data""" |
|
if not value: |
|
return None |
|
imgdata = np.frombuffer(value, dtype='uint8') |
|
if imgdata is None: |
|
return None |
|
imgori = cv2.imdecode(imgdata, 1) |
|
if imgori is None: |
|
return None |
|
return imgori |
|
|
|
def get_ext_data(self): |
|
ext_data_num = 0 |
|
for op in self.ops: |
|
if hasattr(op, 'ext_data_num'): |
|
ext_data_num = getattr(op, 'ext_data_num') |
|
break |
|
load_data_ops = self.ops[:self.ext_op_transform_idx] |
|
ext_data = [] |
|
|
|
while len(ext_data) < ext_data_num: |
|
lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint( |
|
len(self))] |
|
lmdb_idx = int(lmdb_idx) |
|
file_idx = int(file_idx) |
|
sample_info = self.get_lmdb_sample_info( |
|
self.lmdb_sets[lmdb_idx]['txn'], file_idx) |
|
if sample_info is None: |
|
continue |
|
img, label = sample_info |
|
data = {'image': img, 'label': label} |
|
data = transform(data, load_data_ops) |
|
if data is None: |
|
continue |
|
ext_data.append(data) |
|
return ext_data |
|
|
|
def get_lmdb_sample_info(self, txn, index): |
|
label_key = 'label-%09d'.encode() % index |
|
label = txn.get(label_key) |
|
if label is None: |
|
return None |
|
label = label.decode('utf-8') |
|
img_key = 'image-%09d'.encode() % index |
|
imgbuf = txn.get(img_key) |
|
return imgbuf, label |
|
|
|
def __getitem__(self, idx): |
|
lmdb_idx, file_idx = self.data_idx_order_list[idx] |
|
lmdb_idx = int(lmdb_idx) |
|
file_idx = int(file_idx) |
|
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'], |
|
file_idx) |
|
if sample_info is None: |
|
return self.__getitem__(np.random.randint(self.__len__())) |
|
img, label = sample_info |
|
data = {'image': img, 'label': label} |
|
data['ext_data'] = self.get_ext_data() |
|
outs = transform(data, self.ops) |
|
if outs is None: |
|
return self.__getitem__(np.random.randint(self.__len__())) |
|
return outs |
|
|
|
def __len__(self): |
|
return self.data_idx_order_list.shape[0] |
|
|
|
|
|
class LMDBDataSetSR(LMDBDataSet): |
|
def buf2PIL(self, txn, key, type='RGB'): |
|
imgbuf = txn.get(key) |
|
buf = six.BytesIO() |
|
buf.write(imgbuf) |
|
buf.seek(0) |
|
im = Image.open(buf).convert(type) |
|
return im |
|
|
|
def str_filt(self, str_, voc_type): |
|
alpha_dict = { |
|
'digit': string.digits, |
|
'lower': string.digits + string.ascii_lowercase, |
|
'upper': string.digits + string.ascii_letters, |
|
'all': string.digits + string.ascii_letters + string.punctuation |
|
} |
|
if voc_type == 'lower': |
|
str_ = str_.lower() |
|
for char in str_: |
|
if char not in alpha_dict[voc_type]: |
|
str_ = str_.replace(char, '') |
|
return str_ |
|
|
|
def get_lmdb_sample_info(self, txn, index): |
|
self.voc_type = 'upper' |
|
self.max_len = 100 |
|
self.test = False |
|
label_key = b'label-%09d' % index |
|
word = str(txn.get(label_key).decode()) |
|
img_HR_key = b'image_hr-%09d' % index |
|
img_lr_key = b'image_lr-%09d' % index |
|
try: |
|
img_HR = self.buf2PIL(txn, img_HR_key, 'RGB') |
|
img_lr = self.buf2PIL(txn, img_lr_key, 'RGB') |
|
except IOError or len(word) > self.max_len: |
|
return self[index + 1] |
|
label_str = self.str_filt(word, self.voc_type) |
|
return img_HR, img_lr, label_str |
|
|
|
def __getitem__(self, idx): |
|
lmdb_idx, file_idx = self.data_idx_order_list[idx] |
|
lmdb_idx = int(lmdb_idx) |
|
file_idx = int(file_idx) |
|
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'], |
|
file_idx) |
|
if sample_info is None: |
|
return self.__getitem__(np.random.randint(self.__len__())) |
|
img_HR, img_lr, label_str = sample_info |
|
data = {'image_hr': img_HR, 'image_lr': img_lr, 'label': label_str} |
|
outs = transform(data, self.ops) |
|
if outs is None: |
|
return self.__getitem__(np.random.randint(self.__len__())) |
|
return outs |
|
|