fohy24 commited on
Commit
1a423d1
·
1 Parent(s): 0b2c223

retreive trained model from GCS

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -1,10 +1,9 @@
1
  import torch
2
  from torch import nn
3
  from torchvision import models
4
- import torch.nn.functional as F
5
  from torchvision.transforms import v2
6
-
7
- from PIL import Image
8
 
9
  labels = ['Pastel',
10
  'Yellow Belly',
@@ -53,7 +52,14 @@ def predict(img, confidence):
53
  # If using GPU
54
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
 
56
- checkpoint = torch.load(f'model/model_v8_epoch9.pt', map_location=device)
 
 
 
 
 
 
 
57
  densenet.load_state_dict(checkpoint['model_state_dict'])
58
 
59
  densenet.eval()
 
1
  import torch
2
  from torch import nn
3
  from torchvision import models
 
4
  from torchvision.transforms import v2
5
+ import os
6
+ import requests
7
 
8
  labels = ['Pastel',
9
  'Yellow Belly',
 
52
  # If using GPU
53
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
 
55
+ # Download model from GCS
56
+ model_path = os.environ.get('model_path')
57
+ response = requests.get(model_path)
58
+
59
+ with open('model.pt', 'wb') as f:
60
+ f.write(response.content)
61
+
62
+ checkpoint = torch.load('model.pt', map_location=device)
63
  densenet.load_state_dict(checkpoint['model_state_dict'])
64
 
65
  densenet.eval()