SSD-Detection / app.py
Kaori1707's picture
update app
b7478bf
import gradio as gr
import numpy as np
import torch
import cv2
import os
from vision.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd, create_mobilenetv1_ssd_predictor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: %s" % device)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
default_models = {
"ssd": "weights/mb1-ssd-bestmodel.pth",
"label_path": "weights/labels.txt"
}
class_names = [name.strip() for name in open(default_models["label_path"]).readlines()]
net = create_mobilenetv1_ssd(len(class_names), is_test=True)
try:
net.load(default_models["ssd"])
predictor = create_mobilenetv1_ssd_predictor(net, candidate_size=200)
except:
print("The net type is wrong. It should be one of mb1-ssd and mb1-ssd-lite.")
colors = [np.random.choice(range(256), size=3) for i in range(len(class_names))]
def detection(image):
boxes, labels, probs = predictor.predict(image, 10, 0.4)
for i in range(boxes.size(0)):
box = boxes[i, :]
box = box.numpy()
box = np.array(box, dtype=np.int32)
color = colors[labels[i]]
cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (int(color[0]), int(color[1]), int(color[2])), thickness=4)
label = f"{class_names[labels[i]]}: {probs[i]:.2f}"
# cv2.putText(image, label,
# (box[0] + 20, box[1] + 40),
# cv2.FONT_HERSHEY_SIMPLEX,
# 1, # font scale
# (255, 0, 255),
# 2) # line type
s = f"Found {len(probs)} objects"
return image, s
title = " AISeed AI Application Demo "
description = "# A Demo of Deep Learning for Object Detection"
example_list = [["examples/" + example] for example in os.listdir("examples")]
with gr.Blocks() as demo:
demo.title = title
gr.Markdown(description)
with gr.Tabs():
with gr.TabItem("for Images"):
with gr.Row():
with gr.Column():
im = gr.Image(label="Input Image")
im_2 = gr.Image(label="Output Image")
with gr.Column():
text = gr.Textbox(label="Number of objects")
btn1 = gr.Button(value="Who wears mask?")
btn1.click(detection, inputs=[im], outputs=[im_2, text])
gr.Examples(examples=example_list,
inputs=[im],
outputs=[im_2])
# with gr.TabItem("for Videos"):
# with gr.Row():
# with gr.Column():
# text1 = gr.Textbox(label="Number of objects")
# with gr.Column():
# text2 = gr.Textbox(label="Number of objects")
with gr.Tab("for streaming"):
with gr.Row():
input_video = gr.Image(source="webcam", streaming=True)
with gr.Column():
output_video = gr.Image(label="Video")
text1 = gr.Textbox(label="Number of objects")
input_video.change(detection, inputs = [input_video], outputs=[output_video, text1], show_progress=False)
if __name__ == "__main__":
demo.launch()