hantech's picture
Duplicate from hantech/VietOCR
33c0fae
raw
history blame
857 Bytes
import torch
from torch import nn
import vietocr.model.backbone.vgg as vgg
from vietocr.model.backbone.resnet import Resnet50
class CNN(nn.Module):
def __init__(self, backbone, **kwargs):
super(CNN, self).__init__()
if backbone == 'vgg11_bn':
self.model = vgg.vgg11_bn(**kwargs)
elif backbone == 'vgg19_bn':
self.model = vgg.vgg19_bn(**kwargs)
elif backbone == 'resnet50':
self.model = Resnet50(**kwargs)
def forward(self, x):
return self.model(x)
def freeze(self):
for name, param in self.model.features.named_parameters():
if name != 'last_conv_1x1':
param.requires_grad = False
def unfreeze(self):
for param in self.model.features.parameters():
param.requires_grad = True