FoodVision_Big / model.py
DiabeticOwl's picture
Matching requirements in main and fixing bug.
1956b65
raw
history blame contribute delete
No virus
1.61 kB
import torch
from pathlib import Path
from torch import nn
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
from torchvision.models._api import WeightsEnum
from torch.hub import load_state_dict_from_url
from typing import Optional, Tuple
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# A RuntimeError appeared in HuggingFace when the application tried to load the
# weights of the model. The following link is to the source of the fix now
# implemented.
# https://github.com/pytorch/vision/issues/7744#issuecomment-1757321451
def get_state_dict(self, *args, **kwargs):
"""
Override intented to fix a bug while loading the state_dict
from the internet.
"""
kwargs.pop("check_hash")
return load_state_dict_from_url(self.url, *args, **kwargs)
WeightsEnum.get_state_dict = get_state_dict
def create_effnetb2_model(
num_classes: int,
seed: Optional[int] = 42,
load_st_dict: Optional[bool] = False
) -> Tuple[nn.Module, nn.Module]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
weights = EfficientNet_B2_Weights.DEFAULT
transforms = weights.transforms()
model = efficientnet_b2(weights=weights)
model.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features=1408, out_features=num_classes, bias=True)
).to(DEVICE)
if load_st_dict:
st_dict = Path('model.pth')
model.load_state_dict(torch.load(st_dict, map_location=DEVICE))
for param in model.parameters():
param.requires_grad = False
return model.to(DEVICE), transforms