JohnJoelMota commited on
Commit
f826d0c
·
verified ·
1 Parent(s): 5df6d63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -348
app.py CHANGED
@@ -1,372 +1,259 @@
1
- import torch
2
- import torchvision
3
- from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights, MaskRCNN_ResNet50_FPN_Weights
4
- from torchvision.transforms import functional as F
5
- from PIL import Image
6
- import numpy as np
7
- import matplotlib.pyplot as plt
8
- import matplotlib.patches as patches
9
- import gradio as gr
10
- import os
11
  import sys
12
- import random
13
- from typing import Tuple, List, Dict, Any, Optional
14
 
 
 
 
 
15
 
16
- # Load models only once
17
- def load_models():
18
- print("Loading detection models...", file=sys.stderr)
19
- # Model 1: Faster R-CNN
20
- model1 = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
21
- model1.eval()
22
-
23
- # Model 2: RetinaNet
24
- model2 = torchvision.models.detection.retinanet_resnet50_fpn_v2(weights=torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT)
25
- model2.eval()
26
-
27
- # Segmentation model
28
- seg_model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)
29
- seg_model.eval()
30
-
31
- return model1, model2, seg_model
32
 
33
- # Global models
34
- MODEL1, MODEL2, SEG_MODEL = load_models()
35
-
36
- # COCO class names
37
- COCO_INSTANCE_CATEGORY_NAMES = [
38
- '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
39
- 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
40
- 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
41
- 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
42
- 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
43
- 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
44
- 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
45
- 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
46
- 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
47
- 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
48
- 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
49
- 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
50
  ]
51
 
52
- def get_prediction(model, image, threshold=0.5):
53
- """Get prediction from model"""
54
- transform = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
55
- image_tensor = transform(image).unsqueeze(0)
56
-
57
- with torch.no_grad():
58
- prediction = model(image_tensor)[0]
59
-
60
- boxes = prediction['boxes'].cpu().numpy()
61
- labels = prediction['labels'].cpu().numpy()
62
- scores = prediction['scores'].cpu().numpy()
63
-
64
- # Filter by threshold
65
- keep = scores >= threshold
66
- boxes = boxes[keep]
67
- labels = labels[keep]
68
- scores = scores[keep]
69
-
70
- return boxes, labels, scores
71
-
72
- def get_segmentation_prediction(model, image, threshold=0.5):
73
- """Get segmentation prediction"""
74
- transform = MaskRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
75
- image_tensor = transform(image).unsqueeze(0)
76
-
77
- with torch.no_grad():
78
- prediction = model(image_tensor)[0]
79
-
80
- boxes = prediction['boxes'].cpu().numpy()
81
- labels = prediction['labels'].cpu().numpy()
82
- scores = prediction['scores'].cpu().numpy()
83
- masks = prediction['masks'].cpu().numpy()
84
-
85
- # Filter by threshold
86
- keep = scores >= threshold
87
- boxes = boxes[keep]
88
- labels = labels[keep]
89
- scores = scores[keep]
90
- masks = masks[keep]
91
-
92
- return boxes, labels, scores, masks
93
 
94
- def visualize_detection(image, boxes, labels, scores, title="Detection Results"):
95
- """Visualize detection results"""
96
- image_np = np.array(image)
97
- plt.figure(figsize=(10, 10))
98
- plt.imshow(image_np)
99
- ax = plt.gca()
100
-
101
- for box, label, score in zip(boxes, labels, scores):
102
- x1, y1, x2, y2 = box
103
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
104
- fill=False, color='red', linewidth=2))
105
- class_name = COCO_INSTANCE_CATEGORY_NAMES[label]
106
- ax.text(x1, y1, f'{class_name}: {score:.2f}',
107
- bbox=dict(facecolor='yellow', alpha=0.5),
108
- fontsize=12, color='black')
109
-
110
- plt.title(title)
111
- plt.axis('off')
112
- plt.tight_layout()
113
-
114
- # Save the figure
115
- output_path = f"{title.replace(' ', '_').lower()}.png"
116
- plt.savefig(output_path)
117
- plt.close()
118
- return output_path
119
 
