JohnJoelMota commited on
Commit
5df6d63
·
verified ·
1 Parent(s): fba3adc
Files changed (1) hide show
  1. app.py +340 -111
app.py CHANGED
@@ -1,17 +1,38 @@
1
  import torch
2
  import torchvision
3
- from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
 
4
  from PIL import Image
5
  import numpy as np
6
  import matplotlib.pyplot as plt
 
7
  import gradio as gr
8
  import os
9
  import sys
10
-
11
- # Load the pre-trained model once
12
- model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
13
- model.eval()
14
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # COCO class names
16
  COCO_INSTANCE_CATEGORY_NAMES = [
17
  '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
@@ -26,118 +47,326 @@ COCO_INSTANCE_CATEGORY_NAMES = [
26
  'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
27
  'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
28
  'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
29
- ]
30
 
31
- # Gradio-compatible detection function
32
- def detect_objects(image, threshold=0.5):
33
- if image is None:
34
- print("Image is None, returning empty output", file=sys.stderr)
35
- # Create a blank image as output
36
- blank_img = Image.new('RGB', (400, 400), color='white')
37
- plt.figure(figsize=(10, 10))
38
- plt.imshow(blank_img)
39
- plt.text(0.5, 0.5, "No image provided",
40
- horizontalalignment='center', verticalalignment='center',
41
- transform=plt.gca().transAxes, fontsize=20)
42
- plt.axis('off')
43
- output_path = "blank_output.png"
44
- plt.savefig(output_path)
45
- plt.close()
46
- return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  try:
49
- print(f"Processing image of type {type(image)} and threshold {threshold}", file=sys.stderr)
50
- # Make sure threshold is a valid number
51
- if threshold is None:
52
- threshold = 0.5
53
- print("Threshold was None, using default 0.5", file=sys.stderr)
54
-
55
- # Convert threshold to float if it's not already
56
- threshold = float(threshold)
57
-
58
- transform = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
59
- image_tensor = transform(image).unsqueeze(0)
60
-
61
- with torch.no_grad():
62
- prediction = model(image_tensor)[0]
63
-
64
- boxes = prediction['boxes'].cpu().numpy()
65
- labels = prediction['labels'].cpu().numpy()
66
- scores = prediction['scores'].cpu().numpy()
67
-
68
- image_np = np.array(image)
69
- plt.figure(figsize=(10, 10))
70
- plt.imshow(image_np)
71
- ax = plt.gca()
72
-
73
- for box, label, score in zip(boxes, labels, scores):
74
- # Explicit debug prints to trace the comparison issue
75
- print(f"Score: {score}, Threshold: {threshold}, Type: {type(score)}/{type(threshold)}", file=sys.stderr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- if score >= threshold:
78
- x1, y1, x2, y2 = box
79
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
80
- fill=False, color='red', linewidth=2))
81
- class_name = COCO_INSTANCE_CATEGORY_NAMES[label]
82
- ax.text(x1, y1, f'{class_name}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5),
83
- fontsize=12, color='black')
84
-
85
- plt.axis('off')
86
- plt.tight_layout()
87
-
88
- # Save the figure to return
89
- output_path = "output.png"
90
- plt.savefig(output_path)
91
- plt.close()
92
- return output_path
93
  except Exception as e:
94
- print(f"Error in detect_objects: {e}", file=sys.stderr)
95
  import traceback
96
  traceback.print_exc(file=sys.stderr)
97
-
98
- # Create an error image
99
- error_img = Image.new('RGB', (400, 400), color='white')
100
- plt.figure(figsize=(10, 10))
101
- plt.imshow(error_img)
102
- plt.text(0.5, 0.5, f"Error: {str(e)}",
103
- horizontalalignment='center', verticalalignment='center',
104
- transform=plt.gca().transAxes, fontsize=12, wrap=True)
105
- plt.axis('off')
106
- error_path = "error_output.png"
107
- plt.savefig(error_path)
108
- plt.close()
109
- return error_path
110
-
111
- # Create direct file paths for examples
112
- # These exact filenames match what's visible in your repository
113
- examples = [
114
- os.path.join("/home/user/app", "TEST_IMG_1.jpg"),
115
- os.path.join("/home/user/app", "TEST_IMG_2.JPG"),
116
- os.path.join("/home/user/app", "TEST_IMG_3.jpg"),
117
- os.path.join("/home/user/app", "TEST_IMG_4.jpg")
118
- ]
119
 
