lombardata commited on
Commit
1f62428
1 Parent(s): fb33916

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -9
app.py CHANGED
@@ -46,7 +46,13 @@ model.to(device)
46
  def sigmoid(_outputs):
47
  return 1.0 / (1.0 + np.exp(-_outputs))
48
 
49
- def predict(image, threshold=0.5):
 
 
 
 
 
 
50
  # Preprocess the image
51
  processor = AutoImageProcessor.from_pretrained(checkpoint_name)
52
  inputs = processor(images=image, return_tensors="pt").to(device)
@@ -57,13 +63,26 @@ def predict(image, threshold=0.5):
57
  logits = model_outputs.logits[0]
58
  probabilities = torch.sigmoid(logits).cpu().numpy() # Convert to probabilities
59
 
60
- # Create a dictionary of label scores
61
- results = {id2label[str(i)]: float(prob) for i, prob in enumerate(probabilities)}
 
 
 
 
 
 
 
62
 
63
- # Filter out predictions below a certain threshold (e.g., 0.5)
64
- filtered_results = {label: prob for label, prob in results.items() if prob > threshold}
 
 
 
 
 
 
 
65
 
66
- return filtered_results
67
 
68
  # Define style
69
  title = "Victor - DinoVd'eau image classification"
@@ -71,11 +90,14 @@ model_link = "https://huggingface.co/" + checkpoint_name
71
  description = f"This application showcases the capability of artificial intelligence-based systems to identify objects within underwater images. To utilize it, you can either upload your own image or select one of the provided examples for analysis.\nFor predictions, we use this [open-source model]({model_link})"
72
 
73
  iface = gr.Interface(
74
- fn=predict,
75
  inputs=[gr.components.Image(type="pil"), gr.components.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Threshold")],
76
- outputs=gr.components.Label(),
 
 
 
77
  title=title,
78
- description = description,
79
  examples=[["session_GOPR0106.JPG"],
80
  ["session_2021_08_30_Mayotte_10_image_00066.jpg"],
81
  ["session_2018_11_17_kite_Le_Morne_Manawa_G0065777.JPG"],
 
46
  def sigmoid(_outputs):
47
  return 1.0 / (1.0 + np.exp(-_outputs))
48
 
49
+ def download_thresholds(repo_id, filename):
50
+ threshold_path = hf_hub_download(repo_id=repo_id, filename=filename)
51
+ with open(threshold_path, 'r') as threshold_file:
52
+ thresholds = json.load(threshold_file)
53
+ return thresholds
54
+
55
+ def predict(image, slider_threshold=0.5, fixed_thresholds=None):
56
  # Preprocess the image
57
  processor = AutoImageProcessor.from_pretrained(checkpoint_name)
58
  inputs = processor(images=image, return_tensors="pt").to(device)
 
63
  logits = model_outputs.logits[0]
64
  probabilities = torch.sigmoid(logits).cpu().numpy() # Convert to probabilities
65
 
66
+ # Create a dictionary of label scores based on the slider threshold
67
+ slider_results = {id2label[str(i)]: float(prob) for i, prob in enumerate(probabilities) if prob > slider_threshold}
68
+
69
+ # If fixed thresholds are provided, create a dictionary of label scores based on the fixed thresholds
70
+ fixed_threshold_results = None
71
+ if fixed_thresholds is not None:
72
+ fixed_threshold_results = {id2label[str(i)]: float(prob) for i, prob in enumerate(probabilities) if prob > fixed_thresholds[id2label[str(i)]]}
73
+
74
+ return slider_results, fixed_threshold_results
75
 
76
+ def predict_wrapper(image, slider_threshold):
77
+ # Download thresholds from the model repository
78
+ thresholds = download_thresholds(checkpoint_name, "threshold.json")
79
+
80
+ # Get predictions from the predict function using both the slider and fixed thresholds
81
+ slider_results, fixed_threshold_results = predict(image, slider_threshold, thresholds)
82
+
83
+ # Return both sets of predictions for Gradio outputs
84
+ return slider_results, fixed_threshold_results
85
 
 
86
 
87
  # Define style
88
  title = "Victor - DinoVd'eau image classification"
 
90
  description = f"This application showcases the capability of artificial intelligence-based systems to identify objects within underwater images. To utilize it, you can either upload your own image or select one of the provided examples for analysis.\nFor predictions, we use this [open-source model]({model_link})"
91
 
92
  iface = gr.Interface(
93
+ fn=predict_wrapper,
94
  inputs=[gr.components.Image(type="pil"), gr.components.Slider(minimum=0.0, maximum=1.0, value=0.5, label="Threshold")],
95
+ outputs=[
96
+ gr.components.Label(label="Slider Threshold Predictions"),
97
+ gr.components.Label(label="Fixed Thresholds Predictions")
98
+ ],
99
  title=title,
100
+ description=description,
101
  examples=[["session_GOPR0106.JPG"],
102
  ["session_2021_08_30_Mayotte_10_image_00066.jpg"],
103
  ["session_2018_11_17_kite_Le_Morne_Manawa_G0065777.JPG"],