Spaces:
Running
Running
import torch | |
from pathlib import Path | |
from torch import nn | |
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights | |
from torchvision.models._api import WeightsEnum | |
from torch.hub import load_state_dict_from_url | |
from typing import Optional, Tuple | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# A RuntimeError appeared in HuggingFace when the application tried to load the | |
# weights of the model. The following link is to the source of the fix now | |
# implemented. | |
# https://github.com/pytorch/vision/issues/7744#issuecomment-1757321451 | |
def get_state_dict(self, *args, **kwargs): | |
""" | |
Override intented to fix a bug while loading the state_dict | |
from the internet. | |
""" | |
kwargs.pop("check_hash") | |
return load_state_dict_from_url(self.url, *args, **kwargs) | |
WeightsEnum.get_state_dict = get_state_dict | |
def create_effnetb2_model( | |
num_classes: int, | |
seed: Optional[int] = 42, | |
load_st_dict: Optional[bool] = False | |
) -> Tuple[nn.Module, nn.Module]: | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
weights = EfficientNet_B2_Weights.DEFAULT | |
transforms = weights.transforms() | |
model = efficientnet_b2(weights=weights) | |
model.classifier = nn.Sequential( | |
nn.Dropout(p=0.3, inplace=True), | |
nn.Linear(in_features=1408, out_features=num_classes, bias=True) | |
).to(DEVICE) | |
if load_st_dict: | |
st_dict = Path('model.pth') | |
model.load_state_dict(torch.load(st_dict, map_location=DEVICE)) | |
for param in model.parameters(): | |
param.requires_grad = False | |
return model.to(DEVICE), transforms | |