120
- def visualize_segmentation(image, boxes, labels, scores, masks, title="Segmentation Results"):
121
- """Visualize segmentation results"""
122
- image_np = np.array(image)
123
- plt.figure(figsize=(10, 10))
124
- plt.imshow(image_np)
125
- ax = plt.gca()
126
-
127
- # Random colors for masks
128
- colors = plt.cm.rainbow(np.linspace(0, 1, len(masks)))
129
-
130
- for box, label, score, mask, color in zip(boxes, labels, scores, masks, colors):
131
- # Draw bounding box
132
- x1, y1, x2, y2 = box
133
- rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
134
- linewidth=2, edgecolor='r', facecolor='none')
135
- ax.add_patch(rect)
136
-
137
- # Add text
138
- class_name = COCO_INSTANCE_CATEGORY_NAMES[label]
139
- ax.text(x1, y1-10, f'{class_name}: {score:.2f}',
140
- bbox=dict(facecolor='yellow', alpha=0.5),
141
- fontsize=12, color='black')
142
-
143
- # Draw mask
144
- mask_image = mask[0, :, :] # First channel
145
- mask_overlay = np.zeros_like(image_np, dtype=np.uint8)
146
- for c in range(3):
147
- mask_overlay[:, :, c] = np.where(mask_image > 0.5,
148
- int(color[c] * 255), 0)
149
 
150
- # Add mask with transparency
151
- alpha = 0.5
152
- mask_bool = mask_image > 0.5
153
- for c in range(3):
154
- image_np[:, :, c] = np.where(
155
- mask_bool,
156
- image_np[:, :, c] * (1-alpha) + mask_overlay[:, :, c] * alpha,
157
- image_np[:, :, c]
158
- )
159
-
160
- plt.imshow(image_np)
161
- plt.title(title)
162
- plt.axis('off')
163
- plt.tight_layout()
164
-
165
- # Save the figure
166
- output_path = f"{title.replace(' ', '_').lower()}.png"
167
- plt.savefig(output_path)
168
- plt.close()
169
- return output_path
170
 
171
- def calculate_metrics(boxes, labels, scores):
172
- """Calculate simple metrics for model comparison"""
173
- # In a real app, you'd use proper metrics like mAP
174
- # For simplicity, we'll use:
175
- # 1. Number of detections
176
- # 2. Average confidence score
177
- # 3. Number of unique classes detected
178
-
179
- num_detections = len(boxes)
180
- avg_confidence = np.mean(scores) if len(scores) > 0 else 0
181
- unique_classes = len(set(labels))
182
-
183
- return {
184
- "num_detections": num_detections,
185
- "avg_confidence": avg_confidence,
186
- "unique_classes": unique_classes,
187
- "total_score": num_detections * avg_confidence + unique_classes # Simple combined metric
188
- }
189
 
