MAERec-Gradio / mmocr /models /textdet /module_losses /seg_based_module_loss.py
Mountchicken's picture
Upload 704 files
9bf4bd7
raw
history blame
3.66 kB
# Copyright (c) OpenMMLab. All rights reserved.
import sys
from typing import Optional, Sequence, Tuple, Union
import cv2
import numpy as np
import torch
from mmengine.logging import MMLogger
from shapely.geometry import Polygon
from mmocr.utils.polygon_utils import offset_polygon
from .base import BaseTextDetModuleLoss
class SegBasedModuleLoss(BaseTextDetModuleLoss):
"""Base class for the module loss of segmentation-based text detection
algorithms with some handy utilities."""
def _generate_kernels(
self,
img_size: Tuple[int, int],
text_polys: Sequence[np.ndarray],
shrink_ratio: float,
max_shrink_dist: Union[float, int] = sys.maxsize,
ignore_flags: Optional[torch.Tensor] = None
) -> Tuple[np.ndarray, np.ndarray]:
"""Generate text instance kernels according to a shrink ratio.
Args:
img_size (tuple(int, int)): The image size of (height, width).
text_polys (Sequence[np.ndarray]): 2D array of text polygons.
shrink_ratio (float or int): The shrink ratio of kernel.
max_shrink_dist (float or int): The maximum shrinking distance.
ignore_flags (torch.BoolTensor, optional): Indicate whether the
corresponding text polygon is ignored. Defaults to None.
Returns:
tuple(ndarray, ndarray): The text instance kernels of shape
(height, width) and updated ignorance flags.
"""
assert isinstance(img_size, tuple)
assert isinstance(shrink_ratio, (float, int))
logger: MMLogger = MMLogger.get_current_instance()
h, w = img_size
text_kernel = np.zeros((h, w), dtype=np.float32)
for text_ind, poly in enumerate(text_polys):
if ignore_flags is not None and ignore_flags[text_ind]:
continue
poly = poly.reshape(-1, 2).astype(np.int32)
poly_obj = Polygon(poly)
area = poly_obj.area
peri = poly_obj.length
distance = min(
int(area * (1 - shrink_ratio * shrink_ratio) / (peri + 0.001) +
0.5), max_shrink_dist)
shrunk_poly = offset_polygon(poly, -distance)
if len(shrunk_poly) == 0:
if ignore_flags is not None:
ignore_flags[text_ind] = True
continue
try:
shrunk_poly = shrunk_poly.reshape(-1, 2)
except Exception as e:
logger.info(f'{shrunk_poly} with error {e}')
if ignore_flags is not None:
ignore_flags[text_ind] = True
continue
cv2.fillPoly(text_kernel, [shrunk_poly.astype(np.int32)],
text_ind + 1)
return text_kernel, ignore_flags
def _generate_effective_mask(self, mask_size: Tuple[int, int],
ignored_polygons: Sequence[np.ndarray]
) -> np.ndarray:
"""Generate effective mask by setting the invalid regions to 0 and 1
otherwise.
Args:
mask_size (tuple(int, int)): The mask size.
ignored_polygons (Sequence[ndarray]): 2-d array, representing all
the ignored polygons of the text region.
Returns:
mask (ndarray): The effective mask of shape (height, width).
"""
mask = np.ones(mask_size, dtype=np.uint8)
for poly in ignored_polygons:
instance = poly.astype(np.int32).reshape(1, -1, 2)
cv2.fillPoly(mask, instance, 0)
return mask