|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import os |
|
import json |
|
import random |
|
import traceback |
|
from paddle.io import Dataset |
|
from .imaug import transform, create_operators |
|
|
|
|
|
class SimpleDataSet(Dataset): |
|
def __init__(self, config, mode, logger, seed=None): |
|
super(SimpleDataSet, self).__init__() |
|
self.logger = logger |
|
self.mode = mode.lower() |
|
|
|
global_config = config['Global'] |
|
dataset_config = config[mode]['dataset'] |
|
loader_config = config[mode]['loader'] |
|
|
|
self.delimiter = dataset_config.get('delimiter', '\t') |
|
label_file_list = dataset_config.pop('label_file_list') |
|
data_source_num = len(label_file_list) |
|
ratio_list = dataset_config.get("ratio_list", 1.0) |
|
if isinstance(ratio_list, (float, int)): |
|
ratio_list = [float(ratio_list)] * int(data_source_num) |
|
|
|
assert len( |
|
ratio_list |
|
) == data_source_num, "The length of ratio_list should be the same as the file_list." |
|
self.data_dir = dataset_config['data_dir'] |
|
self.do_shuffle = loader_config['shuffle'] |
|
self.seed = seed |
|
logger.info("Initialize indexs of datasets:%s" % label_file_list) |
|
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) |
|
self.data_idx_order_list = list(range(len(self.data_lines))) |
|
if self.mode == "train" and self.do_shuffle: |
|
self.shuffle_data_random() |
|
self.ops = create_operators(dataset_config['transforms'], global_config) |
|
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", |
|
2) |
|
self.need_reset = True in [x < 1 for x in ratio_list] |
|
|
|
def get_image_info_list(self, file_list, ratio_list): |
|
if isinstance(file_list, str): |
|
file_list = [file_list] |
|
data_lines = [] |
|
for idx, file in enumerate(file_list): |
|
with open(file, "rb") as f: |
|
lines = f.readlines() |
|
if self.mode == "train" or ratio_list[idx] < 1.0: |
|
random.seed(self.seed) |
|
lines = random.sample(lines, |
|
round(len(lines) * ratio_list[idx])) |
|
data_lines.extend(lines) |
|
return data_lines |
|
|
|
def shuffle_data_random(self): |
|
random.seed(self.seed) |
|
random.shuffle(self.data_lines) |
|
return |
|
|
|
def _try_parse_filename_list(self, file_name): |
|
|
|
if len(file_name) > 0 and file_name[0] == "[": |
|
try: |
|
info = json.loads(file_name) |
|
file_name = random.choice(info) |
|
except: |
|
pass |
|
return file_name |
|
|
|
def get_ext_data(self): |
|
ext_data_num = 0 |
|
for op in self.ops: |
|
if hasattr(op, 'ext_data_num'): |
|
ext_data_num = getattr(op, 'ext_data_num') |
|
break |
|
load_data_ops = self.ops[:self.ext_op_transform_idx] |
|
ext_data = [] |
|
|
|
while len(ext_data) < ext_data_num: |
|
file_idx = self.data_idx_order_list[np.random.randint(self.__len__( |
|
))] |
|
data_line = self.data_lines[file_idx] |
|
data_line = data_line.decode('utf-8') |
|
substr = data_line.strip("\n").split(self.delimiter) |
|
file_name = substr[0] |
|
file_name = self._try_parse_filename_list(file_name) |
|
label = substr[1] |
|
img_path = os.path.join(self.data_dir, file_name) |
|
data = {'img_path': img_path, 'label': label} |
|
if not os.path.exists(img_path): |
|
continue |
|
with open(data['img_path'], 'rb') as f: |
|
img = f.read() |
|
data['image'] = img |
|
data = transform(data, load_data_ops) |
|
|
|
if data is None: |
|
continue |
|
if 'polys' in data.keys(): |
|
if data['polys'].shape[1] != 4: |
|
continue |
|
ext_data.append(data) |
|
return ext_data |
|
|
|
def __getitem__(self, idx): |
|
file_idx = self.data_idx_order_list[idx] |
|
data_line = self.data_lines[file_idx] |
|
try: |
|
data_line = data_line.decode('utf-8') |
|
substr = data_line.strip("\n").split(self.delimiter) |
|
file_name = substr[0] |
|
file_name = self._try_parse_filename_list(file_name) |
|
label = substr[1] |
|
img_path = os.path.join(self.data_dir, file_name) |
|
data = {'img_path': img_path, 'label': label} |
|
if not os.path.exists(img_path): |
|
raise Exception("{} does not exist!".format(img_path)) |
|
with open(data['img_path'], 'rb') as f: |
|
img = f.read() |
|
data['image'] = img |
|
data['ext_data'] = self.get_ext_data() |
|
outs = transform(data, self.ops) |
|
except: |
|
self.logger.error( |
|
"When parsing line {}, error happened with msg: {}".format( |
|
data_line, traceback.format_exc())) |
|
outs = None |
|
if outs is None: |
|
|
|
rnd_idx = np.random.randint(self.__len__( |
|
)) if self.mode == "train" else (idx + 1) % self.__len__() |
|
return self.__getitem__(rnd_idx) |
|
return outs |
|
|
|
def __len__(self): |
|
return len(self.data_idx_order_list) |
|
|