KdaiP commited on
Commit
3dcbe08
1 Parent(s): 72c39aa
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. README.md +75 -13
  3. app.py +145 -78
  4. main.py +71 -52
  5. requirements.txt +1 -0
  6. webui.png +0 -0
.gitattributes CHANGED
@@ -36,3 +36,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  deep_sort/deep_sort/deep/checkpoint/ckpt.t7 filter=lfs diff=lfs merge=lfs -text
37
  demo.png filter=lfs diff=lfs merge=lfs -text
38
  test.mp4 filter=lfs diff=lfs merge=lfs -text
 
 
36
  deep_sort/deep_sort/deep/checkpoint/ckpt.t7 filter=lfs diff=lfs merge=lfs -text
37
  demo.png filter=lfs diff=lfs merge=lfs -text
38
  test.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ webui.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,75 @@
1
- ---
2
- title: Yolov8 Deepsort Tracking
3
- emoji: 👀
4
- colorFrom: red
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 3.48.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1> yolov8-deepsort-tracking </h1>
3
+
4
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/KdaiP/yolov8-deepsort-tracking)
5
+ </div>
6
+
7
+ ![示例图片](./demo.png)
8
+
9
+ opencv+yolov8+deepsort的行人检测与跟踪。当然,也可以识别车辆等其他类别。
10
+
11
+ # 更新历史
12
+
13
+ 2024/2/11更新:清理代码,完善注释。WebUI新增识别目标选择、进度条显示、终止推理、示例等功能。
14
+
15
+ 2023/10/17更新:简化代码,删除不必要的依赖。解决webui上传视频不会清空tracker ID的问题。
16
+
17
+ 2023/7/4更新:加入了一个基于Gradio的WebUI界面
18
+
19
+ ## 安装
20
+ 环境:Python>=3.8
21
+
22
+ 本项目需要pytorch,建议手动在[pytorch官网](https://pytorch.org/get-started/locally/)根据自己的平台和CUDA环境安装对应的版本。
23
+
24
+ pytorch的详细安装教程可以参照[Conda Quickstart Guide for Ultralytics](https://docs.ultralytics.com/guides/conda-quickstart/)
25
+
26
+ 安装完pytorch后,需要通过以下命令来安装其他依赖:
27
+
28
+ ```shell
29
+ $ pip install -r requirements.txt
30
+ ```
31
+
32
+ 如果需要使用GUI,需要通过以下命令安装tqdm进度条和Gradio库:
33
+
34
+ ```shell
35
+ $ pip install tqdm gradio
36
+ ```
37
+
38
+
39
+ ## 配置(非WebUI)
40
+
41
+ 在main.py中修改以下代码,将输入视频路径换成你要处理的视频的路径:
42
+
43
+ ```python
44
+ input_path = "test.mp4"
45
+ ```
46
+
47
+ 模型默认使用Ultralytics官方的YOLOv8n模型:
48
+
49
+ ```python
50
+ model = YOLO("yolov8n.pt")
51
+ ```
52
+
53
+ 第一次使用会自动从官网下载模型,如果网速过慢,可以在[ultralytics的官方文档](https://docs.ultralytics.com/tasks/detect/)下载模型,然后将模型文件拷贝到程序所在目录下。
54
+
55
+ ## 运行(非WebUI)
56
+
57
+ 运行main.py
58
+
59
+ 推理完成后,终端会显示输出视频所在的路径。
60
+
61
+ ## WebUI界面的配置和运行
62
+
63
+ demo: [Huggingface demo](https://huggingface.co/spaces/KdaiP/yolov8-deepsort-tracking)
64
+
65
+
66
+ 运行app.py,如果控制台出现以下消息代表成功运行:
67
+ ```shell
68
+ Running on local URL: http://127.0.0.1:6006
69
+ To create a public link, set `share=True` in `launch()`
70
+ ```
71
+
72
+ 浏览器打开该URL即可使用WebUI界面
73
+
74
+ ![WebUI](./webui.png)
75
+
app.py CHANGED
@@ -3,13 +3,47 @@ import cv2
3
  import numpy as np
4
  import tempfile
5
  from pathlib import Path
 
 
6
  import deep_sort.deep_sort.deep_sort as ds
7
 
8
  import gradio as gr
9
 
10
- # YoloV8官方模型,从左往右由小到大,第一次使用会自动下载
11
- model_list = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt", "yolov8l.pt", "yolov8x.pt"]
 
 
 
 
 
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def putTextWithBackground(
14
  img,
15
  text,
@@ -53,113 +87,146 @@ def putTextWithBackground(
53
  )
54
 
55
 
56
- # 视频处理
57
- def processVideo(inputPath, model):
58
- """处理视频,检测并跟踪行人。
59
-
60
- :param inputPath: 视频文件路径
61
- :return: 输出视频的路径
62
  """
63
- tracker = ds.DeepSort(
64
- "deep_sort/deep_sort/deep/checkpoint/ckpt.t7"
65
- ) # 加载deepsort权重文件
66
- model = YOLO(model) # 加载YOLO模型文件
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- # 读取视频文件
69
- cap = cv2.VideoCapture(inputPath)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频的帧率
71
- size = (
72
- int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
73
- int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
74
- ) # 获取视频的大小
75
- output_video = cv2.VideoWriter() # 初始化视频写入
76
- outputPath = tempfile.mkdtemp() # 创建输出视频的临时文件夹的路径
77
-
78
- # 输出格式为XVID格式的avi文件
79
  # 如果需要使用h264编码或者需要保存为其他格式,可能需要下载openh264-1.8.0
80
  # 下载地址:https://github.com/cisco/openh264/releases/tag/v1.8.0
81
  # 下载完成后将dll文件放在当前文件夹内
82
- output_type = "avi"
83
- if output_type == "avi":
84
- fourcc = cv2.VideoWriter_fourcc(*"XVID")
85
- video_save_path = Path(outputPath) / "output.avi" # 创建输出视频路径
86
- if output_type == "mp4": # 浏览器只支持播放h264编码的mp4视频文件
87
- fourcc = cv2.VideoWriter_fourcc(*"h264")
88
- video_save_path = Path(outputPath) / "output.mp4"
89
-
90
- output_video.open(video_save_path.as_posix(), fourcc, fps, size, True)
91
  # 对每一帧图片进行读取和处理
92
- while True:
93
- success, frame = cap.read()
 
 
 
 
 
 
 
 
94
  if not (success):
95
  break
96
 
97
- # 获取每一帧的目标检测推理结果
98
  results = model(frame, stream=True)
99
 
100
- detections = np.empty((0, 4)) # 存放bounding box结果
101
- confarray = [] # 存放每个检测结果的置信度
102
-
103
- # 读取目标检测推理结果
104
- # 参考: https://docs.ultralytics.com/modes/predict/#working-with-results
105
- for r in results:
106
- boxes = r.boxes
107
- for box in boxes:
108
- x1, y1, x2, y2 = map(int, box.xywh[0]) # 提取矩形框左上和右下的点,并将tensor类型转为整型
109
- conf = round(float(box.conf[0]), 2) # 对conf四舍五入到2位小数
110
- cls = int(box.cls[0]) # 获取物体类别标签
111
 
112
- if cls == detect_class:
113
- detections = np.vstack((detections,np.array([x1,y1,x2,y2])))
114
- confarray.append(conf)
115
-
116
- # 使用deepsort进行跟踪
117
  resultsTracker = tracker.update(detections, confarray, frame)
 
118
  for x1, y1, x2, y2, Id in resultsTracker:
119
- x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
120
 
121
- # 绘制bounding box
122
  cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 255), 3)
123
- putTextWithBackground(
124
- frame,
125
- str(int(Id)),
126
- (max(-10, x1), max(40, y1)),
127
- font_scale=1.5,
128
- text_color=(255, 255, 255),
129
- bg_color=(255, 0, 255),
130
- )
131
 
132
- output_video.write(frame) # 将处理后的图像写入视频
133
- output_video.release() # 释放
134
- cap.release() # 释放
135
- print(f"output dir is: {video_save_path.as_posix()}")
136
- return video_save_path.as_posix(), video_save_path.as_posix() # Gradio的视频控件实际读取的是文件路径
 
 
137
 
138
 
139
  if __name__ == "__main__":
140
- # 需要跟踪的物体类别
141
- detect_class = 0
142
-
 
 
 
 
 
 
 
 
143
  # Gradio参考文档:https://www.gradio.app/guides/blocks-and-event-listeners
144
  with gr.Blocks() as demo:
145
  with gr.Tab("Tracking"):
 
146
  gr.Markdown(
147
  """
148
- # YoloV8 + deepsort
149
  基于opencv + YoloV8 + deepsort
150
  """
151
  )
 
152
  with gr.Row():
 
153
  with gr.Column():
154
- input_video = gr.Video(label="Input video")
155
- model = gr.Dropdown(model_list, value="yolov8n.pt", label="Model")
 
 
 
 
 
 
156
  with gr.Column():
157
- output = gr.Video()
158
- output_path = gr.Textbox(label="Output path")
159
- button = gr.Button("Process")
160
-
161
- button.click(
162
- processVideo, inputs=[input_video, model], outputs=[output, output_path]
163
- )
 
 
 
 
 
 
164
 
165
  demo.launch()
 
3
  import numpy as np
4
  import tempfile
5
  from pathlib import Path
6
+ from tqdm.auto import tqdm
7
+
8
  import deep_sort.deep_sort.deep_sort as ds
9
 
10
  import gradio as gr
11
 
12
+ # 控制处理流程是否终止
13
+ should_continue = True
14
+
15
+ def get_detectable_classes(model_file):
16
+ """获取给定模型文件可以检测的类别。
17
+
18
+ 参数:
19
+ - model_file: 模型文件名。
20
 
21
+ 返回:
22
+ - class_names: 可检测的类别名称。
23
+ """
24
+ model = YOLO(model_file)
25
+ class_names = list(model.names.values()) # 直接获取类别名称列表
26
+ del model # 删除模型实例释放资源
27
+ return class_names
28
+
29
+ # 用于终止视频处理
30
+ def stop_processing():
31
+ global should_continue
32
+ should_continue = False # 更改变量来停止处理
33
+ return "尝试终止处理..."
34
+
35
+ # 用于开始视频处理
36
+ # gr.Progress(track_tqdm=True)用于捕获tqdm进度条,从而在GUI上显示进度
37
+ def start_processing(input_path, output_path, detect_class, model, progress=gr.Progress(track_tqdm=True)):
38
+ global should_continue
39
+ should_continue = True
40
+
41
+ detect_class = int(detect_class)
42
+ model = YOLO(model)
43
+ tracker = ds.DeepSort("deep_sort/deep_sort/deep/checkpoint/ckpt.t7")
44
+ output_video_path = detect_and_track(input_path, output_path, detect_class, model, tracker)
45
+ return output_video_path, output_video_path
46
+
47
  def putTextWithBackground(
48
  img,
49
  text,
 
87
  )
88
 
89
 
90
+ def extract_detections(results, detect_class):
91
+ """
92
+ 从模型结果中提取和处理检测信息。
93
+ - results: YoloV8模型预测结果,包含检测到的物体的位置、类别和置信度等信息。
94
+ - detect_class: 需要提取的目标类别的索引。
95
+ 参考: https://docs.ultralytics.com/modes/predict/#working-with-results
96
  """
97
+
98
+ # 初始化一个空的二维numpy数组,用于存放检测到的目标的位置信息
99
+ # 如果视频中没有需要提取的目标类别,如果不初始化,会导致tracker报错
100
+ detections = np.empty((0, 4))
101
+
102
+ confarray = [] # 初始化一个空列表,用于存放检测到的目标的置信度。
103
+
104
+ # 遍历检测结果
105
+ # 参考:https://docs.ultralytics.com/modes/predict/#working-with-results
106
+ for r in results:
107
+ for box in r.boxes:
108
+ # 如果检测到的目标类别与指定的目标类别相匹配,提取目标的位置信息和置信度
109
+ if box.cls[0].int() == detect_class:
110
+ x1, y1, x2, y2 = box.xywh[0].int().tolist() # 提取目标的位置信息,并从tensor转换为整数列表。
111
+ conf = round(box.conf[0].item(), 2) # 提取目标的置信度,从tensor中取出浮点数结果,并四舍五入到小数点后两位。
112
+ detections = np.vstack((detections, np.array([x1, y1, x2, y2]))) # 将目标的位置信息添加到detections数组中。
113
+ confarray.append(conf) # 将目标的置信度添加到confarray列表中。
114
+ return detections, confarray # 返回提取出的位置信息和置信度。
115
 
116
+ # 视频处理
117
+ def detect_and_track(input_path: str, output_path: str, detect_class: int, model, tracker) -> Path:
118
+ """
119
+ 处理视频,检测并跟踪目标。
120
+ - input_path: 输入视频文件的路径。
121
+ - output_path: 处理后视频保存的路径。
122
+ - detect_class: 需要检测和跟踪的目标类别的索引。
123
+ - model: 用于目标检测的模型。
124
+ - tracker: 用于目标跟踪的模型。
125
+ """
126
+ global should_continue
127
+ cap = cv2.VideoCapture(input_path) # 使用OpenCV打开视频文件。
128
+ if not cap.isOpened(): # 检查视频文件是否成功打开。
129
+ print(f"Error opening video file {input_path}")
130
+ return None
131
+
132
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 获取视频总帧数
133
  fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频的帧率
134
+ size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) # 获取视频的分辨率(宽度和高度)。
135
+ output_video_path = Path(output_path) / "output.avi" # 设置输出视频的保存路径。
136
+
137
+ # 设置视频编码格式为XVID格式的avi文件
 
 
 
 
138
  # 如果需要使用h264编码或者需要保存为其他格式,可能需要下载openh264-1.8.0
139
  # 下载地址:https://github.com/cisco/openh264/releases/tag/v1.8.0
140
  # 下载完成后将dll文件放在当前文件夹内
141
+ fourcc = cv2.VideoWriter_fourcc(*"XVID")
142
+ output_video = cv2.VideoWriter(output_video_path.as_posix(), fourcc, fps, size, isColor=True) # 创建一个VideoWriter对象用于写视频。
143
+
 
 
 
 
 
 
144
  # 对每一帧图片进行读取和处理
145
+ # 使用tqdm显示处理进度。
146
+ for _ in tqdm(range(total_frames)):
147
+ # 如果全局变量should_continue为False(通常由于GUI上按下Stop按钮),则终止目标检测和跟踪,返回已处理的视频部分
148
+ if not should_continue:
149
+ print('stopping process')
150
+ break
151
+
152
+ success, frame = cap.read() # 逐帧读取视频。
153
+
154
+ # 如果读取失败(或者视频已处理完毕),则跳出循环。
155
  if not (success):
156
  break
157
 
158
+ # 使用YoloV8模型对当前帧进行目标检测。
159
  results = model(frame, stream=True)
160
 
161
+ # 从预测结果中提取检测信息。
162
+ detections, confarray = extract_detections(results, detect_class)
 
 
 
 
 
 
 
 
 
163
 
164
+ # 使用deepsort模型对检测到的目标进行跟踪。
 
 
 
 
165
  resultsTracker = tracker.update(detections, confarray, frame)
166
+
167
  for x1, y1, x2, y2, Id in resultsTracker:
168
+ x1, y1, x2, y2 = map(int, [x1, y1, x2, y2]) # 将位置信息转换为整数。
169
 
170
+ # 绘制bounding box和文本
171
  cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 255), 3)
