Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
import vietocr.model.backbone.vgg as vgg | |
from vietocr.model.backbone.resnet import Resnet50 | |
class CNN(nn.Module): | |
def __init__(self, backbone, **kwargs): | |
super(CNN, self).__init__() | |
if backbone == 'vgg11_bn': | |
self.model = vgg.vgg11_bn(**kwargs) | |
elif backbone == 'vgg19_bn': | |
self.model = vgg.vgg19_bn(**kwargs) | |
elif backbone == 'resnet50': | |
self.model = Resnet50(**kwargs) | |
def forward(self, x): | |
return self.model(x) | |
def freeze(self): | |
for name, param in self.model.features.named_parameters(): | |
if name != 'last_conv_1x1': | |
param.requires_grad = False | |
def unfreeze(self): | |
for param in self.model.features.parameters(): | |
param.requires_grad = True | |