190
- def process_game(image, task_type, user_prediction, confidence_threshold=0.5):
191
- """Main game function that processes the image based on selected task type"""
192
- if image is None:
193
- return {
194
- "status": "error",
195
- "message": "Please upload an image to continue."
196
- }, None, None, None, None
197
-
198
- try:
199
- if task_type == "Object Detection":
200
- # Model 1: Faster R-CNN
201
- boxes1, labels1, scores1 = get_prediction(MODEL1, image, confidence_threshold)
202
- result1 = visualize_detection(image, boxes1, labels1, scores1, "Faster R-CNN Results")
203
- metrics1 = calculate_metrics(boxes1, labels1, scores1)
204
-
205
- # Model 2: RetinaNet
206
- boxes2, labels2, scores2 = get_prediction(MODEL2, image, confidence_threshold)
207
- result2 = visualize_detection(image, boxes2, labels2, scores2, "RetinaNet Results")
208
- metrics2 = calculate_metrics(boxes2, labels2, scores2)
209
-
210
- # Determine winner
211
- score1 = metrics1["total_score"]
212
- score2 = metrics2["total_score"]
213
-
214
- if score1 > score2:
215
- winner = "Model 1 (Faster R-CNN)"
216
- winning_score = score1
217
- losing_score = score2
218
- elif score2 > score1:
219
- winner = "Model 2 (RetinaNet)"
220
- winning_score = score2
221
- losing_score = score1
222
- else:
223
- winner = "Tie"
224
- winning_score = score1
225
- losing_score = score2
226
-
227
- user_correct = (user_prediction == "Model 1" and winner == "Model 1 (Faster R-CNN)") or \
228
- (user_prediction == "Model 2" and winner == "Model 2 (RetinaNet)") or \
229
- (user_prediction == "Tie" and winner == "Tie")
230
-
231
- result_message = f"Winner: {winner} (Score: {winning_score:.2f} vs {losing_score:.2f})\n"
232
- result_message += f"Your prediction: {user_prediction} - {'Correct!' if user_correct else 'Incorrect!'}\n\n"
233
- result_message += f"Model 1 detected {metrics1['num_detections']} objects with {metrics1['unique_classes']} unique classes.\n"
234
- result_message += f"Model 2 detected {metrics2['num_detections']} objects with {metrics2['unique_classes']} unique classes."
235
-
236
- return {"status": "success", "message": result_message}, result1, result2, None, None
237
-
238
- elif task_type == "Instance Segmentation":
239
- # Only using one model for segmentation for now
240
- boxes, labels, scores, masks = get_segmentation_prediction(SEG_MODEL, image, confidence_threshold)
241
- seg_result = visualize_segmentation(image, boxes, labels, scores, masks, "Mask R-CNN Results")
242
-
243
- # Also get detection results for comparison
244
- boxes1, labels1, scores1 = get_prediction(MODEL1, image, confidence_threshold)
245
- det_result = visualize_detection(image, boxes1, labels1, scores1, "Detection Results")
246
-
247
- metrics_seg = calculate_metrics(boxes, labels, scores)
248
- metrics_det = calculate_metrics(boxes1, labels1, scores1)
249
-
250
- result_message = f"Segmentation detected {metrics_seg['num_detections']} objects with {metrics_seg['unique_classes']} unique classes.\n"
251
- result_message += f"The segmentation model provides pixel-level masks for each detected object."
252
-
253
- return {"status": "success", "message": result_message}, None, None, det_result, seg_result
254
 
255
- else:
256
- return {"status": "error", "message": "Invalid task type selected."}, None, None, None, None
257
-
 
 
 
 
 
 
 
 
 
 
258
  except Exception as e:
259
- print(f"Error in process_game: {e}", file=sys.stderr)
260
- import traceback
261
  traceback.print_exc(file=sys.stderr)
262
- return {"status": "error", "message": f"Error processing image: {str(e)}"}, None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
- def create_ui():
265
- """Create the Gradio UI for the game"""
266
- with gr.Blocks(title="Object Detection Game") as app:
267
- gr.Markdown("# 🎮 Computer Vision Model Comparison Game")
268
- gr.Markdown("Upload an image, choose a task, and predict which model will perform better!")
 
 
 
 
 
269
 
270
- with gr.Row():
271
- with gr.Column(scale=1):
272
- # Input components
273
- input_image = gr.Image(type="pil", label="Upload Image")
274
- task_type = gr.Radio(
275
- ["Object Detection", "Instance Segmentation"],
276
- label="Select Task",
277
- value="Object Detection"
278
- )
279
-
280
- with gr.Row():
281
- with gr.Column(scale=1, visible=True) as detection_options:
282
- user_prediction = gr.Radio(
283
- ["Model 1", "Model 2", "Tie"],
284
- label="Which model will perform better?",
285
- value="Model 1"
286
- )
287
-
288
- confidence = gr.Slider(
289
- minimum=0.0, maximum=1.0, value=0.5, step=0.05,
290
- label="Confidence Threshold"
291
- )
292
-
293
- submit_btn = gr.Button("Run Comparison", variant="primary")
294
 
295
- with gr.Column(scale=1):
296
- # Output components
297
- result_msg = gr.JSON(label="Results")
 
 
298
 
