fohy24 commited on
Commit
bddc3d6
·
1 Parent(s): 1a423d1

specifying pytorch version to fix model loading failure

Browse files
Files changed (2) hide show
  1. app.py +3 -2
  2. requirements.txt +2 -2
app.py CHANGED
@@ -5,6 +5,7 @@ from torchvision.transforms import v2
5
  import os
6
  import requests
7
 
 
8
  labels = ['Pastel',
9
  'Yellow Belly',
10
  'Enchi',
@@ -53,12 +54,12 @@ def predict(img, confidence):
53
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
 
55
  # Download model from GCS
56
- model_path = os.environ.get('model_path')
57
  response = requests.get(model_path)
58
 
59
  with open('model.pt', 'wb') as f:
60
  f.write(response.content)
61
-
62
  checkpoint = torch.load('model.pt', map_location=device)
63
  densenet.load_state_dict(checkpoint['model_state_dict'])
64
 
 
5
  import os
6
  import requests
7
 
8
+
9
  labels = ['Pastel',
10
  'Yellow Belly',
11
  'Enchi',
 
54
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
 
56
  # Download model from GCS
57
+ model_path = os.getenv('model_path')
58
  response = requests.get(model_path)
59
 
60
  with open('model.pt', 'wb') as f:
61
  f.write(response.content)
62
+
63
  checkpoint = torch.load('model.pt', map_location=device)
64
  densenet.load_state_dict(checkpoint['model_state_dict'])
65
 
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- torch
2
- torchvision
 
1
+ torch==2.2.2
2
+ torchvision==0.17.2