m7mdal7aj commited on
Commit
7f2e710
1 Parent(s): 92894b3

added object detection to the space UI

Browse files
Files changed (1) hide show
  1. app.py +54 -1
app.py CHANGED
@@ -6,7 +6,7 @@ import scipy
6
  from PIL import Image
7
  import torch.nn as nn
8
  from transformers import Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration
9
- from My_Model.object_detection import ObjectDetector
10
 
11
  def load_caption_model(blip2=False, instructblip=True):
12
 
@@ -65,3 +65,56 @@ if st.button("Get Answer"):
65
  st.write(answer)
66
  else:
67
  st.write("Please upload an image and enter a question.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from PIL import Image
7
  import torch.nn as nn
8
  from transformers import Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration
9
+ from my_model.object_detection import ObjectDetector
10
 
11
  def load_caption_model(blip2=False, instructblip=True):
12
 
 
65
  st.write(answer)
66
  else:
67
  st.write("Please upload an image and enter a question.")
68
+
69
+
70
+
71
+
72
+
73
+
74
+ # Object Detection
75
+
76
+ # Object Detection UI in the sidebar
77
+ st.sidebar.title("Object Detection")
78
+ # Dropdown to select the model
79
+ detect_model = st.sidebar.selectbox("Choose a model for object detection:", ["detic", "yolov5"])
80
+ # Slider for threshold with default values based on the model
81
+ threshold = st.sidebar.slider("Select Detection Threshold", 0.1, 0.9, 0.2 if detect_model == "yolov5" else 0.4)
82
+ # Button to trigger object detection
83
+ detect_button = st.sidebar.button("Detect Objects")
84
+
85
+
86
+ def perform_object_detection(image, model_name, threshold):
87
+ """
88
+ Perform object detection on the given image using the specified model and threshold.
89
+
90
+ Args:
91
+ image (PIL.Image): The image on which to perform object detection.
92
+ model_name (str): The name of the object detection model to use.
93
+ threshold (float): The threshold for object detection.
94
+
95
+ Returns:
96
+ PIL.Image, str: The image with drawn bounding boxes and a string of detected objects.
97
+ """
98
+ # Initialize the ObjectDetector
99
+ detector = ObjectDetector()
100
+ # Load the specified model
101
+ detector.load_model(model_name)
102
+ # Perform object detection
103
+ processed_image, detected_objects = detector.detect_objects(image, threshold)
104
+ return processed_image, detected_objects
105
+
106
+ # Check if the 'Detect Objects' button was clicked
107
+ if detect_button:
108
+ if image is not None:
109
+ # Open the uploaded image
110
+ image = Image.open(image)
111
+ # Display the original image
112
+ st.image(image, use_column_width=True, caption="Original Image")
113
+ # Perform object detection
114
+ processed_image, detected_objects = perform_object_detection(image, detect_model, threshold)
115
+ # Display the image with detected objects
116
+ st.image(processed_image, use_column_width=True, caption="Image with Detected Objects")
117
+ # Display the detected objects
118
+ st.write(detected_objects)
119
+ else:
120
+ st.write("Please upload an image for object detection.")