Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Dict | |
from mmcv.transforms.base import BaseTransform | |
from mmdet.structures.mask import PolygonMasks, bitmap_to_polygon | |
from mmocr.registry import TRANSFORMS | |
class MMDet2MMOCR(BaseTransform): | |
"""Convert transforms's data format from MMDet to MMOCR. | |
Required Keys: | |
- gt_masks (PolygonMasks | BitmapMasks) (optional) | |
- gt_ignore_flags (np.bool) (optional) | |
Added Keys: | |
- gt_polygons (list[np.ndarray]) | |
- gt_ignored (np.ndarray) | |
""" | |
def transform(self, results: Dict) -> Dict: | |
"""Convert MMDet's data format to MMOCR's data format. | |
Args: | |
results (Dict): Result dict containing the data to transform. | |
Returns: | |
(Dict): The transformed data. | |
""" | |
# gt_masks -> gt_polygons | |
if 'gt_masks' in results.keys(): | |
gt_polygons = [] | |
gt_masks = results.pop('gt_masks') | |
if len(gt_masks) > 0: | |
# PolygonMasks | |
if isinstance(gt_masks[0], PolygonMasks): | |
gt_polygons = [mask[0] for mask in gt_masks.masks] | |
# BitmapMasks | |
else: | |
polygons = [] | |
for mask in gt_masks.masks: | |
contours, _ = bitmap_to_polygon(mask) | |
polygons += [ | |
contour.reshape(-1) for contour in contours | |
] | |
# filter invalid polygons | |
gt_polygons = [] | |
for polygon in polygons: | |
if len(polygon) < 6: | |
continue | |
gt_polygons.append(polygon) | |
results['gt_polygons'] = gt_polygons | |
# gt_ignore_flags -> gt_ignored | |
if 'gt_ignore_flags' in results.keys(): | |
gt_ignored = results.pop('gt_ignore_flags') | |
results['gt_ignored'] = gt_ignored | |
return results | |
def __repr__(self) -> str: | |
repr_str = self.__class__.__name__ | |
return repr_str | |
class MMOCR2MMDet(BaseTransform): | |
"""Convert transforms's data format from MMOCR to MMDet. | |
Required Keys: | |
- img_shape | |
- gt_polygons (List[ndarray]) (optional) | |
- gt_ignored (np.bool) (optional) | |
Added Keys: | |
- gt_masks (PolygonMasks | BitmapMasks) (optional) | |
- gt_ignore_flags (np.bool) (optional) | |
Args: | |
poly2mask (bool): Whether to convert mask to bitmap. Default: True. | |
""" | |
def __init__(self, poly2mask: bool = False) -> None: | |
self.poly2mask = poly2mask | |
def transform(self, results: Dict) -> Dict: | |
"""Convert MMOCR's data format to MMDet's data format. | |
Args: | |
results (Dict): Result dict containing the data to transform. | |
Returns: | |
(Dict): The transformed data. | |
""" | |
# gt_polygons -> gt_masks | |
if 'gt_polygons' in results.keys(): | |
gt_polygons = results.pop('gt_polygons') | |
gt_polygons = [[gt_polygon] for gt_polygon in gt_polygons] | |
gt_masks = PolygonMasks(gt_polygons, *results['img_shape']) | |
if self.poly2mask: | |
gt_masks = gt_masks.to_bitmap() | |
results['gt_masks'] = gt_masks | |
# gt_ignore_flags -> gt_ignored | |
if 'gt_ignored' in results.keys(): | |
gt_ignored = results.pop('gt_ignored') | |
results['gt_ignore_flags'] = gt_ignored | |
return results | |
def __repr__(self) -> str: | |
repr_str = self.__class__.__name__ | |
repr_str += f'(poly2mask = {self.poly2mask})' | |
return repr_str | |