OpenOCR-Demo / openrec /preprocess /cppd_label_encode.py
topdu's picture
openocr demo
29f689c
raw
history blame
6.05 kB
import random
import numpy as np
from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
class CPPDLabelEncode(BaseRecLabelEncode):
"""Convert between text-label and text-index."""
def __init__(
self,
max_text_length,
character_dict_path=None,
use_space_char=False,
ch=False,
# ch_7000=7000,
ignore_index=100,
use_sos=False,
pos_len=False,
**kwargs):
self.use_sos = use_sos
super(CPPDLabelEncode,
self).__init__(max_text_length, character_dict_path,
use_space_char)
self.ch = ch
self.ignore_index = ignore_index
self.pos_len = pos_len
def __call__(self, data):
text = data['label']
if self.ch:
text, text_node_index, text_node_num = self.encodech(text)
if text is None:
return None
if len(text) > self.max_text_len:
return None
data['length'] = np.array(len(text))
# text.insert(0, 0)
if self.pos_len:
text_pos_node = [i_ for i_ in range(len(text), -1, -1)
] + [100] * (self.max_text_len - len(text))
else:
text_pos_node = [1] * (len(text) + 1) + [0] * (
self.max_text_len - len(text))
text.append(0)
text + [0] * (self.max_text_len - len(text))
text = text + [self.ignore_index
] * (self.max_text_len + 1 - len(text))
data['label'] = np.array(text)
data['label_node'] = np.array(text_node_num + text_pos_node)
data['label_index'] = np.array(text_node_index)
# data['label_ctc'] = np.array(ctc_text)
return data
else:
text, text_char_node, ch_order = self.encode(text)
if text is None:
return None
if len(text) > self.max_text_len:
return None
data['length'] = np.array(len(text))
# text.insert(0, 0)
if self.pos_len:
text_pos_node = [i_ for i_ in range(len(text), -1, -1)
] + [100] * (self.max_text_len - len(text))
else:
text_pos_node = [1] * (len(text) + 1) + [0] * (
self.max_text_len - len(text))
text.append(0)
text = text + [self.ignore_index
] * (self.max_text_len + 1 - len(text))
data['label'] = np.array(text)
data['label_node'] = np.array(text_char_node + text_pos_node)
data['label_order'] = np.array(ch_order)
return data
def add_special_char(self, dict_character):
if self.use_sos:
dict_character = ['<s>', '</s>'] + dict_character
else:
dict_character = ['</s>'] + dict_character
self.num_character = len(dict_character)
return dict_character
def encode(self, text):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if len(text) == 0:
return None, None, None
if self.lower:
text = text.lower()
text_node = [0 for _ in range(self.num_character)]
text_node[0] = 1
text_list = []
ch_order = []
order = 1
for char in text:
if char not in self.dict:
continue
text_list.append(self.dict[char])
text_node[self.dict[char]] += 1
ch_order.append(
[self.dict[char], text_node[self.dict[char]], order])
order += 1
no_ch_order = []
for char in self.character:
if char not in text:
no_ch_order.append([self.dict[char], 1, 0])
random.shuffle(no_ch_order)
ch_order = ch_order + no_ch_order
ch_order = ch_order[:self.max_text_len + 1]
if len(text_list) == 0 or len(text_list) > self.max_text_len:
return None, None, None
return text_list, text_node, ch_order.sort()
def encodech(self, text):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if len(text) == 0:
return None, None, None
if self.lower:
text = text.lower()
text_node_dict = {}
text_node_dict.update({0: 1})
character_index = [_ for _ in range(self.num_character)]
text_list = []
for char in text:
if char not in self.dict:
continue
i_c = self.dict[char]
text_list.append(i_c)
if i_c in text_node_dict.keys():
text_node_dict[i_c] += 1
else:
text_node_dict.update({i_c: 1})
for ic in list(text_node_dict.keys()):
character_index.remove(ic)
none_char_index = random.sample(character_index,
37 - len(list(text_node_dict.keys())))
for ic in none_char_index:
text_node_dict[ic] = 0
text_node_index = sorted(text_node_dict)
text_node_num = [text_node_dict[k] for k in text_node_index]
if len(text_list) == 0 or len(text_list) > self.max_text_len:
return None, None, None
return text_list, text_node_index, text_node_num