import numpy as np
from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
class ARLabelEncode(BaseRecLabelEncode):
"""Convert between text-label and text-index."""
BOS = ''
EOS = ''
PAD = ''
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(ARLabelEncode,
self).__init__(max_text_length, character_dict_path,
use_space_char)
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
data['length'] = np.array(len(text))
text = [self.dict[self.BOS]] + text + [self.dict[self.EOS]]
text = text + [self.dict[self.PAD]
] * (self.max_text_len + 2 - len(text))
data['label'] = np.array(text)
return data
def add_special_char(self, dict_character):
dict_character = [self.EOS] + dict_character + [self.BOS, self.PAD]
return dict_character