fohy24
commited on
Commit
·
d79e14c
1
Parent(s):
c18a023
using HF hub to download model
Browse files
app.py
CHANGED
@@ -3,8 +3,7 @@ from torch import nn
|
|
3 |
from torchvision import models
|
4 |
from torchvision.transforms import v2
|
5 |
import os
|
6 |
-
import
|
7 |
-
import time
|
8 |
|
9 |
|
10 |
labels = ['Pastel',
|
@@ -54,22 +53,10 @@ def predict(img, confidence):
|
|
54 |
# If using GPU
|
55 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
56 |
|
57 |
-
|
58 |
-
model_path =
|
59 |
-
print(model_path)
|
60 |
-
response = requests.get(model_path)
|
61 |
-
time.sleep(20)
|
62 |
|
63 |
-
|
64 |
-
# cache_dir = os.path.expanduser("~/.cache/torch/hub/checkpoints")
|
65 |
-
# model_file_name = os.path.basename(model_path)
|
66 |
-
# model_file_path = os.path.join(cache_dir, model_file_name)
|
67 |
-
# print(model_file_path)
|
68 |
-
|
69 |
-
with open('model.pt', 'wb') as f:
|
70 |
-
f.write(response.content)
|
71 |
-
|
72 |
-
checkpoint = torch.load('model.pt', map_location=device)
|
73 |
densenet.load_state_dict(checkpoint['model_state_dict'])
|
74 |
|
75 |
densenet.eval()
|
|
|
3 |
from torchvision import models
|
4 |
from torchvision.transforms import v2
|
5 |
import os
|
6 |
+
from huggingface_hub import hf_hub_download
|
|
|
7 |
|
8 |
|
9 |
labels = ['Pastel',
|
|
|
53 |
# If using GPU
|
54 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
55 |
|
56 |
+
hf_token = os.getenv('HF_token')
|
57 |
+
model_path = hf_hub_download(repo_id="fohy24/morphmarket_model", filename="model_v8_epoch9.pt", token=hf_token)
|
|
|
|
|
|
|
58 |
|
59 |
+
checkpoint = torch.load(model_path, map_location=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
densenet.load_state_dict(checkpoint['model_state_dict'])
|
61 |
|
62 |
densenet.eval()
|