MMOCR / mmocr /models /textdet /losses /textsnake_loss.py
tomofi's picture
Add application file
2366e36
raw
history blame
7.88 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmdet.core import BitmapMasks
from torch import nn
from mmocr.models.builder import LOSSES
from mmocr.utils import check_argument
@LOSSES.register_module()
class TextSnakeLoss(nn.Module):
"""The class for implementing TextSnake loss. This is partially adapted
from https://github.com/princewang1994/TextSnake.pytorch.
TextSnake: `A Flexible Representation for Detecting Text of Arbitrary
Shapes <https://arxiv.org/abs/1807.01544>`_.
Args:
ohem_ratio (float): The negative/positive ratio in ohem.
"""
def __init__(self, ohem_ratio=3.0):
super().__init__()
self.ohem_ratio = ohem_ratio
def balanced_bce_loss(self, pred, gt, mask):
assert pred.shape == gt.shape == mask.shape
positive = gt * mask
negative = (1 - gt) * mask
positive_count = int(positive.float().sum())
gt = gt.float()
if positive_count > 0:
loss = F.binary_cross_entropy(pred, gt, reduction='none')
positive_loss = torch.sum(loss * positive.float())
negative_loss = loss * negative.float()
negative_count = min(
int(negative.float().sum()),
int(positive_count * self.ohem_ratio))
else:
positive_loss = torch.tensor(0.0, device=pred.device)
loss = F.binary_cross_entropy(pred, gt, reduction='none')
negative_loss = loss * negative.float()
negative_count = 100
negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
balance_loss = (positive_loss + torch.sum(negative_loss)) / (
float(positive_count + negative_count) + 1e-5)
return balance_loss
def bitmasks2tensor(self, bitmasks, target_sz):
"""Convert Bitmasks to tensor.
Args:
bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
for one img.
target_sz (tuple(int, int)): The target tensor of size
:math:`(H, W)`.
Returns:
list[Tensor]: The list of kernel tensors. Each element stands for
one kernel level.
"""
assert check_argument.is_type_list(bitmasks, BitmapMasks)
assert isinstance(target_sz, tuple)
batch_size = len(bitmasks)
num_masks = len(bitmasks[0])
results = []
for level_inx in range(num_masks):
kernel = []
for batch_inx in range(batch_size):
mask = torch.from_numpy(bitmasks[batch_inx].masks[level_inx])
# hxw
mask_sz = mask.shape
# left, right, top, bottom
pad = [
0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]
]
mask = F.pad(mask, pad, mode='constant', value=0)
kernel.append(mask)
kernel = torch.stack(kernel)
results.append(kernel)
return results
def forward(self, pred_maps, downsample_ratio, gt_text_mask,
gt_center_region_mask, gt_mask, gt_radius_map, gt_sin_map,
gt_cos_map):
"""
Args:
pred_maps (Tensor): The prediction map of shape
:math:`(N, 5, H, W)`, where each dimension is the map of
"text_region", "center_region", "sin_map", "cos_map", and
"radius_map" respectively.
downsample_ratio (float): Downsample ratio.
gt_text_mask (list[BitmapMasks]): Gold text masks.
gt_center_region_mask (list[BitmapMasks]): Gold center region
masks.
gt_mask (list[BitmapMasks]): Gold general masks.
gt_radius_map (list[BitmapMasks]): Gold radius maps.
gt_sin_map (list[BitmapMasks]): Gold sin maps.
gt_cos_map (list[BitmapMasks]): Gold cos maps.
Returns:
dict: A loss dict with ``loss_text``, ``loss_center``,
``loss_radius``, ``loss_sin`` and ``loss_cos``.
"""
assert isinstance(downsample_ratio, float)
assert check_argument.is_type_list(gt_text_mask, BitmapMasks)
assert check_argument.is_type_list(gt_center_region_mask, BitmapMasks)
assert check_argument.is_type_list(gt_mask, BitmapMasks)
assert check_argument.is_type_list(gt_radius_map, BitmapMasks)
assert check_argument.is_type_list(gt_sin_map, BitmapMasks)
assert check_argument.is_type_list(gt_cos_map, BitmapMasks)
pred_text_region = pred_maps[:, 0, :, :]
pred_center_region = pred_maps[:, 1, :, :]
pred_sin_map = pred_maps[:, 2, :, :]
pred_cos_map = pred_maps[:, 3, :, :]
pred_radius_map = pred_maps[:, 4, :, :]
feature_sz = pred_maps.size()
device = pred_maps.device
# bitmask 2 tensor
mapping = {
'gt_text_mask': gt_text_mask,
'gt_center_region_mask': gt_center_region_mask,
'gt_mask': gt_mask,
'gt_radius_map': gt_radius_map,
'gt_sin_map': gt_sin_map,
'gt_cos_map': gt_cos_map
}
gt = {}
for key, value in mapping.items():
gt[key] = value
if abs(downsample_ratio - 1.0) < 1e-2:
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
else:
gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
if key == 'gt_radius_map':
gt[key] = [item * downsample_ratio for item in gt[key]]
gt[key] = [item.to(device) for item in gt[key]]
scale = torch.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
pred_sin_map = pred_sin_map * scale
pred_cos_map = pred_cos_map * scale
loss_text = self.balanced_bce_loss(
torch.sigmoid(pred_text_region), gt['gt_text_mask'][0],
gt['gt_mask'][0])
text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0]).float()
loss_center_map = F.binary_cross_entropy(
torch.sigmoid(pred_center_region),
gt['gt_center_region_mask'][0].float(),
reduction='none')
if int(text_mask.sum()) > 0:
loss_center = torch.sum(
loss_center_map * text_mask) / torch.sum(text_mask)
else:
loss_center = torch.tensor(0.0, device=device)
center_mask = (gt['gt_center_region_mask'][0] *
gt['gt_mask'][0]).float()
if int(center_mask.sum()) > 0:
map_sz = pred_radius_map.size()
ones = torch.ones(map_sz, dtype=torch.float, device=device)
loss_radius = torch.sum(
F.smooth_l1_loss(
pred_radius_map / (gt['gt_radius_map'][0] + 1e-2),
ones,
reduction='none') * center_mask) / torch.sum(center_mask)
loss_sin = torch.sum(
F.smooth_l1_loss(
pred_sin_map, gt['gt_sin_map'][0], reduction='none') *
center_mask) / torch.sum(center_mask)
loss_cos = torch.sum(
F.smooth_l1_loss(
pred_cos_map, gt['gt_cos_map'][0], reduction='none') *
center_mask) / torch.sum(center_mask)
else:
loss_radius = torch.tensor(0.0, device=device)
loss_sin = torch.tensor(0.0, device=device)
loss_cos = torch.tensor(0.0, device=device)
results = dict(
loss_text=loss_text,
loss_center=loss_center,
loss_radius=loss_radius,
loss_sin=loss_sin,
loss_cos=loss_cos)
return results