DiabeticOwl commited on
Commit
1956b65
1 Parent(s): 39a9c1b

Matching requirements in main and fixing bug.

Browse files

This commit to fix the RuntimeError that shows in HuggingFace through a patch found in a GitHub issue on the PyTorch Vision repository and matching some of the requirements with local.

Files changed (2) hide show
  1. model.py +21 -2
  2. requirements.txt +3 -3
model.py CHANGED
@@ -3,10 +3,29 @@ import torch
3
  from pathlib import Path
4
  from torch import nn
5
  from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
 
 
6
  from typing import Optional, Tuple
7
 
8
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def create_effnetb2_model(
12
  num_classes: int,
@@ -18,7 +37,7 @@ def create_effnetb2_model(
18
  weights = EfficientNet_B2_Weights.DEFAULT
19
  transforms = weights.transforms()
20
  model = efficientnet_b2(weights=weights)
21
-
22
  model.classifier = nn.Sequential(
23
  nn.Dropout(p=0.3, inplace=True),
24
  nn.Linear(in_features=1408, out_features=num_classes, bias=True)
@@ -28,5 +47,5 @@ def create_effnetb2_model(
28
  model.load_state_dict(torch.load(st_dict, map_location=DEVICE))
29
  for param in model.parameters():
30
  param.requires_grad = False
31
-
32
  return model.to(DEVICE), transforms
 
3
  from pathlib import Path
4
  from torch import nn
5
  from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights
6
+ from torchvision.models._api import WeightsEnum
7
+ from torch.hub import load_state_dict_from_url
8
  from typing import Optional, Tuple
9
 
10
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
11
 
12
+ # A RuntimeError appeared in HuggingFace when the application tried to load the
13
+ # weights of the model. The following link is to the source of the fix now
14
+ # implemented.
15
+ # https://github.com/pytorch/vision/issues/7744#issuecomment-1757321451
16
+
17
+
18
+ def get_state_dict(self, *args, **kwargs):
19
+ """
20
+ Override intented to fix a bug while loading the state_dict
21
+ from the internet.
22
+ """
23
+ kwargs.pop("check_hash")
24
+ return load_state_dict_from_url(self.url, *args, **kwargs)
25
+
26
+
27
+ WeightsEnum.get_state_dict = get_state_dict
28
+
29
 
30
  def create_effnetb2_model(
31
  num_classes: int,
 
37
  weights = EfficientNet_B2_Weights.DEFAULT
38
  transforms = weights.transforms()
39
  model = efficientnet_b2(weights=weights)
40
+
41
  model.classifier = nn.Sequential(
42
  nn.Dropout(p=0.3, inplace=True),
43
  nn.Linear(in_features=1408, out_features=num_classes, bias=True)
 
47
  model.load_state_dict(torch.load(st_dict, map_location=DEVICE))
48
  for param in model.parameters():
49
  param.requires_grad = False
50
+
51
  return model.to(DEVICE), transforms
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- torch>=2.0.0
2
- torchvision>=0.15.0
3
- gradio>=3.33.1
 
1
+ torch>=2.1.0
2
+ torchvision>=0.16.0
3
+ gradio==3.33.1