|
from torch import nn |
|
import torch |
|
from .RecSVTR import Block |
|
|
|
class Swish(nn.Module): |
|
def __int__(self): |
|
super(Swish, self).__int__() |
|
|
|
def forward(self,x): |
|
return x*torch.sigmoid(x) |
|
|
|
class Im2Im(nn.Module): |
|
def __init__(self, in_channels, **kwargs): |
|
super().__init__() |
|
self.out_channels = in_channels |
|
|
|
def forward(self, x): |
|
return x |
|
|
|
class Im2Seq(nn.Module): |
|
def __init__(self, in_channels, **kwargs): |
|
super().__init__() |
|
self.out_channels = in_channels |
|
|
|
def forward(self, x): |
|
B, C, H, W = x.shape |
|
|
|
x = x.reshape(B, C, H * W) |
|
x = x.permute((0, 2, 1)) |
|
return x |
|
|
|
class EncoderWithRNN(nn.Module): |
|
def __init__(self, in_channels,**kwargs): |
|
super(EncoderWithRNN, self).__init__() |
|
hidden_size = kwargs.get('hidden_size', 256) |
|
self.out_channels = hidden_size * 2 |
|
self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2,batch_first=True) |
|
|
|
def forward(self, x): |
|
self.lstm.flatten_parameters() |
|
x, _ = self.lstm(x) |
|
return x |
|
|
|
class SequenceEncoder(nn.Module): |
|
def __init__(self, in_channels, encoder_type='rnn', **kwargs): |
|
super(SequenceEncoder, self).__init__() |
|
self.encoder_reshape = Im2Seq(in_channels) |
|
self.out_channels = self.encoder_reshape.out_channels |
|
self.encoder_type = encoder_type |
|
if encoder_type == 'reshape': |
|
self.only_reshape = True |
|
else: |
|
support_encoder_dict = { |
|
'reshape': Im2Seq, |
|
'rnn': EncoderWithRNN, |
|
'svtr': EncoderWithSVTR |
|
} |
|
assert encoder_type in support_encoder_dict, '{} must in {}'.format( |
|
encoder_type, support_encoder_dict.keys()) |
|
|
|
self.encoder = support_encoder_dict[encoder_type]( |
|
self.encoder_reshape.out_channels,**kwargs) |
|
self.out_channels = self.encoder.out_channels |
|
self.only_reshape = False |
|
|
|
def forward(self, x): |
|
if self.encoder_type != 'svtr': |
|
x = self.encoder_reshape(x) |
|
if not self.only_reshape: |
|
x = self.encoder(x) |
|
return x |
|
else: |
|
x = self.encoder(x) |
|
x = self.encoder_reshape(x) |
|
return x |
|
|
|
class ConvBNLayer(nn.Module): |
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=0, |
|
bias_attr=False, |
|
groups=1, |
|
act=nn.GELU): |
|
super().__init__() |
|
self.conv = nn.Conv2d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
groups=groups, |
|
|
|
bias=bias_attr) |
|
self.norm = nn.BatchNorm2d(out_channels) |
|
self.act = Swish() |
|
|
|
def forward(self, inputs): |
|
out = self.conv(inputs) |
|
out = self.norm(out) |
|
out = self.act(out) |
|
return out |
|
|
|
|
|
class EncoderWithSVTR(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
dims=64, |
|
depth=2, |
|
hidden_dims=120, |
|
use_guide=False, |
|
num_heads=8, |
|
qkv_bias=True, |
|
mlp_ratio=2.0, |
|
drop_rate=0.1, |
|
attn_drop_rate=0.1, |
|
drop_path=0., |
|
qk_scale=None): |
|
super(EncoderWithSVTR, self).__init__() |
|
self.depth = depth |
|
self.use_guide = use_guide |
|
self.conv1 = ConvBNLayer( |
|
in_channels, in_channels // 8, padding=1, act='swish') |
|
self.conv2 = ConvBNLayer( |
|
in_channels // 8, hidden_dims, kernel_size=1, act='swish') |
|
|
|
self.svtr_block = nn.ModuleList([ |
|
Block( |
|
dim=hidden_dims, |
|
num_heads=num_heads, |
|
mixer='Global', |
|
HW=None, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
drop=drop_rate, |
|
act_layer='swish', |
|
attn_drop=attn_drop_rate, |
|
drop_path=drop_path, |
|
norm_layer='nn.LayerNorm', |
|
epsilon=1e-05, |
|
prenorm=False) for i in range(depth) |
|
]) |
|
self.norm = nn.LayerNorm(hidden_dims, eps=1e-6) |
|
self.conv3 = ConvBNLayer( |
|
hidden_dims, in_channels, kernel_size=1, act='swish') |
|
|
|
self.conv4 = ConvBNLayer( |
|
2 * in_channels, in_channels // 8, padding=1, act='swish') |
|
|
|
self.conv1x1 = ConvBNLayer( |
|
in_channels // 8, dims, kernel_size=1, act='swish') |
|
self.out_channels = dims |
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
|
|
if isinstance(m, nn.Conv2d): |
|
nn.init.kaiming_normal_(m.weight, mode='fan_out') |
|
if m.bias is not None: |
|
nn.init.zeros_(m.bias) |
|
elif isinstance(m, nn.BatchNorm2d): |
|
nn.init.ones_(m.weight) |
|
nn.init.zeros_(m.bias) |
|
elif isinstance(m, nn.Linear): |
|
nn.init.normal_(m.weight, 0, 0.01) |
|
if m.bias is not None: |
|
nn.init.zeros_(m.bias) |
|
elif isinstance(m, nn.ConvTranspose2d): |
|
nn.init.kaiming_normal_(m.weight, mode='fan_out') |
|
if m.bias is not None: |
|
nn.init.zeros_(m.bias) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.ones_(m.weight) |
|
nn.init.zeros_(m.bias) |
|
|
|
def forward(self, x): |
|
|
|
if self.use_guide: |
|
z = x.clone() |
|
z.stop_gradient = True |
|
else: |
|
z = x |
|
|
|
h = z |
|
|
|
z = self.conv1(z) |
|
z = self.conv2(z) |
|
|
|
B, C, H, W = z.shape |
|
z = z.flatten(2).permute(0, 2, 1) |
|
|
|
for blk in self.svtr_block: |
|
z = blk(z) |
|
|
|
z = self.norm(z) |
|
|
|
z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2) |
|
z = self.conv3(z) |
|
z = torch.cat((h, z), dim=1) |
|
z = self.conv1x1(self.conv4(z)) |
|
|
|
return z |
|
|
|
if __name__=="__main__": |
|
svtrRNN = EncoderWithSVTR(56) |
|
print(svtrRNN) |