fohy24
commited on
Commit
·
bddc3d6
1
Parent(s):
1a423d1
specifying pytorch version to fix model loading failure
Browse files- app.py +3 -2
- requirements.txt +2 -2
app.py
CHANGED
|
@@ -5,6 +5,7 @@ from torchvision.transforms import v2
|
|
| 5 |
import os
|
| 6 |
import requests
|
| 7 |
|
|
|
|
| 8 |
labels = ['Pastel',
|
| 9 |
'Yellow Belly',
|
| 10 |
'Enchi',
|
|
@@ -53,12 +54,12 @@ def predict(img, confidence):
|
|
| 53 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 54 |
|
| 55 |
# Download model from GCS
|
| 56 |
-
model_path = os.
|
| 57 |
response = requests.get(model_path)
|
| 58 |
|
| 59 |
with open('model.pt', 'wb') as f:
|
| 60 |
f.write(response.content)
|
| 61 |
-
|
| 62 |
checkpoint = torch.load('model.pt', map_location=device)
|
| 63 |
densenet.load_state_dict(checkpoint['model_state_dict'])
|
| 64 |
|
|
|
|
| 5 |
import os
|
| 6 |
import requests
|
| 7 |
|
| 8 |
+
|
| 9 |
labels = ['Pastel',
|
| 10 |
'Yellow Belly',
|
| 11 |
'Enchi',
|
|
|
|
| 54 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 55 |
|
| 56 |
# Download model from GCS
|
| 57 |
+
model_path = os.getenv('model_path')
|
| 58 |
response = requests.get(model_path)
|
| 59 |
|
| 60 |
with open('model.pt', 'wb') as f:
|
| 61 |
f.write(response.content)
|
| 62 |
+
|
| 63 |
checkpoint = torch.load('model.pt', map_location=device)
|
| 64 |
densenet.load_state_dict(checkpoint['model_state_dict'])
|
| 65 |
|
requirements.txt
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
-
torch
|
| 2 |
-
torchvision
|
|
|
|
| 1 |
+
torch==2.2.2
|
| 2 |
+
torchvision==0.17.2
|