172
+ putTextWithBackground(frame, str(int(Id)), (max(-10, x1), max(40, y1)), font_scale=1.5, text_color=(255, 255, 255), bg_color=(255, 0, 255))
 
 
 
 
 
 
 
173
 
174
+ output_video.write(frame) # 将处理后的帧写入到输出视频文件中。
175
+
176
+ output_video.release() # 释放VideoWriter对象。
177
+ cap.release() # 释放视频文件。
178
+
179
+ print(f'output dir is: {output_video_path}')
180
+ return output_video_path
181
 
182
 
183
  if __name__ == "__main__":
184
+
185
+ # YoloV8官方模型列表,从左往右由小到大,第一次使用会自动下载
186
+ model_list = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt", "yolov8l.pt", "yolov8x.pt"]
187
+
188
+ # 获取YoloV8模型可以检测的所有类别,默认调用model_list中第一个模型
189
+ detect_classes = get_detectable_classes(model_list[0])
190
+
191
+ # gradio界面的输入示例,包含一个测试视频文件路径、一个随机生成的输出目录、检测的类别、使用的模型
192
+ examples = [["test.mp4", tempfile.mkdtemp(), detect_classes[0], model_list[0]],]
193
+
194
+ # 使用Gradio的Blocks创建一个GUI界面
195
  # Gradio参考文档:https://www.gradio.app/guides/blocks-and-event-listeners
