import gradio as gr import onnxruntime as rt import cv2 import numpy as np from PIL import Image H, W = 224, 224 classes=['aeroplane','bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable', 'dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor'] providers = ['CPUExecutionProvider'] m = rt.InferenceSession("./model/yolo_efficient.onnx", providers=providers) def nms(final_boxes, scores, IOU_threshold=0): scores = np.array(scores) final_boxes = np.array(final_boxes) boxes = final_boxes[...,:-1] boxes = [list(map(int, i)) for i in boxes] boxes = np.array(boxes) x1 = boxes[:, 0] y1 = boxes[:, 1] x2 = boxes[:, 2] y2 = boxes[:, 3] area = (x2 - x1)*(y2 - y1) order = np.argsort(scores) pick = [] while len(order) > 0: last = len(order)-1 i = order[last] pick.append(i) suppress = [last] if len(order)==0: break for pos in range(last): j = order[pos] xx1 = max(x1[i], x1[j]) yy1 = max(y1[i], y1[j]) xx2 = min(x2[i], x2[j]) yy2 = min(y2[i], y2[j]) w = max(0, xx2-xx1+1) h = max(0, yy2-yy1+1) overlap = float(w*h)/area[j] if overlap > IOU_threshold: suppress.append(pos) order = np.delete(order, suppress) return final_boxes[pick] def detect_obj(input_image, obj_threshold, bb_threshold): try: image = np.array(input_image) image = cv2.resize(image, (H, W)) img = image image = image.astype(np.float32) image = np.expand_dims(image, axis=0) output = m.run(['reshape'], {"input": image}) output = np.squeeze(output, axis=0) object_positions = np.concatenate( [np.stack(np.where(output[..., 0]>=obj_threshold), axis=-1), np.stack(np.where(output[..., 5]>=obj_threshold), axis=-1)], axis=0 ) selected_output = [] for indices in object_positions: selected_output.append(output[indices[0]][indices[1]][indices[2]]) selected_output = np.array(selected_output) final_boxes = [] final_scores = [] for i,pos in enumerate(object_positions): for j in range(2): if selected_output[i][j*5]>obj_threshold: output_box = np.array(output[pos[0]][pos[1]][pos[2]][(j*5)+1:(j*5)+5], dtype=float) x_centre = (np.array(pos[1], dtype=float) + output_box[0])*32 y_centre = (np.array(pos[2], dtype=float) + output_box[1])*32 x_width, y_height = abs(W*output_box[2]), abs(H*output_box[3]) x_min, y_min = int(x_centre - (x_width/2)), int(y_centre-(y_height/2)) x_max, y_max = int(x_centre+(x_width/2)), int(y_centre + (y_height/2)) if(x_min<0):x_min=0 if(y_min<0):y_min=0 if(x_max<0):x_max=0 if(y_max<0):y_max=0 final_boxes.append( [x_min, y_min, x_max, y_max, str(classes[np.argmax(selected_output[..., 10:], axis=-1)[i]])] ) final_scores.append(selected_output[i][j*5]) final_boxes = np.array(final_boxes) nms_output = nms(final_boxes, final_scores, bb_threshold) for i in nms_output: cv2.rectangle( img, (int(i[0]), int(i[1])), (int(i[2]), int(i[3])), (255, 0, 0) ) cv2.putText( img, i[-1], (int(i[0]), int(i[1])+15), cv2.FONT_HERSHEY_PLAIN, 1, (255, 0, 0), 1 ) output_pil_img = Image.fromarray(np.uint8(img)).convert('RGB') return output_pil_img except: return input_image with gr.Blocks(title="YOLOS Object Detection - ClassCat", css=".gradio-container {background:lightyellow;}") as demo: gr.HTML('

Yolo Object Detection

') gr.HTML("

supported objects are [aeroplane,bicycle,bird,boat,bottle,bus,car,cat,chair,cow,diningtable,dog,horse,motorbike,person,pottedplant,sheep,sofa,train,tvmonitor]

") gr.HTML("
") with gr.Row(): input_image = gr.Image(label="Input image", type="pil") output_image = gr.Image(label="Output image", type="pil") gr.HTML("
") gr.HTML("

object centre detection threshold means the object centre will be considered a new object if it's value is above threshold

") gr.HTML("

less means more objects

") gr.HTML("

bounding box threshold is IOU value threshold. If intersection/union area of two bounding boxes are greater than threshold value the one box will be suppressed

") gr.HTML("

more means more bounding boxes

") gr.HTML("
") obj_threshold = gr.Slider(0, 1.0, value=0.2, label=' object centre detection threshold') gr.HTML("
") bb_threshold = gr.Slider(0, 1.0, value=0.3, label=' bounding box draw threshold') gr.HTML("
") send_btn = gr.Button("Detect") gr.HTML("
") gr.Examples(['./samples/out_1.jpg'], inputs=input_image) send_btn.click(fn=detect_obj, inputs=[input_image, obj_threshold, bb_threshold], outputs=[output_image]) demo.launch(debug=True)