# Copyright (c) OpenMMLab. All rights reserved. import math from typing import Dict, Tuple import cv2 import mmcv import numpy as np from mmcv.transforms import Resize as MMCV_Resize from mmcv.transforms.base import BaseTransform from mmcv.transforms.utils import avoid_cache_randomness, cache_randomness from mmocr.registry import TRANSFORMS from mmocr.utils import (bbox2poly, crop_polygon, is_poly_inside_rect, poly2bbox, poly2shapely, poly_make_valid, remove_pipeline_elements, rescale_polygon, shapely2poly) from .wrappers import ImgAugWrapper @TRANSFORMS.register_module() @avoid_cache_randomness class RandomCrop(BaseTransform): """Randomly crop images and make sure to contain at least one intact instance. Required Keys: - img - gt_polygons - gt_bboxes - gt_bboxes_labels - gt_ignored - gt_texts (optional) Modified Keys: - img - img_shape - gt_polygons - gt_bboxes - gt_bboxes_labels - gt_ignored - gt_texts (optional) Args: min_side_ratio (float): The ratio of the shortest edge of the cropped image to the original image size. """ def __init__(self, min_side_ratio: float = 0.4) -> None: if not 0. <= min_side_ratio <= 1.: raise ValueError('`min_side_ratio` should be in range [0, 1],') self.min_side_ratio = min_side_ratio def _sample_valid_start_end(self, valid_array: np.ndarray, min_len: int, max_start_idx: int, min_end_idx: int) -> Tuple[int, int]: """Sample a start and end idx on a given axis that contains at least one polygon. There should be at least one intact polygon bounded by max_start_idx and min_end_idx. Args: valid_array (ndarray): A 0-1 mask 1D array indicating valid regions on the axis. 0 indicates text regions which are not allowed to be sampled from. min_len (int): Minimum distance between two start and end points. max_start_idx (int): The maximum start index. min_end_idx (int): The minimum end index. Returns: tuple(int, int): Start and end index on a given axis, where 0 <= start < max_start_idx and min_end_idx <= end < len(valid_array). """ assert isinstance(min_len, int) assert len(valid_array) > min_len start_array = valid_array.copy() max_start_idx = min(len(start_array) - min_len, max_start_idx) start_array[max_start_idx:] = 0 start_array[0] = 1 diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0]) region_starts = np.where(diff_array < 0)[0] region_ends = np.where(diff_array > 0)[0] region_ind = np.random.randint(0, len(region_starts)) start = np.random.randint(region_starts[region_ind], region_ends[region_ind]) end_array = valid_array.copy() min_end_idx = max(start + min_len, min_end_idx) end_array[:min_end_idx] = 0 end_array[-1] = 1 diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0]) region_starts = np.where(diff_array < 0)[0] region_ends = np.where(diff_array > 0)[0] region_ind = np.random.randint(0, len(region_starts)) # Note that end index will never be region_ends[region_ind] # and therefore end index is always in range [0, w+1] end = np.random.randint(region_starts[region_ind], region_ends[region_ind]) return start, end def _sample_crop_box(self, img_size: Tuple[int, int], results: Dict) -> np.ndarray: """Generate crop box which only contains intact polygon instances with the number >= 1. Args: img_size (tuple(int, int)): The image size (h, w). results (dict): The results dict. Returns: ndarray: Crop area in shape (4, ). """ assert isinstance(img_size, tuple) h, w = img_size[:2] # Crop box can be represented by any integer numbers in # range [0, w] and [0, h] x_valid_array = np.ones(w + 1, dtype=np.int32) y_valid_array = np.ones(h + 1, dtype=np.int32) polygons = results['gt_polygons'] # Randomly select a polygon that must be inside # the cropped region kept_poly_idx = np.random.randint(0, len(polygons)) for i, polygon in enumerate(polygons): polygon = polygon.reshape((-1, 2)) clip_x = np.clip(polygon[:, 0], 0, w) clip_y = np.clip(polygon[:, 1], 0, h) min_x = np.floor(np.min(clip_x)).astype(np.int32) min_y = np.floor(np.min(clip_y)).astype(np.int32) max_x = np.ceil(np.max(clip_x)).astype(np.int32) max_y = np.ceil(np.max(clip_y)).astype(np.int32) x_valid_array[min_x:max_x] = 0 y_valid_array[min_y:max_y] = 0 if i == kept_poly_idx: max_x_start = min_x min_x_end = max_x max_y_start = min_y min_y_end = max_y min_w = int(w * self.min_side_ratio) min_h = int(h * self.min_side_ratio) x1, x2 = self._sample_valid_start_end(x_valid_array, min_w, max_x_start, min_x_end) y1, y2 = self._sample_valid_start_end(y_valid_array, min_h, max_y_start, min_y_end) return np.array([x1, y1, x2, y2]) def _crop_img(self, img: np.ndarray, bbox: np.ndarray) -> np.ndarray: """Crop image given a bbox region. Args: img (ndarray): Image. bbox (ndarray): Cropping region in shape (4, ) Returns: ndarray: Cropped image. """ assert img.ndim == 3 h, w, _ = img.shape assert 0 <= bbox[1] < bbox[3] <= h assert 0 <= bbox[0] < bbox[2] <= w return img[bbox[1]:bbox[3], bbox[0]:bbox[2]] def transform(self, results: Dict) -> Dict: """Applying random crop on results. Args: results (dict): Result dict contains the data to transform. Returns: dict: The transformed data. """ if len(results['gt_polygons']) < 1: return results crop_box = self._sample_crop_box(results['img'].shape, results) img = self._crop_img(results['img'], crop_box) results['img'] = img results['img_shape'] = img.shape[:2] crop_x = crop_box[0] crop_y = crop_box[1] crop_w = crop_box[2] - crop_box[0] crop_h = crop_box[3] - crop_box[1] labels = results['gt_bboxes_labels'] valid_labels = [] ignored = results['gt_ignored'] valid_ignored = [] if 'gt_texts' in results: valid_texts = [] texts = results['gt_texts'] polys = results['gt_polygons'] valid_polys = [] for idx, poly in enumerate(polys): poly = poly.reshape(-1, 2) poly = (poly - (crop_x, crop_y)).flatten() if is_poly_inside_rect(poly, [0, 0, crop_w, crop_h]): valid_polys.append(poly) valid_labels.append(labels[idx]) valid_ignored.append(ignored[idx]) if 'gt_texts' in results: valid_texts.append(texts[idx]) results['gt_polygons'] = valid_polys results['gt_bboxes_labels'] = np.array(valid_labels, dtype=np.int64) results['gt_ignored'] = np.array(valid_ignored, dtype=bool) if 'gt_texts' in results: results['gt_texts'] = valid_texts valid_bboxes = [poly2bbox(poly) for poly in results['gt_polygons']] results['gt_bboxes'] = np.array(valid_bboxes).astype( np.float32).reshape(-1, 4) return results def __repr__(self) -> str: repr_str = self.__class__.__name__ repr_str += f'(min_side_ratio = {self.min_side_ratio})' return repr_str @TRANSFORMS.register_module() class RandomRotate(BaseTransform): """Randomly rotate the image, boxes, and polygons. For recognition task, only the image will be rotated. If set ``use_canvas`` as True, the shape of rotated image might be modified based on the rotated angle size, otherwise, the image will keep the shape before rotation. Required Keys: - img - img_shape - gt_bboxes (optional) - gt_polygons (optional) Modified Keys: - img - img_shape (optional) - gt_bboxes (optional) - gt_polygons (optional) Added Keys: - rotated_angle Args: max_angle (int): The maximum rotation angle (can be bigger than 180 or a negative). Defaults to 10. pad_with_fixed_color (bool): The flag for whether to pad rotated image with fixed value. Defaults to False. pad_value (tuple[int, int, int]): The color value for padding rotated image. Defaults to (0, 0, 0). use_canvas (bool): Whether to create a canvas for rotated image. Defaults to False. If set true, the image shape may be modified. """ def __init__( self, max_angle: int = 10, pad_with_fixed_color: bool = False, pad_value: Tuple[int, int, int] = (0, 0, 0), use_canvas: bool = False, ) -> None: if not isinstance(max_angle, int): raise TypeError('`max_angle` should be an integer' f', but got {type(max_angle)} instead') if not isinstance(pad_with_fixed_color, bool): raise TypeError('`pad_with_fixed_color` should be a bool, ' f'but got {type(pad_with_fixed_color)} instead') if not isinstance(pad_value, (list, tuple)): raise TypeError('`pad_value` should be a list or tuple, ' f'but got {type(pad_value)} instead') if len(pad_value) != 3: raise ValueError('`pad_value` should contain three integers') if not isinstance(pad_value[0], int) or not isinstance( pad_value[1], int) or not isinstance(pad_value[2], int): raise ValueError('`pad_value` should contain three integers') self.max_angle = max_angle self.pad_with_fixed_color = pad_with_fixed_color self.pad_value = pad_value self.use_canvas = use_canvas @cache_randomness def _sample_angle(self, max_angle: int) -> float: """Sampling a random angle for rotation. Args: max_angle (int): Maximum rotation angle Returns: float: The random angle used for rotation """ angle = np.random.random_sample() * 2 * max_angle - max_angle return angle @staticmethod def _cal_canvas_size(ori_size: Tuple[int, int], degree: int) -> Tuple[int, int]: """Calculate the canvas size. Args: ori_size (Tuple[int, int]): The original image size (height, width) degree (int): The rotation angle Returns: Tuple[int, int]: The size of the canvas """ assert isinstance(ori_size, tuple) angle = degree * math.pi / 180.0 h, w = ori_size[:2] cos = math.cos(angle) sin = math.sin(angle) canvas_h = int(w * math.fabs(sin) + h * math.fabs(cos)) canvas_w = int(w * math.fabs(cos) + h * math.fabs(sin)) canvas_size = (canvas_h, canvas_w) return canvas_size @staticmethod def _rotate_points(center: Tuple[float, float], points: np.array, theta: float, center_shift: Tuple[int, int] = (0, 0)) -> np.array: """Rotating a set of points according to the given theta. Args: center (Tuple[float, float]): The coordinate of the canvas center points (np.array): A set of points needed to be rotated theta (float): Rotation angle center_shift (Tuple[int, int]): The shifting offset of the center coordinate Returns: np.array: The rotated coordinates of the input points """ (center_x, center_y) = center center_y = -center_y x, y = points[::2], points[1::2] y = -y theta = theta / 180 * math.pi cos = math.cos(theta) sin = math.sin(theta) x = (x - center_x) y = (y - center_y) _x = center_x + x * cos - y * sin + center_shift[0] _y = -(center_y + x * sin + y * cos) + center_shift[1] points[::2], points[1::2] = _x, _y return points def _rotate_img(self, results: Dict) -> Tuple[int, int]: """Rotating the input image based on the given angle. Args: results (dict): Result dict containing the data to transform. Returns: Tuple[int, int]: The shifting offset of the center point. """ if results.get('img', None) is not None: h = results['img'].shape[0] w = results['img'].shape[1] rotation_matrix = cv2.getRotationMatrix2D( (w / 2, h / 2), results['rotated_angle'], 1) canvas_size = self._cal_canvas_size((h, w), results['rotated_angle']) center_shift = (int( (canvas_size[1] - w) / 2), int((canvas_size[0] - h) / 2)) rotation_matrix[0, 2] += int((canvas_size[1] - w) / 2) rotation_matrix[1, 2] += int((canvas_size[0] - h) / 2) if self.pad_with_fixed_color: rotated_img = cv2.warpAffine( results['img'], rotation_matrix, (canvas_size[1], canvas_size[0]), flags=cv2.INTER_NEAREST, borderValue=self.pad_value) else: mask = np.zeros_like(results['img']) (h_ind, w_ind) = (np.random.randint(0, h * 7 // 8), np.random.randint(0, w * 7 // 8)) img_cut = results['img'][h_ind:(h_ind + h // 9), w_ind:(w_ind + w // 9)] img_cut = mmcv.imresize(img_cut, (canvas_size[1], canvas_size[0])) mask = cv2.warpAffine( mask, rotation_matrix, (canvas_size[1], canvas_size[0]), borderValue=[1, 1, 1]) rotated_img = cv2.warpAffine( results['img'], rotation_matrix, (canvas_size[1], canvas_size[0]), borderValue=[0, 0, 0]) rotated_img = rotated_img + img_cut * mask results['img'] = rotated_img else: raise ValueError('`img` is not found in results') return center_shift def _rotate_bboxes(self, results: Dict, center_shift: Tuple[int, int]) -> None: """Rotating the bounding boxes based on the given angle. Args: results (dict): Result dict containing the data to transform. center_shift (Tuple[int, int]): The shifting offset of the center point """ if results.get('gt_bboxes', None) is not None: height, width = results['img_shape'] box_list = [] for box in results['gt_bboxes']: rotated_box = self._rotate_points((width / 2, height / 2), bbox2poly(box), results['rotated_angle'], center_shift) rotated_box = poly2bbox(rotated_box) box_list.append(rotated_box) results['gt_bboxes'] = np.array( box_list, dtype=np.float32).reshape(-1, 4) def _rotate_polygons(self, results: Dict, center_shift: Tuple[int, int]) -> None: """Rotating the polygons based on the given angle. Args: results (dict): Result dict containing the data to transform. center_shift (Tuple[int, int]): The shifting offset of the center point """ if results.get('gt_polygons', None) is not None: height, width = results['img_shape'] polygon_list = [] for poly in results['gt_polygons']: rotated_poly = self._rotate_points( (width / 2, height / 2), poly, results['rotated_angle'], center_shift) polygon_list.append(rotated_poly) results['gt_polygons'] = polygon_list def transform(self, results: Dict) -> Dict: """Applying random rotate on results. Args: results (Dict): Result dict containing the data to transform. center_shift (Tuple[int, int]): The shifting offset of the center point Returns: dict: The transformed data """ # TODO rotate char_quads & char_rects for SegOCR if self.use_canvas: results['rotated_angle'] = self._sample_angle(self.max_angle) # rotate image center_shift = self._rotate_img(results) # rotate gt_bboxes self._rotate_bboxes(results, center_shift) # rotate gt_polygons self._rotate_polygons(results, center_shift) results['img_shape'] = (results['img'].shape[0], results['img'].shape[1]) else: args = [ dict( cls='Affine', rotate=[-self.max_angle, self.max_angle], backend='cv2', order=0) # order=0 -> cv2.INTER_NEAREST ] imgaug_transform = ImgAugWrapper(args) results = imgaug_transform(results) return results def __repr__(self) -> str: repr_str = self.__class__.__name__ repr_str += f'(max_angle = {self.max_angle}' repr_str += f', pad_with_fixed_color = {self.pad_with_fixed_color}' repr_str += f', pad_value = {self.pad_value}' repr_str += f', use_canvas = {self.use_canvas})' return repr_str @TRANSFORMS.register_module() class Resize(MMCV_Resize): """Resize image & bboxes & polygons. This transform resizes the input image according to ``scale`` or ``scale_factor``. Bboxes and polygons are then resized with the same scale factor. if ``scale`` and ``scale_factor`` are both set, it will use ``scale`` to resize. Required Keys: - img - img_shape - gt_bboxes - gt_polygons Modified Keys: - img - img_shape - gt_bboxes - gt_polygons Added Keys: - scale - scale_factor - keep_ratio Args: scale (int or tuple): Image scales for resizing. Defaults to None. scale_factor (float or tuple[float, float]): Scale factors for resizing. It's either a factor applicable to both dimensions or in the form of (scale_w, scale_h). Defaults to None. keep_ratio (bool): Whether to keep the aspect ratio when resizing the image. Defaults to False. clip_object_border (bool): Whether to clip the objects outside the border of the image. Defaults to True. backend (str): Image resize backend, choices are 'cv2' and 'pillow'. These two backends generates slightly different results. Defaults to 'cv2'. interpolation (str): Interpolation method, accepted values are "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' backend, "nearest", "bilinear" for 'pillow' backend. Defaults to 'bilinear'. """ def _resize_img(self, results: dict) -> None: """Resize images with ``results['scale']``. If no image is provided, only resize ``results['img_shape']``. """ if results.get('img', None) is not None: return super()._resize_img(results) h, w = results['img_shape'] if self.keep_ratio: new_w, new_h = mmcv.rescale_size((w, h), results['scale'], return_scale=False) else: new_w, new_h = results['scale'] w_scale = new_w / w h_scale = new_h / h results['img_shape'] = (new_h, new_w) results['scale'] = (new_w, new_h) results['scale_factor'] = (w_scale, h_scale) results['keep_ratio'] = self.keep_ratio def _resize_bboxes(self, results: dict) -> None: """Resize bounding boxes.""" super()._resize_bboxes(results) if results.get('gt_bboxes', None) is not None: results['gt_bboxes'] = results['gt_bboxes'].astype(np.float32) def _resize_polygons(self, results: dict) -> None: """Resize polygons with ``results['scale_factor']``.""" if results.get('gt_polygons', None) is not None: polygons = results['gt_polygons'] polygons_resize = [] for idx, polygon in enumerate(polygons): polygon = rescale_polygon(polygon, results['scale_factor']) if self.clip_object_border: crop_bbox = np.array([ 0, 0, results['img_shape'][1], results['img_shape'][0] ]) polygon = crop_polygon(polygon, crop_bbox) if polygon is not None: polygons_resize.append(polygon.astype(np.float32)) else: polygons_resize.append( np.zeros_like(polygons[idx], dtype=np.float32)) results['gt_polygons'] = polygons_resize def transform(self, results: dict) -> dict: """Transform function to resize images, bounding boxes and polygons. Args: results (dict): Result dict from loading pipeline. Returns: dict: Resized results, 'img', 'gt_bboxes', 'gt_polygons', 'scale', 'scale_factor', 'height', 'width', and 'keep_ratio' keys are updated in result dict. """ results = super().transform(results) self._resize_polygons(results) return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(scale={self.scale}, ' repr_str += f'scale_factor={self.scale_factor}, ' repr_str += f'keep_ratio={self.keep_ratio}, ' repr_str += f'clip_object_border={self.clip_object_border}), ' repr_str += f'backend={self.backend}), ' repr_str += f'interpolation={self.interpolation})' return repr_str @TRANSFORMS.register_module() class RemoveIgnored(BaseTransform): """Removed ignored elements from the pipeline. Required Keys: - gt_ignored - gt_polygons (optional) - gt_bboxes (optional) - gt_bboxes_labels (optional) - gt_texts (optional) Modified Keys: - gt_ignored - gt_polygons (optional) - gt_bboxes (optional) - gt_bboxes_labels (optional) - gt_texts (optional) """ def transform(self, results: Dict) -> Dict: remove_inds = np.where(results['gt_ignored'])[0] if len(remove_inds) == len(results['gt_ignored']): return None return remove_pipeline_elements(results, remove_inds) @TRANSFORMS.register_module() class FixInvalidPolygon(BaseTransform): """Fix invalid polygons in the dataset. Required Keys: - gt_polygons - gt_ignored (optional) - gt_bboxes (optional) - gt_bboxes_labels (optional) - gt_texts (optional) Modified Keys: - gt_polygons - gt_ignored (optional) - gt_bboxes (optional) - gt_bboxes_labels (optional) - gt_texts (optional) Args: mode (str): The mode of fixing invalid polygons. Options are 'fix' and 'ignore'. For the 'fix' mode, the transform will try to fix the invalid polygons to a valid one by eliminating the self-intersection or converting the bboxes to polygons. If it can't be fixed by any means (e.g. the polygon contains less than 3 points or it's actually a line/point), the annotation will be removed. For the 'ignore' mode, the invalid polygons will be set to "ignored" during training. Defaults to 'fix'. min_poly_points (int): Minimum number of the coordinate points in a polygon. Defaults to 4. fix_from_bbox (bool): Whether to convert the bboxes to polygons when the polygon is invalid and not directly fixable. Defaults to True. """ def __init__(self, mode: str = 'fix', min_poly_points: int = 4, fix_from_bbox: bool = True) -> None: super().__init__() self.mode = mode assert min_poly_points >= 3, 'min_poly_points must be greater than 3.' self.min_poly_points = min_poly_points self.fix_from_bbox = fix_from_bbox assert self.mode in [ 'fix', 'ignore' ], f"Supported modes are 'fix' and 'ignore', but got {self.mode}" def transform(self, results: Dict) -> Dict: """Fix invalid polygons. Args: results (dict): Result dict containing the data to transform. Returns: Optional[dict]: The transformed data. If all the polygons are unfixable, return None. """ if results.get('gt_polygons', None) is not None: remove_inds = [] for idx, polygon in enumerate(results['gt_polygons']): if self.mode == 'ignore': if results['gt_ignored'][idx]: continue if not (len(polygon) >= self.min_poly_points * 2 and len(polygon) % 2 == 0) or not poly2shapely(polygon).is_valid: results['gt_ignored'][idx] = True else: # If "polygon" contains less than 3 points if len(polygon) < 6: remove_inds.append(idx) continue try: shapely_polygon = poly2shapely(polygon) if shapely_polygon.is_valid and len( polygon) >= self.min_poly_points * 2: continue results['gt_polygons'][idx] = shapely2poly( poly_make_valid(shapely_polygon)) # If an empty polygon is generated, it's still a bad # fix if len(results['gt_polygons'][idx]) == 0: raise ValueError # It's hard to fix, e.g. the "polygon" is a line or # a point except Exception: if self.fix_from_bbox and 'gt_bboxes' in results: bbox = results['gt_bboxes'][idx] bbox_polygon = bbox2poly(bbox) results['gt_polygons'][idx] = bbox_polygon shapely_polygon = poly2shapely(bbox_polygon) if (not shapely_polygon.is_valid or shapely_polygon.is_empty): remove_inds.append(idx) else: remove_inds.append(idx) if len(remove_inds) == len(results['gt_polygons']): return None results = remove_pipeline_elements(results, remove_inds) return results def __repr__(self) -> str: repr_str = self.__class__.__name__ repr_str += f'(mode = "{self.mode}", ' repr_str += f'min_poly_points = {self.min_poly_points}, ' repr_str += f'fix_from_bbox = {self.fix_from_bbox})' return repr_str