196
  with gr.Blocks() as demo:
197
  with gr.Tab("Tracking"):
198
+ # 使用Markdown显示文本信息,介绍界面的功能
199
  gr.Markdown(
200
  """
201
+ # 目标检测与跟踪
202
  基于opencv + YoloV8 + deepsort
203
  """
204
  )
205
+ # 行容器,水平排列元素
206
  with gr.Row():
207
+ # 列容器,垂直排列元素
208
  with gr.Column():
209
+ input_path = gr.Video(label="Input video") # 视频输入控件,用于上传视频文件
210
+ model = gr.Dropdown(model_list, value=0, label="Model") # 下拉菜单控件,用于选择模型
211
+ detect_class = gr.Dropdown(detect_classes, value=0, label="Class", type='index') # 下拉菜单控件,用于选择要检测的目标类别
212
+ output_dir = gr.Textbox(label="Output dir", value=tempfile.mkdtemp()) # 文本框控件,用于指定输出视频的保存路径,默认为一个临时生成的目录
213
+ with gr.Row():
214
+ # 创建两个按钮控件,分别用于开始处理和停止处理
215
+ start_button = gr.Button("Process")
216
+ stop_button = gr.Button("Stop")
217
  with gr.Column():
218
+ output = gr.Video() # 视频显示控件,展示处理后的输出视频
219
+ output_path = gr.Textbox(label="Output path") # 文本框控件,用于显示输出视频的文件路径
220
+
221
+ # 添加示例到GUI中,允许用户选择预定义的输入进行快速测试
222
+ gr.Examples(examples,label="Examples",
223
+ inputs=[input_path, output_dir, detect_class, model],
224
+ outputs=[output, output_path],
225
+ fn=start_processing, # 指定处理示例时调用的函数
226
+ cache_examples=False) # 禁用示例缓存
227
+
228
+ # 将按钮与处理函数绑定
229
+ start_button.click(start_processing, inputs=[input_path, output_dir, detect_class, model], outputs=[output, output_path])
230
+ stop_button.click(stop_processing)
231
 
