Spaces:
Runtime error
Runtime error
threshold added
Browse files
app.py
CHANGED
@@ -63,8 +63,9 @@ def nms(final_boxes, scores, IOU_threshold=0):
|
|
63 |
|
64 |
return final_boxes[pick]
|
65 |
|
66 |
-
def detect_obj(input_image):
|
67 |
try:
|
|
|
68 |
image = np.array(input_image)
|
69 |
image = cv2.resize(image, (H, W))
|
70 |
img = image
|
@@ -75,12 +76,9 @@ def detect_obj(input_image):
|
|
75 |
output = m.run(['reshape'], {"input": image})
|
76 |
output = np.squeeze(output, axis=0)
|
77 |
|
78 |
-
THRESH=.25
|
79 |
-
|
80 |
-
|
81 |
object_positions = np.concatenate(
|
82 |
-
[np.stack(np.where(output[..., 0]>=
|
83 |
-
np.stack(np.where(output[..., 5]>=
|
84 |
)
|
85 |
|
86 |
selected_output = []
|
@@ -93,7 +91,7 @@ def detect_obj(input_image):
|
|
93 |
|
94 |
for i,pos in enumerate(object_positions):
|
95 |
for j in range(2):
|
96 |
-
if selected_output[i][j*5]>
|
97 |
output_box = np.array(output[pos[0]][pos[1]][pos[2]][(j*5)+1:(j*5)+5], dtype=float)
|
98 |
|
99 |
x_centre = (np.array(pos[1], dtype=float) + output_box[0])*32
|
@@ -116,7 +114,7 @@ def detect_obj(input_image):
|
|
116 |
|
117 |
final_boxes = np.array(final_boxes)
|
118 |
|
119 |
-
nms_output = nms(final_boxes, final_scores,
|
120 |
|
121 |
for i in nms_output:
|
122 |
cv2.rectangle(
|
@@ -140,17 +138,32 @@ def detect_obj(input_image):
|
|
140 |
return input_image
|
141 |
|
142 |
|
|
|
143 |
with gr.Blocks(title="YOLOS Object Detection - ClassCat", css=".gradio-container {background:lightyellow;}") as demo:
|
144 |
gr.HTML('<h1>Yolo Object Detection</h1>')
|
145 |
gr.HTML("<h4>supported objects are [aeroplane,bicycle,bird,boat,bottle,bus,car,cat,chair,cow,diningtable,dog,horse,motorbike,person,pottedplant,sheep,sofa,train,tvmonitor]</h4>")
|
|
|
146 |
with gr.Row():
|
147 |
input_image = gr.Image(label="Input image", type="pil")
|
148 |
output_image = gr.Image(label="Output image", type="pil")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
send_btn = gr.Button("Detect")
|
|
|
151 |
gr.Examples(['./samples/out_1.jpg'], inputs=input_image)
|
152 |
|
153 |
-
send_btn.click(fn=detect_obj, inputs=[input_image], outputs=[output_image])
|
|
|
154 |
|
155 |
|
156 |
demo.launch(debug=True)
|
|
|
63 |
|
64 |
return final_boxes[pick]
|
65 |
|
66 |
+
def detect_obj(input_image, obj_threshold, bb_threshold):
|
67 |
try:
|
68 |
+
|
69 |
image = np.array(input_image)
|
70 |
image = cv2.resize(image, (H, W))
|
71 |
img = image
|
|
|
76 |
output = m.run(['reshape'], {"input": image})
|
77 |
output = np.squeeze(output, axis=0)
|
78 |
|
|
|
|
|
|
|
79 |
object_positions = np.concatenate(
|
80 |
+
[np.stack(np.where(output[..., 0]>=obj_threshold), axis=-1),
|
81 |
+
np.stack(np.where(output[..., 5]>=obj_threshold), axis=-1)], axis=0
|
82 |
)
|
83 |
|
84 |
selected_output = []
|
|
|
91 |
|
92 |
for i,pos in enumerate(object_positions):
|
93 |
for j in range(2):
|
94 |
+
if selected_output[i][j*5]>obj_threshold:
|
95 |
output_box = np.array(output[pos[0]][pos[1]][pos[2]][(j*5)+1:(j*5)+5], dtype=float)
|
96 |
|
97 |
x_centre = (np.array(pos[1], dtype=float) + output_box[0])*32
|
|
|
114 |
|
115 |
final_boxes = np.array(final_boxes)
|
116 |
|
117 |
+
nms_output = nms(final_boxes, final_scores, bb_threshold)
|
118 |
|
119 |
for i in nms_output:
|
120 |
cv2.rectangle(
|
|
|
138 |
return input_image
|
139 |
|
140 |
|
141 |
+
|
142 |
with gr.Blocks(title="YOLOS Object Detection - ClassCat", css=".gradio-container {background:lightyellow;}") as demo:
|
143 |
gr.HTML('<h1>Yolo Object Detection</h1>')
|
144 |
gr.HTML("<h4>supported objects are [aeroplane,bicycle,bird,boat,bottle,bus,car,cat,chair,cow,diningtable,dog,horse,motorbike,person,pottedplant,sheep,sofa,train,tvmonitor]</h4>")
|
145 |
+
gr.HTML("<br>")
|
146 |
with gr.Row():
|
147 |
input_image = gr.Image(label="Input image", type="pil")
|
148 |
output_image = gr.Image(label="Output image", type="pil")
|
149 |
+
gr.HTML("<br>")
|
150 |
+
gr.HTML("<h4>object centre detection threshold means the object centre will be considered a new object if it's value is above threshold</h4>")
|
151 |
+
gr.HTML("<p>less means more objects</p>")
|
152 |
+
gr.HTML("<h4>bounding box threshold is IOU value threshold. If intersection/union area of two bounding boxes are greater than threshold value the one box will be suppressed</h4>")
|
153 |
+
gr.HTML("<p>more means more bounding boxes<p>")
|
154 |
+
gr.HTML("<br>")
|
155 |
+
|
156 |
+
obj_threshold = gr.Slider(0, 1.0, value=0.2, label=' object centre detection threshold')
|
157 |
+
gr.HTML("<br>")
|
158 |
+
bb_threshold = gr.Slider(0, 1.0, value=0.3, label=' bounding box draw threshold')
|
159 |
+
gr.HTML("<br>")
|
160 |
|
161 |
send_btn = gr.Button("Detect")
|
162 |
+
gr.HTML("<br>")
|
163 |
gr.Examples(['./samples/out_1.jpg'], inputs=input_image)
|
164 |
|
165 |
+
send_btn.click(fn=detect_obj, inputs=[input_image, obj_threshold, bb_threshold], outputs=[output_image])
|
166 |
+
|
167 |
|
168 |
|
169 |
demo.launch(debug=True)
|