hantech's picture
Duplicate from hantech/VietOCR
33c0fae
raw
history blame
No virus
1.61 kB
import torch
from torch import nn
from torchvision import models
from einops import rearrange
from torchvision.models._utils import IntermediateLayerGetter
class Vgg(nn.Module):
def __init__(self, name, ss, ks, hidden, pretrained=True, dropout=0.5):
super(Vgg, self).__init__()
if name == 'vgg11_bn':
cnn = models.vgg11_bn(weights='DEFAULT')
elif name == 'vgg19_bn':
cnn = models.vgg19_bn(weights='DEFAULT')
pool_idx = 0
for i, layer in enumerate(cnn.features):
if isinstance(layer, torch.nn.MaxPool2d):
cnn.features[i] = torch.nn.AvgPool2d(kernel_size=ks[pool_idx], stride=ss[pool_idx], padding=0)
pool_idx += 1
self.features = cnn.features
self.dropout = nn.Dropout(dropout)
self.last_conv_1x1 = nn.Conv2d(512, hidden, 1)
def forward(self, x):
"""
Shape:
- x: (N, C, H, W)
- output: (W, N, C)
"""
conv = self.features(x)
conv = self.dropout(conv)
conv = self.last_conv_1x1(conv)
# conv = rearrange(conv, 'b d h w -> b d (w h)')
conv = conv.transpose(-1, -2)
conv = conv.flatten(2)
conv = conv.permute(-1, 0, 1)
return conv
def vgg11_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
return Vgg('vgg11_bn', ss, ks, hidden, pretrained, dropout)
def vgg19_bn(ss, ks, hidden, pretrained=True, dropout=0.5):
return Vgg('vgg19_bn', ss, ks, hidden, pretrained, dropout)