fohy24
commited on
Commit
·
bda76e2
1
Parent(s):
6fd8eba
prediction function returns dictionary
Browse files
app.py
CHANGED
@@ -74,15 +74,10 @@ def predict(img):
|
|
74 |
with torch.no_grad():
|
75 |
output = densenet(input_img)
|
76 |
|
77 |
-
predicted_probs = torch.sigmoid(output).to('cpu')
|
78 |
-
|
79 |
-
columns=labels).T.sort_values(by=['predictions'], ascending=False)
|
80 |
|
81 |
-
|
82 |
-
prediction_probs = prediction.query('predictions > 0.5').reset_index(names='morphs')['morphs'].to_list()
|
83 |
-
prediction_confidence = prediction.query('predictions > 0.5')['predictions'].to_list()
|
84 |
-
|
85 |
-
return prediction_probs, prediction_confidence
|
86 |
|
87 |
|
88 |
|
|
|
74 |
with torch.no_grad():
|
75 |
output = densenet(input_img)
|
76 |
|
77 |
+
predicted_probs = torch.sigmoid(output).to('cpu').flatten().tolist()
|
78 |
+
prediction_dict = {labels[i]: predicted_probs[i] for i in range(len(labels)) if predicted_probs[i] > 0.5}
|
|
|
79 |
|
80 |
+
return prediction_dict
|
|
|
|
|
|
|
|
|
81 |
|
82 |
|
83 |
|