232
  demo.launch()
main.py CHANGED
@@ -1,8 +1,9 @@
1
- from ultralytics import YOLO
2
- import cv2
3
- import numpy as np
4
  import tempfile
5
  from pathlib import Path
 
 
 
 
6
  import deep_sort.deep_sort.deep_sort as ds
7
 
8
  def putTextWithBackground(img, text, origin, font=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1, text_color=(255, 255, 255), bg_color=(0, 0, 0), thickness=1):
@@ -28,91 +29,109 @@ def putTextWithBackground(img, text, origin, font=cv2.FONT_HERSHEY_SIMPLEX, font
28
  # 在矩形上绘制文本
29
  text_origin = (origin[0], origin[1] - 5) # 从左上角的位置减去5来留出一些边距
30
  cv2.putText(img, text, text_origin, font, font_scale, text_color, thickness, lineType=cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # 视频处理
33
- def processVideo(inputPath: str) -> Path:
34
- """处理视频,检测并跟踪行人。
35
-
36
- :param inputPath: 视频文件路径
37
- :return: 输出视频的路径
 
 
 
38
  """
39
- # 读取视频文件
40
- cap = cv2.VideoCapture(inputPath)
 
 
 
41
  fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频的帧率
42
- size = (
43
- int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
44
- int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
45
- ) # 获取视频的大小
46
- output_video = cv2.VideoWriter() # 初始化视频写入
47
 
48
- # 输出格式为XVID格式的avi文件
49
  # 如果需要使用h264编码或者需要保存为其他格式,可能需要下载openh264-1.8.0
50
  # 下载地址:https://github.com/cisco/openh264/releases/tag/v1.8.0
51
  # 下载完成后将dll文件放在当前文件夹内
52
  fourcc = cv2.VideoWriter_fourcc(*"XVID")
53
- video_save_path = Path(outputPath) / "output.avi" # 创建输出视频路径
54
-
55
- output_video.open(video_save_path.as_posix(), fourcc, fps, size, isColor=True)
56
 
57
  # 对每一帧图片进行读取和处理
58
  while True:
59
- success, frame = cap.read()
 
 
60
  if not (success):
61
  break
62
 
63
- # 获取每一帧的目标检测推理结果
64
  results = model(frame, stream=True)
65
 
66
- detections = np.empty((0, 4)) # 存放bounding box结果
67
- confarray = [] # 存放每个检测结果的置��度
68
 
69
- # 读取目标检测推理结果
70
- # 参考: https://docs.ultralytics.com/modes/predict/#working-with-results
71
- for r in results:
72
- boxes = r.boxes
73
- for box in boxes:
74
- x1, y1, x2, y2 = map(int, box.xywh[0]) # 提取矩形框左上和右下的点,并将tensor类型转为整型
75
- conf = round(float(box.conf[0]), 2) # 对conf四舍五入到2位小数
76
- cls = int(box.cls[0]) # 获取物体类别标签
77
-
78
- if cls == detect_class:
79
- detections = np.vstack((detections,np.array([x1,y1,x2,y2])))
80
- confarray.append(conf)
81
-
82
- # 使用deepsort进行跟踪
83
  resultsTracker = tracker.update(detections, confarray, frame)
 
84
  for x1, y1, x2, y2, Id in resultsTracker:
85
- x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
86
 
87
- # 绘制bounding box
88
  cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 255), 3)
