Spaces:
Runtime error
Runtime error
yourusername
commited on
Commit
β’
2a27594
1
Parent(s):
e23d478
:beers: cheers
Browse files- .gitattributes +2 -0
- .gitignore +160 -0
- README.md +1 -1
- app.py +120 -0
- coco.yaml +20 -0
- data/coco.yaml +20 -0
- data/dataset.yaml +11 -0
- data/voc.yaml +11 -0
- example_1.jpg +3 -0
- example_1.mp4 +3 -0
- example_2.jpg +3 -0
- example_2.mp4 +3 -0
- example_3.jpg +3 -0
- example_3.mp4 +3 -0
- inferer.py +238 -0
- packages.txt +1 -0
- pyproject.toml +7 -0
- requirements.txt +15 -0
- yolov6/core/engine.py +273 -0
- yolov6/core/evaler.py +256 -0
- yolov6/core/inferer.py +231 -0
- yolov6/data/data_augment.py +193 -0
- yolov6/data/data_load.py +113 -0
- yolov6/data/datasets.py +550 -0
- yolov6/data/vis_dataset.py +57 -0
- yolov6/data/voc2yolo.py +99 -0
- yolov6/layers/common.py +501 -0
- yolov6/layers/dbb_transforms.py +50 -0
- yolov6/models/efficientrep.py +102 -0
- yolov6/models/effidehead.py +211 -0
- yolov6/models/end2end.py +147 -0
- yolov6/models/loss.py +411 -0
- yolov6/models/reppan.py +108 -0
- yolov6/models/yolo.py +83 -0
- yolov6/solver/build.py +42 -0
- yolov6/utils/Arial.ttf +0 -0
- yolov6/utils/checkpoint.py +60 -0
- yolov6/utils/config.py +101 -0
- yolov6/utils/ema.py +59 -0
- yolov6/utils/envs.py +54 -0
- yolov6/utils/events.py +41 -0
- yolov6/utils/figure_iou.py +114 -0
- yolov6/utils/general.py +24 -0
- yolov6/utils/nms.py +106 -0
- yolov6/utils/torch_utils.py +110 -0
.gitattributes
CHANGED
@@ -25,3 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Yolov6
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: Yolov6
|
3 |
+
emoji: π₯ππ₯
|
4 |
colorFrom: blue
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
app.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import tempfile
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
from inferer import Inferer
|
10 |
+
|
11 |
+
pipeline = Inferer("nateraw/yolov6s")
|
12 |
+
|
13 |
+
|
14 |
+
def fn_image(image, conf_thres, iou_thres):
|
15 |
+
return pipeline(image, conf_thres, iou_thres)
|
16 |
+
|
17 |
+
|
18 |
+
def fn_video(video_file, conf_thres, iou_thres, start_sec, duration):
|
19 |
+
start_timestamp = time.strftime("%H:%M:%S", time.gmtime(start_sec))
|
20 |
+
end_timestamp = time.strftime("%H:%M:%S", time.gmtime(start_sec + duration))
|
21 |
+
|
22 |
+
suffix = Path(video_file).suffix
|
23 |
+
|
24 |
+
clip_temp_file = tempfile.NamedTemporaryFile(suffix=suffix)
|
25 |
+
subprocess.call(
|
26 |
+
f"ffmpeg -y -ss {start_timestamp} -i {video_file} -to {end_timestamp} -c copy {clip_temp_file.name}".split()
|
27 |
+
)
|
28 |
+
|
29 |
+
# Reader of clip file
|
30 |
+
cap = cv2.VideoCapture(clip_temp_file.name)
|
31 |
+
|
32 |
+
# This is an intermediary temp file where we'll write the video to
|
33 |
+
# Unfortunately, gradio doesn't play too nice with videos rn so we have to do some hackiness
|
34 |
+
# with ffmpeg at the end of the function here.
|
35 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4") as temp_file:
|
36 |
+
out = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*"MP4V"), 30, (1280, 720))
|
37 |
+
|
38 |
+
num_frames = 0
|
39 |
+
max_frames = duration * 30
|
40 |
+
while cap.isOpened():
|
41 |
+
try:
|
42 |
+
ret, frame = cap.read()
|
43 |
+
if not ret:
|
44 |
+
break
|
45 |
+
except Exception as e:
|
46 |
+
print(e)
|
47 |
+
continue
|
48 |
+
|
49 |
+
out.write(pipeline(frame, conf_thres, iou_thres))
|
50 |
+
num_frames += 1
|
51 |
+
print("Processed {} frames".format(num_frames))
|
52 |
+
if num_frames == max_frames:
|
53 |
+
break
|
54 |
+
|
55 |
+
out.release()
|
56 |
+
|
57 |
+
# Aforementioned hackiness
|
58 |
+
out_file = tempfile.NamedTemporaryFile(suffix="out.mp4", delete=False)
|
59 |
+
subprocess.run(f"ffmpeg -y -loglevel quiet -stats -i {temp_file.name} -c:v libx264 {out_file.name}".split())
|
60 |
+
|
61 |
+
return out_file.name
|
62 |
+
|
63 |
+
|
64 |
+
image_interface = gr.Interface(
|
65 |
+
fn=fn_image,
|
66 |
+
inputs=[
|
67 |
+
"image",
|
68 |
+
gr.Slider(0, 1, value=0.5, label="Confidence Threshold"),
|
69 |
+
gr.Slider(0, 1, value=0.5, label="IOU Threshold"),
|
70 |
+
],
|
71 |
+
outputs=gr.Image(type="file"),
|
72 |
+
examples=[["example_1.jpg", 0.5, 0.5], ["example_2.jpg", 0.25, 0.45], ["example_3.jpg", 0.25, 0.45]],
|
73 |
+
title="YOLOv6",
|
74 |
+
description=(
|
75 |
+
"Gradio demo for YOLOv6 for object detection on images. To use it, simply upload your image or click one of the"
|
76 |
+
" examples to load them. Read more at the links below."
|
77 |
+
),
|
78 |
+
article=(
|
79 |
+
"<div style='text-align: center;'><a href='https://github.com/meituan/YOLOv6' target='_blank'>Github Repo</a> |"
|
80 |
+
" <center><img src='https://visitor-badge.glitch.me/badge?page_id=nateraw_yolov6' alt='visitor"
|
81 |
+
" badge'></center></div>"
|
82 |
+
),
|
83 |
+
allow_flagging=False,
|
84 |
+
allow_screenshot=False,
|
85 |
+
)
|
86 |
+
|
87 |
+
video_interface = gr.Interface(
|
88 |
+
fn=fn_video,
|
89 |
+
inputs=[
|
90 |
+
gr.Video(type="file"),
|
91 |
+
gr.Slider(0, 1, value=0.25, label="Confidence Threshold"),
|
92 |
+
gr.Slider(0, 1, value=0.45, label="IOU Threshold"),
|
93 |
+
gr.Slider(0, 10, value=0, label="Start Second", step=1),
|
94 |
+
gr.Slider(0, 3, value=2, label="Duration", step=1),
|
95 |
+
],
|
96 |
+
outputs=gr.Video(type="file", format="mp4"),
|
97 |
+
examples=[
|
98 |
+
["example_1.mp4", 0.25, 0.45, 0, 2],
|
99 |
+
["example_2.mp4", 0.25, 0.45, 5, 3],
|
100 |
+
["example_3.mp4", 0.25, 0.45, 6, 3],
|
101 |
+
],
|
102 |
+
title="YOLOv6",
|
103 |
+
description=(
|
104 |
+
"Gradio demo for YOLOv6 for object detection on videos. To use it, simply upload your video or click one of the"
|
105 |
+
" examples to load them. Read more at the links below."
|
106 |
+
),
|
107 |
+
article=(
|
108 |
+
"<div style='text-align: center;'><a href='https://github.com/meituan/YOLOv6' target='_blank'>Github Repo</a> |"
|
109 |
+
" <center><img src='https://visitor-badge.glitch.me/badge?page_id=nateraw_yolov6' alt='visitor"
|
110 |
+
" badge'></center></div>"
|
111 |
+
),
|
112 |
+
allow_flagging=False,
|
113 |
+
allow_screenshot=False,
|
114 |
+
)
|
115 |
+
|
116 |
+
if __name__ == "__main__":
|
117 |
+
gr.TabbedInterface(
|
118 |
+
[video_interface, image_interface],
|
119 |
+
["Run on Videos!", "Run on Images!"],
|
120 |
+
).launch()
|
coco.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# COCO 2017 dataset http://cocodataset.org
|
2 |
+
train: ../coco/images/train2017 # 118287 images
|
3 |
+
val: ../coco/images/val2017 # 5000 images
|
4 |
+
test: ../coco/images/test2017
|
5 |
+
anno_path: ../coco/annotations/instances_val2017.json
|
6 |
+
# number of classes
|
7 |
+
nc: 80
|
8 |
+
# whether it is coco dataset, only coco dataset should be set to True.
|
9 |
+
is_coco: True
|
10 |
+
|
11 |
+
# class names
|
12 |
+
names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
|
13 |
+
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
14 |
+
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
15 |
+
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
16 |
+
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
17 |
+
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
18 |
+
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
19 |
+
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
20 |
+
'hair drier', 'toothbrush' ]
|
data/coco.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# COCO 2017 dataset http://cocodataset.org
|
2 |
+
train: ../coco/images/train2017 # 118287 images
|
3 |
+
val: ../coco/images/val2017 # 5000 images
|
4 |
+
test: ../coco/images/test2017
|
5 |
+
anno_path: ../coco/annotations/instances_val2017.json
|
6 |
+
# number of classes
|
7 |
+
nc: 80
|
8 |
+
# whether it is coco dataset, only coco dataset should be set to True.
|
9 |
+
is_coco: True
|
10 |
+
|
11 |
+
# class names
|
12 |
+
names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
|
13 |
+
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
14 |
+
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
15 |
+
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
16 |
+
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
17 |
+
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
18 |
+
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
19 |
+
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
20 |
+
'hair drier', 'toothbrush' ]
|
data/dataset.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Please insure that your custom_dataset are put in same parent dir with YOLOv6_DIR
|
2 |
+
train: ../custom_dataset/images/train # train images
|
3 |
+
val: ../custom_dataset/images/val # val images
|
4 |
+
test: ../custom_dataset/images/test # test images (optional)
|
5 |
+
|
6 |
+
# whether it is coco dataset, only coco dataset should be set to True.
|
7 |
+
is_coco: False
|
8 |
+
# Classes
|
9 |
+
nc: 20 # number of classes
|
10 |
+
names: ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
|
11 |
+
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] # class names
|
data/voc.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Please insure that your custom_dataset are put in same parent dir with YOLOv6_DIR
|
2 |
+
train: VOCdevkit/voc_07_12/images/train # train images
|
3 |
+
val: VOCdevkit/voc_07_12/images/val # val images
|
4 |
+
test: VOCdevkit/voc_07_12/images/val # test images (optional)
|
5 |
+
|
6 |
+
# whether it is coco dataset, only coco dataset should be set to True.
|
7 |
+
is_coco: False
|
8 |
+
# Classes
|
9 |
+
nc: 20 # number of classes
|
10 |
+
names: ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
|
11 |
+
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] # class names
|
example_1.jpg
ADDED
Git LFS Details
|
example_1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:52e97530eb82cb036d6cd3dc6f141fbeaa15461b3346a11649a64bda9be7e828
|
3 |
+
size 3890679
|
example_2.jpg
ADDED
Git LFS Details
|
example_2.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6223d01dcc060f8598d0a79da33baaa6d4049087d650224e25771a670aee0a6a
|
3 |
+
size 4137103
|
example_3.jpg
ADDED
Git LFS Details
|
example_3.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3974097c49918132965c02a121ec45e525d53216f61ccdcdd4a5247a193468ff
|
3 |
+
size 4991487
|
inferer.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
import math
|
4 |
+
import os.path as osp
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
from PIL import Image, ImageFont
|
11 |
+
|
12 |
+
from yolov6.data.data_augment import letterbox
|
13 |
+
from yolov6.layers.common import DetectBackend
|
14 |
+
from yolov6.utils.events import LOGGER, load_yaml
|
15 |
+
from yolov6.utils.nms import non_max_suppression
|
16 |
+
|
17 |
+
|
18 |
+
class Inferer:
|
19 |
+
def __init__(self, model_id, device="cpu", yaml="coco.yaml", img_size=640, half=False):
|
20 |
+
self.__dict__.update(locals())
|
21 |
+
|
22 |
+
# Init model
|
23 |
+
self.img_size = img_size
|
24 |
+
cuda = device != "cpu" and torch.cuda.is_available()
|
25 |
+
self.device = torch.device("cuda:0" if cuda else "cpu")
|
26 |
+
self.model = DetectBackend(hf_hub_download(model_id, "model.pt"), device=self.device)
|
27 |
+
self.stride = self.model.stride
|
28 |
+
self.class_names = load_yaml(yaml)["names"]
|
29 |
+
self.img_size = self.check_img_size(self.img_size, s=self.stride) # check image size
|
30 |
+
|
31 |
+
# Half precision
|
32 |
+
if half & (self.device.type != "cpu"):
|
33 |
+
self.model.model.half()
|
34 |
+
else:
|
35 |
+
self.model.model.float()
|
36 |
+
half = False
|
37 |
+
|
38 |
+
if self.device.type != "cpu":
|
39 |
+
self.model(
|
40 |
+
torch.zeros(1, 3, *self.img_size).to(self.device).type_as(next(self.model.model.parameters()))
|
41 |
+
) # warmup
|
42 |
+
|
43 |
+
# Switch model to deploy status
|
44 |
+
self.model_switch(self.model, self.img_size)
|
45 |
+
|
46 |
+
def model_switch(self, model, img_size):
|
47 |
+
"""Model switch to deploy status"""
|
48 |
+
from yolov6.layers.common import RepVGGBlock
|
49 |
+
|
50 |
+
for layer in model.modules():
|
51 |
+
if isinstance(layer, RepVGGBlock):
|
52 |
+
layer.switch_to_deploy()
|
53 |
+
|
54 |
+
LOGGER.info("Switch model to deploy modality.")
|
55 |
+
|
56 |
+
def __call__(
|
57 |
+
self,
|
58 |
+
path_or_image,
|
59 |
+
conf_thres=0.25,
|
60 |
+
iou_thres=0.45,
|
61 |
+
classes=None,
|
62 |
+
agnostic_nms=False,
|
63 |
+
max_det=1000,
|
64 |
+
hide_labels=False,
|
65 |
+
hide_conf=False,
|
66 |
+
):
|
67 |
+
"""Model Inference and results visualization"""
|
68 |
+
|
69 |
+
img, img_src = self.precess_image(path_or_image, self.img_size, self.stride, self.half)
|
70 |
+
img = img.to(self.device)
|
71 |
+
if len(img.shape) == 3:
|
72 |
+
img = img[None]
|
73 |
+
# expand for batch dim
|
74 |
+
pred_results = self.model(img)
|
75 |
+
det = non_max_suppression(pred_results, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)[0]
|
76 |
+
|
77 |
+
gn = torch.tensor(img_src.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
78 |
+
img_ori = img_src
|
79 |
+
|
80 |
+
# check image and font
|
81 |
+
assert (
|
82 |
+
img_ori.data.contiguous
|
83 |
+
), "Image needs to be contiguous. Please apply to input images with np.ascontiguousarray(im)."
|
84 |
+
self.font_check()
|
85 |
+
|
86 |
+
if len(det):
|
87 |
+
det[:, :4] = self.rescale(img.shape[2:], det[:, :4], img_src.shape).round()
|
88 |
+
|
89 |
+
for *xyxy, conf, cls in reversed(det):
|
90 |
+
class_num = int(cls) # integer class
|
91 |
+
label = (
|
92 |
+
None
|
93 |
+
if hide_labels
|
94 |
+
else (self.class_names[class_num] if hide_conf else f"{self.class_names[class_num]} {conf:.2f}")
|
95 |
+
)
|
96 |
+
|
97 |
+
self.plot_box_and_label(
|
98 |
+
img_ori,
|
99 |
+
max(round(sum(img_ori.shape) / 2 * 0.003), 2),
|
100 |
+
xyxy,
|
101 |
+
label,
|
102 |
+
color=self.generate_colors(class_num, True),
|
103 |
+
)
|
104 |
+
|
105 |
+
img_src = np.asarray(img_ori)
|
106 |
+
|
107 |
+
return img_src
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
def precess_image(path_or_image, img_size, stride, half):
|
111 |
+
"""Process image before image inference."""
|
112 |
+
if isinstance(path_or_image, str):
|
113 |
+
try:
|
114 |
+
img_src = cv2.imread(path_or_image)
|
115 |
+
assert img_src is not None, f"Invalid image: {path_or_image}"
|
116 |
+
except Exception as e:
|
117 |
+
LOGGER.warning(e)
|
118 |
+
elif isinstance(path_or_image, np.ndarray):
|
119 |
+
img_src = path_or_image
|
120 |
+
elif isinstance(path_or_image, Image.Image):
|
121 |
+
img_src = np.array(path_or_image)
|
122 |
+
|
123 |
+
image = letterbox(img_src, img_size, stride=stride)[0]
|
124 |
+
|
125 |
+
# Convert
|
126 |
+
image = image.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
127 |
+
image = torch.from_numpy(np.ascontiguousarray(image))
|
128 |
+
image = image.half() if half else image.float() # uint8 to fp16/32
|
129 |
+
image /= 255 # 0 - 255 to 0.0 - 1.0
|
130 |
+
|
131 |
+
return image, img_src
|
132 |
+
|
133 |
+
@staticmethod
|
134 |
+
def rescale(ori_shape, boxes, target_shape):
|
135 |
+
"""Rescale the output to the original image shape"""
|
136 |
+
ratio = min(ori_shape[0] / target_shape[0], ori_shape[1] / target_shape[1])
|
137 |
+
padding = (ori_shape[1] - target_shape[1] * ratio) / 2, (ori_shape[0] - target_shape[0] * ratio) / 2
|
138 |
+
|
139 |
+
boxes[:, [0, 2]] -= padding[0]
|
140 |
+
boxes[:, [1, 3]] -= padding[1]
|
141 |
+
boxes[:, :4] /= ratio
|
142 |
+
|
143 |
+
boxes[:, 0].clamp_(0, target_shape[1]) # x1
|
144 |
+
boxes[:, 1].clamp_(0, target_shape[0]) # y1
|
145 |
+
boxes[:, 2].clamp_(0, target_shape[1]) # x2
|
146 |
+
boxes[:, 3].clamp_(0, target_shape[0]) # y2
|
147 |
+
|
148 |
+
return boxes
|
149 |
+
|
150 |
+
def check_img_size(self, img_size, s=32, floor=0):
|
151 |
+
"""Make sure image size is a multiple of stride s in each dimension, and return a new shape list of image."""
|
152 |
+
if isinstance(img_size, int): # integer i.e. img_size=640
|
153 |
+
new_size = max(self.make_divisible(img_size, int(s)), floor)
|
154 |
+
elif isinstance(img_size, list): # list i.e. img_size=[640, 480]
|
155 |
+
new_size = [max(self.make_divisible(x, int(s)), floor) for x in img_size]
|
156 |
+
else:
|
157 |
+
raise Exception(f"Unsupported type of img_size: {type(img_size)}")
|
158 |
+
|
159 |
+
if new_size != img_size:
|
160 |
+
print(f"WARNING: --img-size {img_size} must be multiple of max stride {s}, updating to {new_size}")
|
161 |
+
return new_size if isinstance(img_size, list) else [new_size] * 2
|
162 |
+
|
163 |
+
def make_divisible(self, x, divisor):
|
164 |
+
# Upward revision the value x to make it evenly divisible by the divisor.
|
165 |
+
return math.ceil(x / divisor) * divisor
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def plot_box_and_label(image, lw, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255)):
|
169 |
+
# Add one xyxy box to image with label
|
170 |
+
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
171 |
+
cv2.rectangle(image, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
|
172 |
+
if label:
|
173 |
+
tf = max(lw - 1, 1) # font thickness
|
174 |
+
w, h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=tf)[0] # text width, height
|
175 |
+
outside = p1[1] - h - 3 >= 0 # label fits outside box
|
176 |
+
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
177 |
+
cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled
|
178 |
+
cv2.putText(
|
179 |
+
image,
|
180 |
+
label,
|
181 |
+
(p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
|
182 |
+
0,
|
183 |
+
lw / 3,
|
184 |
+
txt_color,
|
185 |
+
thickness=tf,
|
186 |
+
lineType=cv2.LINE_AA,
|
187 |
+
)
|
188 |
+
|
189 |
+
@staticmethod
|
190 |
+
def font_check(font="./yolov6/utils/Arial.ttf", size=10):
|
191 |
+
# Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
|
192 |
+
assert osp.exists(font), f"font path not exists: {font}"
|
193 |
+
try:
|
194 |
+
return ImageFont.truetype(str(font) if font.exists() else font.name, size)
|
195 |
+
except Exception as e: # download if missing
|
196 |
+
return ImageFont.truetype(str(font), size)
|
197 |
+
|
198 |
+
@staticmethod
|
199 |
+
def box_convert(x):
|
200 |
+
# Convert boxes with shape [n, 4] from [x1, y1, x2, y2] to [x, y, w, h] where x1y1=top-left, x2y2=bottom-right
|
201 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
202 |
+
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
203 |
+
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
204 |
+
y[:, 2] = x[:, 2] - x[:, 0] # width
|
205 |
+
y[:, 3] = x[:, 3] - x[:, 1] # height
|
206 |
+
return y
|
207 |
+
|
208 |
+
@staticmethod
|
209 |
+
def generate_colors(i, bgr=False):
|
210 |
+
hex = (
|
211 |
+
"FF3838",
|
212 |
+
"FF9D97",
|
213 |
+
"FF701F",
|
214 |
+
"FFB21D",
|
215 |
+
"CFD231",
|
216 |
+
"48F90A",
|
217 |
+
"92CC17",
|
218 |
+
"3DDB86",
|
219 |
+
"1A9334",
|
220 |
+
"00D4BB",
|
221 |
+
"2C99A8",
|
222 |
+
"00C2FF",
|
223 |
+
"344593",
|
224 |
+
"6473FF",
|
225 |
+
"0018EC",
|
226 |
+
"8438FF",
|
227 |
+
"520085",
|
228 |
+
"CB38FF",
|
229 |
+
"FF95C8",
|
230 |
+
"FF37C7",
|
231 |
+
)
|
232 |
+
palette = []
|
233 |
+
for iter in hex:
|
234 |
+
h = "#" + iter
|
235 |
+
palette.append(tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)))
|
236 |
+
num = len(palette)
|
237 |
+
color = palette[int(i) % num]
|
238 |
+
return (color[2], color[1], color[0]) if bgr else color
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
ffmpeg
|
pyproject.toml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.black]
|
2 |
+
line-length = 120
|
3 |
+
target_version = ['py37', 'py38', 'py39', 'py310']
|
4 |
+
preview = true
|
5 |
+
|
6 |
+
[tool.isort]
|
7 |
+
profile = "black"
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub
|
2 |
+
gradio
|
3 |
+
torch>=1.8.0
|
4 |
+
torchvision>=0.9.0
|
5 |
+
numpy>=1.18.5
|
6 |
+
opencv-python>=4.1.2
|
7 |
+
PyYAML>=5.3.1
|
8 |
+
scipy>=1.4.1
|
9 |
+
# tqdm>=4.41.0
|
10 |
+
# addict>=2.4.0
|
11 |
+
# tensorboard>=2.7.0
|
12 |
+
# pycocotools>=2.0
|
13 |
+
# onnx>=1.10.0 # ONNX export
|
14 |
+
# onnx-simplifier>=0.3.6 # ONNX simplifier
|
15 |
+
# thop # FLOPs computation
|
yolov6/core/engine.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from copy import deepcopy
|
6 |
+
import os.path as osp
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torch.cuda import amp
|
13 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
14 |
+
from torch.utils.tensorboard import SummaryWriter
|
15 |
+
|
16 |
+
import tools.eval as eval
|
17 |
+
from yolov6.data.data_load import create_dataloader
|
18 |
+
from yolov6.models.yolo import build_model
|
19 |
+
from yolov6.models.loss import ComputeLoss
|
20 |
+
from yolov6.utils.events import LOGGER, NCOLS, load_yaml, write_tblog
|
21 |
+
from yolov6.utils.ema import ModelEMA, de_parallel
|
22 |
+
from yolov6.utils.checkpoint import load_state_dict, save_checkpoint, strip_optimizer
|
23 |
+
from yolov6.solver.build import build_optimizer, build_lr_scheduler
|
24 |
+
|
25 |
+
class Trainer:
|
26 |
+
def __init__(self, args, cfg, device):
|
27 |
+
self.args = args
|
28 |
+
self.cfg = cfg
|
29 |
+
self.device = device
|
30 |
+
|
31 |
+
if args.resume:
|
32 |
+
self.ckpt = torch.load(args.resume, map_location='cpu')
|
33 |
+
|
34 |
+
self.rank = args.rank
|
35 |
+
self.local_rank = args.local_rank
|
36 |
+
self.world_size = args.world_size
|
37 |
+
self.main_process = self.rank in [-1, 0]
|
38 |
+
self.save_dir = args.save_dir
|
39 |
+
# get data loader
|
40 |
+
self.data_dict = load_yaml(args.data_path)
|
41 |
+
self.num_classes = self.data_dict['nc']
|
42 |
+
self.train_loader, self.val_loader = self.get_data_loader(args, cfg, self.data_dict)
|
43 |
+
# get model and optimizer
|
44 |
+
model = self.get_model(args, cfg, self.num_classes, device)
|
45 |
+
self.optimizer = self.get_optimizer(args, cfg, model)
|
46 |
+
self.scheduler, self.lf = self.get_lr_scheduler(args, cfg, self.optimizer)
|
47 |
+
self.ema = ModelEMA(model) if self.main_process else None
|
48 |
+
# tensorboard
|
49 |
+
self.tblogger = SummaryWriter(self.save_dir) if self.main_process else None
|
50 |
+
self.start_epoch = 0
|
51 |
+
#resume
|
52 |
+
if hasattr(self, "ckpt"):
|
53 |
+
resume_state_dict = self.ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
54 |
+
model.load_state_dict(resume_state_dict, strict=True) # load
|
55 |
+
self.start_epoch = self.ckpt['epoch'] + 1
|
56 |
+
self.optimizer.load_state_dict(self.ckpt['optimizer'])
|
57 |
+
if self.main_process:
|
58 |
+
self.ema.ema.load_state_dict(self.ckpt['ema'].float().state_dict())
|
59 |
+
self.ema.updates = self.ckpt['updates']
|
60 |
+
self.model = self.parallel_model(args, model, device)
|
61 |
+
self.model.nc, self.model.names = self.data_dict['nc'], self.data_dict['names']
|
62 |
+
|
63 |
+
self.max_epoch = args.epochs
|
64 |
+
self.max_stepnum = len(self.train_loader)
|
65 |
+
self.batch_size = args.batch_size
|
66 |
+
self.img_size = args.img_size
|
67 |
+
|
68 |
+
|
69 |
+
# Training Process
|
70 |
+
|
71 |
+
def train(self):
|
72 |
+
try:
|
73 |
+
self.train_before_loop()
|
74 |
+
for self.epoch in range(self.start_epoch, self.max_epoch):
|
75 |
+
self.train_in_loop()
|
76 |
+
|
77 |
+
except Exception as _:
|
78 |
+
LOGGER.error('ERROR in training loop or eval/save model.')
|
79 |
+
raise
|
80 |
+
finally:
|
81 |
+
self.train_after_loop()
|
82 |
+
|
83 |
+
# Training loop for each epoch
|
84 |
+
def train_in_loop(self):
|
85 |
+
try:
|
86 |
+
self.prepare_for_steps()
|
87 |
+
for self.step, self.batch_data in self.pbar:
|
88 |
+
self.train_in_steps()
|
89 |
+
self.print_details()
|
90 |
+
except Exception as _:
|
91 |
+
LOGGER.error('ERROR in training steps.')
|
92 |
+
raise
|
93 |
+
try:
|
94 |
+
self.eval_and_save()
|
95 |
+
except Exception as _:
|
96 |
+
LOGGER.error('ERROR in evaluate and save model.')
|
97 |
+
raise
|
98 |
+
|
99 |
+
# Training loop for batchdata
|
100 |
+
def train_in_steps(self):
|
101 |
+
images, targets = self.prepro_data(self.batch_data, self.device)
|
102 |
+
# forward
|
103 |
+
with amp.autocast(enabled=self.device != 'cpu'):
|
104 |
+
preds = self.model(images)
|
105 |
+
total_loss, loss_items = self.compute_loss(preds, targets)
|
106 |
+
if self.rank != -1:
|
107 |
+
total_loss *= self.world_size
|
108 |
+
# backward
|
109 |
+
self.scaler.scale(total_loss).backward()
|
110 |
+
self.loss_items = loss_items
|
111 |
+
self.update_optimizer()
|
112 |
+
|
113 |
+
def eval_and_save(self):
|
114 |
+
remaining_epochs = self.max_epoch - self.epoch
|
115 |
+
eval_interval = self.args.eval_interval if remaining_epochs > self.args.heavy_eval_range else 1
|
116 |
+
is_val_epoch = (not self.args.eval_final_only or (remaining_epochs == 1)) and (self.epoch % eval_interval == 0)
|
117 |
+
if self.main_process:
|
118 |
+
self.ema.update_attr(self.model, include=['nc', 'names', 'stride']) # update attributes for ema model
|
119 |
+
if is_val_epoch:
|
120 |
+
self.eval_model()
|
121 |
+
self.ap = self.evaluate_results[0] * 0.1 + self.evaluate_results[1] * 0.9
|
122 |
+
self.best_ap = max(self.ap, self.best_ap)
|
123 |
+
# save ckpt
|
124 |
+
ckpt = {
|
125 |
+
'model': deepcopy(de_parallel(self.model)).half(),
|
126 |
+
'ema': deepcopy(self.ema.ema).half(),
|
127 |
+
'updates': self.ema.updates,
|
128 |
+
'optimizer': self.optimizer.state_dict(),
|
129 |
+
'epoch': self.epoch,
|
130 |
+
}
|
131 |
+
|
132 |
+
save_ckpt_dir = osp.join(self.save_dir, 'weights')
|
133 |
+
save_checkpoint(ckpt, (is_val_epoch) and (self.ap == self.best_ap), save_ckpt_dir, model_name='last_ckpt')
|
134 |
+
del ckpt
|
135 |
+
# log for tensorboard
|
136 |
+
write_tblog(self.tblogger, self.epoch, self.evaluate_results, self.mean_loss)
|
137 |
+
|
138 |
+
def eval_model(self):
|
139 |
+
results = eval.run(self.data_dict,
|
140 |
+
batch_size=self.batch_size // self.world_size * 2,
|
141 |
+
img_size=self.img_size,
|
142 |
+
model=self.ema.ema,
|
143 |
+
dataloader=self.val_loader,
|
144 |
+
save_dir=self.save_dir,
|
145 |
+
task='train')
|
146 |
+
|
147 |
+
LOGGER.info(f"Epoch: {self.epoch} | mAP@0.5: {results[0]} | mAP@0.50:0.95: {results[1]}")
|
148 |
+
self.evaluate_results = results[:2]
|
149 |
+
|
150 |
+
def train_before_loop(self):
|
151 |
+
LOGGER.info('Training start...')
|
152 |
+
self.start_time = time.time()
|
153 |
+
self.warmup_stepnum = max(round(self.cfg.solver.warmup_epochs * self.max_stepnum), 1000)
|
154 |
+
self.scheduler.last_epoch = self.start_epoch - 1
|
155 |
+
self.last_opt_step = -1
|
156 |
+
self.scaler = amp.GradScaler(enabled=self.device != 'cpu')
|
157 |
+
|
158 |
+
self.best_ap, self.ap = 0.0, 0.0
|
159 |
+
self.evaluate_results = (0, 0) # AP50, AP50_95
|
160 |
+
self.compute_loss = ComputeLoss(iou_type=self.cfg.model.head.iou_type)
|
161 |
+
|
162 |
+
def prepare_for_steps(self):
|
163 |
+
if self.epoch > self.start_epoch:
|
164 |
+
self.scheduler.step()
|
165 |
+
self.model.train()
|
166 |
+
if self.rank != -1:
|
167 |
+
self.train_loader.sampler.set_epoch(self.epoch)
|
168 |
+
self.mean_loss = torch.zeros(4, device=self.device)
|
169 |
+
self.optimizer.zero_grad()
|
170 |
+
|
171 |
+
LOGGER.info(('\n' + '%10s' * 5) % ('Epoch', 'iou_loss', 'l1_loss', 'obj_loss', 'cls_loss'))
|
172 |
+
self.pbar = enumerate(self.train_loader)
|
173 |
+
if self.main_process:
|
174 |
+
self.pbar = tqdm(self.pbar, total=self.max_stepnum, ncols=NCOLS, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
|
175 |
+
|
176 |
+
# Print loss after each steps
|
177 |
+
def print_details(self):
|
178 |
+
if self.main_process:
|
179 |
+
self.mean_loss = (self.mean_loss * self.step + self.loss_items) / (self.step + 1)
|
180 |
+
self.pbar.set_description(('%10s' + '%10.4g' * 4) % (f'{self.epoch}/{self.max_epoch - 1}', \
|
181 |
+
*(self.mean_loss)))
|
182 |
+
|
183 |
+
# Empty cache if training finished
|
184 |
+
def train_after_loop(self):
|
185 |
+
if self.main_process:
|
186 |
+
LOGGER.info(f'\nTraining completed in {(time.time() - self.start_time) / 3600:.3f} hours.')
|
187 |
+
save_ckpt_dir = osp.join(self.save_dir, 'weights')
|
188 |
+
strip_optimizer(save_ckpt_dir, self.epoch) # strip optimizers for saved pt model
|
189 |
+
if self.device != 'cpu':
|
190 |
+
torch.cuda.empty_cache()
|
191 |
+
|
192 |
+
def update_optimizer(self):
|
193 |
+
curr_step = self.step + self.max_stepnum * self.epoch
|
194 |
+
self.accumulate = max(1, round(64 / self.batch_size))
|
195 |
+
if curr_step <= self.warmup_stepnum:
|
196 |
+
self.accumulate = max(1, np.interp(curr_step, [0, self.warmup_stepnum], [1, 64 / self.batch_size]).round())
|
197 |
+
for k, param in enumerate(self.optimizer.param_groups):
|
198 |
+
warmup_bias_lr = self.cfg.solver.warmup_bias_lr if k == 2 else 0.0
|
199 |
+
param['lr'] = np.interp(curr_step, [0, self.warmup_stepnum], [warmup_bias_lr, param['initial_lr'] * self.lf(self.epoch)])
|
200 |
+
if 'momentum' in param:
|
201 |
+
param['momentum'] = np.interp(curr_step, [0, self.warmup_stepnum], [self.cfg.solver.warmup_momentum, self.cfg.solver.momentum])
|
202 |
+
if curr_step - self.last_opt_step >= self.accumulate:
|
203 |
+
self.scaler.step(self.optimizer)
|
204 |
+
self.scaler.update()
|
205 |
+
self.optimizer.zero_grad()
|
206 |
+
if self.ema:
|
207 |
+
self.ema.update(self.model)
|
208 |
+
self.last_opt_step = curr_step
|
209 |
+
|
210 |
+
@staticmethod
|
211 |
+
def get_data_loader(args, cfg, data_dict):
|
212 |
+
train_path, val_path = data_dict['train'], data_dict['val']
|
213 |
+
# check data
|
214 |
+
nc = int(data_dict['nc'])
|
215 |
+
class_names = data_dict['names']
|
216 |
+
assert len(class_names) == nc, f'the length of class names does not match the number of classes defined'
|
217 |
+
grid_size = max(int(max(cfg.model.head.strides)), 32)
|
218 |
+
# create train dataloader
|
219 |
+
train_loader = create_dataloader(train_path, args.img_size, args.batch_size // args.world_size, grid_size,
|
220 |
+
hyp=dict(cfg.data_aug), augment=True, rect=False, rank=args.local_rank,
|
221 |
+
workers=args.workers, shuffle=True, check_images=args.check_images,
|
222 |
+
check_labels=args.check_labels, data_dict=data_dict, task='train')[0]
|
223 |
+
# create val dataloader
|
224 |
+
val_loader = None
|
225 |
+
if args.rank in [-1, 0]:
|
226 |
+
val_loader = create_dataloader(val_path, args.img_size, args.batch_size // args.world_size * 2, grid_size,
|
227 |
+
hyp=dict(cfg.data_aug), rect=True, rank=-1, pad=0.5,
|
228 |
+
workers=args.workers, check_images=args.check_images,
|
229 |
+
check_labels=args.check_labels, data_dict=data_dict, task='val')[0]
|
230 |
+
|
231 |
+
return train_loader, val_loader
|
232 |
+
|
233 |
+
@staticmethod
|
234 |
+
def prepro_data(batch_data, device):
|
235 |
+
images = batch_data[0].to(device, non_blocking=True).float() / 255
|
236 |
+
targets = batch_data[1].to(device)
|
237 |
+
return images, targets
|
238 |
+
|
239 |
+
def get_model(self, args, cfg, nc, device):
|
240 |
+
model = build_model(cfg, nc, device)
|
241 |
+
weights = cfg.model.pretrained
|
242 |
+
if weights: # finetune if pretrained model is set
|
243 |
+
LOGGER.info(f'Loading state_dict from {weights} for fine-tuning...')
|
244 |
+
model = load_state_dict(weights, model, map_location=device)
|
245 |
+
LOGGER.info('Model: {}'.format(model))
|
246 |
+
return model
|
247 |
+
|
248 |
+
@staticmethod
|
249 |
+
def parallel_model(args, model, device):
|
250 |
+
# If DP mode
|
251 |
+
dp_mode = device.type != 'cpu' and args.rank == -1
|
252 |
+
if dp_mode and torch.cuda.device_count() > 1:
|
253 |
+
LOGGER.warning('WARNING: DP not recommended, use DDP instead.\n')
|
254 |
+
model = torch.nn.DataParallel(model)
|
255 |
+
|
256 |
+
# If DDP mode
|
257 |
+
ddp_mode = device.type != 'cpu' and args.rank != -1
|
258 |
+
if ddp_mode:
|
259 |
+
model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
|
260 |
+
|
261 |
+
return model
|
262 |
+
|
263 |
+
def get_optimizer(self, args, cfg, model):
|
264 |
+
accumulate = max(1, round(64 / args.batch_size))
|
265 |
+
cfg.solver.weight_decay *= args.batch_size * accumulate / 64
|
266 |
+
optimizer = build_optimizer(cfg, model)
|
267 |
+
return optimizer
|
268 |
+
|
269 |
+
@staticmethod
|
270 |
+
def get_lr_scheduler(args, cfg, optimizer):
|
271 |
+
epochs = args.epochs
|
272 |
+
lr_scheduler, lf = build_lr_scheduler(cfg, optimizer, epochs)
|
273 |
+
return lr_scheduler, lf
|
yolov6/core/evaler.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
import os
|
4 |
+
from tqdm import tqdm
|
5 |
+
import numpy as np
|
6 |
+
import json
|
7 |
+
import torch
|
8 |
+
import yaml
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
from pycocotools.coco import COCO
|
12 |
+
from pycocotools.cocoeval import COCOeval
|
13 |
+
|
14 |
+
from yolov6.data.data_load import create_dataloader
|
15 |
+
from yolov6.utils.events import LOGGER, NCOLS
|
16 |
+
from yolov6.utils.nms import non_max_suppression
|
17 |
+
from yolov6.utils.checkpoint import load_checkpoint
|
18 |
+
from yolov6.utils.torch_utils import time_sync, get_model_info
|
19 |
+
|
20 |
+
'''
|
21 |
+
python tools/eval.py --task 'train'/'val'/'speed'
|
22 |
+
'''
|
23 |
+
|
24 |
+
|
25 |
+
class Evaler:
|
26 |
+
def __init__(self,
|
27 |
+
data,
|
28 |
+
batch_size=32,
|
29 |
+
img_size=640,
|
30 |
+
conf_thres=0.001,
|
31 |
+
iou_thres=0.65,
|
32 |
+
device='',
|
33 |
+
half=True,
|
34 |
+
save_dir=''):
|
35 |
+
self.data = data
|
36 |
+
self.batch_size = batch_size
|
37 |
+
self.img_size = img_size
|
38 |
+
self.conf_thres = conf_thres
|
39 |
+
self.iou_thres = iou_thres
|
40 |
+
self.device = device
|
41 |
+
self.half = half
|
42 |
+
self.save_dir = save_dir
|
43 |
+
|
44 |
+
def init_model(self, model, weights, task):
|
45 |
+
if task != 'train':
|
46 |
+
model = load_checkpoint(weights, map_location=self.device)
|
47 |
+
self.stride = int(model.stride.max())
|
48 |
+
if self.device.type != 'cpu':
|
49 |
+
model(torch.zeros(1, 3, self.img_size, self.img_size).to(self.device).type_as(next(model.parameters())))
|
50 |
+
# switch to deploy
|
51 |
+
from yolov6.layers.common import RepVGGBlock
|
52 |
+
for layer in model.modules():
|
53 |
+
if isinstance(layer, RepVGGBlock):
|
54 |
+
layer.switch_to_deploy()
|
55 |
+
LOGGER.info("Switch model to deploy modality.")
|
56 |
+
LOGGER.info("Model Summary: {}".format(get_model_info(model, self.img_size)))
|
57 |
+
model.half() if self.half else model.float()
|
58 |
+
return model
|
59 |
+
|
60 |
+
def init_data(self, dataloader, task):
|
61 |
+
'''Initialize dataloader.
|
62 |
+
Returns a dataloader for task val or speed.
|
63 |
+
'''
|
64 |
+
self.is_coco = self.data.get("is_coco", False)
|
65 |
+
self.ids = self.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
66 |
+
if task != 'train':
|
67 |
+
pad = 0.0 if task == 'speed' else 0.5
|
68 |
+
dataloader = create_dataloader(self.data[task if task in ('train', 'val', 'test') else 'val'],
|
69 |
+
self.img_size, self.batch_size, self.stride, check_labels=True, pad=pad, rect=True,
|
70 |
+
data_dict=self.data, task=task)[0]
|
71 |
+
return dataloader
|
72 |
+
|
73 |
+
def predict_model(self, model, dataloader, task):
|
74 |
+
'''Model prediction
|
75 |
+
Predicts the whole dataset and gets the prediced results and inference time.
|
76 |
+
'''
|
77 |
+
self.speed_result = torch.zeros(4, device=self.device)
|
78 |
+
pred_results = []
|
79 |
+
pbar = tqdm(dataloader, desc="Inferencing model in val datasets.", ncols=NCOLS)
|
80 |
+
for imgs, targets, paths, shapes in pbar:
|
81 |
+
# pre-process
|
82 |
+
t1 = time_sync()
|
83 |
+
imgs = imgs.to(self.device, non_blocking=True)
|
84 |
+
imgs = imgs.half() if self.half else imgs.float()
|
85 |
+
imgs /= 255
|
86 |
+
self.speed_result[1] += time_sync() - t1 # pre-process time
|
87 |
+
|
88 |
+
# Inference
|
89 |
+
t2 = time_sync()
|
90 |
+
outputs = model(imgs)
|
91 |
+
self.speed_result[2] += time_sync() - t2 # inference time
|
92 |
+
|
93 |
+
# post-process
|
94 |
+
t3 = time_sync()
|
95 |
+
outputs = non_max_suppression(outputs, self.conf_thres, self.iou_thres, multi_label=True)
|
96 |
+
self.speed_result[3] += time_sync() - t3 # post-process time
|
97 |
+
self.speed_result[0] += len(outputs)
|
98 |
+
|
99 |
+
# save result
|
100 |
+
pred_results.extend(self.convert_to_coco_format(outputs, imgs, paths, shapes, self.ids))
|
101 |
+
return pred_results
|
102 |
+
|
103 |
+
def eval_model(self, pred_results, model, dataloader, task):
|
104 |
+
'''Evaluate models
|
105 |
+
For task speed, this function only evaluates the speed of model and outputs inference time.
|
106 |
+
For task val, this function evaluates the speed and mAP by pycocotools, and returns
|
107 |
+
inference time and mAP value.
|
108 |
+
'''
|
109 |
+
LOGGER.info(f'\nEvaluating speed.')
|
110 |
+
self.eval_speed(task)
|
111 |
+
|
112 |
+
LOGGER.info(f'\nEvaluating mAP by pycocotools.')
|
113 |
+
if task != 'speed' and len(pred_results):
|
114 |
+
if 'anno_path' in self.data:
|
115 |
+
anno_json = self.data['anno_path']
|
116 |
+
else:
|
117 |
+
# generated coco format labels in dataset initialization
|
118 |
+
dataset_root = os.path.dirname(os.path.dirname(self.data['val']))
|
119 |
+
base_name = os.path.basename(self.data['val'])
|
120 |
+
anno_json = os.path.join(dataset_root, 'annotations', f'instances_{base_name}.json')
|
121 |
+
pred_json = os.path.join(self.save_dir, "predictions.json")
|
122 |
+
LOGGER.info(f'Saving {pred_json}...')
|
123 |
+
with open(pred_json, 'w') as f:
|
124 |
+
json.dump(pred_results, f)
|
125 |
+
|
126 |
+
anno = COCO(anno_json)
|
127 |
+
pred = anno.loadRes(pred_json)
|
128 |
+
cocoEval = COCOeval(anno, pred, 'bbox')
|
129 |
+
if self.is_coco:
|
130 |
+
imgIds = [int(os.path.basename(x).split(".")[0])
|
131 |
+
for x in dataloader.dataset.img_paths]
|
132 |
+
cocoEval.params.imgIds = imgIds
|
133 |
+
cocoEval.evaluate()
|
134 |
+
cocoEval.accumulate()
|
135 |
+
cocoEval.summarize()
|
136 |
+
map, map50 = cocoEval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
|
137 |
+
# Return results
|
138 |
+
model.float() # for training
|
139 |
+
if task != 'train':
|
140 |
+
LOGGER.info(f"Results saved to {self.save_dir}")
|
141 |
+
return (map50, map)
|
142 |
+
return (0.0, 0.0)
|
143 |
+
|
144 |
+
def eval_speed(self, task):
|
145 |
+
'''Evaluate model inference speed.'''
|
146 |
+
if task != 'train':
|
147 |
+
n_samples = self.speed_result[0].item()
|
148 |
+
pre_time, inf_time, nms_time = 1000 * self.speed_result[1:].cpu().numpy() / n_samples
|
149 |
+
for n, v in zip(["pre-process", "inference", "NMS"],[pre_time, inf_time, nms_time]):
|
150 |
+
LOGGER.info("Average {} time: {:.2f} ms".format(n, v))
|
151 |
+
|
152 |
+
def box_convert(self, x):
|
153 |
+
# Convert boxes with shape [n, 4] from [x1, y1, x2, y2] to [x, y, w, h] where x1y1=top-left, x2y2=bottom-right
|
154 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
155 |
+
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
156 |
+
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
157 |
+
y[:, 2] = x[:, 2] - x[:, 0] # width
|
158 |
+
y[:, 3] = x[:, 3] - x[:, 1] # height
|
159 |
+
return y
|
160 |
+
|
161 |
+
def scale_coords(self, img1_shape, coords, img0_shape, ratio_pad=None):
|
162 |
+
# Rescale coords (xyxy) from img1_shape to img0_shape
|
163 |
+
if ratio_pad is None: # calculate from img0_shape
|
164 |
+
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
165 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
166 |
+
else:
|
167 |
+
gain = ratio_pad[0][0]
|
168 |
+
pad = ratio_pad[1]
|
169 |
+
|
170 |
+
coords[:, [0, 2]] -= pad[0] # x padding
|
171 |
+
coords[:, [1, 3]] -= pad[1] # y padding
|
172 |
+
coords[:, :4] /= gain
|
173 |
+
if isinstance(coords, torch.Tensor): # faster individually
|
174 |
+
coords[:, 0].clamp_(0, img0_shape[1]) # x1
|
175 |
+
coords[:, 1].clamp_(0, img0_shape[0]) # y1
|
176 |
+
coords[:, 2].clamp_(0, img0_shape[1]) # x2
|
177 |
+
coords[:, 3].clamp_(0, img0_shape[0]) # y2
|
178 |
+
else: # np.array (faster grouped)
|
179 |
+
coords[:, [0, 2]] = coords[:, [0, 2]].clip(0, img0_shape[1]) # x1, x2
|
180 |
+
coords[:, [1, 3]] = coords[:, [1, 3]].clip(0, img0_shape[0]) # y1, y2
|
181 |
+
return coords
|
182 |
+
|
183 |
+
def convert_to_coco_format(self, outputs, imgs, paths, shapes, ids):
|
184 |
+
pred_results = []
|
185 |
+
for i, pred in enumerate(outputs):
|
186 |
+
if len(pred) == 0:
|
187 |
+
continue
|
188 |
+
path, shape = Path(paths[i]), shapes[i][0]
|
189 |
+
self.scale_coords(imgs[i].shape[1:], pred[:, :4], shape, shapes[i][1])
|
190 |
+
image_id = int(path.stem) if path.stem.isnumeric() else path.stem
|
191 |
+
bboxes = self.box_convert(pred[:, 0:4])
|
192 |
+
bboxes[:, :2] -= bboxes[:, 2:] / 2
|
193 |
+
cls = pred[:, 5]
|
194 |
+
scores = pred[:, 4]
|
195 |
+
for ind in range(pred.shape[0]):
|
196 |
+
category_id = ids[int(cls[ind])]
|
197 |
+
bbox = [round(x, 3) for x in bboxes[ind].tolist()]
|
198 |
+
score = round(scores[ind].item(), 5)
|
199 |
+
pred_data = {
|
200 |
+
"image_id": image_id,
|
201 |
+
"category_id": category_id,
|
202 |
+
"bbox": bbox,
|
203 |
+
"score": score
|
204 |
+
}
|
205 |
+
pred_results.append(pred_data)
|
206 |
+
return pred_results
|
207 |
+
|
208 |
+
@staticmethod
|
209 |
+
def check_task(task):
|
210 |
+
if task not in ['train','val','speed']:
|
211 |
+
raise Exception("task argument error: only support 'train' / 'val' / 'speed' task.")
|
212 |
+
|
213 |
+
@staticmethod
|
214 |
+
def reload_thres(conf_thres, iou_thres, task):
|
215 |
+
'''Sets conf and iou threshold for task val/speed'''
|
216 |
+
if task != 'train':
|
217 |
+
if task == 'val':
|
218 |
+
conf_thres = 0.001
|
219 |
+
if task == 'speed':
|
220 |
+
conf_thres = 0.25
|
221 |
+
iou_thres = 0.45
|
222 |
+
return conf_thres, iou_thres
|
223 |
+
|
224 |
+
@staticmethod
|
225 |
+
def reload_device(device, model, task):
|
226 |
+
# device = 'cpu' or '0' or '0,1,2,3'
|
227 |
+
if task == 'train':
|
228 |
+
device = next(model.parameters()).device
|
229 |
+
else:
|
230 |
+
if device == 'cpu':
|
231 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
232 |
+
elif device:
|
233 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = device
|
234 |
+
assert torch.cuda.is_available()
|
235 |
+
cuda = device != 'cpu' and torch.cuda.is_available()
|
236 |
+
device = torch.device('cuda:0' if cuda else 'cpu')
|
237 |
+
return device
|
238 |
+
|
239 |
+
@staticmethod
|
240 |
+
def reload_dataset(data):
|
241 |
+
with open(data, errors='ignore') as yaml_file:
|
242 |
+
data = yaml.safe_load(yaml_file)
|
243 |
+
val = data.get('val')
|
244 |
+
if not os.path.exists(val):
|
245 |
+
raise Exception('Dataset not found.')
|
246 |
+
return data
|
247 |
+
|
248 |
+
@staticmethod
|
249 |
+
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
|
250 |
+
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
|
251 |
+
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20,
|
252 |
+
21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
|
253 |
+
41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58,
|
254 |
+
59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79,
|
255 |
+
80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
|
256 |
+
return x
|
yolov6/core/inferer.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import math
|
6 |
+
from tqdm import tqdm
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
import torch
|
10 |
+
from PIL import ImageFont
|
11 |
+
|
12 |
+
from yolov6.utils.events import LOGGER, load_yaml
|
13 |
+
from yolov6.layers.common import DetectBackend
|
14 |
+
from yolov6.data.data_augment import letterbox
|
15 |
+
from yolov6.utils.nms import non_max_suppression
|
16 |
+
from yolov6.utils.torch_utils import get_model_info
|
17 |
+
|
18 |
+
|
19 |
+
class Inferer:
|
20 |
+
def __init__(self, source, weights, device, yaml, img_size, half):
|
21 |
+
import glob
|
22 |
+
from yolov6.data.datasets import IMG_FORMATS
|
23 |
+
|
24 |
+
self.__dict__.update(locals())
|
25 |
+
|
26 |
+
# Init model
|
27 |
+
self.device = device
|
28 |
+
self.img_size = img_size
|
29 |
+
cuda = self.device != 'cpu' and torch.cuda.is_available()
|
30 |
+
self.device = torch.device('cuda:0' if cuda else 'cpu')
|
31 |
+
self.model = DetectBackend(weights, device=self.device)
|
32 |
+
self.stride = self.model.stride
|
33 |
+
self.class_names = load_yaml(yaml)['names']
|
34 |
+
self.img_size = self.check_img_size(self.img_size, s=self.stride) # check image size
|
35 |
+
|
36 |
+
# Half precision
|
37 |
+
if half & (self.device.type != 'cpu'):
|
38 |
+
self.model.model.half()
|
39 |
+
else:
|
40 |
+
self.model.model.float()
|
41 |
+
half = False
|
42 |
+
|
43 |
+
if self.device.type != 'cpu':
|
44 |
+
self.model(torch.zeros(1, 3, *self.img_size).to(self.device).type_as(next(self.model.model.parameters()))) # warmup
|
45 |
+
|
46 |
+
# Load data
|
47 |
+
if os.path.isdir(source):
|
48 |
+
img_paths = sorted(glob.glob(os.path.join(source, '*.*'))) # dir
|
49 |
+
elif os.path.isfile(source):
|
50 |
+
img_paths = [source] # files
|
51 |
+
else:
|
52 |
+
raise Exception(f'Invalid path: {source}')
|
53 |
+
self.img_paths = [img_path for img_path in img_paths if img_path.split('.')[-1].lower() in IMG_FORMATS]
|
54 |
+
|
55 |
+
# Switch model to deploy status
|
56 |
+
self.model_switch(self.model, self.img_size)
|
57 |
+
|
58 |
+
def model_switch(self, model, img_size):
|
59 |
+
''' Model switch to deploy status '''
|
60 |
+
from yolov6.layers.common import RepVGGBlock
|
61 |
+
for layer in model.modules():
|
62 |
+
if isinstance(layer, RepVGGBlock):
|
63 |
+
layer.switch_to_deploy()
|
64 |
+
|
65 |
+
LOGGER.info("Switch model to deploy modality.")
|
66 |
+
|
67 |
+
def infer(self, conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir, save_txt, save_img, hide_labels, hide_conf):
|
68 |
+
''' Model Inference and results visualization '''
|
69 |
+
|
70 |
+
for img_path in tqdm(self.img_paths):
|
71 |
+
img, img_src = self.precess_image(img_path, self.img_size, self.stride, self.half)
|
72 |
+
img = img.to(self.device)
|
73 |
+
if len(img.shape) == 3:
|
74 |
+
img = img[None]
|
75 |
+
# expand for batch dim
|
76 |
+
pred_results = self.model(img)
|
77 |
+
det = non_max_suppression(pred_results, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)[0]
|
78 |
+
|
79 |
+
save_path = osp.join(save_dir, osp.basename(img_path)) # im.jpg
|
80 |
+
txt_path = osp.join(save_dir, 'labels', osp.splitext(osp.basename(img_path))[0])
|
81 |
+
|
82 |
+
gn = torch.tensor(img_src.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
83 |
+
img_ori = img_src
|
84 |
+
|
85 |
+
# check image and font
|
86 |
+
assert img_ori.data.contiguous, 'Image needs to be contiguous. Please apply to input images with np.ascontiguousarray(im).'
|
87 |
+
self.font_check()
|
88 |
+
|
89 |
+
if len(det):
|
90 |
+
det[:, :4] = self.rescale(img.shape[2:], det[:, :4], img_src.shape).round()
|
91 |
+
|
92 |
+
for *xyxy, conf, cls in reversed(det):
|
93 |
+
if save_txt: # Write to file
|
94 |
+
xywh = (self.box_convert(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
95 |
+
line = (cls, *xywh, conf)
|
96 |
+
with open(txt_path + '.txt', 'a') as f:
|
97 |
+
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
98 |
+
|
99 |
+
if save_img:
|
100 |
+
class_num = int(cls) # integer class
|
101 |
+
label = None if hide_labels else (self.class_names[class_num] if hide_conf else f'{self.class_names[class_num]} {conf:.2f}')
|
102 |
+
|
103 |
+
self.plot_box_and_label(img_ori, max(round(sum(img_ori.shape) / 2 * 0.003), 2), xyxy, label, color=self.generate_colors(class_num, True))
|
104 |
+
|
105 |
+
img_src = np.asarray(img_ori)
|
106 |
+
|
107 |
+
# Save results (image with detections)
|
108 |
+
if save_img:
|
109 |
+
cv2.imwrite(save_path, img_src)
|
110 |
+
|
111 |
+
@staticmethod
|
112 |
+
def precess_image(path, img_size, stride, half):
|
113 |
+
'''Process image before image inference.'''
|
114 |
+
try:
|
115 |
+
img_src = cv2.imread(path)
|
116 |
+
assert img_src is not None, f'Invalid image: {path}'
|
117 |
+
except Exception as e:
|
118 |
+
LOGGER.warning(e)
|
119 |
+
image = letterbox(img_src, img_size, stride=stride)[0]
|
120 |
+
|
121 |
+
# Convert
|
122 |
+
image = image.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
123 |
+
image = torch.from_numpy(np.ascontiguousarray(image))
|
124 |
+
image = image.half() if half else image.float() # uint8 to fp16/32
|
125 |
+
image /= 255 # 0 - 255 to 0.0 - 1.0
|
126 |
+
|
127 |
+
return image, img_src
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def rescale(ori_shape, boxes, target_shape):
|
131 |
+
'''Rescale the output to the original image shape'''
|
132 |
+
ratio = min(ori_shape[0] / target_shape[0], ori_shape[1] / target_shape[1])
|
133 |
+
padding = (ori_shape[1] - target_shape[1] * ratio) / 2, (ori_shape[0] - target_shape[0] * ratio) / 2
|
134 |
+
|
135 |
+
boxes[:, [0, 2]] -= padding[0]
|
136 |
+
boxes[:, [1, 3]] -= padding[1]
|
137 |
+
boxes[:, :4] /= ratio
|
138 |
+
|
139 |
+
boxes[:, 0].clamp_(0, target_shape[1]) # x1
|
140 |
+
boxes[:, 1].clamp_(0, target_shape[0]) # y1
|
141 |
+
boxes[:, 2].clamp_(0, target_shape[1]) # x2
|
142 |
+
boxes[:, 3].clamp_(0, target_shape[0]) # y2
|
143 |
+
|
144 |
+
return boxes
|
145 |
+
|
146 |
+
def check_img_size(self, img_size, s=32, floor=0):
|
147 |
+
"""Make sure image size is a multiple of stride s in each dimension, and return a new shape list of image."""
|
148 |
+
if isinstance(img_size, int): # integer i.e. img_size=640
|
149 |
+
new_size = max(self.make_divisible(img_size, int(s)), floor)
|
150 |
+
elif isinstance(img_size, list): # list i.e. img_size=[640, 480]
|
151 |
+
new_size = [max(self.make_divisible(x, int(s)), floor) for x in img_size]
|
152 |
+
else:
|
153 |
+
raise Exception(f"Unsupported type of img_size: {type(img_size)}")
|
154 |
+
|
155 |
+
if new_size != img_size:
|
156 |
+
print(f'WARNING: --img-size {img_size} must be multiple of max stride {s}, updating to {new_size}')
|
157 |
+
return new_size if isinstance(img_size,list) else [new_size]*2
|
158 |
+
|
159 |
+
def make_divisible(self, x, divisor):
|
160 |
+
# Upward revision the value x to make it evenly divisible by the divisor.
|
161 |
+
return math.ceil(x / divisor) * divisor
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def plot_box_and_label(image, lw, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
|
165 |
+
# Add one xyxy box to image with label
|
166 |
+
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
167 |
+
cv2.rectangle(image, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
|
168 |
+
if label:
|
169 |
+
tf = max(lw - 1, 1) # font thickness
|
170 |
+
w, h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=tf)[0] # text width, height
|
171 |
+
outside = p1[1] - h - 3 >= 0 # label fits outside box
|
172 |
+
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
173 |
+
cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled
|
174 |
+
cv2.putText(image, label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), 0, lw / 3, txt_color,
|
175 |
+
thickness=tf, lineType=cv2.LINE_AA)
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def font_check(font='./yolov6/utils/Arial.ttf', size=10):
|
179 |
+
# Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
|
180 |
+
assert osp.exists(font), f'font path not exists: {font}'
|
181 |
+
try:
|
182 |
+
return ImageFont.truetype(str(font) if font.exists() else font.name, size)
|
183 |
+
except Exception as e: # download if missing
|
184 |
+
return ImageFont.truetype(str(font), size)
|
185 |
+
|
186 |
+
@staticmethod
|
187 |
+
def box_convert(x):
|
188 |
+
# Convert boxes with shape [n, 4] from [x1, y1, x2, y2] to [x, y, w, h] where x1y1=top-left, x2y2=bottom-right
|
189 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
190 |
+
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
|
191 |
+
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
|
192 |
+
y[:, 2] = x[:, 2] - x[:, 0] # width
|
193 |
+
y[:, 3] = x[:, 3] - x[:, 1] # height
|
194 |
+
return y
|
195 |
+
|
196 |
+
@staticmethod
|
197 |
+
def generate_colors(i, bgr=False):
|
198 |
+
hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
199 |
+
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
200 |
+
palette = []
|
201 |
+
for iter in hex:
|
202 |
+
h = '#' + iter
|
203 |
+
palette.append(tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)))
|
204 |
+
num = len(palette)
|
205 |
+
color = palette[int(i) % num]
|
206 |
+
return (color[2], color[1], color[0]) if bgr else color
|
207 |
+
|
208 |
+
|
209 |
+
class VideoInferer(Inferer):
|
210 |
+
|
211 |
+
def setup_source(self, source):
|
212 |
+
# Load data
|
213 |
+
if os.path.isfile(source):
|
214 |
+
self.vid_path = source
|
215 |
+
self.vid_name = '.'.join(os.path.basename(source).split('.')[:-1])
|
216 |
+
else:
|
217 |
+
raise Exception(f'Invalid path: {source}')
|
218 |
+
|
219 |
+
self.cap = cv2.VideoCapture(self.vid_path)
|
220 |
+
|
221 |
+
def iterator_length(self):
|
222 |
+
return int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
223 |
+
|
224 |
+
def img_iterator(self):
|
225 |
+
cur_fid = 0
|
226 |
+
ret, frame = self.cap.read()
|
227 |
+
|
228 |
+
while ret:
|
229 |
+
yield frame, f'{self.vid_name}_frame_{cur_fid:06}.jpg'
|
230 |
+
ret, frame = self.cap.read()
|
231 |
+
cur_fid += 1
|
yolov6/data/data_augment.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# This code is based on
|
4 |
+
# https://github.com/ultralytics/yolov5/blob/master/utils/dataloaders.py
|
5 |
+
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
|
13 |
+
def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
|
14 |
+
# HSV color-space augmentation
|
15 |
+
if hgain or sgain or vgain:
|
16 |
+
r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
|
17 |
+
hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV))
|
18 |
+
dtype = im.dtype # uint8
|
19 |
+
|
20 |
+
x = np.arange(0, 256, dtype=r.dtype)
|
21 |
+
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
22 |
+
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
23 |
+
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
24 |
+
|
25 |
+
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
26 |
+
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed
|
27 |
+
|
28 |
+
|
29 |
+
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleup=True, stride=32):
|
30 |
+
# Resize and pad image while meeting stride-multiple constraints
|
31 |
+
shape = im.shape[:2] # current shape [height, width]
|
32 |
+
if isinstance(new_shape, int):
|
33 |
+
new_shape = (new_shape, new_shape)
|
34 |
+
|
35 |
+
# Scale ratio (new / old)
|
36 |
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
37 |
+
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
38 |
+
r = min(r, 1.0)
|
39 |
+
|
40 |
+
# Compute padding
|
41 |
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
42 |
+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
43 |
+
|
44 |
+
if auto: # minimum rectangle
|
45 |
+
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
|
46 |
+
|
47 |
+
dw /= 2 # divide padding into 2 sides
|
48 |
+
dh /= 2
|
49 |
+
|
50 |
+
if shape[::-1] != new_unpad: # resize
|
51 |
+
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
52 |
+
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
53 |
+
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
54 |
+
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
55 |
+
return im, r, (dw, dh)
|
56 |
+
|
57 |
+
|
58 |
+
def mixup(im, labels, im2, labels2):
|
59 |
+
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
|
60 |
+
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
61 |
+
im = (im * r + im2 * (1 - r)).astype(np.uint8)
|
62 |
+
labels = np.concatenate((labels, labels2), 0)
|
63 |
+
return im, labels
|
64 |
+
|
65 |
+
|
66 |
+
def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
|
67 |
+
# Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
|
68 |
+
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
|
69 |
+
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
70 |
+
ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
|
71 |
+
return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
|
72 |
+
|
73 |
+
|
74 |
+
def random_affine(img, labels=(), degrees=10, translate=.1, scale=.1, shear=10,
|
75 |
+
new_shape=(640, 640)):
|
76 |
+
|
77 |
+
n = len(labels)
|
78 |
+
height, width = new_shape
|
79 |
+
|
80 |
+
M, s = get_transform_matrix(img.shape[:2], (height, width), degrees, scale, shear, translate)
|
81 |
+
if (M != np.eye(3)).any(): # image changed
|
82 |
+
img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
|
83 |
+
|
84 |
+
# Transform label coordinates
|
85 |
+
if n:
|
86 |
+
new = np.zeros((n, 4))
|
87 |
+
|
88 |
+
xy = np.ones((n * 4, 3))
|
89 |
+
xy[:, :2] = labels[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
|
90 |
+
xy = xy @ M.T # transform
|
91 |
+
xy = xy[:, :2].reshape(n, 8) # perspective rescale or affine
|
92 |
+
|
93 |
+
# create new boxes
|
94 |
+
x = xy[:, [0, 2, 4, 6]]
|
95 |
+
y = xy[:, [1, 3, 5, 7]]
|
96 |
+
new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
|
97 |
+
|
98 |
+
# clip
|
99 |
+
new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
|
100 |
+
new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
|
101 |
+
|
102 |
+
# filter candidates
|
103 |
+
i = box_candidates(box1=labels[:, 1:5].T * s, box2=new.T, area_thr=0.1)
|
104 |
+
labels = labels[i]
|
105 |
+
labels[:, 1:5] = new[i]
|
106 |
+
|
107 |
+
return img, labels
|
108 |
+
|
109 |
+
|
110 |
+
def get_transform_matrix(img_shape, new_shape, degrees, scale, shear, translate):
|
111 |
+
new_height, new_width = new_shape
|
112 |
+
# Center
|
113 |
+
C = np.eye(3)
|
114 |
+
C[0, 2] = -img_shape[1] / 2 # x translation (pixels)
|
115 |
+
C[1, 2] = -img_shape[0] / 2 # y translation (pixels)
|
116 |
+
|
117 |
+
# Rotation and Scale
|
118 |
+
R = np.eye(3)
|
119 |
+
a = random.uniform(-degrees, degrees)
|
120 |
+
# a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
|
121 |
+
s = random.uniform(1 - scale, 1 + scale)
|
122 |
+
# s = 2 ** random.uniform(-scale, scale)
|
123 |
+
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
|
124 |
+
|
125 |
+
# Shear
|
126 |
+
S = np.eye(3)
|
127 |
+
S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
|
128 |
+
S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
|
129 |
+
|
130 |
+
# Translation
|
131 |
+
T = np.eye(3)
|
132 |
+
T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * new_width # x translation (pixels)
|
133 |
+
T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * new_height # y transla ion (pixels)
|
134 |
+
|
135 |
+
# Combined rotation matrix
|
136 |
+
M = T @ S @ R @ C # order of operations (right to left) is IMPORTANT
|
137 |
+
return M, s
|
138 |
+
|
139 |
+
|
140 |
+
def mosaic_augmentation(img_size, imgs, hs, ws, labels, hyp):
|
141 |
+
|
142 |
+
assert len(imgs) == 4, "Mosaic augmentation of current version only supports 4 images."
|
143 |
+
|
144 |
+
labels4 = []
|
145 |
+
s = img_size
|
146 |
+
yc, xc = (int(random.uniform(s//2, 3*s//2)) for _ in range(2)) # mosaic center x, y
|
147 |
+
for i in range(len(imgs)):
|
148 |
+
# Load image
|
149 |
+
img, h, w = imgs[i], hs[i], ws[i]
|
150 |
+
# place img in img4
|
151 |
+
if i == 0: # top left
|
152 |
+
img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
|
153 |
+
x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
|
154 |
+
x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
|
155 |
+
elif i == 1: # top right
|
156 |
+
x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
|
157 |
+
x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
|
158 |
+
elif i == 2: # bottom left
|
159 |
+
x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
|
160 |
+
x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
|
161 |
+
elif i == 3: # bottom right
|
162 |
+
x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
|
163 |
+
x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
|
164 |
+
|
165 |
+
img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
|
166 |
+
padw = x1a - x1b
|
167 |
+
padh = y1a - y1b
|
168 |
+
|
169 |
+
# Labels
|
170 |
+
labels_per_img = labels[i].copy()
|
171 |
+
if labels_per_img.size:
|
172 |
+
boxes = np.copy(labels_per_img[:, 1:])
|
173 |
+
boxes[:, 0] = w * (labels_per_img[:, 1] - labels_per_img[:, 3] / 2) + padw # top left x
|
174 |
+
boxes[:, 1] = h * (labels_per_img[:, 2] - labels_per_img[:, 4] / 2) + padh # top left y
|
175 |
+
boxes[:, 2] = w * (labels_per_img[:, 1] + labels_per_img[:, 3] / 2) + padw # bottom right x
|
176 |
+
boxes[:, 3] = h * (labels_per_img[:, 2] + labels_per_img[:, 4] / 2) + padh # bottom right y
|
177 |
+
labels_per_img[:, 1:] = boxes
|
178 |
+
|
179 |
+
labels4.append(labels_per_img)
|
180 |
+
|
181 |
+
# Concat/clip labels
|
182 |
+
labels4 = np.concatenate(labels4, 0)
|
183 |
+
for x in (labels4[:, 1:]):
|
184 |
+
np.clip(x, 0, 2 * s, out=x)
|
185 |
+
|
186 |
+
# Augment
|
187 |
+
img4, labels4 = random_affine(img4, labels4,
|
188 |
+
degrees=hyp['degrees'],
|
189 |
+
translate=hyp['translate'],
|
190 |
+
scale=hyp['scale'],
|
191 |
+
shear=hyp['shear'])
|
192 |
+
|
193 |
+
return img4, labels4
|
yolov6/data/data_load.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# This code is based on
|
4 |
+
# https://github.com/ultralytics/yolov5/blob/master/utils/dataloaders.py
|
5 |
+
|
6 |
+
import os
|
7 |
+
from torch.utils.data import dataloader, distributed
|
8 |
+
|
9 |
+
from .datasets import TrainValDataset
|
10 |
+
from yolov6.utils.events import LOGGER
|
11 |
+
from yolov6.utils.torch_utils import torch_distributed_zero_first
|
12 |
+
|
13 |
+
|
14 |
+
def create_dataloader(
|
15 |
+
path,
|
16 |
+
img_size,
|
17 |
+
batch_size,
|
18 |
+
stride,
|
19 |
+
hyp=None,
|
20 |
+
augment=False,
|
21 |
+
check_images=False,
|
22 |
+
check_labels=False,
|
23 |
+
pad=0.0,
|
24 |
+
rect=False,
|
25 |
+
rank=-1,
|
26 |
+
workers=8,
|
27 |
+
shuffle=False,
|
28 |
+
data_dict=None,
|
29 |
+
task="Train",
|
30 |
+
):
|
31 |
+
"""Create general dataloader.
|
32 |
+
|
33 |
+
Returns dataloader and dataset
|
34 |
+
"""
|
35 |
+
if rect and shuffle:
|
36 |
+
LOGGER.warning(
|
37 |
+
"WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False"
|
38 |
+
)
|
39 |
+
shuffle = False
|
40 |
+
with torch_distributed_zero_first(rank):
|
41 |
+
dataset = TrainValDataset(
|
42 |
+
path,
|
43 |
+
img_size,
|
44 |
+
batch_size,
|
45 |
+
augment=augment,
|
46 |
+
hyp=hyp,
|
47 |
+
rect=rect,
|
48 |
+
check_images=check_images,
|
49 |
+
check_labels=check_labels,
|
50 |
+
stride=int(stride),
|
51 |
+
pad=pad,
|
52 |
+
rank=rank,
|
53 |
+
data_dict=data_dict,
|
54 |
+
task=task,
|
55 |
+
)
|
56 |
+
|
57 |
+
batch_size = min(batch_size, len(dataset))
|
58 |
+
workers = min(
|
59 |
+
[
|
60 |
+
os.cpu_count() // int(os.getenv("WORLD_SIZE", 1)),
|
61 |
+
batch_size if batch_size > 1 else 0,
|
62 |
+
workers,
|
63 |
+
]
|
64 |
+
) # number of workers
|
65 |
+
sampler = (
|
66 |
+
None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
67 |
+
)
|
68 |
+
return (
|
69 |
+
TrainValDataLoader(
|
70 |
+
dataset,
|
71 |
+
batch_size=batch_size,
|
72 |
+
shuffle=shuffle and sampler is None,
|
73 |
+
num_workers=workers,
|
74 |
+
sampler=sampler,
|
75 |
+
pin_memory=True,
|
76 |
+
collate_fn=TrainValDataset.collate_fn,
|
77 |
+
),
|
78 |
+
dataset,
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
class TrainValDataLoader(dataloader.DataLoader):
|
83 |
+
"""Dataloader that reuses workers
|
84 |
+
|
85 |
+
Uses same syntax as vanilla DataLoader
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(self, *args, **kwargs):
|
89 |
+
super().__init__(*args, **kwargs)
|
90 |
+
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
|
91 |
+
self.iterator = super().__iter__()
|
92 |
+
|
93 |
+
def __len__(self):
|
94 |
+
return len(self.batch_sampler.sampler)
|
95 |
+
|
96 |
+
def __iter__(self):
|
97 |
+
for i in range(len(self)):
|
98 |
+
yield next(self.iterator)
|
99 |
+
|
100 |
+
|
101 |
+
class _RepeatSampler:
|
102 |
+
"""Sampler that repeats forever
|
103 |
+
|
104 |
+
Args:
|
105 |
+
sampler (Sampler)
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(self, sampler):
|
109 |
+
self.sampler = sampler
|
110 |
+
|
111 |
+
def __iter__(self):
|
112 |
+
while True:
|
113 |
+
yield from iter(self.sampler)
|
yolov6/data/datasets.py
ADDED
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
import glob
|
5 |
+
import os
|
6 |
+
import os.path as osp
|
7 |
+
import random
|
8 |
+
import json
|
9 |
+
import time
|
10 |
+
import hashlib
|
11 |
+
|
12 |
+
from multiprocessing.pool import Pool
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from PIL import ExifTags, Image, ImageOps
|
18 |
+
from torch.utils.data import Dataset
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
from .data_augment import (
|
22 |
+
augment_hsv,
|
23 |
+
letterbox,
|
24 |
+
mixup,
|
25 |
+
random_affine,
|
26 |
+
mosaic_augmentation,
|
27 |
+
)
|
28 |
+
from yolov6.utils.events import LOGGER
|
29 |
+
|
30 |
+
# Parameters
|
31 |
+
IMG_FORMATS = ["bmp", "jpg", "jpeg", "png", "tif", "tiff", "dng", "webp", "mpo"]
|
32 |
+
# Get orientation exif tag
|
33 |
+
for k, v in ExifTags.TAGS.items():
|
34 |
+
if v == "Orientation":
|
35 |
+
ORIENTATION = k
|
36 |
+
break
|
37 |
+
|
38 |
+
|
39 |
+
class TrainValDataset(Dataset):
|
40 |
+
# YOLOv6 train_loader/val_loader, loads images and labels for training and validation
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
img_dir,
|
44 |
+
img_size=640,
|
45 |
+
batch_size=16,
|
46 |
+
augment=False,
|
47 |
+
hyp=None,
|
48 |
+
rect=False,
|
49 |
+
check_images=False,
|
50 |
+
check_labels=False,
|
51 |
+
stride=32,
|
52 |
+
pad=0.0,
|
53 |
+
rank=-1,
|
54 |
+
data_dict=None,
|
55 |
+
task="train",
|
56 |
+
):
|
57 |
+
assert task.lower() in ("train", "val", "speed"), f"Not supported task: {task}"
|
58 |
+
t1 = time.time()
|
59 |
+
self.__dict__.update(locals())
|
60 |
+
self.main_process = self.rank in (-1, 0)
|
61 |
+
self.task = self.task.capitalize()
|
62 |
+
self.class_names = data_dict["names"]
|
63 |
+
self.img_paths, self.labels = self.get_imgs_labels(self.img_dir)
|
64 |
+
if self.rect:
|
65 |
+
shapes = [self.img_info[p]["shape"] for p in self.img_paths]
|
66 |
+
self.shapes = np.array(shapes, dtype=np.float64)
|
67 |
+
self.batch_indices = np.floor(
|
68 |
+
np.arange(len(shapes)) / self.batch_size
|
69 |
+
).astype(
|
70 |
+
np.int
|
71 |
+
) # batch indices of each image
|
72 |
+
self.sort_files_shapes()
|
73 |
+
t2 = time.time()
|
74 |
+
if self.main_process:
|
75 |
+
LOGGER.info(f"%.1fs for dataset initialization." % (t2 - t1))
|
76 |
+
|
77 |
+
def __len__(self):
|
78 |
+
"""Get the length of dataset"""
|
79 |
+
return len(self.img_paths)
|
80 |
+
|
81 |
+
def __getitem__(self, index):
|
82 |
+
"""Fetching a data sample for a given key.
|
83 |
+
This function applies mosaic and mixup augments during training.
|
84 |
+
During validation, letterbox augment is applied.
|
85 |
+
"""
|
86 |
+
# Mosaic Augmentation
|
87 |
+
if self.augment and random.random() < self.hyp["mosaic"]:
|
88 |
+
img, labels = self.get_mosaic(index)
|
89 |
+
shapes = None
|
90 |
+
|
91 |
+
# MixUp augmentation
|
92 |
+
if random.random() < self.hyp["mixup"]:
|
93 |
+
img_other, labels_other = self.get_mosaic(
|
94 |
+
random.randint(0, len(self.img_paths) - 1)
|
95 |
+
)
|
96 |
+
img, labels = mixup(img, labels, img_other, labels_other)
|
97 |
+
|
98 |
+
else:
|
99 |
+
# Load image
|
100 |
+
img, (h0, w0), (h, w) = self.load_image(index)
|
101 |
+
|
102 |
+
# Letterbox
|
103 |
+
shape = (
|
104 |
+
self.batch_shapes[self.batch_indices[index]]
|
105 |
+
if self.rect
|
106 |
+
else self.img_size
|
107 |
+
) # final letterboxed shape
|
108 |
+
img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
|
109 |
+
shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
|
110 |
+
|
111 |
+
labels = self.labels[index].copy()
|
112 |
+
if labels.size:
|
113 |
+
w *= ratio
|
114 |
+
h *= ratio
|
115 |
+
# new boxes
|
116 |
+
boxes = np.copy(labels[:, 1:])
|
117 |
+
boxes[:, 0] = (
|
118 |
+
w * (labels[:, 1] - labels[:, 3] / 2) + pad[0]
|
119 |
+
) # top left x
|
120 |
+
boxes[:, 1] = (
|
121 |
+
h * (labels[:, 2] - labels[:, 4] / 2) + pad[1]
|
122 |
+
) # top left y
|
123 |
+
boxes[:, 2] = (
|
124 |
+
w * (labels[:, 1] + labels[:, 3] / 2) + pad[0]
|
125 |
+
) # bottom right x
|
126 |
+
boxes[:, 3] = (
|
127 |
+
h * (labels[:, 2] + labels[:, 4] / 2) + pad[1]
|
128 |
+
) # bottom right y
|
129 |
+
labels[:, 1:] = boxes
|
130 |
+
|
131 |
+
if self.augment:
|
132 |
+
img, labels = random_affine(
|
133 |
+
img,
|
134 |
+
labels,
|
135 |
+
degrees=self.hyp["degrees"],
|
136 |
+
translate=self.hyp["translate"],
|
137 |
+
scale=self.hyp["scale"],
|
138 |
+
shear=self.hyp["shear"],
|
139 |
+
new_shape=(self.img_size, self.img_size),
|
140 |
+
)
|
141 |
+
|
142 |
+
if len(labels):
|
143 |
+
h, w = img.shape[:2]
|
144 |
+
|
145 |
+
labels[:, [1, 3]] = labels[:, [1, 3]].clip(0, w - 1e-3) # x1, x2
|
146 |
+
labels[:, [2, 4]] = labels[:, [2, 4]].clip(0, h - 1e-3) # y1, y2
|
147 |
+
|
148 |
+
boxes = np.copy(labels[:, 1:])
|
149 |
+
boxes[:, 0] = ((labels[:, 1] + labels[:, 3]) / 2) / w # x center
|
150 |
+
boxes[:, 1] = ((labels[:, 2] + labels[:, 4]) / 2) / h # y center
|
151 |
+
boxes[:, 2] = (labels[:, 3] - labels[:, 1]) / w # width
|
152 |
+
boxes[:, 3] = (labels[:, 4] - labels[:, 2]) / h # height
|
153 |
+
labels[:, 1:] = boxes
|
154 |
+
|
155 |
+
if self.augment:
|
156 |
+
img, labels = self.general_augment(img, labels)
|
157 |
+
|
158 |
+
labels_out = torch.zeros((len(labels), 6))
|
159 |
+
if len(labels):
|
160 |
+
labels_out[:, 1:] = torch.from_numpy(labels)
|
161 |
+
|
162 |
+
# Convert
|
163 |
+
img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
164 |
+
img = np.ascontiguousarray(img)
|
165 |
+
|
166 |
+
return torch.from_numpy(img), labels_out, self.img_paths[index], shapes
|
167 |
+
|
168 |
+
def load_image(self, index):
|
169 |
+
"""Load image.
|
170 |
+
This function loads image by cv2, resize original image to target shape(img_size) with keeping ratio.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
Image, original shape of image, resized image shape
|
174 |
+
"""
|
175 |
+
path = self.img_paths[index]
|
176 |
+
im = cv2.imread(path)
|
177 |
+
assert im is not None, f"Image Not Found {path}, workdir: {os.getcwd()}"
|
178 |
+
|
179 |
+
h0, w0 = im.shape[:2] # origin shape
|
180 |
+
r = self.img_size / max(h0, w0)
|
181 |
+
if r != 1:
|
182 |
+
im = cv2.resize(
|
183 |
+
im,
|
184 |
+
(int(w0 * r), int(h0 * r)),
|
185 |
+
interpolation=cv2.INTER_AREA
|
186 |
+
if r < 1 and not self.augment
|
187 |
+
else cv2.INTER_LINEAR,
|
188 |
+
)
|
189 |
+
return im, (h0, w0), im.shape[:2]
|
190 |
+
|
191 |
+
@staticmethod
|
192 |
+
def collate_fn(batch):
|
193 |
+
"""Merges a list of samples to form a mini-batch of Tensor(s)"""
|
194 |
+
img, label, path, shapes = zip(*batch)
|
195 |
+
for i, l in enumerate(label):
|
196 |
+
l[:, 0] = i # add target image index for build_targets()
|
197 |
+
return torch.stack(img, 0), torch.cat(label, 0), path, shapes
|
198 |
+
|
199 |
+
def get_imgs_labels(self, img_dir):
|
200 |
+
|
201 |
+
assert osp.exists(img_dir), f"{img_dir} is an invalid directory path!"
|
202 |
+
valid_img_record = osp.join(
|
203 |
+
osp.dirname(img_dir), "." + osp.basename(img_dir) + ".json"
|
204 |
+
)
|
205 |
+
NUM_THREADS = min(8, os.cpu_count())
|
206 |
+
|
207 |
+
img_paths = glob.glob(osp.join(img_dir, "*"), recursive=True)
|
208 |
+
img_paths = sorted(
|
209 |
+
p for p in img_paths if p.split(".")[-1].lower() in IMG_FORMATS
|
210 |
+
)
|
211 |
+
assert img_paths, f"No images found in {img_dir}."
|
212 |
+
|
213 |
+
img_hash = self.get_hash(img_paths)
|
214 |
+
if osp.exists(valid_img_record):
|
215 |
+
with open(valid_img_record, "r") as f:
|
216 |
+
cache_info = json.load(f)
|
217 |
+
if "image_hash" in cache_info and cache_info["image_hash"] == img_hash:
|
218 |
+
img_info = cache_info["information"]
|
219 |
+
else:
|
220 |
+
self.check_images = True
|
221 |
+
else:
|
222 |
+
self.check_images = True
|
223 |
+
|
224 |
+
# check images
|
225 |
+
if self.check_images and self.main_process:
|
226 |
+
img_info = {}
|
227 |
+
nc, msgs = 0, [] # number corrupt, messages
|
228 |
+
LOGGER.info(
|
229 |
+
f"{self.task}: Checking formats of images with {NUM_THREADS} process(es): "
|
230 |
+
)
|
231 |
+
with Pool(NUM_THREADS) as pool:
|
232 |
+
pbar = tqdm(
|
233 |
+
pool.imap(TrainValDataset.check_image, img_paths),
|
234 |
+
total=len(img_paths),
|
235 |
+
)
|
236 |
+
for img_path, shape_per_img, nc_per_img, msg in pbar:
|
237 |
+
if nc_per_img == 0: # not corrupted
|
238 |
+
img_info[img_path] = {"shape": shape_per_img}
|
239 |
+
nc += nc_per_img
|
240 |
+
if msg:
|
241 |
+
msgs.append(msg)
|
242 |
+
pbar.desc = f"{nc} image(s) corrupted"
|
243 |
+
pbar.close()
|
244 |
+
if msgs:
|
245 |
+
LOGGER.info("\n".join(msgs))
|
246 |
+
|
247 |
+
cache_info = {"information": img_info, "image_hash": img_hash}
|
248 |
+
# save valid image paths.
|
249 |
+
with open(valid_img_record, "w") as f:
|
250 |
+
json.dump(cache_info, f)
|
251 |
+
|
252 |
+
# check and load anns
|
253 |
+
label_dir = osp.join(
|
254 |
+
osp.dirname(osp.dirname(img_dir)), "labels", osp.basename(img_dir)
|
255 |
+
)
|
256 |
+
assert osp.exists(label_dir), f"{label_dir} is an invalid directory path!"
|
257 |
+
|
258 |
+
img_paths = list(img_info.keys())
|
259 |
+
label_paths = sorted(
|
260 |
+
osp.join(label_dir, osp.splitext(osp.basename(p))[0] + ".txt")
|
261 |
+
for p in img_paths
|
262 |
+
)
|
263 |
+
label_hash = self.get_hash(label_paths)
|
264 |
+
if "label_hash" not in cache_info or cache_info["label_hash"] != label_hash:
|
265 |
+
self.check_labels = True
|
266 |
+
|
267 |
+
if self.check_labels:
|
268 |
+
cache_info["label_hash"] = label_hash
|
269 |
+
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number corrupt, messages
|
270 |
+
LOGGER.info(
|
271 |
+
f"{self.task}: Checking formats of labels with {NUM_THREADS} process(es): "
|
272 |
+
)
|
273 |
+
with Pool(NUM_THREADS) as pool:
|
274 |
+
pbar = pool.imap(
|
275 |
+
TrainValDataset.check_label_files, zip(img_paths, label_paths)
|
276 |
+
)
|
277 |
+
pbar = tqdm(pbar, total=len(label_paths)) if self.main_process else pbar
|
278 |
+
for (
|
279 |
+
img_path,
|
280 |
+
labels_per_file,
|
281 |
+
nc_per_file,
|
282 |
+
nm_per_file,
|
283 |
+
nf_per_file,
|
284 |
+
ne_per_file,
|
285 |
+
msg,
|
286 |
+
) in pbar:
|
287 |
+
if nc_per_file == 0:
|
288 |
+
img_info[img_path]["labels"] = labels_per_file
|
289 |
+
else:
|
290 |
+
img_info.pop(img_path)
|
291 |
+
nc += nc_per_file
|
292 |
+
nm += nm_per_file
|
293 |
+
nf += nf_per_file
|
294 |
+
ne += ne_per_file
|
295 |
+
if msg:
|
296 |
+
msgs.append(msg)
|
297 |
+
if self.main_process:
|
298 |
+
pbar.desc = f"{nf} label(s) found, {nm} label(s) missing, {ne} label(s) empty, {nc} invalid label files"
|
299 |
+
if self.main_process:
|
300 |
+
pbar.close()
|
301 |
+
with open(valid_img_record, "w") as f:
|
302 |
+
json.dump(cache_info, f)
|
303 |
+
if msgs:
|
304 |
+
LOGGER.info("\n".join(msgs))
|
305 |
+
if nf == 0:
|
306 |
+
LOGGER.warning(
|
307 |
+
f"WARNING: No labels found in {osp.dirname(self.img_paths[0])}. "
|
308 |
+
)
|
309 |
+
|
310 |
+
if self.task.lower() == "val":
|
311 |
+
if self.data_dict.get("is_coco", False): # use original json file when evaluating on coco dataset.
|
312 |
+
assert osp.exists(self.data_dict["anno_path"]), "Eval on coco dataset must provide valid path of the annotation file in config file: data/coco.yaml"
|
313 |
+
else:
|
314 |
+
assert (
|
315 |
+
self.class_names
|
316 |
+
), "Class names is required when converting labels to coco format for evaluating."
|
317 |
+
save_dir = osp.join(osp.dirname(osp.dirname(img_dir)), "annotations")
|
318 |
+
if not osp.exists(save_dir):
|
319 |
+
os.mkdir(save_dir)
|
320 |
+
save_path = osp.join(
|
321 |
+
save_dir, "instances_" + osp.basename(img_dir) + ".json"
|
322 |
+
)
|
323 |
+
TrainValDataset.generate_coco_format_labels(
|
324 |
+
img_info, self.class_names, save_path
|
325 |
+
)
|
326 |
+
|
327 |
+
img_paths, labels = list(
|
328 |
+
zip(
|
329 |
+
*[
|
330 |
+
(
|
331 |
+
img_path,
|
332 |
+
np.array(info["labels"], dtype=np.float32)
|
333 |
+
if info["labels"]
|
334 |
+
else np.zeros((0, 5), dtype=np.float32),
|
335 |
+
)
|
336 |
+
for img_path, info in img_info.items()
|
337 |
+
]
|
338 |
+
)
|
339 |
+
)
|
340 |
+
self.img_info = img_info
|
341 |
+
LOGGER.info(
|
342 |
+
f"{self.task}: Final numbers of valid images: {len(img_paths)}/ labels: {len(labels)}. "
|
343 |
+
)
|
344 |
+
return img_paths, labels
|
345 |
+
|
346 |
+
def get_mosaic(self, index):
|
347 |
+
"""Gets images and labels after mosaic augments"""
|
348 |
+
indices = [index] + random.choices(
|
349 |
+
range(0, len(self.img_paths)), k=3
|
350 |
+
) # 3 additional image indices
|
351 |
+
random.shuffle(indices)
|
352 |
+
imgs, hs, ws, labels = [], [], [], []
|
353 |
+
for index in indices:
|
354 |
+
img, _, (h, w) = self.load_image(index)
|
355 |
+
labels_per_img = self.labels[index]
|
356 |
+
imgs.append(img)
|
357 |
+
hs.append(h)
|
358 |
+
ws.append(w)
|
359 |
+
labels.append(labels_per_img)
|
360 |
+
img, labels = mosaic_augmentation(self.img_size, imgs, hs, ws, labels, self.hyp)
|
361 |
+
return img, labels
|
362 |
+
|
363 |
+
def general_augment(self, img, labels):
|
364 |
+
"""Gets images and labels after general augment
|
365 |
+
This function applies hsv, random ud-flip and random lr-flips augments.
|
366 |
+
"""
|
367 |
+
nl = len(labels)
|
368 |
+
|
369 |
+
# HSV color-space
|
370 |
+
augment_hsv(
|
371 |
+
img,
|
372 |
+
hgain=self.hyp["hsv_h"],
|
373 |
+
sgain=self.hyp["hsv_s"],
|
374 |
+
vgain=self.hyp["hsv_v"],
|
375 |
+
)
|
376 |
+
|
377 |
+
# Flip up-down
|
378 |
+
if random.random() < self.hyp["flipud"]:
|
379 |
+
img = np.flipud(img)
|
380 |
+
if nl:
|
381 |
+
labels[:, 2] = 1 - labels[:, 2]
|
382 |
+
|
383 |
+
# Flip left-right
|
384 |
+
if random.random() < self.hyp["fliplr"]:
|
385 |
+
img = np.fliplr(img)
|
386 |
+
if nl:
|
387 |
+
labels[:, 1] = 1 - labels[:, 1]
|
388 |
+
|
389 |
+
return img, labels
|
390 |
+
|
391 |
+
def sort_files_shapes(self):
|
392 |
+
# Sort by aspect ratio
|
393 |
+
batch_num = self.batch_indices[-1] + 1
|
394 |
+
s = self.shapes # wh
|
395 |
+
ar = s[:, 1] / s[:, 0] # aspect ratio
|
396 |
+
irect = ar.argsort()
|
397 |
+
self.img_paths = [self.img_paths[i] for i in irect]
|
398 |
+
self.labels = [self.labels[i] for i in irect]
|
399 |
+
self.shapes = s[irect] # wh
|
400 |
+
ar = ar[irect]
|
401 |
+
|
402 |
+
# Set training image shapes
|
403 |
+
shapes = [[1, 1]] * batch_num
|
404 |
+
for i in range(batch_num):
|
405 |
+
ari = ar[self.batch_indices == i]
|
406 |
+
mini, maxi = ari.min(), ari.max()
|
407 |
+
if maxi < 1:
|
408 |
+
shapes[i] = [maxi, 1]
|
409 |
+
elif mini > 1:
|
410 |
+
shapes[i] = [1, 1 / mini]
|
411 |
+
self.batch_shapes = (
|
412 |
+
np.ceil(np.array(shapes) * self.img_size / self.stride + self.pad).astype(
|
413 |
+
np.int
|
414 |
+
)
|
415 |
+
* self.stride
|
416 |
+
)
|
417 |
+
|
418 |
+
@staticmethod
|
419 |
+
def check_image(im_file):
|
420 |
+
# verify an image.
|
421 |
+
nc, msg = 0, ""
|
422 |
+
try:
|
423 |
+
im = Image.open(im_file)
|
424 |
+
im.verify() # PIL verify
|
425 |
+
shape = im.size # (width, height)
|
426 |
+
im_exif = im._getexif()
|
427 |
+
if im_exif and ORIENTATION in im_exif:
|
428 |
+
rotation = im_exif[ORIENTATION]
|
429 |
+
if rotation in (6, 8):
|
430 |
+
shape = (shape[1], shape[0])
|
431 |
+
|
432 |
+
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
433 |
+
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
434 |
+
if im.format.lower() in ("jpg", "jpeg"):
|
435 |
+
with open(im_file, "rb") as f:
|
436 |
+
f.seek(-2, 2)
|
437 |
+
if f.read() != b"\xff\xd9": # corrupt JPEG
|
438 |
+
ImageOps.exif_transpose(Image.open(im_file)).save(
|
439 |
+
im_file, "JPEG", subsampling=0, quality=100
|
440 |
+
)
|
441 |
+
msg += f"WARNING: {im_file}: corrupt JPEG restored and saved"
|
442 |
+
return im_file, shape, nc, msg
|
443 |
+
except Exception as e:
|
444 |
+
nc = 1
|
445 |
+
msg = f"WARNING: {im_file}: ignoring corrupt image: {e}"
|
446 |
+
return im_file, None, nc, msg
|
447 |
+
|
448 |
+
@staticmethod
|
449 |
+
def check_label_files(args):
|
450 |
+
img_path, lb_path = args
|
451 |
+
nm, nf, ne, nc, msg = 0, 0, 0, 0, "" # number (missing, found, empty, message
|
452 |
+
try:
|
453 |
+
if osp.exists(lb_path):
|
454 |
+
nf = 1 # label found
|
455 |
+
with open(lb_path, "r") as f:
|
456 |
+
labels = [
|
457 |
+
x.split() for x in f.read().strip().splitlines() if len(x)
|
458 |
+
]
|
459 |
+
labels = np.array(labels, dtype=np.float32)
|
460 |
+
if len(labels):
|
461 |
+
assert all(
|
462 |
+
len(l) == 5 for l in labels
|
463 |
+
), f"{lb_path}: wrong label format."
|
464 |
+
assert (
|
465 |
+
labels >= 0
|
466 |
+
).all(), f"{lb_path}: Label values error: all values in label file must > 0"
|
467 |
+
assert (
|
468 |
+
labels[:, 1:] <= 1
|
469 |
+
).all(), f"{lb_path}: Label values error: all coordinates must be normalized"
|
470 |
+
|
471 |
+
_, indices = np.unique(labels, axis=0, return_index=True)
|
472 |
+
if len(indices) < len(labels): # duplicate row check
|
473 |
+
labels = labels[indices] # remove duplicates
|
474 |
+
msg += f"WARNING: {lb_path}: {len(labels) - len(indices)} duplicate labels removed"
|
475 |
+
labels = labels.tolist()
|
476 |
+
else:
|
477 |
+
ne = 1 # label empty
|
478 |
+
labels = []
|
479 |
+
else:
|
480 |
+
nm = 1 # label missing
|
481 |
+
labels = []
|
482 |
+
|
483 |
+
return img_path, labels, nc, nm, nf, ne, msg
|
484 |
+
except Exception as e:
|
485 |
+
nc = 1
|
486 |
+
msg = f"WARNING: {lb_path}: ignoring invalid labels: {e}"
|
487 |
+
return img_path, None, nc, nm, nf, ne, msg
|
488 |
+
|
489 |
+
@staticmethod
|
490 |
+
def generate_coco_format_labels(img_info, class_names, save_path):
|
491 |
+
# for evaluation with pycocotools
|
492 |
+
dataset = {"categories": [], "annotations": [], "images": []}
|
493 |
+
for i, class_name in enumerate(class_names):
|
494 |
+
dataset["categories"].append(
|
495 |
+
{"id": i, "name": class_name, "supercategory": ""}
|
496 |
+
)
|
497 |
+
|
498 |
+
ann_id = 0
|
499 |
+
LOGGER.info(f"Convert to COCO format")
|
500 |
+
for i, (img_path, info) in enumerate(tqdm(img_info.items())):
|
501 |
+
labels = info["labels"] if info["labels"] else []
|
502 |
+
img_id = osp.splitext(osp.basename(img_path))[0]
|
503 |
+
img_id = int(img_id) if img_id.isnumeric() else img_id
|
504 |
+
img_w, img_h = info["shape"]
|
505 |
+
dataset["images"].append(
|
506 |
+
{
|
507 |
+
"file_name": os.path.basename(img_path),
|
508 |
+
"id": img_id,
|
509 |
+
"width": img_w,
|
510 |
+
"height": img_h,
|
511 |
+
}
|
512 |
+
)
|
513 |
+
if labels:
|
514 |
+
for label in labels:
|
515 |
+
c, x, y, w, h = label[:5]
|
516 |
+
# convert x,y,w,h to x1,y1,x2,y2
|
517 |
+
x1 = (x - w / 2) * img_w
|
518 |
+
y1 = (y - h / 2) * img_h
|
519 |
+
x2 = (x + w / 2) * img_w
|
520 |
+
y2 = (y + h / 2) * img_h
|
521 |
+
# cls_id starts from 0
|
522 |
+
cls_id = int(c)
|
523 |
+
w = max(0, x2 - x1)
|
524 |
+
h = max(0, y2 - y1)
|
525 |
+
dataset["annotations"].append(
|
526 |
+
{
|
527 |
+
"area": h * w,
|
528 |
+
"bbox": [x1, y1, w, h],
|
529 |
+
"category_id": cls_id,
|
530 |
+
"id": ann_id,
|
531 |
+
"image_id": img_id,
|
532 |
+
"iscrowd": 0,
|
533 |
+
# mask
|
534 |
+
"segmentation": [],
|
535 |
+
}
|
536 |
+
)
|
537 |
+
ann_id += 1
|
538 |
+
|
539 |
+
with open(save_path, "w") as f:
|
540 |
+
json.dump(dataset, f)
|
541 |
+
LOGGER.info(
|
542 |
+
f"Convert to COCO format finished. Resutls saved in {save_path}"
|
543 |
+
)
|
544 |
+
|
545 |
+
@staticmethod
|
546 |
+
def get_hash(paths):
|
547 |
+
"""Get the hash value of paths"""
|
548 |
+
assert isinstance(paths, list), "Only support list currently."
|
549 |
+
h = hashlib.md5("".join(paths).encode())
|
550 |
+
return h.hexdigest()
|
yolov6/data/vis_dataset.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Description: visualize yolo label image.
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
IMG_FORMATS = ["bmp", "jpg", "jpeg", "png", "tif", "tiff", "dng", "webp", "mpo"]
|
10 |
+
|
11 |
+
def main(args):
|
12 |
+
img_dir, label_dir, class_names = args.img_dir, args.label_dir, args.class_names
|
13 |
+
|
14 |
+
label_map = dict()
|
15 |
+
for class_id, classname in enumerate(class_names):
|
16 |
+
label_map[class_id] = classname
|
17 |
+
|
18 |
+
for file in os.listdir(img_dir):
|
19 |
+
if file.split('.')[-1] not in IMG_FORMATS:
|
20 |
+
print(f'[Warning]: Non-image file {file}')
|
21 |
+
continue
|
22 |
+
img_path = os.path.join(img_dir, file)
|
23 |
+
label_path = os.path.join(label_dir, file[: file.rindex('.')] + '.txt')
|
24 |
+
|
25 |
+
try:
|
26 |
+
img_data = cv2.imread(img_path)
|
27 |
+
height, width, _ = img_data.shape
|
28 |
+
color = [tuple(np.random.choice(range(256), size=3)) for i in class_names]
|
29 |
+
thickness = 2
|
30 |
+
|
31 |
+
with open(label_path, 'r') as f:
|
32 |
+
for bbox in f:
|
33 |
+
cls, x_c, y_c, w, h = [float(v) if i > 0 else int(v) for i, v in enumerate(bbox.split('\n')[0].split(' '))]
|
34 |
+
|
35 |
+
x_tl = int((x_c - w / 2) * width)
|
36 |
+
y_tl = int((y_c - h / 2) * height)
|
37 |
+
cv2.rectangle(img_data, (x_tl, y_tl), (x_tl + int(w * width), y_tl + int(h * height)), tuple([int(x) for x in color[cls]]), thickness)
|
38 |
+
cv2.putText(img_data, label_map[cls], (x_tl, y_tl - 10), cv2.FONT_HERSHEY_COMPLEX, 1, tuple([int(x) for x in color[cls]]), thickness)
|
39 |
+
|
40 |
+
cv2.imshow('image', img_data)
|
41 |
+
cv2.waitKey(0)
|
42 |
+
except Exception as e:
|
43 |
+
print(f'[Error]: {e} {img_path}')
|
44 |
+
print('======All Done!======')
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == '__main__':
|
48 |
+
parser = argparse.ArgumentParser()
|
49 |
+
parser.add_argument('--img_dir', default='VOCdevkit/voc_07_12/images')
|
50 |
+
parser.add_argument('--label_dir', default='VOCdevkit/voc_07_12/labels')
|
51 |
+
parser.add_argument('--class_names', default=['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
|
52 |
+
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'])
|
53 |
+
|
54 |
+
args = parser.parse_args()
|
55 |
+
print(args)
|
56 |
+
|
57 |
+
main(args)
|
yolov6/data/voc2yolo.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import xml.etree.ElementTree as ET
|
2 |
+
from tqdm import tqdm
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
# VOC dataset (refer https://github.com/ultralytics/yolov5/blob/master/data/VOC.yaml)
|
8 |
+
# VOC2007 trainval: 446MB, 5012 images
|
9 |
+
# VOC2007 test: 438MB, 4953 images
|
10 |
+
# VOC2012 trainval: 1.95GB, 17126 images
|
11 |
+
|
12 |
+
VOC_NAMES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
|
13 |
+
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
|
14 |
+
|
15 |
+
def convert_label(path, lb_path, year, image_id):
|
16 |
+
def convert_box(size, box):
|
17 |
+
dw, dh = 1. / size[0], 1. / size[1]
|
18 |
+
x, y, w, h = (box[0] + box[1]) / 2.0 - 1, (box[2] + box[3]) / 2.0 - 1, box[1] - box[0], box[3] - box[2]
|
19 |
+
return x * dw, y * dh, w * dw, h * dh
|
20 |
+
in_file = open(os.path.join(path, f'VOC{year}/Annotations/{image_id}.xml'))
|
21 |
+
out_file = open(lb_path, 'w')
|
22 |
+
tree = ET.parse(in_file)
|
23 |
+
root = tree.getroot()
|
24 |
+
size = root.find('size')
|
25 |
+
w = int(size.find('width').text)
|
26 |
+
h = int(size.find('height').text)
|
27 |
+
for obj in root.iter('object'):
|
28 |
+
cls = obj.find('name').text
|
29 |
+
if cls in VOC_NAMES and not int(obj.find('difficult').text) == 1:
|
30 |
+
xmlbox = obj.find('bndbox')
|
31 |
+
bb = convert_box((w, h), [float(xmlbox.find(x).text) for x in ('xmin', 'xmax', 'ymin', 'ymax')])
|
32 |
+
cls_id = VOC_NAMES.index(cls) # class id
|
33 |
+
out_file.write(" ".join([str(a) for a in (cls_id, *bb)]) + '\n')
|
34 |
+
|
35 |
+
|
36 |
+
def gen_voc07_12(voc_path):
|
37 |
+
'''
|
38 |
+
Generate voc07+12 setting dataset:
|
39 |
+
train: # train images 16551 images
|
40 |
+
- images/train2012
|
41 |
+
- images/train2007
|
42 |
+
- images/val2012
|
43 |
+
- images/val2007
|
44 |
+
val: # val images (relative to 'path') 4952 images
|
45 |
+
- images/test2007
|
46 |
+
'''
|
47 |
+
dataset_root = os.path.join(voc_path, 'voc_07_12')
|
48 |
+
if not os.path.exists(dataset_root):
|
49 |
+
os.makedirs(dataset_root)
|
50 |
+
|
51 |
+
dataset_settings = {'train': ['train2007', 'val2007', 'train2012', 'val2012'], 'val':['test2007']}
|
52 |
+
for item in ['images', 'labels']:
|
53 |
+
for data_type, data_list in dataset_settings.items():
|
54 |
+
for data_name in data_list:
|
55 |
+
ori_path = os.path.join(voc_path, item, data_name)
|
56 |
+
new_path = os.path.join(dataset_root, item, data_type)
|
57 |
+
if not os.path.exists(new_path):
|
58 |
+
os.makedirs(new_path)
|
59 |
+
|
60 |
+
print(f'[INFO]: Copying {ori_path} to {new_path}')
|
61 |
+
for file in os.listdir(ori_path):
|
62 |
+
shutil.copy(os.path.join(ori_path, file), new_path)
|
63 |
+
|
64 |
+
|
65 |
+
def main(args):
|
66 |
+
voc_path = args.voc_path
|
67 |
+
for year, image_set in ('2012', 'train'), ('2012', 'val'), ('2007', 'train'), ('2007', 'val'), ('2007', 'test'):
|
68 |
+
imgs_path = os.path.join(voc_path, 'images', f'{image_set}')
|
69 |
+
lbs_path = os.path.join(voc_path, 'labels', f'{image_set}')
|
70 |
+
|
71 |
+
try:
|
72 |
+
with open(os.path.join(voc_path, f'VOC{year}/ImageSets/Main/{image_set}.txt'), 'r') as f:
|
73 |
+
image_ids = f.read().strip().split()
|
74 |
+
if not os.path.exists(imgs_path):
|
75 |
+
os.makedirs(imgs_path)
|
76 |
+
if not os.path.exists(lbs_path):
|
77 |
+
os.makedirs(lbs_path)
|
78 |
+
|
79 |
+
for id in tqdm(image_ids, desc=f'{image_set}{year}'):
|
80 |
+
f = os.path.join(voc_path, f'VOC{year}/JPEGImages/{id}.jpg') # old img path
|
81 |
+
lb_path = os.path.join(lbs_path, f'{id}.txt') # new label path
|
82 |
+
convert_label(voc_path, lb_path, year, id) # convert labels to YOLO format
|
83 |
+
if os.path.exists(f):
|
84 |
+
shutil.move(f, imgs_path) # move image
|
85 |
+
except Exception as e:
|
86 |
+
print(f'[Warning]: {e} {year}{image_set} convert fail!')
|
87 |
+
|
88 |
+
gen_voc07_12(voc_path)
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == '__main__':
|
93 |
+
parser = argparse.ArgumentParser()
|
94 |
+
parser.add_argument('--voc_path', default='VOCdevkit')
|
95 |
+
|
96 |
+
args = parser.parse_args()
|
97 |
+
print(args)
|
98 |
+
|
99 |
+
main(args)
|
yolov6/layers/common.py
ADDED
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
import warnings
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from yolov6.layers.dbb_transforms import *
|
12 |
+
|
13 |
+
|
14 |
+
class SiLU(nn.Module):
|
15 |
+
'''Activation of SiLU'''
|
16 |
+
@staticmethod
|
17 |
+
def forward(x):
|
18 |
+
return x * torch.sigmoid(x)
|
19 |
+
|
20 |
+
|
21 |
+
class Conv(nn.Module):
|
22 |
+
'''Normal Conv with SiLU activation'''
|
23 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False):
|
24 |
+
super().__init__()
|
25 |
+
padding = kernel_size // 2
|
26 |
+
self.conv = nn.Conv2d(
|
27 |
+
in_channels,
|
28 |
+
out_channels,
|
29 |
+
kernel_size=kernel_size,
|
30 |
+
stride=stride,
|
31 |
+
padding=padding,
|
32 |
+
groups=groups,
|
33 |
+
bias=bias,
|
34 |
+
)
|
35 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
36 |
+
self.act = nn.SiLU()
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
return self.act(self.bn(self.conv(x)))
|
40 |
+
|
41 |
+
def forward_fuse(self, x):
|
42 |
+
return self.act(self.conv(x))
|
43 |
+
|
44 |
+
|
45 |
+
class SimConv(nn.Module):
|
46 |
+
'''Normal Conv with ReLU activation'''
|
47 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1, bias=False):
|
48 |
+
super().__init__()
|
49 |
+
padding = kernel_size // 2
|
50 |
+
self.conv = nn.Conv2d(
|
51 |
+
in_channels,
|
52 |
+
out_channels,
|
53 |
+
kernel_size=kernel_size,
|
54 |
+
stride=stride,
|
55 |
+
padding=padding,
|
56 |
+
groups=groups,
|
57 |
+
bias=bias,
|
58 |
+
)
|
59 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
60 |
+
self.act = nn.ReLU()
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
return self.act(self.bn(self.conv(x)))
|
64 |
+
|
65 |
+
def forward_fuse(self, x):
|
66 |
+
return self.act(self.conv(x))
|
67 |
+
|
68 |
+
|
69 |
+
class SimSPPF(nn.Module):
|
70 |
+
'''Simplified SPPF with ReLU activation'''
|
71 |
+
def __init__(self, in_channels, out_channels, kernel_size=5):
|
72 |
+
super().__init__()
|
73 |
+
c_ = in_channels // 2 # hidden channels
|
74 |
+
self.cv1 = SimConv(in_channels, c_, 1, 1)
|
75 |
+
self.cv2 = SimConv(c_ * 4, out_channels, 1, 1)
|
76 |
+
self.m = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
x = self.cv1(x)
|
80 |
+
with warnings.catch_warnings():
|
81 |
+
warnings.simplefilter('ignore')
|
82 |
+
y1 = self.m(x)
|
83 |
+
y2 = self.m(y1)
|
84 |
+
return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
|
85 |
+
|
86 |
+
|
87 |
+
class Transpose(nn.Module):
|
88 |
+
'''Normal Transpose, default for upsampling'''
|
89 |
+
def __init__(self, in_channels, out_channels, kernel_size=2, stride=2):
|
90 |
+
super().__init__()
|
91 |
+
self.upsample_transpose = torch.nn.ConvTranspose2d(
|
92 |
+
in_channels=in_channels,
|
93 |
+
out_channels=out_channels,
|
94 |
+
kernel_size=kernel_size,
|
95 |
+
stride=stride,
|
96 |
+
bias=True
|
97 |
+
)
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
return self.upsample_transpose(x)
|
101 |
+
|
102 |
+
|
103 |
+
class Concat(nn.Module):
|
104 |
+
def __init__(self, dimension=1):
|
105 |
+
super().__init__()
|
106 |
+
self.d = dimension
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
return torch.cat(x, self.d)
|
110 |
+
|
111 |
+
|
112 |
+
def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
|
113 |
+
'''Basic cell for rep-style block, including conv and bn'''
|
114 |
+
result = nn.Sequential()
|
115 |
+
result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
116 |
+
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))
|
117 |
+
result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
|
118 |
+
return result
|
119 |
+
|
120 |
+
|
121 |
+
class RepBlock(nn.Module):
|
122 |
+
'''
|
123 |
+
RepBlock is a stage block with rep-style basic block
|
124 |
+
'''
|
125 |
+
def __init__(self, in_channels, out_channels, n=1):
|
126 |
+
super().__init__()
|
127 |
+
self.conv1 = RepVGGBlock(in_channels, out_channels)
|
128 |
+
self.block = nn.Sequential(*(RepVGGBlock(out_channels, out_channels) for _ in range(n - 1))) if n > 1 else None
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
x = self.conv1(x)
|
132 |
+
if self.block is not None:
|
133 |
+
x = self.block(x)
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
class RepVGGBlock(nn.Module):
|
138 |
+
'''RepVGGBlock is a basic rep-style block, including training and deploy status
|
139 |
+
This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
|
140 |
+
'''
|
141 |
+
def __init__(self, in_channels, out_channels, kernel_size=3,
|
142 |
+
stride=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False):
|
143 |
+
super(RepVGGBlock, self).__init__()
|
144 |
+
""" Initialization of the class.
|
145 |
+
Args:
|
146 |
+
in_channels (int): Number of channels in the input image
|
147 |
+
out_channels (int): Number of channels produced by the convolution
|
148 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
149 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
150 |
+
padding (int or tuple, optional): Zero-padding added to both sides of
|
151 |
+
the input. Default: 1
|
152 |
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
153 |
+
groups (int, optional): Number of blocked connections from input
|
154 |
+
channels to output channels. Default: 1
|
155 |
+
padding_mode (string, optional): Default: 'zeros'
|
156 |
+
deploy: Whether to be deploy status or training status. Default: False
|
157 |
+
use_se: Whether to use se. Default: False
|
158 |
+
"""
|
159 |
+
self.deploy = deploy
|
160 |
+
self.groups = groups
|
161 |
+
self.in_channels = in_channels
|
162 |
+
self.out_channels = out_channels
|
163 |
+
|
164 |
+
assert kernel_size == 3
|
165 |
+
assert padding == 1
|
166 |
+
|
167 |
+
padding_11 = padding - kernel_size // 2
|
168 |
+
|
169 |
+
self.nonlinearity = nn.ReLU()
|
170 |
+
|
171 |
+
if use_se:
|
172 |
+
raise NotImplementedError("se block not supported yet")
|
173 |
+
else:
|
174 |
+
self.se = nn.Identity()
|
175 |
+
|
176 |
+
if deploy:
|
177 |
+
self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
|
178 |
+
padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)
|
179 |
+
|
180 |
+
else:
|
181 |
+
self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
|
182 |
+
self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
|
183 |
+
self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)
|
184 |
+
|
185 |
+
def forward(self, inputs):
|
186 |
+
'''Forward process'''
|
187 |
+
if hasattr(self, 'rbr_reparam'):
|
188 |
+
return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
|
189 |
+
|
190 |
+
if self.rbr_identity is None:
|
191 |
+
id_out = 0
|
192 |
+
else:
|
193 |
+
id_out = self.rbr_identity(inputs)
|
194 |
+
|
195 |
+
return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))
|
196 |
+
|
197 |
+
def get_equivalent_kernel_bias(self):
|
198 |
+
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
|
199 |
+
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
|
200 |
+
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
|
201 |
+
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
|
202 |
+
|
203 |
+
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
204 |
+
if kernel1x1 is None:
|
205 |
+
return 0
|
206 |
+
else:
|
207 |
+
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
|
208 |
+
|
209 |
+
def _fuse_bn_tensor(self, branch):
|
210 |
+
if branch is None:
|
211 |
+
return 0, 0
|
212 |
+
if isinstance(branch, nn.Sequential):
|
213 |
+
kernel = branch.conv.weight
|
214 |
+
running_mean = branch.bn.running_mean
|
215 |
+
running_var = branch.bn.running_var
|
216 |
+
gamma = branch.bn.weight
|
217 |
+
beta = branch.bn.bias
|
218 |
+
eps = branch.bn.eps
|
219 |
+
else:
|
220 |
+
assert isinstance(branch, nn.BatchNorm2d)
|
221 |
+
if not hasattr(self, 'id_tensor'):
|
222 |
+
input_dim = self.in_channels // self.groups
|
223 |
+
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
|
224 |
+
for i in range(self.in_channels):
|
225 |
+
kernel_value[i, i % input_dim, 1, 1] = 1
|
226 |
+
self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
|
227 |
+
kernel = self.id_tensor
|
228 |
+
running_mean = branch.running_mean
|
229 |
+
running_var = branch.running_var
|
230 |
+
gamma = branch.weight
|
231 |
+
beta = branch.bias
|
232 |
+
eps = branch.eps
|
233 |
+
std = (running_var + eps).sqrt()
|
234 |
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
235 |
+
return kernel * t, beta - running_mean * gamma / std
|
236 |
+
|
237 |
+
def switch_to_deploy(self):
|
238 |
+
if hasattr(self, 'rbr_reparam'):
|
239 |
+
return
|
240 |
+
kernel, bias = self.get_equivalent_kernel_bias()
|
241 |
+
self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,
|
242 |
+
kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
|
243 |
+
padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)
|
244 |
+
self.rbr_reparam.weight.data = kernel
|
245 |
+
self.rbr_reparam.bias.data = bias
|
246 |
+
for para in self.parameters():
|
247 |
+
para.detach_()
|
248 |
+
self.__delattr__('rbr_dense')
|
249 |
+
self.__delattr__('rbr_1x1')
|
250 |
+
if hasattr(self, 'rbr_identity'):
|
251 |
+
self.__delattr__('rbr_identity')
|
252 |
+
if hasattr(self, 'id_tensor'):
|
253 |
+
self.__delattr__('id_tensor')
|
254 |
+
self.deploy = True
|
255 |
+
|
256 |
+
|
257 |
+
def conv_bn_v2(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
|
258 |
+
padding_mode='zeros'):
|
259 |
+
conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
|
260 |
+
stride=stride, padding=padding, dilation=dilation, groups=groups,
|
261 |
+
bias=False, padding_mode=padding_mode)
|
262 |
+
bn_layer = nn.BatchNorm2d(num_features=out_channels, affine=True)
|
263 |
+
se = nn.Sequential()
|
264 |
+
se.add_module('conv', conv_layer)
|
265 |
+
se.add_module('bn', bn_layer)
|
266 |
+
return se
|
267 |
+
|
268 |
+
|
269 |
+
class IdentityBasedConv1x1(nn.Conv2d):
|
270 |
+
|
271 |
+
def __init__(self, channels, groups=1):
|
272 |
+
super(IdentityBasedConv1x1, self).__init__(in_channels=channels, out_channels=channels, kernel_size=1, stride=1, padding=0, groups=groups, bias=False)
|
273 |
+
|
274 |
+
assert channels % groups == 0
|
275 |
+
input_dim = channels // groups
|
276 |
+
id_value = np.zeros((channels, input_dim, 1, 1))
|
277 |
+
for i in range(channels):
|
278 |
+
id_value[i, i % input_dim, 0, 0] = 1
|
279 |
+
self.id_tensor = torch.from_numpy(id_value).type_as(self.weight)
|
280 |
+
nn.init.zeros_(self.weight)
|
281 |
+
|
282 |
+
def forward(self, input):
|
283 |
+
kernel = self.weight + self.id_tensor.to(self.weight.device)
|
284 |
+
result = F.conv2d(input, kernel, None, stride=1, padding=0, dilation=self.dilation, groups=self.groups)
|
285 |
+
return result
|
286 |
+
|
287 |
+
def get_actual_kernel(self):
|
288 |
+
return self.weight + self.id_tensor.to(self.weight.device)
|
289 |
+
|
290 |
+
|
291 |
+
class BNAndPadLayer(nn.Module):
|
292 |
+
def __init__(self,
|
293 |
+
pad_pixels,
|
294 |
+
num_features,
|
295 |
+
eps=1e-5,
|
296 |
+
momentum=0.1,
|
297 |
+
affine=True,
|
298 |
+
track_running_stats=True):
|
299 |
+
super(BNAndPadLayer, self).__init__()
|
300 |
+
self.bn = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats)
|
301 |
+
self.pad_pixels = pad_pixels
|
302 |
+
|
303 |
+
def forward(self, input):
|
304 |
+
output = self.bn(input)
|
305 |
+
if self.pad_pixels > 0:
|
306 |
+
if self.bn.affine:
|
307 |
+
pad_values = self.bn.bias.detach() - self.bn.running_mean * self.bn.weight.detach() / torch.sqrt(self.bn.running_var + self.bn.eps)
|
308 |
+
else:
|
309 |
+
pad_values = - self.bn.running_mean / torch.sqrt(self.bn.running_var + self.bn.eps)
|
310 |
+
output = F.pad(output, [self.pad_pixels] * 4)
|
311 |
+
pad_values = pad_values.view(1, -1, 1, 1)
|
312 |
+
output[:, :, 0:self.pad_pixels, :] = pad_values
|
313 |
+
output[:, :, -self.pad_pixels:, :] = pad_values
|
314 |
+
output[:, :, :, 0:self.pad_pixels] = pad_values
|
315 |
+
output[:, :, :, -self.pad_pixels:] = pad_values
|
316 |
+
return output
|
317 |
+
|
318 |
+
@property
|
319 |
+
def bn_weight(self):
|
320 |
+
return self.bn.weight
|
321 |
+
|
322 |
+
@property
|
323 |
+
def bn_bias(self):
|
324 |
+
return self.bn.bias
|
325 |
+
|
326 |
+
@property
|
327 |
+
def running_mean(self):
|
328 |
+
return self.bn.running_mean
|
329 |
+
|
330 |
+
@property
|
331 |
+
def running_var(self):
|
332 |
+
return self.bn.running_var
|
333 |
+
|
334 |
+
@property
|
335 |
+
def eps(self):
|
336 |
+
return self.bn.eps
|
337 |
+
|
338 |
+
|
339 |
+
class DBBBlock(nn.Module):
|
340 |
+
'''
|
341 |
+
RepBlock is a stage block with rep-style basic block
|
342 |
+
'''
|
343 |
+
def __init__(self, in_channels, out_channels, n=1):
|
344 |
+
super().__init__()
|
345 |
+
self.conv1 = DiverseBranchBlock(in_channels, out_channels)
|
346 |
+
self.block = nn.Sequential(*(DiverseBranchBlock(out_channels, out_channels) for _ in range(n - 1))) if n > 1 else None
|
347 |
+
|
348 |
+
def forward(self, x):
|
349 |
+
x = self.conv1(x)
|
350 |
+
if self.block is not None:
|
351 |
+
x = self.block(x)
|
352 |
+
return x
|
353 |
+
|
354 |
+
|
355 |
+
class DiverseBranchBlock(nn.Module):
|
356 |
+
|
357 |
+
def __init__(self, in_channels, out_channels, kernel_size=3,
|
358 |
+
stride=1, padding=1, dilation=1, groups=1,
|
359 |
+
internal_channels_1x1_3x3=None,
|
360 |
+
deploy=False, nonlinear=nn.ReLU(), single_init=False):
|
361 |
+
super(DiverseBranchBlock, self).__init__()
|
362 |
+
self.deploy = deploy
|
363 |
+
|
364 |
+
if nonlinear is None:
|
365 |
+
self.nonlinear = nn.Identity()
|
366 |
+
else:
|
367 |
+
self.nonlinear = nonlinear
|
368 |
+
|
369 |
+
self.kernel_size = kernel_size
|
370 |
+
self.out_channels = out_channels
|
371 |
+
self.groups = groups
|
372 |
+
assert padding == kernel_size // 2
|
373 |
+
|
374 |
+
if deploy:
|
375 |
+
self.dbb_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
|
376 |
+
padding=padding, dilation=dilation, groups=groups, bias=True)
|
377 |
+
|
378 |
+
else:
|
379 |
+
|
380 |
+
self.dbb_origin = conv_bn_v2(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
381 |
+
|
382 |
+
self.dbb_avg = nn.Sequential()
|
383 |
+
if groups < out_channels:
|
384 |
+
self.dbb_avg.add_module('conv',
|
385 |
+
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
|
386 |
+
stride=1, padding=0, groups=groups, bias=False))
|
387 |
+
self.dbb_avg.add_module('bn', BNAndPadLayer(pad_pixels=padding, num_features=out_channels))
|
388 |
+
self.dbb_avg.add_module('avg', nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0))
|
389 |
+
self.dbb_1x1 = conv_bn_v2(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride,
|
390 |
+
padding=0, groups=groups)
|
391 |
+
else:
|
392 |
+
self.dbb_avg.add_module('avg', nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=padding))
|
393 |
+
|
394 |
+
self.dbb_avg.add_module('avgbn', nn.BatchNorm2d(out_channels))
|
395 |
+
|
396 |
+
if internal_channels_1x1_3x3 is None:
|
397 |
+
internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels # For mobilenet, it is better to have 2X internal channels
|
398 |
+
|
399 |
+
self.dbb_1x1_kxk = nn.Sequential()
|
400 |
+
if internal_channels_1x1_3x3 == in_channels:
|
401 |
+
self.dbb_1x1_kxk.add_module('idconv1', IdentityBasedConv1x1(channels=in_channels, groups=groups))
|
402 |
+
else:
|
403 |
+
self.dbb_1x1_kxk.add_module('conv1', nn.Conv2d(in_channels=in_channels, out_channels=internal_channels_1x1_3x3,
|
404 |
+
kernel_size=1, stride=1, padding=0, groups=groups, bias=False))
|
405 |
+
self.dbb_1x1_kxk.add_module('bn1', BNAndPadLayer(pad_pixels=padding, num_features=internal_channels_1x1_3x3, affine=True))
|
406 |
+
self.dbb_1x1_kxk.add_module('conv2', nn.Conv2d(in_channels=internal_channels_1x1_3x3, out_channels=out_channels,
|
407 |
+
kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=False))
|
408 |
+
self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels))
|
409 |
+
|
410 |
+
# The experiments reported in the paper used the default initialization of bn.weight (all as 1). But changing the initialization may be useful in some cases.
|
411 |
+
if single_init:
|
412 |
+
# Initialize the bn.weight of dbb_origin as 1 and others as 0. This is not the default setting.
|
413 |
+
self.single_init()
|
414 |
+
|
415 |
+
def get_equivalent_kernel_bias(self):
|
416 |
+
k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight, self.dbb_origin.bn)
|
417 |
+
|
418 |
+
if hasattr(self, 'dbb_1x1'):
|
419 |
+
k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight, self.dbb_1x1.bn)
|
420 |
+
k_1x1 = transVI_multiscale(k_1x1, self.kernel_size)
|
421 |
+
else:
|
422 |
+
k_1x1, b_1x1 = 0, 0
|
423 |
+
|
424 |
+
if hasattr(self.dbb_1x1_kxk, 'idconv1'):
|
425 |
+
k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
|
426 |
+
else:
|
427 |
+
k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
|
428 |
+
k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(k_1x1_kxk_first, self.dbb_1x1_kxk.bn1)
|
429 |
+
k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2)
|
430 |
+
k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(k_1x1_kxk_first, b_1x1_kxk_first, k_1x1_kxk_second, b_1x1_kxk_second, groups=self.groups)
|
431 |
+
|
432 |
+
k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
|
433 |
+
k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg.to(self.dbb_avg.avgbn.weight.device), self.dbb_avg.avgbn)
|
434 |
+
if hasattr(self.dbb_avg, 'conv'):
|
435 |
+
k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(self.dbb_avg.conv.weight, self.dbb_avg.bn)
|
436 |
+
k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(k_1x1_avg_first, b_1x1_avg_first, k_1x1_avg_second, b_1x1_avg_second, groups=self.groups)
|
437 |
+
else:
|
438 |
+
k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
|
439 |
+
|
440 |
+
return transII_addbranch((k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged), (b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged))
|
441 |
+
|
442 |
+
def switch_to_deploy(self):
|
443 |
+
if hasattr(self, 'dbb_reparam'):
|
444 |
+
return
|
445 |
+
kernel, bias = self.get_equivalent_kernel_bias()
|
446 |
+
self.dbb_reparam = nn.Conv2d(in_channels=self.dbb_origin.conv.in_channels, out_channels=self.dbb_origin.conv.out_channels,
|
447 |
+
kernel_size=self.dbb_origin.conv.kernel_size, stride=self.dbb_origin.conv.stride,
|
448 |
+
padding=self.dbb_origin.conv.padding, dilation=self.dbb_origin.conv.dilation, groups=self.dbb_origin.conv.groups, bias=True)
|
449 |
+
self.dbb_reparam.weight.data = kernel
|
450 |
+
self.dbb_reparam.bias.data = bias
|
451 |
+
for para in self.parameters():
|
452 |
+
para.detach_()
|
453 |
+
self.__delattr__('dbb_origin')
|
454 |
+
self.__delattr__('dbb_avg')
|
455 |
+
if hasattr(self, 'dbb_1x1'):
|
456 |
+
self.__delattr__('dbb_1x1')
|
457 |
+
self.__delattr__('dbb_1x1_kxk')
|
458 |
+
|
459 |
+
def forward(self, inputs):
|
460 |
+
|
461 |
+
if hasattr(self, 'dbb_reparam'):
|
462 |
+
return self.nonlinear(self.dbb_reparam(inputs))
|
463 |
+
|
464 |
+
out = self.dbb_origin(inputs)
|
465 |
+
if hasattr(self, 'dbb_1x1'):
|
466 |
+
out += self.dbb_1x1(inputs)
|
467 |
+
out += self.dbb_avg(inputs)
|
468 |
+
out += self.dbb_1x1_kxk(inputs)
|
469 |
+
return self.nonlinear(out)
|
470 |
+
|
471 |
+
def init_gamma(self, gamma_value):
|
472 |
+
if hasattr(self, "dbb_origin"):
|
473 |
+
torch.nn.init.constant_(self.dbb_origin.bn.weight, gamma_value)
|
474 |
+
if hasattr(self, "dbb_1x1"):
|
475 |
+
torch.nn.init.constant_(self.dbb_1x1.bn.weight, gamma_value)
|
476 |
+
if hasattr(self, "dbb_avg"):
|
477 |
+
torch.nn.init.constant_(self.dbb_avg.avgbn.weight, gamma_value)
|
478 |
+
if hasattr(self, "dbb_1x1_kxk"):
|
479 |
+
torch.nn.init.constant_(self.dbb_1x1_kxk.bn2.weight, gamma_value)
|
480 |
+
|
481 |
+
def single_init(self):
|
482 |
+
self.init_gamma(0.0)
|
483 |
+
if hasattr(self, "dbb_origin"):
|
484 |
+
torch.nn.init.constant_(self.dbb_origin.bn.weight, 1.0)
|
485 |
+
|
486 |
+
|
487 |
+
class DetectBackend(nn.Module):
|
488 |
+
def __init__(self, weights='yolov6s.pt', device=None, dnn=True):
|
489 |
+
|
490 |
+
super().__init__()
|
491 |
+
assert isinstance(weights, str) and Path(weights).suffix == '.pt', f'{Path(weights).suffix} format is not supported.'
|
492 |
+
from yolov6.utils.checkpoint import load_checkpoint
|
493 |
+
model = load_checkpoint(weights, map_location=device)
|
494 |
+
stride = int(model.stride.max())
|
495 |
+
self.__dict__.update(locals()) # assign all variables to self
|
496 |
+
|
497 |
+
def forward(self, im, val=False):
|
498 |
+
y = self.model(im)
|
499 |
+
if isinstance(y, np.ndarray):
|
500 |
+
y = torch.tensor(y, device=self.device)
|
501 |
+
return y
|
yolov6/layers/dbb_transforms.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def transI_fusebn(kernel, bn):
|
7 |
+
gamma = bn.weight
|
8 |
+
std = (bn.running_var + bn.eps).sqrt()
|
9 |
+
return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), bn.bias - bn.running_mean * gamma / std
|
10 |
+
|
11 |
+
|
12 |
+
def transII_addbranch(kernels, biases):
|
13 |
+
return sum(kernels), sum(biases)
|
14 |
+
|
15 |
+
|
16 |
+
def transIII_1x1_kxk(k1, b1, k2, b2, groups):
|
17 |
+
if groups == 1:
|
18 |
+
k = F.conv2d(k2, k1.permute(1, 0, 2, 3)) #
|
19 |
+
b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3))
|
20 |
+
else:
|
21 |
+
k_slices = []
|
22 |
+
b_slices = []
|
23 |
+
k1_T = k1.permute(1, 0, 2, 3)
|
24 |
+
k1_group_width = k1.size(0) // groups
|
25 |
+
k2_group_width = k2.size(0) // groups
|
26 |
+
for g in range(groups):
|
27 |
+
k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :]
|
28 |
+
k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :]
|
29 |
+
k_slices.append(F.conv2d(k2_slice, k1_T_slice))
|
30 |
+
b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3)))
|
31 |
+
k, b_hat = transIV_depthconcat(k_slices, b_slices)
|
32 |
+
return k, b_hat + b2
|
33 |
+
|
34 |
+
|
35 |
+
def transIV_depthconcat(kernels, biases):
|
36 |
+
return torch.cat(kernels, dim=0), torch.cat(biases)
|
37 |
+
|
38 |
+
|
39 |
+
def transV_avg(channels, kernel_size, groups):
|
40 |
+
input_dim = channels // groups
|
41 |
+
k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
|
42 |
+
k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
|
43 |
+
return k
|
44 |
+
|
45 |
+
|
46 |
+
# This has not been tested with non-square kernels (kernel.size(2) != kernel.size(3)) nor even-size kernels
|
47 |
+
def transVI_multiscale(kernel, target_kernel_size):
|
48 |
+
H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2
|
49 |
+
W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2
|
50 |
+
return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])
|
yolov6/models/efficientrep.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from yolov6.layers.common import RepVGGBlock, RepBlock, SimSPPF
|
3 |
+
|
4 |
+
|
5 |
+
class EfficientRep(nn.Module):
|
6 |
+
'''EfficientRep Backbone
|
7 |
+
EfficientRep is handcrafted by hardware-aware neural network design.
|
8 |
+
With rep-style struct, EfficientRep is friendly to high-computation hardware(e.g. GPU).
|
9 |
+
'''
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
in_channels=3,
|
14 |
+
channels_list=None,
|
15 |
+
num_repeats=None,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
assert channels_list is not None
|
20 |
+
assert num_repeats is not None
|
21 |
+
|
22 |
+
self.stem = RepVGGBlock(
|
23 |
+
in_channels=in_channels,
|
24 |
+
out_channels=channels_list[0],
|
25 |
+
kernel_size=3,
|
26 |
+
stride=2
|
27 |
+
)
|
28 |
+
|
29 |
+
self.ERBlock_2 = nn.Sequential(
|
30 |
+
RepVGGBlock(
|
31 |
+
in_channels=channels_list[0],
|
32 |
+
out_channels=channels_list[1],
|
33 |
+
kernel_size=3,
|
34 |
+
stride=2
|
35 |
+
),
|
36 |
+
RepBlock(
|
37 |
+
in_channels=channels_list[1],
|
38 |
+
out_channels=channels_list[1],
|
39 |
+
n=num_repeats[1]
|
40 |
+
)
|
41 |
+
)
|
42 |
+
|
43 |
+
self.ERBlock_3 = nn.Sequential(
|
44 |
+
RepVGGBlock(
|
45 |
+
in_channels=channels_list[1],
|
46 |
+
out_channels=channels_list[2],
|
47 |
+
kernel_size=3,
|
48 |
+
stride=2
|
49 |
+
),
|
50 |
+
RepBlock(
|
51 |
+
in_channels=channels_list[2],
|
52 |
+
out_channels=channels_list[2],
|
53 |
+
n=num_repeats[2]
|
54 |
+
)
|
55 |
+
)
|
56 |
+
|
57 |
+
self.ERBlock_4 = nn.Sequential(
|
58 |
+
RepVGGBlock(
|
59 |
+
in_channels=channels_list[2],
|
60 |
+
out_channels=channels_list[3],
|
61 |
+
kernel_size=3,
|
62 |
+
stride=2
|
63 |
+
),
|
64 |
+
RepBlock(
|
65 |
+
in_channels=channels_list[3],
|
66 |
+
out_channels=channels_list[3],
|
67 |
+
n=num_repeats[3]
|
68 |
+
)
|
69 |
+
)
|
70 |
+
|
71 |
+
self.ERBlock_5 = nn.Sequential(
|
72 |
+
RepVGGBlock(
|
73 |
+
in_channels=channels_list[3],
|
74 |
+
out_channels=channels_list[4],
|
75 |
+
kernel_size=3,
|
76 |
+
stride=2,
|
77 |
+
),
|
78 |
+
RepBlock(
|
79 |
+
in_channels=channels_list[4],
|
80 |
+
out_channels=channels_list[4],
|
81 |
+
n=num_repeats[4]
|
82 |
+
),
|
83 |
+
SimSPPF(
|
84 |
+
in_channels=channels_list[4],
|
85 |
+
out_channels=channels_list[4],
|
86 |
+
kernel_size=5
|
87 |
+
)
|
88 |
+
)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
|
92 |
+
outputs = []
|
93 |
+
x = self.stem(x)
|
94 |
+
x = self.ERBlock_2(x)
|
95 |
+
x = self.ERBlock_3(x)
|
96 |
+
outputs.append(x)
|
97 |
+
x = self.ERBlock_4(x)
|
98 |
+
outputs.append(x)
|
99 |
+
x = self.ERBlock_5(x)
|
100 |
+
outputs.append(x)
|
101 |
+
|
102 |
+
return tuple(outputs)
|
yolov6/models/effidehead.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
from yolov6.layers.common import *
|
5 |
+
|
6 |
+
|
7 |
+
class Detect(nn.Module):
|
8 |
+
'''Efficient Decoupled Head
|
9 |
+
With hardware-aware degisn, the decoupled head is optimized with
|
10 |
+
hybridchannels methods.
|
11 |
+
'''
|
12 |
+
def __init__(self, num_classes=80, anchors=1, num_layers=3, inplace=True, head_layers=None): # detection layer
|
13 |
+
super().__init__()
|
14 |
+
assert head_layers is not None
|
15 |
+
self.nc = num_classes # number of classes
|
16 |
+
self.no = num_classes + 5 # number of outputs per anchor
|
17 |
+
self.nl = num_layers # number of detection layers
|
18 |
+
if isinstance(anchors, (list, tuple)):
|
19 |
+
self.na = len(anchors[0]) // 2
|
20 |
+
else:
|
21 |
+
self.na = anchors
|
22 |
+
self.anchors = anchors
|
23 |
+
self.grid = [torch.zeros(1)] * num_layers
|
24 |
+
self.prior_prob = 1e-2
|
25 |
+
self.inplace = inplace
|
26 |
+
stride = [8, 16, 32] # strides computed during build
|
27 |
+
self.stride = torch.tensor(stride)
|
28 |
+
|
29 |
+
# Init decouple head
|
30 |
+
self.cls_convs = nn.ModuleList()
|
31 |
+
self.reg_convs = nn.ModuleList()
|
32 |
+
self.cls_preds = nn.ModuleList()
|
33 |
+
self.reg_preds = nn.ModuleList()
|
34 |
+
self.obj_preds = nn.ModuleList()
|
35 |
+
self.stems = nn.ModuleList()
|
36 |
+
|
37 |
+
# Efficient decoupled head layers
|
38 |
+
for i in range(num_layers):
|
39 |
+
idx = i*6
|
40 |
+
self.stems.append(head_layers[idx])
|
41 |
+
self.cls_convs.append(head_layers[idx+1])
|
42 |
+
self.reg_convs.append(head_layers[idx+2])
|
43 |
+
self.cls_preds.append(head_layers[idx+3])
|
44 |
+
self.reg_preds.append(head_layers[idx+4])
|
45 |
+
self.obj_preds.append(head_layers[idx+5])
|
46 |
+
|
47 |
+
def initialize_biases(self):
|
48 |
+
for conv in self.cls_preds:
|
49 |
+
b = conv.bias.view(self.na, -1)
|
50 |
+
b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob))
|
51 |
+
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
52 |
+
for conv in self.obj_preds:
|
53 |
+
b = conv.bias.view(self.na, -1)
|
54 |
+
b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob))
|
55 |
+
conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
z = []
|
59 |
+
for i in range(self.nl):
|
60 |
+
x[i] = self.stems[i](x[i])
|
61 |
+
cls_x = x[i]
|
62 |
+
reg_x = x[i]
|
63 |
+
cls_feat = self.cls_convs[i](cls_x)
|
64 |
+
cls_output = self.cls_preds[i](cls_feat)
|
65 |
+
reg_feat = self.reg_convs[i](reg_x)
|
66 |
+
reg_output = self.reg_preds[i](reg_feat)
|
67 |
+
obj_output = self.obj_preds[i](reg_feat)
|
68 |
+
if self.training:
|
69 |
+
x[i] = torch.cat([reg_output, obj_output, cls_output], 1)
|
70 |
+
bs, _, ny, nx = x[i].shape
|
71 |
+
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
72 |
+
else:
|
73 |
+
y = torch.cat([reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1)
|
74 |
+
bs, _, ny, nx = y.shape
|
75 |
+
y = y.view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
76 |
+
if self.grid[i].shape[2:4] != y.shape[2:4]:
|
77 |
+
d = self.stride.device
|
78 |
+
yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)])
|
79 |
+
self.grid[i] = torch.stack((xv, yv), 2).view(1, self.na, ny, nx, 2).float()
|
80 |
+
if self.inplace:
|
81 |
+
y[..., 0:2] = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy
|
82 |
+
y[..., 2:4] = torch.exp(y[..., 2:4]) * self.stride[i] # wh
|
83 |
+
else:
|
84 |
+
xy = (y[..., 0:2] + self.grid[i]) * self.stride[i] # xy
|
85 |
+
wh = torch.exp(y[..., 2:4]) * self.stride[i] # wh
|
86 |
+
y = torch.cat((xy, wh, y[..., 4:]), -1)
|
87 |
+
z.append(y.view(bs, -1, self.no))
|
88 |
+
return x if self.training else torch.cat(z, 1)
|
89 |
+
|
90 |
+
|
91 |
+
def build_effidehead_layer(channels_list, num_anchors, num_classes):
|
92 |
+
head_layers = nn.Sequential(
|
93 |
+
# stem0
|
94 |
+
Conv(
|
95 |
+
in_channels=channels_list[6],
|
96 |
+
out_channels=channels_list[6],
|
97 |
+
kernel_size=1,
|
98 |
+
stride=1
|
99 |
+
),
|
100 |
+
# cls_conv0
|
101 |
+
Conv(
|
102 |
+
in_channels=channels_list[6],
|
103 |
+
out_channels=channels_list[6],
|
104 |
+
kernel_size=3,
|
105 |
+
stride=1
|
106 |
+
),
|
107 |
+
# reg_conv0
|
108 |
+
Conv(
|
109 |
+
in_channels=channels_list[6],
|
110 |
+
out_channels=channels_list[6],
|
111 |
+
kernel_size=3,
|
112 |
+
stride=1
|
113 |
+
),
|
114 |
+
# cls_pred0
|
115 |
+
nn.Conv2d(
|
116 |
+
in_channels=channels_list[6],
|
117 |
+
out_channels=num_classes * num_anchors,
|
118 |
+
kernel_size=1
|
119 |
+
),
|
120 |
+
# reg_pred0
|
121 |
+
nn.Conv2d(
|
122 |
+
in_channels=channels_list[6],
|
123 |
+
out_channels=4 * num_anchors,
|
124 |
+
kernel_size=1
|
125 |
+
),
|
126 |
+
# obj_pred0
|
127 |
+
nn.Conv2d(
|
128 |
+
in_channels=channels_list[6],
|
129 |
+
out_channels=1 * num_anchors,
|
130 |
+
kernel_size=1
|
131 |
+
),
|
132 |
+
# stem1
|
133 |
+
Conv(
|
134 |
+
in_channels=channels_list[8],
|
135 |
+
out_channels=channels_list[8],
|
136 |
+
kernel_size=1,
|
137 |
+
stride=1
|
138 |
+
),
|
139 |
+
# cls_conv1
|
140 |
+
Conv(
|
141 |
+
in_channels=channels_list[8],
|
142 |
+
out_channels=channels_list[8],
|
143 |
+
kernel_size=3,
|
144 |
+
stride=1
|
145 |
+
),
|
146 |
+
# reg_conv1
|
147 |
+
Conv(
|
148 |
+
in_channels=channels_list[8],
|
149 |
+
out_channels=channels_list[8],
|
150 |
+
kernel_size=3,
|
151 |
+
stride=1
|
152 |
+
),
|
153 |
+
# cls_pred1
|
154 |
+
nn.Conv2d(
|
155 |
+
in_channels=channels_list[8],
|
156 |
+
out_channels=num_classes * num_anchors,
|
157 |
+
kernel_size=1
|
158 |
+
),
|
159 |
+
# reg_pred1
|
160 |
+
nn.Conv2d(
|
161 |
+
in_channels=channels_list[8],
|
162 |
+
out_channels=4 * num_anchors,
|
163 |
+
kernel_size=1
|
164 |
+
),
|
165 |
+
# obj_pred1
|
166 |
+
nn.Conv2d(
|
167 |
+
in_channels=channels_list[8],
|
168 |
+
out_channels=1 * num_anchors,
|
169 |
+
kernel_size=1
|
170 |
+
),
|
171 |
+
# stem2
|
172 |
+
Conv(
|
173 |
+
in_channels=channels_list[10],
|
174 |
+
out_channels=channels_list[10],
|
175 |
+
kernel_size=1,
|
176 |
+
stride=1
|
177 |
+
),
|
178 |
+
# cls_conv2
|
179 |
+
Conv(
|
180 |
+
in_channels=channels_list[10],
|
181 |
+
out_channels=channels_list[10],
|
182 |
+
kernel_size=3,
|
183 |
+
stride=1
|
184 |
+
),
|
185 |
+
# reg_conv2
|
186 |
+
Conv(
|
187 |
+
in_channels=channels_list[10],
|
188 |
+
out_channels=channels_list[10],
|
189 |
+
kernel_size=3,
|
190 |
+
stride=1
|
191 |
+
),
|
192 |
+
# cls_pred2
|
193 |
+
nn.Conv2d(
|
194 |
+
in_channels=channels_list[10],
|
195 |
+
out_channels=num_classes * num_anchors,
|
196 |
+
kernel_size=1
|
197 |
+
),
|
198 |
+
# reg_pred2
|
199 |
+
nn.Conv2d(
|
200 |
+
in_channels=channels_list[10],
|
201 |
+
out_channels=4 * num_anchors,
|
202 |
+
kernel_size=1
|
203 |
+
),
|
204 |
+
# obj_pred2
|
205 |
+
nn.Conv2d(
|
206 |
+
in_channels=channels_list[10],
|
207 |
+
out_channels=1 * num_anchors,
|
208 |
+
kernel_size=1
|
209 |
+
)
|
210 |
+
)
|
211 |
+
return head_layers
|
yolov6/models/end2end.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import random
|
4 |
+
|
5 |
+
class ORT_NMS(torch.autograd.Function):
|
6 |
+
|
7 |
+
@staticmethod
|
8 |
+
def forward(ctx,
|
9 |
+
boxes,
|
10 |
+
scores,
|
11 |
+
max_output_boxes_per_class=torch.tensor([100]),
|
12 |
+
iou_threshold=torch.tensor([0.45]),
|
13 |
+
score_threshold=torch.tensor([0.25])):
|
14 |
+
device = boxes.device
|
15 |
+
batch = scores.shape[0]
|
16 |
+
num_det = random.randint(0, 100)
|
17 |
+
batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
|
18 |
+
idxs = torch.arange(100, 100 + num_det).to(device)
|
19 |
+
zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
|
20 |
+
selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
|
21 |
+
selected_indices = selected_indices.to(torch.int64)
|
22 |
+
return selected_indices
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
|
26 |
+
return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
|
27 |
+
|
28 |
+
class TRT_NMS(torch.autograd.Function):
|
29 |
+
@staticmethod
|
30 |
+
def forward(
|
31 |
+
ctx,
|
32 |
+
boxes,
|
33 |
+
scores,
|
34 |
+
background_class=-1,
|
35 |
+
box_coding=1,
|
36 |
+
iou_threshold=0.45,
|
37 |
+
max_output_boxes=100,
|
38 |
+
plugin_version="1",
|
39 |
+
score_activation=0,
|
40 |
+
score_threshold=0.25,
|
41 |
+
):
|
42 |
+
batch_size, num_boxes, num_classes = scores.shape
|
43 |
+
num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
|
44 |
+
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
|
45 |
+
det_scores = torch.randn(batch_size, max_output_boxes)
|
46 |
+
det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
|
47 |
+
|
48 |
+
return num_det, det_boxes, det_scores, det_classes
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def symbolic(g,
|
52 |
+
boxes,
|
53 |
+
scores,
|
54 |
+
background_class=-1,
|
55 |
+
box_coding=1,
|
56 |
+
iou_threshold=0.45,
|
57 |
+
max_output_boxes=100,
|
58 |
+
plugin_version="1",
|
59 |
+
score_activation=0,
|
60 |
+
score_threshold=0.25):
|
61 |
+
out = g.op("TRT::EfficientNMS_TRT",
|
62 |
+
boxes,
|
63 |
+
scores,
|
64 |
+
background_class_i=background_class,
|
65 |
+
box_coding_i=box_coding,
|
66 |
+
iou_threshold_f=iou_threshold,
|
67 |
+
max_output_boxes_i=max_output_boxes,
|
68 |
+
plugin_version_s=plugin_version,
|
69 |
+
score_activation_i=score_activation,
|
70 |
+
score_threshold_f=score_threshold,
|
71 |
+
outputs=4)
|
72 |
+
nums, boxes, scores, classes = out
|
73 |
+
return nums,boxes,scores,classes
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
class ONNX_ORT(nn.Module):
|
78 |
+
|
79 |
+
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None):
|
80 |
+
super().__init__()
|
81 |
+
self.device = device if device else torch.device("cpu")
|
82 |
+
self.max_obj = torch.tensor([max_obj]).to(device)
|
83 |
+
self.iou_threshold = torch.tensor([iou_thres]).to(device)
|
84 |
+
self.score_threshold = torch.tensor([score_thres]).to(device)
|
85 |
+
self.max_wh = max_wh
|
86 |
+
self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
|
87 |
+
dtype=torch.float32,
|
88 |
+
device=self.device)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
box = x[:, :, :4]
|
92 |
+
conf = x[:, :, 4:5]
|
93 |
+
score = x[:, :, 5:]
|
94 |
+
score *= conf
|
95 |
+
box @= self.convert_matrix
|
96 |
+
objScore, objCls = score.max(2, keepdim=True)
|
97 |
+
dis = objCls.float() * self.max_wh
|
98 |
+
nmsbox = box + dis
|
99 |
+
objScore1 = objScore.transpose(1, 2).contiguous()
|
100 |
+
selected_indices = ORT_NMS.apply(nmsbox, objScore1, self.max_obj, self.iou_threshold, self.score_threshold)
|
101 |
+
X, Y = selected_indices[:, 0], selected_indices[:, 2]
|
102 |
+
resBoxes = box[X, Y, :]
|
103 |
+
resClasses = objCls[X, Y, :].float()
|
104 |
+
resScores = objScore[X, Y, :]
|
105 |
+
X = X.unsqueeze(1).float()
|
106 |
+
return torch.cat([X, resBoxes, resClasses, resScores], 1)
|
107 |
+
|
108 |
+
class ONNX_TRT(nn.Module):
|
109 |
+
|
110 |
+
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None):
|
111 |
+
super().__init__()
|
112 |
+
assert max_wh is None
|
113 |
+
self.device = device if device else torch.device('cpu')
|
114 |
+
self.background_class = -1,
|
115 |
+
self.box_coding = 1,
|
116 |
+
self.iou_threshold = iou_thres
|
117 |
+
self.max_obj = max_obj
|
118 |
+
self.plugin_version = '1'
|
119 |
+
self.score_activation = 0
|
120 |
+
self.score_threshold = score_thres
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
box = x[:, :, :4]
|
124 |
+
conf = x[:, :, 4:5]
|
125 |
+
score = x[:, :, 5:]
|
126 |
+
score *= conf
|
127 |
+
num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(box, score, self.background_class, self.box_coding,
|
128 |
+
self.iou_threshold, self.max_obj,
|
129 |
+
self.plugin_version, self.score_activation,
|
130 |
+
self.score_threshold)
|
131 |
+
return num_det, det_boxes, det_scores, det_classes
|
132 |
+
|
133 |
+
|
134 |
+
class End2End(nn.Module):
|
135 |
+
|
136 |
+
def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None):
|
137 |
+
super().__init__()
|
138 |
+
device = device if device else torch.device('cpu')
|
139 |
+
self.model = model.to(device)
|
140 |
+
self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
|
141 |
+
self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device)
|
142 |
+
self.end2end.eval()
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
x = self.model(x)
|
146 |
+
x = self.end2end(x)
|
147 |
+
return x
|
yolov6/models/loss.py
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
# The code is based on
|
5 |
+
# https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/yolo_head.py
|
6 |
+
# Copyright (c) Megvii, Inc. and its affiliates.
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import numpy as np
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from yolov6.utils.figure_iou import IOUloss, pairwise_bbox_iou
|
13 |
+
|
14 |
+
|
15 |
+
class ComputeLoss:
|
16 |
+
'''Loss computation func.
|
17 |
+
This func contains SimOTA and siou loss.
|
18 |
+
'''
|
19 |
+
def __init__(self,
|
20 |
+
reg_weight=5.0,
|
21 |
+
iou_weight=3.0,
|
22 |
+
cls_weight=1.0,
|
23 |
+
center_radius=2.5,
|
24 |
+
eps=1e-7,
|
25 |
+
in_channels=[256, 512, 1024],
|
26 |
+
strides=[8, 16, 32],
|
27 |
+
n_anchors=1,
|
28 |
+
iou_type='ciou'
|
29 |
+
):
|
30 |
+
|
31 |
+
self.reg_weight = reg_weight
|
32 |
+
self.iou_weight = iou_weight
|
33 |
+
self.cls_weight = cls_weight
|
34 |
+
|
35 |
+
self.center_radius = center_radius
|
36 |
+
self.eps = eps
|
37 |
+
self.n_anchors = n_anchors
|
38 |
+
self.strides = strides
|
39 |
+
self.grids = [torch.zeros(1)] * len(in_channels)
|
40 |
+
|
41 |
+
# Define criteria
|
42 |
+
self.l1_loss = nn.L1Loss(reduction="none")
|
43 |
+
self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
|
44 |
+
self.iou_loss = IOUloss(iou_type=iou_type, reduction="none")
|
45 |
+
|
46 |
+
def __call__(
|
47 |
+
self,
|
48 |
+
outputs,
|
49 |
+
targets
|
50 |
+
):
|
51 |
+
dtype = outputs[0].type()
|
52 |
+
device = targets.device
|
53 |
+
loss_cls, loss_obj, loss_iou, loss_l1 = torch.zeros(1, device=device), torch.zeros(1, device=device), \
|
54 |
+
torch.zeros(1, device=device), torch.zeros(1, device=device)
|
55 |
+
num_classes = outputs[0].shape[-1] - 5
|
56 |
+
|
57 |
+
outputs, outputs_origin, gt_bboxes_scale, xy_shifts, expanded_strides = self.get_outputs_and_grids(
|
58 |
+
outputs, self.strides, dtype, device)
|
59 |
+
|
60 |
+
total_num_anchors = outputs.shape[1]
|
61 |
+
bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
|
62 |
+
bbox_preds_org = outputs_origin[:, :, :4] # [batch, n_anchors_all, 4]
|
63 |
+
obj_preds = outputs[:, :, 4].unsqueeze(-1) # [batch, n_anchors_all, 1]
|
64 |
+
cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
|
65 |
+
|
66 |
+
# targets
|
67 |
+
batch_size = bbox_preds.shape[0]
|
68 |
+
targets_list = np.zeros((batch_size, 1, 5)).tolist()
|
69 |
+
for i, item in enumerate(targets.cpu().numpy().tolist()):
|
70 |
+
targets_list[int(item[0])].append(item[1:])
|
71 |
+
max_len = max((len(l) for l in targets_list))
|
72 |
+
|
73 |
+
targets = torch.from_numpy(np.array(list(map(lambda l:l + [[-1,0,0,0,0]]*(max_len - len(l)), targets_list)))[:,1:,:]).to(targets.device)
|
74 |
+
num_targets_list = (targets.sum(dim=2) > 0).sum(dim=1) # number of objects
|
75 |
+
|
76 |
+
num_fg, num_gts = 0, 0
|
77 |
+
cls_targets, reg_targets, l1_targets, obj_targets, fg_masks = [], [], [], [], []
|
78 |
+
|
79 |
+
for batch_idx in range(batch_size):
|
80 |
+
num_gt = int(num_targets_list[batch_idx])
|
81 |
+
num_gts += num_gt
|
82 |
+
if num_gt == 0:
|
83 |
+
cls_target = outputs.new_zeros((0, num_classes))
|
84 |
+
reg_target = outputs.new_zeros((0, 4))
|
85 |
+
l1_target = outputs.new_zeros((0, 4))
|
86 |
+
obj_target = outputs.new_zeros((total_num_anchors, 1))
|
87 |
+
fg_mask = outputs.new_zeros(total_num_anchors).bool()
|
88 |
+
else:
|
89 |
+
|
90 |
+
gt_bboxes_per_image = targets[batch_idx, :num_gt, 1:5].mul_(gt_bboxes_scale)
|
91 |
+
gt_classes = targets[batch_idx, :num_gt, 0]
|
92 |
+
bboxes_preds_per_image = bbox_preds[batch_idx]
|
93 |
+
cls_preds_per_image = cls_preds[batch_idx]
|
94 |
+
obj_preds_per_image = obj_preds[batch_idx]
|
95 |
+
|
96 |
+
try:
|
97 |
+
(
|
98 |
+
gt_matched_classes,
|
99 |
+
fg_mask,
|
100 |
+
pred_ious_this_matching,
|
101 |
+
matched_gt_inds,
|
102 |
+
num_fg_img,
|
103 |
+
) = self.get_assignments(
|
104 |
+
batch_idx,
|
105 |
+
num_gt,
|
106 |
+
total_num_anchors,
|
107 |
+
gt_bboxes_per_image,
|
108 |
+
gt_classes,
|
109 |
+
bboxes_preds_per_image,
|
110 |
+
cls_preds_per_image,
|
111 |
+
obj_preds_per_image,
|
112 |
+
expanded_strides,
|
113 |
+
xy_shifts,
|
114 |
+
num_classes
|
115 |
+
)
|
116 |
+
|
117 |
+
except RuntimeError:
|
118 |
+
print(
|
119 |
+
"OOM RuntimeError is raised due to the huge memory cost during label assignment. \
|
120 |
+
CPU mode is applied in this batch. If you want to avoid this issue, \
|
121 |
+
try to reduce the batch size or image size."
|
122 |
+
)
|
123 |
+
torch.cuda.empty_cache()
|
124 |
+
print("------------CPU Mode for This Batch-------------")
|
125 |
+
|
126 |
+
_gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
|
127 |
+
_gt_classes = gt_classes.cpu().float()
|
128 |
+
_bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
|
129 |
+
_cls_preds_per_image = cls_preds_per_image.cpu().float()
|
130 |
+
_obj_preds_per_image = obj_preds_per_image.cpu().float()
|
131 |
+
|
132 |
+
_expanded_strides = expanded_strides.cpu().float()
|
133 |
+
_xy_shifts = xy_shifts.cpu()
|
134 |
+
|
135 |
+
(
|
136 |
+
gt_matched_classes,
|
137 |
+
fg_mask,
|
138 |
+
pred_ious_this_matching,
|
139 |
+
matched_gt_inds,
|
140 |
+
num_fg_img,
|
141 |
+
) = self.get_assignments(
|
142 |
+
batch_idx,
|
143 |
+
num_gt,
|
144 |
+
total_num_anchors,
|
145 |
+
_gt_bboxes_per_image,
|
146 |
+
_gt_classes,
|
147 |
+
_bboxes_preds_per_image,
|
148 |
+
_cls_preds_per_image,
|
149 |
+
_obj_preds_per_image,
|
150 |
+
_expanded_strides,
|
151 |
+
_xy_shifts,
|
152 |
+
num_classes
|
153 |
+
)
|
154 |
+
|
155 |
+
gt_matched_classes = gt_matched_classes.cuda()
|
156 |
+
fg_mask = fg_mask.cuda()
|
157 |
+
pred_ious_this_matching = pred_ious_this_matching.cuda()
|
158 |
+
matched_gt_inds = matched_gt_inds.cuda()
|
159 |
+
|
160 |
+
torch.cuda.empty_cache()
|
161 |
+
num_fg += num_fg_img
|
162 |
+
if num_fg_img > 0:
|
163 |
+
cls_target = F.one_hot(
|
164 |
+
gt_matched_classes.to(torch.int64), num_classes
|
165 |
+
) * pred_ious_this_matching.unsqueeze(-1)
|
166 |
+
obj_target = fg_mask.unsqueeze(-1)
|
167 |
+
reg_target = gt_bboxes_per_image[matched_gt_inds]
|
168 |
+
|
169 |
+
l1_target = self.get_l1_target(
|
170 |
+
outputs.new_zeros((num_fg_img, 4)),
|
171 |
+
gt_bboxes_per_image[matched_gt_inds],
|
172 |
+
expanded_strides[0][fg_mask],
|
173 |
+
xy_shifts=xy_shifts[0][fg_mask],
|
174 |
+
)
|
175 |
+
|
176 |
+
cls_targets.append(cls_target)
|
177 |
+
reg_targets.append(reg_target)
|
178 |
+
obj_targets.append(obj_target)
|
179 |
+
l1_targets.append(l1_target)
|
180 |
+
fg_masks.append(fg_mask)
|
181 |
+
|
182 |
+
cls_targets = torch.cat(cls_targets, 0)
|
183 |
+
reg_targets = torch.cat(reg_targets, 0)
|
184 |
+
obj_targets = torch.cat(obj_targets, 0)
|
185 |
+
l1_targets = torch.cat(l1_targets, 0)
|
186 |
+
fg_masks = torch.cat(fg_masks, 0)
|
187 |
+
|
188 |
+
num_fg = max(num_fg, 1)
|
189 |
+
# loss
|
190 |
+
loss_iou += (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks].T, reg_targets)).sum() / num_fg
|
191 |
+
loss_l1 += (self.l1_loss(bbox_preds_org.view(-1, 4)[fg_masks], l1_targets)).sum() / num_fg
|
192 |
+
|
193 |
+
loss_obj += (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets*1.0)).sum() / num_fg
|
194 |
+
loss_cls += (self.bcewithlog_loss(cls_preds.view(-1, num_classes)[fg_masks], cls_targets)).sum() / num_fg
|
195 |
+
|
196 |
+
total_losses = self.reg_weight * loss_iou + loss_l1 + loss_obj + loss_cls
|
197 |
+
return total_losses, torch.cat((self.reg_weight * loss_iou, loss_l1, loss_obj, loss_cls)).detach()
|
198 |
+
|
199 |
+
def decode_output(self, output, k, stride, dtype, device):
|
200 |
+
grid = self.grids[k].to(device)
|
201 |
+
batch_size = output.shape[0]
|
202 |
+
hsize, wsize = output.shape[2:4]
|
203 |
+
if grid.shape[2:4] != output.shape[2:4]:
|
204 |
+
yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)])
|
205 |
+
grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype).to(device)
|
206 |
+
self.grids[k] = grid
|
207 |
+
|
208 |
+
output = output.reshape(batch_size, self.n_anchors * hsize * wsize, -1)
|
209 |
+
output_origin = output.clone()
|
210 |
+
grid = grid.view(1, -1, 2)
|
211 |
+
|
212 |
+
output[..., :2] = (output[..., :2] + grid) * stride
|
213 |
+
output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
|
214 |
+
|
215 |
+
return output, output_origin, grid, hsize, wsize
|
216 |
+
|
217 |
+
def get_outputs_and_grids(self, outputs, strides, dtype, device):
|
218 |
+
xy_shifts = []
|
219 |
+
expanded_strides = []
|
220 |
+
outputs_new = []
|
221 |
+
outputs_origin = []
|
222 |
+
|
223 |
+
for k, output in enumerate(outputs):
|
224 |
+
output, output_origin, grid, feat_h, feat_w = self.decode_output(
|
225 |
+
output, k, strides[k], dtype, device)
|
226 |
+
|
227 |
+
xy_shift = grid
|
228 |
+
expanded_stride = torch.full((1, grid.shape[1], 1), strides[k], dtype=grid.dtype, device=grid.device)
|
229 |
+
|
230 |
+
xy_shifts.append(xy_shift)
|
231 |
+
expanded_strides.append(expanded_stride)
|
232 |
+
outputs_new.append(output)
|
233 |
+
outputs_origin.append(output_origin)
|
234 |
+
|
235 |
+
xy_shifts = torch.cat(xy_shifts, 1) # [1, n_anchors_all, 2]
|
236 |
+
expanded_strides = torch.cat(expanded_strides, 1) # [1, n_anchors_all, 1]
|
237 |
+
outputs_origin = torch.cat(outputs_origin, 1)
|
238 |
+
outputs = torch.cat(outputs_new, 1)
|
239 |
+
|
240 |
+
feat_h *= strides[-1]
|
241 |
+
feat_w *= strides[-1]
|
242 |
+
gt_bboxes_scale = torch.Tensor([[feat_w, feat_h, feat_w, feat_h]]).type_as(outputs)
|
243 |
+
|
244 |
+
return outputs, outputs_origin, gt_bboxes_scale, xy_shifts, expanded_strides
|
245 |
+
|
246 |
+
def get_l1_target(self, l1_target, gt, stride, xy_shifts, eps=1e-8):
|
247 |
+
|
248 |
+
l1_target[:, 0:2] = gt[:, 0:2] / stride - xy_shifts
|
249 |
+
l1_target[:, 2:4] = torch.log(gt[:, 2:4] / stride + eps)
|
250 |
+
return l1_target
|
251 |
+
|
252 |
+
@torch.no_grad()
|
253 |
+
def get_assignments(
|
254 |
+
self,
|
255 |
+
batch_idx,
|
256 |
+
num_gt,
|
257 |
+
total_num_anchors,
|
258 |
+
gt_bboxes_per_image,
|
259 |
+
gt_classes,
|
260 |
+
bboxes_preds_per_image,
|
261 |
+
cls_preds_per_image,
|
262 |
+
obj_preds_per_image,
|
263 |
+
expanded_strides,
|
264 |
+
xy_shifts,
|
265 |
+
num_classes
|
266 |
+
):
|
267 |
+
|
268 |
+
fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
|
269 |
+
gt_bboxes_per_image,
|
270 |
+
expanded_strides,
|
271 |
+
xy_shifts,
|
272 |
+
total_num_anchors,
|
273 |
+
num_gt,
|
274 |
+
)
|
275 |
+
|
276 |
+
bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
|
277 |
+
cls_preds_ = cls_preds_per_image[fg_mask]
|
278 |
+
obj_preds_ = obj_preds_per_image[fg_mask]
|
279 |
+
num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
|
280 |
+
|
281 |
+
# cost
|
282 |
+
pair_wise_ious = pairwise_bbox_iou(gt_bboxes_per_image, bboxes_preds_per_image, box_format='xywh')
|
283 |
+
pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
|
284 |
+
|
285 |
+
gt_cls_per_image = (
|
286 |
+
F.one_hot(gt_classes.to(torch.int64), num_classes)
|
287 |
+
.float()
|
288 |
+
.unsqueeze(1)
|
289 |
+
.repeat(1, num_in_boxes_anchor, 1)
|
290 |
+
)
|
291 |
+
|
292 |
+
with torch.cuda.amp.autocast(enabled=False):
|
293 |
+
cls_preds_ = (
|
294 |
+
cls_preds_.float().sigmoid_().unsqueeze(0).repeat(num_gt, 1, 1)
|
295 |
+
* obj_preds_.float().sigmoid_().unsqueeze(0).repeat(num_gt, 1, 1)
|
296 |
+
)
|
297 |
+
pair_wise_cls_loss = F.binary_cross_entropy(
|
298 |
+
cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
|
299 |
+
).sum(-1)
|
300 |
+
del cls_preds_, obj_preds_
|
301 |
+
|
302 |
+
cost = (
|
303 |
+
self.cls_weight * pair_wise_cls_loss
|
304 |
+
+ self.iou_weight * pair_wise_ious_loss
|
305 |
+
+ 100000.0 * (~is_in_boxes_and_center)
|
306 |
+
)
|
307 |
+
|
308 |
+
(
|
309 |
+
num_fg,
|
310 |
+
gt_matched_classes,
|
311 |
+
pred_ious_this_matching,
|
312 |
+
matched_gt_inds,
|
313 |
+
) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
|
314 |
+
|
315 |
+
del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
|
316 |
+
|
317 |
+
return (
|
318 |
+
gt_matched_classes,
|
319 |
+
fg_mask,
|
320 |
+
pred_ious_this_matching,
|
321 |
+
matched_gt_inds,
|
322 |
+
num_fg,
|
323 |
+
)
|
324 |
+
|
325 |
+
def get_in_boxes_info(
|
326 |
+
self,
|
327 |
+
gt_bboxes_per_image,
|
328 |
+
expanded_strides,
|
329 |
+
xy_shifts,
|
330 |
+
total_num_anchors,
|
331 |
+
num_gt,
|
332 |
+
):
|
333 |
+
expanded_strides_per_image = expanded_strides[0]
|
334 |
+
xy_shifts_per_image = xy_shifts[0] * expanded_strides_per_image
|
335 |
+
xy_centers_per_image = (
|
336 |
+
(xy_shifts_per_image + 0.5 * expanded_strides_per_image)
|
337 |
+
.unsqueeze(0)
|
338 |
+
.repeat(num_gt, 1, 1)
|
339 |
+
) # [n_anchor, 2] -> [n_gt, n_anchor, 2]
|
340 |
+
|
341 |
+
gt_bboxes_per_image_lt = (
|
342 |
+
(gt_bboxes_per_image[:, 0:2] - 0.5 * gt_bboxes_per_image[:, 2:4])
|
343 |
+
.unsqueeze(1)
|
344 |
+
.repeat(1, total_num_anchors, 1)
|
345 |
+
)
|
346 |
+
gt_bboxes_per_image_rb = (
|
347 |
+
(gt_bboxes_per_image[:, 0:2] + 0.5 * gt_bboxes_per_image[:, 2:4])
|
348 |
+
.unsqueeze(1)
|
349 |
+
.repeat(1, total_num_anchors, 1)
|
350 |
+
) # [n_gt, 2] -> [n_gt, n_anchor, 2]
|
351 |
+
|
352 |
+
b_lt = xy_centers_per_image - gt_bboxes_per_image_lt
|
353 |
+
b_rb = gt_bboxes_per_image_rb - xy_centers_per_image
|
354 |
+
bbox_deltas = torch.cat([b_lt, b_rb], 2)
|
355 |
+
|
356 |
+
is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
|
357 |
+
is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
|
358 |
+
|
359 |
+
# in fixed center
|
360 |
+
gt_bboxes_per_image_lt = (gt_bboxes_per_image[:, 0:2]).unsqueeze(1).repeat(
|
361 |
+
1, total_num_anchors, 1
|
362 |
+
) - self.center_radius * expanded_strides_per_image.unsqueeze(0)
|
363 |
+
gt_bboxes_per_image_rb = (gt_bboxes_per_image[:, 0:2]).unsqueeze(1).repeat(
|
364 |
+
1, total_num_anchors, 1
|
365 |
+
) + self.center_radius * expanded_strides_per_image.unsqueeze(0)
|
366 |
+
|
367 |
+
c_lt = xy_centers_per_image - gt_bboxes_per_image_lt
|
368 |
+
c_rb = gt_bboxes_per_image_rb - xy_centers_per_image
|
369 |
+
center_deltas = torch.cat([c_lt, c_rb], 2)
|
370 |
+
is_in_centers = center_deltas.min(dim=-1).values > 0.0
|
371 |
+
is_in_centers_all = is_in_centers.sum(dim=0) > 0
|
372 |
+
|
373 |
+
# in boxes and in centers
|
374 |
+
is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
|
375 |
+
|
376 |
+
is_in_boxes_and_center = (
|
377 |
+
is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
|
378 |
+
)
|
379 |
+
return is_in_boxes_anchor, is_in_boxes_and_center
|
380 |
+
|
381 |
+
def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
|
382 |
+
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
|
383 |
+
ious_in_boxes_matrix = pair_wise_ious
|
384 |
+
n_candidate_k = min(10, ious_in_boxes_matrix.size(1))
|
385 |
+
topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
|
386 |
+
dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
|
387 |
+
dynamic_ks = dynamic_ks.tolist()
|
388 |
+
|
389 |
+
for gt_idx in range(num_gt):
|
390 |
+
_, pos_idx = torch.topk(
|
391 |
+
cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
|
392 |
+
)
|
393 |
+
matching_matrix[gt_idx][pos_idx] = 1
|
394 |
+
del topk_ious, dynamic_ks, pos_idx
|
395 |
+
|
396 |
+
anchor_matching_gt = matching_matrix.sum(0)
|
397 |
+
if (anchor_matching_gt > 1).sum() > 0:
|
398 |
+
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
|
399 |
+
matching_matrix[:, anchor_matching_gt > 1] *= 0
|
400 |
+
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
|
401 |
+
fg_mask_inboxes = matching_matrix.sum(0) > 0
|
402 |
+
num_fg = fg_mask_inboxes.sum().item()
|
403 |
+
fg_mask[fg_mask.clone()] = fg_mask_inboxes
|
404 |
+
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
|
405 |
+
gt_matched_classes = gt_classes[matched_gt_inds]
|
406 |
+
|
407 |
+
pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
|
408 |
+
fg_mask_inboxes
|
409 |
+
]
|
410 |
+
|
411 |
+
return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
|
yolov6/models/reppan.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from yolov6.layers.common import RepBlock, SimConv, Transpose
|
4 |
+
|
5 |
+
|
6 |
+
class RepPANNeck(nn.Module):
|
7 |
+
"""RepPANNeck Module
|
8 |
+
EfficientRep is the default backbone of this model.
|
9 |
+
RepPANNeck has the balance of feature fusion ability and hardware efficiency.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
channels_list=None,
|
15 |
+
num_repeats=None
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
assert channels_list is not None
|
20 |
+
assert num_repeats is not None
|
21 |
+
|
22 |
+
self.Rep_p4 = RepBlock(
|
23 |
+
in_channels=channels_list[3] + channels_list[5],
|
24 |
+
out_channels=channels_list[5],
|
25 |
+
n=num_repeats[5],
|
26 |
+
)
|
27 |
+
|
28 |
+
self.Rep_p3 = RepBlock(
|
29 |
+
in_channels=channels_list[2] + channels_list[6],
|
30 |
+
out_channels=channels_list[6],
|
31 |
+
n=num_repeats[6]
|
32 |
+
)
|
33 |
+
|
34 |
+
self.Rep_n3 = RepBlock(
|
35 |
+
in_channels=channels_list[6] + channels_list[7],
|
36 |
+
out_channels=channels_list[8],
|
37 |
+
n=num_repeats[7],
|
38 |
+
)
|
39 |
+
|
40 |
+
self.Rep_n4 = RepBlock(
|
41 |
+
in_channels=channels_list[5] + channels_list[9],
|
42 |
+
out_channels=channels_list[10],
|
43 |
+
n=num_repeats[8]
|
44 |
+
)
|
45 |
+
|
46 |
+
self.reduce_layer0 = SimConv(
|
47 |
+
in_channels=channels_list[4],
|
48 |
+
out_channels=channels_list[5],
|
49 |
+
kernel_size=1,
|
50 |
+
stride=1
|
51 |
+
)
|
52 |
+
|
53 |
+
self.upsample0 = Transpose(
|
54 |
+
in_channels=channels_list[5],
|
55 |
+
out_channels=channels_list[5],
|
56 |
+
)
|
57 |
+
|
58 |
+
self.reduce_layer1 = SimConv(
|
59 |
+
in_channels=channels_list[5],
|
60 |
+
out_channels=channels_list[6],
|
61 |
+
kernel_size=1,
|
62 |
+
stride=1
|
63 |
+
)
|
64 |
+
|
65 |
+
self.upsample1 = Transpose(
|
66 |
+
in_channels=channels_list[6],
|
67 |
+
out_channels=channels_list[6]
|
68 |
+
)
|
69 |
+
|
70 |
+
self.downsample2 = SimConv(
|
71 |
+
in_channels=channels_list[6],
|
72 |
+
out_channels=channels_list[7],
|
73 |
+
kernel_size=3,
|
74 |
+
stride=2
|
75 |
+
)
|
76 |
+
|
77 |
+
self.downsample1 = SimConv(
|
78 |
+
in_channels=channels_list[8],
|
79 |
+
out_channels=channels_list[9],
|
80 |
+
kernel_size=3,
|
81 |
+
stride=2
|
82 |
+
)
|
83 |
+
|
84 |
+
def forward(self, input):
|
85 |
+
|
86 |
+
(x2, x1, x0) = input
|
87 |
+
|
88 |
+
fpn_out0 = self.reduce_layer0(x0)
|
89 |
+
upsample_feat0 = self.upsample0(fpn_out0)
|
90 |
+
f_concat_layer0 = torch.cat([upsample_feat0, x1], 1)
|
91 |
+
f_out0 = self.Rep_p4(f_concat_layer0)
|
92 |
+
|
93 |
+
fpn_out1 = self.reduce_layer1(f_out0)
|
94 |
+
upsample_feat1 = self.upsample1(fpn_out1)
|
95 |
+
f_concat_layer1 = torch.cat([upsample_feat1, x2], 1)
|
96 |
+
pan_out2 = self.Rep_p3(f_concat_layer1)
|
97 |
+
|
98 |
+
down_feat1 = self.downsample2(pan_out2)
|
99 |
+
p_concat_layer1 = torch.cat([down_feat1, fpn_out1], 1)
|
100 |
+
pan_out1 = self.Rep_n3(p_concat_layer1)
|
101 |
+
|
102 |
+
down_feat0 = self.downsample1(pan_out1)
|
103 |
+
p_concat_layer2 = torch.cat([down_feat0, fpn_out0], 1)
|
104 |
+
pan_out0 = self.Rep_n4(p_concat_layer2)
|
105 |
+
|
106 |
+
outputs = [pan_out2, pan_out1, pan_out0]
|
107 |
+
|
108 |
+
return outputs
|
yolov6/models/yolo.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
import math
|
4 |
+
import torch.nn as nn
|
5 |
+
from yolov6.layers.common import *
|
6 |
+
from yolov6.utils.torch_utils import initialize_weights
|
7 |
+
from yolov6.models.efficientrep import EfficientRep
|
8 |
+
from yolov6.models.reppan import RepPANNeck
|
9 |
+
from yolov6.models.effidehead import Detect, build_effidehead_layer
|
10 |
+
|
11 |
+
|
12 |
+
class Model(nn.Module):
|
13 |
+
'''YOLOv6 model with backbone, neck and head.
|
14 |
+
The default parts are EfficientRep Backbone, Rep-PAN and
|
15 |
+
Efficient Decoupled Head.
|
16 |
+
'''
|
17 |
+
def __init__(self, config, channels=3, num_classes=None, anchors=None): # model, input channels, number of classes
|
18 |
+
super().__init__()
|
19 |
+
# Build network
|
20 |
+
num_layers = config.model.head.num_layers
|
21 |
+
self.backbone, self.neck, self.detect = build_network(config, channels, num_classes, anchors, num_layers)
|
22 |
+
|
23 |
+
# Init Detect head
|
24 |
+
begin_indices = config.model.head.begin_indices
|
25 |
+
out_indices_head = config.model.head.out_indices
|
26 |
+
self.stride = self.detect.stride
|
27 |
+
self.detect.i = begin_indices
|
28 |
+
self.detect.f = out_indices_head
|
29 |
+
self.detect.initialize_biases()
|
30 |
+
|
31 |
+
# Init weights
|
32 |
+
initialize_weights(self)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
x = self.backbone(x)
|
36 |
+
x = self.neck(x)
|
37 |
+
x = self.detect(x)
|
38 |
+
return x
|
39 |
+
|
40 |
+
def _apply(self, fn):
|
41 |
+
self = super()._apply(fn)
|
42 |
+
self.detect.stride = fn(self.detect.stride)
|
43 |
+
self.detect.grid = list(map(fn, self.detect.grid))
|
44 |
+
return self
|
45 |
+
|
46 |
+
|
47 |
+
def make_divisible(x, divisor):
|
48 |
+
# Upward revision the value x to make it evenly divisible by the divisor.
|
49 |
+
return math.ceil(x / divisor) * divisor
|
50 |
+
|
51 |
+
|
52 |
+
def build_network(config, channels, num_classes, anchors, num_layers):
|
53 |
+
depth_mul = config.model.depth_multiple
|
54 |
+
width_mul = config.model.width_multiple
|
55 |
+
num_repeat_backbone = config.model.backbone.num_repeats
|
56 |
+
channels_list_backbone = config.model.backbone.out_channels
|
57 |
+
num_repeat_neck = config.model.neck.num_repeats
|
58 |
+
channels_list_neck = config.model.neck.out_channels
|
59 |
+
num_anchors = config.model.head.anchors
|
60 |
+
num_repeat = [(max(round(i * depth_mul), 1) if i > 1 else i) for i in (num_repeat_backbone + num_repeat_neck)]
|
61 |
+
channels_list = [make_divisible(i * width_mul, 8) for i in (channels_list_backbone + channels_list_neck)]
|
62 |
+
|
63 |
+
backbone = EfficientRep(
|
64 |
+
in_channels=channels,
|
65 |
+
channels_list=channels_list,
|
66 |
+
num_repeats=num_repeat
|
67 |
+
)
|
68 |
+
|
69 |
+
neck = RepPANNeck(
|
70 |
+
channels_list=channels_list,
|
71 |
+
num_repeats=num_repeat
|
72 |
+
)
|
73 |
+
|
74 |
+
head_layers = build_effidehead_layer(channels_list, num_anchors, num_classes)
|
75 |
+
|
76 |
+
head = Detect(num_classes, anchors, num_layers, head_layers=head_layers)
|
77 |
+
|
78 |
+
return backbone, neck, head
|
79 |
+
|
80 |
+
|
81 |
+
def build_model(cfg, num_classes, device):
|
82 |
+
model = Model(cfg, channels=3, num_classes=num_classes, anchors=cfg.model.head.anchors).to(device)
|
83 |
+
return model
|
yolov6/solver/build.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
import os
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
def build_optimizer(cfg, model):
|
11 |
+
""" Build optimizer from cfg file."""
|
12 |
+
g_bnw, g_w, g_b = [], [], []
|
13 |
+
for v in model.modules():
|
14 |
+
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
|
15 |
+
g_b.append(v.bias)
|
16 |
+
if isinstance(v, nn.BatchNorm2d):
|
17 |
+
g_bnw.append(v.weight)
|
18 |
+
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
|
19 |
+
g_w.append(v.weight)
|
20 |
+
|
21 |
+
assert cfg.solver.optim == 'SGD' or 'Adam', 'ERROR: unknown optimizer, use SGD defaulted'
|
22 |
+
if cfg.solver.optim == 'SGD':
|
23 |
+
optimizer = torch.optim.SGD(g_bnw, lr=cfg.solver.lr0, momentum=cfg.solver.momentum, nesterov=True)
|
24 |
+
elif cfg.solver.optim == 'Adam':
|
25 |
+
optimizer = torch.optim.Adam(g_bnw, lr=cfg.solver.lr0, betas=(cfg.solver.momentum, 0.999))
|
26 |
+
|
27 |
+
optimizer.add_param_group({'params': g_w, 'weight_decay': cfg.solver.weight_decay})
|
28 |
+
optimizer.add_param_group({'params': g_b})
|
29 |
+
|
30 |
+
del g_bnw, g_w, g_b
|
31 |
+
return optimizer
|
32 |
+
|
33 |
+
|
34 |
+
def build_lr_scheduler(cfg, optimizer, epochs):
|
35 |
+
"""Build learning rate scheduler from cfg file."""
|
36 |
+
if cfg.solver.lr_scheduler == 'Cosine':
|
37 |
+
lf = lambda x: ((1 - math.cos(x * math.pi / epochs)) / 2) * (cfg.solver.lrf - 1) + 1
|
38 |
+
else:
|
39 |
+
LOGGER.error('unknown lr scheduler, use Cosine defaulted')
|
40 |
+
|
41 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
42 |
+
return scheduler, lf
|
yolov6/utils/Arial.ttf
ADDED
Binary file (773 kB). View file
|
|
yolov6/utils/checkpoint.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import torch
|
6 |
+
import os.path as osp
|
7 |
+
from yolov6.utils.events import LOGGER
|
8 |
+
from yolov6.utils.torch_utils import fuse_model
|
9 |
+
|
10 |
+
|
11 |
+
def load_state_dict(weights, model, map_location=None):
|
12 |
+
"""Load weights from checkpoint file, only assign weights those layers' name and shape are match."""
|
13 |
+
ckpt = torch.load(weights, map_location=map_location)
|
14 |
+
state_dict = ckpt['model'].float().state_dict()
|
15 |
+
model_state_dict = model.state_dict()
|
16 |
+
state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
|
17 |
+
model.load_state_dict(state_dict, strict=False)
|
18 |
+
del ckpt, state_dict, model_state_dict
|
19 |
+
return model
|
20 |
+
|
21 |
+
|
22 |
+
def load_checkpoint(weights, map_location=None, inplace=True, fuse=True):
|
23 |
+
"""Load model from checkpoint file."""
|
24 |
+
LOGGER.info("Loading checkpoint from {}".format(weights))
|
25 |
+
ckpt = torch.load(weights, map_location=map_location) # load
|
26 |
+
model = ckpt['ema' if ckpt.get('ema') else 'model'].float()
|
27 |
+
if fuse:
|
28 |
+
LOGGER.info("\nFusing model...")
|
29 |
+
model = fuse_model(model).eval()
|
30 |
+
else:
|
31 |
+
model = model.eval()
|
32 |
+
return model
|
33 |
+
|
34 |
+
|
35 |
+
def save_checkpoint(ckpt, is_best, save_dir, model_name=""):
|
36 |
+
""" Save checkpoint to the disk."""
|
37 |
+
if not osp.exists(save_dir):
|
38 |
+
os.makedirs(save_dir)
|
39 |
+
filename = osp.join(save_dir, model_name + '.pt')
|
40 |
+
torch.save(ckpt, filename)
|
41 |
+
if is_best:
|
42 |
+
best_filename = osp.join(save_dir, 'best_ckpt.pt')
|
43 |
+
shutil.copyfile(filename, best_filename)
|
44 |
+
|
45 |
+
|
46 |
+
def strip_optimizer(ckpt_dir, epoch):
|
47 |
+
for s in ['best', 'last']:
|
48 |
+
ckpt_path = osp.join(ckpt_dir, '{}_ckpt.pt'.format(s))
|
49 |
+
if not osp.exists(ckpt_path):
|
50 |
+
continue
|
51 |
+
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
|
52 |
+
if ckpt.get('ema'):
|
53 |
+
ckpt['model'] = ckpt['ema'] # replace model with ema
|
54 |
+
for k in ['optimizer', 'ema', 'updates']: # keys
|
55 |
+
ckpt[k] = None
|
56 |
+
ckpt['epoch'] = epoch
|
57 |
+
ckpt['model'].half() # to FP16
|
58 |
+
for p in ckpt['model'].parameters():
|
59 |
+
p.requires_grad = False
|
60 |
+
torch.save(ckpt, ckpt_path)
|
yolov6/utils/config.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# The code is based on
|
4 |
+
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
|
5 |
+
# Copyright (c) OpenMMLab.
|
6 |
+
|
7 |
+
import os.path as osp
|
8 |
+
import shutil
|
9 |
+
import sys
|
10 |
+
import tempfile
|
11 |
+
from importlib import import_module
|
12 |
+
from addict import Dict
|
13 |
+
|
14 |
+
|
15 |
+
class ConfigDict(Dict):
|
16 |
+
|
17 |
+
def __missing__(self, name):
|
18 |
+
raise KeyError(name)
|
19 |
+
|
20 |
+
def __getattr__(self, name):
|
21 |
+
try:
|
22 |
+
value = super(ConfigDict, self).__getattr__(name)
|
23 |
+
except KeyError:
|
24 |
+
ex = AttributeError("'{}' object has no attribute '{}'".format(
|
25 |
+
self.__class__.__name__, name))
|
26 |
+
except Exception as e:
|
27 |
+
ex = e
|
28 |
+
else:
|
29 |
+
return value
|
30 |
+
raise ex
|
31 |
+
|
32 |
+
|
33 |
+
class Config(object):
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def _file2dict(filename):
|
37 |
+
filename = str(filename)
|
38 |
+
if filename.endswith('.py'):
|
39 |
+
with tempfile.TemporaryDirectory() as temp_config_dir:
|
40 |
+
shutil.copyfile(filename,
|
41 |
+
osp.join(temp_config_dir, '_tempconfig.py'))
|
42 |
+
sys.path.insert(0, temp_config_dir)
|
43 |
+
mod = import_module('_tempconfig')
|
44 |
+
sys.path.pop(0)
|
45 |
+
cfg_dict = {
|
46 |
+
name: value
|
47 |
+
for name, value in mod.__dict__.items()
|
48 |
+
if not name.startswith('__')
|
49 |
+
}
|
50 |
+
# delete imported module
|
51 |
+
del sys.modules['_tempconfig']
|
52 |
+
else:
|
53 |
+
raise IOError('Only .py type are supported now!')
|
54 |
+
cfg_text = filename + '\n'
|
55 |
+
with open(filename, 'r') as f:
|
56 |
+
cfg_text += f.read()
|
57 |
+
|
58 |
+
return cfg_dict, cfg_text
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def fromfile(filename):
|
62 |
+
cfg_dict, cfg_text = Config._file2dict(filename)
|
63 |
+
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
|
64 |
+
|
65 |
+
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
|
66 |
+
if cfg_dict is None:
|
67 |
+
cfg_dict = dict()
|
68 |
+
elif not isinstance(cfg_dict, dict):
|
69 |
+
raise TypeError('cfg_dict must be a dict, but got {}'.format(
|
70 |
+
type(cfg_dict)))
|
71 |
+
|
72 |
+
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
|
73 |
+
super(Config, self).__setattr__('_filename', filename)
|
74 |
+
if cfg_text:
|
75 |
+
text = cfg_text
|
76 |
+
elif filename:
|
77 |
+
with open(filename, 'r') as f:
|
78 |
+
text = f.read()
|
79 |
+
else:
|
80 |
+
text = ''
|
81 |
+
super(Config, self).__setattr__('_text', text)
|
82 |
+
|
83 |
+
@property
|
84 |
+
def filename(self):
|
85 |
+
return self._filename
|
86 |
+
|
87 |
+
@property
|
88 |
+
def text(self):
|
89 |
+
return self._text
|
90 |
+
|
91 |
+
def __repr__(self):
|
92 |
+
return 'Config (path: {}): {}'.format(self.filename,
|
93 |
+
self._cfg_dict.__repr__())
|
94 |
+
|
95 |
+
def __getattr__(self, name):
|
96 |
+
return getattr(self._cfg_dict, name)
|
97 |
+
|
98 |
+
def __setattr__(self, name, value):
|
99 |
+
if isinstance(value, dict):
|
100 |
+
value = ConfigDict(value)
|
101 |
+
self._cfg_dict.__setattr__(name, value)
|
yolov6/utils/ema.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# The code is based on
|
4 |
+
# https://github.com/ultralytics/yolov5/blob/master/utils/torch_utils.py
|
5 |
+
import math
|
6 |
+
from copy import deepcopy
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
|
11 |
+
class ModelEMA:
|
12 |
+
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
|
13 |
+
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
14 |
+
This is intended to allow functionality like
|
15 |
+
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
16 |
+
A smoothed version of the weights is necessary for some training schemes to perform well.
|
17 |
+
This class is sensitive where it is initialized in the sequence of model init,
|
18 |
+
GPU assignment and distributed training wrappers.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, model, decay=0.9999, updates=0):
|
22 |
+
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
|
23 |
+
self.updates = updates
|
24 |
+
self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
|
25 |
+
for param in self.ema.parameters():
|
26 |
+
param.requires_grad_(False)
|
27 |
+
|
28 |
+
def update(self, model):
|
29 |
+
with torch.no_grad():
|
30 |
+
self.updates += 1
|
31 |
+
decay = self.decay(self.updates)
|
32 |
+
|
33 |
+
state_dict = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
|
34 |
+
for k, item in self.ema.state_dict().items():
|
35 |
+
if item.dtype.is_floating_point:
|
36 |
+
item *= decay
|
37 |
+
item += (1 - decay) * state_dict[k].detach()
|
38 |
+
|
39 |
+
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
40 |
+
copy_attr(self.ema, model, include, exclude)
|
41 |
+
|
42 |
+
|
43 |
+
def copy_attr(a, b, include=(), exclude=()):
|
44 |
+
"""Copy attributes from one instance and set them to another instance."""
|
45 |
+
for k, item in b.__dict__.items():
|
46 |
+
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
47 |
+
continue
|
48 |
+
else:
|
49 |
+
setattr(a, k, item)
|
50 |
+
|
51 |
+
|
52 |
+
def is_parallel(model):
|
53 |
+
# Return True if model's type is DP or DDP, else False.
|
54 |
+
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
55 |
+
|
56 |
+
|
57 |
+
def de_parallel(model):
|
58 |
+
# De-parallelize a model. Return single-GPU model if model's type is DP or DDP.
|
59 |
+
return model.module if is_parallel(model) else model
|
yolov6/utils/envs.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.backends.cudnn as cudnn
|
9 |
+
from yolov6.utils.events import LOGGER
|
10 |
+
|
11 |
+
|
12 |
+
def get_envs():
|
13 |
+
"""Get PyTorch needed environments from system envirionments."""
|
14 |
+
local_rank = int(os.getenv('LOCAL_RANK', -1))
|
15 |
+
rank = int(os.getenv('RANK', -1))
|
16 |
+
world_size = int(os.getenv('WORLD_SIZE', 1))
|
17 |
+
return local_rank, rank, world_size
|
18 |
+
|
19 |
+
|
20 |
+
def select_device(device):
|
21 |
+
"""Set devices' information to the program.
|
22 |
+
Args:
|
23 |
+
device: a string, like 'cpu' or '1,2,3,4'
|
24 |
+
Returns:
|
25 |
+
torch.device
|
26 |
+
"""
|
27 |
+
if device == 'cpu':
|
28 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
29 |
+
LOGGER.info('Using CPU for training... ')
|
30 |
+
elif device:
|
31 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = device
|
32 |
+
assert torch.cuda.is_available()
|
33 |
+
nd = len(device.strip().split(','))
|
34 |
+
LOGGER.info(f'Using {nd} GPU for training... ')
|
35 |
+
cuda = device != 'cpu' and torch.cuda.is_available()
|
36 |
+
device = torch.device('cuda:0' if cuda else 'cpu')
|
37 |
+
return device
|
38 |
+
|
39 |
+
|
40 |
+
def set_random_seed(seed, deterministic=False):
|
41 |
+
""" Set random state to random libray, numpy, torch and cudnn.
|
42 |
+
Args:
|
43 |
+
seed: int value.
|
44 |
+
deterministic: bool value.
|
45 |
+
"""
|
46 |
+
random.seed(seed)
|
47 |
+
np.random.seed(seed)
|
48 |
+
torch.manual_seed(seed)
|
49 |
+
if deterministic:
|
50 |
+
cudnn.deterministic = True
|
51 |
+
cudnn.benchmark = False
|
52 |
+
else:
|
53 |
+
cudnn.deterministic = False
|
54 |
+
cudnn.benchmark = True
|
yolov6/utils/events.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os
|
4 |
+
import yaml
|
5 |
+
import logging
|
6 |
+
import shutil
|
7 |
+
|
8 |
+
|
9 |
+
def set_logging(name=None):
|
10 |
+
rank = int(os.getenv('RANK', -1))
|
11 |
+
logging.basicConfig(format="%(message)s", level=logging.INFO if (rank in (-1, 0)) else logging.WARNING)
|
12 |
+
return logging.getLogger(name)
|
13 |
+
|
14 |
+
|
15 |
+
LOGGER = set_logging(__name__)
|
16 |
+
NCOLS = shutil.get_terminal_size().columns
|
17 |
+
|
18 |
+
|
19 |
+
def load_yaml(file_path):
|
20 |
+
"""Load data from yaml file."""
|
21 |
+
if isinstance(file_path, str):
|
22 |
+
with open(file_path, errors='ignore') as f:
|
23 |
+
data_dict = yaml.safe_load(f)
|
24 |
+
return data_dict
|
25 |
+
|
26 |
+
|
27 |
+
def save_yaml(data_dict, save_path):
|
28 |
+
"""Save data to yaml file"""
|
29 |
+
with open(save_path, 'w') as f:
|
30 |
+
yaml.safe_dump(data_dict, f, sort_keys=False)
|
31 |
+
|
32 |
+
|
33 |
+
def write_tblog(tblogger, epoch, results, losses):
|
34 |
+
"""Display mAP and loss information to log."""
|
35 |
+
tblogger.add_scalar("val/mAP@0.5", results[0], epoch + 1)
|
36 |
+
tblogger.add_scalar("val/mAP@0.50:0.95", results[1], epoch + 1)
|
37 |
+
|
38 |
+
tblogger.add_scalar("train/iou_loss", losses[0], epoch + 1)
|
39 |
+
tblogger.add_scalar("train/l1_loss", losses[1], epoch + 1)
|
40 |
+
tblogger.add_scalar("train/obj_loss", losses[2], epoch + 1)
|
41 |
+
tblogger.add_scalar("train/cls_loss", losses[3], epoch + 1)
|
yolov6/utils/figure_iou.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class IOUloss:
|
8 |
+
""" Calculate IoU loss.
|
9 |
+
"""
|
10 |
+
def __init__(self, box_format='xywh', iou_type='ciou', reduction='none', eps=1e-7):
|
11 |
+
""" Setting of the class.
|
12 |
+
Args:
|
13 |
+
box_format: (string), must be one of 'xywh' or 'xyxy'.
|
14 |
+
iou_type: (string), can be one of 'ciou', 'diou', 'giou' or 'siou'
|
15 |
+
reduction: (string), specifies the reduction to apply to the output, must be one of 'none', 'mean','sum'.
|
16 |
+
eps: (float), a value to avoid divide by zero error.
|
17 |
+
"""
|
18 |
+
self.box_format = box_format
|
19 |
+
self.iou_type = iou_type.lower()
|
20 |
+
self.reduction = reduction
|
21 |
+
self.eps = eps
|
22 |
+
|
23 |
+
def __call__(self, box1, box2):
|
24 |
+
""" calculate iou. box1 and box2 are torch tensor with shape [M, 4] and [Nm 4].
|
25 |
+
"""
|
26 |
+
box2 = box2.T
|
27 |
+
if self.box_format == 'xyxy':
|
28 |
+
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
|
29 |
+
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
|
30 |
+
elif self.box_format == 'xywh':
|
31 |
+
b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
|
32 |
+
b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
|
33 |
+
b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
|
34 |
+
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
|
35 |
+
|
36 |
+
# Intersection area
|
37 |
+
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
|
38 |
+
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
|
39 |
+
|
40 |
+
# Union Area
|
41 |
+
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + self.eps
|
42 |
+
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + self.eps
|
43 |
+
union = w1 * h1 + w2 * h2 - inter + self.eps
|
44 |
+
iou = inter / union
|
45 |
+
|
46 |
+
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex width
|
47 |
+
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
|
48 |
+
if self.iou_type == 'giou':
|
49 |
+
c_area = cw * ch + self.eps # convex area
|
50 |
+
iou = iou - (c_area - union) / c_area
|
51 |
+
elif self.iou_type in ['diou', 'ciou']:
|
52 |
+
c2 = cw ** 2 + ch ** 2 + self.eps # convex diagonal squared
|
53 |
+
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
|
54 |
+
(b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
|
55 |
+
if self.iou_type == 'diou':
|
56 |
+
iou = iou - rho2 / c2
|
57 |
+
elif self.iou_type == 'ciou':
|
58 |
+
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
59 |
+
with torch.no_grad():
|
60 |
+
alpha = v / (v - iou + (1 + self.eps))
|
61 |
+
iou = iou - (rho2 / c2 + v * alpha)
|
62 |
+
elif self.iou_type == 'siou':
|
63 |
+
# SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
|
64 |
+
s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5
|
65 |
+
s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5
|
66 |
+
sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
|
67 |
+
sin_alpha_1 = torch.abs(s_cw) / sigma
|
68 |
+
sin_alpha_2 = torch.abs(s_ch) / sigma
|
69 |
+
threshold = pow(2, 0.5) / 2
|
70 |
+
sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
|
71 |
+
angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
|
72 |
+
rho_x = (s_cw / cw) ** 2
|
73 |
+
rho_y = (s_ch / ch) ** 2
|
74 |
+
gamma = angle_cost - 2
|
75 |
+
distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
|
76 |
+
omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
|
77 |
+
omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
|
78 |
+
shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
|
79 |
+
iou = iou - 0.5 * (distance_cost + shape_cost)
|
80 |
+
loss = 1.0 - iou
|
81 |
+
|
82 |
+
if self.reduction == 'sum':
|
83 |
+
loss = loss.sum()
|
84 |
+
elif self.reduction == 'mean':
|
85 |
+
loss = loss.mean()
|
86 |
+
|
87 |
+
return loss
|
88 |
+
|
89 |
+
|
90 |
+
def pairwise_bbox_iou(box1, box2, box_format='xywh'):
|
91 |
+
"""Calculate iou.
|
92 |
+
This code is based on https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/boxes.py
|
93 |
+
"""
|
94 |
+
if box_format == 'xyxy':
|
95 |
+
lt = torch.max(box1[:, None, :2], box2[:, :2])
|
96 |
+
rb = torch.min(box1[:, None, 2:], box2[:, 2:])
|
97 |
+
area_1 = torch.prod(box1[:, 2:] - box1[:, :2], 1)
|
98 |
+
area_2 = torch.prod(box2[:, 2:] - box2[:, :2], 1)
|
99 |
+
|
100 |
+
elif box_format == 'xywh':
|
101 |
+
lt = torch.max(
|
102 |
+
(box1[:, None, :2] - box1[:, None, 2:] / 2),
|
103 |
+
(box2[:, :2] - box2[:, 2:] / 2),
|
104 |
+
)
|
105 |
+
rb = torch.min(
|
106 |
+
(box1[:, None, :2] + box1[:, None, 2:] / 2),
|
107 |
+
(box2[:, :2] + box2[:, 2:] / 2),
|
108 |
+
)
|
109 |
+
|
110 |
+
area_1 = torch.prod(box1[:, 2:], 1)
|
111 |
+
area_2 = torch.prod(box2[:, 2:], 1)
|
112 |
+
valid = (lt < rb).type(lt.type()).prod(dim=2)
|
113 |
+
inter = torch.prod(rb - lt, 2) * valid
|
114 |
+
return inter / (area_1[:, None] + area_2 - inter)
|
yolov6/utils/general.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
import os
|
4 |
+
import glob
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
def increment_name(path):
|
8 |
+
"increase save directory's id"
|
9 |
+
path = Path(path)
|
10 |
+
sep = ''
|
11 |
+
if path.exists():
|
12 |
+
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
|
13 |
+
for n in range(1, 9999):
|
14 |
+
p = f'{path}{sep}{n}{suffix}'
|
15 |
+
if not os.path.exists(p):
|
16 |
+
break
|
17 |
+
path = Path(p)
|
18 |
+
return path
|
19 |
+
|
20 |
+
|
21 |
+
def find_latest_checkpoint(search_dir='.'):
|
22 |
+
# Find the most recent saved checkpoint in search_dir
|
23 |
+
checkpoint_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
24 |
+
return max(checkpoint_list, key=os.path.getctime) if checkpoint_list else ''
|
yolov6/utils/nms.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# The code is based on
|
4 |
+
# https://github.com/ultralytics/yolov5/blob/master/utils/general.py
|
5 |
+
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import numpy as np
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
import torchvision
|
12 |
+
|
13 |
+
|
14 |
+
# Settings
|
15 |
+
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
16 |
+
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
17 |
+
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
18 |
+
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads
|
19 |
+
|
20 |
+
|
21 |
+
def xywh2xyxy(x):
|
22 |
+
# Convert boxes with shape [n, 4] from [x, y, w, h] to [x1, y1, x2, y2] where x1y1 is top-left, x2y2=bottom-right
|
23 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
24 |
+
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
25 |
+
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
26 |
+
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
27 |
+
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
28 |
+
return y
|
29 |
+
|
30 |
+
|
31 |
+
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, max_det=300):
|
32 |
+
"""Runs Non-Maximum Suppression (NMS) on inference results.
|
33 |
+
This code is borrowed from: https://github.com/ultralytics/yolov5/blob/47233e1698b89fc437a4fb9463c815e9171be955/utils/general.py#L775
|
34 |
+
Args:
|
35 |
+
prediction: (tensor), with shape [N, 5 + num_classes], N is the number of bboxes.
|
36 |
+
conf_thres: (float) confidence threshold.
|
37 |
+
iou_thres: (float) iou threshold.
|
38 |
+
classes: (None or list[int]), if a list is provided, nms only keep the classes you provide.
|
39 |
+
agnostic: (bool), when it is set to True, we do class-independent nms, otherwise, different class would do nms respectively.
|
40 |
+
multi_label: (bool), when it is set to True, one box can have multi labels, otherwise, one box only huave one label.
|
41 |
+
max_det:(int), max number of output bboxes.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
list of detections, echo item is one tensor with shape (num_boxes, 6), 6 is for [xyxy, conf, cls].
|
45 |
+
"""
|
46 |
+
|
47 |
+
num_classes = prediction.shape[2] - 5 # number of classes
|
48 |
+
pred_candidates = prediction[..., 4] > conf_thres # candidates
|
49 |
+
|
50 |
+
# Check the parameters.
|
51 |
+
assert 0 <= conf_thres <= 1, f'conf_thresh must be in 0.0 to 1.0, however {conf_thres} is provided.'
|
52 |
+
assert 0 <= iou_thres <= 1, f'iou_thres must be in 0.0 to 1.0, however {iou_thres} is provided.'
|
53 |
+
|
54 |
+
# Function settings.
|
55 |
+
max_wh = 4096 # maximum box width and height
|
56 |
+
max_nms = 30000 # maximum number of boxes put into torchvision.ops.nms()
|
57 |
+
time_limit = 10.0 # quit the function when nms cost time exceed the limit time.
|
58 |
+
multi_label &= num_classes > 1 # multiple labels per box
|
59 |
+
|
60 |
+
tik = time.time()
|
61 |
+
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
|
62 |
+
for img_idx, x in enumerate(prediction): # image index, image inference
|
63 |
+
x = x[pred_candidates[img_idx]] # confidence
|
64 |
+
|
65 |
+
# If no box remains, skip the next process.
|
66 |
+
if not x.shape[0]:
|
67 |
+
continue
|
68 |
+
|
69 |
+
# confidence multiply the objectness
|
70 |
+
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
|
71 |
+
|
72 |
+
# (center x, center y, width, height) to (x1, y1, x2, y2)
|
73 |
+
box = xywh2xyxy(x[:, :4])
|
74 |
+
|
75 |
+
# Detections matrix's shape is (n,6), each row represents (xyxy, conf, cls)
|
76 |
+
if multi_label:
|
77 |
+
box_idx, class_idx = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
|
78 |
+
x = torch.cat((box[box_idx], x[box_idx, class_idx + 5, None], class_idx[:, None].float()), 1)
|
79 |
+
else: # Only keep the class with highest scores.
|
80 |
+
conf, class_idx = x[:, 5:].max(1, keepdim=True)
|
81 |
+
x = torch.cat((box, conf, class_idx.float()), 1)[conf.view(-1) > conf_thres]
|
82 |
+
|
83 |
+
# Filter by class, only keep boxes whose category is in classes.
|
84 |
+
if classes is not None:
|
85 |
+
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
86 |
+
|
87 |
+
# Check shape
|
88 |
+
num_box = x.shape[0] # number of boxes
|
89 |
+
if not num_box: # no boxes kept.
|
90 |
+
continue
|
91 |
+
elif num_box > max_nms: # excess max boxes' number.
|
92 |
+
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
|
93 |
+
|
94 |
+
# Batched NMS
|
95 |
+
class_offset = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
96 |
+
boxes, scores = x[:, :4] + class_offset, x[:, 4] # boxes (offset by class), scores
|
97 |
+
keep_box_idx = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
98 |
+
if keep_box_idx.shape[0] > max_det: # limit detections
|
99 |
+
keep_box_idx = keep_box_idx[:max_det]
|
100 |
+
|
101 |
+
output[img_idx] = x[keep_box_idx]
|
102 |
+
if (time.time() - tik) > time_limit:
|
103 |
+
print(f'WARNING: NMS cost time exceed the limited {time_limit}s.')
|
104 |
+
break # time limit exceeded
|
105 |
+
|
106 |
+
return output
|
yolov6/utils/torch_utils.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
|
4 |
+
import time
|
5 |
+
from contextlib import contextmanager
|
6 |
+
from copy import deepcopy
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from yolov6.utils.events import LOGGER
|
12 |
+
|
13 |
+
try:
|
14 |
+
import thop # for FLOPs computation
|
15 |
+
except ImportError:
|
16 |
+
thop = None
|
17 |
+
|
18 |
+
|
19 |
+
@contextmanager
|
20 |
+
def torch_distributed_zero_first(local_rank: int):
|
21 |
+
"""
|
22 |
+
Decorator to make all processes in distributed training wait for each local_master to do something.
|
23 |
+
"""
|
24 |
+
if local_rank not in [-1, 0]:
|
25 |
+
dist.barrier(device_ids=[local_rank])
|
26 |
+
yield
|
27 |
+
if local_rank == 0:
|
28 |
+
dist.barrier(device_ids=[0])
|
29 |
+
|
30 |
+
|
31 |
+
def time_sync():
|
32 |
+
# Waits for all kernels in all streams on a CUDA device to complete if cuda is available.
|
33 |
+
if torch.cuda.is_available():
|
34 |
+
torch.cuda.synchronize()
|
35 |
+
return time.time()
|
36 |
+
|
37 |
+
|
38 |
+
def initialize_weights(model):
|
39 |
+
for m in model.modules():
|
40 |
+
t = type(m)
|
41 |
+
if t is nn.Conv2d:
|
42 |
+
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
43 |
+
elif t is nn.BatchNorm2d:
|
44 |
+
m.eps = 1e-3
|
45 |
+
m.momentum = 0.03
|
46 |
+
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
|
47 |
+
m.inplace = True
|
48 |
+
|
49 |
+
|
50 |
+
def fuse_conv_and_bn(conv, bn):
|
51 |
+
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
52 |
+
fusedconv = (
|
53 |
+
nn.Conv2d(
|
54 |
+
conv.in_channels,
|
55 |
+
conv.out_channels,
|
56 |
+
kernel_size=conv.kernel_size,
|
57 |
+
stride=conv.stride,
|
58 |
+
padding=conv.padding,
|
59 |
+
groups=conv.groups,
|
60 |
+
bias=True,
|
61 |
+
)
|
62 |
+
.requires_grad_(False)
|
63 |
+
.to(conv.weight.device)
|
64 |
+
)
|
65 |
+
|
66 |
+
# prepare filters
|
67 |
+
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
68 |
+
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
69 |
+
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
70 |
+
|
71 |
+
# prepare spatial bias
|
72 |
+
b_conv = (
|
73 |
+
torch.zeros(conv.weight.size(0), device=conv.weight.device)
|
74 |
+
if conv.bias is None
|
75 |
+
else conv.bias
|
76 |
+
)
|
77 |
+
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
|
78 |
+
torch.sqrt(bn.running_var + bn.eps)
|
79 |
+
)
|
80 |
+
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
81 |
+
|
82 |
+
return fusedconv
|
83 |
+
|
84 |
+
|
85 |
+
def fuse_model(model):
|
86 |
+
from yolov6.layers.common import Conv
|
87 |
+
|
88 |
+
for m in model.modules():
|
89 |
+
if type(m) is Conv and hasattr(m, "bn"):
|
90 |
+
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
91 |
+
delattr(m, "bn") # remove batchnorm
|
92 |
+
m.forward = m.forward_fuse # update forward
|
93 |
+
return model
|
94 |
+
|
95 |
+
|
96 |
+
def get_model_info(model, img_size=640):
|
97 |
+
"""Get model Params and GFlops.
|
98 |
+
Code base on https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/model_utils.py
|
99 |
+
"""
|
100 |
+
from thop import profile
|
101 |
+
stride = 32
|
102 |
+
img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device)
|
103 |
+
|
104 |
+
flops, params = profile(deepcopy(model), inputs=(img,), verbose=False)
|
105 |
+
params /= 1e6
|
106 |
+
flops /= 1e9
|
107 |
+
img_size = img_size if isinstance(img_size, list) else [img_size, img_size]
|
108 |
+
flops *= img_size[0] * img_size[1] / stride / stride * 2 # Gflops
|
109 |
+
info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops)
|
110 |
+
return info
|