JohnJoelMota commited on
Commit
8d47979
·
verified ·
1 Parent(s): b7582d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -264
app.py CHANGED
@@ -11,19 +11,11 @@ import os
11
  import sys
12
  import io
13
 
14
- # Set up logging
15
- import logging
16
- logging.basicConfig(level=logging.INFO,
17
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
18
- stream=sys.stderr)
19
- logger = logging.getLogger(__name__)
20
 
21
- # Load models once at startup
22
- logger.info("Loading Faster R-CNN model...")
23
- rcnn_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
24
- rcnn_model.eval()
25
-
26
- logger.info("Loading DETR model...")
27
  detr_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
28
  detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
29
 
@@ -43,295 +35,168 @@ COCO_INSTANCE_CATEGORY_NAMES = [
43
  'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
44
  ]
45
 
46
- def faster_rcnn_detection(image, threshold=0.5):
47
- """Detect objects using Faster R-CNN model"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  if image is None:
49
- logger.error("No image provided to Faster R-CNN detector")
50
- return create_error_image("No image provided")
51
-
 
 
 
 
 
 
 
 
52
  try:
53
- logger.info(f"Processing image with Faster R-CNN (threshold: {threshold})")
54
-
55
- # Convert threshold to float
56
- threshold = float(threshold)
57
-
58
- # Apply transforms required by the model
59
  transform = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
60
  image_tensor = transform(image).unsqueeze(0)
61
-
62
- # Run detection
63
  with torch.no_grad():
64
- prediction = rcnn_model(image_tensor)[0]
65
-
66
- # Extract results
67
  boxes = prediction['boxes'].cpu().numpy()
68
  labels = prediction['labels'].cpu().numpy()
69
  scores = prediction['scores'].cpu().numpy()
70
-
71
- # Create visualization
72
  image_np = np.array(image)
73
  plt.figure(figsize=(10, 10))
74
  plt.imshow(image_np)
75
  ax = plt.gca()
76
-
77
- # Draw bounding boxes
78
  for box, label, score in zip(boxes, labels, scores):
79
  if score >= threshold:
80
  x1, y1, x2, y2 = box
81
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
82
- fill=False, color='red', linewidth=2))
83
  class_name = COCO_INSTANCE_CATEGORY_NAMES[label]
84
- ax.text(x1, y1, f'{class_name}: {score:.2f}',
85
- bbox=dict(facecolor='yellow', alpha=0.5),
86
- fontsize=12, color='black')
87
-
88
- plt.title("Faster R-CNN Detection")
89
  plt.axis('off')
90
  plt.tight_layout()
91
-
92
- # Save result to image
93
- output = io.BytesIO()
94
- plt.savefig(output, format='png')
95
  plt.close()
96
- output.seek(0)
97
-
98
- return Image.open(output)
99
-
100
  except Exception as e:
101
- logger.error(f"Error in Faster R-CNN detection: {e}")
102
- import traceback
103
- traceback.print_exc(file=sys.stderr)
104
- return create_error_image(f"Faster R-CNN error: {str(e)}")
 
 
 
 
 
 
105
 
106
- def detr_detection(image, threshold=0.5):
107
- """Detect objects using DETR model"""
108
  if image is None:
109
- logger.error("No image provided to DETR detector")
110
- return create_error_image("No image provided")
111
-
 
 
 
 
 
 
 
 
 
112
  try:
113
- logger.info(f"Processing image with DETR (threshold: {threshold})")
114
-
115
- # Convert threshold to float
116
- threshold = float(threshold)
117
-
118
- # Process image and run model
119
  inputs = detr_processor(images=image, return_tensors="pt")
120
  outputs = detr_model(**inputs)
121
-
122
- # Post-process results
123
  target_sizes = torch.tensor([image.size[::-1]])
124
- results = detr_processor.post_process_object_detection(
125
- outputs, target_sizes=target_sizes, threshold=threshold)[0]
126
-
127
- # Create visualization
128
  fig, ax = plt.subplots(1, figsize=(10, 10))
129
  ax.imshow(image)
130
-
131
- # Draw bounding boxes
132
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
133
  xmin, ymin, xmax, ymax = box.tolist()
134
- ax.add_patch(patches.Rectangle(
135
- (xmin, ymin), xmax - xmin, ymax - ymin,
136
- linewidth=2, edgecolor='blue', facecolor='none'))
137
- ax.text(xmin, ymin, f"{detr_model.config.id2label[label.item()]}: {score:.2f}",
138
- bbox=dict(facecolor='cyan', alpha=0.5), fontsize=12)
139
-
140
- plt.title("DETR Detection")
141
  plt.axis('off')
142
- plt.tight_layout()
143
-
144
- # Save result to image
145
- output = io.BytesIO()
146
- plt.savefig(output, format='png')
147
  plt.close(fig)
148
- output.seek(0)
149
-
150
- return Image.open(output)
151
-
152
  except Exception as e:
153
- logger.error(f"Error in DETR detection: {e}")
154
- import traceback
155
- traceback.print_exc(file=sys.stderr)
156
- return create_error_image(f"DETR error: {str(e)}")
157
-
158
- def compare_detections(image, threshold=0.5):
159
- """Run both models and return side-by-side comparison"""
160
- if image is None:
161
- logger.error("No image provided for comparison")
162
- return create_error_image("No image provided")
163
-
164
- try:
165
- logger.info(f"Comparing both models with threshold: {threshold}")
166
-
167
- # Run both models
168
- rcnn_result = faster_rcnn_detection(image, threshold)
169
- detr_result = detr_detection(image, threshold)
170
-
171
- # Create side-by-side comparison
172
- fig, axes = plt.subplots(1, 2, figsize=(20, 10))
173
-
174
- axes[0].imshow(np.array(rcnn_result))
175
- axes[0].set_title("Faster R-CNN Detection", fontsize=16)
176
- axes[0].axis('off')
177
-
178
- axes[1].imshow(np.array(detr_result))
179
- axes[1].set_title("DETR Detection", fontsize=16)
180
- axes[1].axis('off')
181
-
182
- plt.tight_layout()
183
-
184
- # Save comparison to image
185
- output = io.BytesIO()
186
- plt.savefig(output, format='png', dpi=120)
187
  plt.close(fig)
188
- output.seek(0)
189
-
190
- return Image.open(output)
191
-
192
- except Exception as e:
193
- logger.error(f"Error in comparison: {e}")
194
- import traceback
195
- traceback.print_exc(file=sys.stderr)
196
- return create_error_image(f"Comparison error: {str(e)}")
197
 
198
- def create_error_image(error_text):
199
- """Create an image with error message"""
200
- error_img = Image.new('RGB', (800, 400), color='white')
201
- fig, ax = plt.subplots(figsize=(8, 4))
202
- ax.imshow(error_img)
203
- ax.text(0.5, 0.5, f"Error: {error_text}",
204
- horizontalalignment='center', verticalalignment='center',
205
- transform=ax.transAxes, fontsize=14, wrap=True)
206
- ax.axis('off')
207
-
208
- # Save to buffer
209
- buf = io.BytesIO()
210
- plt.savefig(buf, format='png')
211
- plt.close(fig)
212
- buf.seek(0)
213
-
214
- return Image.open(buf)
215
 
216
- def detect_objects(image, model_choice, threshold=0.5):
217
- """Main detection function that routes to the appropriate model"""
218
- if image is None:
219
- return create_error_image("No image provided")
220
-
221
- if model_choice == "Faster R-CNN":
222
- return faster_rcnn_detection(image, threshold)
223
- elif model_choice == "DETR":
224
- return detr_detection(image, threshold)
225
- elif model_choice == "Compare Both":
226
- return compare_detections(image, threshold)
227
- else:
228
- return create_error_image(f"Unknown model choice: {model_choice}")
229
 
230
- def model_info(model_choice):
231
- """Provide information about the selected model"""
232
- if model_choice == "Faster R-CNN":
233
- return """
234
- **Faster R-CNN** is a two-stage object detection model that first proposes regions of interest and then classifies them.
235
-
236
- **Strengths:**
237
- - Generally high accuracy
238
- - Good for detecting objects of various sizes
239
- - Well-established architecture with many pretrained variants
240
-
241
- **Suited for:**
242
- - General object detection tasks
243
- - Scenes with multiple objects of different scales
244
- - When detection accuracy is more important than speed
245
- """
246
- elif model_choice == "DETR":
247
- return """
248
- **DETR (DEtection TRansformer)** is an end-to-end object detection model using transformers.
249
-
250
- **Strengths:**
251
- - Clean, end-to-end architecture without manual anchors or NMS
252
- - Strong spatial reasoning via self-attention
253
- - Good at dealing with occlusion
254
-
255
- **Suited for:**
256
- - Scenes with overlapping objects
257
- - When you need global context understanding
258
- - Modern transformer-based approach to detection
259
- """
260
- elif model_choice == "Compare Both":
261
- return """
262
- **Comparison Mode** runs both Faster R-CNN and DETR side by side to compare their detection results.
263
-
264
- This is useful for:
265
- - Understanding the strengths of each model
266
- - Seeing how detection approaches differ
267
- - Choosing the right model for your specific use case
268
- """
269
- return ""
270
 
271
- # Create Gradio interface
272
- with gr.Blocks(title="Object Detection Model Comparison") as demo:
273
- gr.Markdown("""
274
- # Object Detection Model Comparison
275
-
276
- Upload an image and choose between two state-of-the-art object detection models:
277
- - **Faster R-CNN**: A classic two-stage detector
278
- - **DETR**: A modern transformer-based detector
279
-
280
- Adjust the confidence threshold to control detection sensitivity.
281
- """)
282
-
283
- with gr.Row():
284
- with gr.Column(scale=1):
285
- # Input controls
286
- input_image = gr.Image(type="pil", label="Input Image")
287
- model_dropdown = gr.Dropdown(
288
- choices=["Faster R-CNN", "DETR", "Compare Both"],
289
- value="Compare Both",
290
- label="Detection Model"
291
- )
292
- threshold_slider = gr.Slider(
293
- minimum=0.0, maximum=1.0, value=0.5, step=0.05,
294
- label="Confidence Threshold"
295
- )
296
- detect_button = gr.Button("Detect Objects", variant="primary")
297
-
298
- # Model info box
299
- model_info_box = gr.Markdown()
300
-
301
- with gr.Column(scale=2):
302
- # Output image
303
- output_image = gr.Image(label="Detection Results")
304
-
305
- # Connect components
306
- detect_button.click(
307
- detect_objects,
308
- inputs=[input_image, model_dropdown, threshold_slider],
309
- outputs=output_image
310
- )
311
-
312
- model_dropdown.change(
313
- model_info,
314
- inputs=model_dropdown,
315
- outputs=model_info_box
316
- )
317
-
318
- # Add examples
319
- examples_dir = "/home/user/app"
320
- examples = [
321
- [os.path.join(examples_dir, "TEST_IMG_1.jpg"), "Compare Both", 0.5],
322
- [os.path.join(examples_dir, "TEST_IMG_2.JPG"), "Compare Both", 0.5],
323
- [os.path.join(examples_dir, "TEST_IMG_3.jpg"), "Compare Both", 0.5],
324
- [os.path.join(examples_dir, "TEST_IMG_4.jpg"), "Compare Both", 0.5]
325
- ]
326
-
327
- gr.Examples(
328
- examples=examples,
329
- inputs=[input_image, model_dropdown, threshold_slider],
330
- outputs=output_image,
331
- fn=detect_objects,
332
- cache_examples=False
333
- )
334
 
335
- # Launch the app
336
  if __name__ == "__main__":
337
- demo.launch(debug=True)
 
11
  import sys
12
  import io
13
 
14
+ # Load Faster R-CNN model
15
+ frcnn_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
16
+ frcnn_model.eval()
 
 
 
17
 
18
+ # Load DETR model and processor
 
 
 
 
 
19
  detr_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
20
  detr_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
21
 
 
35
  'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
36
  ]
37
 
38
+ def recommend_model(image):
39
+ """Provide a basic model recommendation based on image characteristics."""
40
+ if image is None:
41
+ return "Please upload an image to get a recommendation."
42
+ try:
43
+ img_array = np.array(image)
44
+ height, width = img_array.shape[:2]
45
+ pixel_variance = np.var(img_array)
46
+ # Basic heuristic: DETR is better for high-resolution, complex images; Faster R-CNN for smaller, simpler ones
47
+ if height * width > 1000 * 1000 or pixel_variance > 1000:
48
+ return "DETR is recommended for high-resolution or complex images."
49
+ else:
50
+ return "Faster R-CNN is recommended for smaller or simpler images."
51
+ except Exception as e:
52
+ return f"Error in recommendation: {str(e)}"
53
+
54
+ def detect_objects_frcnn(image, threshold=0.5):
55
+ """Run Faster R-CNN detection."""
56
  if image is None:
57
+ blank_img = Image.new('RGB', (400, 400), color='white')
58
+ plt.figure(figsize=(10, 10))
59
+ plt.imshow(blank_img)
60
+ plt.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center',
61
+ transform=plt.gca().transAxes, fontsize=20)
62
+ plt.axis('off')
63
+ output_path = "frcnn_blank_output.png"
64
+ plt.savefig(output_path)
65
+ plt.close()
66
+ return output_path
67
+
68
  try:
69
+ threshold = float(threshold) if threshold is not None else 0.5
 
 
 
 
 
70
  transform = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
71
  image_tensor = transform(image).unsqueeze(0)
72
+
 
73
  with torch.no_grad():
74
+ prediction = frcnn_model(image_tensor)[0]
75
+
 
76
  boxes = prediction['boxes'].cpu().numpy()
77
  labels = prediction['labels'].cpu().numpy()
78
  scores = prediction['scores'].cpu().numpy()
79
+
 
80
  image_np = np.array(image)
81
  plt.figure(figsize=(10, 10))
82
  plt.imshow(image_np)
83
  ax = plt.gca()
84
+
 
85
  for box, label, score in zip(boxes, labels, scores):
86
  if score >= threshold:
87
  x1, y1, x2, y2 = box
88
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, color='red', linewidth=2))
 
89
  class_name = COCO_INSTANCE_CATEGORY_NAMES[label]
90
+ ax.text(x1, y1, f'{class_name}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5), fontsize=12, color='black')
91
+
 
 
 
92
  plt.axis('off')
93
  plt.tight_layout()
94
+ output_path = "frcnn_output.png"
95
+ plt.savefig(output_path)
 
 
96
  plt.close()
97
+ return output_path
 
 
 
98
  except Exception as e:
99
+ error_img = Image.new('RGB', (400, 400), color='white')
100
+ plt.figure(figsize=(10, 10))
101
+ plt.imshow(error_img)
102
+ plt.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center',
103
+ transform=plt.gca().transAxes, fontsize=12, wrap=True)
104
+ plt.axis('off')
105
+ error_path = "frcnn_error_output.png"
106
+ plt.savefig(error_path)
107
+ plt.close()
108
+ return error_path
109
 
110
+ def detect_objects_detr(image, threshold=0.9):
111
+ """Run DETR detection."""
112
  if image is None:
113
+ blank_img = Image.new('RGB', (400, 400), color='white')
114
+ fig, ax = plt.subplots(1, figsize=(10, 10))
115
+ ax.imshow(blank_img)
116
+ ax.text(0.5, 0.5, "No image provided", horizontalalignment='center', verticalalignment='center',
117
+ transform=ax.transAxes, fontsize=20)
118
+ plt.axis('off')
119
+ buf = io.BytesIO()
120
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
121
+ plt.close(fig)
122
+ buf.seek(0)
123
+ return Image.open(buf)
124
+
125
  try:
 
 
 
 
 
 
126
  inputs = detr_processor(images=image, return_tensors="pt")
127
  outputs = detr_model(**inputs)
 
 
128
  target_sizes = torch.tensor([image.size[::-1]])
129
+ results = detr_processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0]
130
+
 
 
131
  fig, ax = plt.subplots(1, figsize=(10, 10))
132
  ax.imshow(image)
133
+
 
134
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
135
  xmin, ymin, xmax, ymax = box.tolist()
136
+ ax.add_patch(patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='red', facecolor='none'))
137
+ ax.text(xmin, ymin, f"{detr_model.config.id2label[label.item()]}: {round(score.item(), 2)}",
138
+ bbox=dict(facecolor='yellow', alpha=0.5), fontsize=8)
139
+
 
 
 
140
  plt.axis('off')
141
+ buf = io.BytesIO()
142
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
 
 
 
143
  plt.close(fig)
144
+ buf.seek(0)
145
+ return Image.open(buf)
 
 
146
  except Exception as e:
147
+ error_img = Image.new('RGB', (400, 400), color='white')
148
+ fig, ax = plt.subplots(1, figsize=(10, 10))
149
+ ax.imshow(error_img)
150
+ ax.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center',
151
+ transform=ax.transAxes, fontsize=12, wrap=True)
152
+ plt.axis('off')
153
+ buf = io.BytesIO()
154
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  plt.close(fig)
156
+ buf.seek(0)
157
+ return Image.open(buf)
 
 
 
 
 
 
 
158
 
159
+ def run_detection(image, model_choice, frcnn_threshold=0.5, detr_threshold=0.9):
160
+ """Run detection based on model choice and return results with recommendation."""
161
+ recommendation = recommend_model(image)
162
+ frcnn_result = None
163
+ detr_result = None
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ if model_choice in ["Faster R-CNN", "Both"]:
166
+ frcnn_result = detect_objects_frcnn(image, frcnn_threshold)
167
+ if model_choice in ["DETR", "Both"]:
168
+ detr_result = detect_objects_detr(image, detr_threshold)
 
 
 
 
 
 
 
 
 
169
 
170
+ return recommendation, frcnn_result, detr_result
171
+
172
+ # Example image paths
173
+ examples = [
174
+ os.path.join("/home/user/app", "TEST_IMG_1.jpg"),
175
+ os.path.join("/home/user/app", "TEST_IMG_2.JPG"),
176
+ os.path.join("/home/user/app", "TEST_IMG_3.jpg"),
177
+ os.path.join("/home/user/app", "TEST_IMG_4.jpg")
178
+ ]
179
+ example_list = [[path] for path in examples if os.path.exists(path)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ # Gradio interface
182
+ interface = gr.Interface(
183
+ fn=run_detection,
184
+ inputs=[
185
+ gr.Image(type="pil", label="Input Image"),
186
+ gr.Dropdown(choices=["Faster R-CNN", "DETR", "Both"], label="Model Choice", value="Both"),
187
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Faster R-CNN Confidence Threshold"),
188
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="DETR Confidence Threshold")
189
+ ],
190
+ outputs=[
191
+ gr.Textbox(label="Model Recommendation"),
192
+ gr.Image(type="filepath", label="Faster R-CNN Result"),
193
+ gr.Image(type="pil", label="DETR Result")
194
+ ],
195
+ title="Object Detection: Faster R-CNN vs DETR",
196
+ description="Upload an image, select a model (or both), and view object detection results. A recommendation is provided based on image characteristics.",
197
+ examples=example_list,
198
+ cache_examples=False
199
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
 
201
  if __name__ == "__main__":
202
+ interface.launch(debug=True)