food_vision_mini / model.py
Your Name
error resolved
4f12f52
raw
history blame contribute delete
699 Bytes
import torch
from torchvision.models import EfficientNet_B2_Weights, efficientnet_b2
from torch import nn
def create_effnet_b2(num_classes:int = 3, seed:int = 42):
eff_weights = EfficientNet_B2_Weights.DEFAULT
efficientnet_transform = eff_weights.transforms()
effnet_model = efficientnet_b2(eff_weights)
for params in effnet_model.parameters():
params.requires_grad = False
torch.manual_seed(seed=seed)
effnet_model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features=1408, out_features=num_classes, bias=True)
)
return effnet_model, efficientnet_transform