89
  putTextWithBackground(frame, str(int(Id)), (max(-10, x1), max(40, y1)), font_scale=1.5, text_color=(255, 255, 255), bg_color=(255, 0, 255))
90
 
91
- output_video.write(frame) # 将处理后的图像写入视频
92
- output_video.release() # 释放
93
- cap.release() # 释放
94
- print(f'output dir is: {video_save_path}')
95
- return video_save_path
96
-
 
97
 
98
  if __name__ == "__main__":
99
- # 在这里填入视频文件路径
100
  ######
101
- input_video_path = "test.mp4"
102
  ######
103
 
104
  # 输出文件夹,默认为系统的临时文件夹路径
105
- outputPath = tempfile.mkdtemp() # 创建临时文件夹用于存储输出视频
106
 
107
  # 加载yoloV8模型权重
108
  model = YOLO("yolov8n.pt")
109
 
110
- # 需要跟踪的物体类别,model.names返回模型所支持的所有物体类别
111
  # yoloV8官方模型的第一个类别为'person'
112
  detect_class = 0
113
- print(f"detecting {model.names[detect_class]}")
114
 
115
- # 加载deepsort模型权重
116
  tracker = ds.DeepSort("deep_sort/deep_sort/deep/checkpoint/ckpt.t7")
117
 
