File size: 3,725 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict

from mmcv.transforms.base import BaseTransform
from mmdet.structures.mask import PolygonMasks, bitmap_to_polygon

from mmocr.registry import TRANSFORMS


@TRANSFORMS.register_module()
class MMDet2MMOCR(BaseTransform):
    """Convert transforms's data format from MMDet to MMOCR.

    Required Keys:

    - gt_masks (PolygonMasks | BitmapMasks) (optional)
    - gt_ignore_flags (np.bool) (optional)

    Added Keys:

    - gt_polygons (list[np.ndarray])
    - gt_ignored (np.ndarray)
    """

    def transform(self, results: Dict) -> Dict:
        """Convert MMDet's data format to MMOCR's data format.

        Args:
            results (Dict): Result dict containing the data to transform.

        Returns:
            (Dict): The transformed data.
        """
        # gt_masks -> gt_polygons
        if 'gt_masks' in results.keys():
            gt_polygons = []
            gt_masks = results.pop('gt_masks')
            if len(gt_masks) > 0:
                # PolygonMasks
                if isinstance(gt_masks[0], PolygonMasks):
                    gt_polygons = [mask[0] for mask in gt_masks.masks]
                # BitmapMasks
                else:
                    polygons = []
                    for mask in gt_masks.masks:
                        contours, _ = bitmap_to_polygon(mask)
                        polygons += [
                            contour.reshape(-1) for contour in contours
                        ]
                    # filter invalid polygons
                    gt_polygons = []
                    for polygon in polygons:
                        if len(polygon) < 6:
                            continue
                        gt_polygons.append(polygon)

            results['gt_polygons'] = gt_polygons
        # gt_ignore_flags -> gt_ignored
        if 'gt_ignore_flags' in results.keys():
            gt_ignored = results.pop('gt_ignore_flags')
            results['gt_ignored'] = gt_ignored

        return results

    def __repr__(self) -> str:
        repr_str = self.__class__.__name__
        return repr_str


@TRANSFORMS.register_module()
class MMOCR2MMDet(BaseTransform):
    """Convert transforms's data format from MMOCR to MMDet.

    Required Keys:

    - img_shape
    - gt_polygons (List[ndarray]) (optional)
    - gt_ignored (np.bool) (optional)

    Added Keys:

    - gt_masks (PolygonMasks | BitmapMasks) (optional)
    - gt_ignore_flags (np.bool) (optional)

    Args:
        poly2mask (bool): Whether to convert mask to bitmap. Default: True.
    """

    def __init__(self, poly2mask: bool = False) -> None:
        self.poly2mask = poly2mask

    def transform(self, results: Dict) -> Dict:
        """Convert MMOCR's data format to MMDet's data format.

        Args:
            results (Dict): Result dict containing the data to transform.

        Returns:
            (Dict): The transformed data.
        """
        # gt_polygons -> gt_masks
        if 'gt_polygons' in results.keys():
            gt_polygons = results.pop('gt_polygons')
            gt_polygons = [[gt_polygon] for gt_polygon in gt_polygons]
            gt_masks = PolygonMasks(gt_polygons, *results['img_shape'])

            if self.poly2mask:
                gt_masks = gt_masks.to_bitmap()

            results['gt_masks'] = gt_masks
        # gt_ignore_flags -> gt_ignored
        if 'gt_ignored' in results.keys():
            gt_ignored = results.pop('gt_ignored')
            results['gt_ignore_flags'] = gt_ignored

        return results

    def __repr__(self) -> str:
        repr_str = self.__class__.__name__
        repr_str += f'(poly2mask = {self.poly2mask})'
        return repr_str