OpenOCR-Demo / openrec /postprocess /abinet_postprocess.py
topdu's picture
openocr demo
29f689c
raw
history blame
1.26 kB
import torch
from .nrtr_postprocess import NRTRLabelDecode
class ABINetLabelDecode(NRTRLabelDecode):
"""Convert between text-label and text-index."""
def __init__(self,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(ABINetLabelDecode, self).__init__(character_dict_path,
use_space_char)
def __call__(self, preds, batch=None, *args, **kwargs):
if isinstance(preds, dict):
if len(preds['align']) > 0:
preds = preds['align'][-1].detach().cpu().numpy()
else:
preds = preds['vision'].detach().cpu().numpy()
elif isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy()
else:
preds = preds
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if batch is None:
return text
label = self.decode(batch[1].cpu().numpy())
return text, label
def add_special_char(self, dict_character):
dict_character = ['</s>'] + dict_character
return dict_character