deprem-ocr / ocr /ppocr /data /pubtab_dataset.py
Goodsea's picture
paddleocr
fc8c192
raw history blame
No virus
3.77 kB
import json
import os
import random
import numpy as np
from paddle.io import Dataset
from .imaug import create_operators, transform
class PubTabDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(PubTabDataSet, self).__init__()
self.logger = logger
global_config = config["Global"]
dataset_config = config[mode]["dataset"]
loader_config = config[mode]["loader"]
label_file_path = dataset_config.pop("label_file_path")
self.data_dir = dataset_config["data_dir"]
self.do_shuffle = loader_config["shuffle"]
self.do_hard_select = False
if "hard_select" in loader_config:
self.do_hard_select = loader_config["hard_select"]
self.hard_prob = loader_config["hard_prob"]
if self.do_hard_select:
self.img_select_prob = self.load_hard_select_prob()
self.table_select_type = None
if "table_select_type" in loader_config:
self.table_select_type = loader_config["table_select_type"]
self.table_select_prob = loader_config["table_select_prob"]
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_path)
with open(label_file_path, "rb") as f:
self.data_lines = f.readlines()
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.shuffle_data_random()
self.ops = create_operators(dataset_config["transforms"], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def __getitem__(self, idx):
try:
data_line = self.data_lines[idx]
data_line = data_line.decode("utf-8").strip("\n")
info = json.loads(data_line)
file_name = info["filename"]
select_flag = True
if self.do_hard_select:
prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1):
select_flag = False
if self.table_select_type:
structure = info["html"]["structure"]["tokens"].copy()
structure_str = "".join(structure)
table_type = "simple"
if "colspan" in structure_str or "rowspan" in structure_str:
table_type = "complex"
if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1):
select_flag = False
if select_flag:
cells = info["html"]["cells"].copy()
structure = info["html"]["structure"].copy()
img_path = os.path.join(self.data_dir, file_name)
data = {"img_path": img_path, "cells": cells, "structure": structure}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data["img_path"], "rb") as f:
img = f.read()
data["image"] = img
outs = transform(data, self.ops)
else:
outs = None
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(data_line, e)
)
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_idx_order_list)