120
- # Create Gradio interface
121
- # Important: For Gradio examples, we need to create a list of lists
122
- example_list = [[path] for path in examples if os.path.exists(path)]
123
-
124
- print(f"Found {len(example_list)} valid examples: {example_list}", file=sys.stderr)
125
-
126
- # Create Gradio interface with a simplified approach
127
- interface = gr.Interface(
128
- fn=detect_objects,
129
- inputs=[
130
- gr.Image(type="pil", label="Input Image"),
131
- gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold")
132
- ],
133
- outputs=gr.Image(type="filepath", label="Detected Objects"),
134
- title="Faster R-CNN Object Detection",
135
- description="Upload an image to detect objects using a pretrained Faster R-CNN model.",
136
- examples=example_list,
137
- cache_examples=False # Disable caching to avoid potential issues
138
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- # Launch with specific configuration for Hugging Face
141
  if __name__ == "__main__":
142
- # Launch with debug mode enabled
143
- interface.launch(debug=True)
 
1
  import torch
2
  import torchvision
3
+ from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights, MaskRCNN_ResNet50_FPN_Weights
4
+ from torchvision.transforms import functional as F
5
  from PIL import Image
6
  import numpy as np
7
  import matplotlib.pyplot as plt
8
+ import matplotlib.patches as patches
9
  import gradio as gr
10
  import os
11
  import sys
12
+ import random
13
+ from typing import Tuple, List, Dict, Any, Optional
14
+
15
+
16
+ # Load models only once
17
+ def load_models():
18
+ print("Loading detection models...", file=sys.stderr)
19
+ # Model 1: Faster R-CNN
20
+ model1 = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
21
+ model1.eval()
22
+
23
+ # Model 2: RetinaNet
24
+ model2 = torchvision.models.detection.retinanet_resnet50_fpn_v2(weights=torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT)
25
+ model2.eval()
26
+
27
+ # Segmentation model
28
+ seg_model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)
29
+ seg_model.eval()
30
+
31
+ return model1, model2, seg_model
32
+
33
+ # Global models
34
+ MODEL1, MODEL2, SEG_MODEL = load_models()
35
+
36
  # COCO class names
37
  COCO_INSTANCE_CATEGORY_NAMES = [
38
  '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
 
47
  'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
48
  'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
49
  'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
50
+ ]
51
 
