deprem-ocr / ocr /ppocr /data /simple_dataset.py
Goodsea's picture
paddleocr
fc8c192
raw history blame
No virus
5.45 kB
import json
import os
import random
import traceback
import numpy as np
from paddle.io import Dataset
from .imaug import create_operators, transform
class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__()
self.logger = logger
self.mode = mode.lower()
global_config = config["Global"]
dataset_config = config[mode]["dataset"]
loader_config = config[mode]["loader"]
self.delimiter = dataset_config.get("delimiter", "\t")
label_file_list = dataset_config.pop("label_file_list")
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", 1.0)
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
assert (
len(ratio_list) == data_source_num
), "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config["data_dir"]
self.do_shuffle = loader_config["shuffle"]
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random()
self.ops = create_operators(dataset_config["transforms"], global_config)
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 2)
self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
lines = random.sample(lines, round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def shuffle_data_random(self):
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def _try_parse_filename_list(self, file_name):
# multiple images -> one gt label
if len(file_name) > 0 and file_name[0] == "[":
try:
info = json.loads(file_name)
file_name = random.choice(info)
except:
pass
return file_name
def get_ext_data(self):
ext_data_num = 0
for op in self.ops:
if hasattr(op, "ext_data_num"):
ext_data_num = getattr(op, "ext_data_num")
break
load_data_ops = self.ops[: self.ext_op_transform_idx]
ext_data = []
while len(ext_data) < ext_data_num:
file_idx = self.data_idx_order_list[np.random.randint(self.__len__())]
data_line = self.data_lines[file_idx]
data_line = data_line.decode("utf-8")
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {"img_path": img_path, "label": label}
if not os.path.exists(img_path):
continue
with open(data["img_path"], "rb") as f:
img = f.read()
data["image"] = img
data = transform(data, load_data_ops)
if data is None:
continue
if "polys" in data.keys():
if data["polys"].shape[1] != 4:
continue
ext_data.append(data)
return ext_data
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx]
try:
data_line = data_line.decode("utf-8")
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {"img_path": img_path, "label": label}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data["img_path"], "rb") as f:
img = f.read()
data["image"] = img
data["ext_data"] = self.get_ext_data()
outs = transform(data, self.ops)
except:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
data_line, traceback.format_exc()
)
)
outs = None
if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation.
rnd_idx = (
np.random.randint(self.__len__())
if self.mode == "train"
else (idx + 1) % self.__len__()
)
return self.__getitem__(rnd_idx)
return outs
def __len__(self):
return len(self.data_idx_order_list)