2lu's picture
model upload
41e7e8e
raw
history blame
4.08 kB
import gradio as gr
import onnxruntime as rt
import cv2
import numpy as np
from PIL import Image
H, W = 224, 224
classes=['aeroplane','bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable',
'dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor']
providers = ['CPUExecutionProvider']
m = rt.InferenceSession("./model/yolo_efficient.onnx", providers=providers)
def nms(final_boxes, scores, IOU_threshold=0):
scores = np.array(scores)
final_boxes = np.array(final_boxes)
boxes = final_boxes[...,:-1]
boxes = [list(map(int, i)) for i in boxes]
boxes = np.array(boxes)
#print(boxes)
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
area = (x2 - x1)*(y2 - y1)
order = np.argsort(scores)
#print(order)
pick = []
while len(order) > 0:
last = len(order)-1
i = order[last]
pick.append(i)
suppress = [last]
if len(order)==0:
break
for pos in range(last):
j = order[pos]
xx1 = max(x1[i], x1[j])
yy1 = max(y1[i], y1[j])
xx2 = min(x2[i], x2[j])
yy2 = min(y2[i], y2[j])
w = max(0, xx2-xx1+1)
h = max(0, yy2-yy1+1)
overlap = float(w*h)/area[j]
if overlap > IOU_threshold:
suppress.append(pos)
order = np.delete(order, suppress)
return final_boxes[pick]
def detect_obj(input_image):
try:
image = np.array(input_image)
image = cv2.resize(image, (H, W))
img = image
image = image.astype(np.float32)
image = np.expand_dims(image, axis=0)
print(image.shape)
output = m.run(['reshape'], {"input": image})
output = np.squeeze(output, axis=0)
print(output.shape)
THRESH=.25
object_positions = np.concatenate(
[np.stack(np.where(output[..., 0]>=THRESH), axis=-1),
np.stack(np.where(output[..., 5]>=THRESH), axis=-1)], axis=0
)
selected_output = []
for indices in object_positions:
selected_output.append(output[indices[0]][indices[1]][indices[2]])
selected_output = np.array(selected_output)
final_boxes = []
final_scores = []
for i,pos in enumerate(object_positions):
for j in range(2):
if selected_output[i][j*5]>THRESH:
output_box = np.array(output[pos[0]][pos[1]][pos[2]][(j*5)+1:(j*5)+5], dtype=float)
x_centre = (np.array(pos[1], dtype=float) + output_box[0])*32
y_centre = (np.array(pos[2], dtype=float) + output_box[1])*32
x_width, y_height = abs(W*output_box[2]), abs(H*output_box[3])
x_min, y_min = int(x_centre - (x_width/2)), int(y_centre-(y_height/2))
x_max, y_max = int(x_centre+(x_width/2)), int(y_centre + (y_height/2))
if(x_min<0):x_min=0
if(y_min<0):y_min=0
if(x_max<0):x_max=0
if(y_max<0):y_max=0
final_boxes.append(
[x_min, y_min, x_max, y_max, str(classes[np.argmax(selected_output[..., 10:], axis=-1)[i]])]
)
final_scores.append(selected_output[i][j*5])
final_boxes = np.array(final_boxes)
nms_output = nms(final_boxes, final_scores, 0.3)
print(nms_output)
for i in nms_output:
cv2.rectangle(
img,
(int(i[0]), int(i[1])),
(int(i[2]), int(i[3])), (255, 0, 0)
)
cv2.putText(
img,
i[-1],
(int(i[0]), int(i[1])+15),
cv2.FONT_HERSHEY_PLAIN, 1, (255, 0, 0), 1
)
output_pil_img = Image.fromarray(np.uint8(img)).convert('RGB')
return output_pil_img
except:
return input_image