52
+ def get_prediction(model, image, threshold=0.5):
53
+ """Get prediction from model"""
54
+ transform = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
55
+ image_tensor = transform(image).unsqueeze(0)
56
+
57
+ with torch.no_grad():
58
+ prediction = model(image_tensor)[0]
59
+
60
+ boxes = prediction['boxes'].cpu().numpy()
61
+ labels = prediction['labels'].cpu().numpy()
62
+ scores = prediction['scores'].cpu().numpy()
63
+
64
+ # Filter by threshold
65
+ keep = scores >= threshold
66
+ boxes = boxes[keep]
67
+ labels = labels[keep]
68
+ scores = scores[keep]
69
+
70
+ return boxes, labels, scores
71
+
72
+ def get_segmentation_prediction(model, image, threshold=0.5):
73
+ """Get segmentation prediction"""
74
+ transform = MaskRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
75
+ image_tensor = transform(image).unsqueeze(0)
76
+
77
+ with torch.no_grad():
78
+ prediction = model(image_tensor)[0]
79
+
80
+ boxes = prediction['boxes'].cpu().numpy()
81
+ labels = prediction['labels'].cpu().numpy()
82
+ scores = prediction['scores'].cpu().numpy()
83
+ masks = prediction['masks'].cpu().numpy()
84
+
85
+ # Filter by threshold
86
+ keep = scores >= threshold
87
+ boxes = boxes[keep]
88
+ labels = labels[keep]
89
+ scores = scores[keep]
90
+ masks = masks[keep]
91
+
92
+ return boxes, labels, scores, masks
93
+
94
+ def visualize_detection(image, boxes, labels, scores, title="Detection Results"):
95
+ """Visualize detection results"""
96
+ image_np = np.array(image)
97
+ plt.figure(figsize=(10, 10))
98
+ plt.imshow(image_np)
99
+ ax = plt.gca()
100
+
101
+ for box, label, score in zip(boxes, labels, scores):
102
+ x1, y1, x2, y2 = box
103
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
104
+ fill=False, color='red', linewidth=2))
105
+ class_name = COCO_INSTANCE_CATEGORY_NAMES[label]
106
+ ax.text(x1, y1, f'{class_name}: {score:.2f}',
107
+ bbox=dict(facecolor='yellow', alpha=0.5),
108
+ fontsize=12, color='black')
109
+
110
+ plt.title(title)
111
+ plt.axis('off')
112
+ plt.tight_layout()
113
+
114
+ # Save the figure
115
+ output_path = f"{title.replace(' ', '_').lower()}.png"
116
+ plt.savefig(output_path)
117
+ plt.close()
118
+ return output_path
119
+
120
+ def visualize_segmentation(image, boxes, labels, scores, masks, title="Segmentation Results"):
121
+ """Visualize segmentation results"""
122
+ image_np = np.array(image)
123
+ plt.figure(figsize=(10, 10))
124
+ plt.imshow(image_np)
125
+ ax = plt.gca()
126
+
127
+ # Random colors for masks
128
+ colors = plt.cm.rainbow(np.linspace(0, 1, len(masks)))
129
+
130
+ for box, label, score, mask, color in zip(boxes, labels, scores, masks, colors):
131
+ # Draw bounding box
132
+ x1, y1, x2, y2 = box
133
+ rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
134
+ linewidth=2, edgecolor='r', facecolor='none')
135
+ ax.add_patch(rect)
136
 
137
+ # Add text
138
+ class_name = COCO_INSTANCE_CATEGORY_NAMES[label]
139
+ ax.text(x1, y1-10, f'{class_name}: {score:.2f}',
140
+ bbox=dict(facecolor='yellow', alpha=0.5),
141
+ fontsize=12, color='black')
142
+
143
+ # Draw mask
144
+ mask_image = mask[0, :, :] # First channel
145
+ mask_overlay = np.zeros_like(image_np, dtype=np.uint8)
146
+ for c in range(3):
147
+ mask_overlay[:, :, c] = np.where(mask_image > 0.5,
148
+ int(color[c] * 255), 0)
149
+
150
+ # Add mask with transparency
151
+ alpha = 0.5
152
+ mask_bool = mask_image > 0.5
153
+ for c in range(3):
154
+ image_np[:, :, c] = np.where(
155
+ mask_bool,
156
+ image_np[:, :, c] * (1-alpha) + mask_overlay[:, :, c] * alpha,
157
+ image_np[:, :, c]
158
+ )
159
+
160
+ plt.imshow(image_np)
161
+ plt.title(title)
162
+ plt.axis('off')
163
+ plt.tight_layout()
164
+
165
+ # Save the figure
166
+ output_path = f"{title.replace(' ', '_').lower()}.png"
167
+ plt.savefig(output_path)
168
+ plt.close()
169
+ return output_path
170
+
171
+ def calculate_metrics(boxes, labels, scores):
172
+ """Calculate simple metrics for model comparison"""
173
+ # In a real app, you'd use proper metrics like mAP
174
+ # For simplicity, we'll use:
175
+ # 1. Number of detections
176
+ # 2. Average confidence score
177
+ # 3. Number of unique classes detected
178
+
179
+ num_detections = len(boxes)
180
+ avg_confidence = np.mean(scores) if len(scores) > 0 else 0
181
+ unique_classes = len(set(labels))
182
+
183
+ return {
184
+ "num_detections": num_detections,
185
+ "avg_confidence": avg_confidence,
186
+ "unique_classes": unique_classes,
187
+ "total_score": num_detections * avg_confidence + unique_classes # Simple combined metric
188
+ }
189
+
190
+ def process_game(image, task_type, user_prediction, confidence_threshold=0.5):
191
+ """Main game function that processes the image based on selected task type"""
192
+ if image is None:
193
+ return {
194
+ "status": "error",
195
+ "message": "Please upload an image to continue."
196
+ }, None, None, None, None
197
+
198
  try:
