import gradio as gr import numpy as np import torch import cv2 import os from random import randint from vision.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd, create_mobilenetv1_ssd_predictor device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("device: %s" % device) torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True default_models = { "ssd": "weights/mb1-ssd-bestmodel.pth", "label_path": "weights/labels.txt" } class_names = [name.strip() for name in open(default_models["label_path"]).readlines()] net = create_mobilenetv1_ssd(len(class_names), is_test=True) try: net.load(default_models["ssd"]) predictor = create_mobilenetv1_ssd_predictor(net, candidate_size=200) except: print("The net type is wrong. It should be one of mb1-ssd and mb1-ssd-lite.") colors = [np.random.choice(range(256), size=3) for i in range(len(class_names))] def detection(image): boxes, labels, probs = predictor.predict(image, 10, 0.4) for i in range(boxes.size(0)): box = boxes[i, :] box = box.numpy() box = np.array(box, dtype=np.int) color = colors[labels[i]] cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), (int(color[0]), int(color[1]), int(color[2])), thickness=4) label = f"{class_names[labels[i]]}: {probs[i]:.2f}" # cv2.putText(image, label, # (box[0] + 20, box[1] + 40), # cv2.FONT_HERSHEY_SIMPLEX, # 1, # font scale # (255, 0, 255), # 2) # line type s = f"Found {len(probs)} objects" return image, s title = " AISeed AI Application Demo " description = "# A Demo of Deep Learning for Object Detection" example_list = [["examples/" + example] for example in os.listdir("examples")] with gr.Blocks() as demo: demo.title = title gr.Markdown(description) with gr.Row(): with gr.Column(): im = gr.Image(label="Input Image") im_2 = gr.Image(label="Output Image") with gr.Column(): text = gr.Textbox(label="Number of objects") btn1 = gr.Button(value="Who wears mask?") btn1.click(detection, inputs=[im], outputs=[im_2, text]) gr.Examples(examples=example_list, inputs=[im], outputs=[im_2]) if __name__ == "__main__": demo.launch()