dnth commited on
Commit
62393c2
·
verified ·
1 Parent(s): 8fabc30

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +414 -0
app.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import colorsys
3
+ import os
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import onnxruntime as ort
8
+ import pandas as pd
9
+ from PIL import Image, ImageDraw
10
+
11
+
12
+ def resize_with_aspect_ratio(image, size, interpolation=Image.BILINEAR):
13
+ """Resizes an image while maintaining aspect ratio and pads it."""
14
+ original_width, original_height = image.size
15
+ ratio = min(size / original_width, size / original_height)
16
+ new_width = int(original_width * ratio)
17
+ new_height = int(original_height * ratio)
18
+ image = image.resize((new_width, new_height), interpolation)
19
+
20
+ # Create a new image with the desired size and paste the resized image onto it
21
+ new_image = Image.new("RGB", (size, size))
22
+ new_image.paste(image, ((size - new_width) // 2, (size - new_height) // 2))
23
+ return new_image, ratio, (size - new_width) // 2, (size - new_height) // 2
24
+
25
+
26
+ def generate_colors(num_classes):
27
+ """Generate a list of distinct colors for different classes."""
28
+ # Generate evenly spaced hues
29
+ hsv_tuples = [(x / num_classes, 0.8, 0.9) for x in range(num_classes)]
30
+
31
+ # Convert to RGB
32
+ colors = []
33
+ for hsv in hsv_tuples:
34
+ rgb = colorsys.hsv_to_rgb(*hsv)
35
+ # Convert to 0-255 range and to tuple
36
+ colors.append(tuple(int(255 * x) for x in rgb))
37
+
38
+ return colors
39
+
40
+
41
+ def draw(images, labels, boxes, scores, ratios, paddings, thrh=0.4, class_names=None):
42
+ result_images = []
43
+
44
+ # Generate colors for classes
45
+ num_classes = (
46
+ len(class_names) if class_names else 91
47
+ ) # Use length of class_names if available, otherwise default to COCO's 91 classes
48
+ colors = generate_colors(num_classes)
49
+
50
+ for i, im in enumerate(images):
51
+ draw = ImageDraw.Draw(im)
52
+ scr = scores[i]
53
+ lab = labels[i][scr > thrh]
54
+ box = boxes[i][scr > thrh]
55
+ scr = scr[scr > thrh]
56
+
57
+ ratio = ratios[i]
58
+ pad_w, pad_h = paddings[i]
59
+
60
+ for lbl, bb in zip(lab, box):
61
+ # Get color for this class
62
+ class_idx = int(lbl)
63
+ color = colors[class_idx % len(colors)]
64
+
65
+ # Convert RGB to hex for PIL
66
+ hex_color = "#{:02x}{:02x}{:02x}".format(*color)
67
+
68
+ # Adjust bounding boxes according to the resizing and padding
69
+ bb = [
70
+ (bb[0] - pad_w) / ratio,
71
+ (bb[1] - pad_h) / ratio,
72
+ (bb[2] - pad_w) / ratio,
73
+ (bb[3] - pad_h) / ratio,
74
+ ]
75
+
76
+ # Draw rectangle with class-specific color
77
+ draw.rectangle(bb, outline=hex_color, width=3)
78
+
79
+ # Use class name if available, otherwise use class index
80
+ if class_names and class_idx < len(class_names):
81
+ label_text = f"{class_names[class_idx]} {scr[lab == lbl][0]:.2f}"
82
+ else:
83
+ label_text = f"Class {class_idx} {scr[lab == lbl][0]:.2f}"
84
+
85
+ # Draw text background
86
+ text_size = draw.textbbox((0, 0), label_text, font=None)
87
+ text_width = text_size[2] - text_size[0]
88
+ text_height = text_size[3] - text_size[1]
89
+
90
+ # Draw text background rectangle
91
+ draw.rectangle(
92
+ [bb[0], bb[1] - text_height - 4, bb[0] + text_width + 4, bb[1]],
93
+ fill=hex_color,
94
+ )
95
+
96
+ # Draw text in white or black depending on color brightness
97
+ brightness = (color[0] * 299 + color[1] * 587 + color[2] * 114) / 1000
98
+ text_color = "black" if brightness > 128 else "white"
99
+
100
+ # Draw text
101
+ draw.text(
102
+ (bb[0] + 2, bb[1] - text_height - 2), text=label_text, fill=text_color
103
+ )
104
+
105
+ result_images.append(im)
106
+ return result_images
107
+
108
+
109
+ def load_model(model_path):
110
+ """
111
+ Load an ONNX model for inference.
112
+
113
+ Args:
114
+ model_path: Path to the ONNX model file
115
+
116
+ Returns:
117
+ tuple: (session, error_message)
118
+ """
119
+ providers = ["CPUExecutionProvider"]
120
+
121
+ try:
122
+ sess = ort.InferenceSession(model_path, providers=providers)
123
+ print(f"Using device: {ort.get_device()}")
124
+ return sess, None
125
+ except Exception as e:
126
+ return None, f"Error creating inference session: {e}"
127
+
128
+
129
+ def load_class_names(class_names_path):
130
+ """
131
+ Load class names from a text file.
132
+
133
+ Args:
134
+ class_names_path: Path to a text file with class names (one per line)
135
+
136
+ Returns:
137
+ list: Class names or None if loading failed
138
+ """
139
+ if not class_names_path or not os.path.exists(class_names_path):
140
+ return None
141
+
142
+ try:
143
+ with open(class_names_path, "r") as f:
144
+ class_names = [line.strip() for line in f.readlines()]
145
+ print(f"Loaded {len(class_names)} class names")
146
+ return class_names
147
+ except Exception as e:
148
+ print(f"Error loading class names: {e}")
149
+ return None
150
+
151
+
152
+ def prepare_image(image):
153
+ """
154
+ Prepare image for inference by converting to PIL and resizing.
155
+
156
+ Args:
157
+ image: Input image (PIL or numpy array)
158
+
159
+ Returns:
160
+ tuple: (resized_image, original_image, ratio, padding)
161
+ """
162
+ # Convert to PIL image if needed
163
+ if not isinstance(image, Image.Image):
164
+ image = Image.fromarray(image).convert("RGB")
165
+
166
+ # Resize image while preserving aspect ratio
167
+ resized_image, ratio, pad_w, pad_h = resize_with_aspect_ratio(image, 640)
168
+
169
+ return resized_image, image, ratio, (pad_w, pad_h)
170
+
171
+
172
+ def run_inference(session, image):
173
+ """
174
+ Run inference on the prepared image.
175
+
176
+ Args:
177
+ session: ONNX runtime session
178
+ image: Prepared PIL image
179
+
180
+ Returns:
181
+ tuple: (labels, boxes, scores)
182
+ """
183
+ # Get original image dimensions
184
+ orig_height, orig_width = image.size[1], image.size[0]
185
+ # Convert to int64 as expected by the model
186
+ orig_size = np.array([[orig_height, orig_width]], dtype=np.int64)
187
+
188
+ # Convert PIL image to numpy array and normalize to 0-1 range
189
+ im_data = np.array(image, dtype=np.float32) / 255.0
190
+ # Transpose from HWC to CHW format
191
+ im_data = im_data.transpose(2, 0, 1)
192
+ # Add batch dimension
193
+ im_data = np.expand_dims(im_data, axis=0)
194
+
195
+ output = session.run(
196
+ output_names=None,
197
+ input_feed={"images": im_data, "orig_target_sizes": orig_size},
198
+ )
199
+
200
+ return output # labels, boxes, scores
201
+
202
+
203
+ def count_objects(labels, scores, confidence_threshold, class_names):
204
+ """
205
+ Count detected objects by class.
206
+
207
+ Args:
208
+ labels: Detection labels
209
+ scores: Detection confidence scores
210
+ confidence_threshold: Minimum confidence threshold
211
+ class_names: List of class names
212
+
213
+ Returns:
214
+ dict: Counts of objects by class
215
+ """
216
+ object_counts = {}
217
+ for i, score_batch in enumerate(scores):
218
+ for j, score in enumerate(score_batch):
219
+ if score >= confidence_threshold:
220
+ label = labels[i][j]
221
+ class_name = (
222
+ class_names[label]
223
+ if class_names and label < len(class_names)
224
+ else f"Class {label}"
225
+ )
226
+ object_counts[class_name] = object_counts.get(class_name, 0) + 1
227
+
228
+ return object_counts
229
+
230
+
231
+ def create_status_message(object_counts):
232
+ """
233
+ Create a status message with object counts.
234
+
235
+ Args:
236
+ object_counts: Dictionary of object counts by class
237
+
238
+ Returns:
239
+ str: Formatted status message
240
+ """
241
+ status_message = "Detection completed successfully\n\nObjects detected:"
242
+ if object_counts:
243
+ for class_name, count in object_counts.items():
244
+ status_message += f"\n- {class_name}: {count}"
245
+ else:
246
+ status_message += "\n- No objects detected above confidence threshold"
247
+
248
+ return status_message
249
+
250
+
251
+ def create_bar_data(object_counts):
252
+ """
253
+ Create data for the bar plot visualization.
254
+
255
+ Args:
256
+ object_counts: Dictionary of object counts by class
257
+
258
+ Returns:
259
+ DataFrame: Data for bar plot
260
+ """
261
+ if object_counts:
262
+ # Sort by count in descending order
263
+ sorted_counts = sorted(object_counts.items(), key=lambda x: x[1], reverse=True)
264
+ class_names_list = [item[0] for item in sorted_counts]
265
+ counts_list = [item[1] for item in sorted_counts]
266
+ # Create a pandas DataFrame for the bar plot
267
+ return pd.DataFrame({"Class": class_names_list, "Count": counts_list})
268
+ else:
269
+ return pd.DataFrame({"Class": ["No objects detected"], "Count": [0]})
270
+
271
+
272
+ def predict(image, model_path, class_names_path, confidence_threshold):
273
+ """
274
+ Main prediction function that orchestrates the detection pipeline.
275
+
276
+ Args:
277
+ image: Input image
278
+ model_path: Path to ONNX model
279
+ class_names_path: Path to class names file
280
+ confidence_threshold: Detection confidence threshold
281
+
282
+ Returns:
283
+ tuple: (result_image, status_message, bar_data)
284
+ """
285
+ # Load model
286
+ session, error = load_model(model_path)
287
+ if error:
288
+ return None, error, None
289
+
290
+ # Load class names
291
+ class_names = load_class_names(class_names_path)
292
+
293
+ try:
294
+ # Prepare image
295
+ resized_image, original_image, ratio, padding = prepare_image(image)
296
+
297
+ # Run inference
298
+ labels, boxes, scores = run_inference(session, resized_image)
299
+
300
+ # Draw detections on the original image
301
+ result_images = draw(
302
+ [original_image],
303
+ labels,
304
+ boxes,
305
+ scores,
306
+ [ratio],
307
+ [padding],
308
+ thrh=confidence_threshold,
309
+ class_names=class_names,
310
+ )
311
+
312
+ # Count objects by class
313
+ object_counts = count_objects(labels, scores, confidence_threshold, class_names)
314
+
315
+ # Create status message
316
+ status_message = create_status_message(object_counts)
317
+
318
+ # Create bar plot data
319
+ bar_data = create_bar_data(object_counts)
320
+
321
+ return result_images[0], status_message, bar_data
322
+ except Exception as e:
323
+ return None, f"Error during inference: {e}", None
324
+
325
+
326
+ def build_interface(model_path, class_names_path):
327
+ """
328
+ Build the Gradio interface components.
329
+
330
+ Args:
331
+ model_path: Path to the ONNX model
332
+ class_names_path: Path to the class names file
333
+
334
+ Returns:
335
+ gr.Blocks: The Gradio demo interface
336
+ """
337
+ with gr.Blocks(title="Blood Cell Detection") as demo:
338
+ gr.Markdown("# Blood Cell Detection")
339
+ gr.Markdown("Upload an image to detect blood cells. The model can detect 3 types of blood cells: red blood cells, white blood cells and platelets.")
340
+
341
+ with gr.Row():
342
+ with gr.Column():
343
+ input_image = gr.Image(type="pil", label="Input Image")
344
+ confidence = gr.Slider(
345
+ minimum=0.1,
346
+ maximum=1.0,
347
+ value=0.4,
348
+ step=0.05,
349
+ label="Confidence Threshold",
350
+ )
351
+ submit_btn = gr.Button("Detect Objects")
352
+
353
+ with gr.Column():
354
+ output_image = gr.Image(type="pil", label="Detection Result")
355
+
356
+ with gr.Row():
357
+ output_message = gr.Textbox(label="Status")
358
+
359
+ count_plot = gr.BarPlot(
360
+ y="Class",
361
+ x="Count",
362
+ title="Object Counts",
363
+ tooltip=["Class", "Count"],
364
+ height=300,
365
+ orientation="h",
366
+ label_title="Object Counts",
367
+ )
368
+
369
+ # Set up the click event inside the Blocks context
370
+ submit_btn.click(
371
+ fn=predict,
372
+ inputs=[
373
+ input_image,
374
+ gr.State(model_path),
375
+ gr.State(class_names_path),
376
+ confidence,
377
+ ],
378
+ outputs=[output_image, output_message, count_plot],
379
+ )
380
+
381
+ return demo
382
+
383
+
384
+ def launch_demo(args):
385
+ """
386
+ Launch the Gradio demo with the specified arguments.
387
+
388
+ Args:
389
+ args: Command-line arguments
390
+ """
391
+ demo = build_interface(args.onnx, args.class_names)
392
+
393
+ # Launch the demo
394
+ demo.launch(share=args.share)
395
+
396
+
397
+ if __name__ == "__main__":
398
+ parser = argparse.ArgumentParser(
399
+ description="Gradio demo for object detection with ONNX Runtime"
400
+ )
401
+ parser.add_argument(
402
+ "--onnx", type=str, required=True, help="Path to the ONNX model file"
403
+ )
404
+ parser.add_argument(
405
+ "--class-names",
406
+ type=str,
407
+ help="Path to a text file with class names (one per line)",
408
+ )
409
+ parser.add_argument(
410
+ "--share", action="store_true", help="Create a shareable link for the demo"
411
+ )
412
+ args = parser.parse_args()
413
+
414
+ launch_demo(args)