fohy24
commited on
Commit
·
aadf03a
1
Parent(s):
8406ea0
downlaod model before calling predict()
Browse files
app.py
CHANGED
|
@@ -28,6 +28,13 @@ labels = ['Pastel',
|
|
| 28 |
'Super Pastel']
|
| 29 |
num_labels = len(labels)
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
def predict(img, confidence):
|
| 32 |
|
| 33 |
new_layers = nn.Sequential(
|
|
@@ -46,17 +53,9 @@ def predict(img, confidence):
|
|
| 46 |
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 47 |
])
|
| 48 |
|
| 49 |
-
|
| 50 |
efficientnet = models.efficientnet_v2_l(weights='EfficientNet_V2_L_Weights.DEFAULT')
|
| 51 |
efficientnet.classifier = new_layers
|
| 52 |
|
| 53 |
-
# If using GPU
|
| 54 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 55 |
-
|
| 56 |
-
hf_token = os.getenv('HF_token')
|
| 57 |
-
model_path = hf_hub_download(repo_id="fohy24/morphmarket_model", filename="model_v9_epoch8.pt", token=hf_token)
|
| 58 |
-
|
| 59 |
-
checkpoint = torch.load(model_path, map_location=device)
|
| 60 |
efficientnet.load_state_dict(checkpoint['model_state_dict'])
|
| 61 |
|
| 62 |
efficientnet.eval()
|
|
@@ -64,7 +63,6 @@ def predict(img, confidence):
|
|
| 64 |
input_img = transform(img)
|
| 65 |
input_img = input_img.unsqueeze(0)
|
| 66 |
|
| 67 |
-
|
| 68 |
with torch.no_grad():
|
| 69 |
output = efficientnet(input_img)
|
| 70 |
|
|
|
|
| 28 |
'Super Pastel']
|
| 29 |
num_labels = len(labels)
|
| 30 |
|
| 31 |
+
# If using GPU
|
| 32 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 33 |
+
|
| 34 |
+
hf_token = os.getenv('HF_token')
|
| 35 |
+
model_path = hf_hub_download(repo_id="fohy24/morphmarket_model", filename="model_v9_epoch8.pt", token=hf_token)
|
| 36 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 37 |
+
|
| 38 |
def predict(img, confidence):
|
| 39 |
|
| 40 |
new_layers = nn.Sequential(
|
|
|
|
| 53 |
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 54 |
])
|
| 55 |
|
|
|
|
| 56 |
efficientnet = models.efficientnet_v2_l(weights='EfficientNet_V2_L_Weights.DEFAULT')
|
| 57 |
efficientnet.classifier = new_layers
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
efficientnet.load_state_dict(checkpoint['model_state_dict'])
|
| 60 |
|
| 61 |
efficientnet.eval()
|
|
|
|
| 63 |
input_img = transform(img)
|
| 64 |
input_img = input_img.unsqueeze(0)
|
| 65 |
|
|
|
|
| 66 |
with torch.no_grad():
|
| 67 |
output = efficientnet(input_img)
|
| 68 |
|