fohy24
commited on
Commit
·
826a99e
1
Parent(s):
d7a718b
switched model to effnetv2 L
Browse files
app.py
CHANGED
@@ -31,14 +31,14 @@ num_labels = len(labels)
|
|
31 |
def predict(img, confidence):
|
32 |
|
33 |
new_layers = nn.Sequential(
|
34 |
-
nn.
|
35 |
-
nn.BatchNorm1d(
|
36 |
-
nn.ReLU(),
|
37 |
-
nn.Dropout(0.5),
|
38 |
-
nn.
|
39 |
)
|
40 |
|
41 |
-
IMAGE_SIZE =
|
42 |
transform = v2.Compose([
|
43 |
v2.ToImage(),
|
44 |
v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
@@ -47,26 +47,26 @@ def predict(img, confidence):
|
|
47 |
])
|
48 |
|
49 |
|
50 |
-
|
51 |
-
|
52 |
|
53 |
# If using GPU
|
54 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
55 |
|
56 |
hf_token = os.getenv('HF_token')
|
57 |
-
model_path = hf_hub_download(repo_id="fohy24/morphmarket_model", filename="
|
58 |
|
59 |
checkpoint = torch.load(model_path, map_location=device)
|
60 |
-
|
61 |
|
62 |
-
|
63 |
|
64 |
input_img = transform(img)
|
65 |
input_img = input_img.unsqueeze(0)
|
66 |
|
67 |
|
68 |
with torch.no_grad():
|
69 |
-
output =
|
70 |
|
71 |
predicted_probs = torch.sigmoid(output).to('cpu').flatten().tolist()
|
72 |
prediction_dict = {labels[i]: predicted_probs[i] for i in range(len(labels)) if predicted_probs[i] > confidence}
|
|
|
31 |
def predict(img, confidence):
|
32 |
|
33 |
new_layers = nn.Sequential(
|
34 |
+
nn.LazyLinear(1280),
|
35 |
+
nn.BatchNorm1d(1280),
|
36 |
+
nn.ReLU(),
|
37 |
+
nn.Dropout(0.5),
|
38 |
+
nn.LazyLinear(num_labels)
|
39 |
)
|
40 |
|
41 |
+
IMAGE_SIZE = 480
|
42 |
transform = v2.Compose([
|
43 |
v2.ToImage(),
|
44 |
v2.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
|
|
47 |
])
|
48 |
|
49 |
|
50 |
+
efficientnet = models.efficientnet_v2_l(weights='EfficientNet_V2_L_Weights.DEFAULT')
|
51 |
+
efficientnet.classifier = new_layers
|
52 |
|
53 |
# If using GPU
|
54 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
55 |
|
56 |
hf_token = os.getenv('HF_token')
|
57 |
+
model_path = hf_hub_download(repo_id="fohy24/morphmarket_model", filename="model_v9_epoch8.pt", token=hf_token)
|
58 |
|
59 |
checkpoint = torch.load(model_path, map_location=device)
|
60 |
+
efficientnet.load_state_dict(checkpoint['model_state_dict'])
|
61 |
|
62 |
+
efficientnet.eval()
|
63 |
|
64 |
input_img = transform(img)
|
65 |
input_img = input_img.unsqueeze(0)
|
66 |
|
67 |
|
68 |
with torch.no_grad():
|
69 |
+
output = efficientnet(input_img)
|
70 |
|
71 |
predicted_probs = torch.sigmoid(output).to('cpu').flatten().tolist()
|
72 |
prediction_dict = {labels[i]: predicted_probs[i] for i in range(len(labels)) if predicted_probs[i] > confidence}
|