118
- processVideo(input_video_path)
 
 
 
 
1
  import tempfile
2
  from pathlib import Path
3
+ import numpy as np
4
+ import cv2 # opencv-python
5
+ from ultralytics import YOLO
6
+
7
  import deep_sort.deep_sort.deep_sort as ds
8
 
9
  def putTextWithBackground(img, text, origin, font=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1, text_color=(255, 255, 255), bg_color=(0, 0, 0), thickness=1):
 
29
  # 在矩形上绘制文本
30
  text_origin = (origin[0], origin[1] - 5) # 从左上角的位置减去5来留出一些边距
31
  cv2.putText(img, text, text_origin, font, font_scale, text_color, thickness, lineType=cv2.LINE_AA)
32
+
33
+ def extract_detections(results, detect_class):
34
+ """
35
+ 从模型结果中提取和处理检测信息。
36
+ - results: YoloV8模型预测结果,包含检测到的物体的位置、类别和置信度等信息。
37
+ - detect_class: 需要提取的目标类别的索引。
38
+ 参考: https://docs.ultralytics.com/modes/predict/#working-with-results
39
+ """
40
+
41
+ # 初始化一个空的二维numpy数组,用于存放检测到的目标的位置信息
42
+ # 如果视频中没有需要提取的目标类别,如果不初始化,会导致tracker报错
43
+ detections = np.empty((0, 4))
44
+
45
+ confarray = [] # 初始化一个空列表,用于存放检测到的目标的置信度。
46
+
47
+ # 遍历检测结果
48
+ # 参考:https://docs.ultralytics.com/modes/predict/#working-with-results
49
+ for r in results:
50
+ for box in r.boxes:
51
+ # 如果检测到的目标类别与指定的目标类别相匹配,提取目标的位置信息和置信度
52
+ if box.cls[0].int() == detect_class:
53
+ x1, y1, x2, y2 = box.xywh[0].int().tolist() # 提取目标的位置信息,并从tensor转换为整数列表。
54
+ conf = round(box.conf[0].item(), 2) # 提取目标的置信度,从tensor中取出浮点数结果,并四舍五入到小数点后两位。
55
+ detections = np.vstack((detections, np.array([x1, y1, x2, y2]))) # 将目标的位置信息添加到detections数组中。
56
+ confarray.append(conf) # 将目标的置信度添加到confarray列表中。
57
+ return detections, confarray # 返回提取出的位置信息和置信度。
58
 
59
  # 视频处理
