Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
import sys | |
from typing import Optional, Sequence, Tuple, Union | |
import cv2 | |
import numpy as np | |
import torch | |
from mmengine.logging import MMLogger | |
from shapely.geometry import Polygon | |
from mmocr.utils.polygon_utils import offset_polygon | |
from .base import BaseTextDetModuleLoss | |
class SegBasedModuleLoss(BaseTextDetModuleLoss): | |
"""Base class for the module loss of segmentation-based text detection | |
algorithms with some handy utilities.""" | |
def _generate_kernels( | |
self, | |
img_size: Tuple[int, int], | |
text_polys: Sequence[np.ndarray], | |
shrink_ratio: float, | |
max_shrink_dist: Union[float, int] = sys.maxsize, | |
ignore_flags: Optional[torch.Tensor] = None | |
) -> Tuple[np.ndarray, np.ndarray]: | |
"""Generate text instance kernels according to a shrink ratio. | |
Args: | |
img_size (tuple(int, int)): The image size of (height, width). | |
text_polys (Sequence[np.ndarray]): 2D array of text polygons. | |
shrink_ratio (float or int): The shrink ratio of kernel. | |
max_shrink_dist (float or int): The maximum shrinking distance. | |
ignore_flags (torch.BoolTensor, optional): Indicate whether the | |
corresponding text polygon is ignored. Defaults to None. | |
Returns: | |
tuple(ndarray, ndarray): The text instance kernels of shape | |
(height, width) and updated ignorance flags. | |
""" | |
assert isinstance(img_size, tuple) | |
assert isinstance(shrink_ratio, (float, int)) | |
logger: MMLogger = MMLogger.get_current_instance() | |
h, w = img_size | |
text_kernel = np.zeros((h, w), dtype=np.float32) | |
for text_ind, poly in enumerate(text_polys): | |
if ignore_flags is not None and ignore_flags[text_ind]: | |
continue | |
poly = poly.reshape(-1, 2).astype(np.int32) | |
poly_obj = Polygon(poly) | |
area = poly_obj.area | |
peri = poly_obj.length | |
distance = min( | |
int(area * (1 - shrink_ratio * shrink_ratio) / (peri + 0.001) + | |
0.5), max_shrink_dist) | |
shrunk_poly = offset_polygon(poly, -distance) | |
if len(shrunk_poly) == 0: | |
if ignore_flags is not None: | |
ignore_flags[text_ind] = True | |
continue | |
try: | |
shrunk_poly = shrunk_poly.reshape(-1, 2) | |
except Exception as e: | |
logger.info(f'{shrunk_poly} with error {e}') | |
if ignore_flags is not None: | |
ignore_flags[text_ind] = True | |
continue | |
cv2.fillPoly(text_kernel, [shrunk_poly.astype(np.int32)], | |
text_ind + 1) | |
return text_kernel, ignore_flags | |
def _generate_effective_mask(self, mask_size: Tuple[int, int], | |
ignored_polygons: Sequence[np.ndarray] | |
) -> np.ndarray: | |
"""Generate effective mask by setting the invalid regions to 0 and 1 | |
otherwise. | |
Args: | |
mask_size (tuple(int, int)): The mask size. | |
ignored_polygons (Sequence[ndarray]): 2-d array, representing all | |
the ignored polygons of the text region. | |
Returns: | |
mask (ndarray): The effective mask of shape (height, width). | |
""" | |
mask = np.ones(mask_size, dtype=np.uint8) | |
for poly in ignored_polygons: | |
instance = poly.astype(np.int32).reshape(1, -1, 2) | |
cv2.fillPoly(mask, instance, 0) | |
return mask | |