Pizza_Steak_Sushi_Classifier / create_effnet.py
itzRahul's picture
Updated the Repository with necessary files
840b425
raw
history blame contribute delete
711 Bytes
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
import torch
from torch import nn
def create_effnetb2_instance(num_classes:int=1000,
device:torch.device="cpu"):
effnet_weights = EfficientNet_B2_Weights.DEFAULT
effnet_transforms = effnet_weights.transforms()
effnet_model = efficientnet_b2(weights=effnet_weights).to(device)
# Base Layer Freeze
for param in effnet_model.parameters():
param.requires_grad = False
# Classifier Head Modification
effnet_model.classifier = nn.Sequential(
nn.Dropout(p=0.3),
nn.Linear(in_features=1408,
out_features=num_classes)
).to(device)
return effnet_model, effnet_transforms