# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Optional import torch import torch.nn as nn from mmdet.models.utils import multi_apply from mmocr.models.textdet.heads import BaseTextDetHead from mmocr.registry import MODELS from mmocr.structures import TextDetDataSample @MODELS.register_module() class FCEHead(BaseTextDetHead): """The class for implementing FCENet head. FCENet(CVPR2021): `Fourier Contour Embedding for Arbitrary-shaped Text Detection `_ Args: in_channels (int): The number of input channels. fourier_degree (int) : The maximum Fourier transform degree k. Defaults to 5. module_loss (dict): Config of loss for FCENet. Defaults to ``dict(type='FCEModuleLoss', num_sample=50)``. postprocessor (dict): Config of postprocessor for FCENet. init_cfg (dict, optional): Initialization configs. """ def __init__( self, in_channels: int, fourier_degree: int = 5, module_loss: Dict = dict(type='FCEModuleLoss', num_sample=50), postprocessor: Dict = dict( type='FCEPostprocessor', text_repr_type='poly', num_reconstr_points=50, alpha=1.0, beta=2.0, score_thr=0.3), init_cfg: Optional[Dict] = dict( type='Normal', mean=0, std=0.01, override=[dict(name='out_conv_cls'), dict(name='out_conv_reg')]) ) -> None: module_loss['fourier_degree'] = fourier_degree postprocessor['fourier_degree'] = fourier_degree super().__init__( module_loss=module_loss, postprocessor=postprocessor, init_cfg=init_cfg) assert isinstance(in_channels, int) assert isinstance(fourier_degree, int) self.in_channels = in_channels self.fourier_degree = fourier_degree self.out_channels_cls = 4 self.out_channels_reg = (2 * self.fourier_degree + 1) * 2 self.out_conv_cls = nn.Conv2d( self.in_channels, self.out_channels_cls, kernel_size=3, stride=1, padding=1) self.out_conv_reg = nn.Conv2d( self.in_channels, self.out_channels_reg, kernel_size=3, stride=1, padding=1) def forward(self, inputs: List[torch.Tensor], data_samples: Optional[List[TextDetDataSample]] = None ) -> Dict: """ Args: inputs (List[Tensor]): Each tensor has the shape of :math:`(N, C_i, H_i, W_i)`. data_samples (list[TextDetDataSample], optional): A list of data samples. Defaults to None. Returns: list[dict]: A list of dict with keys of ``cls_res``, ``reg_res`` corresponds to the classification result and regression result computed from the input tensor with the same index. They have the shapes of :math:`(N, C_{cls,i}, H_i, W_i)` and :math:`(N, C_{out,i}, H_i, W_i)`. """ cls_res, reg_res = multi_apply(self.forward_single, inputs) level_num = len(cls_res) preds = [ dict(cls_res=cls_res[i], reg_res=reg_res[i]) for i in range(level_num) ] return preds def forward_single(self, x: torch.Tensor) -> torch.Tensor: """Forward function for a single feature level. Args: x (Tensor): The input tensor with the shape of :math:`(N, C_i, H_i, W_i)`. Returns: Tensor: The classification and regression result with the shape of :math:`(N, C_{cls,i}, H_i, W_i)` and :math:`(N, C_{out,i}, H_i, W_i)`. """ cls_predict = self.out_conv_cls(x) reg_predict = self.out_conv_reg(x) return cls_predict, reg_predict