fohy24
commited on
Commit
·
1a423d1
1
Parent(s):
0b2c223
retreive trained model from GCS
Browse files
app.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
from torchvision import models
|
4 |
-
import torch.nn.functional as F
|
5 |
from torchvision.transforms import v2
|
6 |
-
|
7 |
-
|
8 |
|
9 |
labels = ['Pastel',
|
10 |
'Yellow Belly',
|
@@ -53,7 +52,14 @@ def predict(img, confidence):
|
|
53 |
# If using GPU
|
54 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
55 |
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
densenet.load_state_dict(checkpoint['model_state_dict'])
|
58 |
|
59 |
densenet.eval()
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
3 |
from torchvision import models
|
|
|
4 |
from torchvision.transforms import v2
|
5 |
+
import os
|
6 |
+
import requests
|
7 |
|
8 |
labels = ['Pastel',
|
9 |
'Yellow Belly',
|
|
|
52 |
# If using GPU
|
53 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
54 |
|
55 |
+
# Download model from GCS
|
56 |
+
model_path = os.environ.get('model_path')
|
57 |
+
response = requests.get(model_path)
|
58 |
+
|
59 |
+
with open('model.pt', 'wb') as f:
|
60 |
+
f.write(response.content)
|
61 |
+
|
62 |
+
checkpoint = torch.load('model.pt', map_location=device)
|
63 |
densenet.load_state_dict(checkpoint['model_state_dict'])
|
64 |
|
65 |
densenet.eval()
|