FoodVision_Big / model.py
DiabeticOwl's picture
Initial Build.
cb1d8c8
import torch
from pathlib import Path
from torch import nn
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
from typing import Optional, Tuple
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def create_effnetb2_model(
num_classes: int,
seed: Optional[int] = 42,
load_st_dict: Optional[bool] = False
) -> Tuple[nn.Module, nn.Module]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
weights = EfficientNet_B2_Weights.DEFAULT
transforms = weights.transforms()
model = efficientnet_b2(weights=weights)
model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features=1408, out_features=num_classes, bias=True)
).to(DEVICE)
if load_st_dict:
st_dict = Path('model.pth')
model.load_state_dict(torch.load(st_dict, map_location=DEVICE))
for param in model.parameters():
param.requires_grad = False
return model.to(DEVICE), transforms