60
+ def detect_and_track(input_path: str, output_path: str, detect_class: int, model, tracker) -> Path:
61
+ """
62
+ 处理视频,检测并跟踪目标。
63
+ - input_path: 输入视频文件的路径。
64
+ - output_path: 处理后视频保存的路径。
65
+ - detect_class: 需要检测和跟踪的目标类别的索引。
66
+ - model: 用于目标检测的模型。
67
+ - tracker: 用于目标跟踪的模型。
68
  """
69
+ cap = cv2.VideoCapture(input_path) # 使用OpenCV打开视频文件。
70
+ if not cap.isOpened(): # 检查视频文件是否成功打开。
71
+ print(f"Error opening video file {input_path}")
72
+ return None
73
+
74
  fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频的帧率
75
+ size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) # 获取视频的分辨率(宽度和高度)。
76
+ output_video_path = Path(output_path) / "output.avi" # 设置输出视频的保存路径。
 
 
 
77
 
78
+ # 设置视频编码格式为XVID格式的avi文件
79
  # 如果需要使用h264编码或者需要保存为其他格式,可能需要下载openh264-1.8.0
80
  # 下载地址:https://github.com/cisco/openh264/releases/tag/v1.8.0
81
  # 下载完成后将dll文件放在当前文件夹内
82
  fourcc = cv2.VideoWriter_fourcc(*"XVID")
83
+ output_video = cv2.VideoWriter(output_video_path.as_posix(), fourcc, fps, size, isColor=True) # 创建一个VideoWriter对象用于写视频。
 
 
84
 
85
  # 对每一帧图片进行读取和处理
86
  while True:
87
+ success, frame = cap.read() # 逐帧读取视频。
88
+
89
+ # 如果读取失败(或者视频已处理完毕),则跳出循环。
90
  if not (success):
91
  break
92
 
93
+ # 使用YoloV8模型对当前帧进行目标检测。
94
  results = model(frame, stream=True)
95
 
96
+ # 从预测结果中提取检测信息。
97
+ detections, confarray = extract_detections(results, detect_class)
98
 
99
+ # 使用deepsort模型对检测到的目标进行跟踪。
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  resultsTracker = tracker.update(detections, confarray, frame)
101
+
102
  for x1, y1, x2, y2, Id in resultsTracker:
103
+ x1, y1, x2, y2 = map(int, [x1, y1, x2, y2]) # 将位置信息转换为整数。
104
 
105
+ # 绘制bounding box和文本
106
  cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 255), 3)
107
  putTextWithBackground(frame, str(int(Id)), (max(-10, x1), max(40, y1)), font_scale=1.5, text_color=(255, 255, 255), bg_color=(255, 0, 255))
108
 
109
+ output_video.write(frame) # 将处理后的帧写入到输出视频文件中。
110
+
111
+ output_video.release() # 释放VideoWriter对象。
112
+ cap.release() # 释放视频文件。
113
+
114
+ print(f'output dir is: {output_video_path}')
115
+ return output_video_path
116
 
117
  if __name__ == "__main__":
118
+ # 指定输入视频的路径。
119
  ######
120
+ input_path = "test.mp4"
121
  ######
122
 
123
  # 输出文件夹,默认为系统的临时文件夹路径
124
+ output_path = tempfile.mkdtemp() # 创建一个临时目录用于存放输出视频。
125
 
126
  # 加载yoloV8模型权重
127
  model = YOLO("yolov8n.pt")
128
 
129
+ # 设置需要检测和跟踪的目标类别
130
  # yoloV8官方模型的第一个类别为'person'
131
  detect_class = 0
132
+ print(f"detecting {model.names[detect_class]}") # model.names返回模型所支持的所有物体类别
133
 
134
+ # 加载DeepSort模型
135
  tracker = ds.DeepSort("deep_sort/deep_sort/deep/checkpoint/ckpt.t7")
136
 
137
+ detect_and_track(input_path, output_path, detect_class, model, tracker)
requirements.txt CHANGED
@@ -8,4 +8,5 @@ torch
8
  matplotlib
9
 
10
  # WebUI ---------------------------------------
 
11
  # gradio
 
8
  matplotlib
9
 
10
  # WebUI ---------------------------------------
11
+ # tqdm
12
  # gradio
webui.png CHANGED

Git LFS Details

  • SHA256: c14e7d199c011d1cf6ba0897854cc983b4c0816071515debd21355b606d1a26e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.15 MB