199
+ if task_type == "Object Detection":
200
+ # Model 1: Faster R-CNN
201
+ boxes1, labels1, scores1 = get_prediction(MODEL1, image, confidence_threshold)
202
+ result1 = visualize_detection(image, boxes1, labels1, scores1, "Faster R-CNN Results")
203
+ metrics1 = calculate_metrics(boxes1, labels1, scores1)
204
+
205
+ # Model 2: RetinaNet
206
+ boxes2, labels2, scores2 = get_prediction(MODEL2, image, confidence_threshold)
207
+ result2 = visualize_detection(image, boxes2, labels2, scores2, "RetinaNet Results")
208
+ metrics2 = calculate_metrics(boxes2, labels2, scores2)
209
+
210
+ # Determine winner
211
+ score1 = metrics1["total_score"]
212
+ score2 = metrics2["total_score"]
213
+
214
+ if score1 > score2:
215
+ winner = "Model 1 (Faster R-CNN)"
216
+ winning_score = score1
217
+ losing_score = score2
218
+ elif score2 > score1:
219
+ winner = "Model 2 (RetinaNet)"
220
+ winning_score = score2
221
+ losing_score = score1
222
+ else:
223
+ winner = "Tie"
224
+ winning_score = score1
225
+ losing_score = score2
226
+
227
+ user_correct = (user_prediction == "Model 1" and winner == "Model 1 (Faster R-CNN)") or \
228
+ (user_prediction == "Model 2" and winner == "Model 2 (RetinaNet)") or \
229
+ (user_prediction == "Tie" and winner == "Tie")
230
+
231
+ result_message = f"Winner: {winner} (Score: {winning_score:.2f} vs {losing_score:.2f})\n"
232
+ result_message += f"Your prediction: {user_prediction} - {'Correct!' if user_correct else 'Incorrect!'}\n\n"
233
+ result_message += f"Model 1 detected {metrics1['num_detections']} objects with {metrics1['unique_classes']} unique classes.\n"
234
+ result_message += f"Model 2 detected {metrics2['num_detections']} objects with {metrics2['unique_classes']} unique classes."
235
+
236
+ return {"status": "success", "message": result_message}, result1, result2, None, None
237
+
238
+ elif task_type == "Instance Segmentation":
239
+ # Only using one model for segmentation for now
240
+ boxes, labels, scores, masks = get_segmentation_prediction(SEG_MODEL, image, confidence_threshold)
241
+ seg_result = visualize_segmentation(image, boxes, labels, scores, masks, "Mask R-CNN Results")
242
+
243
+ # Also get detection results for comparison
244
+ boxes1, labels1, scores1 = get_prediction(MODEL1, image, confidence_threshold)
245
+ det_result = visualize_detection(image, boxes1, labels1, scores1, "Detection Results")
246
+
247
+ metrics_seg = calculate_metrics(boxes, labels, scores)
248
+ metrics_det = calculate_metrics(boxes1, labels1, scores1)
249
+
250
+ result_message = f"Segmentation detected {metrics_seg['num_detections']} objects with {metrics_seg['unique_classes']} unique classes.\n"
251
+ result_message += f"The segmentation model provides pixel-level masks for each detected object."
252
+
253
+ return {"status": "success", "message": result_message}, None, None, det_result, seg_result
254
+
255
+ else:
256
+ return {"status": "error", "message": "Invalid task type selected."}, None, None, None, None
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  except Exception as e:
259
+ print(f"Error in process_game: {e}", file=sys.stderr)
260
  import traceback
261
  traceback.print_exc(file=sys.stderr)
