Spaces:
Running
Running
| import re | |
| import numpy as np | |
| from tools.utils.logging import get_logger | |
| class BaseRecLabelEncode(object): | |
| """Convert between text-label and text-index.""" | |
| def __init__( | |
| self, | |
| max_text_length, | |
| character_dict_path=None, | |
| use_space_char=False, | |
| lower=False, | |
| ): | |
| self.max_text_len = max_text_length | |
| self.beg_str = 'sos' | |
| self.end_str = 'eos' | |
| self.lower = lower | |
| self.reverse = False | |
| if character_dict_path is None: | |
| logger = get_logger() | |
| logger.warning( | |
| 'The character_dict_path is None, model can only recognize number and lower letters' | |
| ) | |
| self.character_str = '0123456789abcdefghijklmnopqrstuvwxyz' | |
| dict_character = list(self.character_str) | |
| self.lower = True | |
| else: | |
| self.character_str = [] | |
| with open(character_dict_path, 'rb') as fin: | |
| lines = fin.readlines() | |
| for line in lines: | |
| line = line.decode('utf-8').strip('\n').strip('\r\n') | |
| self.character_str.append(line) | |
| if use_space_char: | |
| self.character_str.append(' ') | |
| dict_character = list(self.character_str) | |
| if 'arabic' in character_dict_path: | |
| self.reverse = True | |
| dict_character = self.add_special_char(dict_character) | |
| self.dict = {} | |
| for i, char in enumerate(dict_character): | |
| self.dict[char] = i | |
| self.character = dict_character | |
| def label_reverse(self, text): | |
| text_re = [] | |
| c_current = '' | |
| for c in text: | |
| if not bool(re.search('[a-zA-Z0-9 :*./%+-١٢٣٤٥٦٧٨٩٠]', c)): | |
| if c_current != '': | |
| text_re.append(c_current) | |
| text_re.append(c) | |
| c_current = '' | |
| else: | |
| c_current += c | |
| if c_current != '': | |
| text_re.append(c_current) | |
| return ''.join(text_re[::-1]) | |
| def add_special_char(self, dict_character): | |
| return dict_character | |
| def encode(self, text): | |
| """convert text-label into text-index. | |
| input: | |
| text: text labels of each image. [batch_size] | |
| output: | |
| text: concatenated text index for CTCLoss. | |
| [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] | |
| length: length of each text. [batch_size] | |
| """ | |
| if len(text) == 0 or len(text) > self.max_text_len: | |
| return None | |
| if self.lower: | |
| text = text.lower() | |
| text_list = [] | |
| for char in text: | |
| if char not in self.dict: | |
| # logger = get_logger() | |
| # logger.warning('{} is not in dict'.format(char)) | |
| continue | |
| text_list.append(self.dict[char]) | |
| if len(text_list) == 0: | |
| return None | |
| return text_list | |
| class CELabelEncode(BaseRecLabelEncode): | |
| """Convert between text-label and text-index.""" | |
| def __init__(self, | |
| max_text_length, | |
| character_dict_path=None, | |
| use_space_char=False, | |
| **kwargs): | |
| super(CELabelEncode, | |
| self).__init__(max_text_length, character_dict_path, | |
| use_space_char) | |
| def __call__(self, data): | |
| text = data['label'] | |
| text = self.encode(text) | |
| if text is None: | |
| return None | |
| data['length'] = np.array(len(text)) | |
| data['label'] = np.array(text) | |
| return data | |
| def add_special_char(self, dict_character): | |
| return dict_character | |