Spaces:
Runtime error
Runtime error
update
Browse files- .gitattributes +1 -0
- README.md +75 -13
- app.py +145 -78
- main.py +71 -52
- requirements.txt +1 -0
- webui.png +0 -0
.gitattributes
CHANGED
@@ -36,3 +36,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
36 |
deep_sort/deep_sort/deep/checkpoint/ckpt.t7 filter=lfs diff=lfs merge=lfs -text
|
37 |
demo.png filter=lfs diff=lfs merge=lfs -text
|
38 |
test.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
36 |
deep_sort/deep_sort/deep/checkpoint/ckpt.t7 filter=lfs diff=lfs merge=lfs -text
|
37 |
demo.png filter=lfs diff=lfs merge=lfs -text
|
38 |
test.mp4 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
webui.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,13 +1,75 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<h1> yolov8-deepsort-tracking </h1>
|
3 |
+
|
4 |
+
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/KdaiP/yolov8-deepsort-tracking)
|
5 |
+
</div>
|
6 |
+
|
7 |
+
![示例图片](./demo.png)
|
8 |
+
|
9 |
+
opencv+yolov8+deepsort的行人检测与跟踪。当然,也可以识别车辆等其他类别。
|
10 |
+
|
11 |
+
# 更新历史
|
12 |
+
|
13 |
+
2024/2/11更新:清理代码,完善注释。WebUI新增识别目标选择、进度条显示、终止推理、示例等功能。
|
14 |
+
|
15 |
+
2023/10/17更新:简化代码,删除不必要的依赖。解决webui上传视频不会清空tracker ID的问题。
|
16 |
+
|
17 |
+
2023/7/4更新:加入了一个基于Gradio的WebUI界面
|
18 |
+
|
19 |
+
## 安装
|
20 |
+
环境:Python>=3.8
|
21 |
+
|
22 |
+
本项目需要pytorch,建议手动在[pytorch官网](https://pytorch.org/get-started/locally/)根据自己的平台和CUDA环境安装对应的版本。
|
23 |
+
|
24 |
+
pytorch的详细安装教程可以参照[Conda Quickstart Guide for Ultralytics](https://docs.ultralytics.com/guides/conda-quickstart/)
|
25 |
+
|
26 |
+
安装完pytorch后,需要通过以下命令来安装其他依赖:
|
27 |
+
|
28 |
+
```shell
|
29 |
+
$ pip install -r requirements.txt
|
30 |
+
```
|
31 |
+
|
32 |
+
如果需要使用GUI,需要通过以下命令安装tqdm进度条和Gradio库:
|
33 |
+
|
34 |
+
```shell
|
35 |
+
$ pip install tqdm gradio
|
36 |
+
```
|
37 |
+
|
38 |
+
|
39 |
+
## 配置(非WebUI)
|
40 |
+
|
41 |
+
在main.py中修改以下代码,将输入视频路径换成你要处理的视频的路径:
|
42 |
+
|
43 |
+
```python
|
44 |
+
input_path = "test.mp4"
|
45 |
+
```
|
46 |
+
|
47 |
+
模型默认使用Ultralytics官方的YOLOv8n模型:
|
48 |
+
|
49 |
+
```python
|
50 |
+
model = YOLO("yolov8n.pt")
|
51 |
+
```
|
52 |
+
|
53 |
+
第一次使用会自动从官网下载模型,如果网速过慢,可以在[ultralytics的官方文档](https://docs.ultralytics.com/tasks/detect/)下载模型,然后将模型文件拷贝到程序所在目录下。
|
54 |
+
|
55 |
+
## 运行(非WebUI)
|
56 |
+
|
57 |
+
运行main.py
|
58 |
+
|
59 |
+
推理完成后,终端会显示输出视频所在的路径。
|
60 |
+
|
61 |
+
## WebUI界面的配置和运行
|
62 |
+
|
63 |
+
demo: [Huggingface demo](https://huggingface.co/spaces/KdaiP/yolov8-deepsort-tracking)
|
64 |
+
|
65 |
+
|
66 |
+
运行app.py,如果控制台出现以下消息代表成功运行:
|
67 |
+
```shell
|
68 |
+
Running on local URL: http://127.0.0.1:6006
|
69 |
+
To create a public link, set `share=True` in `launch()`
|
70 |
+
```
|
71 |
+
|
72 |
+
浏览器打开该URL即可使用WebUI界面
|
73 |
+
|
74 |
+
![WebUI](./webui.png)
|
75 |
+
|
app.py
CHANGED
@@ -3,13 +3,47 @@ import cv2
|
|
3 |
import numpy as np
|
4 |
import tempfile
|
5 |
from pathlib import Path
|
|
|
|
|
6 |
import deep_sort.deep_sort.deep_sort as ds
|
7 |
|
8 |
import gradio as gr
|
9 |
|
10 |
-
#
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def putTextWithBackground(
|
14 |
img,
|
15 |
text,
|
@@ -53,113 +87,146 @@ def putTextWithBackground(
|
|
53 |
)
|
54 |
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
"""
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频的帧率
|
71 |
-
size = (
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
output_video = cv2.VideoWriter() # 初始化视频写入
|
76 |
-
outputPath = tempfile.mkdtemp() # 创建输出视频的临时文件夹的路径
|
77 |
-
|
78 |
-
# 输出格式为XVID格式的avi文件
|
79 |
# 如果需要使用h264编码或者需要保存为其他格式,可能需要下载openh264-1.8.0
|
80 |
# 下载地址:https://github.com/cisco/openh264/releases/tag/v1.8.0
|
81 |
# 下载完成后将dll文件放在当前文件夹内
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
video_save_path = Path(outputPath) / "output.avi" # 创建输出视频路径
|
86 |
-
if output_type == "mp4": # 浏览器只支持播放h264编码的mp4视频文件
|
87 |
-
fourcc = cv2.VideoWriter_fourcc(*"h264")
|
88 |
-
video_save_path = Path(outputPath) / "output.mp4"
|
89 |
-
|
90 |
-
output_video.open(video_save_path.as_posix(), fourcc, fps, size, True)
|
91 |
# 对每一帧图片进行读取和处理
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
if not (success):
|
95 |
break
|
96 |
|
97 |
-
#
|
98 |
results = model(frame, stream=True)
|
99 |
|
100 |
-
|
101 |
-
confarray =
|
102 |
-
|
103 |
-
# 读取目标检测推理结果
|
104 |
-
# 参考: https://docs.ultralytics.com/modes/predict/#working-with-results
|
105 |
-
for r in results:
|
106 |
-
boxes = r.boxes
|
107 |
-
for box in boxes:
|
108 |
-
x1, y1, x2, y2 = map(int, box.xywh[0]) # 提取矩形框左上和右下的点,并将tensor类型转为整型
|
109 |
-
conf = round(float(box.conf[0]), 2) # 对conf四舍五入到2位小数
|
110 |
-
cls = int(box.cls[0]) # 获取物体类别标签
|
111 |
|
112 |
-
|
113 |
-
detections = np.vstack((detections,np.array([x1,y1,x2,y2])))
|
114 |
-
confarray.append(conf)
|
115 |
-
|
116 |
-
# 使用deepsort进行跟踪
|
117 |
resultsTracker = tracker.update(detections, confarray, frame)
|
|
|
118 |
for x1, y1, x2, y2, Id in resultsTracker:
|
119 |
-
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
|
120 |
|
121 |
-
# 绘制bounding box
|
122 |
cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 255), 3)
|
123 |
-
putTextWithBackground(
|
124 |
-
frame,
|
125 |
-
str(int(Id)),
|
126 |
-
(max(-10, x1), max(40, y1)),
|
127 |
-
font_scale=1.5,
|
128 |
-
text_color=(255, 255, 255),
|
129 |
-
bg_color=(255, 0, 255),
|
130 |
-
)
|
131 |
|
132 |
-
output_video.write(frame) #
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
137 |
|
138 |
|
139 |
if __name__ == "__main__":
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
# Gradio参考文档:https://www.gradio.app/guides/blocks-and-event-listeners
|
144 |
with gr.Blocks() as demo:
|
145 |
with gr.Tab("Tracking"):
|
|
|
146 |
gr.Markdown(
|
147 |
"""
|
148 |
-
#
|
149 |
基于opencv + YoloV8 + deepsort
|
150 |
"""
|
151 |
)
|
|
|
152 |
with gr.Row():
|
|
|
153 |
with gr.Column():
|
154 |
-
|
155 |
-
model = gr.Dropdown(model_list, value=
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
with gr.Column():
|
157 |
-
output = gr.Video()
|
158 |
-
output_path = gr.Textbox(label="Output path")
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
demo.launch()
|
|
|
3 |
import numpy as np
|
4 |
import tempfile
|
5 |
from pathlib import Path
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
|
8 |
import deep_sort.deep_sort.deep_sort as ds
|
9 |
|
10 |
import gradio as gr
|
11 |
|
12 |
+
# 控制处理流程是否终止
|
13 |
+
should_continue = True
|
14 |
+
|
15 |
+
def get_detectable_classes(model_file):
|
16 |
+
"""获取给定模型文件可以检测的类别。
|
17 |
+
|
18 |
+
参数:
|
19 |
+
- model_file: 模型文件名。
|
20 |
|
21 |
+
返回:
|
22 |
+
- class_names: 可检测的类别名称。
|
23 |
+
"""
|
24 |
+
model = YOLO(model_file)
|
25 |
+
class_names = list(model.names.values()) # 直接获取类别名称列表
|
26 |
+
del model # 删除模型实例释放资源
|
27 |
+
return class_names
|
28 |
+
|
29 |
+
# 用于终止视频处理
|
30 |
+
def stop_processing():
|
31 |
+
global should_continue
|
32 |
+
should_continue = False # 更改变量来停止处理
|
33 |
+
return "尝试终止处理..."
|
34 |
+
|
35 |
+
# 用于开始视频处理
|
36 |
+
# gr.Progress(track_tqdm=True)用于捕获tqdm进度条,从而在GUI上显示进度
|
37 |
+
def start_processing(input_path, output_path, detect_class, model, progress=gr.Progress(track_tqdm=True)):
|
38 |
+
global should_continue
|
39 |
+
should_continue = True
|
40 |
+
|
41 |
+
detect_class = int(detect_class)
|
42 |
+
model = YOLO(model)
|
43 |
+
tracker = ds.DeepSort("deep_sort/deep_sort/deep/checkpoint/ckpt.t7")
|
44 |
+
output_video_path = detect_and_track(input_path, output_path, detect_class, model, tracker)
|
45 |
+
return output_video_path, output_video_path
|
46 |
+
|
47 |
def putTextWithBackground(
|
48 |
img,
|
49 |
text,
|
|
|
87 |
)
|
88 |
|
89 |
|
90 |
+
def extract_detections(results, detect_class):
|
91 |
+
"""
|
92 |
+
从模型结果中提取和处理检测信息。
|
93 |
+
- results: YoloV8模型预测结果,包含检测到的物体的位置、类别和置信度等信息。
|
94 |
+
- detect_class: 需要提取的目标类别的索引。
|
95 |
+
参考: https://docs.ultralytics.com/modes/predict/#working-with-results
|
96 |
"""
|
97 |
+
|
98 |
+
# 初始化一个空的二维numpy数组,用于存放检测到的目标的位置信息
|
99 |
+
# 如果视频中没有需要提取的目标类别,如果不初始化,会导致tracker报错
|
100 |
+
detections = np.empty((0, 4))
|
101 |
+
|
102 |
+
confarray = [] # 初始化一个空列表,用于存放检测到的目标的置信度。
|
103 |
+
|
104 |
+
# 遍历检测结果
|
105 |
+
# 参考:https://docs.ultralytics.com/modes/predict/#working-with-results
|
106 |
+
for r in results:
|
107 |
+
for box in r.boxes:
|
108 |
+
# 如果检测到的目标类别与指定的目标类别相匹配,提取目标的位置信息和置信度
|
109 |
+
if box.cls[0].int() == detect_class:
|
110 |
+
x1, y1, x2, y2 = box.xywh[0].int().tolist() # 提取目标的位置信息,并从tensor转换为整数列表。
|
111 |
+
conf = round(box.conf[0].item(), 2) # 提取目标的置信度,从tensor中取出浮点数结果,并四舍五入到小数点后两位。
|
112 |
+
detections = np.vstack((detections, np.array([x1, y1, x2, y2]))) # 将目标的位置信息添加到detections数组中。
|
113 |
+
confarray.append(conf) # 将目标的置信度添加到confarray列表中。
|
114 |
+
return detections, confarray # 返回提取出的位置信息和置信度。
|
115 |
|
116 |
+
# 视频处理
|
117 |
+
def detect_and_track(input_path: str, output_path: str, detect_class: int, model, tracker) -> Path:
|
118 |
+
"""
|
119 |
+
处理视频,检测并跟踪目标。
|
120 |
+
- input_path: 输入视频文件的路径。
|
121 |
+
- output_path: 处理后视频保存的路径。
|
122 |
+
- detect_class: 需要检测和跟踪的目标类别的索引。
|
123 |
+
- model: 用于目标检测的模型。
|
124 |
+
- tracker: 用于目标跟踪的模型。
|
125 |
+
"""
|
126 |
+
global should_continue
|
127 |
+
cap = cv2.VideoCapture(input_path) # 使用OpenCV打开视频文件。
|
128 |
+
if not cap.isOpened(): # 检查视频文件是否成功打开。
|
129 |
+
print(f"Error opening video file {input_path}")
|
130 |
+
return None
|
131 |
+
|
132 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 获取视频总帧数
|
133 |
fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频的帧率
|
134 |
+
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) # 获取视频的分辨率(宽度和高度)。
|
135 |
+
output_video_path = Path(output_path) / "output.avi" # 设置输出视频的保存路径。
|
136 |
+
|
137 |
+
# 设置视频编码格式为XVID格式的avi文件
|
|
|
|
|
|
|
|
|
138 |
# 如果需要使用h264编码或者需要保存为其他格式,可能需要下载openh264-1.8.0
|
139 |
# 下载地址:https://github.com/cisco/openh264/releases/tag/v1.8.0
|
140 |
# 下载完成后将dll文件放在当前文件夹内
|
141 |
+
fourcc = cv2.VideoWriter_fourcc(*"XVID")
|
142 |
+
output_video = cv2.VideoWriter(output_video_path.as_posix(), fourcc, fps, size, isColor=True) # 创建一个VideoWriter对象用于写视频。
|
143 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
# 对每一帧图片进行读取和处理
|
145 |
+
# 使用tqdm显示处理进度。
|
146 |
+
for _ in tqdm(range(total_frames)):
|
147 |
+
# 如果全局变量should_continue为False(通常由于GUI上按下Stop按钮),则终止目标检测和跟踪,返回已处理的视频部分
|
148 |
+
if not should_continue:
|
149 |
+
print('stopping process')
|
150 |
+
break
|
151 |
+
|
152 |
+
success, frame = cap.read() # 逐帧读取视频。
|
153 |
+
|
154 |
+
# 如果读取失败(或者视频已处理完毕),则跳出循环。
|
155 |
if not (success):
|
156 |
break
|
157 |
|
158 |
+
# 使用YoloV8模型对当前帧进行目标检测。
|
159 |
results = model(frame, stream=True)
|
160 |
|
161 |
+
# 从预测结果中提取检测信息。
|
162 |
+
detections, confarray = extract_detections(results, detect_class)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
+
# 使用deepsort模型对检测到的目标进行跟踪。
|
|
|
|
|
|
|
|
|
165 |
resultsTracker = tracker.update(detections, confarray, frame)
|
166 |
+
|
167 |
for x1, y1, x2, y2, Id in resultsTracker:
|
168 |
+
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2]) # 将位置信息转换为整数。
|
169 |
|
170 |
+
# 绘制bounding box和文本
|
171 |
cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 255), 3)
|
172 |
+
putTextWithBackground(frame, str(int(Id)), (max(-10, x1), max(40, y1)), font_scale=1.5, text_color=(255, 255, 255), bg_color=(255, 0, 255))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
+
output_video.write(frame) # 将处理后的帧写入到输出视频文件中。
|
175 |
+
|
176 |
+
output_video.release() # 释放VideoWriter对象。
|
177 |
+
cap.release() # 释放视频文件。
|
178 |
+
|
179 |
+
print(f'output dir is: {output_video_path}')
|
180 |
+
return output_video_path
|
181 |
|
182 |
|
183 |
if __name__ == "__main__":
|
184 |
+
|
185 |
+
# YoloV8官方模型列表,从左往右由小到大,第一次使用会自动下载
|
186 |
+
model_list = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt", "yolov8l.pt", "yolov8x.pt"]
|
187 |
+
|
188 |
+
# 获取YoloV8模型可以检测的所有类别,默认调用model_list中第一个模型
|
189 |
+
detect_classes = get_detectable_classes(model_list[0])
|
190 |
+
|
191 |
+
# gradio界面的输入示例,包含一个测试视频文件路径、一个随机生成的输出目录、检测的类别、使用的模型
|
192 |
+
examples = [["test.mp4", tempfile.mkdtemp(), detect_classes[0], model_list[0]],]
|
193 |
+
|
194 |
+
# 使用Gradio的Blocks创建一个GUI界面
|
195 |
# Gradio参考文档:https://www.gradio.app/guides/blocks-and-event-listeners
|
196 |
with gr.Blocks() as demo:
|
197 |
with gr.Tab("Tracking"):
|
198 |
+
# 使用Markdown显示文本信息,介绍界面的功能
|
199 |
gr.Markdown(
|
200 |
"""
|
201 |
+
# 目标检测与跟踪
|
202 |
基于opencv + YoloV8 + deepsort
|
203 |
"""
|
204 |
)
|
205 |
+
# 行容器,水平排列元素
|
206 |
with gr.Row():
|
207 |
+
# 列容器,垂直排列元素
|
208 |
with gr.Column():
|
209 |
+
input_path = gr.Video(label="Input video") # 视频输入控件,用于上传视频文件
|
210 |
+
model = gr.Dropdown(model_list, value=0, label="Model") # 下拉菜单控件,用于选择模型
|
211 |
+
detect_class = gr.Dropdown(detect_classes, value=0, label="Class", type='index') # 下拉菜单控件,用于选择要检测的目标类别
|
212 |
+
output_dir = gr.Textbox(label="Output dir", value=tempfile.mkdtemp()) # 文本框控件,用于指定输出视频的保存路径,默认为一个临时生成的目录
|
213 |
+
with gr.Row():
|
214 |
+
# 创建两个按钮控件,分别用于开始处理和停止处理
|
215 |
+
start_button = gr.Button("Process")
|
216 |
+
stop_button = gr.Button("Stop")
|
217 |
with gr.Column():
|
218 |
+
output = gr.Video() # 视频显示控件,展示处理后的输出视频
|
219 |
+
output_path = gr.Textbox(label="Output path") # 文本框控件,用于显示输出视频的文件路径
|
220 |
+
|
221 |
+
# 添加示例到GUI中,允许用户选择预定义的输入进行快速测试
|
222 |
+
gr.Examples(examples,label="Examples",
|
223 |
+
inputs=[input_path, output_dir, detect_class, model],
|
224 |
+
outputs=[output, output_path],
|
225 |
+
fn=start_processing, # 指定处理示例时调用的函数
|
226 |
+
cache_examples=False) # 禁用示例缓存
|
227 |
+
|
228 |
+
# 将按钮与处理函数绑定
|
229 |
+
start_button.click(start_processing, inputs=[input_path, output_dir, detect_class, model], outputs=[output, output_path])
|
230 |
+
stop_button.click(stop_processing)
|
231 |
|
232 |
demo.launch()
|
main.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
-
from ultralytics import YOLO
|
2 |
-
import cv2
|
3 |
-
import numpy as np
|
4 |
import tempfile
|
5 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
6 |
import deep_sort.deep_sort.deep_sort as ds
|
7 |
|
8 |
def putTextWithBackground(img, text, origin, font=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1, text_color=(255, 255, 255), bg_color=(0, 0, 0), thickness=1):
|
@@ -28,91 +29,109 @@ def putTextWithBackground(img, text, origin, font=cv2.FONT_HERSHEY_SIMPLEX, font
|
|
28 |
# 在矩形上绘制文本
|
29 |
text_origin = (origin[0], origin[1] - 5) # 从左上角的位置减去5来留出一些边距
|
30 |
cv2.putText(img, text, text_origin, font, font_scale, text_color, thickness, lineType=cv2.LINE_AA)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
# 视频处理
|
33 |
-
def
|
34 |
-
"""
|
35 |
-
|
36 |
-
|
37 |
-
:
|
|
|
|
|
|
|
38 |
"""
|
39 |
-
#
|
40 |
-
|
|
|
|
|
|
|
41 |
fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频的帧率
|
42 |
-
size = (
|
43 |
-
|
44 |
-
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
|
45 |
-
) # 获取视频的大小
|
46 |
-
output_video = cv2.VideoWriter() # 初始化视频写入
|
47 |
|
48 |
-
#
|
49 |
# 如果需要使用h264编码或者需要保存为其他格式,可能需要下载openh264-1.8.0
|
50 |
# 下载地址:https://github.com/cisco/openh264/releases/tag/v1.8.0
|
51 |
# 下载完成后将dll文件放在当前文件夹内
|
52 |
fourcc = cv2.VideoWriter_fourcc(*"XVID")
|
53 |
-
|
54 |
-
|
55 |
-
output_video.open(video_save_path.as_posix(), fourcc, fps, size, isColor=True)
|
56 |
|
57 |
# 对每一帧图片进行读取和处理
|
58 |
while True:
|
59 |
-
success, frame = cap.read()
|
|
|
|
|
60 |
if not (success):
|
61 |
break
|
62 |
|
63 |
-
#
|
64 |
results = model(frame, stream=True)
|
65 |
|
66 |
-
|
67 |
-
confarray =
|
68 |
|
69 |
-
#
|
70 |
-
# 参考: https://docs.ultralytics.com/modes/predict/#working-with-results
|
71 |
-
for r in results:
|
72 |
-
boxes = r.boxes
|
73 |
-
for box in boxes:
|
74 |
-
x1, y1, x2, y2 = map(int, box.xywh[0]) # 提取矩形框左上和右下的点,并将tensor类型转为整型
|
75 |
-
conf = round(float(box.conf[0]), 2) # 对conf四舍五入到2位小数
|
76 |
-
cls = int(box.cls[0]) # 获取物体类别标签
|
77 |
-
|
78 |
-
if cls == detect_class:
|
79 |
-
detections = np.vstack((detections,np.array([x1,y1,x2,y2])))
|
80 |
-
confarray.append(conf)
|
81 |
-
|
82 |
-
# 使用deepsort进行跟踪
|
83 |
resultsTracker = tracker.update(detections, confarray, frame)
|
|
|
84 |
for x1, y1, x2, y2, Id in resultsTracker:
|
85 |
-
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
|
86 |
|
87 |
-
# 绘制bounding box
|
88 |
cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 255), 3)
|
89 |
putTextWithBackground(frame, str(int(Id)), (max(-10, x1), max(40, y1)), font_scale=1.5, text_color=(255, 255, 255), bg_color=(255, 0, 255))
|
90 |
|
91 |
-
output_video.write(frame) #
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
97 |
|
98 |
if __name__ == "__main__":
|
99 |
-
#
|
100 |
######
|
101 |
-
|
102 |
######
|
103 |
|
104 |
# 输出文件夹,默认为系统的临时文件夹路径
|
105 |
-
|
106 |
|
107 |
# 加载yoloV8模型权重
|
108 |
model = YOLO("yolov8n.pt")
|
109 |
|
110 |
-
#
|
111 |
# yoloV8官方模型的第一个类别为'person'
|
112 |
detect_class = 0
|
113 |
-
print(f"detecting {model.names[detect_class]}")
|
114 |
|
115 |
-
# 加载
|
116 |
tracker = ds.DeepSort("deep_sort/deep_sort/deep/checkpoint/ckpt.t7")
|
117 |
|
118 |
-
|
|
|
|
|
|
|
|
|
1 |
import tempfile
|
2 |
from pathlib import Path
|
3 |
+
import numpy as np
|
4 |
+
import cv2 # opencv-python
|
5 |
+
from ultralytics import YOLO
|
6 |
+
|
7 |
import deep_sort.deep_sort.deep_sort as ds
|
8 |
|
9 |
def putTextWithBackground(img, text, origin, font=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1, text_color=(255, 255, 255), bg_color=(0, 0, 0), thickness=1):
|
|
|
29 |
# 在矩形上绘制文本
|
30 |
text_origin = (origin[0], origin[1] - 5) # 从左上角的位置减去5来留出一些边距
|
31 |
cv2.putText(img, text, text_origin, font, font_scale, text_color, thickness, lineType=cv2.LINE_AA)
|
32 |
+
|
33 |
+
def extract_detections(results, detect_class):
|
34 |
+
"""
|
35 |
+
从模型结果中提取和处理检测信息。
|
36 |
+
- results: YoloV8模型预测结果,包含检测到的物体的位置、类别和置信度等信息。
|
37 |
+
- detect_class: 需要提取的目标类别的索引。
|
38 |
+
参考: https://docs.ultralytics.com/modes/predict/#working-with-results
|
39 |
+
"""
|
40 |
+
|
41 |
+
# 初始化一个空的二维numpy数组,用于存放检测到的目标的位置信息
|
42 |
+
# 如果视频中没有需要提取的目标类别,如果不初始化,会导致tracker报错
|
43 |
+
detections = np.empty((0, 4))
|
44 |
+
|
45 |
+
confarray = [] # 初始化一个空列表,用于存放检测到的目标的置信度。
|
46 |
+
|
47 |
+
# 遍历检测结果
|
48 |
+
# 参考:https://docs.ultralytics.com/modes/predict/#working-with-results
|
49 |
+
for r in results:
|
50 |
+
for box in r.boxes:
|
51 |
+
# 如果检测到的目标类别与指定的目标类别相匹配,提取目标的位置信息和置信度
|
52 |
+
if box.cls[0].int() == detect_class:
|
53 |
+
x1, y1, x2, y2 = box.xywh[0].int().tolist() # 提取目标的位置信息,并从tensor转换为整数列表。
|
54 |
+
conf = round(box.conf[0].item(), 2) # 提取目标的置信度,从tensor中取出浮点数结果,并四舍五入到小数点后两位。
|
55 |
+
detections = np.vstack((detections, np.array([x1, y1, x2, y2]))) # 将目标的位置信息添加到detections数组中。
|
56 |
+
confarray.append(conf) # 将目标的置信度添加到confarray列表中。
|
57 |
+
return detections, confarray # 返回提取出的位置信息和置信度。
|
58 |
|
59 |
# 视频处理
|
60 |
+
def detect_and_track(input_path: str, output_path: str, detect_class: int, model, tracker) -> Path:
|
61 |
+
"""
|
62 |
+
处理视频,检测并跟踪目标。
|
63 |
+
- input_path: 输入视频文件的路径。
|
64 |
+
- output_path: 处理后视频保存的路径。
|
65 |
+
- detect_class: 需要检测和跟踪的目标类别的索引。
|
66 |
+
- model: 用于目标检测的模型。
|
67 |
+
- tracker: 用于目标跟踪的模型。
|
68 |
"""
|
69 |
+
cap = cv2.VideoCapture(input_path) # 使用OpenCV打开视频文件。
|
70 |
+
if not cap.isOpened(): # 检查视频文件是否成功打开。
|
71 |
+
print(f"Error opening video file {input_path}")
|
72 |
+
return None
|
73 |
+
|
74 |
fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频的帧率
|
75 |
+
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) # 获取视频的分辨率(宽度和高度)。
|
76 |
+
output_video_path = Path(output_path) / "output.avi" # 设置输出视频的保存路径。
|
|
|
|
|
|
|
77 |
|
78 |
+
# 设置视频编码格式为XVID格式的avi文件
|
79 |
# 如果需要使用h264编码或者需要保存为其他格式,可能需要下载openh264-1.8.0
|
80 |
# 下载地址:https://github.com/cisco/openh264/releases/tag/v1.8.0
|
81 |
# 下载完成后将dll文件放在当前文件夹内
|
82 |
fourcc = cv2.VideoWriter_fourcc(*"XVID")
|
83 |
+
output_video = cv2.VideoWriter(output_video_path.as_posix(), fourcc, fps, size, isColor=True) # 创建一个VideoWriter对象用于写视频。
|
|
|
|
|
84 |
|
85 |
# 对每一帧图片进行读取和处理
|
86 |
while True:
|
87 |
+
success, frame = cap.read() # 逐帧读取视频。
|
88 |
+
|
89 |
+
# 如果读取失败(或者视频已处理完毕),则跳出循环。
|
90 |
if not (success):
|
91 |
break
|
92 |
|
93 |
+
# 使用YoloV8模型对当前帧进行目标检测。
|
94 |
results = model(frame, stream=True)
|
95 |
|
96 |
+
# 从预测结果中提取检测信息。
|
97 |
+
detections, confarray = extract_detections(results, detect_class)
|
98 |
|
99 |
+
# 使用deepsort模型对检测到的目标进行跟踪。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
resultsTracker = tracker.update(detections, confarray, frame)
|
101 |
+
|
102 |
for x1, y1, x2, y2, Id in resultsTracker:
|
103 |
+
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2]) # 将位置信息转换为整数。
|
104 |
|
105 |
+
# 绘制bounding box和文本
|
106 |
cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 255), 3)
|
107 |
putTextWithBackground(frame, str(int(Id)), (max(-10, x1), max(40, y1)), font_scale=1.5, text_color=(255, 255, 255), bg_color=(255, 0, 255))
|
108 |
|
109 |
+
output_video.write(frame) # 将处理后的帧写入到输出视频文件中。
|
110 |
+
|
111 |
+
output_video.release() # 释放VideoWriter对象。
|
112 |
+
cap.release() # 释放视频文件。
|
113 |
+
|
114 |
+
print(f'output dir is: {output_video_path}')
|
115 |
+
return output_video_path
|
116 |
|
117 |
if __name__ == "__main__":
|
118 |
+
# 指定输入视频的路径。
|
119 |
######
|
120 |
+
input_path = "test.mp4"
|
121 |
######
|
122 |
|
123 |
# 输出文件夹,默认为系统的临时文件夹路径
|
124 |
+
output_path = tempfile.mkdtemp() # 创建一个临时目录用于存放输出视频。
|
125 |
|
126 |
# 加载yoloV8模型权重
|
127 |
model = YOLO("yolov8n.pt")
|
128 |
|
129 |
+
# 设置需要检测和跟踪的目标类别
|
130 |
# yoloV8官方模型的第一个类别为'person'
|
131 |
detect_class = 0
|
132 |
+
print(f"detecting {model.names[detect_class]}") # model.names返回模型所支持的所有物体类别
|
133 |
|
134 |
+
# 加载DeepSort模型
|
135 |
tracker = ds.DeepSort("deep_sort/deep_sort/deep/checkpoint/ckpt.t7")
|
136 |
|
137 |
+
detect_and_track(input_path, output_path, detect_class, model, tracker)
|
requirements.txt
CHANGED
@@ -8,4 +8,5 @@ torch
|
|
8 |
matplotlib
|
9 |
|
10 |
# WebUI ---------------------------------------
|
|
|
11 |
# gradio
|
|
|
8 |
matplotlib
|
9 |
|
10 |
# WebUI ---------------------------------------
|
11 |
+
# tqdm
|
12 |
# gradio
|
webui.png
CHANGED
![]() |
![]() |
Git LFS Details
|