FoodVision_Big / model.py
DiabeticOwl's picture
Matching requirements in main and fixing bug.
1956b65
raw
history blame
1.61 kB
import torch
from pathlib import Path
from torch import nn
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
from torchvision.models._api import WeightsEnum
from torch.hub import load_state_dict_from_url
from typing import Optional, Tuple
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# A RuntimeError appeared in HuggingFace when the application tried to load the
# weights of the model. The following link is to the source of the fix now
# implemented.
# https://github.com/pytorch/vision/issues/7744#issuecomment-1757321451
def get_state_dict(self, *args, **kwargs):
"""
Override intented to fix a bug while loading the state_dict
from the internet.
"""
kwargs.pop("check_hash")
return load_state_dict_from_url(self.url, *args, **kwargs)
WeightsEnum.get_state_dict = get_state_dict
def create_effnetb2_model(
num_classes: int,
seed: Optional[int] = 42,
load_st_dict: Optional[bool] = False
) -> Tuple[nn.Module, nn.Module]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
weights = EfficientNet_B2_Weights.DEFAULT
transforms = weights.transforms()
model = efficientnet_b2(weights=weights)
model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features=1408, out_features=num_classes, bias=True)
).to(DEVICE)
if load_st_dict:
st_dict = Path('model.pth')
model.load_state_dict(torch.load(st_dict, map_location=DEVICE))
for param in model.parameters():
param.requires_grad = False
return model.to(DEVICE), transforms