LuisDarioHinojosa
first commit
69beb24
import torch
import torchvision
from torch import nn
def create_effnet_b2_instance(num_classes = 3):
# fetch the model's pretrained weights
effnetb2_pretrained_weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
# fetch the preprocessing transforms
effnetb2_transforms = effnetb2_pretrained_weights.transforms()
# get the model and load the pretrained weighits
effnetb2 = torchvision.models.efficientnet_b2(weights=effnetb2_pretrained_weights)
# freeze the feature extractor
for param in effnetb2.parameters():
param.requires_grad = False
# fix the output
effnetb2.classifier = nn.Sequential(
nn.Dropout(p = 0.3,inplace=True),
nn.Linear(in_features = 1408,out_features = num_classes)
)
return effnetb2_transforms,effnetb2