Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import cv2 | |
import numpy as np | |
import torch | |
import mmocr.utils as utils | |
from mmocr.models.builder import CONVERTORS | |
from .base import BaseConvertor | |
class SegConvertor(BaseConvertor): | |
"""Convert between text, index and tensor for segmentation based pipeline. | |
Args: | |
dict_type (str): Type of dict, should be either 'DICT36' or 'DICT90'. | |
dict_file (None|str): Character dict file path. If not none, the | |
file is of higher priority than dict_type. | |
dict_list (None|list[str]): Character list. If not none, the list | |
is of higher priority than dict_type, but lower than dict_file. | |
with_unknown (bool): If True, add `UKN` token to class. | |
lower (bool): If True, convert original string to lower case. | |
""" | |
def __init__(self, | |
dict_type='DICT36', | |
dict_file=None, | |
dict_list=None, | |
with_unknown=True, | |
lower=False, | |
**kwargs): | |
super().__init__(dict_type, dict_file, dict_list) | |
assert isinstance(with_unknown, bool) | |
assert isinstance(lower, bool) | |
self.with_unknown = with_unknown | |
self.lower = lower | |
self.update_dict() | |
def update_dict(self): | |
# background | |
self.idx2char.insert(0, '<BG>') | |
# unknown | |
self.unknown_idx = None | |
if self.with_unknown: | |
self.idx2char.append('<UKN>') | |
self.unknown_idx = len(self.idx2char) - 1 | |
# update char2idx | |
self.char2idx = {} | |
for idx, char in enumerate(self.idx2char): | |
self.char2idx[char] = idx | |
def tensor2str(self, output, img_metas=None): | |
"""Convert model output tensor to string labels. | |
Args: | |
output (tensor): Model outputs with size: N * C * H * W | |
img_metas (list[dict]): Each dict contains one image info. | |
Returns: | |
texts (list[str]): Decoded text labels. | |
scores (list[list[float]]): Decoded chars scores. | |
""" | |
assert utils.is_type_list(img_metas, dict) | |
assert len(img_metas) == output.size(0) | |
texts, scores = [], [] | |
for b in range(output.size(0)): | |
seg_pred = output[b].detach() | |
valid_width = int( | |
output.size(-1) * img_metas[b]['valid_ratio'] + 1) | |
seg_res = torch.argmax( | |
seg_pred[:, :, :valid_width], | |
dim=0).cpu().numpy().astype(np.int32) | |
seg_thr = np.where(seg_res == 0, 0, 255).astype(np.uint8) | |
_, labels, stats, centroids = cv2.connectedComponentsWithStats( | |
seg_thr) | |
component_num = stats.shape[0] | |
all_res = [] | |
for i in range(component_num): | |
temp_loc = (labels == i) | |
temp_value = seg_res[temp_loc] | |
temp_center = centroids[i] | |
temp_max_num = 0 | |
temp_max_cls = -1 | |
temp_total_num = 0 | |
for c in range(len(self.idx2char)): | |
c_num = np.sum(temp_value == c) | |
temp_total_num += c_num | |
if c_num > temp_max_num: | |
temp_max_num = c_num | |
temp_max_cls = c | |
if temp_max_cls == 0: | |
continue | |
temp_max_score = 1.0 * temp_max_num / temp_total_num | |
all_res.append( | |
[temp_max_cls, temp_center, temp_max_num, temp_max_score]) | |
all_res = sorted(all_res, key=lambda s: s[1][0]) | |
chars, char_scores = [], [] | |
for res in all_res: | |
temp_area = res[2] | |
if temp_area < 20: | |
continue | |
temp_char_index = res[0] | |
if temp_char_index >= len(self.idx2char): | |
temp_char = '' | |
elif temp_char_index <= 0: | |
temp_char = '' | |
elif temp_char_index == self.unknown_idx: | |
temp_char = '' | |
else: | |
temp_char = self.idx2char[temp_char_index] | |
chars.append(temp_char) | |
char_scores.append(res[3]) | |
text = ''.join(chars) | |
texts.append(text) | |
scores.append(char_scores) | |
return texts, scores | |