Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import mmocr.utils as utils | |
from mmocr.models.builder import CONVERTORS | |
from .attn import AttnConvertor | |
class ABIConvertor(AttnConvertor): | |
"""Convert between text, index and tensor for encoder-decoder based | |
pipeline. Modified from AttnConvertor to get closer to ABINet's original | |
implementation. | |
Args: | |
dict_type (str): Type of dict, should be one of {'DICT36', 'DICT90'}. | |
dict_file (None|str): Character dict file path. If not none, | |
higher priority than dict_type. | |
dict_list (None|list[str]): Character list. If not none, higher | |
priority than dict_type, but lower than dict_file. | |
with_unknown (bool): If True, add `UKN` token to class. | |
max_seq_len (int): Maximum sequence length of label. | |
lower (bool): If True, convert original string to lower case. | |
start_end_same (bool): Whether use the same index for | |
start and end token or not. Default: True. | |
""" | |
def str2tensor(self, strings): | |
""" | |
Convert text-string into tensor. Different from | |
:obj:`mmocr.models.textrecog.convertors.AttnConvertor`, the targets | |
field returns target index no longer than max_seq_len (EOS token | |
included). | |
Args: | |
strings (list[str]): For instance, ['hello', 'world'] | |
Returns: | |
dict: A dict with two tensors. | |
- | targets (list[Tensor]): [torch.Tensor([1,2,3,3,4,8]), | |
torch.Tensor([5,4,6,3,7,8])] | |
- | padded_targets (Tensor): Tensor of shape | |
(bsz * max_seq_len)). | |
""" | |
assert utils.is_type_list(strings, str) | |
tensors, padded_targets = [], [] | |
indexes = self.str2idx(strings) | |
for index in indexes: | |
tensor = torch.LongTensor(index[:self.max_seq_len - 1] + | |
[self.end_idx]) | |
tensors.append(tensor) | |
# target tensor for loss | |
src_target = torch.LongTensor(tensor.size(0) + 1).fill_(0) | |
src_target[0] = self.start_idx | |
src_target[1:] = tensor | |
padded_target = (torch.ones(self.max_seq_len) * | |
self.padding_idx).long() | |
char_num = src_target.size(0) | |
if char_num > self.max_seq_len: | |
padded_target = src_target[:self.max_seq_len] | |
else: | |
padded_target[:char_num] = src_target | |
padded_targets.append(padded_target) | |
padded_targets = torch.stack(padded_targets, 0).long() | |
return {'targets': tensors, 'padded_targets': padded_targets} | |