Spaces:
Build error
Build error
| import torch.nn as nn | |
| from strhub.models.modules import BidirectionalLSTM | |
| class CRNN(nn.Module): | |
| def __init__(self, img_h, nc, nclass, nh, leaky_relu=False): | |
| super().__init__() | |
| assert img_h % 16 == 0, 'img_h has to be a multiple of 16' | |
| ks = [3, 3, 3, 3, 3, 3, 2] | |
| ps = [1, 1, 1, 1, 1, 1, 0] | |
| ss = [1, 1, 1, 1, 1, 1, 1] | |
| nm = [64, 128, 256, 256, 512, 512, 512] | |
| cnn = nn.Sequential() | |
| def convRelu(i, batchNormalization=False): | |
| nIn = nc if i == 0 else nm[i - 1] | |
| nOut = nm[i] | |
| cnn.add_module('conv{0}'.format(i), | |
| nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i], bias=not batchNormalization)) | |
| if batchNormalization: | |
| cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) | |
| if leaky_relu: | |
| cnn.add_module('relu{0}'.format(i), | |
| nn.LeakyReLU(0.2, inplace=True)) | |
| else: | |
| cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) | |
| convRelu(0) | |
| cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 | |
| convRelu(1) | |
| cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 | |
| convRelu(2, True) | |
| convRelu(3) | |
| cnn.add_module('pooling{0}'.format(2), | |
| nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 | |
| convRelu(4, True) | |
| convRelu(5) | |
| cnn.add_module('pooling{0}'.format(3), | |
| nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 | |
| convRelu(6, True) # 512x1x16 | |
| self.cnn = cnn | |
| self.rnn = nn.Sequential( | |
| BidirectionalLSTM(512, nh, nh), | |
| BidirectionalLSTM(nh, nh, nclass)) | |
| def forward(self, input): | |
| # conv features | |
| conv = self.cnn(input) | |
| b, c, h, w = conv.size() | |
| assert h == 1, 'the height of conv must be 1' | |
| conv = conv.squeeze(2) | |
| conv = conv.transpose(1, 2) # [b, w, c] | |
| # rnn features | |
| output = self.rnn(conv) | |
| return output | |