OpenOCR-Demo / openrec /preprocess /ce_label_encode.py
topdu's picture
openocr demo
29f689c
raw
history blame
3.74 kB
import re
import numpy as np
from tools.utils.logging import get_logger
class BaseRecLabelEncode(object):
"""Convert between text-label and text-index."""
def __init__(
self,
max_text_length,
character_dict_path=None,
use_space_char=False,
lower=False,
):
self.max_text_len = max_text_length
self.beg_str = 'sos'
self.end_str = 'eos'
self.lower = lower
self.reverse = False
if character_dict_path is None:
logger = get_logger()
logger.warning(
'The character_dict_path is None, model can only recognize number and lower letters'
)
self.character_str = '0123456789abcdefghijklmnopqrstuvwxyz'
dict_character = list(self.character_str)
self.lower = True
else:
self.character_str = []
with open(character_dict_path, 'rb') as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip('\n').strip('\r\n')
self.character_str.append(line)
if use_space_char:
self.character_str.append(' ')
dict_character = list(self.character_str)
if 'arabic' in character_dict_path:
self.reverse = True
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def label_reverse(self, text):
text_re = []
c_current = ''
for c in text:
if not bool(re.search('[a-zA-Z0-9 :*./%+-١٢٣٤٥٦٧٨٩٠]', c)):
if c_current != '':
text_re.append(c_current)
text_re.append(c)
c_current = ''
else:
c_current += c
if c_current != '':
text_re.append(c_current)
return ''.join(text_re[::-1])
def add_special_char(self, dict_character):
return dict_character
def encode(self, text):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if len(text) == 0 or len(text) > self.max_text_len:
return None
if self.lower:
text = text.lower()
text_list = []
for char in text:
if char not in self.dict:
# logger = get_logger()
# logger.warning('{} is not in dict'.format(char))
continue
text_list.append(self.dict[char])
if len(text_list) == 0:
return None
return text_list
class CELabelEncode(BaseRecLabelEncode):
"""Convert between text-label and text-index."""
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(CELabelEncode,
self).__init__(max_text_length, character_dict_path,
use_space_char)
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
data['length'] = np.array(len(text))
data['label'] = np.array(text)
return data
def add_special_char(self, dict_character):
return dict_character