299
- # Detection results
300
- with gr.Row(visible=True) as detection_results:
301
- model1_output = gr.Image(type="filepath", label="Model 1 (Faster R-CNN)")
302
- model2_output = gr.Image(type="filepath", label="Model 2 (RetinaNet)")
303
-
304
- # Segmentation results
305
- with gr.Row(visible=False) as segmentation_results:
306
- detection_output = gr.Image(type="filepath", label="Detection")
307
- segmentation_output = gr.Image(type="filepath", label="Segmentation")
308
-
309
- # Example images
310
- examples = [
311
- os.path.join("/home/user/app", "TEST_IMG_1.jpg"),
312
- os.path.join("/home/user/app", "TEST_IMG_2.JPG"),
313
- os.path.join("/home/user/app", "TEST_IMG_3.jpg"),
314
- os.path.join("/home/user/app", "TEST_IMG_4.jpg")
315
- ]
316
-
317
- # Filter to valid examples
318
- example_list = [ex for ex in examples if os.path.exists(ex)]
319
-
320
- if example_list:
321
- gr.Examples(
322
- examples=example_list,
323
- inputs=input_image,
324
- label="Example Images"
325
- )
326
-
327
- # Event handlers
328
- def update_task_visibility(task):
329
- if task == "Object Detection":
330
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
331
- else:
332
- return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
333
-
334
- task_type.change(
335
- fn=update_task_visibility,
336
- inputs=task_type,
337
- outputs=[detection_options, segmentation_results, detection_results, segmentation_results]
338
  )
