fohy24 commited on
Commit
d79e14c
·
1 Parent(s): c18a023

using HF hub to download model

Browse files
Files changed (1) hide show
  1. app.py +4 -17
app.py CHANGED
@@ -3,8 +3,7 @@ from torch import nn
3
  from torchvision import models
4
  from torchvision.transforms import v2
5
  import os
6
- import requests
7
- import time
8
 
9
 
10
  labels = ['Pastel',
@@ -54,22 +53,10 @@ def predict(img, confidence):
54
  # If using GPU
55
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
 
57
- # Download model from GCS
58
- model_path = os.getenv('MODEL_PATH')
59
- print(model_path)
60
- response = requests.get(model_path)
61
- time.sleep(20)
62
 
63
- # # Define the cache directory path
64
- # cache_dir = os.path.expanduser("~/.cache/torch/hub/checkpoints")
65
- # model_file_name = os.path.basename(model_path)
66
- # model_file_path = os.path.join(cache_dir, model_file_name)
67
- # print(model_file_path)
68
-
69
- with open('model.pt', 'wb') as f:
70
- f.write(response.content)
71
-
72
- checkpoint = torch.load('model.pt', map_location=device)
73
  densenet.load_state_dict(checkpoint['model_state_dict'])
74
 
75
  densenet.eval()
 
3
  from torchvision import models
4
  from torchvision.transforms import v2
5
  import os
6
+ from huggingface_hub import hf_hub_download
 
7
 
8
 
9
  labels = ['Pastel',
 
53
  # If using GPU
54
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
 
56
+ hf_token = os.getenv('HF_token')
57
+ model_path = hf_hub_download(repo_id="fohy24/morphmarket_model", filename="model_v8_epoch9.pt", token=hf_token)
 
 
 
58
 
59
+ checkpoint = torch.load(model_path, map_location=device)
 
 
 
 
 
 
 
 
 
60
  densenet.load_state_dict(checkpoint['model_state_dict'])
61
 
62
  densenet.eval()