# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Optional, Union import torch import torch.nn as nn from mmocr.models.textdet.heads import BaseTextDetHead from mmocr.registry import MODELS from mmocr.structures import TextDetDataSample @MODELS.register_module() class TextSnakeHead(BaseTextDetHead): """The class for TextSnake head: TextSnake: A Flexible Representation for Detecting Text of Arbitrary Shapes. TextSnake: `A Flexible Representation for Detecting Text of Arbitrary Shapes `_. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. downsample_ratio (float): Downsample ratio. module_loss (dict): Configuration dictionary for loss type. Defaults to ``dict(type='TextSnakeModuleLoss')``. postprocessor (dict): Config of postprocessor for TextSnake. init_cfg (dict or list[dict], optional): Initialization configs. """ def __init__( self, in_channels: int, out_channels: int = 5, downsample_ratio: float = 1.0, module_loss: Dict = dict(type='TextSnakeModuleLoss'), postprocessor: Dict = dict( type='TextSnakePostprocessor', text_repr_type='poly'), init_cfg: Optional[Union[Dict, List[Dict]]] = dict( type='Normal', override=dict(name='out_conv'), mean=0, std=0.01) ) -> None: super().__init__( module_loss=module_loss, postprocessor=postprocessor, init_cfg=init_cfg) assert isinstance(in_channels, int) assert isinstance(out_channels, int) self.in_channels = in_channels self.out_channels = out_channels self.downsample_ratio = downsample_ratio self.out_conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, stride=1, padding=0) def forward(self, inputs: torch.Tensor, data_samples: Optional[List[TextDetDataSample]] = None ) -> Dict: """ Args: inputs (torch.Tensor): Shape :math:`(N, C_{in}, H, W)`, where :math:`C_{in}` is ``in_channels``. :math:`H` and :math:`W` should be the same as the input of backbone. data_samples (list[TextDetDataSample], optional): A list of data samples. Defaults to None. Returns: Tensor: A tensor of shape :math:`(N, 5, H, W)`, where the five channels represent [0]: text score, [1]: center score, [2]: sin, [3] cos, [4] radius, respectively. """ outputs = self.out_conv(inputs) return outputs