fohy24
commited on
Commit
·
bddc3d6
1
Parent(s):
1a423d1
specifying pytorch version to fix model loading failure
Browse files- app.py +3 -2
- requirements.txt +2 -2
app.py
CHANGED
@@ -5,6 +5,7 @@ from torchvision.transforms import v2
|
|
5 |
import os
|
6 |
import requests
|
7 |
|
|
|
8 |
labels = ['Pastel',
|
9 |
'Yellow Belly',
|
10 |
'Enchi',
|
@@ -53,12 +54,12 @@ def predict(img, confidence):
|
|
53 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
54 |
|
55 |
# Download model from GCS
|
56 |
-
model_path = os.
|
57 |
response = requests.get(model_path)
|
58 |
|
59 |
with open('model.pt', 'wb') as f:
|
60 |
f.write(response.content)
|
61 |
-
|
62 |
checkpoint = torch.load('model.pt', map_location=device)
|
63 |
densenet.load_state_dict(checkpoint['model_state_dict'])
|
64 |
|
|
|
5 |
import os
|
6 |
import requests
|
7 |
|
8 |
+
|
9 |
labels = ['Pastel',
|
10 |
'Yellow Belly',
|
11 |
'Enchi',
|
|
|
54 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
55 |
|
56 |
# Download model from GCS
|
57 |
+
model_path = os.getenv('model_path')
|
58 |
response = requests.get(model_path)
|
59 |
|
60 |
with open('model.pt', 'wb') as f:
|
61 |
f.write(response.content)
|
62 |
+
|
63 |
checkpoint = torch.load('model.pt', map_location=device)
|
64 |
densenet.load_state_dict(checkpoint['model_state_dict'])
|
65 |
|
requirements.txt
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
torch
|
2 |
-
torchvision
|
|
|
1 |
+
torch==2.2.2
|
2 |
+
torchvision==0.17.2
|