from torchvision import models as models import torch.nn as nn def model(pretrained, requires_grad): model = models.resnet50(progress=True, pretrained=pretrained) # to freeze the hidden layers if requires_grad == False: for param in model.parameters(): param.requires_grad = False # to train the hidden layers elif requires_grad == True: for param in model.parameters(): param.requires_grad = True # make the classification layer learnable # we have 25 classes in total model.fc = nn.Linear(2048, 25) return model