Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import cv2 | |
import numpy as np | |
import pyclipper | |
from mmdet.core import BitmapMasks | |
from mmdet.datasets.builder import PIPELINES | |
from shapely.geometry import Polygon | |
from . import BaseTextDetTargets | |
class DBNetTargets(BaseTextDetTargets): | |
"""Generate gt shrunk text, gt threshold map, and their effective region | |
masks to learn DBNet: Real-time Scene Text Detection with Differentiable | |
Binarization [https://arxiv.org/abs/1911.08947]. This was partially adapted | |
from https://github.com/MhLiao/DB. | |
Args: | |
shrink_ratio (float): The area shrunk ratio between text | |
kernels and their text masks. | |
thr_min (float): The minimum value of the threshold map. | |
thr_max (float): The maximum value of the threshold map. | |
min_short_size (int): The minimum size of polygon below which | |
the polygon is invalid. | |
""" | |
def __init__(self, | |
shrink_ratio=0.4, | |
thr_min=0.3, | |
thr_max=0.7, | |
min_short_size=8): | |
super().__init__() | |
self.shrink_ratio = shrink_ratio | |
self.thr_min = thr_min | |
self.thr_max = thr_max | |
self.min_short_size = min_short_size | |
def find_invalid(self, results): | |
"""Find invalid polygons. | |
Args: | |
results (dict): The dict containing gt_mask. | |
Returns: | |
ignore_tags (list[bool]): The indicators for ignoring polygons. | |
""" | |
texts = results['gt_masks'].masks | |
ignore_tags = [False] * len(texts) | |
for idx, text in enumerate(texts): | |
if self.invalid_polygon(text[0]): | |
ignore_tags[idx] = True | |
return ignore_tags | |
def invalid_polygon(self, poly): | |
"""Judge the input polygon is invalid or not. It is invalid if its area | |
smaller than 1 or the shorter side of its minimum bounding box smaller | |
than min_short_size. | |
Args: | |
poly (ndarray): The polygon boundary point sequence. | |
Returns: | |
True/False (bool): Whether the polygon is invalid. | |
""" | |
area = self.polygon_area(poly) | |
if abs(area) < 1: | |
return True | |
short_size = min(self.polygon_size(poly)) | |
if short_size < self.min_short_size: | |
return True | |
return False | |
def ignore_texts(self, results, ignore_tags): | |
"""Ignore gt masks and gt_labels while padding gt_masks_ignore in | |
results given ignore_tags. | |
Args: | |
results (dict): Result for one image. | |
ignore_tags (list[int]): Indicate whether to ignore its | |
corresponding ground truth text. | |
Returns: | |
results (dict): Results after filtering. | |
""" | |
flag_len = len(ignore_tags) | |
assert flag_len == len(results['gt_masks'].masks) | |
assert flag_len == len(results['gt_labels']) | |
results['gt_masks_ignore'].masks += [ | |
mask for i, mask in enumerate(results['gt_masks'].masks) | |
if ignore_tags[i] | |
] | |
results['gt_masks'].masks = [ | |
mask for i, mask in enumerate(results['gt_masks'].masks) | |
if not ignore_tags[i] | |
] | |
results['gt_labels'] = np.array([ | |
mask for i, mask in enumerate(results['gt_labels']) | |
if not ignore_tags[i] | |
]) | |
new_ignore_tags = [ignore for ignore in ignore_tags if not ignore] | |
return results, new_ignore_tags | |
def generate_thr_map(self, img_size, polygons): | |
"""Generate threshold map. | |
Args: | |
img_size (tuple(int)): The image size (h,w) | |
polygons (list(ndarray)): The polygon list. | |
Returns: | |
thr_map (ndarray): The generated threshold map. | |
thr_mask (ndarray): The effective mask of threshold map. | |
""" | |
thr_map = np.zeros(img_size, dtype=np.float32) | |
thr_mask = np.zeros(img_size, dtype=np.uint8) | |
for polygon in polygons: | |
self.draw_border_map(polygon[0], thr_map, mask=thr_mask) | |
thr_map = thr_map * (self.thr_max - self.thr_min) + self.thr_min | |
return thr_map, thr_mask | |
def draw_border_map(self, polygon, canvas, mask): | |
"""Generate threshold map for one polygon. | |
Args: | |
polygon(ndarray): The polygon boundary ndarray. | |
canvas(ndarray): The generated threshold map. | |
mask(ndarray): The generated threshold mask. | |
""" | |
polygon = polygon.reshape(-1, 2) | |
assert polygon.ndim == 2 | |
assert polygon.shape[1] == 2 | |
polygon_shape = Polygon(polygon) | |
distance = ( | |
polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / | |
polygon_shape.length) | |
subject = [tuple(p) for p in polygon] | |
padding = pyclipper.PyclipperOffset() | |
padding.AddPath(subject, pyclipper.JT_ROUND, | |
pyclipper.ET_CLOSEDPOLYGON) | |
padded_polygon = padding.Execute(distance) | |
if len(padded_polygon) > 0: | |
padded_polygon = np.array(padded_polygon[0]) | |
else: | |
print(f'padding {polygon} with {distance} gets {padded_polygon}') | |
padded_polygon = polygon.copy().astype(np.int32) | |
x_min = padded_polygon[:, 0].min() | |
x_max = padded_polygon[:, 0].max() | |
y_min = padded_polygon[:, 1].min() | |
y_max = padded_polygon[:, 1].max() | |
width = x_max - x_min + 1 | |
height = y_max - y_min + 1 | |
polygon[:, 0] = polygon[:, 0] - x_min | |
polygon[:, 1] = polygon[:, 1] - y_min | |
xs = np.broadcast_to( | |
np.linspace(0, width - 1, num=width).reshape(1, width), | |
(height, width)) | |
ys = np.broadcast_to( | |
np.linspace(0, height - 1, num=height).reshape(height, 1), | |
(height, width)) | |
distance_map = np.zeros((polygon.shape[0], height, width), | |
dtype=np.float32) | |
for i in range(polygon.shape[0]): | |
j = (i + 1) % polygon.shape[0] | |
absolute_distance = self.point2line(xs, ys, polygon[i], polygon[j]) | |
distance_map[i] = np.clip(absolute_distance / distance, 0, 1) | |
distance_map = distance_map.min(axis=0) | |
x_min_valid = min(max(0, x_min), canvas.shape[1] - 1) | |
x_max_valid = min(max(0, x_max), canvas.shape[1] - 1) | |
y_min_valid = min(max(0, y_min), canvas.shape[0] - 1) | |
y_max_valid = min(max(0, y_max), canvas.shape[0] - 1) | |
if x_min_valid - x_min >= width or y_min_valid - y_min >= height: | |
return | |
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) | |
canvas[y_min_valid:y_max_valid + 1, | |
x_min_valid:x_max_valid + 1] = np.fmax( | |
1 - distance_map[y_min_valid - y_min:y_max_valid - y_max + | |
height, x_min_valid - x_min:x_max_valid - | |
x_max + width], | |
canvas[y_min_valid:y_max_valid + 1, | |
x_min_valid:x_max_valid + 1]) | |
def generate_targets(self, results): | |
"""Generate the gt targets for DBNet. | |
Args: | |
results (dict): The input result dictionary. | |
Returns: | |
results (dict): The output result dictionary. | |
""" | |
assert isinstance(results, dict) | |
if 'bbox_fields' in results: | |
results['bbox_fields'].clear() | |
ignore_tags = self.find_invalid(results) | |
results, ignore_tags = self.ignore_texts(results, ignore_tags) | |
h, w, _ = results['img_shape'] | |
polygons = results['gt_masks'].masks | |
# generate gt_shrink_kernel | |
gt_shrink, ignore_tags = self.generate_kernels((h, w), | |
polygons, | |
self.shrink_ratio, | |
ignore_tags=ignore_tags) | |
results, ignore_tags = self.ignore_texts(results, ignore_tags) | |
# genenrate gt_shrink_mask | |
polygons_ignore = results['gt_masks_ignore'].masks | |
gt_shrink_mask = self.generate_effective_mask((h, w), polygons_ignore) | |
# generate gt_threshold and gt_threshold_mask | |
polygons = results['gt_masks'].masks | |
gt_thr, gt_thr_mask = self.generate_thr_map((h, w), polygons) | |
results['mask_fields'].clear() # rm gt_masks encoded by polygons | |
results.pop('gt_labels', None) | |
results.pop('gt_masks', None) | |
results.pop('gt_bboxes', None) | |
results.pop('gt_bboxes_ignore', None) | |
mapping = { | |
'gt_shrink': gt_shrink, | |
'gt_shrink_mask': gt_shrink_mask, | |
'gt_thr': gt_thr, | |
'gt_thr_mask': gt_thr_mask | |
} | |
for key, value in mapping.items(): | |
value = value if isinstance(value, list) else [value] | |
results[key] = BitmapMasks(value, h, w) | |
results['mask_fields'].append(key) | |
return results | |