Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from mmdet.models.detectors import MaskRCNN | |
from mmocr.core import seg2boundary | |
from mmocr.models.builder import DETECTORS | |
from .text_detector_mixin import TextDetectorMixin | |
class OCRMaskRCNN(TextDetectorMixin, MaskRCNN): | |
"""Mask RCNN tailored for OCR.""" | |
def __init__(self, | |
backbone, | |
rpn_head, | |
roi_head, | |
train_cfg, | |
test_cfg, | |
neck=None, | |
pretrained=None, | |
text_repr_type='quad', | |
show_score=False, | |
init_cfg=None): | |
TextDetectorMixin.__init__(self, show_score) | |
MaskRCNN.__init__( | |
self, | |
backbone=backbone, | |
neck=neck, | |
rpn_head=rpn_head, | |
roi_head=roi_head, | |
train_cfg=train_cfg, | |
test_cfg=test_cfg, | |
pretrained=pretrained, | |
init_cfg=init_cfg) | |
assert text_repr_type in ['quad', 'poly'] | |
self.text_repr_type = text_repr_type | |
def get_boundary(self, results): | |
"""Convert segmentation into text boundaries. | |
Args: | |
results (tuple): The result tuple. The first element is | |
segmentation while the second is its scores. | |
Returns: | |
dict: A result dict containing 'boundary_result'. | |
""" | |
assert isinstance(results, tuple) | |
instance_num = len(results[1][0]) | |
boundaries = [] | |
for i in range(instance_num): | |
seg = results[1][0][i] | |
score = results[0][0][i][-1] | |
boundary = seg2boundary(seg, self.text_repr_type, score) | |
if boundary is not None: | |
boundaries.append(boundary) | |
results = dict(boundary_result=boundaries) | |
return results | |
def simple_test(self, img, img_metas, proposals=None, rescale=False): | |
results = super().simple_test(img, img_metas, proposals, rescale) | |
boundaries = self.get_boundary(results[0]) | |
boundaries = boundaries if isinstance(boundaries, | |
list) else [boundaries] | |
return boundaries | |