fohy24 commited on
Commit
aadf03a
·
1 Parent(s): 8406ea0

downlaod model before calling predict()

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -28,6 +28,13 @@ labels = ['Pastel',
28
  'Super Pastel']
29
  num_labels = len(labels)
30
 
 
 
 
 
 
 
 
31
  def predict(img, confidence):
32
 
33
  new_layers = nn.Sequential(
@@ -46,17 +53,9 @@ def predict(img, confidence):
46
  v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
47
  ])
48
 
49
-
50
  efficientnet = models.efficientnet_v2_l(weights='EfficientNet_V2_L_Weights.DEFAULT')
51
  efficientnet.classifier = new_layers
52
 
53
- # If using GPU
54
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
-
56
- hf_token = os.getenv('HF_token')
57
- model_path = hf_hub_download(repo_id="fohy24/morphmarket_model", filename="model_v9_epoch8.pt", token=hf_token)
58
-
59
- checkpoint = torch.load(model_path, map_location=device)
60
  efficientnet.load_state_dict(checkpoint['model_state_dict'])
61
 
62
  efficientnet.eval()
@@ -64,7 +63,6 @@ def predict(img, confidence):
64
  input_img = transform(img)
65
  input_img = input_img.unsqueeze(0)
66
 
67
-
68
  with torch.no_grad():
69
  output = efficientnet(input_img)
70
 
 
28
  'Super Pastel']
29
  num_labels = len(labels)
30
 
31
+ # If using GPU
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+ hf_token = os.getenv('HF_token')
35
+ model_path = hf_hub_download(repo_id="fohy24/morphmarket_model", filename="model_v9_epoch8.pt", token=hf_token)
36
+ checkpoint = torch.load(model_path, map_location=device)
37
+
38
  def predict(img, confidence):
39
 
40
  new_layers = nn.Sequential(
 
53
  v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
54
  ])
55
 
 
56
  efficientnet = models.efficientnet_v2_l(weights='EfficientNet_V2_L_Weights.DEFAULT')
57
  efficientnet.classifier = new_layers
58
 
 
 
 
 
 
 
 
59
  efficientnet.load_state_dict(checkpoint['model_state_dict'])
60
 
61
  efficientnet.eval()
 
63
  input_img = transform(img)
64
  input_img = input_img.unsqueeze(0)
65
 
 
66
  with torch.no_grad():
67
  output = efficientnet(input_img)
68