Spaces:
Running
Running
File size: 6,051 Bytes
29f689c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import random
import numpy as np
from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
class CPPDLabelEncode(BaseRecLabelEncode):
"""Convert between text-label and text-index."""
def __init__(
self,
max_text_length,
character_dict_path=None,
use_space_char=False,
ch=False,
# ch_7000=7000,
ignore_index=100,
use_sos=False,
pos_len=False,
**kwargs):
self.use_sos = use_sos
super(CPPDLabelEncode,
self).__init__(max_text_length, character_dict_path,
use_space_char)
self.ch = ch
self.ignore_index = ignore_index
self.pos_len = pos_len
def __call__(self, data):
text = data['label']
if self.ch:
text, text_node_index, text_node_num = self.encodech(text)
if text is None:
return None
if len(text) > self.max_text_len:
return None
data['length'] = np.array(len(text))
# text.insert(0, 0)
if self.pos_len:
text_pos_node = [i_ for i_ in range(len(text), -1, -1)
] + [100] * (self.max_text_len - len(text))
else:
text_pos_node = [1] * (len(text) + 1) + [0] * (
self.max_text_len - len(text))
text.append(0)
text + [0] * (self.max_text_len - len(text))
text = text + [self.ignore_index
] * (self.max_text_len + 1 - len(text))
data['label'] = np.array(text)
data['label_node'] = np.array(text_node_num + text_pos_node)
data['label_index'] = np.array(text_node_index)
# data['label_ctc'] = np.array(ctc_text)
return data
else:
text, text_char_node, ch_order = self.encode(text)
if text is None:
return None
if len(text) > self.max_text_len:
return None
data['length'] = np.array(len(text))
# text.insert(0, 0)
if self.pos_len:
text_pos_node = [i_ for i_ in range(len(text), -1, -1)
] + [100] * (self.max_text_len - len(text))
else:
text_pos_node = [1] * (len(text) + 1) + [0] * (
self.max_text_len - len(text))
text.append(0)
text = text + [self.ignore_index
] * (self.max_text_len + 1 - len(text))
data['label'] = np.array(text)
data['label_node'] = np.array(text_char_node + text_pos_node)
data['label_order'] = np.array(ch_order)
return data
def add_special_char(self, dict_character):
if self.use_sos:
dict_character = ['<s>', '</s>'] + dict_character
else:
dict_character = ['</s>'] + dict_character
self.num_character = len(dict_character)
return dict_character
def encode(self, text):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if len(text) == 0:
return None, None, None
if self.lower:
text = text.lower()
text_node = [0 for _ in range(self.num_character)]
text_node[0] = 1
text_list = []
ch_order = []
order = 1
for char in text:
if char not in self.dict:
continue
text_list.append(self.dict[char])
text_node[self.dict[char]] += 1
ch_order.append(
[self.dict[char], text_node[self.dict[char]], order])
order += 1
no_ch_order = []
for char in self.character:
if char not in text:
no_ch_order.append([self.dict[char], 1, 0])
random.shuffle(no_ch_order)
ch_order = ch_order + no_ch_order
ch_order = ch_order[:self.max_text_len + 1]
if len(text_list) == 0 or len(text_list) > self.max_text_len:
return None, None, None
return text_list, text_node, ch_order.sort()
def encodech(self, text):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if len(text) == 0:
return None, None, None
if self.lower:
text = text.lower()
text_node_dict = {}
text_node_dict.update({0: 1})
character_index = [_ for _ in range(self.num_character)]
text_list = []
for char in text:
if char not in self.dict:
continue
i_c = self.dict[char]
text_list.append(i_c)
if i_c in text_node_dict.keys():
text_node_dict[i_c] += 1
else:
text_node_dict.update({i_c: 1})
for ic in list(text_node_dict.keys()):
character_index.remove(ic)
none_char_index = random.sample(character_index,
37 - len(list(text_node_dict.keys())))
for ic in none_char_index:
text_node_dict[ic] = 0
text_node_index = sorted(text_node_dict)
text_node_num = [text_node_dict[k] for k in text_node_index]
if len(text_list) == 0 or len(text_list) > self.max_text_len:
return None, None, None
return text_list, text_node_index, text_node_num
|