Spaces:
Runtime error
Runtime error
""" a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """ | |
import fire | |
import os | |
import lmdb | |
import cv2 | |
import numpy as np | |
def checkImageIsValid(imageBin): | |
if imageBin is None: | |
return False | |
imageBuf = np.frombuffer(imageBin, dtype=np.uint8) | |
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) | |
imgH, imgW = img.shape[0], img.shape[1] | |
if imgH * imgW == 0: | |
return False | |
return True | |
def writeCache(env, cache): | |
with env.begin(write=True) as txn: | |
for k, v in cache.items(): | |
txn.put(k, v) | |
def createDataset(inputPath, gtFile, outputPath, checkValid=True): | |
""" | |
Create LMDB dataset for training and evaluation. | |
ARGS: | |
inputPath : input folder path where starts imagePath | |
outputPath : LMDB output path | |
gtFile : list of image path and label | |
checkValid : if true, check the validity of every image | |
""" | |
os.makedirs(outputPath, exist_ok=True) | |
env = lmdb.open(outputPath, map_size=1099511627776) | |
cache = {} | |
cnt = 1 | |
with open(gtFile, 'r', encoding='utf-8') as data: | |
datalist = data.readlines() | |
nSamples = len(datalist) | |
for i in range(nSamples): | |
imagePath, label = datalist[i].strip('\n').split('\t') | |
imagePath = os.path.join(inputPath, imagePath) | |
# # only use alphanumeric data | |
# if re.search('[^a-zA-Z0-9]', label): | |
# continue | |
if not os.path.exists(imagePath): | |
print('%s does not exist' % imagePath) | |
continue | |
with open(imagePath, 'rb') as f: | |
imageBin = f.read() | |
if checkValid: | |
try: | |
if not checkImageIsValid(imageBin): | |
print('%s is not a valid image' % imagePath) | |
continue | |
except: | |
print('error occured', i) | |
with open(outputPath + '/error_image_log.txt', 'a') as log: | |
log.write('%s-th image data occured error\n' % str(i)) | |
continue | |
imageKey = 'image-%09d'.encode() % cnt | |
labelKey = 'label-%09d'.encode() % cnt | |
cache[imageKey] = imageBin | |
cache[labelKey] = label.encode() | |
if cnt % 1000 == 0: | |
writeCache(env, cache) | |
cache = {} | |
print('Written %d / %d' % (cnt, nSamples)) | |
cnt += 1 | |
nSamples = cnt-1 | |
cache['num-samples'.encode()] = str(nSamples).encode() | |
writeCache(env, cache) | |
print('Created dataset with %d samples' % nSamples) | |
if __name__ == '__main__': | |
fire.Fire(createDataset) | |