topdu's picture
openocr demo
29f689c
raw
history blame
1.24 kB
import torch
from torch import nn
class MTB(nn.Module):
def __init__(self, cnn_num, in_channels):
super(MTB, self).__init__()
self.block = nn.Sequential()
self.out_channels = in_channels
self.cnn_num = cnn_num
if self.cnn_num == 2:
for i in range(self.cnn_num):
self.block.add_module(
'conv_{}'.format(i),
nn.Conv2d(
in_channels=in_channels if i == 0 else 32 *
(2**(i - 1)),
out_channels=32 * (2**i),
kernel_size=3,
stride=2,
padding=1,
),
)
self.block.add_module('relu_{}'.format(i), nn.ReLU())
self.block.add_module('bn_{}'.format(i),
nn.BatchNorm2d(32 * (2**i)))
def forward(self, images):
x = self.block(images)
if self.cnn_num == 2:
# (b, w, h, c)
x = x.permute(0, 3, 2, 1)
x_shape = x.shape
x = torch.reshape(
x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3]))
return x