Spaces:
Running
Running
File size: 12,225 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Sequence, Tuple, Union
import cv2
import numpy as np
import torch
from mmdet.models.utils import multi_apply
from shapely.geometry import Polygon
from torch import Tensor
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from mmocr.utils import offset_polygon
from mmocr.utils.typing_utils import ArrayLike
from .seg_based_module_loss import SegBasedModuleLoss
@MODELS.register_module()
class DBModuleLoss(SegBasedModuleLoss):
r"""The class for implementing DBNet loss.
This is partially adapted from https://github.com/MhLiao/DB.
Args:
loss_prob (dict): The loss config for probability map. Defaults to
dict(type='MaskedBalancedBCEWithLogitsLoss').
loss_thr (dict): The loss config for threshold map. Defaults to
dict(type='MaskedSmoothL1Loss', beta=0).
loss_db (dict): The loss config for binary map. Defaults to
dict(type='MaskedDiceLoss').
weight_prob (float): The weight of probability map loss.
Denoted as :math:`\alpha` in paper. Defaults to 5.
weight_thr (float): The weight of threshold map loss.
Denoted as :math:`\beta` in paper. Defaults to 10.
shrink_ratio (float): The ratio of shrunk text region. Defaults to 0.4.
thr_min (float): The minimum threshold map value. Defaults to 0.3.
thr_max (float): The maximum threshold map value. Defaults to 0.7.
min_sidelength (int or float): The minimum sidelength of the
minimum rotated rectangle around any text region. Defaults to 8.
"""
def __init__(self,
loss_prob: Dict = dict(
type='MaskedBalancedBCEWithLogitsLoss'),
loss_thr: Dict = dict(type='MaskedSmoothL1Loss', beta=0),
loss_db: Dict = dict(type='MaskedDiceLoss'),
weight_prob: float = 5.,
weight_thr: float = 10.,
shrink_ratio: float = 0.4,
thr_min: float = 0.3,
thr_max: float = 0.7,
min_sidelength: Union[int, float] = 8) -> None:
super().__init__()
self.loss_prob = MODELS.build(loss_prob)
self.loss_thr = MODELS.build(loss_thr)
self.loss_db = MODELS.build(loss_db)
self.weight_prob = weight_prob
self.weight_thr = weight_thr
self.shrink_ratio = shrink_ratio
self.thr_min = thr_min
self.thr_max = thr_max
self.min_sidelength = min_sidelength
def forward(self, preds: Tuple[Tensor, Tensor, Tensor],
data_samples: Sequence[TextDetDataSample]) -> Dict:
"""Compute DBNet loss.
Args:
preds (tuple(tensor)): Raw predictions from model, containing
``prob_logits``, ``thr_map`` and ``binary_map``.
Each is a tensor of shape :math:`(N, H, W)`.
data_samples (list[TextDetDataSample]): The data samples.
Returns:
results(dict): The dict for dbnet losses with loss_prob, \
loss_db and loss_thr.
"""
prob_logits, thr_map, binary_map = preds
gt_shrinks, gt_shrink_masks, gt_thrs, gt_thr_masks = self.get_targets(
data_samples)
gt_shrinks = gt_shrinks.to(prob_logits.device)
gt_shrink_masks = gt_shrink_masks.to(prob_logits.device)
gt_thrs = gt_thrs.to(thr_map.device)
gt_thr_masks = gt_thr_masks.to(thr_map.device)
loss_prob = self.loss_prob(prob_logits, gt_shrinks, gt_shrink_masks)
loss_thr = self.loss_thr(thr_map, gt_thrs, gt_thr_masks)
loss_db = self.loss_db(binary_map, gt_shrinks, gt_shrink_masks)
results = dict(
loss_prob=self.weight_prob * loss_prob,
loss_thr=self.weight_thr * loss_thr,
loss_db=loss_db)
return results
def _is_poly_invalid(self, poly: np.ndarray) -> bool:
"""Check if the input polygon is invalid or not. It is invalid if its
area is smaller than 1 or the shorter side of its minimum bounding box
is smaller than min_sidelength.
Args:
poly (ndarray): The polygon.
Returns:
bool: Whether the polygon is invalid.
"""
poly = poly.reshape(-1, 2)
area = Polygon(poly).area
if abs(area) < 1:
return True
rect_size = cv2.minAreaRect(poly)[1]
len_shortest_side = min(rect_size)
if len_shortest_side < self.min_sidelength:
return True
return False
def _generate_thr_map(self, img_size: Tuple[int, int],
polygons: ArrayLike) -> np.ndarray:
"""Generate threshold map.
Args:
img_size (tuple(int)): The image size (h, w)
polygons (Sequence[ndarray]): 2-d array, representing all the
polygons of the text region.
Returns:
tuple:
- thr_map (ndarray): The generated threshold map.
- thr_mask (ndarray): The effective mask of threshold map.
"""
thr_map = np.zeros(img_size, dtype=np.float32)
thr_mask = np.zeros(img_size, dtype=np.uint8)
for polygon in polygons:
self._draw_border_map(polygon, thr_map, mask=thr_mask)
thr_map = thr_map * (self.thr_max - self.thr_min) + self.thr_min
return thr_map, thr_mask
def _draw_border_map(self, polygon: np.ndarray, canvas: np.ndarray,
mask: np.ndarray) -> None:
"""Generate threshold map for one polygon.
Args:
polygon (np.ndarray): The polygon.
canvas (np.ndarray): The generated threshold map.
mask (np.ndarray): The generated threshold mask.
"""
polygon = polygon.reshape(-1, 2)
polygon_obj = Polygon(polygon)
distance = (
polygon_obj.area * (1 - np.power(self.shrink_ratio, 2)) /
polygon_obj.length)
expanded_polygon = offset_polygon(polygon, distance)
if len(expanded_polygon) == 0:
print(f'Padding {polygon} with {distance} gets {expanded_polygon}')
expanded_polygon = polygon.copy().astype(np.int32)
else:
expanded_polygon = expanded_polygon.reshape(-1, 2).astype(np.int32)
x_min = expanded_polygon[:, 0].min()
x_max = expanded_polygon[:, 0].max()
y_min = expanded_polygon[:, 1].min()
y_max = expanded_polygon[:, 1].max()
width = x_max - x_min + 1
height = y_max - y_min + 1
polygon[:, 0] = polygon[:, 0] - x_min
polygon[:, 1] = polygon[:, 1] - y_min
xs = np.broadcast_to(
np.linspace(0, width - 1, num=width).reshape(1, width),
(height, width))
ys = np.broadcast_to(
np.linspace(0, height - 1, num=height).reshape(height, 1),
(height, width))
distance_map = np.zeros((polygon.shape[0], height, width),
dtype=np.float32)
for i in range(polygon.shape[0]):
j = (i + 1) % polygon.shape[0]
absolute_distance = self._dist_points2line(xs, ys, polygon[i],
polygon[j])
distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
distance_map = distance_map.min(axis=0)
x_min_valid = min(max(0, x_min), canvas.shape[1] - 1)
x_max_valid = min(max(0, x_max), canvas.shape[1] - 1)
y_min_valid = min(max(0, y_min), canvas.shape[0] - 1)
y_max_valid = min(max(0, y_max), canvas.shape[0] - 1)
if x_min_valid - x_min >= width or y_min_valid - y_min >= height:
return
cv2.fillPoly(mask, [expanded_polygon.astype(np.int32)], 1.0)
canvas[y_min_valid:y_max_valid + 1,
x_min_valid:x_max_valid + 1] = np.fmax(
1 - distance_map[y_min_valid - y_min:y_max_valid - y_max +
height, x_min_valid - x_min:x_max_valid -
x_max + width],
canvas[y_min_valid:y_max_valid + 1,
x_min_valid:x_max_valid + 1])
def get_targets(self, data_samples: List[TextDetDataSample]) -> Tuple:
"""Generate loss targets from data samples.
Args:
data_samples (list(TextDetDataSample)): Ground truth data samples.
Returns:
tuple: A tuple of four tensors as DBNet targets.
"""
gt_shrinks, gt_shrink_masks, gt_thrs, gt_thr_masks = multi_apply(
self._get_target_single, data_samples)
gt_shrinks = torch.cat(gt_shrinks)
gt_shrink_masks = torch.cat(gt_shrink_masks)
gt_thrs = torch.cat(gt_thrs)
gt_thr_masks = torch.cat(gt_thr_masks)
return gt_shrinks, gt_shrink_masks, gt_thrs, gt_thr_masks
def _get_target_single(self, data_sample: TextDetDataSample) -> Tuple:
"""Generate loss target from a data sample.
Args:
data_sample (TextDetDataSample): The data sample.
Returns:
tuple: A tuple of four tensors as the targets of one prediction.
"""
gt_instances = data_sample.gt_instances
ignore_flags = gt_instances.ignored
for idx, polygon in enumerate(gt_instances.polygons):
if self._is_poly_invalid(polygon):
ignore_flags[idx] = True
gt_shrink, ignore_flags = self._generate_kernels(
data_sample.img_shape,
gt_instances.polygons,
self.shrink_ratio,
ignore_flags=ignore_flags)
# Get boolean mask where Trues indicate text instance pixels
gt_shrink = gt_shrink > 0
gt_shrink_mask = self._generate_effective_mask(
data_sample.img_shape, gt_instances[ignore_flags].polygons)
gt_thr, gt_thr_mask = self._generate_thr_map(
data_sample.img_shape, gt_instances[~ignore_flags].polygons)
# to_tensor
gt_shrink = torch.from_numpy(gt_shrink).unsqueeze(0).float()
gt_shrink_mask = torch.from_numpy(gt_shrink_mask).unsqueeze(0).float()
gt_thr = torch.from_numpy(gt_thr).unsqueeze(0).float()
gt_thr_mask = torch.from_numpy(gt_thr_mask).unsqueeze(0).float()
return gt_shrink, gt_shrink_mask, gt_thr, gt_thr_mask
@staticmethod
def _dist_points2line(xs: np.ndarray, ys: np.ndarray, pt1: np.ndarray,
pt2: np.ndarray) -> np.ndarray:
"""Compute distances from points to a line. This is adapted from
https://github.com/MhLiao/DB.
Args:
xs (ndarray): The x coordinates of points of size :math:`(N, )`.
ys (ndarray): The y coordinates of size :math:`(N, )`.
pt1 (ndarray): The first point on the line of size :math:`(2, )`.
pt2 (ndarray): The second point on the line of size :math:`(2, )`.
Returns:
ndarray: The distance matrix of size :math:`(N, )`.
"""
# suppose a triangle with three edge abc with c=point_1 point_2
# a^2
a_square = np.square(xs - pt1[0]) + np.square(ys - pt1[1])
# b^2
b_square = np.square(xs - pt2[0]) + np.square(ys - pt2[1])
# c^2
c_square = np.square(pt1[0] - pt2[0]) + np.square(pt1[1] - pt2[1])
# -cosC=(c^2-a^2-b^2)/2(ab)
neg_cos_c = (
(c_square - a_square - b_square) /
(np.finfo(np.float32).eps + 2 * np.sqrt(a_square * b_square)))
# clip -cosC value to [-1, 1]
neg_cos_c = np.clip(neg_cos_c, -1.0, 1.0)
# sinC^2=1-cosC^2
square_sin = 1 - np.square(neg_cos_c)
square_sin = np.nan_to_num(square_sin)
# distance=a*b*sinC/c=a*h/c=2*area/c
result = np.sqrt(a_square * b_square * square_sin /
(np.finfo(np.float32).eps + c_square))
# set result to minimum edge if C<pi/2
result[neg_cos_c < 0] = np.sqrt(np.fmin(a_square,
b_square))[neg_cos_c < 0]
return result
|