KdaiP commited on
Commit
9437b2f
1 Parent(s): b5618b7

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .gitignore +131 -0
  3. README.md +67 -13
  4. app.py +165 -0
  5. deep_sort/configs/deep_sort.yaml +10 -0
  6. deep_sort/deep_sort/README.md +3 -0
  7. deep_sort/deep_sort/__init__.py +21 -0
  8. deep_sort/deep_sort/__pycache__/__init__.cpython-310.pyc +0 -0
  9. deep_sort/deep_sort/__pycache__/deep_sort.cpython-310.pyc +0 -0
  10. deep_sort/deep_sort/deep/__init__.py +0 -0
  11. deep_sort/deep_sort/deep/__pycache__/__init__.cpython-310.pyc +0 -0
  12. deep_sort/deep_sort/deep/__pycache__/feature_extractor.cpython-310.pyc +0 -0
  13. deep_sort/deep_sort/deep/__pycache__/model.cpython-310.pyc +0 -0
  14. deep_sort/deep_sort/deep/checkpoint/ckpt.t7 +3 -0
  15. deep_sort/deep_sort/deep/evaluate.py +15 -0
  16. deep_sort/deep_sort/deep/feature_extractor.py +65 -0
  17. deep_sort/deep_sort/deep/model.py +105 -0
  18. deep_sort/deep_sort/deep/original_model.py +106 -0
  19. deep_sort/deep_sort/deep/prepare_car.py +129 -0
  20. deep_sort/deep_sort/deep/prepare_person.py +108 -0
  21. deep_sort/deep_sort/deep/test.py +77 -0
  22. deep_sort/deep_sort/deep/train.jpg +0 -0
  23. deep_sort/deep_sort/deep/train.py +192 -0
  24. deep_sort/deep_sort/deep_sort.py +125 -0
  25. deep_sort/deep_sort/sort/__init__.py +0 -0
  26. deep_sort/deep_sort/sort/__pycache__/__init__.cpython-310.pyc +0 -0
  27. deep_sort/deep_sort/sort/__pycache__/detection.cpython-310.pyc +0 -0
  28. deep_sort/deep_sort/sort/__pycache__/iou_matching.cpython-310.pyc +0 -0
  29. deep_sort/deep_sort/sort/__pycache__/kalman_filter.cpython-310.pyc +0 -0
  30. deep_sort/deep_sort/sort/__pycache__/linear_assignment.cpython-310.pyc +0 -0
  31. deep_sort/deep_sort/sort/__pycache__/nn_matching.cpython-310.pyc +0 -0
  32. deep_sort/deep_sort/sort/__pycache__/preprocessing.cpython-310.pyc +0 -0
  33. deep_sort/deep_sort/sort/__pycache__/track.cpython-310.pyc +0 -0
  34. deep_sort/deep_sort/sort/__pycache__/tracker.cpython-310.pyc +0 -0
  35. deep_sort/deep_sort/sort/detection.py +49 -0
  36. deep_sort/deep_sort/sort/iou_matching.py +84 -0
  37. deep_sort/deep_sort/sort/kalman_filter.py +286 -0
  38. deep_sort/deep_sort/sort/linear_assignment.py +240 -0
  39. deep_sort/deep_sort/sort/nn_matching.py +207 -0
  40. deep_sort/deep_sort/sort/preprocessing.py +73 -0
  41. deep_sort/deep_sort/sort/track.py +199 -0
  42. deep_sort/deep_sort/sort/tracker.py +168 -0
  43. deep_sort/utils/__init__.py +0 -0
  44. deep_sort/utils/asserts.py +13 -0
  45. deep_sort/utils/draw.py +36 -0
  46. deep_sort/utils/evaluation.py +103 -0
  47. deep_sort/utils/io.py +133 -0
  48. deep_sort/utils/json_logger.py +383 -0
  49. deep_sort/utils/log.py +17 -0
  50. deep_sort/utils/parser.py +38 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ deep_sort/deep_sort/deep/checkpoint/ckpt.t7 filter=lfs diff=lfs merge=lfs -text
