from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights | |
import torch | |
from torch import nn | |
def create_effnetb2_instance(num_classes:int=1000, | |
device:torch.device="cpu"): | |
effnet_weights = EfficientNet_B2_Weights.DEFAULT | |
effnet_transforms = effnet_weights.transforms() | |
effnet_model = efficientnet_b2(weights=effnet_weights).to(device) | |
# Base Layer Freeze | |
for param in effnet_model.parameters(): | |
param.requires_grad = False | |
# Classifier Head Modification | |
effnet_model.classifier = nn.Sequential( | |
nn.Dropout(p=0.3), | |
nn.Linear(in_features=1408, | |
out_features=num_classes) | |
).to(device) | |
return effnet_model, effnet_transforms | |