Mountchicken's picture
Upload 704 files
9bf4bd7
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional
import torch
import torch.nn as nn
from mmdet.models.utils import multi_apply
from mmocr.models.textdet.heads import BaseTextDetHead
from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
@MODELS.register_module()
class FCEHead(BaseTextDetHead):
"""The class for implementing FCENet head.
FCENet(CVPR2021): `Fourier Contour Embedding for Arbitrary-shaped Text
Detection <https://arxiv.org/abs/2104.10442>`_
Args:
in_channels (int): The number of input channels.
fourier_degree (int) : The maximum Fourier transform degree k. Defaults
to 5.
module_loss (dict): Config of loss for FCENet. Defaults to
``dict(type='FCEModuleLoss', num_sample=50)``.
postprocessor (dict): Config of postprocessor for FCENet.
init_cfg (dict, optional): Initialization configs.
"""
def __init__(
self,
in_channels: int,
fourier_degree: int = 5,
module_loss: Dict = dict(type='FCEModuleLoss', num_sample=50),
postprocessor: Dict = dict(
type='FCEPostprocessor',
text_repr_type='poly',
num_reconstr_points=50,
alpha=1.0,
beta=2.0,
score_thr=0.3),
init_cfg: Optional[Dict] = dict(
type='Normal',
mean=0,
std=0.01,
override=[dict(name='out_conv_cls'),
dict(name='out_conv_reg')])
) -> None:
module_loss['fourier_degree'] = fourier_degree
postprocessor['fourier_degree'] = fourier_degree
super().__init__(
module_loss=module_loss,
postprocessor=postprocessor,
init_cfg=init_cfg)
assert isinstance(in_channels, int)
assert isinstance(fourier_degree, int)
self.in_channels = in_channels
self.fourier_degree = fourier_degree
self.out_channels_cls = 4
self.out_channels_reg = (2 * self.fourier_degree + 1) * 2
self.out_conv_cls = nn.Conv2d(
self.in_channels,
self.out_channels_cls,
kernel_size=3,
stride=1,
padding=1)
self.out_conv_reg = nn.Conv2d(
self.in_channels,
self.out_channels_reg,
kernel_size=3,
stride=1,
padding=1)
def forward(self,
inputs: List[torch.Tensor],
data_samples: Optional[List[TextDetDataSample]] = None
) -> Dict:
"""
Args:
inputs (List[Tensor]): Each tensor has the shape of :math:`(N, C_i,
H_i, W_i)`.
data_samples (list[TextDetDataSample], optional): A list of data
samples. Defaults to None.
Returns:
list[dict]: A list of dict with keys of ``cls_res``, ``reg_res``
corresponds to the classification result and regression result
computed from the input tensor with the same index. They have
the shapes of :math:`(N, C_{cls,i}, H_i, W_i)` and :math:`(N,
C_{out,i}, H_i, W_i)`.
"""
cls_res, reg_res = multi_apply(self.forward_single, inputs)
level_num = len(cls_res)
preds = [
dict(cls_res=cls_res[i], reg_res=reg_res[i])
for i in range(level_num)
]
return preds
def forward_single(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function for a single feature level.
Args:
x (Tensor): The input tensor with the shape of :math:`(N, C_i,
H_i, W_i)`.
Returns:
Tensor: The classification and regression result with the shape of
:math:`(N, C_{cls,i}, H_i, W_i)` and :math:`(N, C_{out,i}, H_i,
W_i)`.
"""
cls_predict = self.out_conv_cls(x)
reg_predict = self.out_conv_reg(x)
return cls_predict, reg_predict