tomofi's picture
Add application file
2366e36
raw
history blame
No virus
2 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmocr.models.builder import BACKBONES
@BACKBONES.register_module()
class ShallowCNN(BaseModule):
"""Implement Shallow CNN block for SATRN.
SATRN: `On Recognizing Texts of Arbitrary Shapes with 2D Self-Attention
<https://arxiv.org/pdf/1910.04396.pdf>`_.
Args:
base_channels (int): Number of channels of input image tensor
:math:`D_i`.
hidden_dim (int): Size of hidden layers of the model :math:`D_m`.
init_cfg (dict or list[dict], optional): Initialization configs.
"""
def __init__(self,
input_channels=1,
hidden_dim=512,
init_cfg=[
dict(type='Kaiming', layer='Conv2d'),
dict(type='Uniform', layer='BatchNorm2d')
]):
super().__init__(init_cfg=init_cfg)
assert isinstance(input_channels, int)
assert isinstance(hidden_dim, int)
self.conv1 = ConvModule(
input_channels,
hidden_dim // 2,
kernel_size=3,
stride=1,
padding=1,
bias=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'))
self.conv2 = ConvModule(
hidden_dim // 2,
hidden_dim,
kernel_size=3,
stride=1,
padding=1,
bias=False,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'))
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
def forward(self, x):
"""
Args:
x (Tensor): Input image feature :math:`(N, D_i, H, W)`.
Returns:
Tensor: A tensor of shape :math:`(N, D_m, H/4, W/4)`.
"""
x = self.conv1(x)
x = self.pool(x)
x = self.conv2(x)
x = self.pool(x)
return x