isLinXu commited on
Commit
699ee6d
1 Parent(s): a96c370

update app

Browse files
Files changed (2) hide show
  1. app.py +92 -0
  2. requirements.txt +19 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ os.system("pip install tensorflow")
4
+ os.system("pip install modelscope")
5
+ os.system("pip install thop")
6
+ os.system("pip install easydict ")
7
+
8
+ import gradio as gr
9
+ import PIL.Image as Image
10
+ import torch
11
+ from modelscope.pipelines import pipeline
12
+ from modelscope.utils.constant import Tasks
13
+ import cv2
14
+ import numpy as np
15
+ import random
16
+
17
+ import warnings
18
+
19
+ warnings.filterwarnings("ignore")
20
+
21
+ def object_detection(img_pil, confidence_threshold, device):
22
+ # 加载模型
23
+ p = pipeline(task='image-object-detection', model='damo/cv_tinynas_object-detection_damoyolo', device=device)
24
+
25
+ # 传入图片进行推理
26
+ result = p(img_pil)
27
+ # 读取图片
28
+ img_cv = cv2.cvtColor(np.asarray(img_pil), cv2.COLOR_RGB2BGR)
29
+ # 获取bbox和类别
30
+ scores = result['scores']
31
+ boxes = result['boxes']
32
+ labels = result['labels']
33
+ # 遍历每个bbox
34
+ for i in range(len(scores)):
35
+ # 只绘制置信度大于设定阈值的bbox
36
+ if scores[i] > confidence_threshold:
37
+ # 随机生成颜色
38
+ class_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
39
+ # 获取bbox坐标
40
+ x1, y1, x2, y2 = boxes[i]
41
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
42
+ # 绘制bbox
43
+ cv2.rectangle(img_cv, (x1, y1), (x2, y2), class_color, thickness=2)
44
+ # 绘制类别标签
45
+ label = f"{labels[i]}: {scores[i]:.2f}"
46
+ cv2.putText(img_cv, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, class_color, thickness=2)
47
+ img_pil = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
48
+ return img_pil
49
+
50
+
51
+ def download_test_image():
52
+ # Images
53
+ torch.hub.download_url_to_file(
54
+ 'https://user-images.githubusercontent.com/59380685/266264420-21575a83-4057-41cf-8a4a-b3ea6f332d79.jpg',
55
+ 'bus.jpg')
56
+ torch.hub.download_url_to_file(
57
+ 'https://user-images.githubusercontent.com/59380685/266264536-82afdf58-6b9a-4568-b9df-551ee72cb6d9.jpg',
58
+ 'dogs.jpg')
59
+ torch.hub.download_url_to_file(
60
+ 'https://user-images.githubusercontent.com/59380685/266264600-9d0c26ca-8ba6-45f2-b53b-4dc98460c43e.jpg',
61
+ 'zidane.jpg')
62
+
63
+
64
+ if __name__ == '__main__':
65
+ download_test_image()
66
+ # 定义输入和输出
67
+ input_image = gr.inputs.Image(type='pil')
68
+ input_slide = gr.inputs.Slider(minimum=0, maximum=1, step=0.05, default=0.5, label="Confidence Threshold")
69
+ input_device = gr.inputs.Radio(["cpu", "cuda", "gpu"], default="cpu")
70
+ output_image = gr.outputs.Image(type='pil')
71
+
72
+ examples = [['bus.jpg', 0.45, "cpu"],
73
+ ['dogs.jpg', 0.45, "cpu"],
74
+ ['zidane.jpg', 0.45, "cpu"]]
75
+ title = "DAMO-YOLO web demo"
76
+ description = "<div align='center'><img src='https://raw.githubusercontent.com/tinyvision/DAMO-YOLO/master/assets/logo.png' width='800''/><div>" \
77
+ "<p style='text-align: center'><a href='https://github.com/tinyvision/DAMO-YOLO'>DAMO-YOLO</a> DAMO-YOLO DAMO-YOLO DAMO-YOLO:一种快速准确的目标检测方法,采用了一些新技术,包括 NAS 主干、高效的 RepGFPN、ZeroHead、AlignedOTA 和蒸馏增强。" \
78
+ "DAMO-YOLO: a fast and accurate object detection method with some new techs, including NAS backbones, efficient RepGFPN, ZeroHead, AlignedOTA, and distillation enhancement..</p>"
79
+ article = "<p style='text-align: center'><a href='https://github.com/tinyvision/DAMO-YOLO'>DAMO-YOLO</a></p>" \
80
+ "<p style='text-align: center'><a href='https://github.com/isLinXu'>gradio build by gatilin</a></a></p>"
81
+
82
+ # 创建 Gradio 接口并运行
83
+ gr.Interface(
84
+ fn=object_detection,
85
+ inputs=[
86
+ input_image, input_slide, input_device
87
+ ],
88
+ outputs=output_image,
89
+ title=title,
90
+ examples=examples,
91
+ description=description, article=article
92
+ ).launch()
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2
2
+ ultralytics~=8.0.169
3
+ wget~=3.2
4
+ opencv-python~=4.6.0.66
5
+ numpy~=1.23.0
6
+ pillow~=9.4.0
7
+ gradio~=3.42.0
8
+ pyyaml~=6.0
9
+ wandb~=0.13.11
10
+ tqdm~=4.65.0
11
+ matplotlib~=3.7.1
12
+ pandas~=2.0.0
13
+ seaborn~=0.12.2
14
+ requests~=2.31.0
15
+ psutil~=5.9.4
16
+ thop~=0.1.1-2209072238
17
+ timm~=0.9.2
18
+ super-gradients~=3.2.0
19
+ openmim