339
-
340
- # Submit button click event
341
- submit_btn.click(
342
- fn=process_game,
343
- inputs=[input_image, task_type, user_prediction, confidence],
344
- outputs=[result_msg, model1_output, model2_output, detection_output, segmentation_output]
 
 
 
 
345
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
- # Add markdown information about the models
348
- gr.Markdown("""
349
- ## About the Models
350
-
351
- ### Detection Models:
352
- - **Model 1** is Faster R-CNN with ResNet50 backbone, a two-stage detector that's accurate but relatively slower.
353
- - **Model 2** is RetinaNet with ResNet50 backbone, a one-stage detector that's designed for better speed-accuracy trade-off.
354
-
355
- ### Instance Segmentation:
356
- - The segmentation model is Mask R-CNN with ResNet50 backbone, which provides pixel-level masks in addition to bounding boxes.
357
-
358
- ### How is the winner determined?
359
- The winner is determined based on a combined score of:
360
- 1. Number of objects detected
361
- 2. Average confidence score
362
- 3. Number of unique classes detected
363
-
364
- Can you predict which model will perform better on your image?
365
- """)
366
 
367
- return app
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
- # Launch the app
370
  if __name__ == "__main__":
371
- app = create_ui()
372
- app.launch(debug=True)
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import numpy as np
6
+ import gradio as gr
7
+ import os
 
 
 
8
  import sys
9
+ import uuid # For unique filenames
10
+ import traceback # For detailed error logging
11
 
12
+ # --- Model Loading ---
13
+ # Model A
14
+ model_A = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
15
+ model_A.eval()
16
 
17
+ # Model B (same architecture, will use a different threshold in practice)
18
+ model_B = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
19
+ model_B.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # --- COCO Class Names ---
22
+ COCO_INSTANCE_CATEGORY_NAMES = [
23
+ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
24
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
25
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
26
+ 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
27
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
28
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
29
+ 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
30
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
31
+ 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
32
+ 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
33
+ 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
34
+ 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
 
 
 
35
  ]
36
 
37
+ # --- Helper Functions ---
38
+ def get_font(size=15):
39
+ """Attempts to load Arial font, falls back to PIL default."""
40
+ try:
41
+ return ImageFont.truetype("arial.ttf", size)
42
+ except IOError:
43
+ return ImageFont.load_default()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def run_detection_on_image(image_pil, threshold, model_instance, model_name_str="Model"):
46
+ """
47
+ Runs object detection on a PIL image and returns the path to the annotated image.
48
+ Uses PIL for all drawing operations.
49
+ """
50
+ if image_pil is None:
51
+ print(f"{model_name_str}: Image is None, returning placeholder.", file=sys.stderr)
52
+ placeholder_img = Image.new('RGB', (400, 300), color='lightgray')
53
+ draw = ImageDraw.Draw(placeholder_img)
54
+ font = get_font(15)
55
+ text = f"{model_name_str}:\nNo image provided."
56
+ try:
57
+ bbox = draw.textbbox((0,0), text, font=font, align="center")
58
+ text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
59
+ except AttributeError: # Fallback for older Pillow
60
+ text_width = draw.textlength(text.split('\n')[0], font=font)
61
+ text_height = (font.getmetrics()[0] + font.getmetrics()[1]) * text.count('\n') + font.getmetrics()[0]
62
+ draw.text(((400 - text_width) / 2, (300 - text_height) / 2), text, fill="black", font=font, align="center")
63
+ output_filename = f"placeholder_{model_name_str.lower().replace(' ', '_')}_{uuid.uuid4()}.png"
64
+ placeholder_img.save(output_filename)
65
+ return output_filename
 
 
 
 
66
 
67
+ try:
68
+ print(f"{model_name_str}: Processing with threshold {threshold:.2f}", file=sys.stderr)
69
+ image_rgb = image_pil.convert("RGB") # Ensure 3-channel RGB
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ transform = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.transforms()
72
+ image_tensor = transform(image_rgb).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ with torch.no_grad():
75
+ prediction = model_instance(image_tensor)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ boxes, labels, scores = prediction['boxes'].cpu().numpy(), prediction['labels'].cpu().numpy(), prediction['scores'].cpu().numpy()
78
+
79
+ annotated_image = image_rgb.copy()
80
+ draw = ImageDraw.Draw(annotated_image)
81
+ label_font = get_font(12)
82
+ detections_made = False
83
+
84
+ for box, label_id, score in zip(boxes, labels, scores):
85
+ if score >= threshold:
86
+ detections_made = True
87
+ x1, y1, x2, y2 = box
88
+ draw.rectangle([(x1, y1), (x2, y2)], outline='red', width=3)
89
+ class_name = COCO_INSTANCE_CATEGORY_NAMES[label_id]
90
+ text_label = f'{class_name}: {score:.2f}'
91
+
92
+ try: tb_box = draw.textbbox((0,0), text_label, font=label_font) # Get text size
93
+ except AttributeError: tb_box = (0,0, draw.textlength(text_label, font=label_font), label_font.getmetrics()[0])
94
+
95
+ text_w, text_h = tb_box[2] - tb_box[0], tb_box[3] - tb_box[1]
96
+ bg_y1 = y1 - text_h - 4 if y1 - text_h - 4 > 0 else y1 + 2
97
+ draw.rectangle([x1, bg_y1, x1 + text_w + 4, bg_y1 + text_h + 4], fill='yellow')
98
+ draw.text((x1 + 2, bg_y1 + 2), text_label, fill='black', font=label_font)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ if not detections_made:
101
+ no_detection_font = get_font(max(15, min(annotated_image.width, annotated_image.height) // 20)) # Scaled font
102
+ no_detection_text = f"{model_name_str}:\nNo objects detected\n(Threshold: {threshold:.2f})"
103
+ try: bbox_nd = draw.textbbox((0,0), no_detection_text, font=no_detection_font, align="center")
104
+ except AttributeError: bbox_nd = (0,0, draw.textlength(no_detection_text.split('\n')[0], font=no_detection_font), (no_detection_font.getmetrics()[0] + no_detection_font.getmetrics()[1]) * no_detection_text.count('\n') + no_detection_font.getmetrics()[0])
105
+ text_w_nd, text_h_nd = bbox_nd[2]-bbox_nd[0], bbox_nd[3]-bbox_nd[1]
106
+ draw.text(((annotated_image.width - text_w_nd) / 2, (annotated_image.height - text_h_nd) / 2),
107
+ no_detection_text, fill="blue", font=no_detection_font, align="center", stroke_width=1, stroke_fill="white")
108
+
109
+ output_filename = f"detection_{model_name_str.lower().replace(' ', '_')}_{uuid.uuid4()}.png"
110
+ annotated_image.save(output_filename)
111
+ return output_filename
112
+
113
  except Exception as e:
114
+ print(f"ERROR in {model_name_str} run_detection_on_image: {e}", file=sys.stderr)
 
115
  traceback.print_exc(file=sys.stderr)
116
+ error_img = Image.new('RGB', (400, 300), color='lightpink')
117
+ draw = ImageDraw.Draw(error_img)
118
+ font = get_font(15)
119
+ text = f"{model_name_str} Error:\n{str(e)[:100]}" # Limit error message length
120
+ try: bbox_err = draw.textbbox((0,0), text, font=font, align="center")
121
+ except AttributeError: bbox_err = (0,0, draw.textlength(text.split('\n')[0], font=font), (font.getmetrics()[0] + font.getmetrics()[1]) * text.count('\n') + font.getmetrics()[0])
122
+ text_w_err, text_h_err = bbox_err[2]-bbox_err[0], bbox_err[3]-bbox_err[1]
123
+ draw.text(((400 - text_w_err) / 2, (300 - text_h_err) / 2), text, fill="black", font=font, align="center")
124
+ error_filename = f"error_{model_name_str.lower().replace(' ', '_')}_{uuid.uuid4()}.png"
125
+ error_img.save(error_filename)
126
+ return error_filename
127
+
128
+ # --- Prepare Example Images ---
129
+ example_files_src = ["TEST_IMG_1.jpg", "TEST_IMG_2.JPG", "TEST_IMG_3.jpg", "TEST_IMG_4.jpg"]
130
+ app_root = os.getcwd() # Assumes script runs from app root
131
+ example_paths_final = [os.path.join(app_root, f) for f in example_files_src]
132
+ valid_examples_list = [p for p in example_paths_final if os.path.exists(p)]
133
+
134
+ if not valid_examples_list:
135
+ print("Warning: No example images found at app root. Creating dummy examples.", file=sys.stderr)
136
+ try:
137
+ for i in range(1, 3):
138
+ dummy_fname = f"dummy_example_{i}.png"
139
+ if not os.path.exists(os.path.join(app_root, dummy_fname)):
140
+ img = Image.new('RGB', (300, 200), color=('darkred' if i == 1 else 'darkgreen'))
141
+ draw = ImageDraw.Draw(img)
142
+ font = get_font(25)
143
+ draw.text((10, 10), f"Dummy Example {i}", font=font, fill="white")
144
+ img.save(os.path.join(app_root, dummy_fname))
145
+ valid_examples_list = [os.path.join(app_root, f"dummy_example_{i}.png") for i in range(1, 3) if os.path.exists(os.path.join(app_root, f"dummy_example_{i}.png"))]
146
+ print(f"Created/using dummy examples: {valid_examples_list}", file=sys.stderr)
147
+ except Exception as e:
148
+ print(f"Failed to create dummy examples: {e}", file=sys.stderr)
149
+ valid_examples_list = []
150
+
151
+ print(f"Final list of examples to use: {valid_examples_list}", file=sys.stderr)
152
+
153
+ # --- Gradio UI Definition ---
154
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo:
155
+ gr.Markdown("# 🖼️ Object Detection Game: Model vs. Model 🏆")
156
+ gr.Markdown("Can you guess which model configuration will perform better on your image?")
157
 
158
+ # --- Output Display Area (initially hidden) ---
159
+ with gr.Row(visible=False) as results_feedback_row:
160
+ user_guess_feedback_display = gr.Markdown("")
161
+ with gr.Row(visible=False) as results_images_row:
162
+ output_img_model_A = gr.Image(label="Model A Output", type="filepath", interactive=False)
163
+ output_img_model_B = gr.Image(label="Model B Output", type="filepath", interactive=False)
164
+
165
+ # --- Input and Controls Area ---
166
+ with gr.Row():
167
+ image_uploader = gr.Image(type="pil", label="🖼️ Upload Your Image Here")
168
 
169
+ with gr.Column(scale=1): # Control panel
170
+ task_type_selector = gr.Radio(
171
+ ["Detect Objects", "Segment Objects (Coming Soon!)"],
172
+ label="🎯 1. Select Task:",
173
+ value="Detect Objects"
174
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
+ with gr.Group(visible=True) as detection_controls_group:
177
+ gr.Markdown("--- \n ### ⚔️ Detection Challenge Details:")
178
+ gr.Markdown("ทั้ง **Model A** และ **Model B** คือ Faster R-CNN (ResNet50 FPN).") # Thai for fun
179
+ gr.Markdown("- **Model A**: You control its confidence threshold.")
180
+ gr.Markdown("- **Model B**: Its threshold is `Model A Threshold - 0.15` (minimum 0.05).")
181
 
182
+ model_A_threshold_slider = gr.Slider(
183
+ minimum=0.1, maximum=0.95, value=0.5, step=0.05,
184
+ label="⚙️ 2. Confidence Threshold for Model A"
185
+ )
186
+ user_model_preference_guess = gr.Radio(
187
+ ["Model A will be better", "Model B will be better", "They will be similar"],
188
+ label="🤔 3. Your Guess:",
189
+ value="Model A will be better"
190
+ )
191
+ run_game_button = gr.Button("🚀 Run Detection & Reveal Results!", variant="primary")
192
+
193
+ with gr.Group(visible=False) as segmentation_controls_group:
194
+ gr.Markdown("--- \n ### 🚧 Segmentation Challenge (Coming Soon!)")
195
+ gr.Markdown("This feature is under active development. Please choose 'Detect Objects' for now.")
196
+
197
+ if valid_examples_list:
198
+ gr.Examples(
199
+ examples=[[ex_path] for ex_path in valid_examples_list],
200
+ inputs=[image_uploader],
201
+ label="✨ Click an Example Image to Load",
202
+ # cache_examples=True # Set to True if examples are static and processing is heavy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  )
204
+
205
+ # --- Event Handlers ---
206
+ def handle_task_selection(selected_task):
207
+ """Updates visibility of control groups and hides results when task changes."""
208
+ show_detection = (selected_task == "Detect Objects")
209
+ return (
210
+ gr.update(visible=show_detection), # detection_controls_group
211
+ gr.update(visible=not show_detection), # segmentation_controls_group
212
+ gr.update(visible=False), # results_feedback_row
213
+ gr.update(visible=False) # results_images_row
214
  )
215
+
216
+ task_type_selector.change(
217
+ fn=handle_task_selection,
218
+ inputs=task_type_selector,
219
+ outputs=[detection_controls_group, segmentation_controls_group, results_feedback_row, results_images_row]
220
+ )
221
+
222
+ def execute_detection_game(image_pil_data, chosen_task, user_guess_str, threshold_for_A):
223
+ """Main game logic: processes image with both models and returns results."""
224
+ if image_pil_data is None:
225
+ msg = "⚠️ **Oops! Please upload an image first.**"
226
+ return gr.update(value=msg), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None)
227
+
228
+ if chosen_task != "Detect Objects":
229
+ msg = f"⚠️ **Hold on!** '{chosen_task}' is not quite ready. Please select 'Detect Objects' to play."
230
+ return gr.update(value=msg), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None)
231
+
232
+ threshold_for_B = max(0.05, threshold_for_A - 0.15) # Ensure threshold_B is not too low or negative
233
 
234
+ print(f"Player guessed: {user_guess_str}", file=sys.stderr)
235
+ print(f"Model A using threshold: {threshold_for_A:.2f}", file=sys.stderr)
236
+ print(f"Model B using threshold: {threshold_for_B:.2f}", file=sys.stderr)
237
+
238
+ output_path_A = run_detection_on_image(image_pil_data, threshold_for_A, model_A, "Model A")
239
+ output_path_B = run_detection_on_image(image_pil_data, threshold_for_B, model_B, "Model B")
240
+
241
+ feedback_text = (f"💬 You guessed: **{user_guess_str}**.\n\n"
242
+ f" দেখে নিন (See the results!): Model A (Threshold: {threshold_for_A:.2f}) vs. Model B (Threshold: {threshold_for_B:.2f})")
 
 
 
 
 
 
 
 
 
 
243
 
244
+ return (
245
+ gr.update(value=feedback_text), # For user_guess_feedback_display
246
+ gr.update(visible=True), # Make results_feedback_row visible
247
+ gr.update(visible=True), # Make results_images_row visible
248
+ gr.update(value=output_path_A), # Set image for output_img_model_A
249
+ gr.update(value=output_path_B) # Set image for output_img_model_B
250
+ )
251
+
252
+ run_game_button.click(
253
+ fn=execute_detection_game,
254
+ inputs=[image_uploader, task_type_selector, user_model_preference_guess, model_A_threshold_slider],
255
+ outputs=[user_guess_feedback_display, results_feedback_row, results_images_row, output_img_model_A, output_img_model_B]
256
+ )
257
 
 
258
  if __name__ == "__main__":
259
+ demo.launch(debug=True) # debug=True is helpful for development