Spaces:
Sleeping
Sleeping
Eric P. Nusbaum
commited on
Commit
·
f61c335
1
Parent(s):
b5e07e6
Update Space
Browse files- app.py +106 -103
- requirements.txt +5 -4
app.py
CHANGED
@@ -1,149 +1,152 @@
|
|
1 |
import os
|
2 |
import numpy as np
|
|
|
3 |
import onnxruntime
|
4 |
from PIL import Image, ImageDraw, ImageFont
|
5 |
import gradio as gr
|
6 |
|
7 |
-
#
|
|
|
8 |
MODEL_PATH = os.path.join("onnx", "model.onnx")
|
9 |
LABELS_PATH = os.path.join("onnx", "labels.txt")
|
10 |
|
11 |
# Load labels
|
12 |
with open(LABELS_PATH, "r") as f:
|
13 |
-
LABELS =
|
14 |
|
15 |
-
# Initialize ONNX Runtime session
|
16 |
class Model:
|
17 |
def __init__(self, model_filepath):
|
18 |
-
# Initialize the InferenceSession
|
19 |
self.session = onnxruntime.InferenceSession(model_filepath)
|
20 |
-
|
21 |
-
# Ensure the model has exactly one input
|
22 |
-
assert len(self.session.get_inputs()) == 1, "Model should have exactly one input."
|
23 |
-
|
24 |
-
# Extract input details
|
25 |
self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W)
|
26 |
self.input_name = self.session.get_inputs()[0].name
|
27 |
-
self.input_type = {
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
# Extract output names
|
33 |
-
self.output_names = [output.name for output in self.session.get_outputs()]
|
34 |
-
|
35 |
-
# Default preprocessing flags
|
36 |
self.is_bgr = False
|
37 |
self.is_range255 = False
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
for key, value in metadata_map.items():
|
42 |
-
if key == 'Image.BitmapPixelFormat' and value == 'Bgr8':
|
43 |
self.is_bgr = True
|
44 |
-
elif key == 'Image.NominalPixelRange' and value == 'NominalRange_0_255':
|
45 |
self.is_range255 = True
|
46 |
|
47 |
-
def predict(self, image):
|
48 |
# Preprocess image
|
49 |
image_resized = image.resize(self.input_shape)
|
50 |
input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
|
51 |
input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W)
|
52 |
-
|
53 |
if self.is_bgr:
|
54 |
-
input_array = input_array[:, (2, 1, 0), :, :]
|
55 |
-
|
56 |
if not self.is_range255:
|
57 |
input_array = input_array / 255.0 # Normalize to [0,1]
|
58 |
-
|
59 |
-
# Prepare input tensor
|
60 |
-
input_tensor = input_array.astype(self.input_type)
|
61 |
-
|
62 |
# Run inference
|
63 |
-
outputs = self.session.run(self.output_names, {self.input_name:
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
boxes = outputs[0][0] # shape: [num_detections, 4]
|
69 |
-
labels = outputs[1][0].astype(int) # shape: [num_detections]
|
70 |
-
scores = outputs[2][0] # shape: [num_detections]
|
71 |
-
return boxes, labels, scores
|
72 |
-
else:
|
73 |
-
raise ValueError("Unexpected number of outputs from the model.")
|
74 |
-
|
75 |
-
# Load the model
|
76 |
-
model = Model(MODEL_PATH)
|
77 |
|
78 |
-
#
|
79 |
-
|
80 |
-
|
81 |
try:
|
82 |
-
|
|
|
83 |
except IOError:
|
|
|
84 |
font = ImageFont.load_default()
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
91 |
continue
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
ymin
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
return image
|
109 |
|
110 |
-
#
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
title = "JunkWaxHero: Object Detection for Junk Wax Baseball Cards"
|
133 |
-
description = """
|
134 |
-
Upload an image of a Junk Wax Baseball Card, and the model will identify the card by its set (1980-1999).
|
135 |
-
"""
|
136 |
|
|
|
137 |
iface = gr.Interface(
|
138 |
-
fn=
|
139 |
inputs=gr.Image(type="pil"),
|
140 |
-
outputs=
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
145 |
)
|
146 |
|
147 |
-
# Launch the interface
|
148 |
if __name__ == "__main__":
|
149 |
iface.launch()
|
|
|
1 |
import os
|
2 |
import numpy as np
|
3 |
+
import onnx
|
4 |
import onnxruntime
|
5 |
from PIL import Image, ImageDraw, ImageFont
|
6 |
import gradio as gr
|
7 |
|
8 |
+
# Constants
|
9 |
+
PROB_THRESHOLD = 0.5 # Minimum probability to show results
|
10 |
MODEL_PATH = os.path.join("onnx", "model.onnx")
|
11 |
LABELS_PATH = os.path.join("onnx", "labels.txt")
|
12 |
|
13 |
# Load labels
|
14 |
with open(LABELS_PATH, "r") as f:
|
15 |
+
LABELS = f.read().strip().split("\n")
|
16 |
|
|
|
17 |
class Model:
|
18 |
def __init__(self, model_filepath):
|
|
|
19 |
self.session = onnxruntime.InferenceSession(model_filepath)
|
20 |
+
assert len(self.session.get_inputs()) == 1
|
|
|
|
|
|
|
|
|
21 |
self.input_shape = self.session.get_inputs()[0].shape[2:] # (H, W)
|
22 |
self.input_name = self.session.get_inputs()[0].name
|
23 |
+
self.input_type = {'tensor(float)': np.float32, 'tensor(float16)': np.float16}.get(
|
24 |
+
self.session.get_inputs()[0].type, np.float32
|
25 |
+
)
|
26 |
+
self.output_names = [o.name for o in self.session.get_outputs()]
|
27 |
+
|
|
|
|
|
|
|
|
|
28 |
self.is_bgr = False
|
29 |
self.is_range255 = False
|
30 |
+
onnx_model = onnx.load(model_filepath)
|
31 |
+
for metadata in onnx_model.metadata_props:
|
32 |
+
if metadata.key == 'Image.BitmapPixelFormat' and metadata.value == 'Bgr8':
|
|
|
|
|
33 |
self.is_bgr = True
|
34 |
+
elif metadata.key == 'Image.NominalPixelRange' and metadata.value == 'NominalRange_0_255':
|
35 |
self.is_range255 = True
|
36 |
|
37 |
+
def predict(self, image: Image.Image):
|
38 |
# Preprocess image
|
39 |
image_resized = image.resize(self.input_shape)
|
40 |
input_array = np.array(image_resized, dtype=np.float32)[np.newaxis, :, :, :]
|
41 |
input_array = input_array.transpose((0, 3, 1, 2)) # (N, C, H, W)
|
|
|
42 |
if self.is_bgr:
|
43 |
+
input_array = input_array[:, (2, 1, 0), :, :]
|
|
|
44 |
if not self.is_range255:
|
45 |
input_array = input_array / 255.0 # Normalize to [0,1]
|
46 |
+
|
|
|
|
|
|
|
47 |
# Run inference
|
48 |
+
outputs = self.session.run(self.output_names, {self.input_name: input_array.astype(self.input_type)})
|
49 |
+
return {name: outputs[i] for i, name in enumerate(self.output_names)}
|
50 |
+
|
51 |
+
def draw_boxes(image: Image.Image, outputs: dict):
|
52 |
+
draw = ImageDraw.Draw(image, "RGBA") # Use RGBA for transparency
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
# Dynamic font size based on image dimensions
|
55 |
+
image_width, image_height = image.size
|
56 |
+
font_size = max(20, image_width // 50) # Increased minimum font size
|
57 |
try:
|
58 |
+
# Attempt to load a truetype font; adjust the path if necessary
|
59 |
+
font = ImageFont.truetype("arial.ttf", size=font_size)
|
60 |
except IOError:
|
61 |
+
# Fallback to default font if truetype font is not found
|
62 |
font = ImageFont.load_default()
|
63 |
|
64 |
+
boxes = outputs.get('detected_boxes', [])
|
65 |
+
classes = outputs.get('detected_classes', [])
|
66 |
+
scores = outputs.get('detected_scores', [])
|
67 |
+
|
68 |
+
for box, cls, score in zip(boxes[0], classes[0], scores[0]):
|
69 |
+
if score < PROB_THRESHOLD:
|
70 |
continue
|
71 |
+
label = LABELS[int(cls)]
|
72 |
+
|
73 |
+
# Assuming box format: [ymin, xmin, ymax, xmax] normalized [0,1]
|
74 |
+
ymin, xmin, ymax, xmax = box
|
75 |
+
left = xmin * image_width
|
76 |
+
right = xmax * image_width
|
77 |
+
top = ymin * image_height
|
78 |
+
bottom = ymax * image_height
|
79 |
+
|
80 |
+
# Draw bounding box
|
81 |
+
draw.rectangle([left, top, right, bottom], outline="red", width=3)
|
82 |
+
|
83 |
+
# Prepare label text
|
84 |
+
text = f"{label}: {score:.2f}"
|
85 |
+
|
86 |
+
# Calculate text size using textbbox
|
87 |
+
text_bbox = draw.textbbox((0, 0), text, font=font)
|
88 |
+
text_width = text_bbox[2] - text_bbox[0]
|
89 |
+
text_height = text_bbox[3] - text_bbox[1]
|
90 |
+
|
91 |
+
# Calculate label background position
|
92 |
+
# Ensure the label box does not go above the image
|
93 |
+
label_top = max(top - text_height - 10, 0)
|
94 |
+
label_left = left
|
95 |
+
|
96 |
+
# Draw semi-transparent rectangle behind text
|
97 |
+
draw.rectangle(
|
98 |
+
[label_left, label_top, label_left + text_width + 10, label_top + text_height + 10],
|
99 |
+
fill=(255, 0, 0, 160) # Semi-transparent red
|
100 |
+
)
|
101 |
+
|
102 |
+
# Draw text
|
103 |
+
draw.text(
|
104 |
+
(label_left + 5, label_top + 5),
|
105 |
+
text,
|
106 |
+
fill="white",
|
107 |
+
font=font
|
108 |
+
)
|
109 |
|
110 |
return image
|
111 |
|
112 |
+
# Initialize model
|
113 |
+
model = Model(MODEL_PATH)
|
114 |
+
|
115 |
+
def detect_objects(image):
|
116 |
+
outputs = model.predict(image)
|
117 |
+
annotated_image = draw_boxes(image.copy(), outputs)
|
118 |
+
|
119 |
+
# Prepare detection summary
|
120 |
+
detections = []
|
121 |
+
boxes = outputs.get('detected_boxes', [])
|
122 |
+
classes = outputs.get('detected_classes', [])
|
123 |
+
scores = outputs.get('detected_scores', [])
|
124 |
+
|
125 |
+
for box, cls, score in zip(boxes[0], classes[0], scores[0]):
|
126 |
+
if score < PROB_THRESHOLD:
|
127 |
+
continue
|
128 |
+
label = LABELS[int(cls)]
|
129 |
+
detections.append(f"{label}: {score:.2f}")
|
130 |
+
|
131 |
+
detection_summary = "\n".join(detections) if detections else "No objects detected."
|
132 |
+
|
133 |
+
return annotated_image, detection_summary
|
|
|
|
|
|
|
|
|
134 |
|
135 |
+
# Gradio Interface
|
136 |
iface = gr.Interface(
|
137 |
+
fn=detect_objects,
|
138 |
inputs=gr.Image(type="pil"),
|
139 |
+
outputs=[
|
140 |
+
gr.Image(type="pil", label="Detected Objects"),
|
141 |
+
gr.Textbox(label="Detections")
|
142 |
+
],
|
143 |
+
title="Object Detection with ONNX Model",
|
144 |
+
description="Upload an image to detect objects using the ONNX model.",
|
145 |
+
examples=["examples/card1.jpg", "examples/card2.jpg", "examples/card3.jpg"],
|
146 |
+
theme="default", # You can choose other themes if desired
|
147 |
+
allow_flagging="never" # Disable flagging if not needed
|
148 |
+
# Removed 'layout' parameter
|
149 |
)
|
150 |
|
|
|
151 |
if __name__ == "__main__":
|
152 |
iface.launch()
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
gradio
|
2 |
-
|
3 |
-
onnxruntime
|
4 |
-
|
|
|
|
1 |
+
gradio==3.32.0
|
2 |
+
onnx==1.14.0
|
3 |
+
onnxruntime==1.15.1
|
4 |
+
Pillow>=10.0.0
|
5 |
+
numpy==1.25.0
|