262
+ return {"status": "error", "message": f"Error processing image: {str(e)}"}, None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
+ def create_ui():
265
+ """Create the Gradio UI for the game"""
266
+ with gr.Blocks(title="Object Detection Game") as app:
267
+ gr.Markdown("# 🎮 Computer Vision Model Comparison Game")
268
+ gr.Markdown("Upload an image, choose a task, and predict which model will perform better!")
269
+
270
+ with gr.Row():
271
+ with gr.Column(scale=1):
272
+ # Input components
273
+ input_image = gr.Image(type="pil", label="Upload Image")
274
+ task_type = gr.Radio(
275
+ ["Object Detection", "Instance Segmentation"],
276
+ label="Select Task",
277
+ value="Object Detection"
278
+ )
279
+
280
+ with gr.Row():
281
+ with gr.Column(scale=1, visible=True) as detection_options:
282
+ user_prediction = gr.Radio(
283
+ ["Model 1", "Model 2", "Tie"],
284
+ label="Which model will perform better?",
285
+ value="Model 1"
286
+ )
287
+
288
+ confidence = gr.Slider(
289
+ minimum=0.0, maximum=1.0, value=0.5, step=0.05,
290
+ label="Confidence Threshold"
291
+ )
292
+
293
+ submit_btn = gr.Button("Run Comparison", variant="primary")
294
+
295
+ with gr.Column(scale=1):
296
+ # Output components
297
+ result_msg = gr.JSON(label="Results")
298
+
299
+ # Detection results
300
+ with gr.Row(visible=True) as detection_results:
301
+ model1_output = gr.Image(type="filepath", label="Model 1 (Faster R-CNN)")
302
+ model2_output = gr.Image(type="filepath", label="Model 2 (RetinaNet)")
303
+
304
+ # Segmentation results
305
+ with gr.Row(visible=False) as segmentation_results:
306
+ detection_output = gr.Image(type="filepath", label="Detection")
307
+ segmentation_output = gr.Image(type="filepath", label="Segmentation")
308
+
309
+ # Example images
310
+ examples = [
311
+ os.path.join("/home/user/app", "TEST_IMG_1.jpg"),
312
+ os.path.join("/home/user/app", "TEST_IMG_2.JPG"),
313
+ os.path.join("/home/user/app", "TEST_IMG_3.jpg"),
314
+ os.path.join("/home/user/app", "TEST_IMG_4.jpg")
315
+ ]
316
+
317
+ # Filter to valid examples
318
+ example_list = [ex for ex in examples if os.path.exists(ex)]
319
+
320
+ if example_list:
321
+ gr.Examples(
322
+ examples=example_list,
323
+ inputs=input_image,
324
+ label="Example Images"
325
+ )
326
+
327
+ # Event handlers
328
+ def update_task_visibility(task):
329
+ if task == "Object Detection":
330
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
331
+ else:
332
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
333
+
334
+ task_type.change(
335
+ fn=update_task_visibility,
336
+ inputs=task_type,
337
+ outputs=[detection_options, segmentation_results, detection_results, segmentation_results]
338
+ )
339
+
340
+ # Submit button click event
341
+ submit_btn.click(
342
+ fn=process_game,
343
+ inputs=[input_image, task_type, user_prediction, confidence],
344
+ outputs=[result_msg, model1_output, model2_output, detection_output, segmentation_output]
345
+ )
346
+
347
+ # Add markdown information about the models
348
+ gr.Markdown("""
349
+ ## About the Models
350
+
351
+ ### Detection Models:
352
+ - **Model 1** is Faster R-CNN with ResNet50 backbone, a two-stage detector that's accurate but relatively slower.
353
+ - **Model 2** is RetinaNet with ResNet50 backbone, a one-stage detector that's designed for better speed-accuracy trade-off.
354
+
355
+ ### Instance Segmentation:
356
+ - The segmentation model is Mask R-CNN with ResNet50 backbone, which provides pixel-level masks in addition to bounding boxes.
357
+
358
+ ### How is the winner determined?
359
+ The winner is determined based on a combined score of:
360
+ 1. Number of objects detected
361
+ 2. Average confidence score
362
+ 3. Number of unique classes detected
363
+
364
+ Can you predict which model will perform better on your image?
365
+ """)
366
+
367
+ return app
368
 
369
+ # Launch the app
370
  if __name__ == "__main__":
371
+ app = create_ui()
372
+ app.launch(debug=True)