2lu commited on
Commit
9941ed8
1 Parent(s): 5b8ecd8

threshold added

Browse files
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -63,8 +63,9 @@ def nms(final_boxes, scores, IOU_threshold=0):
63
 
64
  return final_boxes[pick]
65
 
66
- def detect_obj(input_image):
67
  try:
 
68
  image = np.array(input_image)
69
  image = cv2.resize(image, (H, W))
70
  img = image
@@ -75,12 +76,9 @@ def detect_obj(input_image):
75
  output = m.run(['reshape'], {"input": image})
76
  output = np.squeeze(output, axis=0)
77
 
78
- THRESH=.25
79
-
80
-
81
  object_positions = np.concatenate(
82
- [np.stack(np.where(output[..., 0]>=THRESH), axis=-1),
83
- np.stack(np.where(output[..., 5]>=THRESH), axis=-1)], axis=0
84
  )
85
 
86
  selected_output = []
@@ -93,7 +91,7 @@ def detect_obj(input_image):
93
 
94
  for i,pos in enumerate(object_positions):
95
  for j in range(2):
96
- if selected_output[i][j*5]>THRESH:
97
  output_box = np.array(output[pos[0]][pos[1]][pos[2]][(j*5)+1:(j*5)+5], dtype=float)
98
 
99
  x_centre = (np.array(pos[1], dtype=float) + output_box[0])*32
@@ -116,7 +114,7 @@ def detect_obj(input_image):
116
 
117
  final_boxes = np.array(final_boxes)
118
 
119
- nms_output = nms(final_boxes, final_scores, 0.3)
120
 
121
  for i in nms_output:
122
  cv2.rectangle(
@@ -140,17 +138,32 @@ def detect_obj(input_image):
140
  return input_image
141
 
142
 
 
143
  with gr.Blocks(title="YOLOS Object Detection - ClassCat", css=".gradio-container {background:lightyellow;}") as demo:
144
  gr.HTML('<h1>Yolo Object Detection</h1>')
145
  gr.HTML("<h4>supported objects are [aeroplane,bicycle,bird,boat,bottle,bus,car,cat,chair,cow,diningtable,dog,horse,motorbike,person,pottedplant,sheep,sofa,train,tvmonitor]</h4>")
 
146
  with gr.Row():
147
  input_image = gr.Image(label="Input image", type="pil")
148
  output_image = gr.Image(label="Output image", type="pil")
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  send_btn = gr.Button("Detect")
 
151
  gr.Examples(['./samples/out_1.jpg'], inputs=input_image)
152
 
153
- send_btn.click(fn=detect_obj, inputs=[input_image], outputs=[output_image])
 
154
 
155
 
156
  demo.launch(debug=True)
 
63
 
64
  return final_boxes[pick]
65
 
66
+ def detect_obj(input_image, obj_threshold, bb_threshold):
67
  try:
68
+
69
  image = np.array(input_image)
70
  image = cv2.resize(image, (H, W))
71
  img = image
 
76
  output = m.run(['reshape'], {"input": image})
77
  output = np.squeeze(output, axis=0)
78
 
 
 
 
79
  object_positions = np.concatenate(
80
+ [np.stack(np.where(output[..., 0]>=obj_threshold), axis=-1),
81
+ np.stack(np.where(output[..., 5]>=obj_threshold), axis=-1)], axis=0
82
  )
83
 
84
  selected_output = []
 
91
 
92
  for i,pos in enumerate(object_positions):
93
  for j in range(2):
94
+ if selected_output[i][j*5]>obj_threshold:
95
  output_box = np.array(output[pos[0]][pos[1]][pos[2]][(j*5)+1:(j*5)+5], dtype=float)
96
 
97
  x_centre = (np.array(pos[1], dtype=float) + output_box[0])*32
 
114
 
115
  final_boxes = np.array(final_boxes)
116
 
117
+ nms_output = nms(final_boxes, final_scores, bb_threshold)
118
 
119
  for i in nms_output:
120
  cv2.rectangle(
 
138
  return input_image
139
 
140
 
141
+
142
  with gr.Blocks(title="YOLOS Object Detection - ClassCat", css=".gradio-container {background:lightyellow;}") as demo:
143
  gr.HTML('<h1>Yolo Object Detection</h1>')
144
  gr.HTML("<h4>supported objects are [aeroplane,bicycle,bird,boat,bottle,bus,car,cat,chair,cow,diningtable,dog,horse,motorbike,person,pottedplant,sheep,sofa,train,tvmonitor]</h4>")
145
+ gr.HTML("<br>")
146
  with gr.Row():
147
  input_image = gr.Image(label="Input image", type="pil")
148
  output_image = gr.Image(label="Output image", type="pil")
149
+ gr.HTML("<br>")
150
+ gr.HTML("<h4>object centre detection threshold means the object centre will be considered a new object if it's value is above threshold</h4>")
151
+ gr.HTML("<p>less means more objects</p>")
152
+ gr.HTML("<h4>bounding box threshold is IOU value threshold. If intersection/union area of two bounding boxes are greater than threshold value the one box will be suppressed</h4>")
153
+ gr.HTML("<p>more means more bounding boxes<p>")
154
+ gr.HTML("<br>")
155
+
156
+ obj_threshold = gr.Slider(0, 1.0, value=0.2, label=' object centre detection threshold')
157
+ gr.HTML("<br>")
158
+ bb_threshold = gr.Slider(0, 1.0, value=0.3, label=' bounding box draw threshold')
159
+ gr.HTML("<br>")
160
 
161
  send_btn = gr.Button("Detect")
162
+ gr.HTML("<br>")
163
  gr.Examples(['./samples/out_1.jpg'], inputs=input_image)
164
 
165
+ send_btn.click(fn=detect_obj, inputs=[input_image, obj_threshold, bb_threshold], outputs=[output_image])
166
+
167
 
168
 
169
  demo.launch(debug=True)