Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
import os.path as osp | |
import warnings | |
from typing import Dict, List | |
import cv2 | |
import lmdb | |
import mmengine | |
import numpy as np | |
from mmocr.registry import DATA_DUMPERS | |
from .base import BaseDumper | |
class TextRecogLMDBDumper(BaseDumper): | |
"""Text recognition LMDB format dataset dumper. | |
Args: | |
task (str): Task type. Options are 'textdet', 'textrecog', | |
'textspotter', and 'kie'. It is usually set automatically and users | |
do not need to set it manually in config file in most cases. | |
split (str): It' s the partition of the datasets. Options are 'train', | |
'val' or 'test'. It is usually set automatically and users do not | |
need to set it manually in config file in most cases. Defaults to | |
None. | |
data_root (str): The root directory of the image and | |
annotation. It is usually set automatically and users do not need | |
to set it manually in config file in most cases. Defaults to None. | |
batch_size (int): Number of files written to the cache each time. | |
Defaults to 1000. | |
encoding (str): Label encoding method. Defaults to 'utf-8'. | |
lmdb_map_size (int): Maximum size database may grow to. Defaults to | |
1099511627776. | |
verify (bool): Whether to check the validity of every image. Defaults | |
to True. | |
""" | |
def __init__(self, | |
task: str, | |
split: str, | |
data_root: str, | |
batch_size: int = 1000, | |
encoding: str = 'utf-8', | |
lmdb_map_size: int = 1099511627776, | |
verify: bool = True) -> None: | |
assert task == 'textrecog', \ | |
f'TextRecogLMDBDumper only works with textrecog, but got {task}' | |
super().__init__(task=task, split=split, data_root=data_root) | |
self.batch_size = batch_size | |
self.encoding = encoding | |
self.lmdb_map_size = lmdb_map_size | |
self.verify = verify | |
def check_image_is_valid(self, imageBin): | |
if imageBin is None: | |
return False | |
imageBuf = np.frombuffer(imageBin, dtype=np.uint8) | |
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) | |
imgH, imgW = img.shape[0], img.shape[1] | |
if imgH * imgW == 0: | |
return False | |
return True | |
def write_cache(self, env, cache): | |
with env.begin(write=True) as txn: | |
cursor = txn.cursor() | |
cursor.putmulti(cache, dupdata=False, overwrite=True) | |
def parser_pack_instance(self, instance: Dict): | |
"""parser an packed MMOCR format textrecog instance. | |
Args: | |
instance (Dict): An packed MMOCR format textrecog instance. | |
For example, | |
{ | |
"instance": [ | |
{ | |
"text": "Hello" | |
} | |
], | |
"img_path": "img1.jpg" | |
} | |
""" | |
assert isinstance(instance, | |
Dict), 'Element of data_list must be a dict' | |
assert 'img_path' in instance and 'instances' in instance, \ | |
'Element of data_list must have the following keys: ' \ | |
f'img_path and instances, but got {instance.keys()}' | |
assert isinstance(instance['instances'], List) and len( | |
instance['instances']) == 1 | |
assert 'text' in instance['instances'][0] | |
img_path = instance['img_path'] | |
text = instance['instances'][0]['text'] | |
return img_path, text | |
def dump(self, data: Dict) -> None: | |
"""Dump data to LMDB format.""" | |
# create lmdb env | |
output_dirname = f'{self.task}_{self.split}.lmdb' | |
output = osp.join(self.data_root, output_dirname) | |
mmengine.mkdir_or_exist(output) | |
env = lmdb.open(output, map_size=self.lmdb_map_size) | |
# load data | |
if 'data_list' not in data: | |
raise ValueError('Dump data must have data_list key') | |
data_list = data['data_list'] | |
cache = [] | |
# index start from 1 | |
cnt = 1 | |
n_samples = len(data_list) | |
for d in data_list: | |
# convert both images and labels to lmdb | |
label_key = 'label-%09d'.encode(self.encoding) % cnt | |
img_name, text = self.parser_pack_instance(d) | |
img_path = osp.join(self.data_root, img_name) | |
if not osp.exists(img_path): | |
warnings.warn('%s does not exist' % img_path) | |
continue | |
with open(img_path, 'rb') as f: | |
image_bin = f.read() | |
if self.verify: | |
if not self.check_image_is_valid(image_bin): | |
warnings.warn('%s is not a valid image' % img_path) | |
continue | |
image_key = 'image-%09d'.encode(self.encoding) % cnt | |
cache.append((image_key, image_bin)) | |
cache.append((label_key, text.encode(self.encoding))) | |
if cnt % self.batch_size == 0: | |
self.write_cache(env, cache) | |
cache = [] | |
print('Written %d / %d' % (cnt, n_samples)) | |
cnt += 1 | |
n_samples = cnt - 1 | |
cache.append(('num-samples'.encode(self.encoding), | |
str(n_samples).encode(self.encoding))) | |
self.write_cache(env, cache) | |
print('Created lmdb dataset with %d samples' % n_samples) | |