37
+ demo.png filter=lfs diff=lfs merge=lfs -text
38
+ test.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ openh264-1.8.0-win64.dll
README.md CHANGED
@@ -1,13 +1,67 @@
1
- ---
2
- title: Yolov8 Deepsort Tracking
3
- emoji: 👀
4
- colorFrom: red
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 3.48.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1> yolov8-deepsort-tracking </h1>
3
+ </div>
4
+
5
+ ![示例图片](https://github.com/KdaiP/yolov8-deepsort-tracking/blob/main/demo.png)
6
+
7
+ opencv+yolov8+deepsort的行人检测与跟踪。当然,也可以识别车辆等其他类别。
8
+
9
+ - 2023/10/17更新:简化代码,删除不必要的依赖
10
+
11
+ - 2023/7/4更新:加入了一个基于Gradio的WebUI界面
12
+
13
+ ## 安装
14
+ 环境:Python>=3.8
15
+
16
+ 本项目需要pytorch,建议手动在[pytorch官网](https://pytorch.org/get-started/locally/)根据自己的平台和CUDA环境安装对应的版本。
17
+
18
+ pytorch的详细安装教程可以参照[Conda Quickstart Guide for Ultralytics](https://docs.ultralytics.com/guides/conda-quickstart/)
19
+
20
+ 安装完pytorch后,需要通过以下命令来安装其他依赖:
21
+
22
+ ```shell
23
+ $ pip install -r requirements.txt
24
+ ```
25
+
26
+
27
+ ## 配置(非WebUI)
28
+
29
+ 在main.py中修改以下代码,将输入视频路径换成你要处理的视频的路径:
30
+
31
+ ```python
32
+ input_video_path = "test.mp4"
33
+ ```
34
+
35
+ 模型默认使用Ultralytics官方的YOLOv8n模型:
36
+
37
+ ```python
38
+ model = "yolov8n.pt"
39
+ ```
40
+
41
+ 第一次使用会自动从官网下载模型,如果网速过慢,可以在[ultralytics的官方文档](https://docs.ultralytics.com/tasks/detect/)下载模型,然后将模型文件拷贝到程序所在目录下。
42
+
43
+ ## 运行(非WebUI)
44
+
45
+ 运行main.py
46
+ 运行完成后,终端会显示输出视频所在的路径。
47
+
48
+ ## WebUI界面的配置和运行
49
+
50
+ **请先确保已经安装完成上面的依赖**
51
+
52
+ 安装Gradio库:
53
+
54
+ ```shell
55
+ $ pip install gradio
56
+ ```
57
+
58
+ 运行app.py,如果控制台出现以下消息代表成功运行:
59
+ ```shell
60
+ Running on local URL: http://127.0.0.1:6006
61
+ To create a public link, set `share=True` in `launch()`
62
+ ```
63
+
64
+ 浏览器打开该URL即可使用WebUI界面
65
+
66
+ ![WebUI](https://github.com/KdaiP/yolov8-deepsort-tracking/blob/main/webui.png)
67
+
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import cv2
3
+ import numpy as np
4
+ import tempfile
5
+ from pathlib import Path
6
+ import deep_sort.deep_sort.deep_sort as ds
7
+
8
+ import gradio as gr
9
+
10
+ # YoloV8官方模型,从左往右由小到大,第一次使用会自动下载
11
+ model_list = ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt", "yolov8l.pt", "yolov8x.pt"]
12
+
13
+ def putTextWithBackground(
14
+ img,
15
+ text,
16
+ origin,
17
+ font=cv2.FONT_HERSHEY_SIMPLEX,
18
+ font_scale=1,
19
+ text_color=(255, 255, 255),
20
+ bg_color=(0, 0, 0),
21
+ thickness=1,
22
+ ):
23
+ """绘制带有背景的文本。
24
+
25
+ :param img: 输入图像。
26
+ :param text: 要绘制的文本。
27
+ :param origin: 文本的左上角坐标。
28
+ :param font: 字体类型。
29
+ :param font_scale: 字体大小。
30
+ :param text_color: 文本的颜色。
31
+ :param bg_color: 背景的颜色。
32
+ :param thickness: 文本的线条厚度。
33
+ """
34
+ # 计算文本的尺寸
35
+ (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
36
+
37
+ # 绘制背景矩形
38
+ bottom_left = origin
39
+ top_right = (origin[0] + text_width, origin[1] - text_height - 5) # 减去5以留出一些边距
40
+ cv2.rectangle(img, bottom_left, top_right, bg_color, -1)
41
+
42
+ # 在矩形上绘制文本
43
+ text_origin = (origin[0], origin[1] - 5) # 从左上角的位置减去5来留出一些边距
44
+ cv2.putText(
45
+ img,
46
+ text,
47
+ text_origin,
48
+ font,
49
+ font_scale,
50
+ text_color,
51
+ thickness,
52
+ lineType=cv2.LINE_AA,
53
+ )
54
+
55
+
56
+ # 视频处理
57
+ def processVideo(inputPath, model):
58
+ """处理视频,检测并跟踪行人。
59
+
60
+ :param inputPath: 视频文件路径
61
+ :return: 输出视频的路径
62
+ """
63
+ tracker = ds.DeepSort(
64
+ "deep_sort/deep_sort/deep/checkpoint/ckpt.t7"
65
+ ) # 加载deepsort权重文件
66
+ model = YOLO(model) # 加载YOLO模型文件
67
+
68
+ # 读取视频文件
69
+ cap = cv2.VideoCapture(inputPath)
70
+ fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频的帧率
71
+ size = (
72
+ int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
73
+ int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
74
+ ) # 获取视频的大小
75
+ output_video = cv2.VideoWriter() # 初始化视频写入
76
+ outputPath = tempfile.mkdtemp() # 创建输出视频的临时文件夹的路径
77
+
78
+ # 输出格式为XVID格式的avi文件
79
+ # 如果需要使用h264编码或者需要保存为其他格式,可能需要下载openh264-1.8.0
80
+ # 下载地址:https://github.com/cisco/openh264/releases/tag/v1.8.0
81
+ # 下载完成后将dll文件放在当前文件夹内
82
+ output_type = "avi"
83
+ if output_type == "avi":
84
+ fourcc = cv2.VideoWriter_fourcc(*"XVID")
85
+ video_save_path = Path(outputPath) / "output.avi" # 创建输出视频路径
86
+ if output_type == "mp4": # 浏览器只支持播放h264编码的mp4视频文件
87
+ fourcc = cv2.VideoWriter_fourcc(*"h264")
88
+ video_save_path = Path(outputPath) / "output.mp4"
89
+
90
+ output_video.open(video_save_path.as_posix(), fourcc, fps, size, True)
91
+ # 对每一帧图片进行读取和处理
92
+ while True:
93
+ success, frame = cap.read()
94
+ if not (success):
95
+ break
96
+
97
+ # 获取每一帧的目标检测推理结果
98
+ results = model(frame, stream=True)
99
+
100
+ detections = [] # 存放bounding box结果
101
+ confarray = [] # 存放每个检测结果的置信度
102
+
103
+ # 读取目标检测推理结果
104
+ # 参考: https://docs.ultralytics.com/modes/predict/#working-with-results
105
+ for r in results:
106
+ boxes = r.boxes
107
+ for box in boxes:
108
+ x1, y1, x2, y2 = map(int, box.xywh[0]) # 提取矩形框左上和右下的点,并将tensor类型转为整型
109
+ conf = round(float(box.conf[0]), 2) # 对conf四舍五入到2位小数
110
+ cls = int(box.cls[0]) # 获取物体类别标签
111
+
112
+ if cls == detect_class:
113
+ detections.append([x1, y1, x2, y2])
114
+ confarray.append(conf)
115
+
116
+ # 使用deepsort进行跟踪
117
+ resultsTracker = tracker.update(np.array(detections), confarray, frame)
118
+ for x1, y1, x2, y2, Id in resultsTracker:
119
+ x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
120
+
121
+ # 绘制bounding box
122
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 255), 3)
123
+ putTextWithBackground(
124
+ frame,
125
+ str(int(Id)),
126
+ (max(-10, x1), max(40, y1)),
127
+ font_scale=1.5,
128
+ text_color=(255, 255, 255),
129
+ bg_color=(255, 0, 255),
130
+ )
131
+
132
+ output_video.write(frame) # 将处理后的图像写入视频
133
+ output_video.release() # 释放
134
+ cap.release() # 释放
135
+ print(f"output dir is: {video_save_path.as_posix()}")
136
+ return video_save_path.as_posix(), video_save_path.as_posix() # Gradio的视频控件实际读取的是文件路径
137
+
138
+
139
+ if __name__ == "__main__":
140
+ # 需要跟踪的物体类别
141
+ detect_class = 0
142
+
143
+ # Gradio参考文档:https://www.gradio.app/guides/blocks-and-event-listeners
144
+ with gr.Blocks() as demo:
145
+ with gr.Tab("Tracking"):
146
+ gr.Markdown(
147
+ """
148
+ # YoloV8 + deepsort
149
+ 基于opencv + YoloV8 + deepsort
150
+ """
151
+ )
152
+ with gr.Row():
153
+ with gr.Column():
154
+ input_video = gr.Video(label="Input video")
155
+ model = gr.Dropdown(model_list, value="yolov8n.pt", label="Model")
156
+ with gr.Column():
157
+ output = gr.Video()
158
+ output_path = gr.Textbox(label="Output path")
159
+ button = gr.Button("Process")
160
+
161
+ button.click(
162
+ processVideo, inputs=[input_video, model], outputs=[output, output_path]
163
+ )
164
+
165
+ demo.launch(server_port=6006)
deep_sort/configs/deep_sort.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ DEEPSORT:
2
+ REID_CKPT: "deep_sort/deep_sort/deep/checkpoint/ckpt.t7"
3
+ MAX_DIST: 0.2
4
+ MIN_CONFIDENCE: 0.3
5
+ NMS_MAX_OVERLAP: 0.5
6
+ MAX_IOU_DISTANCE: 0.7
7
+ MAX_AGE: 70
8
+ N_INIT: 3
9
+ NN_BUDGET: 100
10
+
deep_sort/deep_sort/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Deep Sort
2
+
3
+ This is the implemention of deep sort with pytorch.
deep_sort/deep_sort/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .deep_sort import DeepSort
2
+
3
+
4
+ __all__ = ['DeepSort', 'build_tracker']
5
+
6
+
7
+ def build_tracker(cfg, use_cuda):
8
+ return DeepSort(cfg.DEEPSORT.REID_CKPT,
9
+ max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
10
+ nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
11
+ max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET, use_cuda=use_cuda)
12
+
13
+
14
+
15
+
16
+
17
+
18
+
19
+
20
+
21
+
deep_sort/deep_sort/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (619 Bytes). View file
 
deep_sort/deep_sort/__pycache__/deep_sort.cpython-310.pyc ADDED
Binary file (4.16 kB). View file
 
deep_sort/deep_sort/deep/__init__.py ADDED
File without changes
deep_sort/deep_sort/deep/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (168 Bytes). View file
 
deep_sort/deep_sort/deep/__pycache__/feature_extractor.cpython-310.pyc ADDED
Binary file (2.58 kB). View file
 
deep_sort/deep_sort/deep/__pycache__/model.cpython-310.pyc ADDED
Binary file (2.82 kB). View file
 
deep_sort/deep_sort/deep/checkpoint/ckpt.t7 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22628596f112dc7eb1fe7adfbfaf95bbc6ce8eb024205beafdc705232a646c29
3
+ size 46061055
deep_sort/deep_sort/deep/evaluate.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ features = torch.load("features.pth")
4
+ qf = features["qf"]
5
+ ql = features["ql"]
6
+ gf = features["gf"]
7
+ gl = features["gl"]
8
+
9
+ scores = qf.mm(gf.t())
10
+ res = scores.topk(5, dim=1)[1][:,0]
11
+ top1correct = gl[res].eq(ql).sum().item()
12
+
13
+ print("Acc top1:{:.3f}".format(top1correct/ql.size(0)))
14
+
15
+
deep_sort/deep_sort/deep/feature_extractor.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ import numpy as np
4
+ import cv2
5
+ import logging
6
+
7
+ from .model import Net
8
+
9
+ '''
10
+ 特征提取器:
11
+ 提取对应bounding box中的特征, 得到一个固定维度的embedding作为该bounding box的代表,
12
+ 供计算相似度时使用。
13
+
14
+ 模型训练是按照传统ReID的方法进行,使用Extractor类的时候输入为一个list的图片,得到图片对应的特征。
15
+ '''
16
+
17
+ class Extractor(object):
18
+ def __init__(self, model_path, use_cuda=True):
19
+ self.net = Net(reid=True)
20
+ self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
21
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)['net_dict']
22
+ self.net.load_state_dict(state_dict)
23
+ logger = logging.getLogger("root.tracker")
24
+ logger.info("Loading weights from {}... Done!".format(model_path))
25
+ self.net.to(self.device)
26
+ self.size = (64, 128)
27
+ self.norm = transforms.Compose([
28
+ # RGB图片数据范围是[0-255],需要先经过ToTensor除以255归一化到[0,1]之后,
29
+ # 再通过Normalize计算(x - mean)/std后,将数据归一化到[-1,1]。
30
+ transforms.ToTensor(),
31
+ # mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]是从imagenet训练集中算出来的
32
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
33
+ ])
34
+
35
+ def _preprocess(self, im_crops):
36
+ """
37
+ TODO:
38
+ 1. to float with scale from 0 to 1
39
+ 2. resize to (64, 128) as Market1501 dataset did
40
+ 3. concatenate to a numpy array
41
+ 3. to torch Tensor
42
+ 4. normalize
43
+ """
44
+ def _resize(im, size):
45
+ return cv2.resize(im.astype(np.float32)/255., size)
46
+
47
+ im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float()
48
+ return im_batch
49
+
50
+ # __call__()是一个非常特殊的实例方法。该方法的功能类似于在类中重载 () 运算符,
51
+ # 使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用。
52
+ def __call__(self, im_crops):
53
+ im_batch = self._preprocess(im_crops)
54
+ with torch.no_grad():
55
+ im_batch = im_batch.to(self.device)
56
+ features = self.net(im_batch)
57
+ return features.cpu().numpy()
58
+
59
+
60
+ if __name__ == '__main__':
61
+ img = cv2.imread("demo.jpg")[:,:,(2,1,0)]
62
+ extr = Extractor("checkpoint/ckpt.t7")
63
+ feature = extr(img)
64
+ print(feature.shape)
65
+
deep_sort/deep_sort/deep/model.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class BasicBlock(nn.Module):
6
+ def __init__(self, c_in, c_out,is_downsample=False):
7
+ super(BasicBlock,self).__init__()
8
+ self.is_downsample = is_downsample
9
+ if is_downsample:
10
+ self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False)
11
+ else:
12
+ self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)
13
+ self.bn1 = nn.BatchNorm2d(c_out)
14
+ self.relu = nn.ReLU(True)
15
+ self.conv2 = nn.Conv2d(c_out,c_out,3,stride=1,padding=1, bias=False)
16
+ self.bn2 = nn.BatchNorm2d(c_out)
17
+ if is_downsample:
18
+ self.downsample = nn.Sequential(
19
+ nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
20
+ nn.BatchNorm2d(c_out)
21
+ )
22
+ elif c_in != c_out:
23
+ self.downsample = nn.Sequential(
24
+ nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
25
+ nn.BatchNorm2d(c_out)
26
+ )
27
+ self.is_downsample = True
28
+
29
+ def forward(self,x):
30
+ y = self.conv1(x)
31
+ y = self.bn1(y)
32
+ y = self.relu(y)
33
+ y = self.conv2(y)
34
+ y = self.bn2(y)
35
+ if self.is_downsample:
36
+ x = self.downsample(x)
37
+ return F.relu(x.add(y),True)
38
+
39
+ def make_layers(c_in,c_out,repeat_times, is_downsample=False):
40
+ blocks = []
41
+ for i in range(repeat_times):
42
+ if i ==0:
43
+ blocks += [BasicBlock(c_in,c_out, is_downsample=is_downsample),]
44
+ else:
45
+ blocks += [BasicBlock(c_out,c_out),]
46
+ return nn.Sequential(*blocks)
47
+
48
+ class Net(nn.Module):
49
+ def __init__(self, num_classes=751, reid=False):
50
+ super(Net,self).__init__()
51
+ # 3 128 64
52
+ self.conv = nn.Sequential(
53
+ nn.Conv2d(3,64,3,stride=1,padding=1),
54
+ nn.BatchNorm2d(64),
55
+ nn.ReLU(inplace=True),
56
+ # nn.Conv2d(32,32,3,stride=1,padding=1),
57
+ # nn.BatchNorm2d(32),
58
+ # nn.ReLU(inplace=True),
59
+ nn.MaxPool2d(3,2,padding=1),
60
+ )
61
+ # 32 64 32
62
+ self.layer1 = make_layers(64,64,2,False)
63
+ # 32 64 32
64
+ self.layer2 = make_layers(64,128,2,True)
65
+ # 64 32 16
66
+ self.layer3 = make_layers(128,256,2,True)
67
+ # 128 16 8
68
+ self.layer4 = make_layers(256,512,2,True)
69
+ # 256 8 4
70
+ self.avgpool = nn.AvgPool2d((8,4),1)
71
+ # 256 1 1
72
+ self.reid = reid
73
+
74
+ self.classifier = nn.Sequential(
75
+ nn.Linear(512, 256),
76
+ nn.BatchNorm1d(256),
77
+ nn.ReLU(inplace=True),
78
+ nn.Dropout(),
79
+ nn.Linear(256, num_classes),
80
+ )
81
+
82
+ def forward(self, x):
83
+ x = self.conv(x)
84
+ x = self.layer1(x)
85
+ x = self.layer2(x)
86
+ x = self.layer3(x)
87
+ x = self.layer4(x)
88
+ x = self.avgpool(x)
89
+ x = x.view(x.size(0),-1)
90
+ # B x 128
91
+ if self.reid:
92
+ x = x.div(x.norm(p=2,dim=1,keepdim=True))
93
+ return x
94
+ # classifier
95
+ x = self.classifier(x)
96
+ return x
97
+
98
+
99
+ if __name__ == '__main__':
100
+ net = Net()
101
+ x = torch.randn(4,3,128,64)
102
+ y = net(x)
103
+ import ipdb; ipdb.set_trace()
104
+
105
+
deep_sort/deep_sort/deep/original_model.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class BasicBlock(nn.Module):
6
+ def __init__(self, c_in, c_out,is_downsample=False):
7
+ super(BasicBlock,self).__init__()
8
+ self.is_downsample = is_downsample
9
+ if is_downsample:
10
+ self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False)
11
+ else:
12
+ self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)
13
+ self.bn1 = nn.BatchNorm2d(c_out)
14
+ self.relu = nn.ReLU(True)
15
+ self.conv2 = nn.Conv2d(c_out,c_out,3,stride=1,padding=1, bias=False)
16
+ self.bn2 = nn.BatchNorm2d(c_out)
17
+ if is_downsample:
18
+ self.downsample = nn.Sequential(
19
+ nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
20
+ nn.BatchNorm2d(c_out)
21
+ )
22
+ elif c_in != c_out:
23
+ self.downsample = nn.Sequential(
24
+ nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
25
+ nn.BatchNorm2d(c_out)
26
+ )
27
+ self.is_downsample = True
28
+
29
+ def forward(self,x):
30
+ y = self.conv1(x)
31
+ y = self.bn1(y)
32
+ y = self.relu(y)
33
+ y = self.conv2(y)
34
+ y = self.bn2(y)
35
+ if self.is_downsample:
36
+ x = self.downsample(x)
37
+ return F.relu(x.add(y),True)
38
+
39
+ def make_layers(c_in,c_out,repeat_times, is_downsample=False):
40
+ blocks = []
41
+ for i in range(repeat_times):
42
+ if i ==0:
43
+ blocks += [BasicBlock(c_in,c_out, is_downsample=is_downsample),]
44
+ else:
45
+ blocks += [BasicBlock(c_out,c_out),]
46
+ return nn.Sequential(*blocks)
47
+
48
+ class Net(nn.Module):
49
+ def __init__(self, num_classes=625 ,reid=False):
50
+ super(Net,self).__init__()
51
+ # 3 128 64
52
+ self.conv = nn.Sequential(
53
+ nn.Conv2d(3,32,3,stride=1,padding=1),
54
+ nn.BatchNorm2d(32),
55
+ nn.ELU(inplace=True),
56
+ nn.Conv2d(32,32,3,stride=1,padding=1),
57
+ nn.BatchNorm2d(32),
58
+ nn.ELU(inplace=True),
59
+ nn.MaxPool2d(3,2,padding=1),
60
+ )
61
+ # 32 64 32
62
+ self.layer1 = make_layers(32,32,2,False)
63
+ # 32 64 32
64
+ self.layer2 = make_layers(32,64,2,True)
65
+ # 64 32 16
66
+ self.layer3 = make_layers(64,128,2,True)
67
+ # 128 16 8
68
+ self.dense = nn.Sequential(
69
+ nn.Dropout(p=0.6),
70
+ nn.Linear(128*16*8, 128),
71
+ nn.BatchNorm1d(128),
72
+ nn.ELU(inplace=True)
73
+ )
74
+ # 256 1 1
75
+ self.reid = reid
76
+ self.batch_norm = nn.BatchNorm1d(128)
77
+ self.classifier = nn.Sequential(
78
+ nn.Linear(128, num_classes),
79
+ )
80
+
81
+ def forward(self, x):
82
+ x = self.conv(x)
83
+ x = self.layer1(x)
84
+ x = self.layer2(x)
85
+ x = self.layer3(x)
86
+
87
+ x = x.view(x.size(0),-1)
88
+ if self.reid:
89
+ x = self.dense[0](x)
90
+ x = self.dense[1](x)
91
+ x = x.div(x.norm(p=2,dim=1,keepdim=True))
92
+ return x
93
+ x = self.dense(x)
94
+ # B x 128
95
+ # classifier
96
+ x = self.classifier(x)
97
+ return x
98
+
99
+
100
+ if __name__ == '__main__':
101
+ net = Net(reid=True)
102
+ x = torch.randn(4,3,128,64)
103
+ y = net(x)
104
+ import ipdb; ipdb.set_trace()
105
+
106
+
deep_sort/deep_sort/deep/prepare_car.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf8 -*-
2
+
3
+ import os
4
+ from PIL import Image
5
+ from shutil import copyfile, copytree, rmtree, move
6
+
7
+ PATH_DATASET = './car-dataset' # 需要处理的文件夹
8
+ PATH_NEW_DATASET = './car-reid-dataset' # 处理后的文件夹
9
+ PATH_ALL_IMAGES = PATH_NEW_DATASET + '/all_images'
10
+ PATH_TRAIN = PATH_NEW_DATASET + '/train'
11
+ PATH_TEST = PATH_NEW_DATASET + '/test'
12
+
13
+ # 定义创建目录函数
14
+ def mymkdir(path):
15
+ path = path.strip() # 去除首位空格
16
+ path = path.rstrip("\\") # 去除尾部 \ 符号
17
+ isExists = os.path.exists(path) # 判断路径是否存在
18
+ if not isExists:
19
+ os.makedirs(path) # 如果不存在则创建目录
20
+ print(path + ' 创建成功')
21
+ return True
22
+ else:
23
+ # 如果目录存在则不创建,并提示目录已存在
24
+ print(path + ' 目录已存在')
25
+ return False
26
+
27
+ class BatchRename():
28
+ '''
29
+ 批量重命名文件夹中的图片文件
30
+ '''
31
+
32
+ def __init__(self):
33
+ self.path = PATH_DATASET # 表示需要命名处理的文件夹
34
+
35
+ # 修改图像尺寸
36
+ def resize(self):
37
+ for aroot, dirs, files in os.walk(self.path):
38
+ # aroot是self.path目录下的所有子目录(含self.path),dir是self.path下所有的文件夹的列表.
39
+ filelist = files # 注意此处仅是该路径下的其中一个列表
40
+ # print('list', list)
41
+
42
+ # filelist = os.listdir(self.path) #获取文件路径
43
+ total_num = len(filelist) # 获取文件长度(个数)
44
+
45
+ for item in filelist:
46
+ if item.endswith('.jpg'): # 初始的图片的格式为jpg格式的(或者源文件是png格式及其他格式,后面的转换格式就可以调整为自己需要的格式即可)
47
+ src = os.path.join(os.path.abspath(aroot), item)
48
+
49
+ # 修改图片尺寸到128宽*256高
50
+ im = Image.open(src)
51
+ out = im.resize((128, 256), Image.ANTIALIAS) # resize image with high-quality
52
+ out.save(src) # 原路径保存
53
+
54
+ def rename(self):
55
+
56
+ for aroot, dirs, files in os.walk(self.path):
57
+ # aroot是self.path目录下的所有子目录(含self.path),dir是self.path下所有的文件夹的列表.
58
+ filelist = files # 注意此处仅是该路径下的其中一个列表
59
+ # print('list', list)
60
+
61
+ # filelist = os.listdir(self.path) #获取文件路径
62
+ total_num = len(filelist) # 获取文件长度(个数)
63
+
64
+ i = 1 # 表示文件的命名是从1开始的
65
+ for item in filelist:
66
+ if item.endswith('.jpg'): # 初始的图片的格式为jpg格式的(或者源文件是png格式及其他格式,后面的转换格式就可以调整为自己需要的格式即可)
67
+ src = os.path.join(os.path.abspath(aroot), item)
68
+
69
+ # 根据图片名创建图片目录
70
+ dirname = str(item.split('_')[0])
71
+ # 为相同车辆创建目录
72
+ #new_dir = os.path.join(self.path, '..', 'bbox_all', dirname)
73
+ new_dir = os.path.join(PATH_ALL_IMAGES, dirname)
74
+ if not os.path.isdir(new_dir):
75
+ mymkdir(new_dir)
76
+
77
+ # 获得new_dir中的图片数
78
+ num_pic = len(os.listdir(new_dir))
79
+
80
+ dst = os.path.join(os.path.abspath(new_dir),
81
+ dirname + 'C1T0001F' + str(num_pic + 1) + '.jpg')
82
+ # 处理后的格式也为jpg格式的,当然这里可以改成png格式 C1T0001F见mars.py filenames 相机ID,跟踪指数
83
+ # dst = os.path.join(os.path.abspath(self.path), '0000' + format(str(i), '0>3s') + '.jpg') 这种情况下的命名格式为0000000.jpg形式,可以自主定义想要的格式
84
+ try:
85
+ copyfile(src, dst) #os.rename(src, dst)
86
+ print ('converting %s to %s ...' % (src, dst))
87
+ i = i + 1
88
+ except:
89
+ continue
90
+ print ('total %d to rename & converted %d jpgs' % (total_num, i))
91
+
92
+ def split(self):
93
+ #---------------------------------------
94
+ #train_test
95
+ images_path = PATH_ALL_IMAGES
96
+ train_save_path = PATH_TRAIN
97
+ test_save_path = PATH_TEST
98
+ if not os.path.isdir(train_save_path):
99
+ os.mkdir(train_save_path)
100
+ os.mkdir(test_save_path)
101
+
102
+ for _, dirs, _ in os.walk(images_path, topdown=True):
103
+ for i, dir in enumerate(dirs):
104
+ for root, _, files in os.walk(images_path + '/' + dir, topdown=True):
105
+ for j, file in enumerate(files):
106
+ if(j==0): # test dataset;每个车辆的第一幅图片
107
+ print("序号:%s 文件夹: %s 图片:%s ��为测试集" % (i + 1, root, file))
108
+ src_path = root + '/' + file
109
+ dst_dir = test_save_path + '/' + dir
110
+ if not os.path.isdir(dst_dir):
111
+ os.mkdir(dst_dir)
112
+ dst_path = dst_dir + '/' + file
113
+ move(src_path, dst_path)
114
+ else:
115
+ src_path = root + '/' + file
116
+ dst_dir = train_save_path + '/' + dir
117
+ if not os.path.isdir(dst_dir):
118
+ os.mkdir(dst_dir)
119
+ dst_path = dst_dir + '/' + file
120
+ move(src_path, dst_path)
121
+ rmtree(PATH_ALL_IMAGES)
122
+
123
+ if __name__ == '__main__':
124
+ demo = BatchRename()
125
+ demo.resize()
126
+ demo.rename()
127
+ demo.split()
128
+
129
+
deep_sort/deep_sort/deep/prepare_person.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from shutil import copyfile
3
+
4
+ # You only need to change this line to your dataset download path
5
+ download_path = './Market-1501-v15.09.15'
6
+
7
+ if not os.path.isdir(download_path):
8
+ print('please change the download_path')
9
+
10
+ save_path = download_path + '/pytorch'
11
+ if not os.path.isdir(save_path):
12
+ os.mkdir(save_path)
13
+ #-----------------------------------------
14
+ #query
15
+ query_path = download_path + '/query'
16
+ query_save_path = download_path + '/pytorch/query'
17
+ if not os.path.isdir(query_save_path):
18
+ os.mkdir(query_save_path)
19
+
20
+ for root, dirs, files in os.walk(query_path, topdown=True):
21
+ for name in files:
22
+ if not name[-3:]=='jpg':
23
+ continue
24
+ ID = name.split('_')
25
+ src_path = query_path + '/' + name
26
+ dst_path = query_save_path + '/' + ID[0]
27
+ if not os.path.isdir(dst_path):
28
+ os.mkdir(dst_path)
29
+ copyfile(src_path, dst_path + '/' + name)
30
+
31
+ #-----------------------------------------
32
+ #multi-query
33
+ query_path = download_path + '/gt_bbox'
34
+ # for dukemtmc-reid, we do not need multi-query
35
+ if os.path.isdir(query_path):
36
+ query_save_path = download_path + '/pytorch/multi-query'
37
+ if not os.path.isdir(query_save_path):
38
+ os.mkdir(query_save_path)
39
+
40
+ for root, dirs, files in os.walk(query_path, topdown=True):
41
+ for name in files:
42
+ if not name[-3:]=='jpg':
43
+ continue
44
+ ID = name.split('_')
45
+ src_path = query_path + '/' + name
46
+ dst_path = query_save_path + '/' + ID[0]
47
+ if not os.path.isdir(dst_path):
48
+ os.mkdir(dst_path)
49
+ copyfile(src_path, dst_path + '/' + name)
50
+
51
+ #-----------------------------------------
52
+ #gallery
53
+ gallery_path = download_path + '/bounding_box_test'
54
+ gallery_save_path = download_path + '/pytorch/gallery'
55
+ if not os.path.isdir(gallery_save_path):
56
+ os.mkdir(gallery_save_path)
57
+
58
+ for root, dirs, files in os.walk(gallery_path, topdown=True):
59
+ for name in files:
60
+ if not name[-3:]=='jpg':
61
+ continue
62
+ ID = name.split('_')
63
+ src_path = gallery_path + '/' + name
64
+ dst_path = gallery_save_path + '/' + ID[0]
65
+ if not os.path.isdir(dst_path):
66
+ os.mkdir(dst_path)
67
+ copyfile(src_path, dst_path + '/' + name)
68
+
69
+ #---------------------------------------
70
+ #train_all
71
+ train_path = download_path + '/bounding_box_train'
72
+ train_save_path = download_path + '/pytorch/train_all'
73
+ if not os.path.isdir(train_save_path):
74
+ os.mkdir(train_save_path)
75
+
76
+ for root, dirs, files in os.walk(train_path, topdown=True):
77
+ for name in files:
78
+ if not name[-3:]=='jpg':
79
+ continue
80
+ ID = name.split('_')
81
+ src_path = train_path + '/' + name
82
+ dst_path = train_save_path + '/' + ID[0]
83
+ if not os.path.isdir(dst_path):
84
+ os.mkdir(dst_path)
85
+ copyfile(src_path, dst_path + '/' + name)
86
+
87
+
88
+ #---------------------------------------
89
+ #train_val
90
+ train_path = download_path + '/bounding_box_train'
91
+ train_save_path = download_path + '/pytorch/train'
92
+ val_save_path = download_path + '/pytorch/test'
93
+ if not os.path.isdir(train_save_path):
94
+ os.mkdir(train_save_path)
95
+ os.mkdir(val_save_path)
96
+
97
+ for root, dirs, files in os.walk(train_path, topdown=True):
98
+ for name in files:
99
+ if not name[-3:]=='jpg':
100
+ continue
101
+ ID = name.split('_')
102
+ src_path = train_path + '/' + name
103
+ dst_path = train_save_path + '/' + ID[0]
104
+ if not os.path.isdir(dst_path):
105
+ os.mkdir(dst_path)
106
+ dst_path = val_save_path + '/' + ID[0] #first image is used as val image
107
+ os.mkdir(dst_path)
108
+ copyfile(src_path, dst_path + '/' + name)
deep_sort/deep_sort/deep/test.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.backends.cudnn as cudnn
3
+ import torchvision
4
+
5
+ import argparse
6
+ import os
7
+
8
+ from model import Net
9
+
10
+ parser = argparse.ArgumentParser(description="Train on market1501")
11
+ parser.add_argument("--data-dir",default='data',type=str)
12
+ parser.add_argument("--no-cuda",action="store_true")
13
+ parser.add_argument("--gpu-id",default=0,type=int)
14
+ args = parser.parse_args()
15
+
16
+ # device
17
+ device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu"
18
+ if torch.cuda.is_available() and not args.no_cuda:
19
+ cudnn.benchmark = True
20
+
21
+ # data loader
22
+ root = args.data_dir
23
+ query_dir = os.path.join(root,"query")
24
+ gallery_dir = os.path.join(root,"gallery")
25
+ transform = torchvision.transforms.Compose([
26
+ torchvision.transforms.Resize((128,64)),
27
+ torchvision.transforms.ToTensor(),
28
+ torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
29
+ ])
30
+ queryloader = torch.utils.data.DataLoader(
31
+ torchvision.datasets.ImageFolder(query_dir, transform=transform),
32
+ batch_size=64, shuffle=False
33
+ )
34
+ galleryloader = torch.utils.data.DataLoader(
35
+ torchvision.datasets.ImageFolder(gallery_dir, transform=transform),
36
+ batch_size=64, shuffle=False
37
+ )
38
+
39
+ # net definition
40
+ net = Net(reid=True)
41
+ assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"
42
+ print('Loading from checkpoint/ckpt.t7')
43
+ checkpoint = torch.load("./checkpoint/ckpt.t7")
44
+ net_dict = checkpoint['net_dict']
45
+ net.load_state_dict(net_dict, strict=False)
46
+ net.eval()
47
+ net.to(device)
48
+
49
+ # compute features
50
+ query_features = torch.tensor([]).float()
51
+ query_labels = torch.tensor([]).long()
52
+ gallery_features = torch.tensor([]).float()
53
+ gallery_labels = torch.tensor([]).long()
54
+
55
+ with torch.no_grad():
56
+ for idx,(inputs,labels) in enumerate(queryloader):
57
+ inputs = inputs.to(device)
58
+ features = net(inputs).cpu()
59
+ query_features = torch.cat((query_features, features), dim=0)
60
+ query_labels = torch.cat((query_labels, labels))
61
+
62
+ for idx,(inputs,labels) in enumerate(galleryloader):
63
+ inputs = inputs.to(device)
64
+ features = net(inputs).cpu()
65
+ gallery_features = torch.cat((gallery_features, features), dim=0)
66
+ gallery_labels = torch.cat((gallery_labels, labels))
67
+
68
+ gallery_labels -= 2
69
+
70
+ # save features
71
+ features = {
72
+ "qf": query_features,
73
+ "ql": query_labels,
74
+ "gf": gallery_features,
75
+ "gl": gallery_labels
76
+ }
77
+ torch.save(features,"features.pth")
deep_sort/deep_sort/deep/train.jpg ADDED
deep_sort/deep_sort/deep/train.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import torch
8
+ import torch.backends.cudnn as cudnn
9
+ import torchvision
10
+
11
+ from model import Net
12
+
13
+ parser = argparse.ArgumentParser(description="Train on market1501")
14
+ parser.add_argument("--data-dir",default='data',type=str)
15
+ parser.add_argument("--no-cuda",action="store_true")
16
+ parser.add_argument("--gpu-id",default=0,type=int)
17
+ parser.add_argument("--lr",default=0.1, type=float)
18
+ parser.add_argument("--interval",'-i',default=20,type=int)
19
+ parser.add_argument('--resume', '-r',action='store_true')
20
+ args = parser.parse_args()
21
+
22
+ # device
23
+ device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu"
24
+ if torch.cuda.is_available() and not args.no_cuda:
25
+ cudnn.benchmark = True
26
+
27
+ # data loading
28
+ root = args.data_dir
29
+ train_dir = os.path.join(root,"train")
30
+ test_dir = os.path.join(root,"test")
31
+
32
+ transform_train = torchvision.transforms.Compose([
33
+ torchvision.transforms.RandomCrop((128,64),padding=4),
34
+ torchvision.transforms.RandomHorizontalFlip(),
35
+ torchvision.transforms.ToTensor(),
36
+ torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
37
+ ])
38
+ transform_test = torchvision.transforms.Compose([
39
+ torchvision.transforms.Resize((128,64)),
40
+ torchvision.transforms.ToTensor(),
41
+ torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
42
+ ])
43
+ trainloader = torch.utils.data.DataLoader(
44
+ torchvision.datasets.ImageFolder(train_dir, transform=transform_train),
45
+ batch_size=64,shuffle=True
46
+ )
47
+ testloader = torch.utils.data.DataLoader(
48
+ torchvision.datasets.ImageFolder(test_dir, transform=transform_test),
49
+ batch_size=64,shuffle=True
50
+ )
51
+ num_classes = max(len(trainloader.dataset.classes), len(testloader.dataset.classes))
52
+ print("num_classes = %s" %num_classes)
53
+
54
+ # net definition
55
+ start_epoch = 0
56
+ net = Net(num_classes=num_classes)
57
+ if args.resume:
58
+ assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"
59
+ print('Loading from checkpoint/ckpt.t7')
60
+ checkpoint = torch.load("./checkpoint/ckpt.t7")
61
+ # import ipdb; ipdb.set_trace()
62
+ net_dict = checkpoint['net_dict']
63
+ net.load_state_dict(net_dict)
64
+ best_acc = checkpoint['acc']
65
+ start_epoch = checkpoint['epoch']
66
+ net.to(device)
67
+
68
+ # loss and optimizer
69
+ criterion = torch.nn.CrossEntropyLoss()
70
+ optimizer = torch.optim.SGD(net.parameters(), args.lr, momentum=0.9, weight_decay=5e-4)
71
+ best_acc = 0.
72
+
73
+ # train function for each epoch
74
+ def train(epoch):
75
+ print("\nEpoch : %d"%(epoch+1))
76
+ net.train()
77
+ training_loss = 0.
78
+ train_loss = 0.
79
+ correct = 0
80
+ total = 0
81
+ interval = args.interval
82
+ start = time.time()
83
+ for idx, (inputs, labels) in enumerate(trainloader):
84
+ # forward
85
+ inputs,labels = inputs.to(device),labels.to(device)
86
+ outputs = net(inputs)
87
+ loss = criterion(outputs, labels)
88
+
89
+ # backward
90
+ optimizer.zero_grad()
91
+ loss.backward()
92
+ optimizer.step()
93
+
94
+ # accumurating
95
+ training_loss += loss.item()
96
+ train_loss += loss.item()
97
+ correct += outputs.max(dim=1)[1].eq(labels).sum().item()
98
+ total += labels.size(0)
99
+
100
+ # print
101
+ if (idx+1)%interval == 0:
102
+ end = time.time()
103
+ print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(
104
+ 100.*(idx+1)/len(trainloader), end-start, training_loss/interval, correct, total, 100.*correct/total
105
+ ))
106
+ training_loss = 0.
107
+ start = time.time()
108
+
109
+ return train_loss/len(trainloader), 1.- correct/total
110
+
111
+ def test(epoch):
112
+ global best_acc
113
+ net.eval()
114
+ test_loss = 0.
115
+ correct = 0
116
+ total = 0
117
+ start = time.time()
118
+ with torch.no_grad():
119
+ for idx, (inputs, labels) in enumerate(testloader):
120
+ inputs, labels = inputs.to(device), labels.to(device)
121
+ outputs = net(inputs)
122
+ loss = criterion(outputs, labels)
123
+
124
+ test_loss += loss.item()
125
+ correct += outputs.max(dim=1)[1].eq(labels).sum().item()
126
+ total += labels.size(0)
127
+
128
+ print("Testing ...")
129
+ end = time.time()
130
+ print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(
131
+ 100.*(idx+1)/len(testloader), end-start, test_loss/len(testloader), correct, total, 100.*correct/total
132
+ ))
133
+
134
+ # saving checkpoint
135
+ acc = 100.*correct/total
136
+ if acc > best_acc:
137
+ best_acc = acc
138
+ print("Saving parameters to checkpoint/ckpt.t7")
139
+ checkpoint = {
140
+ 'net_dict':net.state_dict(),
141
+ 'acc':acc,
142
+ 'epoch':epoch,
143
+ }
144
+ if not os.path.isdir('checkpoint'):
145
+ os.mkdir('checkpoint')
146
+ torch.save(checkpoint, './checkpoint/ckpt.t7')
147
+
148
+ return test_loss/len(testloader), 1.- correct/total
149
+
150
+ # plot figure
151
+ x_epoch = []
152
+ record = {'train_loss':[], 'train_err':[], 'test_loss':[], 'test_err':[]}
153
+ fig = plt.figure()
154
+ ax0 = fig.add_subplot(121, title="loss")
155
+ ax1 = fig.add_subplot(122, title="top1err")
156
+ def draw_curve(epoch, train_loss, train_err, test_loss, test_err):
157
+ global record
158
+ record['train_loss'].append(train_loss)
159
+ record['train_err'].append(train_err)
160
+ record['test_loss'].append(test_loss)
161
+ record['test_err'].append(test_err)
162
+
163
+ x_epoch.append(epoch)
164
+ ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train')
165
+ ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val')
166
+ ax1.plot(x_epoch, record['train_err'], 'bo-', label='train')
167
+ ax1.plot(x_epoch, record['test_err'], 'ro-', label='val')
168
+ if epoch == 0:
169
+ ax0.legend()
170
+ ax1.legend()
171
+ fig.savefig("train.jpg")
172
+
173
+ # lr decay
174
+ def lr_decay():
175
+ global optimizer
176
+ for params in optimizer.param_groups:
177
+ params['lr'] *= 0.1
178
+ lr = params['lr']
179
+ print("Learning rate adjusted to {}".format(lr))
180
+
181
+ def main():
182
+ total_epoches = 40
183
+ for epoch in range(start_epoch, start_epoch+total_epoches):
184
+ train_loss, train_err = train(epoch)
185
+ test_loss, test_err = test(epoch)
186
+ draw_curve(epoch, train_loss, train_err, test_loss, test_err)
187
+ if (epoch+1)%(total_epoches//2)==0:
188
+ lr_decay()
189
+
190
+
191
+ if __name__ == '__main__':
192
+ main()
deep_sort/deep_sort/deep_sort.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from .deep.feature_extractor import Extractor
5
+ from .sort.nn_matching import NearestNeighborDistanceMetric
6
+ from .sort.preprocessing import non_max_suppression
7
+ from .sort.detection import Detection
8
+ from .sort.tracker import Tracker
9
+
10
+
11
+ __all__ = ['DeepSort'] # __all__ 提供了暴露接口用的”白名单“
12
+
13
+
14
+ class DeepSort(object):
15
+ def __init__(self, model_path, max_dist=0.2, min_confidence=0.3, nms_max_overlap=1.0, max_iou_distance=0.7, max_age=70, n_init=3, nn_budget=100, use_cuda=True):
16
+ self.min_confidence = min_confidence # 检测结果置信度阈值
17
+ self.nms_max_overlap = nms_max_overlap # 非极大抑制阈值,设置为1代表不进行抑制
18
+
19
+ self.extractor = Extractor(model_path, use_cuda=use_cuda) # 用于提取一个batch图片对应的特征
20
+
21
+ max_cosine_distance = max_dist # 最大余弦距离,用于级联匹配,如果大于该阈值,则忽略
22
+ nn_budget = 100 # 每个类别gallery最多的外观描述子的个数,如果超过,删除旧的
23
+ # NearestNeighborDistanceMetric 最近邻距离度量
24
+ # 对于每个目标,返回到目前为止已观察到的任何样本的最近距离(欧式或余弦)。
25
+ # 由距离度量方法构造一个 Tracker。
26
+ # 第一个参数可选'cosine' or 'euclidean'
27
+ self.metric = NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget)
28
+ self.tracker = Tracker(self.metric, max_iou_distance=max_iou_distance, max_age=max_age, n_init=n_init)
29
+
30
+ def update(self, bbox_xywh, confidences, ori_img):
31
+ self.height, self.width = ori_img.shape[:2]
32
+ # generate detections
33
+ # 从原图中抠取bbox对应图片并计算得到相应的特征
34
+ features = self._get_features(bbox_xywh, ori_img)
35
+ bbox_tlwh = self._xywh_to_tlwh(bbox_xywh)
36
+ # 筛选掉小于min_confidence的目标,并构造一个Detection对象构成的列表
37
+ detections = [Detection(bbox_tlwh[i], conf, features[i]) for i,conf in enumerate(confidences) if conf>self.min_confidence]
38
+
39
+ # run on non-maximum supression
40
+ boxes = np.array([d.tlwh for d in detections])
41
+ scores = np.array([d.confidence for d in detections])
42
+ indices = non_max_suppression(boxes, self.nms_max_overlap, scores)
43
+ detections = [detections[i] for i in indices]
44
+
45
+ # update tracker
46
+ self.tracker.predict() # 将跟踪状态分布向前传播一步
47
+ self.tracker.update(detections) # 执行测量更新和跟踪管理
48
+
49
+ # output bbox identities
50
+ outputs = []
51
+ for track in self.tracker.tracks:
52
+ if not track.is_confirmed() or track.time_since_update > 1:
53
+ continue
54
+ box = track.to_tlwh()
55
+ x1,y1,x2,y2 = self._tlwh_to_xyxy(box)
56
+ track_id = track.track_id
57
+ outputs.append(np.array([x1,y1,x2,y2,track_id], dtype=np.int16))
58
+ if len(outputs) > 0:
59
+ outputs = np.stack(outputs,axis=0)
60
+ return outputs
61
+
62
+
63
+ """
64
+ TODO:
65
+ Convert bbox from xc_yc_w_h to xtl_ytl_w_h
66
+ Thanks JieChen91@github.com for reporting this bug!
67
+ """
68
+ #将bbox的[x,y,w,h] 转换成[t,l,w,h]
69
+ @staticmethod
70
+ def _xywh_to_tlwh(bbox_xywh):
71
+ if isinstance(bbox_xywh, np.ndarray):
72
+ bbox_tlwh = bbox_xywh.copy()
73
+ elif isinstance(bbox_xywh, torch.Tensor):
74
+ bbox_tlwh = bbox_xywh.clone()
75
+ bbox_tlwh[:,0] = bbox_xywh[:,0] - bbox_xywh[:,2]/2.
76
+ bbox_tlwh[:,1] = bbox_xywh[:,1] - bbox_xywh[:,3]/2.
77
+ return bbox_tlwh
78
+
79
+ #将bbox的[x,y,w,h] 转换成[x1,y1,x2,y2]
80
+ #某些数据集例如 pascal_voc 的标注方式是采用[x,y,w,h]
81
+ """Convert [x y w h] box format to [x1 y1 x2 y2] format."""
82
+ def _xywh_to_xyxy(self, bbox_xywh):
83
+ x,y,w,h = bbox_xywh
84
+ x1 = max(int(x-w/2),0)
85
+ x2 = min(int(x+w/2),self.width-1)
86
+ y1 = max(int(y-h/2),0)
87
+ y2 = min(int(y+h/2),self.height-1)
88
+ return x1,y1,x2,y2
89
+
90
+ def _tlwh_to_xyxy(self, bbox_tlwh):
91
+ """
92
+ TODO:
93
+ Convert bbox from xtl_ytl_w_h to xc_yc_w_h
94
+ Thanks JieChen91@github.com for reporting this bug!
95
+ """
96
+ x,y,w,h = bbox_tlwh
97
+ x1 = max(int(x),0)
98
+ x2 = min(int(x+w),self.width-1)
99
+ y1 = max(int(y),0)
100
+ y2 = min(int(y+h),self.height-1)
101
+ return x1,y1,x2,y2
102
+
103
+ def _xyxy_to_tlwh(self, bbox_xyxy):
104
+ x1,y1,x2,y2 = bbox_xyxy
105
+
106
+ t = x1
107
+ l = y1
108
+ w = int(x2-x1)
109
+ h = int(y2-y1)
110
+ return t,l,w,h
111
+
112
+ # 获取抠图部分的特征
113
+ def _get_features(self, bbox_xywh, ori_img):
114
+ im_crops = []
115
+ for box in bbox_xywh:
116
+ x1,y1,x2,y2 = self._xywh_to_xyxy(box)
117
+ im = ori_img[y1:y2,x1:x2] # 抠图部分
118
+ im_crops.append(im)
119
+ if im_crops:
120
+ features = self.extractor(im_crops) # 对抠图部分提取特征
121
+ else:
122
+ features = np.array([])
123
+ return features
124
+
125
+
deep_sort/deep_sort/sort/__init__.py ADDED
File without changes
deep_sort/deep_sort/sort/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (168 Bytes). View file
 
