OMG_Seg / seg /datasets /pipeliens /transforms.py
HarborYuan's picture
add omg code
b34d1d6
raw
history blame
No virus
5.07 kB
import numpy as np
from mmdet.registry import TRANSFORMS
from mmdet.structures.bbox import autocast_box_type
from mmcv.image.geometric import _scale_size
from mmcv.transforms import Resize as MMCV_Resize
from mmdet.datasets.transforms import Resize as MMDET_Resize
@TRANSFORMS.register_module()
class ResizeImage(MMCV_Resize):
"""Resize images only.
This transform resizes the input image according to ``scale`` or
``scale_factor``. Bboxes, masks, and seg map are then resized
with the same scale factor.
if ``scale`` and ``scale_factor`` are both set, it will use ``scale`` to
resize.
Required Keys:
- img
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
- gt_masks (BitmapMasks | PolygonMasks) (optional)
- gt_seg_map (np.uint8) (optional)
Modified Keys:
- img
- img_shape
- gt_bboxes
- gt_masks
- gt_seg_map
Added Keys:
- scale
- scale_factor
- keep_ratio
- homography_matrix
Args:
scale (int or tuple): Images scales for resizing. Defaults to None
scale_factor (float or tuple[float]): Scale factors for resizing.
Defaults to None.
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image. Defaults to False.
clip_object_border (bool): Whether to clip the objects
outside the border of the image. In some dataset like MOT17, the gt
bboxes are allowed to cross the border of images. Therefore, we
don't need to clip the gt bboxes in these cases. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""Transform function to resize images, bounding boxes and semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map',
'scale', 'scale_factor', 'height', 'width', and 'keep_ratio' keys
are updated in result dict.
"""
if self.scale:
results['scale'] = self.scale
else:
img_shape = results['img'].shape[:2]
results['scale'] = _scale_size(img_shape[::-1], self.scale_factor)
self._resize_img(results)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(scale={self.scale}, '
repr_str += f'scale_factor={self.scale_factor}, '
repr_str += f'keep_ratio={self.keep_ratio}, '
repr_str += f'clip_object_border={self.clip_object_border}), '
repr_str += f'backend={self.backend}), '
repr_str += f'interpolation={self.interpolation})'
return repr_str
@TRANSFORMS.register_module()
class ResizeSAM(MMDET_Resize):
def _resize_point_coords(self, results: dict) -> None:
if results.get('gt_point_coords', None) is not None:
results['gt_point_coords'] = results['gt_point_coords'] * results['scale_factor']
results['gt_point_coords'][..., 0] = np.clip(results['gt_point_coords'][..., 0], 0, results['img_shape'][1])
results['gt_point_coords'][..., 1] = np.clip(results['gt_point_coords'][..., 1], 0, results['img_shape'][0])
@autocast_box_type()
def transform(self, results: dict) -> dict:
"""Transform function to resize images, bounding boxes and semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map',
'scale', 'scale_factor', 'height', 'width', and 'keep_ratio' keys
are updated in result dict.
"""
if self.scale:
results['scale'] = self.scale
else:
img_shape = results['img'].shape[:2]
results['scale'] = _scale_size(img_shape[::-1], self.scale_factor)
self._resize_img(results)
self._resize_bboxes(results)
self._resize_masks(results)
self._resize_seg(results)
self._resize_point_coords(results)
self._record_homography_matrix(results)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(scale={self.scale}, '
repr_str += f'scale_factor={self.scale_factor}, '
repr_str += f'keep_ratio={self.keep_ratio}, '
repr_str += f'clip_object_border={self.clip_object_border}), '
repr_str += f'backend={self.backend}), '
repr_str += f'interpolation={self.interpolation})'
return repr_str