lombardata
commited on
Commit
•
1f62428
1
Parent(s):
fb33916
Update app.py
Browse files
app.py
CHANGED
@@ -46,7 +46,13 @@ model.to(device)
|
|
46 |
def sigmoid(_outputs):
|
47 |
return 1.0 / (1.0 + np.exp(-_outputs))
|
48 |
|
49 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# Preprocess the image
|
51 |
processor = AutoImageProcessor.from_pretrained(checkpoint_name)
|
52 |
inputs = processor(images=image, return_tensors="pt").to(device)
|
@@ -57,13 +63,26 @@ def predict(image, threshold=0.5):
|
|
57 |
logits = model_outputs.logits[0]
|
58 |
probabilities = torch.sigmoid(logits).cpu().numpy() # Convert to probabilities
|
59 |
|
60 |
-
# Create a dictionary of label scores
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
return filtered_results
|
67 |
|
68 |
# Define style
|
69 |
title = "Victor - DinoVd'eau image classification"
|
@@ -71,11 +90,14 @@ model_link = "https://huggingface.co/" + checkpoint_name
|
|
71 |
description = f"This application showcases the capability of artificial intelligence-based systems to identify objects within underwater images. To utilize it, you can either upload your own image or select one of the provided examples for analysis.\nFor predictions, we use this [open-source model]({model_link})"
|
72 |
|
73 |
iface = gr.Interface(
|
74 |
-
fn=
|
75 |
inputs=[gr.components.Image(type="pil"), gr.components.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Threshold")],
|
76 |
-
outputs=
|
|
|
|
|
|
|
77 |
title=title,
|
78 |
-
description
|
79 |
examples=[["session_GOPR0106.JPG"],
|
80 |
["session_2021_08_30_Mayotte_10_image_00066.jpg"],
|
81 |
["session_2018_11_17_kite_Le_Morne_Manawa_G0065777.JPG"],
|
|
|
46 |
def sigmoid(_outputs):
|
47 |
return 1.0 / (1.0 + np.exp(-_outputs))
|
48 |
|
49 |
+
def download_thresholds(repo_id, filename):
|
50 |
+
threshold_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
51 |
+
with open(threshold_path, 'r') as threshold_file:
|
52 |
+
thresholds = json.load(threshold_file)
|
53 |
+
return thresholds
|
54 |
+
|
55 |
+
def predict(image, slider_threshold=0.5, fixed_thresholds=None):
|
56 |
# Preprocess the image
|
57 |
processor = AutoImageProcessor.from_pretrained(checkpoint_name)
|
58 |
inputs = processor(images=image, return_tensors="pt").to(device)
|
|
|
63 |
logits = model_outputs.logits[0]
|
64 |
probabilities = torch.sigmoid(logits).cpu().numpy() # Convert to probabilities
|
65 |
|
66 |
+
# Create a dictionary of label scores based on the slider threshold
|
67 |
+
slider_results = {id2label[str(i)]: float(prob) for i, prob in enumerate(probabilities) if prob > slider_threshold}
|
68 |
+
|
69 |
+
# If fixed thresholds are provided, create a dictionary of label scores based on the fixed thresholds
|
70 |
+
fixed_threshold_results = None
|
71 |
+
if fixed_thresholds is not None:
|
72 |
+
fixed_threshold_results = {id2label[str(i)]: float(prob) for i, prob in enumerate(probabilities) if prob > fixed_thresholds[id2label[str(i)]]}
|
73 |
+
|
74 |
+
return slider_results, fixed_threshold_results
|
75 |
|
76 |
+
def predict_wrapper(image, slider_threshold):
|
77 |
+
# Download thresholds from the model repository
|
78 |
+
thresholds = download_thresholds(checkpoint_name, "threshold.json")
|
79 |
+
|
80 |
+
# Get predictions from the predict function using both the slider and fixed thresholds
|
81 |
+
slider_results, fixed_threshold_results = predict(image, slider_threshold, thresholds)
|
82 |
+
|
83 |
+
# Return both sets of predictions for Gradio outputs
|
84 |
+
return slider_results, fixed_threshold_results
|
85 |
|
|
|
86 |
|
87 |
# Define style
|
88 |
title = "Victor - DinoVd'eau image classification"
|
|
|
90 |
description = f"This application showcases the capability of artificial intelligence-based systems to identify objects within underwater images. To utilize it, you can either upload your own image or select one of the provided examples for analysis.\nFor predictions, we use this [open-source model]({model_link})"
|
91 |
|
92 |
iface = gr.Interface(
|
93 |
+
fn=predict_wrapper,
|
94 |
inputs=[gr.components.Image(type="pil"), gr.components.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Threshold")],
|
95 |
+
outputs=[
|
96 |
+
gr.components.Label(label="Slider Threshold Predictions"),
|
97 |
+
gr.components.Label(label="Fixed Thresholds Predictions")
|
98 |
+
],
|
99 |
title=title,
|
100 |
+
description=description,
|
101 |
examples=[["session_GOPR0106.JPG"],
|
102 |
["session_2021_08_30_Mayotte_10_image_00066.jpg"],
|
103 |
["session_2018_11_17_kite_Le_Morne_Manawa_G0065777.JPG"],
|