deep_sort/deep_sort/sort/__pycache__/detection.cpython-310.pyc ADDED
Binary file (1.91 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/iou_matching.cpython-310.pyc ADDED
Binary file (2.95 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/kalman_filter.cpython-310.pyc ADDED
Binary file (7.95 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/linear_assignment.cpython-310.pyc ADDED
Binary file (8.19 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/nn_matching.cpython-310.pyc ADDED
Binary file (7.45 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/preprocessing.cpython-310.pyc ADDED
Binary file (1.92 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/track.cpython-310.pyc ADDED
Binary file (6.89 kB). View file
 
deep_sort/deep_sort/sort/__pycache__/tracker.cpython-310.pyc ADDED
Binary file (5.71 kB). View file
 
deep_sort/deep_sort/sort/detection.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vim: expandtab:ts=4:sw=4
2
+ import numpy as np
3
+
4
+
5
+ class Detection(object):
6
+ """
7
+ This class represents a bounding box detection in a single image.
8
+
9
+ Parameters
10
+ ----------
11
+ tlwh : array_like
12
+ Bounding box in format `(top left x, top left y, width, height)`.
13
+ confidence : float
14
+ Detector confidence score.
15
+ feature : array_like
16
+ A feature vector that describes the object contained in this image.
17
+
18
+ Attributes
19
+ ----------
20
+ tlwh : ndarray
21
+ Bounding box in format `(top left x, top left y, width, height)`.
22
+ confidence : ndarray
23
+ Detector confidence score.
24
+ feature : ndarray | NoneType
25
+ A feature vector that describes the object contained in this image.
26
+
27
+ """
28
+
29
+ def __init__(self, tlwh, confidence, feature):
30
+ self.tlwh = np.asarray(tlwh, dtype=np.float32)
31
+ self.confidence = float(confidence)
32
+ self.feature = np.asarray(feature, dtype=np.float32)
33
+
34
+ def to_tlbr(self):
35
+ """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
36
+ `(top left, bottom right)`.
37
+ """
38
+ ret = self.tlwh.copy()
39
+ ret[2:] += ret[:2]
40
+ return ret
41
+
42
+ def to_xyah(self):
43
+ """Convert bounding box to format `(center x, center y, aspect ratio,
44
+ height)`, where the aspect ratio is `width / height`.
45
+ """
46
+ ret = self.tlwh.copy()
47
+ ret[:2] += ret[2:] / 2
48
+ ret[2] /= ret[3]
49
+ return ret
deep_sort/deep_sort/sort/iou_matching.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vim: expandtab:ts=4:sw=4
2
+ from __future__ import absolute_import
3
+ import numpy as np
4
+ from . import linear_assignment
5
+
6
+ #计算两个框的IOU
7
+ def iou(bbox, candidates):
8
+ """Computer intersection over union.
9
+
10
+ Parameters
11
+ ----------
12
+ bbox : ndarray
13
+ A bounding box in format `(top left x, top left y, width, height)`.
14
+ candidates : ndarray
15
+ A matrix of candidate bounding boxes (one per row) in the same format
16
+ as `bbox`.
17
+
18
+ Returns
19
+ -------
20
+ ndarray
21
+ The intersection over union in [0, 1] between the `bbox` and each
22
+ candidate. A higher score means a larger fraction of the `bbox` is
23
+ occluded by the candidate.
24
+
25
+ """
26
+ bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:]
27
+ candidates_tl = candidates[:, :2]
28
+ candidates_br = candidates[:, :2] + candidates[:, 2:]
29
+
30
+ # np.c_ Translates slice objects to concatenation along the second axis.
31
+ tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
32
+ np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
33
+ br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
34
+ np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
35
+ wh = np.maximum(0., br - tl)
36
+
37
+ area_intersection = wh.prod(axis=1)
38
+ area_bbox = bbox[2:].prod()
39
+ area_candidates = candidates[:, 2:].prod(axis=1)
40
+ return area_intersection / (area_bbox + area_candidates - area_intersection)
41
+
42
+ # 计算tracks和detections之间的IOU距离成本矩阵
43
+ def iou_cost(tracks, detections, track_indices=None,
44
+ detection_indices=None):
45
+ """An intersection over union distance metric.
46
+
47
+ 用于计算tracks和detections之间的iou距离矩阵
48
+
49
+ Parameters
50
+ ----------
51
+ tracks : List[deep_sort.track.Track]
52
+ A list of tracks.
53
+ detections : List[deep_sort.detection.Detection]
54
+ A list of detections.
55
+ track_indices : Optional[List[int]]
56
+ A list of indices to tracks that should be matched. Defaults to
57
+ all `tracks`.
58
+ detection_indices : Optional[List[int]]
59
+ A list of indices to detections that should be matched. Defaults
60
+ to all `detections`.
61
+
62
+ Returns
63
+ -------
64
+ ndarray
65
+ Returns a cost matrix of shape
66
+ len(track_indices), len(detection_indices) where entry (i, j) is
67
+ `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
68
+
69
+ """
70
+ if track_indices is None:
71
+ track_indices = np.arange(len(tracks))
72
+ if detection_indices is None:
73
+ detection_indices = np.arange(len(detections))
74
+
75
+ cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
76
+ for row, track_idx in enumerate(track_indices):
77
+ if tracks[track_idx].time_since_update > 1:
78
+ cost_matrix[row, :] = linear_assignment.INFTY_COST
79
+ continue
80
+
81
+ bbox = tracks[track_idx].to_tlwh()
82
+ candidates = np.asarray([detections[i].tlwh for i in detection_indices])
83
+ cost_matrix[row, :] = 1. - iou(bbox, candidates)
84
+ return cost_matrix
deep_sort/deep_sort/sort/kalman_filter.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vim: expandtab:ts=4:sw=4
2
+ import numpy as np
3
+ import scipy.linalg
4
+
5
+
6
+ """
7
+ Table for the 0.95 quantile of the chi-square distribution with N degrees of
8
+ freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
9
+ function and used as Mahalanobis gating threshold.
10
+ """
11
+ chi2inv95 = {
12
+ 1: 3.8415,
13
+ 2: 5.9915,
14
+ 3: 7.8147,
15
+ 4: 9.4877,
16
+ 5: 11.070,
17
+ 6: 12.592,
18
+ 7: 14.067,
19
+ 8: 15.507,
20
+ 9: 16.919}
21
+
22
+ '''
23
+ 卡尔曼滤波分为两个阶段:
24
+ (1) 预测track在下一时刻的位置,
25
+ (2) 基于detection来更新预测的位置。
26
+ '''
27
+ class KalmanFilter(object):
28
+ """
29
+ A simple Kalman filter for tracking bounding boxes in image space.
30
+
31
+ The 8-dimensional state space
32
+
33
+ x, y, a, h, vx, vy, va, vh
34
+
35
+ contains the bounding box center position (x, y), aspect ratio a, height h,
36
+ and their respective velocities.
37
+
38
+ Object motion follows a constant velocity model. The bounding box location
39
+ (x, y, a, h) is taken as direct observation of the state space (linear
40
+ observation model).
41
+
42
+ 对于每个轨迹,由一个 KalmanFilter 预测状态分布。每个轨迹记录自己的均值和方差作为滤波器输入。
43
+
44
+ 8维状态空间[x, y, a, h, vx, vy, va, vh]包含边界框中心位置(x, y),纵横比a,高度h和它们各自的速度。
45
+ 物体运动遵循恒速模型。 边界框位置(x, y, a, h)被视为状态空间的直接观察(线性观察模型)
46
+
47
+ """
48
+
49
+ def __init__(self):
50
+ ndim, dt = 4, 1.
51
+
52
+ # Create Kalman filter model matrices.
53
+ self._motion_mat = np.eye(2 * ndim, 2 * ndim)
54
+ for i in range(ndim):
55
+ self._motion_mat[i, ndim + i] = dt
56
+ self._update_mat = np.eye(ndim, 2 * ndim)
57
+
58
+ # Motion and observation uncertainty are chosen relative to the current
59
+ # state estimate. These weights control the amount of uncertainty in
60
+ # the model. This is a bit hacky.
61
+ # 依据当前状态估计(高度)选择运动和观测不确定性。这些权重控制模型中的不确定性。
62
+ self._std_weight_position = 1. / 20
63
+ self._std_weight_velocity = 1. / 160
64
+
65
+ def initiate(self, measurement):
66
+ """Create track from unassociated measurement.
67
+
68
+ Parameters
69
+ ----------
70
+ measurement : ndarray
71
+ Bounding box coordinates (x, y, a, h) with center position (x, y),
72
+ aspect ratio a, and height h.
73
+
74
+ Returns
75
+ -------
76
+ (ndarray, ndarray)
77
+ Returns the mean vector (8 dimensional) and covariance matrix (8x8
78
+ dimensional) of the new track. Unobserved velocities are initialized
79
+ to 0 mean.
80
+
81
+ """
82
+
83
+
84
+ mean_pos = measurement
85
+ mean_vel = np.zeros_like(mean_pos)
86
+ # Translates slice objects to concatenation along the first axis
87
+ mean = np.r_[mean_pos, mean_vel]
88
+
89
+ # 由测量初始化均值向量(8维)和协方差矩阵(8x8维)
90
+ std = [
91
+ 2 * self._std_weight_position * measurement[3],
92
+ 2 * self._std_weight_position * measurement[3],
93
+ 1e-2,
94
+ 2 * self._std_weight_position * measurement[3],
95
+ 10 * self._std_weight_velocity * measurement[3],
96
+ 10 * self._std_weight_velocity * measurement[3],
97
+ 1e-5,
98
+ 10 * self._std_weight_velocity * measurement[3]]
99
+ covariance = np.diag(np.square(std))
100
+ return mean, covariance
101
+
102
+ def predict(self, mean, covariance):
103
+ """Run Kalman filter prediction step.
104
+
105
+ Parameters
106
+ ----------
107
+ mean : ndarray
108
+ The 8 dimensional mean vector of the object state at the previous
109
+ time step.
110
+ covariance : ndarray
111
+ The 8x8 dimensional covariance matrix of the object state at the
112
+ previous time step.
113
+
114
+ Returns
115
+ -------
116
+ (ndarray, ndarray)
117
+ Returns the mean vector and covariance matrix of the predicted
118
+ state. Unobserved velocities are initialized to 0 mean.
119
+
120
+ """
121
+ #卡尔曼滤波器由目标上一时刻的均值和协方差进行预测。
122
+ std_pos = [
123
+ self._std_weight_position * mean[3],
124
+ self._std_weight_position * mean[3],
125
+ 1e-2,
126
+ self._std_weight_position * mean[3]]
127
+ std_vel = [
128
+ self._std_weight_velocity * mean[3],
129
+ self._std_weight_velocity * mean[3],
130
+ 1e-5,
131
+ self._std_weight_velocity * mean[3]]
132
+
133
+ # 初始化噪声矩阵Q;np.r_ 按列连接两个矩阵
134
+ # motion_cov是过程噪声 W_k的 协方差矩阵Qk
135
+ motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
136
+
137
+ # Update time state x' = Fx (1)
138
+ # x为track在t-1时刻的均值,F称为状态转移矩阵,该公式预测t时刻的x'
139
+ # self._motion_mat为F_k是作用在 x_{k-1}上的状态变换模型
140
+ mean = np.dot(self._motion_mat, mean)
141
+ # Calculate error covariance P' = FPF^T+Q (2)
142
+ # P为track在t-1时刻的协方差,Q为系统的噪声矩阵,代表整个系统的可靠程度,一般初始化为很小的值,
143
+ # 该公式预测t时刻的P'
144
+ # covariance为P_{k|k} ,后验估计误差协方差矩阵,度量估计值的精确程度
145
+ covariance = np.linalg.multi_dot((
146
+ self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
147
+
148
+ return mean, covariance
149
+
150
+ def project(self, mean, covariance):
151
+ """Project state distribution to measurement space.
152
+ 投影状态分布到测量空间
153
+
154
+ Parameters
155
+ ----------
156
+ mean : ndarray
157
+ The state's mean vector (8 dimensional array).
158
+ covariance : ndarray
159
+ The state's covariance matrix (8x8 dimensional).
160
+
161
+ mean:ndarray,状态的平均向量(8维数组)。
162
+ covariance:ndarray,状态的协方差矩阵(8x8维)。
163
+
164
+ Returns
165
+ -------
166
+ (ndarray, ndarray)
167
+ Returns the projected mean and covariance matrix of the given state
168
+ estimate.
169
+
170
+ 返回(ndarray,ndarray),返回给定状态估计的投影平均值和协方差矩阵
171
+
172
+ """
173
+ # 在公式4中,R为检测器的噪声矩阵,它是一个4x4的对角矩阵,
174
+ # 对角线上的值分别为中心点两个坐标以及宽高的噪声,
175
+ # 以任意值初始化,一般设置宽高的噪声大于中心点的噪声,
176
+ # 该公式先将协方差矩阵P'映射到检测空间,然后再加上噪声矩阵R;
177
+ std = [
178
+ self._std_weight_position * mean[3],
179
+ self._std_weight_position * mean[3],
180
+ 1e-1,
181
+ self._std_weight_position * mean[3]]
182
+
183
+ # R为测量过程中噪声的协方差;初始化噪声矩阵R
184
+ innovation_cov = np.diag(np.square(std))
185
+
186
+ # 将均值向量映射到检测空间,即 Hx'
187
+ mean = np.dot(self._update_mat, mean)
188
+ # 将协方差矩阵映射到检测空间,即 HP'H^T
189
+ covariance = np.linalg.multi_dot((
190
+ self._update_mat, covariance, self._update_mat.T))
191
+ return mean, covariance + innovation_cov # 公式(4)
192
+
193
+ def update(self, mean, covariance, measurement):
194
+ """Run Kalman filter correction step.
195
+ 通过估计值和观测值估计最新结果
196
+
197
+ Parameters
198
+ ----------
199
+ mean : ndarray
200
+ The predicted state's mean vector (8 dimensional).
201
+ covariance : ndarray
202
+ The state's covariance matrix (8x8 dimensional).
203
+ measurement : ndarray
204
+ The 4 dimensional measurement vector (x, y, a, h), where (x, y)
205
+ is the center position, a the aspect ratio, and h the height of the
206
+ bounding box.
207
+
208
+ Returns
209
+ -------
210
+ (ndarray, ndarray)
211
+ Returns the measurement-corrected state distribution.
212
+
213
+ """
214
+ # 将均值和协方差映射到检测空间,得到 Hx'和S
215
+ projected_mean, projected_cov = self.project(mean, covariance)
216
+
217
+ # 矩阵分解
218
+ chol_factor, lower = scipy.linalg.cho_factor(
219
+ projected_cov, lower=True, check_finite=False)
220
+ # 计算卡尔曼增益K;相当于求解公式(5)
221
+ # 公式5计算卡尔曼增益K,卡尔曼增益用于估计误差的重要程度
222
+ # 求解卡尔曼滤波增益K 用到了cholesky矩阵分解加快求解;
223
+ # 公式5的右边有一个S的逆,如果S矩阵很大,S的逆求解消耗时间太大,
224
+ # 所以代码中把公式两边同时乘上S,右边的S*S的逆变成了单位矩阵,转化成AX=B形式求解。
225
+ kalman_gain = scipy.linalg.cho_solve(
226
+ (chol_factor, lower), np.dot(covariance, self._update_mat.T).T,
227
+ check_finite=False).T
228
+ # y = z - Hx' (3)
229
+ # 在公式3中,z为detection的均值向量,不包含速度变化值,即z=[cx, cy, r, h],
230
+ # H称为测量矩阵,它将track的均值向量x'映射到检测空间,该公式计算detection和track的均值误差
231
+ innovation = measurement - projected_mean
232
+
233
+ # 更新后的均值向量 x = x' + Ky (6)
234
+ new_mean = mean + np.dot(innovation, kalman_gain.T)
235
+ # 更新后的协方差矩阵 P = (I - KH)P' (7)
236
+ new_covariance = covariance - np.linalg.multi_dot((
237
+ kalman_gain, projected_cov, kalman_gain.T))
238
+ return new_mean, new_covariance
239
+
240
+ def gating_distance(self, mean, covariance, measurements,
241
+ only_position=False):
242
+ """Compute gating distance between state distribution and measurements.
243
+
244
+ A suitable distance threshold can be obtained from `chi2inv95`. If
245
+ `only_position` is False, the chi-square distribution has 4 degrees of
246
+ freedom, otherwise 2.
247
+
248
+ Parameters
249
+ ----------
250
+ mean : ndarray
251
+ Mean vector over the state distribution (8 dimensional).
252
+ 状态分布上的平均向量(8维)
253
+ covariance : ndarray
254
+ Covariance of the state distribution (8x8 dimensional).
255
+ 状态分布的协方差(8x8维)
256
+ measurements : ndarray
257
+ An Nx4 dimensional matrix of N measurements, each in
258
+ format (x, y, a, h) where (x, y) is the bounding box center
259
+ position, a the aspect ratio, and h the height.
260
+ N 个测量的 N×4维矩阵,每个矩阵的格式为(x,y,a,h),其中(x,y)是边界框中心位置,宽高比和h高度。
261
+ only_position : Optional[bool]
262
+ If True, distance computation is done with respect to the bounding
263
+ box center position only.
264
+ 如果为True,则只计算盒子中心位置
265
+
266
+ Returns
267
+ -------
268
+ ndarray
269
+ Returns an array of length N, where the i-th element contains the
270
+ squared Mahalanobis distance between (mean, covariance) and
271
+ `measurements[i]`.
272
+ 返回一个长度为N的数组,其中第i个元素包含(mean,covariance)和measurements [i]之间的平方Mahalanobis距离
273
+
274
+ """
275
+ mean, covariance = self.project(mean, covariance)
276
+ if only_position:
277
+ mean, covariance = mean[:2], covariance[:2, :2]
278
+ measurements = measurements[:, :2]
279
+
280
+ cholesky_factor = np.linalg.cholesky(covariance)
281
+ d = measurements - mean
282
+ z = scipy.linalg.solve_triangular(
283
+ cholesky_factor, d.T, lower=True, check_finite=False,
284
+ overwrite_b=True)
285
+ squared_maha = np.sum(z * z, axis=0)
286
+ return squared_maha
deep_sort/deep_sort/sort/linear_assignment.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vim: expandtab:ts=4:sw=4
2
+ from __future__ import absolute_import
3
+ import numpy as np
4
+ # The linear sum assignment problem is also known as minimum weight matching in bipartite graphs.
5
+ from scipy.optimize import linear_sum_assignment as linear_assignment
6
+ from . import kalman_filter
7
+
8
+
9
+ INFTY_COST = 1e+5
10
+
11
+ # min_cost_matching 使用匈牙利算法解决线性分配问题。
12
+ # 传入 门控余弦距离成本 或 iou cost
13
+ def min_cost_matching(
14
+ distance_metric, max_distance, tracks, detections, track_indices=None,
15
+ detection_indices=None):
16
+ """Solve linear assignment problem.
17
+
18
+ Parameters
19
+ ----------
20
+ distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
21
+ The distance metric is given a list of tracks and detections as well as
22
+ a list of N track indices and M detection indices. The metric should
23
+ return the NxM dimensional cost matrix, where element (i, j) is the
24
+ association cost between the i-th track in the given track indices and
25
+ the j-th detection in the given detection_indices.
26
+ max_distance : float
27
+ Gating threshold. Associations with cost larger than this value are
28
+ disregarded.
29
+ tracks : List[track.Track]
30
+ A list of predicted tracks at the current time step.
31
+ detections : List[detection.Detection]
32
+ A list of detections at the current time step.
33
+ track_indices : List[int]
34
+ List of track indices that maps rows in `cost_matrix` to tracks in
35
+ `tracks` (see description above).
36
+ detection_indices : List[int]
37
+ List of detection indices that maps columns in `cost_matrix` to
38
+ detections in `detections` (see description above).
39
+
40
+ Returns
41
+ -------
42
+ (List[(int, int)], List[int], List[int])
43
+ Returns a tuple with the following three entries:
44
+ * A list of matched track and detection indices.
45
+ * A list of unmatched track indices.
46
+ * A list of unmatched detection indices.
47
+
48
+ """
49
+ if track_indices is None:
50
+ track_indices = np.arange(len(tracks))
51
+ if detection_indices is None:
52
+ detection_indices = np.arange(len(detections))
53
+
54
+ if len(detection_indices) == 0 or len(track_indices) == 0:
55
+ return [], track_indices, detection_indices # Nothing to match.
56
+
57
+ # 计算成本矩阵
58
+ cost_matrix = distance_metric(
59
+ tracks, detections, track_indices, detection_indices)
60
+ cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
61
+
62
+ # 执行匈牙利算法,得到指派成功的索引对,行索引为tracks的索引,列索引为detections的索引
63
+ row_indices, col_indices = linear_assignment(cost_matrix)
64
+
65
+ matches, unmatched_tracks, unmatched_detections = [], [], []
66
+ # 找出未匹配的detections
67
+ for col, detection_idx in enumerate(detection_indices):
68
+ if col not in col_indices:
69
+ unmatched_detections.append(detection_idx)
70
+ # 找出未匹配的tracks
71
+ for row, track_idx in enumerate(track_indices):
72
+ if row not in row_indices:
73
+ unmatched_tracks.append(track_idx)
74
+ # 遍历匹配的(track, detection)索引对
75
+ for row, col in zip(row_indices, col_indices):
76
+ track_idx = track_indices[row]
77
+ detection_idx = detection_indices[col]
78
+ # 如果相应的cost大于阈值max_distance,也视为未匹配成功
79
+ if cost_matrix[row, col] > max_distance:
80
+ unmatched_tracks.append(track_idx)
81
+ unmatched_detections.append(detection_idx)
82
+ else:
83
+ matches.append((track_idx, detection_idx))
84
+ return matches, unmatched_tracks, unmatched_detections
85
+
86
+
87
+ def matching_cascade(
88
+ distance_metric, max_distance, cascade_depth, tracks, detections,
89
+ track_indices=None, detection_indices=None):
90
+ """Run matching cascade.
91
+
92
+ Parameters
93
+ ----------
94
+ distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
95
+ The distance metric is given a list of tracks and detections as well as
96
+ a list of N track indices and M detection indices. The metric should
97
+ return the NxM dimensional cost matrix, where element (i, j) is the
98
+ association cost between the i-th track in the given track indices and
99
+ the j-th detection in the given detection indices.
100
+ 距离度量:
101
+ 输入:一个轨迹和检测列表,以及一个N个轨迹索引和M个检测索引的列表。
102
+ 返回:NxM维的代价矩阵,其中元素(i,j)是给定轨迹索引中第i个轨迹与
103
+ 给定检测索引中第j个检测之间的关联成本。
104
+ max_distance : float
105
+ Gating threshold. Associations with cost larger than this value are
106
+ disregarded.
107
+ 门控阈值。成本大于此值的关联将被忽略。
108
+ cascade_depth: int
109
+ The cascade depth, should be se to the maximum track age.
110
+ 级联深度应设置为最大轨迹寿命。
111
+ tracks : List[track.Track]
112
+ A list of predicted tracks at the current time step.
113
+ 当前时间步的预测轨迹列表。
114
+ detections : List[detection.Detection]
115
+ A list of detections at the current time step.
116
+ 当前时间步的检测列表。
117
+ track_indices : Optional[List[int]]
118
+ List of track indices that maps rows in `cost_matrix` to tracks in
119
+ `tracks` (see description above). Defaults to all tracks.
120
+ 轨迹索引列表,用于将 cost_matrix中的行映射到tracks的
121
+ 轨迹(请参见上面的说明)。 默认为所有轨迹。
122
+ detection_indices : Optional[List[int]]
123
+ List of detection indices that maps columns in `cost_matrix` to
124
+ detections in `detections` (see description above). Defaults to all
125
+ detections.
126
+ 将 cost_matrix中的列映射到的检测索引列表
127
+ detections中的检测(请参见上面的说明)。 默认为全部检测。
128
+
129
+ Returns
130
+ -------
131
+ (List[(int, int)], List[int], List[int])
132
+ Returns a tuple with the following three entries:
133
+ * A list of matched track and detection indices.
134
+ * A list of unmatched track indices.
135
+ * A list of unmatched detection indices.
136
+
137
+ 返回包含以下三个条目的元组:
138
+
139
+ 匹配的跟踪和检测的索引列表,
140
+ 不匹配的轨迹索引的列表,
141
+ 未匹配的检测索引的列表。
142
+
143
+ """
144
+
145
+ # 分配track_indices和detection_indices两个列表
146
+ if track_indices is None:
147
+ track_indices = list(range(len(tracks)))
148
+ if detection_indices is None:
149
+ detection_indices = list(range(len(detections)))
150
+
151
+ # 初始化匹配集matches M ← ∅
152
+ # 未匹配检测集unmatched_detections U ← D
153
+ unmatched_detections = detection_indices
154
+ matches = []
155
+ # 由小到大依次对每个level的tracks做匹配
156
+ for level in range(cascade_depth):
157
+ # 如果没有detections,退出循环
158
+ if len(unmatched_detections) == 0: # No detections left
159
+ break
160
+
161
+ # 当前level的所有tracks索引
162
+ # 步骤6:Select tracks by age
163
+ track_indices_l = [
164
+ k for k in track_indices
165
+ if tracks[k].time_since_update == 1 + level
166
+ ]
167
+ # 如果当前level没有track,继续
168
+ if len(track_indices_l) == 0: # Nothing to match at this level
169
+ continue
170
+
171
+ # 步骤7:调用min_cost_matching函数进行匹配
172
+ matches_l, _, unmatched_detections = \
173
+ min_cost_matching(
174
+ distance_metric, max_distance, tracks, detections,
175
+ track_indices_l, unmatched_detections)
176
+ matches += matches_l # 步骤8
177
+ unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches)) # 步骤9
178
+ return matches, unmatched_tracks, unmatched_detections
179
+
180
+ '''
181
+ 门控成本矩阵:通过计算卡尔曼滤波的状态分布和测量值之间的距离对成本矩阵进行限制,
182
+ 成本矩阵中的距离是track和detection之间的外观相似度。
183
+ 如果一个轨迹要去匹配两个外观特征非常相似的 detection,很容易出错;
184
+ 分别让两个detection计算与这个轨迹的马氏距离,并使用一个阈值gating_threshold进行限制,
185
+ 就可以将马氏距离较远的那个detection区分开,从而减少错误的匹配。
186
+ '''
187
+ def gate_cost_matrix(
188
+ kf, cost_matrix, tracks, detections, track_indices, detection_indices,
189
+ gated_cost=INFTY_COST, only_position=False):
190
+ """Invalidate infeasible entries in cost matrix based on the state
191
+ distributions obtained by Kalman filtering.
192
+
193
+ Parameters
194
+ ----------
195
+ kf : The Kalman filter.
196
+ cost_matrix : ndarray
197
+ The NxM dimensional cost matrix, where N is the number of track indices
198
+ and M is the number of detection indices, such that entry (i, j) is the
199
+ association cost between `tracks[track_indices[i]]` and
200
+ `detections[detection_indices[j]]`.
201
+ tracks : List[track.Track]
202
+ A list of predicted tracks at the current time step.
203
+ detections : List[detection.Detection]
204
+ A list of detections at the current time step.
205
+ track_indices : List[int]
206
+ List of track indices that maps rows in `cost_matrix` to tracks in
207
+ `tracks` (see description above).
208
+ detection_indices : List[int]
209
+ List of detection indices that maps columns in `cost_matrix` to
210
+ detections in `detections` (see description above).
211
+ gated_cost : Optional[float]
212
+ Entries in the cost matrix corresponding to infeasible associations are
213
+ set this value. Defaults to a very large value.
214
+ 代价矩阵中与不可行关联相对应的条目设置此值。 默认为一个很大的值。
215
+ only_position : Optional[bool]
216
+ If True, only the x, y position of the state distribution is considered
217
+ during gating. Defaults to False.
218
+ 如果为True,则在门控期间仅考虑状态分布的x,y位置。默认为False。
219
+
220
+ Returns
221
+ -------
222
+ ndarray
223
+ Returns the modified cost matrix.
224
+
225
+ """
226
+ # 根据通过卡尔曼滤波获得的状态分布,使成本矩阵中的不可行条目无效。
227
+ gating_dim = 2 if only_position else 4 # 测量空间维度
228
+ # 马氏距离通过测算检测与平均轨迹位置的距离超过多少标准差来考虑状态估计的不确定性。
229
+ # 通过从逆chi^2分布计算95%置信区间的阈值,排除可能性小的关联。
230
+ # 四维测量空间对应的马氏阈值为9.4877
231
+ gating_threshold = kalman_filter.chi2inv95[gating_dim]
232
+ measurements = np.asarray(
233
+ [detections[i].to_xyah() for i in detection_indices])
234
+ for row, track_idx in enumerate(track_indices):
235
+ track = tracks[track_idx]
236
+ #KalmanFilter.gating_distance 计算状态分布和测量之间的选通距离
237
+ gating_distance = kf.gating_distance(
238
+ track.mean, track.covariance, measurements, only_position)
239
+ cost_matrix[row, gating_distance > gating_threshold] = gated_cost
240
+ return cost_matrix
deep_sort/deep_sort/sort/nn_matching.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vim: expandtab:ts=4:sw=4
2
+ import numpy as np
3
+
4
+
5
+ def _pdist(a, b):
6
+ """Compute pair-wise squared distance between points in `a` and `b`.
7
+
8
+ Parameters
9
+ ----------
10
+ a : array_like
11
+ An NxM matrix of N samples of dimensionality M.
12
+ b : array_like
13
+ An LxM matrix of L samples of dimensionality M.
14
+
15
+ Returns
16
+ -------
17
+ ndarray
18
+ Returns a matrix of size len(a), len(b) such that element (i, j)
19
+ contains the squared distance between `a[i]` and `b[j]`.
20
+
21
+
22
+ 用于计算成对点之间的平方距离
23
+ a :NxM 矩阵,代表 N 个样本,每个样本 M 个数值
24
+ b :LxM 矩阵,代表 L 个样本,每个样本有 M 个数值
25
+ 返回的是 NxL 的矩阵,比如 dist[i][j] 代表 a[i] 和 b[j] 之间的平方和距离
26
+ 参考:https://blog.csdn.net/frankzd/article/details/80251042
27
+
28
+ """
29
+ a, b = np.asarray(a), np.asarray(b)
30
+ if len(a) == 0 or len(b) == 0:
31
+ return np.zeros((len(a), len(b)))
32
+ a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1)
33
+ r2 = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :]
34
+ r2 = np.clip(r2, 0., float(np.inf))
35
+ return r2
36
+
37
+
38
+ def _cosine_distance(a, b, data_is_normalized=False):
39
+ """Compute pair-wise cosine distance between points in `a` and `b`.
40
+
41
+ Parameters
42
+ ----------
43
+ a : array_like
44
+ An NxM matrix of N samples of dimensionality M.
45
+ b : array_like
46
+ An LxM matrix of L samples of dimensionality M.
47
+ data_is_normalized : Optional[bool]
48
+ If True, assumes rows in a and b are unit length vectors.
49
+ Otherwise, a and b are explicitly normalized to lenght 1.
50
+
51
+ Returns
52
+ -------
53
+ ndarray
54
+ Returns a matrix of size len(a), len(b) such that eleement (i, j)
55
+ contains the squared distance between `a[i]` and `b[j]`.
56
+
57
+ 用于计算成对点之间的余弦距离
58
+ a :NxM 矩阵,代表 N 个样本,每个样本 M 个数值
59
+ b :LxM 矩阵,代表 L 个样本,每个样本有 M 个数值
60
+ 返回的是 NxL 的矩阵,比如 c[i][j] 代表 a[i] 和 b[j] 之间的余弦距离
61
+ 参考:
62
+ https://blog.csdn.net/u013749540/article/details/51813922
63
+
64
+
65
+ """
66
+ if not data_is_normalized:
67
+ # np.linalg.norm 求向量的范式,默认是 L2 范式
68
+ a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True)
69
+ b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True)
70
+ return 1. - np.dot(a, b.T) # 余弦距离 = 1 - 余弦相似度
71
+
72
+
73
+ def _nn_euclidean_distance(x, y):
74
+ """ Helper function for nearest neighbor distance metric (Euclidean).
75
+
76
+ Parameters
77
+ ----------
78
+ x : ndarray
79
+ A matrix of N row-vectors (sample points).
80
+ y : ndarray
81
+ A matrix of M row-vectors (query points).
82
+
83
+ Returns
84
+ -------
85
+ ndarray
86
+ A vector of length M that contains for each entry in `y` the
87
+ smallest Euclidean distance to a sample in `x`.
88
+
89
+ """
90
+ distances = _pdist(x, y)
91
+ return np.maximum(0.0, distances.min(axis=0))
92
+
93
+
94
+ def _nn_cosine_distance(x, y):
95
+ """ Helper function for nearest neighbor distance metric (cosine).
96
+
97
+ Parameters
98
+ ----------
99
+ x : ndarray
100
+ A matrix of N row-vectors (sample points).
101
+ y : ndarray
102
+ A matrix of M row-vectors (query points).
103
+
104
+ Returns
105
+ -------
106
+ ndarray
107
+ A vector of length M that contains for each entry in `y` the
108
+ smallest cosine distance to a sample in `x`.
109
+
110
+ """
111
+ distances = _cosine_distance(x, y)
112
+ return distances.min(axis=0)
113
+
114
+
115
+ class NearestNeighborDistanceMetric(object):
116
+ """
117
+ A nearest neighbor distance metric that, for each target, returns
118
+ the closest distance to any sample that has been observed so far.
119
+
120
+ 对于每个目标,返回最近邻居的距离度量, 即与到目前为止已观察到的任何样本的最接近距离。
121
+
122
+ Parameters
123
+ ----------
124
+ metric : str
125
+ Either "euclidean" or "cosine".
126
+ matching_threshold: float
127
+ The matching threshold. Samples with larger distance are considered an
128
+ invalid match.
129
+ 匹配阈值。 距离较大的样本对被认为是无效的匹配。
130
+ budget : Optional[int]
131
+ If not None, fix samples per class to at most this number. Removes
132
+ the oldest samples when the budget is reached.
133
+ 如果不是None,则将每个类别的样本最多固定为该数字。
134
+ 删除达到budget时最古老的样本。
135
+
136
+ Attributes
137
+ ----------
138
+ samples : Dict[int -> List[ndarray]]
139
+ A dictionary that maps from target identities to the list of samples
140
+ that have been observed so far.
141
+ 一个从目标ID映射到到目前为止已经观察到的样本列表的字典
142
+
143
+ """
144
+
145
+ def __init__(self, metric, matching_threshold, budget=None):
146
+
147
+
148
+ if metric == "euclidean":
149
+ self._metric = _nn_euclidean_distance # 欧式距离
150
+ elif metric == "cosine":
151
+ self._metric = _nn_cosine_distance # 余弦距离
152
+ else:
153
+ raise ValueError(
154
+ "Invalid metric; must be either 'euclidean' or 'cosine'")
155
+ self.matching_threshold = matching_threshold
156
+ self.budget = budget # budge用于控制 feature 的数目
157
+ self.samples = {}
158
+
159
+ def partial_fit(self, features, targets, active_targets):
160
+ """Update the distance metric with new data.
161
+ 用新的数据更新测量距离
162
+
163
+ Parameters
164
+ ----------
165
+ features : ndarray
166
+ An NxM matrix of N features of dimensionality M.
167
+ targets : ndarray
168
+ An integer array of associated target identities.
169
+ active_targets : List[int]
170
+ A list of targets that are currently present in the scene.
171
+ 传入特征列表及其对应id,partial_fit构造一个活跃目标的特征字典。
172
+
173
+ """
174
+ for feature, target in zip(features, targets):
175
+ # 对应目标下添加新的feature,更新feature集合
176
+ # samples字典 d: feature list}
177
+ self.samples.setdefault(target, []).append(feature)
178
+ if self.budget is not None:
179
+ # 只考虑budget个目标,超过直接忽略
180
+ self.samples[target] = self.samples[target][-self.budget:]
181
+
182
+ # 筛选激活的目标;samples是一个字典{id->feature list}
183
+ self.samples = {k: self.samples[k] for k in active_targets}
184
+
185
+ def distance(self, features, targets):
186
+ """Compute distance between features and targets.
187
+
188
+ Parameters
189
+ ----------
190
+ features : ndarray
191
+ An NxM matrix of N features of dimensionality M.
192
+ targets : List[int]
193
+ A list of targets to match the given `features` against.
194
+
195
+ Returns
196
+ -------
197
+ ndarray
198
+ Returns a cost matrix of shape len(targets), len(features), where
199
+ element (i, j) contains the closest squared distance between
200
+ `targets[i]` and `features[j]`.
201
+
202
+ 计算features和targets之间的距离,返回一个成本矩阵(代价矩阵)
203
+ """
204
+ cost_matrix = np.zeros((len(targets), len(features)))
205
+ for i, target in enumerate(targets):
206
+ cost_matrix[i, :] = self._metric(self.samples[target], features)
207
+ return cost_matrix
deep_sort/deep_sort/sort/preprocessing.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vim: expandtab:ts=4:sw=4
2
+ import numpy as np
3
+ import cv2
4
+
5
+
6
+ def non_max_suppression(boxes, max_bbox_overlap, scores=None):
7
+ """Suppress overlapping detections.
8
+
9
+ Original code from [1]_ has been adapted to include confidence score.
10
+
11
+ .. [1] http://www.pyimagesearch.com/2015/02/16/
12
+ faster-non-maximum-suppression-python/
13
+
14
+ Examples
15
+ --------
16
+
17
+ >>> boxes = [d.roi for d in detections]
18
+ >>> scores = [d.confidence for d in detections]
19
+ >>> indices = non_max_suppression(boxes, max_bbox_overlap, scores)
20
+ >>> detections = [detections[i] for i in indices]
21
+
22
+ Parameters
23
+ ----------
24
+ boxes : ndarray
25
+ Array of ROIs (x, y, width, height).
26
+ max_bbox_overlap : float
27
+ ROIs that overlap more than this values are suppressed.
28
+ scores : Optional[array_like]
29
+ Detector confidence score.
30
+
31
+ Returns
32
+ -------
33
+ List[int]
34
+ Returns indices of detections that have survived non-maxima suppression.
35
+
36
+ """
37
+ if len(boxes) == 0:
38
+ return []
39
+
40
+ boxes = boxes.astype(np.float32)
41
+ pick = []
42
+
43
+ x1 = boxes[:, 0]
44
+ y1 = boxes[:, 1]
45
+ x2 = boxes[:, 2] + boxes[:, 0]
46
+ y2 = boxes[:, 3] + boxes[:, 1]
47
+
48
+ area = (x2 - x1 + 1) * (y2 - y1 + 1)
49
+ if scores is not None:
50
+ idxs = np.argsort(scores)
51
+ else:
52
+ idxs = np.argsort(y2)
53
+
54
+ while len(idxs) > 0:
55
+ last = len(idxs) - 1
56
+ i = idxs[last]
57
+ pick.append(i)
58
+
59
+ xx1 = np.maximum(x1[i], x1[idxs[:last]])
60
+ yy1 = np.maximum(y1[i], y1[idxs[:last]])
61
+ xx2 = np.minimum(x2[i], x2[idxs[:last]])
62
+ yy2 = np.minimum(y2[i], y2[idxs[:last]])
63
+
64
+ w = np.maximum(0, xx2 - xx1 + 1)
65
+ h = np.maximum(0, yy2 - yy1 + 1)
66
+
67
+ overlap = (w * h) / area[idxs[:last]] # IOU
68
+
69
+ idxs = np.delete(
70
+ idxs, np.concatenate(
71
+ ([last], np.where(overlap > max_bbox_overlap)[0])))
72
+
73
+ return pick
deep_sort/deep_sort/sort/track.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vim: expandtab:ts=4:sw=4
2
+
3
+
4
+ class TrackState:
5
+ """
6
+ Enumeration type for the single target track state. Newly created tracks are
7
+ classified as `tentative` until enough evidence has been collected. Then,
8
+ the track state is changed to `confirmed`. Tracks that are no longer alive
9
+ are classified as `deleted` to mark them for removal from the set of active
10
+ tracks.
11
+
12
+ 单个目标track状态的枚举类型。
13
+ 新创建的track分类为“Tentative”,直到收集到足够的证据为止。
14
+ 然后,跟踪状态更改为“Confirmed”。
15
+ 不再活跃的tracks被归类为“Deleted”,以将其标记为从有效集中删除。
16
+
17
+ """
18
+
19
+ Tentative = 1
20
+ Confirmed = 2
21
+ Deleted = 3
22
+
23
+
24
+ class Track:
25
+ """
26
+ A single target track with state space `(x, y, a, h)` and associated
27
+ velocities, where `(x, y)` is the center of the bounding box, `a` is the
28
+ aspect ratio and `h` is the height.
29
+
30
+ 具有状态空间(x,y,a,h)并关联速度的单个目标轨迹(track),
31
+ 其中(x,y)是边界框的中心,a是宽高比,h是高度。
32
+
33
+ Parameters
34
+ ----------
35
+ mean : ndarray
36
+ Mean vector of the initial state distribution.
37
+ 初始状态分布的均值向量
38
+ covariance : ndarray
39
+ Covariance matrix of the initial state distribution.
40
+ 初始状态分布的协方差矩阵
41
+ track_id : int
42
+ A unique track identifier.
43
+ 唯一的track标识符
44
+ n_init : int
45
+ Number of consecutive detections before the track is confirmed. The
46
+ track state is set to `Deleted` if a miss occurs within the first
47
+ `n_init` frames.
48
+ 确认track之前的连续检测次数。 在第一个n_init帧中
49
+ 第一个未命中的情况下将跟踪状态设置为“Deleted”
50
+ max_age : int
51
+ The maximum number of consecutive misses before the track state is
52
+ set to `Deleted`.
53
+ 跟踪状态设置为Deleted之前的最大连续未命中数;代表一个track的存活期限
54
+
55
+ feature : Optional[ndarray]
56
+ Feature vector of the detection this track originates from. If not None,
57
+ this feature is added to the `features` cache.
58
+ 此track所源自的检测的特征向量。 如果不是None,此feature已添加到feature缓存中。
59
+
60
+ Attributes
61
+ ----------
62
+ mean : ndarray
63
+ Mean vector of the initial state distribution.
64
+ 初始状态分布的均值向量
65
+ covariance : ndarray
66
+ Covariance matrix of the initial state distribution.
67
+ 初始状态分布的协方差矩阵
68
+ track_id : int
69
+ A unique track identifier.
70
+ hits : int
71
+ Total number of measurement updates.
72
+ 测量更新总数
73
+ age : int
74
+ Total number of frames since first occurence.
75
+ 自第一次出现以来的总帧数
76
+ time_since_update : int
77
+ Total number of frames since last measurement update.
78
+ 自上次测量更新以来的总帧数
79
+ state : TrackState
80
+ The current track state.
81
+ features : List[ndarray]
82
+ A cache of features. On each measurement update, the associated feature
83
+ vector is added to this list.
84
+ feature缓存。每次测量更新时,相关feature向量添加到此列表中
85
+
86
+ """
87
+
88
+ def __init__(self, mean, covariance, track_id, n_init, max_age,
89
+ feature=None):
90
+ self.mean = mean
91
+ self.covariance = covariance
92
+ self.track_id = track_id
93
+ # hits代表匹配上了多少次,匹配次数超过n_init,设置Confirmed状态
94
+ # hits每次调用update函数的时候+1
95
+ self.hits = 1
96
+ self.age = 1 # 和time_since_update功能重复
97
+ # 每次调用predict函数的时候就会+1; 每次调用update函数的时候就会设置为0
98
+ self.time_since_update = 0
99
+
100
+ self.state = TrackState.Tentative # 初始化一个Track的时设置Tentative状态
101
+ # 每个track对应多个features, 每次更新都会将最新的feature添加到列表中
102
+ self.features = []
103
+ if feature is not None:
104
+ self.features.append(feature)
105
+
106
+ self._n_init = n_init
107
+ self._max_age = max_age
108
+
109
+ def to_tlwh(self):
110
+ """Get current position in bounding box format `(top left x, top left y,
111
+ width, height)`.
112
+
113
+ Returns
114
+ -------
115
+ ndarray
116
+ The bounding box.
117
+
118
+ """
119
+ ret = self.mean[:4].copy()
120
+ ret[2] *= ret[3]
121
+ ret[:2] -= ret[2:] / 2
122
+ return ret
123
+
124
+ def to_tlbr(self):
125
+ """Get current position in bounding box format `(min x, miny, max x,
126
+ max y)`.
127
+
128
+ Returns
129
+ -------
130
+ ndarray
131
+ The bounding box.
132
+
133
+ """
134
+ ret = self.to_tlwh()
135
+ ret[2:] = ret[:2] + ret[2:]
136
+ return ret
137
+
138
+ def predict(self, kf):
139
+ """Propagate the state distribution to the current time step using a
140
+ Kalman filter prediction step.
141
+ 使用卡尔曼滤波器预测步骤将状态分布传播到当前时间步
142
+
143
+ Parameters
144
+ ----------
145
+ kf : kalman_filter.KalmanFilter
146
+ The Kalman filter.
147
+
148
+ """
149
+ self.mean, self.covariance = kf.predict(self.mean, self.covariance)
150
+ self.age += 1
151
+ self.time_since_update += 1
152
+
153
+ def update(self, kf, detection):
154
+ """Perform Kalman filter measurement update step and update the feature
155
+ cache.
156
+ 执行卡尔曼滤波器测量更新步骤并更新feature缓存
157
+
158
+ Parameters
159
+ ----------
160
+ kf : kalman_filter.KalmanFilter
161
+ The Kalman filter.
162
+ detection : Detection
163
+ The associated detection.
164
+
165
+ """
166
+ self.mean, self.covariance = kf.update(
167
+ self.mean, self.covariance, detection.to_xyah())
168
+ self.features.append(detection.feature)
169
+
170
+ self.hits += 1
171
+ self.time_since_update = 0
172
+ # hits代表匹配上了多少次,匹配次数超过n_init,设置Confirmed状态
173
+ # 连续匹配上n_init帧的时候,转变为确定态
174
+ if self.state == TrackState.Tentative and self.hits >= self._n_init:
175
+ self.state = TrackState.Confirmed
176
+
177
+ def mark_missed(self):
178
+ """Mark this track as missed (no association at the current time step).
179
+ """
180
+ # 如果在处于Tentative态的情况下没有匹配上任何detection,转变为删除态。
181
+ if self.state == TrackState.Tentative:
182
+ self.state = TrackState.Deleted
183
+ elif self.time_since_update > self._max_age:
184
+ # 如果time_since_update超过max_age,设置Deleted状态
185
+ # 即失配连续达到max_age次数的时候,转变为删除态
186
+ self.state = TrackState.Deleted
187
+
188
+ def is_tentative(self):
189
+ """Returns True if this track is tentative (unconfirmed).
190
+ """
191
+ return self.state == TrackState.Tentative
192
+
193
+ def is_confirmed(self):
194
+ """Returns True if this track is confirmed."""
195
+ return self.state == TrackState.Confirmed
196
+
197
+ def is_deleted(self):
198
+ """Returns True if this track is dead and should be deleted."""
199
+ return self.state == TrackState.Deleted
deep_sort/deep_sort/sort/tracker.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vim: expandtab:ts=4:sw=4
2
+ from __future__ import absolute_import
3
+ import numpy as np
4
+ from . import kalman_filter
5
+ from . import linear_assignment
6
+ from . import iou_matching
7
+ from .track import Track
8
+
9
+
10
+ class Tracker:
11
+ """
12
+ This is the multi-target tracker.
13
+
14
+ Parameters
15
+ ----------
16
+ metric : nn_matching.NearestNeighborDistanceMetric
17
+ A distance metric for measurement-to-track association.
18
+ max_age : int
19
+ Maximum number of missed misses before a track is deleted.
20
+ n_init : int
21
+ Number of consecutive detections before the track is confirmed. The
22
+ track state is set to `Deleted` if a miss occurs within the first
23
+ `n_init` frames.
24
+
25
+ Attributes
26
+ ----------
27
+ metric : nn_matching.NearestNeighborDistanceMetric
28
+ The distance metric used for measurement to track association.
29
+ 测量与轨迹关联的距离度量
30
+ max_age : int
31
+ Maximum number of missed misses before a track is deleted.
32
+ 删除轨迹前的最大未命中数
33
+ n_init : int
34
+ Number of frames that a track remains in initialization phase.
35
+ 确认轨迹前的连续检测次数。如果前n_init帧内发生未命中,则将轨迹状态设置为Deleted
36
+ kf : kalman_filter.KalmanFilter
37
+ A Kalman filter to filter target trajectories in image space.
38
+ tracks : List[Track]
39
+ The list of active tracks at the current time step.
40
+
41
+ """
42
+
43
+ def __init__(self, metric, max_iou_distance=0.7, max_age=70, n_init=3):
44
+ self.metric = metric
45
+ self.max_iou_distance = max_iou_distance
46
+ self.max_age = max_age
47
+ self.n_init = n_init
48
+
49
+ self.kf = kalman_filter.KalmanFilter() # 实例化卡尔曼滤波器
50
+ self.tracks = [] # 保存一个轨迹列表,用于保存一系列轨迹
51
+ self._next_id = 1 # 下一个分配的轨迹id
52
+
53
+ def predict(self):
54
+ """Propagate track state distributions one time step forward.
55
+ 将跟踪状态分布向前传播一步
56
+
57
+ This function should be called once every time step, before `update`.
58
+ """
59
+ for track in self.tracks:
60
+ track.predict(self.kf)
61
+
62
+ def update(self, detections):
63
+ """Perform measurement update and track management.
64
+ 执行测量更新和轨迹管理
65
+
66
+ Parameters
67
+ ----------
68
+ detections : List[deep_sort.detection.Detection]
69
+ A list of detections at the current time step.
70
+
71
+ """
72
+ # Run matching cascade.
73
+ matches, unmatched_tracks, unmatched_detections = \
74
+ self._match(detections)
75
+
76
+ # Update track set.
77
+
78
+ # 1. 针对匹配上的结果
79
+ for track_idx, detection_idx in matches:
80
+ # 更新tracks中相应的detection
81
+ self.tracks[track_idx].update(
82
+ self.kf, detections[detection_idx])
83
+
84
+ # 2. 针对未匹配的track, 调用mark_missed进行标记
85
+ # track失配时,若Tantative则删除;若update时间很久也删除
86
+ for track_idx in unmatched_tracks:
87
+ self.tracks[track_idx].mark_missed()
88
+
89
+ # 3. 针对未匹配的detection, detection失配,进行初始化
90
+ for detection_idx in unmatched_detections:
91
+ self._initiate_track(detections[detection_idx])
92
+
93
+ # 得到最新的tracks列表,保存的是标记为Confirmed和Tentative的track
94
+ self.tracks = [t for t in self.tracks if not t.is_deleted()]
95
+
96
+ # Update distance metric.
97
+ active_targets = [t.track_id for t in self.tracks if t.is_confirmed()]
98
+ features, targets = [], []
99
+ for track in self.tracks:
100
+ # 获取所有Confirmed状态的track id
101
+ if not track.is_confirmed():
102
+ continue
103
+ features += track.features # 将Confirmed状态的track的features添加到features列表
104
+ # 获取每个feature对应的trackid
105
+ targets += [track.track_id for _ in track.features]
106
+ track.features = []
107
+ # 距离度量中的特征集更新
108
+ self.metric.partial_fit(
109
+ np.asarray(features), np.asarray(targets), active_targets)
110
+
111
+ def _match(self, detections):
112
+
113
+ def gated_metric(tracks, dets, track_indices, detection_indices):
114
+ features = np.array([dets[i].feature for i in detection_indices])
115
+ targets = np.array([tracks[i].track_id for i in track_indices])
116
+
117
+ # 通过最近邻(余弦距离)计算出成本矩阵(代价矩阵)
118
+ cost_matrix = self.metric.distance(features, targets)
119
+ # 计算门控后的成本矩阵(代价矩阵)
120
+ cost_matrix = linear_assignment.gate_cost_matrix(
121
+ self.kf, cost_matrix, tracks, dets, track_indices,
122
+ detection_indices)
123
+
124
+ return cost_matrix
125
+
126
+ # Split track set into confirmed and unconfirmed tracks.
127
+ # 区分开confirmed tracks和unconfirmed tracks
128
+ confirmed_tracks = [
129
+ i for i, t in enumerate(self.tracks) if t.is_confirmed()]
130
+ unconfirmed_tracks = [
131
+ i for i, t in enumerate(self.tracks) if not t.is_confirmed()]
132
+
133
+ # Associate confirmed tracks using appearance features.
134
+ # 对确定态的轨迹进行级联匹配,得到匹配的tracks、不匹配的tracks、不匹配的detections
135
+ # matching_cascade 根据特征将检测框匹配到确认的轨迹。
136
+ # 传入门控后的成本矩阵
137
+ matches_a, unmatched_tracks_a, unmatched_detections = \
138
+ linear_assignment.matching_cascade(
139
+ gated_metric, self.metric.matching_threshold, self.max_age,
140
+ self.tracks, detections, confirmed_tracks)
141
+
142
+ # Associate remaining tracks together with unconfirmed tracks using IOU.
143
+ # 将未确定态的轨迹和刚刚没有匹配上的轨迹组合为 iou_track_candidates
144
+ # 并进行基于IoU的匹配
145
+ iou_track_candidates = unconfirmed_tracks + [
146
+ k for k in unmatched_tracks_a if
147
+ self.tracks[k].time_since_update == 1] # 刚刚没有匹配上的轨迹
148
+ unmatched_tracks_a = [
149
+ k for k in unmatched_tracks_a if
150
+ self.tracks[k].time_since_update != 1] # 并非刚刚没有匹配上的轨迹
151
+ # 对级联匹配中还没有匹配成功的目标再进行IoU匹配
152
+ # min_cost_matching 使用匈牙利算法解决线性分配问题。
153
+ # 传入 iou_cost,尝试关联剩余的轨迹与未确认的轨迹。
154
+ matches_b, unmatched_tracks_b, unmatched_detections = \
155
+ linear_assignment.min_cost_matching(
156
+ iou_matching.iou_cost, self.max_iou_distance, self.tracks,
157
+ detections, iou_track_candidates, unmatched_detections)
158
+
159
+ matches = matches_a + matches_b # 组合两部分匹配
160
+ unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))
161
+ return matches, unmatched_tracks, unmatched_detections
162
+
163
+ def _initiate_track(self, detection):
164
+ mean, covariance = self.kf.initiate(detection.to_xyah())
165
+ self.tracks.append(Track(
166
+ mean, covariance, self._next_id, self.n_init, self.max_age,
167
+ detection.feature))
168
+ self._next_id += 1
deep_sort/utils/__init__.py ADDED
File without changes
deep_sort/utils/asserts.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import environ
2
+
3
+
4
+ def assert_in(file, files_to_check):
5
+ if file not in files_to_check:
6
+ raise AssertionError("{} does not exist in the list".format(str(file)))
7
+ return True
8
+
9
+
10
+ def assert_in_env(check_list: list):
11
+ for item in check_list:
12
+ assert_in(item, environ.keys())
13
+ return True
deep_sort/utils/draw.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+ palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)
5
+
6
+
7
+ def compute_color_for_labels(label):
8
+ """
9
+ Simple function that adds fixed color depending on the class
10
+ """
11
+ color = [int((p * (label ** 2 - label + 1)) % 255) for p in palette]
12
+ return tuple(color)
13
+
14
+
15
+ def draw_boxes(img, bbox, identities=None, offset=(0,0)):
16
+ for i,box in enumerate(bbox):
17
+ x1,y1,x2,y2 = [int(i) for i in box]
18
+ x1 += offset[0]
19
+ x2 += offset[0]
20
+ y1 += offset[1]
21
+ y2 += offset[1]
22
+ # box text and bar
23
+ id = int(identities[i]) if identities is not None else 0
24
+ color = compute_color_for_labels(id)
25
+ label = '{}{:d}'.format("", id)
26
+ t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 2 , 2)[0]
27
+ cv2.rectangle(img,(x1, y1),(x2,y2),color,3)
28
+ cv2.rectangle(img,(x1, y1),(x1+t_size[0]+3,y1+t_size[1]+4), color,-1)
29
+ cv2.putText(img,label,(x1,y1+t_size[1]+4), cv2.FONT_HERSHEY_PLAIN, 2, [255,255,255], 2)
30
+ return img
31
+
32
+
33
+
34
+ if __name__ == '__main__':
35
+ for i in range(82):
36
+ print(compute_color_for_labels(i))
deep_sort/utils/evaluation.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import copy
4
+ import motmetrics as mm
5
+ mm.lap.default_solver = 'lap'
6
+ from utils.io import read_results, unzip_objs
7
+
8
+
9
+ class Evaluator(object):
10
+
11
+ def __init__(self, data_root, seq_name, data_type):
12
+ self.data_root = data_root
13
+ self.seq_name = seq_name
14
+ self.data_type = data_type
15
+
16
+ self.load_annotations()
17
+ self.reset_accumulator()
18
+
19
+ def load_annotations(self):
20
+ assert self.data_type == 'mot'
21
+
22
+ gt_filename = os.path.join(self.data_root, self.seq_name, 'gt', 'gt.txt')
23
+ self.gt_frame_dict = read_results(gt_filename, self.data_type, is_gt=True)
24
+ self.gt_ignore_frame_dict = read_results(gt_filename, self.data_type, is_ignore=True)
25
+
26
+ def reset_accumulator(self):
27
+ self.acc = mm.MOTAccumulator(auto_id=True)
28
+
29
+ def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False):
30
+ # results
31
+ trk_tlwhs = np.copy(trk_tlwhs)
32
+ trk_ids = np.copy(trk_ids)
33
+
34
+ # gts
35
+ gt_objs = self.gt_frame_dict.get(frame_id, [])
36
+ gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2]
37
+
38
+ # ignore boxes
39
+ ignore_objs = self.gt_ignore_frame_dict.get(frame_id, [])
40
+ ignore_tlwhs = unzip_objs(ignore_objs)[0]
41
+
42
+
43
+ # remove ignored results
44
+ keep = np.ones(len(trk_tlwhs), dtype=bool)
45
+ iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5)
46
+ if len(iou_distance) > 0:
47
+ match_is, match_js = mm.lap.linear_sum_assignment(iou_distance)
48
+ match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js])
49
+ match_ious = iou_distance[match_is, match_js]
50
+
51
+ match_js = np.asarray(match_js, dtype=int)
52
+ match_js = match_js[np.logical_not(np.isnan(match_ious))]
53
+ keep[match_js] = False
54
+ trk_tlwhs = trk_tlwhs[keep]
55
+ trk_ids = trk_ids[keep]
56
+
57
+ # get distance matrix
58
+ iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5)
59
+
60
+ # acc
61
+ self.acc.update(gt_ids, trk_ids, iou_distance)
62
+
63
+ if rtn_events and iou_distance.size > 0 and hasattr(self.acc, 'last_mot_events'):
64
+ events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics
65
+ else:
66
+ events = None
67
+ return events
68
+
69
+ def eval_file(self, filename):
70
+ self.reset_accumulator()
71
+
72
+ result_frame_dict = read_results(filename, self.data_type, is_gt=False)
73
+ frames = sorted(list(set(self.gt_frame_dict.keys()) | set(result_frame_dict.keys())))
74
+ for frame_id in frames:
75
+ trk_objs = result_frame_dict.get(frame_id, [])
76
+ trk_tlwhs, trk_ids = unzip_objs(trk_objs)[:2]
77
+ self.eval_frame(frame_id, trk_tlwhs, trk_ids, rtn_events=False)
78
+
79
+ return self.acc
80
+
81
+ @staticmethod
82
+ def get_summary(accs, names, metrics=('mota', 'num_switches', 'idp', 'idr', 'idf1', 'precision', 'recall')):
83
+ names = copy.deepcopy(names)
84
+ if metrics is None:
85
+ metrics = mm.metrics.motchallenge_metrics
86
+ metrics = copy.deepcopy(metrics)
87
+
88
+ mh = mm.metrics.create()
89
+ summary = mh.compute_many(
90
+ accs,
91
+ metrics=metrics,
92
+ names=names,
93
+ generate_overall=True
94
+ )
95
+
96
+ return summary
97
+
98
+ @staticmethod
99
+ def save_summary(summary, filename):
100
+ import pandas as pd
101
+ writer = pd.ExcelWriter(filename)
102
+ summary.to_excel(writer)
103
+ writer.save()
deep_sort/utils/io.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict
3
+ import numpy as np
4
+
5
+ # from utils.log import get_logger
6
+
7
+
8
+ def write_results(filename, results, data_type):
9
+ if data_type == 'mot':
10
+ save_format = '{frame},{id},{x1},{y1},{w},{h},-1,-1,-1,-1\n'
11
+ elif data_type == 'kitti':
12
+ save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n'
13
+ else:
14
+ raise ValueError(data_type)
15
+
16
+ with open(filename, 'w') as f:
17
+ for frame_id, tlwhs, track_ids in results:
18
+ if data_type == 'kitti':
19
+ frame_id -= 1
20
+ for tlwh, track_id in zip(tlwhs, track_ids):
21
+ if track_id < 0:
22
+ continue
23
+ x1, y1, w, h = tlwh
24
+ x2, y2 = x1 + w, y1 + h
25
+ line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h)
26
+ f.write(line)
27
+
28
+
29
+ # def write_results(filename, results_dict: Dict, data_type: str):
30
+ # if not filename:
31
+ # return
32
+ # path = os.path.dirname(filename)
33
+ # if not os.path.exists(path):
34
+ # os.makedirs(path)
35
+
36
+ # if data_type in ('mot', 'mcmot', 'lab'):
37
+ # save_format = '{frame},{id},{x1},{y1},{w},{h},1,-1,-1,-1\n'
38
+ # elif data_type == 'kitti':
39
+ # save_format = '{frame} {id} pedestrian -1 -1 -10 {x1} {y1} {x2} {y2} -1 -1 -1 -1000 -1000 -1000 -10 {score}\n'
40
+ # else:
41
+ # raise ValueError(data_type)
42
+
43
+ # with open(filename, 'w') as f:
44
+ # for frame_id, frame_data in results_dict.items():
45
+ # if data_type == 'kitti':
46
+ # frame_id -= 1
47
+ # for tlwh, track_id in frame_data:
48
+ # if track_id < 0:
49
+ # continue
50
+ # x1, y1, w, h = tlwh
51
+ # x2, y2 = x1 + w, y1 + h
52
+ # line = save_format.format(frame=frame_id, id=track_id, x1=x1, y1=y1, x2=x2, y2=y2, w=w, h=h, score=1.0)
53
+ # f.write(line)
54
+ # logger.info('Save results to {}'.format(filename))
55
+
56
+
57
+ def read_results(filename, data_type: str, is_gt=False, is_ignore=False):
58
+ if data_type in ('mot', 'lab'):
59
+ read_fun = read_mot_results
60
+ else:
61
+ raise ValueError('Unknown data type: {}'.format(data_type))
62
+
63
+ return read_fun(filename, is_gt, is_ignore)
64
+
65
+
66
+ """
67
+ labels={'ped', ... % 1
68
+ 'person_on_vhcl', ... % 2
69
+ 'car', ... % 3
70
+ 'bicycle', ... % 4
71
+ 'mbike', ... % 5
72
+ 'non_mot_vhcl', ... % 6
73
+ 'static_person', ... % 7
74
+ 'distractor', ... % 8
75
+ 'occluder', ... % 9
76
+ 'occluder_on_grnd', ... %10
77
+ 'occluder_full', ... % 11
78
+ 'reflection', ... % 12
79
+ 'crowd' ... % 13
80
+ };
81
+ """
82
+
83
+
84
+ def read_mot_results(filename, is_gt, is_ignore):
85
+ valid_labels = {1}
86
+ ignore_labels = {2, 7, 8, 12}
87
+ results_dict = dict()
88
+ if os.path.isfile(filename):
89
+ with open(filename, 'r') as f:
90
+ for line in f.readlines():
91
+ linelist = line.split(',')
92
+ if len(linelist) < 7:
93
+ continue
94
+ fid = int(linelist[0])
95
+ if fid < 1:
96
+ continue
97
+ results_dict.setdefault(fid, list())
98
+
99
+ if is_gt:
100
+ if 'MOT16-' in filename or 'MOT17-' in filename:
101
+ label = int(float(linelist[7]))
102
+ mark = int(float(linelist[6]))
103
+ if mark == 0 or label not in valid_labels:
104
+ continue
105
+ score = 1
106
+ elif is_ignore:
107
+ if 'MOT16-' in filename or 'MOT17-' in filename:
108
+ label = int(float(linelist[7]))
109
+ vis_ratio = float(linelist[8])
110
+ if label not in ignore_labels and vis_ratio >= 0:
111
+ continue
112
+ else:
113
+ continue
114
+ score = 1
115
+ else:
116
+ score = float(linelist[6])
117
+
118
+ tlwh = tuple(map(float, linelist[2:6]))
119
+ target_id = int(linelist[1])
120
+
121
+ results_dict[fid].append((tlwh, target_id, score))
122
+
123
+ return results_dict
124
+
125
+
126
+ def unzip_objs(objs):
127
+ if len(objs) > 0:
128
+ tlwhs, ids, scores = zip(*objs)
129
+ else:
130
+ tlwhs, ids, scores = [], [], []
131
+ tlwhs = np.asarray(tlwhs, dtype=float).reshape(-1, 4)
132
+
133
+ return tlwhs, ids, scores
deep_sort/utils/json_logger.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ References:
3
+ https://medium.com/analytics-vidhya/creating-a-custom-logging-mechanism-for-real-time-object-detection-using-tdd-4ca2cfcd0a2f
4
+ """
5
+ import json
6
+ from os import makedirs
7
+ from os.path import exists, join
8
+ from datetime import datetime
9
+
10
+
11
+ class JsonMeta(object):
12
+ HOURS = 3
13
+ MINUTES = 59
14
+ SECONDS = 59
15
+ PATH_TO_SAVE = 'LOGS'
16
+ DEFAULT_FILE_NAME = 'remaining'
17
+
18
+
19
+ class BaseJsonLogger(object):
20
+ """
21
+ This is the base class that returns __dict__ of its own
22
+ it also returns the dicts of objects in the attributes that are list instances
23
+
24
+ """
25
+
26
+ def dic(self):
27
+ # returns dicts of objects
28
+ out = {}
29
+ for k, v in self.__dict__.items():
30
+ if hasattr(v, 'dic'):
31
+ out[k] = v.dic()
32
+ elif isinstance(v, list):
33
+ out[k] = self.list(v)
34
+ else:
35
+ out[k] = v
36
+ return out
37
+
38
+ @staticmethod
39
+ def list(values):
40
+ # applies the dic method on items in the list
41
+ return [v.dic() if hasattr(v, 'dic') else v for v in values]
42
+
43
+
44
+ class Label(BaseJsonLogger):
45
+ """
46
+ For each bounding box there are various categories with confidences. Label class keeps track of that information.
47
+ """
48
+
49
+ def __init__(self, category: str, confidence: float):
50
+ self.category = category
51
+ self.confidence = confidence
52
+
53
+
54
+ class Bbox(BaseJsonLogger):
55
+ """
56
+ This module stores the information for each frame and use them in JsonParser
57
+ Attributes:
58
+ labels (list): List of label module.
59
+ top (int):
60
+ left (int):
61
+ width (int):
62
+ height (int):
63
+
64
+ Args:
65
+ bbox_id (float):
66
+ top (int):
67
+ left (int):
68
+ width (int):
69
+ height (int):
70
+
71
+ References:
72
+ Check Label module for better understanding.
73
+
74
+
75
+ """
76
+
77
+ def __init__(self, bbox_id, top, left, width, height):
78
+ self.labels = []
79
+ self.bbox_id = bbox_id
80
+ self.top = top
81
+ self.left = left
82
+ self.width = width
83
+ self.height = height
84
+
85
+ def add_label(self, category, confidence):
86
+ # adds category and confidence only if top_k is not exceeded.
87
+ self.labels.append(Label(category, confidence))
88
+
89
+ def labels_full(self, value):
90
+ return len(self.labels) == value
91
+
92
+
93
+ class Frame(BaseJsonLogger):
94
+ """
95
+ This module stores the information for each frame and use them in JsonParser
96
+ Attributes:
97
+ timestamp (float): The elapsed time of captured frame
98
+ frame_id (int): The frame number of the captured video
99
+ bboxes (list of Bbox objects): Stores the list of bbox objects.
100
+
101
+ References:
102
+ Check Bbox class for better information
103
+
104
+ Args:
105
+ timestamp (float):
106
+ frame_id (int):
107
+
108
+ """
109
+
110
+ def __init__(self, frame_id: int, timestamp: float = None):
111
+ self.frame_id = frame_id
112
+ self.timestamp = timestamp
113
+ self.bboxes = []
114
+
115
+ def add_bbox(self, bbox_id: int, top: int, left: int, width: int, height: int):
116
+ bboxes_ids = [bbox.bbox_id for bbox in self.bboxes]
117
+ if bbox_id not in bboxes_ids:
118
+ self.bboxes.append(Bbox(bbox_id, top, left, width, height))
119
+ else:
120
+ raise ValueError("Frame with id: {} already has a Bbox with id: {}".format(self.frame_id, bbox_id))
121
+
122
+ def add_label_to_bbox(self, bbox_id: int, category: str, confidence: float):
123
+ bboxes = {bbox.id: bbox for bbox in self.bboxes}
124
+ if bbox_id in bboxes.keys():
125
+ res = bboxes.get(bbox_id)
126
+ res.add_label(category, confidence)
127
+ else:
128
+ raise ValueError('the bbox with id: {} does not exists!'.format(bbox_id))
129
+
130
+
131
+ class BboxToJsonLogger(BaseJsonLogger):
132
+ """
133
+ ُ This module is designed to automate the task of logging jsons. An example json is used
134
+ to show the contents of json file shortly
135
+ Example:
136
+ {
137
+ "video_details": {
138
+ "frame_width": 1920,
139
+ "frame_height": 1080,
140
+ "frame_rate": 20,
141
+ "video_name": "/home/gpu/codes/MSD/pedestrian_2/project/public/camera1.avi"
142
+ },
143
+ "frames": [
144
+ {
145
+ "frame_id": 329,
146
+ "timestamp": 3365.1254
147
+ "bboxes": [
148
+ {
149
+ "labels": [
150
+ {
151
+ "category": "pedestrian",
152
+ "confidence": 0.9
153
+ }
154
+ ],
155
+ "bbox_id": 0,
156
+ "top": 1257,
157
+ "left": 138,
158
+ "width": 68,
159
+ "height": 109
160
+ }
161
+ ]
162
+ }],
163
+
164
+ Attributes:
165
+ frames (dict): It's a dictionary that maps each frame_id to json attributes.
166
+ video_details (dict): information about video file.
167
+ top_k_labels (int): shows the allowed number of labels
168
+ start_time (datetime object): we use it to automate the json output by time.
169
+
170
+ Args:
171
+ top_k_labels (int): shows the allowed number of labels
172
+
173
+ """
174
+
175
+ def __init__(self, top_k_labels: int = 1):
176
+ self.frames = {}
177
+ self.video_details = self.video_details = dict(frame_width=None, frame_height=None, frame_rate=None,
178
+ video_name=None)
179
+ self.top_k_labels = top_k_labels
180
+ self.start_time = datetime.now()
181
+
182
+ def set_top_k(self, value):
183
+ self.top_k_labels = value
184
+
185
+ def frame_exists(self, frame_id: int) -> bool:
186
+ """
187
+ Args:
188
+ frame_id (int):
189
+
190
+ Returns:
191
+ bool: true if frame_id is recognized
192
+ """
193
+ return frame_id in self.frames.keys()
194
+
195
+ def add_frame(self, frame_id: int, timestamp: float = None) -> None:
196
+ """
197
+ Args:
198
+ frame_id (int):
199
+ timestamp (float): opencv captured frame time property
200
+
201
+ Raises:
202
+ ValueError: if frame_id would not exist in class frames attribute
203
+
204
+ Returns:
205
+ None
206
+
207
+ """
208
+ if not self.frame_exists(frame_id):
209
+ self.frames[frame_id] = Frame(frame_id, timestamp)
210
+ else:
211
+ raise ValueError("Frame id: {} already exists".format(frame_id))
212
+
213
+ def bbox_exists(self, frame_id: int, bbox_id: int) -> bool:
214
+ """
215
+ Args:
216
+ frame_id:
217
+ bbox_id:
218
+
219
+ Returns:
220
+ bool: if bbox exists in frame bboxes list
221
+ """
222
+ bboxes = []
223
+ if self.frame_exists(frame_id=frame_id):
224
+ bboxes = [bbox.bbox_id for bbox in self.frames[frame_id].bboxes]
225
+ return bbox_id in bboxes
226
+
227
+ def find_bbox(self, frame_id: int, bbox_id: int):
228
+ """
229
+
230
+ Args:
231
+ frame_id:
232
+ bbox_id:
233
+
234
+ Returns:
235
+ bbox_id (int):
236
+
237
+ Raises:
238
+ ValueError: if bbox_id does not exist in the bbox list of specific frame.
239
+ """
240
+ if not self.bbox_exists(frame_id, bbox_id):
241
+ raise ValueError("frame with id: {} does not contain bbox with id: {}".format(frame_id, bbox_id))
242
+ bboxes = {bbox.bbox_id: bbox for bbox in self.frames[frame_id].bboxes}
243
+ return bboxes.get(bbox_id)
244
+
245
+ def add_bbox_to_frame(self, frame_id: int, bbox_id: int, top: int, left: int, width: int, height: int) -> None:
246
+ """
247
+
248
+ Args:
249
+ frame_id (int):
250
+ bbox_id (int):
251
+ top (int):
252
+ left (int):
253
+ width (int):
254
+ height (int):
255
+
256
+ Returns:
257
+ None
258
+
259
+ Raises:
260
+ ValueError: if bbox_id already exist in frame information with frame_id
261
+ ValueError: if frame_id does not exist in frames attribute
262
+ """
263
+ if self.frame_exists(frame_id):
264
+ frame = self.frames[frame_id]
265
+ if not self.bbox_exists(frame_id, bbox_id):
266
+ frame.add_bbox(bbox_id, top, left, width, height)
267
+ else:
268
+ raise ValueError(
269
+ "frame with frame_id: {} already contains the bbox with id: {} ".format(frame_id, bbox_id))
270
+ else:
271
+ raise ValueError("frame with frame_id: {} does not exist".format(frame_id))
272
+
273
+ def add_label_to_bbox(self, frame_id: int, bbox_id: int, category: str, confidence: float):
274
+ """
275
+ Args:
276
+ frame_id:
277
+ bbox_id:
278
+ category:
279
+ confidence: the confidence value returned from yolo detection
280
+
281
+ Returns:
282
+ None
283
+
284
+ Raises:
285
+ ValueError: if labels quota (top_k_labels) exceeds.
286
+ """
287
+ bbox = self.find_bbox(frame_id, bbox_id)
288
+ if not bbox.labels_full(self.top_k_labels):
289
+ bbox.add_label(category, confidence)
290
+ else:
291
+ raise ValueError("labels in frame_id: {}, bbox_id: {} is fulled".format(frame_id, bbox_id))
292
+
293
+ def add_video_details(self, frame_width: int = None, frame_height: int = None, frame_rate: int = None,
294
+ video_name: str = None):
295
+ self.video_details['frame_width'] = frame_width
296
+ self.video_details['frame_height'] = frame_height
297
+ self.video_details['frame_rate'] = frame_rate
298
+ self.video_details['video_name'] = video_name
299
+
300
+ def output(self):
301
+ output = {'video_details': self.video_details}
302
+ result = list(self.frames.values())
303
+ output['frames'] = [item.dic() for item in result]
304
+ return output
305
+
306
+ def json_output(self, output_name):
307
+ """
308
+ Args:
309
+ output_name:
310
+
311
+ Returns:
312
+ None
313
+
314
+ Notes:
315
+ It creates the json output with `output_name` name.
316
+ """
317
+ if not output_name.endswith('.json'):
318
+ output_name += '.json'
319
+ with open(output_name, 'w') as file:
320
+ json.dump(self.output(), file)
321
+ file.close()
322
+
323
+ def set_start(self):
324
+ self.start_time = datetime.now()
325
+
326
+ def schedule_output_by_time(self, output_dir=JsonMeta.PATH_TO_SAVE, hours: int = 0, minutes: int = 0,
327
+ seconds: int = 60) -> None:
328
+ """
329
+ Notes:
330
+ Creates folder and then periodically stores the jsons on that address.
331
+
332
+ Args:
333
+ output_dir (str): the directory where output files will be stored
334
+ hours (int):
335
+ minutes (int):
336
+ seconds (int):
337
+
338
+ Returns:
339
+ None
340
+
341
+ """
342
+ end = datetime.now()
343
+ interval = 0
344
+ interval += abs(min([hours, JsonMeta.HOURS]) * 3600)
345
+ interval += abs(min([minutes, JsonMeta.MINUTES]) * 60)
346
+ interval += abs(min([seconds, JsonMeta.SECONDS]))
347
+ diff = (end - self.start_time).seconds
348
+
349
+ if diff > interval:
350
+ output_name = self.start_time.strftime('%Y-%m-%d %H-%M-%S') + '.json'
351
+ if not exists(output_dir):
352
+ makedirs(output_dir)
353
+ output = join(output_dir, output_name)
354
+ self.json_output(output_name=output)
355
+ self.frames = {}
356
+ self.start_time = datetime.now()
357
+
358
+ def schedule_output_by_frames(self, frames_quota, frame_counter, output_dir=JsonMeta.PATH_TO_SAVE):
359
+ """
360
+ saves as the number of frames quota increases higher.
361
+ :param frames_quota:
362
+ :param frame_counter:
363
+ :param output_dir:
364
+ :return:
365
+ """
366
+ pass
367
+
368
+ def flush(self, output_dir):
369
+ """
370
+ Notes:
371
+ We use this function to output jsons whenever possible.
372
+ like the time that we exit the while loop of opencv.
373
+
374
+ Args:
375
+ output_dir:
376
+
377
+ Returns:
378
+ None
379
+
380
+ """
381
+ filename = self.start_time.strftime('%Y-%m-%d %H-%M-%S') + '-remaining.json'
382
+ output = join(output_dir, filename)
383
+ self.json_output(output_name=output)
deep_sort/utils/log.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+
4
+ def get_logger(name='root'):
5
+ formatter = logging.Formatter(
6
+ # fmt='%(asctime)s [%(levelname)s]: %(filename)s(%(funcName)s:%(lineno)s) >> %(message)s')
7
+ fmt='%(asctime)s [%(levelname)s]: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
8
+
9
+ handler = logging.StreamHandler()
10
+ handler.setFormatter(formatter)
11
+
12
+ logger = logging.getLogger(name)
13
+ logger.setLevel(logging.INFO)
14
+ logger.addHandler(handler)
15
+ return logger
16
+
17
+
deep_sort/utils/parser.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ from easydict import EasyDict as edict
4
+
5
+ class YamlParser(edict):
6
+ """
7
+ This is yaml parser based on EasyDict.
8
+ """
9
+ def __init__(self, cfg_dict=None, config_file=None):
10
+ if cfg_dict is None:
11
+ cfg_dict = {}
12
+
13
+ if config_file is not None:
14
+ assert(os.path.isfile(config_file))
15
+ with open(config_file, 'r') as fo:
16
+ cfg_dict.update(yaml.load(fo.read()))
17
+
18
+ super(YamlParser, self).__init__(cfg_dict)
19
+
20
+
21
+ def merge_from_file(self, config_file):
22
+ with open(config_file, 'r') as fo:
23
+ #self.update(yaml.load(fo.read()))
24
+ self.update(yaml.load(fo.read(),Loader=yaml.FullLoader))
25
+
26
+ def merge_from_dict(self, config_dict):
27
+ self.update(config_dict)
28
+
29
+
30
+ def get_config(config_file=None):
31
+ return YamlParser(config_file=config_file)
32
+
33
+
34
+ if __name__ == "__main__":
35
+ cfg = YamlParser(config_file="../configs/yolov3.yaml")
36
+ cfg.merge_from_file("../configs/deep_sort.yaml")
37
+
38
+ import ipdb; ipdb.set_trace()