File size: 3,656 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright (c) OpenMMLab. All rights reserved.
import sys
from typing import Optional, Sequence, Tuple, Union

import cv2
import numpy as np
import torch
from mmengine.logging import MMLogger
from shapely.geometry import Polygon

from mmocr.utils.polygon_utils import offset_polygon
from .base import BaseTextDetModuleLoss


class SegBasedModuleLoss(BaseTextDetModuleLoss):
    """Base class for the module loss of segmentation-based text detection
    algorithms with some handy utilities."""

    def _generate_kernels(
        self,
        img_size: Tuple[int, int],
        text_polys: Sequence[np.ndarray],
        shrink_ratio: float,
        max_shrink_dist: Union[float, int] = sys.maxsize,
        ignore_flags: Optional[torch.Tensor] = None
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Generate text instance kernels according to a shrink ratio.

        Args:
            img_size (tuple(int, int)): The image size of (height, width).
            text_polys (Sequence[np.ndarray]): 2D array of text polygons.
            shrink_ratio (float or int): The shrink ratio of kernel.
            max_shrink_dist (float or int): The maximum shrinking distance.
            ignore_flags (torch.BoolTensor, optional): Indicate whether the
                corresponding text polygon is ignored. Defaults to None.

        Returns:
            tuple(ndarray, ndarray): The text instance kernels of shape
                (height, width) and updated ignorance flags.
        """
        assert isinstance(img_size, tuple)
        assert isinstance(shrink_ratio, (float, int))

        logger: MMLogger = MMLogger.get_current_instance()

        h, w = img_size
        text_kernel = np.zeros((h, w), dtype=np.float32)

        for text_ind, poly in enumerate(text_polys):
            if ignore_flags is not None and ignore_flags[text_ind]:
                continue
            poly = poly.reshape(-1, 2).astype(np.int32)
            poly_obj = Polygon(poly)
            area = poly_obj.area
            peri = poly_obj.length
            distance = min(
                int(area * (1 - shrink_ratio * shrink_ratio) / (peri + 0.001) +
                    0.5), max_shrink_dist)
            shrunk_poly = offset_polygon(poly, -distance)

            if len(shrunk_poly) == 0:
                if ignore_flags is not None:
                    ignore_flags[text_ind] = True
                continue

            try:
                shrunk_poly = shrunk_poly.reshape(-1, 2)
            except Exception as e:
                logger.info(f'{shrunk_poly} with error {e}')
                if ignore_flags is not None:
                    ignore_flags[text_ind] = True
                continue

            cv2.fillPoly(text_kernel, [shrunk_poly.astype(np.int32)],
                         text_ind + 1)

        return text_kernel, ignore_flags

    def _generate_effective_mask(self, mask_size: Tuple[int, int],
                                 ignored_polygons: Sequence[np.ndarray]
                                 ) -> np.ndarray:
        """Generate effective mask by setting the invalid regions to 0 and 1
        otherwise.

        Args:
            mask_size (tuple(int, int)): The mask size.
            ignored_polygons (Sequence[ndarray]): 2-d array, representing all
                the ignored polygons of the text region.

        Returns:
            mask (ndarray): The effective mask of shape (height, width).
        """

        mask = np.ones(mask_size, dtype=np.uint8)

        for poly in ignored_polygons:
            instance = poly.astype(np.int32).reshape(1, -1, 2)
            cv2.fillPoly(mask, instance, 0)

        return mask