Spaces:
Runtime error
Runtime error
# Gradio YOLOv5 Det v0.1 | |
# 创建人:曾逸夫 | |
# 创建时间:2022-04-03 | |
# email:zyfiy1314@163.com | |
# https://gitee.com/CV_Lab/gradio_yolov5_det | |
import argparse | |
import csv | |
import sys | |
import gradio as gr | |
import torch | |
import yaml | |
from PIL import Image | |
ROOT_PATH = sys.path[0] # 根目录 | |
# 模型路径 | |
model_path = "ultralytics/yolov5" | |
# 模型名称临时变量 | |
model_name_tmp = "" | |
# 设备临时变量 | |
device_tmp = "" | |
def parse_args(known=False): | |
parser = argparse.ArgumentParser(description="Gradio YOLOv5 Det v0.1") | |
parser.add_argument( | |
"--model_name", "-mn", default="yolov5s", type=str, help="model name" | |
) | |
parser.add_argument( | |
"--model_cfg", | |
"-mc", | |
default="./model_config/model_name_p5_all.yaml", | |
type=str, | |
help="model config", | |
) | |
parser.add_argument( | |
"--cls_name", | |
"-cls", | |
default="./cls_name/cls_name.yaml", | |
type=str, | |
help="cls name", | |
) | |
parser.add_argument( | |
"--nms_conf", | |
"-conf", | |
default=0.5, | |
type=float, | |
help="model NMS confidence threshold", | |
) | |
parser.add_argument( | |
"--nms_iou", "-iou", default=0.45, type=float, help="model NMS IoU threshold" | |
) | |
parser.add_argument( | |
"--label_dnt_show", | |
"-lds", | |
action="store_false", | |
default=True, | |
help="label show", | |
) | |
parser.add_argument( | |
"--device", | |
"-dev", | |
default="cpu", | |
type=str, | |
help="cuda or cpu, hugging face only cpu", | |
) | |
args = parser.parse_known_args()[0] if known else parser.parse_args() | |
return args | |
# 模型加载 | |
def model_loading(model_name, device): | |
# 加载本地模型 | |
model = torch.hub.load(model_path, model_name, force_reload=True, device=device) | |
return model | |
# 检测信息 | |
def export_json(results, model, img_size): | |
return [ | |
[ | |
{ | |
"id": int(i), | |
"class": int(result[i][5]), | |
"class_name": model.model.names[int(result[i][5])], | |
"normalized_box": { | |
"x0": round(result[i][:4].tolist()[0], 6), | |
"y0": round(result[i][:4].tolist()[1], 6), | |
"x1": round(result[i][:4].tolist()[2], 6), | |
"y1": round(result[i][:4].tolist()[3], 6), | |
}, | |
"confidence": round(float(result[i][4]), 2), | |
"fps": round(1000 / float(results.t[1]), 2), | |
"width": img_size[0], | |
"height": img_size[1], | |
} | |
for i in range(len(result)) | |
] | |
for result in results.xyxyn | |
] | |
# YOLOv5图片检测函数 | |
def yolo_det(img, device, model_name, conf, iou, label_opt, model_cls): | |
global model, model_name_tmp, device_tmp | |
if model_name_tmp != model_name: | |
# 模型判断,避免反复加载 | |
model_name_tmp = model_name | |
model = model_loading(model_name_tmp, device) | |
elif device_tmp != device: | |
device_tmp = device | |
model = model_loading(model_name_tmp, device) | |
# -----------模型调参----------- | |
model.conf = conf # NMS 置信度阈值 | |
model.iou = iou # NMS IOU阈值 | |
model.max_det = 1000 # 最大检测框数 | |
model.classes = model_cls # 模型类别 | |
results = model(img) # 检测 | |
results.render(labels=label_opt) # 渲染 | |
det_img = Image.fromarray(results.imgs[0]) # 检测图片 | |
det_json = export_json(results, model, img.size)[0] # 检测信息 | |
return det_img, det_json | |
# yaml文件解析 | |
def yaml_parse(file_path): | |
return yaml.safe_load(open(file_path, "r", encoding="utf-8").read()) | |
def main(args): | |
global model | |
slider_step = 0.05 # 滑动步长 | |
nms_conf = args.nms_conf | |
nms_iou = args.nms_iou | |
label_opt = args.label_dnt_show | |
model_name = args.model_name | |
model_cfg = args.model_cfg | |
cls_name = args.cls_name | |
device = args.device | |
# 模型加载 | |
model = model_loading(model_name, device) | |
# 模型名称 | |
# model_names = [i[0] for i in list(csv.reader(open(model_cfg)))] # csv版 | |
model_names = yaml_parse(model_cfg).get("model_names") # yaml版 | |
# 类别名称 | |
# model_cls_name = [i[0] for i in list(csv.reader(open(cls_name)))] # csv版 | |
model_cls_name = yaml_parse(cls_name).get("model_cls_name") # yaml版 | |
# -------------------输入组件------------------- | |
inputs_img = gr.inputs.Image(type="pil", label="原始图片") | |
device = gr.inputs.Dropdown( | |
choices=["cpu"], default=device, type="value", label="设备" | |
) | |
inputs_model = gr.inputs.Dropdown( | |
choices=model_names, default=model_name, type="value", label="模型" | |
) | |
input_conf = gr.inputs.Slider( | |
0, 1, step=slider_step, default=nms_conf, label="置信度阈值" | |
) | |
inputs_iou = gr.inputs.Slider( | |
0, 1, step=slider_step, default=nms_iou, label="IoU 阈值" | |
) | |
inputs_label = gr.inputs.Checkbox(default=label_opt, label="标签显示") | |
inputs_clsName = gr.inputs.CheckboxGroup( | |
choices=model_cls_name, default=model_cls_name, type="index", label="类别" | |
) | |
# 输入参数 | |
inputs = [ | |
inputs_img, # 输入图片 | |
device, # 设备 | |
inputs_model, # 模型 | |
input_conf, # 置信度阈值 | |
inputs_iou, # IoU阈值 | |
inputs_label, # 标签显示 | |
inputs_clsName, # 类别 | |
] | |
# 输出参数 | |
outputs = gr.outputs.Image(type="pil", label="检测图片") | |
outputs02 = gr.outputs.JSON(label="检测信息") | |
# 标题 | |
title = "基于Gradio的YOLOv5通用目标检测系统" | |
# 描述 | |
description = "<div align='center'>可自定义目标检测模型、安装简单、使用方便</div>" | |
gr.close_all() | |
# 接口 | |
gr.Interface( | |
fn=yolo_det, | |
inputs=inputs, | |
outputs=[outputs, outputs02], | |
title=title, | |
description=description, | |
theme="seafoam", | |
# live=True, # 实时变更输出 | |
flagging_dir="run" # 输出目录 | |
# ).launch(inbrowser=True, auth=['admin', 'admin']) | |
).launch( | |
inbrowser=True, # 自动打开默认浏览器 | |
show_tips=True, # 自动显示gradio最新功能 | |
favicon_path="./icon/logo.ico", | |
) | |
if __name__ == "__main__": | |
args = parse_args() | |
main(args) | |