damo-yolo-webui / app.py
isLinXu
update app
699ee6d
import os
os.system("pip install tensorflow")
os.system("pip install modelscope")
os.system("pip install thop")
os.system("pip install easydict ")
import gradio as gr
import PIL.Image as Image
import torch
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
import cv2
import numpy as np
import random
import warnings
warnings.filterwarnings("ignore")
def object_detection(img_pil, confidence_threshold, device):
# 加载模型
p = pipeline(task='image-object-detection', model='damo/cv_tinynas_object-detection_damoyolo', device=device)
# 传入图片进行推理
result = p(img_pil)
# 读取图片
img_cv = cv2.cvtColor(np.asarray(img_pil), cv2.COLOR_RGB2BGR)
# 获取bbox和类别
scores = result['scores']
boxes = result['boxes']
labels = result['labels']
# 遍历每个bbox
for i in range(len(scores)):
# 只绘制置信度大于设定阈值的bbox
if scores[i] > confidence_threshold:
# 随机生成颜色
class_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
# 获取bbox坐标
x1, y1, x2, y2 = boxes[i]
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
# 绘制bbox
cv2.rectangle(img_cv, (x1, y1), (x2, y2), class_color, thickness=2)
# 绘制类别标签
label = f"{labels[i]}: {scores[i]:.2f}"
cv2.putText(img_cv, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, class_color, thickness=2)
img_pil = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
return img_pil
def download_test_image():
# Images
torch.hub.download_url_to_file(
'https://user-images.githubusercontent.com/59380685/266264420-21575a83-4057-41cf-8a4a-b3ea6f332d79.jpg',
'bus.jpg')
torch.hub.download_url_to_file(
'https://user-images.githubusercontent.com/59380685/266264536-82afdf58-6b9a-4568-b9df-551ee72cb6d9.jpg',
'dogs.jpg')
torch.hub.download_url_to_file(
'https://user-images.githubusercontent.com/59380685/266264600-9d0c26ca-8ba6-45f2-b53b-4dc98460c43e.jpg',
'zidane.jpg')
if __name__ == '__main__':
download_test_image()
# 定义输入和输出
input_image = gr.inputs.Image(type='pil')
input_slide = gr.inputs.Slider(minimum=0, maximum=1, step=0.05, default=0.5, label="Confidence Threshold")
input_device = gr.inputs.Radio(["cpu", "cuda", "gpu"], default="cpu")
output_image = gr.outputs.Image(type='pil')
examples = [['bus.jpg', 0.45, "cpu"],
['dogs.jpg', 0.45, "cpu"],
['zidane.jpg', 0.45, "cpu"]]
title = "DAMO-YOLO web demo"
description = "<div align='center'><img src='https://raw.githubusercontent.com/tinyvision/DAMO-YOLO/master/assets/logo.png' width='800''/><div>" \
"<p style='text-align: center'><a href='https://github.com/tinyvision/DAMO-YOLO'>DAMO-YOLO</a> DAMO-YOLO DAMO-YOLO DAMO-YOLO:一种快速准确的目标检测方法,采用了一些新技术,包括 NAS 主干、高效的 RepGFPN、ZeroHead、AlignedOTA 和蒸馏增强。" \
"DAMO-YOLO: a fast and accurate object detection method with some new techs, including NAS backbones, efficient RepGFPN, ZeroHead, AlignedOTA, and distillation enhancement..</p>"
article = "<p style='text-align: center'><a href='https://github.com/tinyvision/DAMO-YOLO'>DAMO-YOLO</a></p>" \
"<p style='text-align: center'><a href='https://github.com/isLinXu'>gradio build by gatilin</a></a></p>"
# 创建 Gradio 接口并运行
gr.Interface(
fn=object_detection,
inputs=[
input_image, input_slide, input_device
],
outputs=output_image,
title=title,
examples=examples,
description=description, article=article
).launch()