import torch from pathlib import Path from torch import nn from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights from torchvision.models._api import WeightsEnum from torch.hub import load_state_dict_from_url from typing import Optional, Tuple DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' # A RuntimeError appeared in HuggingFace when the application tried to load the # weights of the model. The following link is to the source of the fix now # implemented. # https://github.com/pytorch/vision/issues/7744#issuecomment-1757321451 def get_state_dict(self, *args, **kwargs): """ Override intented to fix a bug while loading the state_dict from the internet. """ kwargs.pop("check_hash") return load_state_dict_from_url(self.url, *args, **kwargs) WeightsEnum.get_state_dict = get_state_dict def create_effnetb2_model( num_classes: int, seed: Optional[int] = 42, load_st_dict: Optional[bool] = False ) -> Tuple[nn.Module, nn.Module]: torch.manual_seed(seed) torch.cuda.manual_seed(seed) weights = EfficientNet_B2_Weights.DEFAULT transforms = weights.transforms() model = efficientnet_b2(weights=weights) model.classifier = nn.Sequential( nn.Dropout(p=0.3, inplace=True), nn.Linear(in_features=1408, out_features=num_classes, bias=True) ).to(DEVICE) if load_st_dict: st_dict = Path('model.pth') model.load_state_dict(torch.load(st_dict, map_location=DEVICE)) for param in model.parameters(): param.requires_grad = False return model.to(DEVICE), transforms