Spaces:
Runtime error
Runtime error
File size: 7,914 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn.functional as F
from mmdet.core import multi_apply
from torch import nn
from mmocr.models.builder import LOSSES
@LOSSES.register_module()
class FCELoss(nn.Module):
"""The class for implementing FCENet loss.
FCENet(CVPR2021): `Fourier Contour Embedding for Arbitrary-shaped Text
Detection <https://arxiv.org/abs/2104.10442>`_
Args:
fourier_degree (int) : The maximum Fourier transform degree k.
num_sample (int) : The sampling points number of regression
loss. If it is too small, fcenet tends to be overfitting.
ohem_ratio (float): the negative/positive ratio in OHEM.
"""
def __init__(self, fourier_degree, num_sample, ohem_ratio=3.):
super().__init__()
self.fourier_degree = fourier_degree
self.num_sample = num_sample
self.ohem_ratio = ohem_ratio
def forward(self, preds, _, p3_maps, p4_maps, p5_maps):
"""Compute FCENet loss.
Args:
preds (list[list[Tensor]]): The outer list indicates images
in a batch, and the inner list indicates the classification
prediction map (with shape :math:`(N, C, H, W)`) and
regression map (with shape :math:`(N, C, H, W)`).
p3_maps (list[ndarray]): List of leval 3 ground truth target map
with shape :math:`(C, H, W)`.
p4_maps (list[ndarray]): List of leval 4 ground truth target map
with shape :math:`(C, H, W)`.
p5_maps (list[ndarray]): List of leval 5 ground truth target map
with shape :math:`(C, H, W)`.
Returns:
dict: A loss dict with ``loss_text``, ``loss_center``,
``loss_reg_x`` and ``loss_reg_y``.
"""
assert isinstance(preds, list)
assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\
'fourier degree not equal in FCEhead and FCEtarget'
device = preds[0][0].device
# to tensor
gts = [p3_maps, p4_maps, p5_maps]
for idx, maps in enumerate(gts):
gts[idx] = torch.from_numpy(np.stack(maps)).float().to(device)
losses = multi_apply(self.forward_single, preds, gts)
loss_tr = torch.tensor(0., device=device).float()
loss_tcl = torch.tensor(0., device=device).float()
loss_reg_x = torch.tensor(0., device=device).float()
loss_reg_y = torch.tensor(0., device=device).float()
for idx, loss in enumerate(losses):
if idx == 0:
loss_tr += sum(loss)
elif idx == 1:
loss_tcl += sum(loss)
elif idx == 2:
loss_reg_x += sum(loss)
else:
loss_reg_y += sum(loss)
results = dict(
loss_text=loss_tr,
loss_center=loss_tcl,
loss_reg_x=loss_reg_x,
loss_reg_y=loss_reg_y,
)
return results
def forward_single(self, pred, gt):
cls_pred = pred[0].permute(0, 2, 3, 1).contiguous()
reg_pred = pred[1].permute(0, 2, 3, 1).contiguous()
gt = gt.permute(0, 2, 3, 1).contiguous()
k = 2 * self.fourier_degree + 1
tr_pred = cls_pred[:, :, :, :2].view(-1, 2)
tcl_pred = cls_pred[:, :, :, 2:].view(-1, 2)
x_pred = reg_pred[:, :, :, 0:k].view(-1, k)
y_pred = reg_pred[:, :, :, k:2 * k].view(-1, k)
tr_mask = gt[:, :, :, :1].view(-1)
tcl_mask = gt[:, :, :, 1:2].view(-1)
train_mask = gt[:, :, :, 2:3].view(-1)
x_map = gt[:, :, :, 3:3 + k].view(-1, k)
y_map = gt[:, :, :, 3 + k:].view(-1, k)
tr_train_mask = train_mask * tr_mask
device = x_map.device
# tr loss
loss_tr = self.ohem(tr_pred, tr_mask.long(), train_mask.long())
# tcl loss
loss_tcl = torch.tensor(0.).float().to(device)
tr_neg_mask = 1 - tr_train_mask
if tr_train_mask.sum().item() > 0:
loss_tcl_pos = F.cross_entropy(
tcl_pred[tr_train_mask.bool()],
tcl_mask[tr_train_mask.bool()].long())
loss_tcl_neg = F.cross_entropy(tcl_pred[tr_neg_mask.bool()],
tcl_mask[tr_neg_mask.bool()].long())
loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg
# regression loss
loss_reg_x = torch.tensor(0.).float().to(device)
loss_reg_y = torch.tensor(0.).float().to(device)
if tr_train_mask.sum().item() > 0:
weight = (tr_mask[tr_train_mask.bool()].float() +
tcl_mask[tr_train_mask.bool()].float()) / 2
weight = weight.contiguous().view(-1, 1)
ft_x, ft_y = self.fourier2poly(x_map, y_map)
ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred)
loss_reg_x = torch.mean(weight * F.smooth_l1_loss(
ft_x_pre[tr_train_mask.bool()],
ft_x[tr_train_mask.bool()],
reduction='none'))
loss_reg_y = torch.mean(weight * F.smooth_l1_loss(
ft_y_pre[tr_train_mask.bool()],
ft_y[tr_train_mask.bool()],
reduction='none'))
return loss_tr, loss_tcl, loss_reg_x, loss_reg_y
def ohem(self, predict, target, train_mask):
device = train_mask.device
pos = (target * train_mask).bool()
neg = ((1 - target) * train_mask).bool()
n_pos = pos.float().sum()
if n_pos.item() > 0:
loss_pos = F.cross_entropy(
predict[pos], target[pos], reduction='sum')
loss_neg = F.cross_entropy(
predict[neg], target[neg], reduction='none')
n_neg = min(
int(neg.float().sum().item()),
int(self.ohem_ratio * n_pos.float()))
else:
loss_pos = torch.tensor(0.).to(device)
loss_neg = F.cross_entropy(
predict[neg], target[neg], reduction='none')
n_neg = 100
if len(loss_neg) > n_neg:
loss_neg, _ = torch.topk(loss_neg, n_neg)
return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()
def fourier2poly(self, real_maps, imag_maps):
"""Transform Fourier coefficient maps to polygon maps.
Args:
real_maps (tensor): A map composed of the real parts of the
Fourier coefficients, whose shape is (-1, 2k+1)
imag_maps (tensor):A map composed of the imag parts of the
Fourier coefficients, whose shape is (-1, 2k+1)
Returns
x_maps (tensor): A map composed of the x value of the polygon
represented by n sample points (xn, yn), whose shape is (-1, n)
y_maps (tensor): A map composed of the y value of the polygon
represented by n sample points (xn, yn), whose shape is (-1, n)
"""
device = real_maps.device
k_vect = torch.arange(
-self.fourier_degree,
self.fourier_degree + 1,
dtype=torch.float,
device=device).view(-1, 1)
i_vect = torch.arange(
0, self.num_sample, dtype=torch.float, device=device).view(1, -1)
transform_matrix = 2 * np.pi / self.num_sample * torch.mm(
k_vect, i_vect)
x1 = torch.einsum('ak, kn-> an', real_maps,
torch.cos(transform_matrix))
x2 = torch.einsum('ak, kn-> an', imag_maps,
torch.sin(transform_matrix))
y1 = torch.einsum('ak, kn-> an', real_maps,
torch.sin(transform_matrix))
y2 = torch.einsum('ak, kn-> an', imag_maps,
torch.cos(transform_matrix))
x_maps = x1 - x2
y_maps = y1 + y2
return x_maps, y_maps
|