Mountchicken's picture
Upload 704 files
9bf4bd7
raw
history blame
12.2 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Sequence, Tuple, Union
import cv2
import numpy as np
import torch
from mmdet.models.utils import multi_apply
from shapely.geometry import Polygon
from torch import Tensor
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from mmocr.utils import offset_polygon
from mmocr.utils.typing_utils import ArrayLike
from .seg_based_module_loss import SegBasedModuleLoss
@MODELS.register_module()
class DBModuleLoss(SegBasedModuleLoss):
r"""The class for implementing DBNet loss.
This is partially adapted from https://github.com/MhLiao/DB.
Args:
loss_prob (dict): The loss config for probability map. Defaults to
dict(type='MaskedBalancedBCEWithLogitsLoss').
loss_thr (dict): The loss config for threshold map. Defaults to
dict(type='MaskedSmoothL1Loss', beta=0).
loss_db (dict): The loss config for binary map. Defaults to
dict(type='MaskedDiceLoss').
weight_prob (float): The weight of probability map loss.
Denoted as :math:`\alpha` in paper. Defaults to 5.
weight_thr (float): The weight of threshold map loss.
Denoted as :math:`\beta` in paper. Defaults to 10.
shrink_ratio (float): The ratio of shrunk text region. Defaults to 0.4.
thr_min (float): The minimum threshold map value. Defaults to 0.3.
thr_max (float): The maximum threshold map value. Defaults to 0.7.
min_sidelength (int or float): The minimum sidelength of the
minimum rotated rectangle around any text region. Defaults to 8.
"""
def __init__(self,
loss_prob: Dict = dict(
type='MaskedBalancedBCEWithLogitsLoss'),
loss_thr: Dict = dict(type='MaskedSmoothL1Loss', beta=0),
loss_db: Dict = dict(type='MaskedDiceLoss'),
weight_prob: float = 5.,
weight_thr: float = 10.,
shrink_ratio: float = 0.4,
thr_min: float = 0.3,
thr_max: float = 0.7,
min_sidelength: Union[int, float] = 8) -> None:
super().__init__()
self.loss_prob = MODELS.build(loss_prob)
self.loss_thr = MODELS.build(loss_thr)
self.loss_db = MODELS.build(loss_db)
self.weight_prob = weight_prob
self.weight_thr = weight_thr
self.shrink_ratio = shrink_ratio
self.thr_min = thr_min
self.thr_max = thr_max
self.min_sidelength = min_sidelength
def forward(self, preds: Tuple[Tensor, Tensor, Tensor],
data_samples: Sequence[TextDetDataSample]) -> Dict:
"""Compute DBNet loss.
Args:
preds (tuple(tensor)): Raw predictions from model, containing
``prob_logits``, ``thr_map`` and ``binary_map``.
Each is a tensor of shape :math:`(N, H, W)`.
data_samples (list[TextDetDataSample]): The data samples.
Returns:
results(dict): The dict for dbnet losses with loss_prob, \
loss_db and loss_thr.
"""
prob_logits, thr_map, binary_map = preds
gt_shrinks, gt_shrink_masks, gt_thrs, gt_thr_masks = self.get_targets(
data_samples)
gt_shrinks = gt_shrinks.to(prob_logits.device)
gt_shrink_masks = gt_shrink_masks.to(prob_logits.device)
gt_thrs = gt_thrs.to(thr_map.device)
gt_thr_masks = gt_thr_masks.to(thr_map.device)
loss_prob = self.loss_prob(prob_logits, gt_shrinks, gt_shrink_masks)
loss_thr = self.loss_thr(thr_map, gt_thrs, gt_thr_masks)
loss_db = self.loss_db(binary_map, gt_shrinks, gt_shrink_masks)
results = dict(
loss_prob=self.weight_prob * loss_prob,
loss_thr=self.weight_thr * loss_thr,
loss_db=loss_db)
return results
def _is_poly_invalid(self, poly: np.ndarray) -> bool:
"""Check if the input polygon is invalid or not. It is invalid if its
area is smaller than 1 or the shorter side of its minimum bounding box
is smaller than min_sidelength.
Args:
poly (ndarray): The polygon.
Returns:
bool: Whether the polygon is invalid.
"""
poly = poly.reshape(-1, 2)
area = Polygon(poly).area
if abs(area) < 1:
return True
rect_size = cv2.minAreaRect(poly)[1]
len_shortest_side = min(rect_size)
if len_shortest_side < self.min_sidelength:
return True
return False
def _generate_thr_map(self, img_size: Tuple[int, int],
polygons: ArrayLike) -> np.ndarray:
"""Generate threshold map.
Args:
img_size (tuple(int)): The image size (h, w)
polygons (Sequence[ndarray]): 2-d array, representing all the
polygons of the text region.
Returns:
tuple:
- thr_map (ndarray): The generated threshold map.
- thr_mask (ndarray): The effective mask of threshold map.
"""
thr_map = np.zeros(img_size, dtype=np.float32)
thr_mask = np.zeros(img_size, dtype=np.uint8)
for polygon in polygons:
self._draw_border_map(polygon, thr_map, mask=thr_mask)
thr_map = thr_map * (self.thr_max - self.thr_min) + self.thr_min
return thr_map, thr_mask
def _draw_border_map(self, polygon: np.ndarray, canvas: np.ndarray,
mask: np.ndarray) -> None:
"""Generate threshold map for one polygon.
Args:
polygon (np.ndarray): The polygon.
canvas (np.ndarray): The generated threshold map.
mask (np.ndarray): The generated threshold mask.
"""
polygon = polygon.reshape(-1, 2)
polygon_obj = Polygon(polygon)
distance = (
polygon_obj.area * (1 - np.power(self.shrink_ratio, 2)) /
polygon_obj.length)
expanded_polygon = offset_polygon(polygon, distance)
if len(expanded_polygon) == 0:
print(f'Padding {polygon} with {distance} gets {expanded_polygon}')
expanded_polygon = polygon.copy().astype(np.int32)
else:
expanded_polygon = expanded_polygon.reshape(-1, 2).astype(np.int32)
x_min = expanded_polygon[:, 0].min()
x_max = expanded_polygon[:, 0].max()
y_min = expanded_polygon[:, 1].min()
y_max = expanded_polygon[:, 1].max()
width = x_max - x_min + 1
height = y_max - y_min + 1
polygon[:, 0] = polygon[:, 0] - x_min
polygon[:, 1] = polygon[:, 1] - y_min
xs = np.broadcast_to(
np.linspace(0, width - 1, num=width).reshape(1, width),
(height, width))
ys = np.broadcast_to(
np.linspace(0, height - 1, num=height).reshape(height, 1),
(height, width))
distance_map = np.zeros((polygon.shape[0], height, width),
dtype=np.float32)
for i in range(polygon.shape[0]):
j = (i + 1) % polygon.shape[0]
absolute_distance = self._dist_points2line(xs, ys, polygon[i],
polygon[j])
distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
distance_map = distance_map.min(axis=0)
x_min_valid = min(max(0, x_min), canvas.shape[1] - 1)
x_max_valid = min(max(0, x_max), canvas.shape[1] - 1)
y_min_valid = min(max(0, y_min), canvas.shape[0] - 1)
y_max_valid = min(max(0, y_max), canvas.shape[0] - 1)
if x_min_valid - x_min >= width or y_min_valid - y_min >= height:
return
cv2.fillPoly(mask, [expanded_polygon.astype(np.int32)], 1.0)
canvas[y_min_valid:y_max_valid + 1,
x_min_valid:x_max_valid + 1] = np.fmax(
1 - distance_map[y_min_valid - y_min:y_max_valid - y_max +
height, x_min_valid - x_min:x_max_valid -
x_max + width],
canvas[y_min_valid:y_max_valid + 1,
x_min_valid:x_max_valid + 1])
def get_targets(self, data_samples: List[TextDetDataSample]) -> Tuple:
"""Generate loss targets from data samples.
Args:
data_samples (list(TextDetDataSample)): Ground truth data samples.
Returns:
tuple: A tuple of four tensors as DBNet targets.
"""
gt_shrinks, gt_shrink_masks, gt_thrs, gt_thr_masks = multi_apply(
self._get_target_single, data_samples)
gt_shrinks = torch.cat(gt_shrinks)
gt_shrink_masks = torch.cat(gt_shrink_masks)
gt_thrs = torch.cat(gt_thrs)
gt_thr_masks = torch.cat(gt_thr_masks)
return gt_shrinks, gt_shrink_masks, gt_thrs, gt_thr_masks
def _get_target_single(self, data_sample: TextDetDataSample) -> Tuple:
"""Generate loss target from a data sample.
Args:
data_sample (TextDetDataSample): The data sample.
Returns:
tuple: A tuple of four tensors as the targets of one prediction.
"""
gt_instances = data_sample.gt_instances
ignore_flags = gt_instances.ignored
for idx, polygon in enumerate(gt_instances.polygons):
if self._is_poly_invalid(polygon):
ignore_flags[idx] = True
gt_shrink, ignore_flags = self._generate_kernels(
data_sample.img_shape,
gt_instances.polygons,
self.shrink_ratio,
ignore_flags=ignore_flags)
# Get boolean mask where Trues indicate text instance pixels
gt_shrink = gt_shrink > 0
gt_shrink_mask = self._generate_effective_mask(
data_sample.img_shape, gt_instances[ignore_flags].polygons)
gt_thr, gt_thr_mask = self._generate_thr_map(
data_sample.img_shape, gt_instances[~ignore_flags].polygons)
# to_tensor
gt_shrink = torch.from_numpy(gt_shrink).unsqueeze(0).float()
gt_shrink_mask = torch.from_numpy(gt_shrink_mask).unsqueeze(0).float()
gt_thr = torch.from_numpy(gt_thr).unsqueeze(0).float()
gt_thr_mask = torch.from_numpy(gt_thr_mask).unsqueeze(0).float()
return gt_shrink, gt_shrink_mask, gt_thr, gt_thr_mask
@staticmethod
def _dist_points2line(xs: np.ndarray, ys: np.ndarray, pt1: np.ndarray,
pt2: np.ndarray) -> np.ndarray:
"""Compute distances from points to a line. This is adapted from
https://github.com/MhLiao/DB.
Args:
xs (ndarray): The x coordinates of points of size :math:`(N, )`.
ys (ndarray): The y coordinates of size :math:`(N, )`.
pt1 (ndarray): The first point on the line of size :math:`(2, )`.
pt2 (ndarray): The second point on the line of size :math:`(2, )`.
Returns:
ndarray: The distance matrix of size :math:`(N, )`.
"""
# suppose a triangle with three edge abc with c=point_1 point_2
# a^2
a_square = np.square(xs - pt1[0]) + np.square(ys - pt1[1])
# b^2
b_square = np.square(xs - pt2[0]) + np.square(ys - pt2[1])
# c^2
c_square = np.square(pt1[0] - pt2[0]) + np.square(pt1[1] - pt2[1])
# -cosC=(c^2-a^2-b^2)/2(ab)
neg_cos_c = (
(c_square - a_square - b_square) /
(np.finfo(np.float32).eps + 2 * np.sqrt(a_square * b_square)))
# clip -cosC value to [-1, 1]
neg_cos_c = np.clip(neg_cos_c, -1.0, 1.0)
# sinC^2=1-cosC^2
square_sin = 1 - np.square(neg_cos_c)
square_sin = np.nan_to_num(square_sin)
# distance=a*b*sinC/c=a*h/c=2*area/c
result = np.sqrt(a_square * b_square * square_sin /
(np.finfo(np.float32).eps + c_square))
# set result to minimum edge if C<pi/2
result[neg_cos_c < 0] = np.sqrt(np.fmin(a_square,
b_square))[neg_cos_c < 0]
return result