Spaces:
Configuration error
Configuration error
luxmorocco
commited on
Commit
•
4efbc62
1
Parent(s):
09723cd
Upload 86 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- yolo-world-with-efficientvit-sam/.DS_Store +0 -0
- yolo-world-with-efficientvit-sam/.gitignore +171 -0
- yolo-world-with-efficientvit-sam/LICENSE +201 -0
- yolo-world-with-efficientvit-sam/Makefile +18 -0
- yolo-world-with-efficientvit-sam/README.md +68 -0
- yolo-world-with-efficientvit-sam/app.py +132 -0
- yolo-world-with-efficientvit-sam/efficientvit/__init__.py +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/__pycache__/__init__.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/__pycache__/sam_model_zoo.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/__init__.py +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/__pycache__/__init__.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__init__.py +7 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__pycache__/__init__.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__pycache__/base.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__init__.py +6 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/__init__.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/bbox.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/color_aug.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/bbox.py +30 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/color_aug.py +84 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/base.py +223 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__init__.py +7 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__pycache__/__init__.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__pycache__/controller.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/_data_loader.py +1603 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/_data_worker.py +378 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/controller.py +94 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/setup.py +141 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__init__.py +6 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/__init__.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/base.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/run_config.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/base.py +297 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/run_config.py +121 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__init__.py +12 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/dist.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/ema.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/export.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/init.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/lr.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/metric.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/misc.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/opt.cpython-311.pyc +0 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/dist.py +73 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/ema.py +50 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/export.py +47 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/init.py +68 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/lr.py +48 -0
- yolo-world-with-efficientvit-sam/efficientvit/apps/utils/metric.py +37 -0
yolo-world-with-efficientvit-sam/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
yolo-world-with-efficientvit-sam/.gitignore
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
161 |
+
|
162 |
+
# Model Weights
|
163 |
+
*.pth
|
164 |
+
*.pt
|
165 |
+
|
166 |
+
# Yolo-World
|
167 |
+
work_dirs
|
168 |
+
src
|
169 |
+
|
170 |
+
# Etc
|
171 |
+
.DS_Store
|
yolo-world-with-efficientvit-sam/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
yolo-world-with-efficientvit-sam/Makefile
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
EFFICIENTVIT_SAM_URL := "https://huggingface.co/han-cai/efficientvit-sam/resolve/main"
|
2 |
+
EFFICIENTVIT_SAM_MODEL := "xl1.pt"
|
3 |
+
|
4 |
+
|
5 |
+
define download
|
6 |
+
@if [ ! -f $(2) ]; then \
|
7 |
+
echo "Download $(2)..."; \
|
8 |
+
wget "$(1)/$(2)"; \
|
9 |
+
fi
|
10 |
+
endef
|
11 |
+
|
12 |
+
|
13 |
+
setup:
|
14 |
+
pip install -r requirements.txt
|
15 |
+
|
16 |
+
|
17 |
+
model:
|
18 |
+
$(call download,$(EFFICIENTVIT_SAM_URL),$(EFFICIENTVIT_SAM_MODEL))
|
yolo-world-with-efficientvit-sam/README.md
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# YOLO-World + EfficientViT SAM
|
2 |
+
|
3 |
+
🤗 [HuggingFace Space](https://huggingface.co/spaces/curt-park/yolo-world-with-efficientvit-sam)
|
4 |
+
|
5 |
+
![example_0](https://github.com/Curt-Park/yolo-world-with-efficientvit-sam/assets/14961526/326bde19-d535-4be5-829e-782fce0c1d00)
|
6 |
+
|
7 |
+
## Prerequisites
|
8 |
+
This project is developed and tested on Python3.10.
|
9 |
+
|
10 |
+
```bash
|
11 |
+
# Create and activate a python 3.10 environment.
|
12 |
+
conda create -n yolo-world-with-efficientvit-sam python=3.10 -y
|
13 |
+
conda activate yolo-world-with-efficientvit-sam
|
14 |
+
# Setup packages.
|
15 |
+
make setup
|
16 |
+
```
|
17 |
+
|
18 |
+
## How to Run
|
19 |
+
```bash
|
20 |
+
python app.py
|
21 |
+
```
|
22 |
+
|
23 |
+
Open http://127.0.0.1:7860/ on your web browser.
|
24 |
+
|
25 |
+
![example_1](https://github.com/Curt-Park/yolo-world-with-efficientvit-sam/assets/14961526/9388e4ee-6f71-4428-b17c-d218fd059949)
|
26 |
+
|
27 |
+
## Core Components
|
28 |
+
|
29 |
+
### YOLO-World
|
30 |
+
[YOLO-World](https://github.com/AILab-CVC/YOLO-World) is an open-vocabulary object detection model with high efficiency.
|
31 |
+
On the challenging LVIS dataset, YOLO-World achieves 35.4 AP with 52.0 FPS on V100,
|
32 |
+
which outperforms many state-of-the-art methods in terms of both accuracy and speed.
|
33 |
+
![image](https://github.com/Curt-Park/yolo-world-with-efficientvit-sam/assets/14961526/8a4a17bd-918d-478a-8451-f58e4a2dce79)
|
34 |
+
<img width="1024" src="https://github.com/Curt-Park/yolo-world-with-efficientvit-sam/assets/14961526/fce57405-e18d-45f3-bea8-fc3971faf975">
|
35 |
+
|
36 |
+
### EfficientViT SAM
|
37 |
+
[EfficientViT SAM](https://github.com/mit-han-lab/efficientvit) is a new family of accelerated segment anything models.
|
38 |
+
Thanks to the lightweight and hardware-efficient core building block,
|
39 |
+
it delivers 48.9× measured TensorRT speedup on A100 GPU over SAM-ViT-H without sacrificing performance.
|
40 |
+
|
41 |
+
<img width="1024" src="https://github.com/Curt-Park/yolo-world-with-efficientvit-sam/assets/14961526/9eec003f-47c9-43a5-86b0-82d6689e1bf9">
|
42 |
+
<img width="1024" src="https://github.com/Curt-Park/yolo-world-with-efficientvit-sam/assets/14961526/d79973bb-0d80-4b64-a175-252de56d0d09">
|
43 |
+
|
44 |
+
## Powered By
|
45 |
+
```
|
46 |
+
@misc{zhang2024efficientvitsam,
|
47 |
+
title={EfficientViT-SAM: Accelerated Segment Anything Model Without Performance Loss},
|
48 |
+
author={Zhuoyang Zhang and Han Cai and Song Han},
|
49 |
+
year={2024},
|
50 |
+
eprint={2402.05008},
|
51 |
+
archivePrefix={arXiv},
|
52 |
+
primaryClass={cs.CV}
|
53 |
+
}
|
54 |
+
|
55 |
+
@article{cheng2024yolow,
|
56 |
+
title={YOLO-World: Real-Time Open-Vocabulary Object Detection},
|
57 |
+
author={Cheng, Tianheng and Song, Lin and Ge, Yixiao and Liu, Wenyu and Wang, Xinggang and Shan, Ying},
|
58 |
+
journal={arXiv preprint arXiv:2401.17270},
|
59 |
+
year={2024}
|
60 |
+
}
|
61 |
+
|
62 |
+
@article{cai2022efficientvit,
|
63 |
+
title={Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition},
|
64 |
+
author={Cai, Han and Gan, Chuang and Han, Song},
|
65 |
+
journal={arXiv preprint arXiv:2205.14756},
|
66 |
+
year={2022}
|
67 |
+
}
|
68 |
+
```
|
yolo-world-with-efficientvit-sam/app.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Fast text to segmentation with yolo-world and efficient-vit sam."""
|
2 |
+
import os
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
import supervision as sv
|
8 |
+
import torch
|
9 |
+
from inference.models import YOLOWorld
|
10 |
+
|
11 |
+
from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
|
12 |
+
from efficientvit.sam_model_zoo import create_sam_model
|
13 |
+
|
14 |
+
|
15 |
+
# Download model weights.
|
16 |
+
os.system("make model")
|
17 |
+
|
18 |
+
# Load models.
|
19 |
+
yolo_world = YOLOWorld(model_id="yolo_world/l")
|
20 |
+
#yolo_world = YOLOWorld("/Users/tounsi/Desktop/DOCTORIA/Doctoria\ Full\ Software/Doctoria\ CXR/Doctoria\ CXR\ Thoracic\ Abnormalities/YOLOv8/CXR\ YOLOv8l.pt")
|
21 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
+
sam = EfficientViTSamPredictor(
|
23 |
+
create_sam_model(name="xl1", weight_url="xl1.pt").to(device).eval()
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
# Load annotators.
|
28 |
+
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
|
29 |
+
MASK_ANNOTATOR = sv.MaskAnnotator()
|
30 |
+
LABEL_ANNOTATOR = sv.LabelAnnotator()
|
31 |
+
|
32 |
+
|
33 |
+
def detect(
|
34 |
+
image: np.ndarray,
|
35 |
+
query: str,
|
36 |
+
confidence_threshold: float,
|
37 |
+
nms_threshold: float,
|
38 |
+
) -> np.ndarray:
|
39 |
+
# Preparation.
|
40 |
+
categories = [category.strip() for category in query.split(",")]
|
41 |
+
yolo_world.set_classes(categories)
|
42 |
+
print("categories:", categories)
|
43 |
+
|
44 |
+
# Object detection.
|
45 |
+
results = yolo_world.infer(image, confidence=confidence_threshold)
|
46 |
+
detections = sv.Detections.from_inference(results).with_nms(
|
47 |
+
class_agnostic=True, threshold=nms_threshold
|
48 |
+
)
|
49 |
+
print("detected:", detections)
|
50 |
+
|
51 |
+
# Segmentation.
|
52 |
+
sam.set_image(image, image_format="RGB")
|
53 |
+
masks = []
|
54 |
+
for xyxy in detections.xyxy:
|
55 |
+
mask, _, _ = sam.predict(box=xyxy, multimask_output=False)
|
56 |
+
masks.append(mask.squeeze())
|
57 |
+
detections.mask = np.array(masks)
|
58 |
+
print("masks shaped as", detections.mask.shape)
|
59 |
+
|
60 |
+
# Annotation.
|
61 |
+
output_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
62 |
+
labels = [
|
63 |
+
f"{categories[class_id]}: {confidence:.2f}"
|
64 |
+
for class_id, confidence in zip(detections.class_id, detections.confidence)
|
65 |
+
]
|
66 |
+
output_image = MASK_ANNOTATOR.annotate(output_image, detections)
|
67 |
+
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
|
68 |
+
output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
|
69 |
+
return cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
|
70 |
+
|
71 |
+
|
72 |
+
app = gr.Interface(
|
73 |
+
fn=detect,
|
74 |
+
inputs=[
|
75 |
+
gr.Image(type="numpy", label="input image"),
|
76 |
+
gr.Text(info="you can input multiple words with comma (,)"),
|
77 |
+
gr.Slider(
|
78 |
+
minimum=0,
|
79 |
+
maximum=1,
|
80 |
+
value=0.3,
|
81 |
+
step=0.01,
|
82 |
+
interactive=True,
|
83 |
+
label="Confidence Threshold",
|
84 |
+
),
|
85 |
+
gr.Slider(
|
86 |
+
minimum=0,
|
87 |
+
maximum=1,
|
88 |
+
value=0.5,
|
89 |
+
step=0.01,
|
90 |
+
interactive=True,
|
91 |
+
label="NMS Threshold",
|
92 |
+
),
|
93 |
+
],
|
94 |
+
outputs=gr.Image(type="numpy", label="output image"),
|
95 |
+
allow_flagging="never",
|
96 |
+
title="Fast Text to Segmentation with YOLO-World + EfficientViT SAM",
|
97 |
+
description="""
|
98 |
+
## Core components
|
99 |
+
### YOLO-World
|
100 |
+
[YOLO-World](https://github.com/AILab-CVC/YOLO-World) is an open-vocabulary object detection model with high efficiency.
|
101 |
+
On the challenging LVIS dataset, YOLO-World achieves 35.4 AP with 52.0 FPS on V100,
|
102 |
+
which outperforms many state-of-the-art methods in terms of both accuracy and speed.
|
103 |
+
|
104 |
+
### EfficientViT SAM
|
105 |
+
[EfficientViT SAM](https://github.com/mit-han-lab/efficientvit) is a new family of accelerated segment anything models.
|
106 |
+
Thanks to the lightweight and hardware-efficient core building block,
|
107 |
+
it delivers 48.9× measured TensorRT speedup on A100 GPU over SAM-ViT-H without sacrificing performance.
|
108 |
+
|
109 |
+
## Demo especially powered by
|
110 |
+
Roboflow's [inference](https://github.com/roboflow/inference) and [supervision](https://github.com/roboflow/supervision).
|
111 |
+
|
112 |
+
## Example images came from
|
113 |
+
[Segment Anything Demo](https://segment-anything.com/demo) and [Unsplash](https://unsplash.com/).
|
114 |
+
""",
|
115 |
+
examples=[
|
116 |
+
[
|
117 |
+
os.path.join(os.path.dirname(__file__), "examples/livingroom.jpg"),
|
118 |
+
"table, lamp, dog, sofa, plant, clock, carpet, frame on the wall",
|
119 |
+
0.05,
|
120 |
+
0.5
|
121 |
+
],
|
122 |
+
[
|
123 |
+
os.path.join(os.path.dirname(__file__), "examples/cat_and_dogs.jpg"),
|
124 |
+
"cat, dog",
|
125 |
+
0.2,
|
126 |
+
0.5
|
127 |
+
],
|
128 |
+
],
|
129 |
+
)
|
130 |
+
|
131 |
+
|
132 |
+
app.launch(server_name="0.0.0.0")
|
yolo-world-with-efficientvit-sam/efficientvit/__init__.py
ADDED
File without changes
|
yolo-world-with-efficientvit-sam/efficientvit/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (188 Bytes). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/__pycache__/sam_model_zoo.cpython-311.pyc
ADDED
Binary file (2.28 kB). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/__init__.py
ADDED
File without changes
|
yolo-world-with-efficientvit-sam/efficientvit/apps/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (193 Bytes). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
from .augment import *
|
6 |
+
from .base import *
|
7 |
+
from .random_resolution import *
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (306 Bytes). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__pycache__/base.cpython-311.pyc
ADDED
Binary file (11.4 kB). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
from .bbox import *
|
6 |
+
from .color_aug import *
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (277 Bytes). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/bbox.cpython-311.pyc
ADDED
Binary file (1.44 kB). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/color_aug.cpython-311.pyc
ADDED
Binary file (5.17 kB). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/bbox.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
__all__ = ["rand_bbox"]
|
8 |
+
|
9 |
+
|
10 |
+
def rand_bbox(
|
11 |
+
h: int,
|
12 |
+
w: int,
|
13 |
+
lam: float,
|
14 |
+
rand_func: callable = np.random.uniform,
|
15 |
+
) -> tuple[int, int, int, int]:
|
16 |
+
"""randomly sample bbox, used in cutmix"""
|
17 |
+
cut_rat = np.sqrt(1.0 - lam)
|
18 |
+
cut_w = w * cut_rat
|
19 |
+
cut_h = h * cut_rat
|
20 |
+
|
21 |
+
# uniform
|
22 |
+
cx = rand_func(0, w)
|
23 |
+
cy = rand_func(0, h)
|
24 |
+
|
25 |
+
bbx1 = int(np.clip(cx - cut_w / 2, 0, w))
|
26 |
+
bby1 = int(np.clip(cy - cut_h / 2, 0, h))
|
27 |
+
bbx2 = int(np.clip(cx + cut_w / 2, 0, w))
|
28 |
+
bby2 = int(np.clip(cy + cut_h / 2, 0, h))
|
29 |
+
|
30 |
+
return bbx1, bby1, bbx2, bby2
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/color_aug.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from PIL import Image
|
8 |
+
from timm.data.auto_augment import rand_augment_transform
|
9 |
+
|
10 |
+
__all__ = ["ColorAug", "RandAug"]
|
11 |
+
|
12 |
+
|
13 |
+
class ImageAug:
|
14 |
+
def aug_image(self, image: Image.Image) -> Image.Image:
|
15 |
+
raise NotImplementedError
|
16 |
+
|
17 |
+
def __call__(
|
18 |
+
self, feed_dict: dict or np.ndarray or Image.Image
|
19 |
+
) -> dict or np.ndarray or Image.Image:
|
20 |
+
if isinstance(feed_dict, dict):
|
21 |
+
output_dict = feed_dict
|
22 |
+
image = feed_dict[self.key]
|
23 |
+
else:
|
24 |
+
output_dict = None
|
25 |
+
image = feed_dict
|
26 |
+
is_ndarray = isinstance(image, np.ndarray)
|
27 |
+
if is_ndarray:
|
28 |
+
image = Image.fromarray(image)
|
29 |
+
|
30 |
+
image = self.aug_image(image)
|
31 |
+
|
32 |
+
if is_ndarray:
|
33 |
+
image = np.array(image)
|
34 |
+
|
35 |
+
if output_dict is None:
|
36 |
+
return image
|
37 |
+
else:
|
38 |
+
output_dict[self.key] = image
|
39 |
+
return output_dict
|
40 |
+
|
41 |
+
|
42 |
+
class ColorAug(transforms.ColorJitter, ImageAug):
|
43 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, key="data"):
|
44 |
+
super().__init__(
|
45 |
+
brightness=brightness,
|
46 |
+
contrast=contrast,
|
47 |
+
saturation=saturation,
|
48 |
+
hue=hue,
|
49 |
+
)
|
50 |
+
self.key = key
|
51 |
+
|
52 |
+
def aug_image(self, image: Image.Image) -> Image.Image:
|
53 |
+
return transforms.ColorJitter.forward(self, image)
|
54 |
+
|
55 |
+
def forward(
|
56 |
+
self, feed_dict: dict or np.ndarray or Image.Image
|
57 |
+
) -> dict or np.ndarray or Image.Image:
|
58 |
+
return ImageAug.__call__(self, feed_dict)
|
59 |
+
|
60 |
+
|
61 |
+
class RandAug(ImageAug):
|
62 |
+
def __init__(
|
63 |
+
self, config: dict[str, any], mean: tuple[float, float, float], key="data"
|
64 |
+
):
|
65 |
+
n = config.get("n", 2)
|
66 |
+
m = config.get("m", 9)
|
67 |
+
mstd = config.get("mstd", 1.0)
|
68 |
+
inc = config.get("inc", 1)
|
69 |
+
tpct = config.get("tpct", 0.45)
|
70 |
+
config_str = f"rand-n{n}-m{m}-mstd{mstd}-inc{inc}"
|
71 |
+
|
72 |
+
aa_params = dict(
|
73 |
+
translate_pct=tpct,
|
74 |
+
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
75 |
+
interpolation=Image.BICUBIC,
|
76 |
+
)
|
77 |
+
self.aug_op = rand_augment_transform(config_str, aa_params)
|
78 |
+
self.key = key
|
79 |
+
|
80 |
+
def aug_image(self, image: Image.Image) -> Image.Image:
|
81 |
+
return self.aug_op(image)
|
82 |
+
|
83 |
+
def __repr__(self):
|
84 |
+
return self.aug_op.__repr__()
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/base.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
import copy
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import torch.utils.data
|
9 |
+
from torch.utils.data.distributed import DistributedSampler
|
10 |
+
|
11 |
+
from efficientvit.apps.data_provider.random_resolution import RRSController
|
12 |
+
from efficientvit.models.utils import val2tuple
|
13 |
+
|
14 |
+
__all__ = ["parse_image_size", "random_drop_data", "DataProvider"]
|
15 |
+
|
16 |
+
|
17 |
+
def parse_image_size(size: int or str) -> tuple[int, int]:
|
18 |
+
if isinstance(size, str):
|
19 |
+
size = [int(val) for val in size.split("-")]
|
20 |
+
return size[0], size[1]
|
21 |
+
else:
|
22 |
+
return val2tuple(size, 2)
|
23 |
+
|
24 |
+
|
25 |
+
def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)):
|
26 |
+
g = torch.Generator()
|
27 |
+
g.manual_seed(seed) # set random seed before sampling validation set
|
28 |
+
rand_indexes = torch.randperm(len(dataset), generator=g).tolist()
|
29 |
+
|
30 |
+
dropped_indexes = rand_indexes[:drop_size]
|
31 |
+
remaining_indexes = rand_indexes[drop_size:]
|
32 |
+
|
33 |
+
dropped_dataset = copy.deepcopy(dataset)
|
34 |
+
for key in keys:
|
35 |
+
setattr(
|
36 |
+
dropped_dataset,
|
37 |
+
key,
|
38 |
+
[getattr(dropped_dataset, key)[idx] for idx in dropped_indexes],
|
39 |
+
)
|
40 |
+
setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes])
|
41 |
+
return dataset, dropped_dataset
|
42 |
+
|
43 |
+
|
44 |
+
class DataProvider:
|
45 |
+
data_keys = ("samples",)
|
46 |
+
mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}
|
47 |
+
SUB_SEED = 937162211 # random seed for sampling subset
|
48 |
+
VALID_SEED = 2147483647 # random seed for the validation set
|
49 |
+
|
50 |
+
name: str
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
train_batch_size: int,
|
55 |
+
test_batch_size: int or None,
|
56 |
+
valid_size: int or float or None,
|
57 |
+
n_worker: int,
|
58 |
+
image_size: int or list[int] or str or list[str],
|
59 |
+
num_replicas: int or None = None,
|
60 |
+
rank: int or None = None,
|
61 |
+
train_ratio: float or None = None,
|
62 |
+
drop_last: bool = False,
|
63 |
+
):
|
64 |
+
warnings.filterwarnings("ignore")
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
# batch_size & valid_size
|
68 |
+
self.train_batch_size = train_batch_size
|
69 |
+
self.test_batch_size = test_batch_size or self.train_batch_size
|
70 |
+
self.valid_size = valid_size
|
71 |
+
|
72 |
+
# image size
|
73 |
+
if isinstance(image_size, list):
|
74 |
+
self.image_size = [parse_image_size(size) for size in image_size]
|
75 |
+
self.image_size.sort() # e.g., 160 -> 224
|
76 |
+
RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size)
|
77 |
+
self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1]
|
78 |
+
else:
|
79 |
+
self.image_size = parse_image_size(image_size)
|
80 |
+
RRSController.IMAGE_SIZE_LIST = [self.image_size]
|
81 |
+
self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size
|
82 |
+
|
83 |
+
# distributed configs
|
84 |
+
self.num_replicas = num_replicas
|
85 |
+
self.rank = rank
|
86 |
+
|
87 |
+
# build datasets
|
88 |
+
train_dataset, val_dataset, test_dataset = self.build_datasets()
|
89 |
+
|
90 |
+
if train_ratio is not None and train_ratio < 1.0:
|
91 |
+
assert 0 < train_ratio < 1
|
92 |
+
_, train_dataset = random_drop_data(
|
93 |
+
train_dataset,
|
94 |
+
int(train_ratio * len(train_dataset)),
|
95 |
+
self.SUB_SEED,
|
96 |
+
self.data_keys,
|
97 |
+
)
|
98 |
+
|
99 |
+
# build data loader
|
100 |
+
self.train = self.build_dataloader(
|
101 |
+
train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True
|
102 |
+
)
|
103 |
+
self.valid = self.build_dataloader(
|
104 |
+
val_dataset, test_batch_size, n_worker, drop_last=False, train=False
|
105 |
+
)
|
106 |
+
self.test = self.build_dataloader(
|
107 |
+
test_dataset, test_batch_size, n_worker, drop_last=False, train=False
|
108 |
+
)
|
109 |
+
if self.valid is None:
|
110 |
+
self.valid = self.test
|
111 |
+
self.sub_train = None
|
112 |
+
|
113 |
+
@property
|
114 |
+
def data_shape(self) -> tuple[int, ...]:
|
115 |
+
return 3, self.active_image_size[0], self.active_image_size[1]
|
116 |
+
|
117 |
+
def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any:
|
118 |
+
raise NotImplementedError
|
119 |
+
|
120 |
+
def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any:
|
121 |
+
raise NotImplementedError
|
122 |
+
|
123 |
+
def build_datasets(self) -> tuple[any, any, any]:
|
124 |
+
raise NotImplementedError
|
125 |
+
|
126 |
+
def build_dataloader(
|
127 |
+
self,
|
128 |
+
dataset: any or None,
|
129 |
+
batch_size: int,
|
130 |
+
n_worker: int,
|
131 |
+
drop_last: bool,
|
132 |
+
train: bool,
|
133 |
+
):
|
134 |
+
if dataset is None:
|
135 |
+
return None
|
136 |
+
if isinstance(self.image_size, list) and train:
|
137 |
+
from efficientvit.apps.data_provider.random_resolution._data_loader import \
|
138 |
+
RRSDataLoader
|
139 |
+
|
140 |
+
dataloader_class = RRSDataLoader
|
141 |
+
else:
|
142 |
+
dataloader_class = torch.utils.data.DataLoader
|
143 |
+
if self.num_replicas is None:
|
144 |
+
return dataloader_class(
|
145 |
+
dataset=dataset,
|
146 |
+
batch_size=batch_size,
|
147 |
+
shuffle=True,
|
148 |
+
num_workers=n_worker,
|
149 |
+
pin_memory=True,
|
150 |
+
drop_last=drop_last,
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
sampler = DistributedSampler(dataset, self.num_replicas, self.rank)
|
154 |
+
return dataloader_class(
|
155 |
+
dataset=dataset,
|
156 |
+
batch_size=batch_size,
|
157 |
+
sampler=sampler,
|
158 |
+
num_workers=n_worker,
|
159 |
+
pin_memory=True,
|
160 |
+
drop_last=drop_last,
|
161 |
+
)
|
162 |
+
|
163 |
+
def set_epoch(self, epoch: int) -> None:
|
164 |
+
RRSController.set_epoch(epoch, len(self.train))
|
165 |
+
if isinstance(self.train.sampler, DistributedSampler):
|
166 |
+
self.train.sampler.set_epoch(epoch)
|
167 |
+
|
168 |
+
def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None:
|
169 |
+
self.active_image_size = val2tuple(new_size, 2)
|
170 |
+
new_transform = self.build_valid_transform(self.active_image_size)
|
171 |
+
# change the transform of the valid and test set
|
172 |
+
self.valid.dataset.transform = self.test.dataset.transform = new_transform
|
173 |
+
|
174 |
+
def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]:
|
175 |
+
if self.valid_size is not None:
|
176 |
+
if 0 < self.valid_size < 1:
|
177 |
+
valid_size = int(self.valid_size * len(train_dataset))
|
178 |
+
else:
|
179 |
+
assert self.valid_size >= 1
|
180 |
+
valid_size = int(self.valid_size)
|
181 |
+
train_dataset, val_dataset = random_drop_data(
|
182 |
+
train_dataset,
|
183 |
+
valid_size,
|
184 |
+
self.VALID_SEED,
|
185 |
+
self.data_keys,
|
186 |
+
)
|
187 |
+
val_dataset.transform = valid_transform
|
188 |
+
else:
|
189 |
+
val_dataset = None
|
190 |
+
return train_dataset, val_dataset
|
191 |
+
|
192 |
+
def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any:
|
193 |
+
# used for resetting BN running statistics
|
194 |
+
if self.sub_train is None:
|
195 |
+
self.sub_train = {}
|
196 |
+
if self.active_image_size in self.sub_train:
|
197 |
+
return self.sub_train[self.active_image_size]
|
198 |
+
|
199 |
+
# construct dataset and dataloader
|
200 |
+
train_dataset = copy.deepcopy(self.train.dataset)
|
201 |
+
if n_samples < len(train_dataset):
|
202 |
+
_, train_dataset = random_drop_data(
|
203 |
+
train_dataset,
|
204 |
+
n_samples,
|
205 |
+
self.SUB_SEED,
|
206 |
+
self.data_keys,
|
207 |
+
)
|
208 |
+
RRSController.ACTIVE_SIZE = self.active_image_size
|
209 |
+
train_dataset.transform = self.build_train_transform(
|
210 |
+
image_size=self.active_image_size
|
211 |
+
)
|
212 |
+
data_loader = self.build_dataloader(
|
213 |
+
train_dataset, batch_size, self.train.num_workers, True, False
|
214 |
+
)
|
215 |
+
|
216 |
+
# pre-fetch data
|
217 |
+
self.sub_train[self.active_image_size] = [
|
218 |
+
data
|
219 |
+
for data in data_loader
|
220 |
+
for _ in range(max(1, n_samples // len(train_dataset)))
|
221 |
+
]
|
222 |
+
|
223 |
+
return self.sub_train[self.active_image_size]
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Random resolution data loader compatible with multi-processing and distributed training.
|
2 |
+
|
3 |
+
Replace Pytorch's DataLoader with RRSDataLoader to support random resolution
|
4 |
+
at the training time, resolution sampling is controlled by RRSController
|
5 |
+
"""
|
6 |
+
|
7 |
+
from .controller import *
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (527 Bytes). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__pycache__/controller.cpython-311.pyc
ADDED
Binary file (5.7 kB). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/_data_loader.py
ADDED
@@ -0,0 +1,1603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""This file is based on torch/utils/data/data_loader.py
|
2 |
+
|
3 |
+
Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
|
4 |
+
|
5 |
+
To support these two classes, in `./_utils` we define many utility methods and
|
6 |
+
functions to be run in multiprocessing. E.g., the data loading worker loop is
|
7 |
+
in `./_utils/worker.py`.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import functools
|
11 |
+
import itertools
|
12 |
+
import logging
|
13 |
+
import multiprocessing as python_multiprocessing
|
14 |
+
import os
|
15 |
+
import queue
|
16 |
+
import threading
|
17 |
+
import warnings
|
18 |
+
from typing import (Any, Callable, Generic, Iterable, List, Optional, Sequence,
|
19 |
+
TypeVar, Union)
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.distributed as dist
|
23 |
+
import torch.multiprocessing as multiprocessing
|
24 |
+
import torch.utils.data.graph_settings
|
25 |
+
from torch._utils import ExceptionWrapper
|
26 |
+
from torch.utils.data import (BatchSampler, Dataset, IterableDataset,
|
27 |
+
IterDataPipe, MapDataPipe, RandomSampler,
|
28 |
+
Sampler, SequentialSampler, _utils)
|
29 |
+
from torch.utils.data.datapipes.datapipe import (
|
30 |
+
_IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper)
|
31 |
+
|
32 |
+
from ._data_worker import _worker_loop
|
33 |
+
|
34 |
+
__all__ = ["RRSDataLoader"]
|
35 |
+
|
36 |
+
T_co = TypeVar("T_co", covariant=True)
|
37 |
+
T = TypeVar("T")
|
38 |
+
_worker_init_fn_t = Callable[[int], None]
|
39 |
+
|
40 |
+
# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
|
41 |
+
# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
|
42 |
+
# See https://github.com/python/mypy/issues/3737.
|
43 |
+
_collate_fn_t = Callable[[List[T]], Any]
|
44 |
+
|
45 |
+
|
46 |
+
# These functions used to be defined in this file. However, it was moved to
|
47 |
+
# _utils/collate.py. Although it is rather hard to access this from user land
|
48 |
+
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
|
49 |
+
# probably is user code out there using it. This aliasing maintains BC in this
|
50 |
+
# aspect.
|
51 |
+
default_collate: _collate_fn_t = _utils.collate.default_collate
|
52 |
+
default_convert = _utils.collate.default_convert
|
53 |
+
|
54 |
+
get_worker_info = _utils.worker.get_worker_info
|
55 |
+
|
56 |
+
logger = logging.getLogger(__name__)
|
57 |
+
|
58 |
+
|
59 |
+
class _DatasetKind:
|
60 |
+
Map = 0
|
61 |
+
Iterable = 1
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
|
65 |
+
if kind == _DatasetKind.Map:
|
66 |
+
return _utils.fetch._MapDatasetFetcher(
|
67 |
+
dataset, auto_collation, collate_fn, drop_last
|
68 |
+
)
|
69 |
+
else:
|
70 |
+
return _utils.fetch._IterableDatasetFetcher(
|
71 |
+
dataset, auto_collation, collate_fn, drop_last
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
class _InfiniteConstantSampler(Sampler):
|
76 |
+
r"""Analogous to ``itertools.repeat(None, None)``.
|
77 |
+
Used as sampler for :class:`~torch.utils.data.IterableDataset`.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
data_source (Dataset): dataset to sample from
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(self):
|
84 |
+
super().__init__(None)
|
85 |
+
|
86 |
+
def __iter__(self):
|
87 |
+
while True:
|
88 |
+
yield None
|
89 |
+
|
90 |
+
|
91 |
+
def _get_distributed_settings():
|
92 |
+
if dist.is_available() and dist.is_initialized():
|
93 |
+
return dist.get_world_size(), dist.get_rank()
|
94 |
+
else:
|
95 |
+
return 1, 0
|
96 |
+
|
97 |
+
|
98 |
+
def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
|
99 |
+
global_worker_id = worker_id
|
100 |
+
info = torch.utils.data.get_worker_info()
|
101 |
+
assert info is not None
|
102 |
+
total_workers = info.num_workers
|
103 |
+
datapipe = info.dataset
|
104 |
+
assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
|
105 |
+
# To distribute elements across distributed process evenly, we should shard data on distributed
|
106 |
+
# processes first then shard on worker processes
|
107 |
+
total_workers *= world_size
|
108 |
+
global_worker_id = global_worker_id * world_size + rank_id
|
109 |
+
# For BC, use default SHARDING_PRIORITIES
|
110 |
+
torch.utils.data.graph_settings.apply_sharding(
|
111 |
+
datapipe, total_workers, global_worker_id
|
112 |
+
)
|
113 |
+
if worker_init_fn is not None:
|
114 |
+
worker_init_fn(worker_id)
|
115 |
+
|
116 |
+
|
117 |
+
def _share_dist_seed(generator, pg):
|
118 |
+
_shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator)
|
119 |
+
if isinstance(pg, dist.ProcessGroup):
|
120 |
+
dist.broadcast(_shared_seed, src=0, group=pg)
|
121 |
+
return _shared_seed.item()
|
122 |
+
|
123 |
+
|
124 |
+
class RRSDataLoader(Generic[T_co]):
|
125 |
+
r"""
|
126 |
+
Data loader. Combines a dataset and a sampler, and provides an iterable over
|
127 |
+
the given dataset.
|
128 |
+
|
129 |
+
The :class:`~torch.utils.data.DataLoader` supports both map-style and
|
130 |
+
iterable-style datasets with single- or multi-process loading, customizing
|
131 |
+
loading order and optional automatic batching (collation) and memory pinning.
|
132 |
+
|
133 |
+
See :py:mod:`torch.utils.data` documentation page for more details.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
dataset (Dataset): dataset from which to load the data.
|
137 |
+
batch_size (int, optional): how many samples per batch to load
|
138 |
+
(default: ``1``).
|
139 |
+
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
140 |
+
at every epoch (default: ``False``).
|
141 |
+
sampler (Sampler or Iterable, optional): defines the strategy to draw
|
142 |
+
samples from the dataset. Can be any ``Iterable`` with ``__len__``
|
143 |
+
implemented. If specified, :attr:`shuffle` must not be specified.
|
144 |
+
batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
|
145 |
+
returns a batch of indices at a time. Mutually exclusive with
|
146 |
+
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
|
147 |
+
and :attr:`drop_last`.
|
148 |
+
num_workers (int, optional): how many subprocesses to use for data
|
149 |
+
loading. ``0`` means that the data will be loaded in the main process.
|
150 |
+
(default: ``0``)
|
151 |
+
collate_fn (Callable, optional): merges a list of samples to form a
|
152 |
+
mini-batch of Tensor(s). Used when using batched loading from a
|
153 |
+
map-style dataset.
|
154 |
+
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
|
155 |
+
into device/CUDA pinned memory before returning them. If your data elements
|
156 |
+
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
|
157 |
+
see the example below.
|
158 |
+
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
159 |
+
if the dataset size is not divisible by the batch size. If ``False`` and
|
160 |
+
the size of dataset is not divisible by the batch size, then the last batch
|
161 |
+
will be smaller. (default: ``False``)
|
162 |
+
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
163 |
+
from workers. Should always be non-negative. (default: ``0``)
|
164 |
+
worker_init_fn (Callable, optional): If not ``None``, this will be called on each
|
165 |
+
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
166 |
+
input, after seeding and before data loading. (default: ``None``)
|
167 |
+
generator (torch.Generator, optional): If not ``None``, this RNG will be used
|
168 |
+
by RandomSampler to generate random indexes and multiprocessing to generate
|
169 |
+
`base_seed` for workers. (default: ``None``)
|
170 |
+
prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
|
171 |
+
in advance by each worker. ``2`` means there will be a total of
|
172 |
+
2 * num_workers batches prefetched across all workers. (default value depends
|
173 |
+
on the set value for num_workers. If value of num_workers=0 default is ``None``.
|
174 |
+
Otherwise if value of num_workers>0 default is ``2``).
|
175 |
+
persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
|
176 |
+
the worker processes after a dataset has been consumed once. This allows to
|
177 |
+
maintain the workers `Dataset` instances alive. (default: ``False``)
|
178 |
+
pin_memory_device (str, optional): the data loader will copy Tensors
|
179 |
+
into device pinned memory before returning them if pin_memory is set to true.
|
180 |
+
|
181 |
+
|
182 |
+
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
|
183 |
+
cannot be an unpicklable object, e.g., a lambda function. See
|
184 |
+
:ref:`multiprocessing-best-practices` on more details related
|
185 |
+
to multiprocessing in PyTorch.
|
186 |
+
|
187 |
+
.. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
|
188 |
+
When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
|
189 |
+
it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
|
190 |
+
rounding depending on :attr:`drop_last`, regardless of multi-process loading
|
191 |
+
configurations. This represents the best guess PyTorch can make because PyTorch
|
192 |
+
trusts user :attr:`dataset` code in correctly handling multi-process
|
193 |
+
loading to avoid duplicate data.
|
194 |
+
|
195 |
+
However, if sharding results in multiple workers having incomplete last batches,
|
196 |
+
this estimate can still be inaccurate, because (1) an otherwise complete batch can
|
197 |
+
be broken into multiple ones and (2) more than one batch worth of samples can be
|
198 |
+
dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
|
199 |
+
cases in general.
|
200 |
+
|
201 |
+
See `Dataset Types`_ for more details on these two types of datasets and how
|
202 |
+
:class:`~torch.utils.data.IterableDataset` interacts with
|
203 |
+
`Multi-process data loading`_.
|
204 |
+
|
205 |
+
.. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
|
206 |
+
:ref:`data-loading-randomness` notes for random seed related questions.
|
207 |
+
"""
|
208 |
+
|
209 |
+
dataset: Dataset[T_co]
|
210 |
+
batch_size: Optional[int]
|
211 |
+
num_workers: int
|
212 |
+
pin_memory: bool
|
213 |
+
drop_last: bool
|
214 |
+
timeout: float
|
215 |
+
sampler: Union[Sampler, Iterable]
|
216 |
+
pin_memory_device: str
|
217 |
+
prefetch_factor: Optional[int]
|
218 |
+
_iterator: Optional["_BaseDataLoaderIter"]
|
219 |
+
__initialized = False
|
220 |
+
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
dataset: Dataset[T_co],
|
224 |
+
batch_size: Optional[int] = 1,
|
225 |
+
shuffle: Optional[bool] = None,
|
226 |
+
sampler: Union[Sampler, Iterable, None] = None,
|
227 |
+
batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
|
228 |
+
num_workers: int = 0,
|
229 |
+
collate_fn: Optional[_collate_fn_t] = None,
|
230 |
+
pin_memory: bool = False,
|
231 |
+
drop_last: bool = False,
|
232 |
+
timeout: float = 0,
|
233 |
+
worker_init_fn: Optional[_worker_init_fn_t] = None,
|
234 |
+
multiprocessing_context=None,
|
235 |
+
generator=None,
|
236 |
+
*,
|
237 |
+
prefetch_factor: Optional[int] = None,
|
238 |
+
persistent_workers: bool = False,
|
239 |
+
pin_memory_device: str = ""
|
240 |
+
):
|
241 |
+
torch._C._log_api_usage_once("python.data_loader")
|
242 |
+
|
243 |
+
if num_workers < 0:
|
244 |
+
raise ValueError(
|
245 |
+
"num_workers option should be non-negative; "
|
246 |
+
"use num_workers=0 to disable multiprocessing."
|
247 |
+
)
|
248 |
+
|
249 |
+
if timeout < 0:
|
250 |
+
raise ValueError("timeout option should be non-negative")
|
251 |
+
|
252 |
+
if num_workers == 0 and prefetch_factor is not None:
|
253 |
+
raise ValueError(
|
254 |
+
"prefetch_factor option could only be specified in multiprocessing."
|
255 |
+
"let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None."
|
256 |
+
)
|
257 |
+
elif num_workers > 0 and prefetch_factor is None:
|
258 |
+
prefetch_factor = 2
|
259 |
+
elif prefetch_factor is not None and prefetch_factor < 0:
|
260 |
+
raise ValueError("prefetch_factor option should be non-negative")
|
261 |
+
|
262 |
+
if persistent_workers and num_workers == 0:
|
263 |
+
raise ValueError("persistent_workers option needs num_workers > 0")
|
264 |
+
|
265 |
+
self.dataset = dataset
|
266 |
+
self.num_workers = num_workers
|
267 |
+
self.prefetch_factor = prefetch_factor
|
268 |
+
self.pin_memory = pin_memory
|
269 |
+
self.pin_memory_device = pin_memory_device
|
270 |
+
self.timeout = timeout
|
271 |
+
self.worker_init_fn = worker_init_fn
|
272 |
+
self.multiprocessing_context = multiprocessing_context
|
273 |
+
|
274 |
+
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
275 |
+
# _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
|
276 |
+
if isinstance(self.dataset, IterDataPipe):
|
277 |
+
self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
|
278 |
+
elif isinstance(self.dataset, MapDataPipe):
|
279 |
+
self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
|
280 |
+
|
281 |
+
# Arg-check dataset related before checking samplers because we want to
|
282 |
+
# tell users that iterable-style datasets are incompatible with custom
|
283 |
+
# samplers first, so that they don't learn that this combo doesn't work
|
284 |
+
# after spending time fixing the custom sampler errors.
|
285 |
+
if isinstance(dataset, IterableDataset):
|
286 |
+
self._dataset_kind = _DatasetKind.Iterable
|
287 |
+
# NOTE [ Custom Samplers and IterableDataset ]
|
288 |
+
#
|
289 |
+
# `IterableDataset` does not support custom `batch_sampler` or
|
290 |
+
# `sampler` since the key is irrelevant (unless we support
|
291 |
+
# generator-style dataset one day...).
|
292 |
+
#
|
293 |
+
# For `sampler`, we always create a dummy sampler. This is an
|
294 |
+
# infinite sampler even when the dataset may have an implemented
|
295 |
+
# finite `__len__` because in multi-process data loading, naive
|
296 |
+
# settings will return duplicated data (which may be desired), and
|
297 |
+
# thus using a sampler with length matching that of dataset will
|
298 |
+
# cause data lost (you may have duplicates of the first couple
|
299 |
+
# batches, but never see anything afterwards). Therefore,
|
300 |
+
# `Iterabledataset` always uses an infinite sampler, an instance of
|
301 |
+
# `_InfiniteConstantSampler` defined above.
|
302 |
+
#
|
303 |
+
# A custom `batch_sampler` essentially only controls the batch size.
|
304 |
+
# However, it is unclear how useful it would be since an iterable-style
|
305 |
+
# dataset can handle that within itself. Moreover, it is pointless
|
306 |
+
# in multi-process data loading as the assignment order of batches
|
307 |
+
# to workers is an implementation detail so users can not control
|
308 |
+
# how to batchify each worker's iterable. Thus, we disable this
|
309 |
+
# option. If this turns out to be useful in future, we can re-enable
|
310 |
+
# this, and support custom samplers that specify the assignments to
|
311 |
+
# specific workers.
|
312 |
+
if isinstance(dataset, IterDataPipe):
|
313 |
+
if shuffle is not None:
|
314 |
+
dataset = torch.utils.data.graph_settings.apply_shuffle_settings(
|
315 |
+
dataset, shuffle=shuffle
|
316 |
+
)
|
317 |
+
# We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
|
318 |
+
elif shuffle not in {False, None}:
|
319 |
+
raise ValueError(
|
320 |
+
"DataLoader with IterableDataset: expected unspecified "
|
321 |
+
"shuffle option, but got shuffle={}".format(shuffle)
|
322 |
+
)
|
323 |
+
|
324 |
+
if sampler is not None:
|
325 |
+
# See NOTE [ Custom Samplers and IterableDataset ]
|
326 |
+
raise ValueError(
|
327 |
+
"DataLoader with IterableDataset: expected unspecified "
|
328 |
+
"sampler option, but got sampler={}".format(sampler)
|
329 |
+
)
|
330 |
+
elif batch_sampler is not None:
|
331 |
+
# See NOTE [ Custom Samplers and IterableDataset ]
|
332 |
+
raise ValueError(
|
333 |
+
"DataLoader with IterableDataset: expected unspecified "
|
334 |
+
"batch_sampler option, but got batch_sampler={}".format(
|
335 |
+
batch_sampler
|
336 |
+
)
|
337 |
+
)
|
338 |
+
else:
|
339 |
+
shuffle = bool(shuffle)
|
340 |
+
self._dataset_kind = _DatasetKind.Map
|
341 |
+
|
342 |
+
if sampler is not None and shuffle:
|
343 |
+
raise ValueError("sampler option is mutually exclusive with " "shuffle")
|
344 |
+
|
345 |
+
if batch_sampler is not None:
|
346 |
+
# auto_collation with custom batch_sampler
|
347 |
+
if batch_size != 1 or shuffle or sampler is not None or drop_last:
|
348 |
+
raise ValueError(
|
349 |
+
"batch_sampler option is mutually exclusive "
|
350 |
+
"with batch_size, shuffle, sampler, and "
|
351 |
+
"drop_last"
|
352 |
+
)
|
353 |
+
batch_size = None
|
354 |
+
drop_last = False
|
355 |
+
elif batch_size is None:
|
356 |
+
# no auto_collation
|
357 |
+
if drop_last:
|
358 |
+
raise ValueError(
|
359 |
+
"batch_size=None option disables auto-batching "
|
360 |
+
"and is mutually exclusive with drop_last"
|
361 |
+
)
|
362 |
+
|
363 |
+
if sampler is None: # give default samplers
|
364 |
+
if self._dataset_kind == _DatasetKind.Iterable:
|
365 |
+
# See NOTE [ Custom Samplers and IterableDataset ]
|
366 |
+
sampler = _InfiniteConstantSampler()
|
367 |
+
else: # map-style
|
368 |
+
if shuffle:
|
369 |
+
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
|
370 |
+
else:
|
371 |
+
sampler = SequentialSampler(dataset) # type: ignore[arg-type]
|
372 |
+
|
373 |
+
if batch_size is not None and batch_sampler is None:
|
374 |
+
# auto_collation without custom batch_sampler
|
375 |
+
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
|
376 |
+
|
377 |
+
self.batch_size = batch_size
|
378 |
+
self.drop_last = drop_last
|
379 |
+
self.sampler = sampler
|
380 |
+
self.batch_sampler = batch_sampler
|
381 |
+
self.generator = generator
|
382 |
+
|
383 |
+
if collate_fn is None:
|
384 |
+
if self._auto_collation:
|
385 |
+
collate_fn = _utils.collate.default_collate
|
386 |
+
else:
|
387 |
+
collate_fn = _utils.collate.default_convert
|
388 |
+
|
389 |
+
self.collate_fn = collate_fn
|
390 |
+
self.persistent_workers = persistent_workers
|
391 |
+
|
392 |
+
self.__initialized = True
|
393 |
+
self._IterableDataset_len_called = (
|
394 |
+
None # See NOTE [ IterableDataset and __len__ ]
|
395 |
+
)
|
396 |
+
|
397 |
+
self._iterator = None
|
398 |
+
|
399 |
+
self.check_worker_number_rationality()
|
400 |
+
|
401 |
+
torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined]
|
402 |
+
|
403 |
+
def _get_iterator(self) -> "_BaseDataLoaderIter":
|
404 |
+
if self.num_workers == 0:
|
405 |
+
return _SingleProcessDataLoaderIter(self)
|
406 |
+
else:
|
407 |
+
self.check_worker_number_rationality()
|
408 |
+
return _MultiProcessingDataLoaderIter(self)
|
409 |
+
|
410 |
+
@property
|
411 |
+
def multiprocessing_context(self):
|
412 |
+
return self.__multiprocessing_context
|
413 |
+
|
414 |
+
@multiprocessing_context.setter
|
415 |
+
def multiprocessing_context(self, multiprocessing_context):
|
416 |
+
if multiprocessing_context is not None:
|
417 |
+
if self.num_workers > 0:
|
418 |
+
if isinstance(multiprocessing_context, str):
|
419 |
+
valid_start_methods = multiprocessing.get_all_start_methods()
|
420 |
+
if multiprocessing_context not in valid_start_methods:
|
421 |
+
raise ValueError(
|
422 |
+
(
|
423 |
+
"multiprocessing_context option "
|
424 |
+
"should specify a valid start method in {!r}, but got "
|
425 |
+
"multiprocessing_context={!r}"
|
426 |
+
).format(valid_start_methods, multiprocessing_context)
|
427 |
+
)
|
428 |
+
multiprocessing_context = multiprocessing.get_context(
|
429 |
+
multiprocessing_context
|
430 |
+
)
|
431 |
+
|
432 |
+
if not isinstance(
|
433 |
+
multiprocessing_context, python_multiprocessing.context.BaseContext
|
434 |
+
):
|
435 |
+
raise TypeError(
|
436 |
+
(
|
437 |
+
"multiprocessing_context option should be a valid context "
|
438 |
+
"object or a string specifying the start method, but got "
|
439 |
+
"multiprocessing_context={}"
|
440 |
+
).format(multiprocessing_context)
|
441 |
+
)
|
442 |
+
else:
|
443 |
+
raise ValueError(
|
444 |
+
(
|
445 |
+
"multiprocessing_context can only be used with "
|
446 |
+
"multi-process loading (num_workers > 0), but got "
|
447 |
+
"num_workers={}"
|
448 |
+
).format(self.num_workers)
|
449 |
+
)
|
450 |
+
|
451 |
+
self.__multiprocessing_context = multiprocessing_context
|
452 |
+
|
453 |
+
def __setattr__(self, attr, val):
|
454 |
+
if self.__initialized and attr in (
|
455 |
+
"batch_size",
|
456 |
+
"batch_sampler",
|
457 |
+
"sampler",
|
458 |
+
"drop_last",
|
459 |
+
"dataset",
|
460 |
+
"persistent_workers",
|
461 |
+
):
|
462 |
+
raise ValueError(
|
463 |
+
"{} attribute should not be set after {} is "
|
464 |
+
"initialized".format(attr, self.__class__.__name__)
|
465 |
+
)
|
466 |
+
|
467 |
+
super().__setattr__(attr, val)
|
468 |
+
|
469 |
+
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
|
470 |
+
# since '_BaseDataLoaderIter' references 'DataLoader'.
|
471 |
+
def __iter__(self) -> "_BaseDataLoaderIter":
|
472 |
+
# When using a single worker the returned iterator should be
|
473 |
+
# created everytime to avoid reseting its state
|
474 |
+
# However, in the case of a multiple workers iterator
|
475 |
+
# the iterator is only created once in the lifetime of the
|
476 |
+
# DataLoader object so that workers can be reused
|
477 |
+
if self.persistent_workers and self.num_workers > 0:
|
478 |
+
if self._iterator is None:
|
479 |
+
self._iterator = self._get_iterator()
|
480 |
+
else:
|
481 |
+
self._iterator._reset(self)
|
482 |
+
return self._iterator
|
483 |
+
else:
|
484 |
+
return self._get_iterator()
|
485 |
+
|
486 |
+
@property
|
487 |
+
def _auto_collation(self):
|
488 |
+
return self.batch_sampler is not None
|
489 |
+
|
490 |
+
@property
|
491 |
+
def _index_sampler(self):
|
492 |
+
# The actual sampler used for generating indices for `_DatasetFetcher`
|
493 |
+
# (see _utils/fetch.py) to read data at each time. This would be
|
494 |
+
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
|
495 |
+
# We can't change `.sampler` and `.batch_sampler` attributes for BC
|
496 |
+
# reasons.
|
497 |
+
if self._auto_collation:
|
498 |
+
return self.batch_sampler
|
499 |
+
else:
|
500 |
+
return self.sampler
|
501 |
+
|
502 |
+
def __len__(self) -> int:
|
503 |
+
if self._dataset_kind == _DatasetKind.Iterable:
|
504 |
+
# NOTE [ IterableDataset and __len__ ]
|
505 |
+
#
|
506 |
+
# For `IterableDataset`, `__len__` could be inaccurate when one naively
|
507 |
+
# does multi-processing data loading, since the samples will be duplicated.
|
508 |
+
# However, no real use case should be actually using that behavior, so
|
509 |
+
# it should count as a user error. We should generally trust user
|
510 |
+
# code to do the proper thing (e.g., configure each replica differently
|
511 |
+
# in `__iter__`), and give us the correct `__len__` if they choose to
|
512 |
+
# implement it (this will still throw if the dataset does not implement
|
513 |
+
# a `__len__`).
|
514 |
+
#
|
515 |
+
# To provide a further warning, we track if `__len__` was called on the
|
516 |
+
# `DataLoader`, save the returned value in `self._len_called`, and warn
|
517 |
+
# if the iterator ends up yielding more than this number of samples.
|
518 |
+
|
519 |
+
# Cannot statically verify that dataset is Sized
|
520 |
+
length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type]
|
521 |
+
if (
|
522 |
+
self.batch_size is not None
|
523 |
+
): # IterableDataset doesn't allow custom sampler or batch_sampler
|
524 |
+
from math import ceil
|
525 |
+
|
526 |
+
if self.drop_last:
|
527 |
+
length = length // self.batch_size
|
528 |
+
else:
|
529 |
+
length = ceil(length / self.batch_size)
|
530 |
+
return length
|
531 |
+
else:
|
532 |
+
return len(self._index_sampler)
|
533 |
+
|
534 |
+
def check_worker_number_rationality(self):
|
535 |
+
# This function check whether the dataloader's worker number is rational based on
|
536 |
+
# current system's resource. Current rule is that if the number of workers this
|
537 |
+
# Dataloader will create is bigger than the number of logical cpus that is allowed to
|
538 |
+
# use, than we will pop up a warning to let user pay attention.
|
539 |
+
#
|
540 |
+
# eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
|
541 |
+
# threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
|
542 |
+
# DataLoader process can use half of them which is 32, then the rational max number of
|
543 |
+
# worker that initiated from this process is 32.
|
544 |
+
# Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
|
545 |
+
# So the warning message is triggered to notify the user to lower the worker number if
|
546 |
+
# necessary.
|
547 |
+
#
|
548 |
+
#
|
549 |
+
# [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
|
550 |
+
# available (available in most of Linux system, but not OSX and Windows).
|
551 |
+
# When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
|
552 |
+
# it doesn't repect cpuset.
|
553 |
+
# We don't take threading into account since each worker process is single threaded
|
554 |
+
# at this time.
|
555 |
+
#
|
556 |
+
# We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
|
557 |
+
# other than `torch.set_num_threads` to 1 in the worker process, if the passing
|
558 |
+
# in functions use 3rd party modules that rely on those threading flags to determine
|
559 |
+
# how many thread to create (eg. numpy, etc), then it is caller's responsibility to
|
560 |
+
# set those flags correctly.
|
561 |
+
def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
|
562 |
+
|
563 |
+
suggested_max_worker_msg = (
|
564 |
+
(
|
565 |
+
(
|
566 |
+
"Our suggested max number of worker in current system is {}{}, which is smaller "
|
567 |
+
"than what this DataLoader is going to create."
|
568 |
+
).format(
|
569 |
+
num_worker_suggest,
|
570 |
+
(
|
571 |
+
""
|
572 |
+
if cpuset_checked
|
573 |
+
else " (`cpuset` is not taken into account)"
|
574 |
+
),
|
575 |
+
)
|
576 |
+
)
|
577 |
+
if num_worker_suggest is not None
|
578 |
+
else (
|
579 |
+
"DataLoader is not able to compute a suggested max number of worker in current system."
|
580 |
+
)
|
581 |
+
)
|
582 |
+
|
583 |
+
warn_msg = (
|
584 |
+
"This DataLoader will create {} worker processes in total. {} "
|
585 |
+
"Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
|
586 |
+
"lower the worker number to avoid potential slowness/freeze if necessary."
|
587 |
+
).format(num_worker_created, suggested_max_worker_msg)
|
588 |
+
return warn_msg
|
589 |
+
|
590 |
+
if not self.num_workers or self.num_workers == 0:
|
591 |
+
return
|
592 |
+
|
593 |
+
# try to compute a suggested max number of worker based on system's resource
|
594 |
+
max_num_worker_suggest = None
|
595 |
+
cpuset_checked = False
|
596 |
+
if hasattr(os, "sched_getaffinity"):
|
597 |
+
try:
|
598 |
+
max_num_worker_suggest = len(os.sched_getaffinity(0))
|
599 |
+
cpuset_checked = True
|
600 |
+
except Exception:
|
601 |
+
pass
|
602 |
+
if max_num_worker_suggest is None:
|
603 |
+
# os.cpu_count() could return Optional[int]
|
604 |
+
# get cpu count first and check None in order to satify mypy check
|
605 |
+
cpu_count = os.cpu_count()
|
606 |
+
if cpu_count is not None:
|
607 |
+
max_num_worker_suggest = cpu_count
|
608 |
+
|
609 |
+
if max_num_worker_suggest is None:
|
610 |
+
warnings.warn(
|
611 |
+
_create_warning_msg(
|
612 |
+
max_num_worker_suggest, self.num_workers, cpuset_checked
|
613 |
+
)
|
614 |
+
)
|
615 |
+
return
|
616 |
+
|
617 |
+
if self.num_workers > max_num_worker_suggest:
|
618 |
+
warnings.warn(
|
619 |
+
_create_warning_msg(
|
620 |
+
max_num_worker_suggest, self.num_workers, cpuset_checked
|
621 |
+
)
|
622 |
+
)
|
623 |
+
|
624 |
+
|
625 |
+
class _BaseDataLoaderIter:
|
626 |
+
def __init__(self, loader: RRSDataLoader) -> None:
|
627 |
+
self._dataset = loader.dataset
|
628 |
+
self._shared_seed = None
|
629 |
+
self._pg = None
|
630 |
+
if isinstance(self._dataset, IterDataPipe):
|
631 |
+
if dist.is_available() and dist.is_initialized():
|
632 |
+
self._pg = dist.new_group(backend="gloo")
|
633 |
+
self._shared_seed = _share_dist_seed(loader.generator, self._pg)
|
634 |
+
shared_rng = torch.Generator()
|
635 |
+
shared_rng.manual_seed(self._shared_seed)
|
636 |
+
self._dataset = torch.utils.data.graph_settings.apply_random_seed(
|
637 |
+
self._dataset, shared_rng
|
638 |
+
)
|
639 |
+
self._dataset_kind = loader._dataset_kind
|
640 |
+
self._IterableDataset_len_called = loader._IterableDataset_len_called
|
641 |
+
self._auto_collation = loader._auto_collation
|
642 |
+
self._drop_last = loader.drop_last
|
643 |
+
self._index_sampler = loader._index_sampler
|
644 |
+
self._num_workers = loader.num_workers
|
645 |
+
ws, rank = _get_distributed_settings()
|
646 |
+
self._world_size = ws
|
647 |
+
self._rank = rank
|
648 |
+
# for other backends, pin_memory_device need to set. if not set
|
649 |
+
# default behaviour is CUDA device. if pin_memory_device is selected
|
650 |
+
# and pin_memory is not set, the default behaviour false.
|
651 |
+
if len(loader.pin_memory_device) == 0:
|
652 |
+
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
|
653 |
+
self._pin_memory_device = None
|
654 |
+
else:
|
655 |
+
if not loader.pin_memory:
|
656 |
+
warn_msg = (
|
657 |
+
"pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
|
658 |
+
"please set pin_memory to true, if you need to use the device pin memory"
|
659 |
+
)
|
660 |
+
warnings.warn(warn_msg)
|
661 |
+
|
662 |
+
self._pin_memory = loader.pin_memory
|
663 |
+
self._pin_memory_device = loader.pin_memory_device
|
664 |
+
self._timeout = loader.timeout
|
665 |
+
self._collate_fn = loader.collate_fn
|
666 |
+
self._sampler_iter = iter(self._index_sampler)
|
667 |
+
self._base_seed = (
|
668 |
+
torch.empty((), dtype=torch.int64)
|
669 |
+
.random_(generator=loader.generator)
|
670 |
+
.item()
|
671 |
+
)
|
672 |
+
self._persistent_workers = loader.persistent_workers
|
673 |
+
self._num_yielded = 0
|
674 |
+
self._profile_name = "enumerate(DataLoader)#{}.__next__".format(
|
675 |
+
self.__class__.__name__
|
676 |
+
)
|
677 |
+
|
678 |
+
def __iter__(self) -> "_BaseDataLoaderIter":
|
679 |
+
return self
|
680 |
+
|
681 |
+
def _reset(self, loader, first_iter=False):
|
682 |
+
self._sampler_iter = iter(self._index_sampler)
|
683 |
+
self._num_yielded = 0
|
684 |
+
self._IterableDataset_len_called = loader._IterableDataset_len_called
|
685 |
+
if isinstance(self._dataset, IterDataPipe):
|
686 |
+
self._shared_seed = _share_dist_seed(loader.generator, self._pg)
|
687 |
+
shared_rng = torch.Generator()
|
688 |
+
shared_rng.manual_seed(self._shared_seed)
|
689 |
+
self._dataset = torch.utils.data.graph_settings.apply_random_seed(
|
690 |
+
self._dataset, shared_rng
|
691 |
+
)
|
692 |
+
|
693 |
+
def _next_index(self):
|
694 |
+
return next(self._sampler_iter) # may raise StopIteration
|
695 |
+
|
696 |
+
def _next_data(self):
|
697 |
+
raise NotImplementedError
|
698 |
+
|
699 |
+
def __next__(self) -> Any:
|
700 |
+
with torch.autograd.profiler.record_function(self._profile_name):
|
701 |
+
if self._sampler_iter is None:
|
702 |
+
# TODO(https://github.com/pytorch/pytorch/issues/76750)
|
703 |
+
self._reset() # type: ignore[call-arg]
|
704 |
+
data = self._next_data()
|
705 |
+
self._num_yielded += 1
|
706 |
+
if (
|
707 |
+
self._dataset_kind == _DatasetKind.Iterable
|
708 |
+
and self._IterableDataset_len_called is not None
|
709 |
+
and self._num_yielded > self._IterableDataset_len_called
|
710 |
+
):
|
711 |
+
warn_msg = (
|
712 |
+
"Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
|
713 |
+
"samples have been fetched. "
|
714 |
+
).format(
|
715 |
+
self._dataset, self._IterableDataset_len_called, self._num_yielded
|
716 |
+
)
|
717 |
+
if self._num_workers > 0:
|
718 |
+
warn_msg += (
|
719 |
+
"For multiprocessing data-loading, this could be caused by not properly configuring the "
|
720 |
+
"IterableDataset replica at each worker. Please see "
|
721 |
+
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples."
|
722 |
+
)
|
723 |
+
warnings.warn(warn_msg)
|
724 |
+
return data
|
725 |
+
|
726 |
+
def __len__(self) -> int:
|
727 |
+
return len(self._index_sampler)
|
728 |
+
|
729 |
+
def __getstate__(self):
|
730 |
+
# TODO: add limited pickling support for sharing an iterator
|
731 |
+
# across multiple threads for HOGWILD.
|
732 |
+
# Probably the best way to do this is by moving the sample pushing
|
733 |
+
# to a separate thread and then just sharing the data queue
|
734 |
+
# but signalling the end is tricky without a non-blocking API
|
735 |
+
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
|
736 |
+
|
737 |
+
|
738 |
+
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
|
739 |
+
def __init__(self, loader):
|
740 |
+
super().__init__(loader)
|
741 |
+
assert self._timeout == 0
|
742 |
+
assert self._num_workers == 0
|
743 |
+
|
744 |
+
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
745 |
+
# Taking care of distributed sharding
|
746 |
+
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
|
747 |
+
# For BC, use default SHARDING_PRIORITIES
|
748 |
+
torch.utils.data.graph_settings.apply_sharding(
|
749 |
+
self._dataset, self._world_size, self._rank
|
750 |
+
)
|
751 |
+
|
752 |
+
self._dataset_fetcher = _DatasetKind.create_fetcher(
|
753 |
+
self._dataset_kind,
|
754 |
+
self._dataset,
|
755 |
+
self._auto_collation,
|
756 |
+
self._collate_fn,
|
757 |
+
self._drop_last,
|
758 |
+
)
|
759 |
+
|
760 |
+
def _next_data(self):
|
761 |
+
index = self._next_index() # may raise StopIteration
|
762 |
+
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
|
763 |
+
if self._pin_memory:
|
764 |
+
data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
|
765 |
+
return data
|
766 |
+
|
767 |
+
|
768 |
+
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
769 |
+
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
|
770 |
+
|
771 |
+
# NOTE [ Data Loader Multiprocessing Shutdown Logic ]
|
772 |
+
#
|
773 |
+
# Preliminary:
|
774 |
+
#
|
775 |
+
# Our data model looks like this (queues are indicated with curly brackets):
|
776 |
+
#
|
777 |
+
# main process ||
|
778 |
+
# | ||
|
779 |
+
# {index_queue} ||
|
780 |
+
# | ||
|
781 |
+
# worker processes || DATA
|
782 |
+
# | ||
|
783 |
+
# {worker_result_queue} || FLOW
|
784 |
+
# | ||
|
785 |
+
# pin_memory_thread of main process || DIRECTION
|
786 |
+
# | ||
|
787 |
+
# {data_queue} ||
|
788 |
+
# | ||
|
789 |
+
# data output \/
|
790 |
+
#
|
791 |
+
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
|
792 |
+
# `pin_memory=False`.
|
793 |
+
#
|
794 |
+
#
|
795 |
+
# Terminating multiprocessing logic requires very careful design. In
|
796 |
+
# particular, we need to make sure that
|
797 |
+
#
|
798 |
+
# 1. The iterator gracefully exits the workers when its last reference is
|
799 |
+
# gone or it is depleted.
|
800 |
+
#
|
801 |
+
# In this case, the workers should be gracefully exited because the
|
802 |
+
# main process may still need to continue to run, and we want cleaning
|
803 |
+
# up code in the workers to be executed (e.g., releasing GPU memory).
|
804 |
+
# Naturally, we implement the shutdown logic in `__del__` of
|
805 |
+
# DataLoaderIterator.
|
806 |
+
#
|
807 |
+
# We delay the discussion on the logic in this case until later.
|
808 |
+
#
|
809 |
+
# 2. The iterator exits the workers when the loader process and/or worker
|
810 |
+
# processes exits normally or with error.
|
811 |
+
#
|
812 |
+
# We set all workers and `pin_memory_thread` to have `daemon=True`.
|
813 |
+
#
|
814 |
+
# You may ask, why can't we make the workers non-daemonic, and
|
815 |
+
# gracefully exit using the same logic as we have in `__del__` when the
|
816 |
+
# iterator gets deleted (see 1 above)?
|
817 |
+
#
|
818 |
+
# First of all, `__del__` is **not** guaranteed to be called when
|
819 |
+
# interpreter exits. Even if it is called, by the time it executes,
|
820 |
+
# many Python core library resources may alreay be freed, and even
|
821 |
+
# simple things like acquiring an internal lock of a queue may hang.
|
822 |
+
# Therefore, in this case, we actually need to prevent `__del__` from
|
823 |
+
# being executed, and rely on the automatic termination of daemonic
|
824 |
+
# children.
|
825 |
+
#
|
826 |
+
# Thus, we register an `atexit` hook that sets a global flag
|
827 |
+
# `_utils.python_exit_status`. Since `atexit` hooks are executed in the
|
828 |
+
# reverse order of registration, we are guaranteed that this flag is
|
829 |
+
# set before library resources we use are freed (which, at least in
|
830 |
+
# CPython, is done via an `atexit` handler defined in
|
831 |
+
# `multiprocessing/util.py`
|
832 |
+
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
|
833 |
+
# registered when an object requiring this mechanism is first
|
834 |
+
# created, e.g., `mp.Queue`
|
835 |
+
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
|
836 |
+
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
|
837 |
+
# )
|
838 |
+
#
|
839 |
+
# So in `__del__`, we check if `_utils.python_exit_status` is set or
|
840 |
+
# `None` (freed), and perform no-op if so.
|
841 |
+
#
|
842 |
+
# However, simply letting library clean-up codes run can also be bad,
|
843 |
+
# because such codes (i.e., `multiprocessing.util._exit_function()`)
|
844 |
+
# include join putting threads for `mp.Queue`, which can be blocking.
|
845 |
+
# Hence, the main process putting threads are called with
|
846 |
+
# `cancel_join_thread` at creation. See later section
|
847 |
+
# [ 3b. A process won't hang when putting into a queue; ]
|
848 |
+
# for more details.
|
849 |
+
#
|
850 |
+
# Here are two example cases where library clean-up codes can run
|
851 |
+
# before `__del__` is called:
|
852 |
+
#
|
853 |
+
# 1. If we hold onto a reference to the iterator, it more often
|
854 |
+
# than not tries to do `multiprocessing` library cleaning before
|
855 |
+
# clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)
|
856 |
+
# and thus prevents our cleaning-up code to run first.
|
857 |
+
#
|
858 |
+
# 2. A similar issue araises when a `DataLoader` is used in a subprocess.
|
859 |
+
# When a process ends, it shuts the all its daemonic children
|
860 |
+
# down with a SIGTERM (instead of joining them without a timeout).
|
861 |
+
# Simiarly for threads, but by a different mechanism. This fact,
|
862 |
+
# together with a few implementation details of multiprocessing, forces
|
863 |
+
# us to make workers daemonic. All of our problems arise when a
|
864 |
+
# DataLoader is used in a subprocess, and are caused by multiprocessing
|
865 |
+
# code which looks more or less like this:
|
866 |
+
#
|
867 |
+
# try:
|
868 |
+
# your_function_using_a_dataloader()
|
869 |
+
# finally:
|
870 |
+
# multiprocessing.util._exit_function()
|
871 |
+
#
|
872 |
+
# The joining/termination mentioned above happens inside
|
873 |
+
# `_exit_function()`. Now, if `your_function_using_a_dataloader()`
|
874 |
+
# throws, the stack trace stored in the exception will prevent the
|
875 |
+
# frame which uses `DataLoaderIter` to be freed. If the frame has any
|
876 |
+
# reference to the `DataLoaderIter` (e.g., in a method of the iter),
|
877 |
+
# its `__del__`, which starts the shutdown procedure, will not be
|
878 |
+
# called. That, in turn, means that workers aren't notified. Attempting
|
879 |
+
# to join in `_exit_function` will then result in a hang.
|
880 |
+
#
|
881 |
+
# For context, `_exit_function` is also registered as an `atexit` call.
|
882 |
+
# So it is unclear to me (@ssnl) why this is needed in a finally block.
|
883 |
+
# The code dates back to 2008 and there is no comment on the original
|
884 |
+
# PEP 371 or patch https://bugs.python.org/issue3050 (containing both
|
885 |
+
# the finally block and the `atexit` registration) that explains this.
|
886 |
+
#
|
887 |
+
#
|
888 |
+
# Finally, another choice is to just shutdown workers with logic in 1
|
889 |
+
# above whenever we see an error in `next`. This isn't ideal because
|
890 |
+
# a. It prevents users from using try-catch to resume data loading.
|
891 |
+
# b. It doesn't prevent hanging if users have references to the
|
892 |
+
# iterator.
|
893 |
+
#
|
894 |
+
# 3. All processes exit if any of them die unexpectedly by fatal signals.
|
895 |
+
#
|
896 |
+
# As shown above, the workers are set as daemonic children of the main
|
897 |
+
# process. However, automatic cleaning-up of such child processes only
|
898 |
+
# happens if the parent process exits gracefully (e.g., not via fatal
|
899 |
+
# signals like SIGKILL). So we must ensure that each process will exit
|
900 |
+
# even the process that should send/receive data to/from it were
|
901 |
+
# killed, i.e.,
|
902 |
+
#
|
903 |
+
# a. A process won't hang when getting from a queue.
|
904 |
+
#
|
905 |
+
# Even with carefully designed data dependencies (i.e., a `put()`
|
906 |
+
# always corresponding to a `get()`), hanging on `get()` can still
|
907 |
+
# happen when data in queue is corrupted (e.g., due to
|
908 |
+
# `cancel_join_thread` or unexpected exit).
|
909 |
+
#
|
910 |
+
# For child exit, we set a timeout whenever we try to get data
|
911 |
+
# from `data_queue`, and check the workers' status on each timeout
|
912 |
+
# and error.
|
913 |
+
# See `_DataLoaderiter._get_batch()` and
|
914 |
+
# `_DataLoaderiter._try_get_data()` for details.
|
915 |
+
#
|
916 |
+
# Additionally, for child exit on non-Windows platforms, we also
|
917 |
+
# register a SIGCHLD handler (which is supported on Windows) on
|
918 |
+
# the main process, which checks if any of the workers fail in the
|
919 |
+
# (Python) handler. This is more efficient and faster in detecting
|
920 |
+
# worker failures, compared to only using the above mechanism.
|
921 |
+
# See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
|
922 |
+
#
|
923 |
+
# For `.get()` calls where the sender(s) is not the workers, we
|
924 |
+
# guard them with timeouts, and check the status of the sender
|
925 |
+
# when timeout happens:
|
926 |
+
# + in the workers, the `_utils.worker.ManagerWatchdog` class
|
927 |
+
# checks the status of the main process.
|
928 |
+
# + if `pin_memory=True`, when getting from `pin_memory_thread`,
|
929 |
+
# check `pin_memory_thread` status periodically until `.get()`
|
930 |
+
# returns or see that `pin_memory_thread` died.
|
931 |
+
#
|
932 |
+
# b. A process won't hang when putting into a queue;
|
933 |
+
#
|
934 |
+
# We use `mp.Queue` which has a separate background thread to put
|
935 |
+
# objects from an unbounded buffer array. The background thread is
|
936 |
+
# daemonic and usually automatically joined when the process
|
937 |
+
# *exits*.
|
938 |
+
#
|
939 |
+
# In case that the receiver has ended abruptly while
|
940 |
+
# reading from the pipe, the join will hang forever. The usual
|
941 |
+
# solution for this in Python is calling `q.cancel_join_thread`,
|
942 |
+
# which prevents automatically joining it when finalizing
|
943 |
+
# (exiting).
|
944 |
+
#
|
945 |
+
# Nonetheless, `cancel_join_thread` must only be called when the
|
946 |
+
# queue is **not** going to be read from or write into by another
|
947 |
+
# process, because it may hold onto a lock or leave corrupted data
|
948 |
+
# in the queue, leading other readers/writers to hang.
|
949 |
+
#
|
950 |
+
# Hence,
|
951 |
+
# + For worker processes, we only do so (for their output
|
952 |
+
# queues, i.e., `worker_result_queue`) before exiting.
|
953 |
+
# + For `pin_memory_thread`, its output queue `data_queue` is a
|
954 |
+
# `queue.Queue` that does blocking `put` if the queue is full.
|
955 |
+
# So there is no above problem, but as a result, in
|
956 |
+
# `_pin_memory_loop`, we do need to wrap the `put` in a loop
|
957 |
+
# that breaks not only upon success, but also when the main
|
958 |
+
# process stops reading, i.e., is shutting down.
|
959 |
+
# + For loader process, we `cancel_join_thread()` for all
|
960 |
+
# `_index_queues` because the whole purpose of workers and
|
961 |
+
# `pin_memory_thread` is to serve the loader process. If
|
962 |
+
# loader process is already exiting, we don't really care if
|
963 |
+
# the queues are corrupted.
|
964 |
+
#
|
965 |
+
#
|
966 |
+
# Now let's get back to 1:
|
967 |
+
# how we gracefully exit the workers when the last reference to the
|
968 |
+
# iterator is gone.
|
969 |
+
#
|
970 |
+
# To achieve this, we implement the following logic along with the design
|
971 |
+
# choices mentioned above:
|
972 |
+
#
|
973 |
+
# `workers_done_event`:
|
974 |
+
# A `multiprocessing.Event` shared among the main process and all worker
|
975 |
+
# processes. This is used to signal the workers that the iterator is
|
976 |
+
# shutting down. After it is set, they will not send processed data to
|
977 |
+
# queues anymore, and only wait for the final `None` before exiting.
|
978 |
+
# `done_event` isn't strictly needed. I.e., we can just check for `None`
|
979 |
+
# from the input queue, but it allows us to skip wasting resources
|
980 |
+
# processing data if we are already shutting down.
|
981 |
+
#
|
982 |
+
# `pin_memory_thread_done_event`:
|
983 |
+
# A `threading.Event` for a similar purpose to that of
|
984 |
+
# `workers_done_event`, but is for the `pin_memory_thread`. The reason
|
985 |
+
# that separate events are needed is that `pin_memory_thread` reads from
|
986 |
+
# the output queue of the workers. But the workers, upon seeing that
|
987 |
+
# `workers_done_event` is set, only wants to see the final `None`, and is
|
988 |
+
# not required to flush all data in the output queue (e.g., it may call
|
989 |
+
# `cancel_join_thread` on that queue if its `IterableDataset` iterator
|
990 |
+
# happens to exhaust coincidentally, which is out of the control of the
|
991 |
+
# main process). Thus, since we will exit `pin_memory_thread` before the
|
992 |
+
# workers (see below), two separete events are used.
|
993 |
+
#
|
994 |
+
# NOTE: In short, the protocol is that the main process will set these
|
995 |
+
# `done_event`s and then the corresponding processes/threads a `None`,
|
996 |
+
# and that they may exit at any time after receiving the `None`.
|
997 |
+
#
|
998 |
+
# NOTE: Using `None` as the final signal is valid, since normal data will
|
999 |
+
# always be a 2-tuple with the 1st element being the index of the data
|
1000 |
+
# transferred (different from dataset index/key), and the 2nd being
|
1001 |
+
# either the dataset key or the data sample (depending on which part
|
1002 |
+
# of the data model the queue is at).
|
1003 |
+
#
|
1004 |
+
# [ worker processes ]
|
1005 |
+
# While loader process is alive:
|
1006 |
+
# Get from `index_queue`.
|
1007 |
+
# If get anything else,
|
1008 |
+
# Check `workers_done_event`.
|
1009 |
+
# If set, continue to next iteration
|
1010 |
+
# i.e., keep getting until see the `None`, then exit.
|
1011 |
+
# Otherwise, process data:
|
1012 |
+
# If is fetching from an `IterableDataset` and the iterator
|
1013 |
+
# is exhausted, send an `_IterableDatasetStopIteration`
|
1014 |
+
# object to signal iteration end. The main process, upon
|
1015 |
+
# receiving such an object, will send `None` to this
|
1016 |
+
# worker and not use the corresponding `index_queue`
|
1017 |
+
# anymore.
|
1018 |
+
# If timed out,
|
1019 |
+
# No matter `workers_done_event` is set (still need to see `None`)
|
1020 |
+
# or not, must continue to next iteration.
|
1021 |
+
# (outside loop)
|
1022 |
+
# If `workers_done_event` is set, (this can be False with `IterableDataset`)
|
1023 |
+
# `data_queue.cancel_join_thread()`. (Everything is ending here:
|
1024 |
+
# main process won't read from it;
|
1025 |
+
# other workers will also call
|
1026 |
+
# `cancel_join_thread`.)
|
1027 |
+
#
|
1028 |
+
# [ pin_memory_thread ]
|
1029 |
+
# # No need to check main thread. If this thread is alive, the main loader
|
1030 |
+
# # thread must be alive, because this thread is set as daemonic.
|
1031 |
+
# While `pin_memory_thread_done_event` is not set:
|
1032 |
+
# Get from `index_queue`.
|
1033 |
+
# If timed out, continue to get in the next iteration.
|
1034 |
+
# Otherwise, process data.
|
1035 |
+
# While `pin_memory_thread_done_event` is not set:
|
1036 |
+
# Put processed data to `data_queue` (a `queue.Queue` with blocking put)
|
1037 |
+
# If timed out, continue to put in the next iteration.
|
1038 |
+
# Otherwise, break, i.e., continuing to the out loop.
|
1039 |
+
#
|
1040 |
+
# NOTE: we don't check the status of the main thread because
|
1041 |
+
# 1. if the process is killed by fatal signal, `pin_memory_thread`
|
1042 |
+
# ends.
|
1043 |
+
# 2. in other cases, either the cleaning-up in __del__ or the
|
1044 |
+
# automatic exit of daemonic thread will take care of it.
|
1045 |
+
# This won't busy-wait either because `.get(timeout)` does not
|
1046 |
+
# busy-wait.
|
1047 |
+
#
|
1048 |
+
# [ main process ]
|
1049 |
+
# In the DataLoader Iter's `__del__`
|
1050 |
+
# b. Exit `pin_memory_thread`
|
1051 |
+
# i. Set `pin_memory_thread_done_event`.
|
1052 |
+
# ii Put `None` in `worker_result_queue`.
|
1053 |
+
# iii. Join the `pin_memory_thread`.
|
1054 |
+
# iv. `worker_result_queue.cancel_join_thread()`.
|
1055 |
+
#
|
1056 |
+
# c. Exit the workers.
|
1057 |
+
# i. Set `workers_done_event`.
|
1058 |
+
# ii. Put `None` in each worker's `index_queue`.
|
1059 |
+
# iii. Join the workers.
|
1060 |
+
# iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
|
1061 |
+
#
|
1062 |
+
# NOTE: (c) is better placed after (b) because it may leave corrupted
|
1063 |
+
# data in `worker_result_queue`, which `pin_memory_thread`
|
1064 |
+
# reads from, in which case the `pin_memory_thread` can only
|
1065 |
+
# happen at timeing out, which is slow. Nonetheless, same thing
|
1066 |
+
# happens if a worker is killed by signal at unfortunate times,
|
1067 |
+
# but in other cases, we are better off having a non-corrupted
|
1068 |
+
# `worker_result_queue` for `pin_memory_thread`.
|
1069 |
+
#
|
1070 |
+
# NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
|
1071 |
+
# can be omitted
|
1072 |
+
#
|
1073 |
+
# NB: `done_event`s isn't strictly needed. E.g., we can just check for
|
1074 |
+
# `None` from `index_queue`, but it allows us to skip wasting resources
|
1075 |
+
# processing indices already in `index_queue` if we are already shutting
|
1076 |
+
# down.
|
1077 |
+
|
1078 |
+
def __init__(self, loader):
|
1079 |
+
super().__init__(loader)
|
1080 |
+
|
1081 |
+
self._prefetch_factor = loader.prefetch_factor
|
1082 |
+
|
1083 |
+
assert self._num_workers > 0
|
1084 |
+
assert self._prefetch_factor > 0
|
1085 |
+
|
1086 |
+
if loader.multiprocessing_context is None:
|
1087 |
+
multiprocessing_context = multiprocessing
|
1088 |
+
else:
|
1089 |
+
multiprocessing_context = loader.multiprocessing_context
|
1090 |
+
|
1091 |
+
self._worker_init_fn = loader.worker_init_fn
|
1092 |
+
|
1093 |
+
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
1094 |
+
# Additional worker init function will take care of sharding in MP and Distributed
|
1095 |
+
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
|
1096 |
+
self._worker_init_fn = functools.partial(
|
1097 |
+
_sharding_worker_init_fn,
|
1098 |
+
self._worker_init_fn,
|
1099 |
+
self._world_size,
|
1100 |
+
self._rank,
|
1101 |
+
)
|
1102 |
+
|
1103 |
+
# No certainty which module multiprocessing_context is
|
1104 |
+
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
|
1105 |
+
self._worker_pids_set = False
|
1106 |
+
self._shutdown = False
|
1107 |
+
self._workers_done_event = multiprocessing_context.Event()
|
1108 |
+
|
1109 |
+
self._index_queues = []
|
1110 |
+
self._workers = []
|
1111 |
+
for i in range(self._num_workers):
|
1112 |
+
# No certainty which module multiprocessing_context is
|
1113 |
+
index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
|
1114 |
+
# Need to `cancel_join_thread` here!
|
1115 |
+
# See sections (2) and (3b) above.
|
1116 |
+
index_queue.cancel_join_thread()
|
1117 |
+
w = multiprocessing_context.Process(
|
1118 |
+
target=_worker_loop,
|
1119 |
+
args=(
|
1120 |
+
self._dataset_kind,
|
1121 |
+
self._dataset,
|
1122 |
+
index_queue,
|
1123 |
+
self._worker_result_queue,
|
1124 |
+
self._workers_done_event,
|
1125 |
+
self._auto_collation,
|
1126 |
+
self._collate_fn,
|
1127 |
+
self._drop_last,
|
1128 |
+
self._base_seed,
|
1129 |
+
self._worker_init_fn,
|
1130 |
+
i,
|
1131 |
+
self._num_workers,
|
1132 |
+
self._persistent_workers,
|
1133 |
+
self._shared_seed,
|
1134 |
+
),
|
1135 |
+
)
|
1136 |
+
w.daemon = True
|
1137 |
+
# NB: Process.start() actually take some time as it needs to
|
1138 |
+
# start a process and pass the arguments over via a pipe.
|
1139 |
+
# Therefore, we only add a worker to self._workers list after
|
1140 |
+
# it started, so that we do not call .join() if program dies
|
1141 |
+
# before it starts, and __del__ tries to join but will get:
|
1142 |
+
# AssertionError: can only join a started process.
|
1143 |
+
w.start()
|
1144 |
+
self._index_queues.append(index_queue)
|
1145 |
+
self._workers.append(w)
|
1146 |
+
|
1147 |
+
if self._pin_memory:
|
1148 |
+
self._pin_memory_thread_done_event = threading.Event()
|
1149 |
+
|
1150 |
+
# Queue is not type-annotated
|
1151 |
+
self._data_queue = queue.Queue() # type: ignore[var-annotated]
|
1152 |
+
if self._pin_memory_device == "xpu":
|
1153 |
+
current_device = torch.xpu.current_device() # type: ignore[attr-defined]
|
1154 |
+
else:
|
1155 |
+
current_device = torch.cuda.current_device() # choose cuda for default
|
1156 |
+
pin_memory_thread = threading.Thread(
|
1157 |
+
target=_utils.pin_memory._pin_memory_loop,
|
1158 |
+
args=(
|
1159 |
+
self._worker_result_queue,
|
1160 |
+
self._data_queue,
|
1161 |
+
current_device,
|
1162 |
+
self._pin_memory_thread_done_event,
|
1163 |
+
self._pin_memory_device,
|
1164 |
+
),
|
1165 |
+
)
|
1166 |
+
pin_memory_thread.daemon = True
|
1167 |
+
pin_memory_thread.start()
|
1168 |
+
# Similar to workers (see comment above), we only register
|
1169 |
+
# pin_memory_thread once it is started.
|
1170 |
+
self._pin_memory_thread = pin_memory_thread
|
1171 |
+
else:
|
1172 |
+
self._data_queue = self._worker_result_queue
|
1173 |
+
|
1174 |
+
# In some rare cases, persistent workers (daemonic processes)
|
1175 |
+
# would be terminated before `__del__` of iterator is invoked
|
1176 |
+
# when main process exits
|
1177 |
+
# It would cause failure when pin_memory_thread tries to read
|
1178 |
+
# corrupted data from worker_result_queue
|
1179 |
+
# atexit is used to shutdown thread and child processes in the
|
1180 |
+
# right sequence before main process exits
|
1181 |
+
if self._persistent_workers and self._pin_memory:
|
1182 |
+
import atexit
|
1183 |
+
|
1184 |
+
for w in self._workers:
|
1185 |
+
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
|
1186 |
+
|
1187 |
+
# .pid can be None only before process is spawned (not the case, so ignore)
|
1188 |
+
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
|
1189 |
+
_utils.signal_handling._set_SIGCHLD_handler()
|
1190 |
+
self._worker_pids_set = True
|
1191 |
+
self._reset(loader, first_iter=True)
|
1192 |
+
|
1193 |
+
def _reset(self, loader, first_iter=False):
|
1194 |
+
super()._reset(loader, first_iter)
|
1195 |
+
self._send_idx = 0 # idx of the next task to be sent to workers
|
1196 |
+
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
|
1197 |
+
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
|
1198 |
+
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
|
1199 |
+
# \ (worker_id, data) if data is already fetched (out-of-order)
|
1200 |
+
self._task_info = {}
|
1201 |
+
self._tasks_outstanding = (
|
1202 |
+
0 # always equal to count(v for v in task_info.values() if len(v) == 1)
|
1203 |
+
)
|
1204 |
+
# A list of booleans representing whether each worker still has work to
|
1205 |
+
# do, i.e., not having exhausted its iterable dataset object. It always
|
1206 |
+
# contains all `True`s if not using an iterable-style dataset
|
1207 |
+
# (i.e., if kind != Iterable).
|
1208 |
+
# Not that this indicates that a worker still has work to do *for this epoch*.
|
1209 |
+
# It does not mean that a worker is dead. In case of `_persistent_workers`,
|
1210 |
+
# the worker will be reset to available in the next epoch.
|
1211 |
+
self._workers_status = [True for i in range(self._num_workers)]
|
1212 |
+
# Reset the worker queue cycle so it resumes next epoch at worker 0
|
1213 |
+
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
|
1214 |
+
# We resume the prefetching in case it was enabled
|
1215 |
+
if not first_iter:
|
1216 |
+
for idx in range(self._num_workers):
|
1217 |
+
self._index_queues[idx].put(
|
1218 |
+
_utils.worker._ResumeIteration(self._shared_seed)
|
1219 |
+
)
|
1220 |
+
resume_iteration_cnt = self._num_workers
|
1221 |
+
while resume_iteration_cnt > 0:
|
1222 |
+
return_idx, return_data = self._get_data()
|
1223 |
+
if isinstance(return_idx, _utils.worker._ResumeIteration):
|
1224 |
+
assert return_data is None
|
1225 |
+
resume_iteration_cnt -= 1
|
1226 |
+
# prime the prefetch loop
|
1227 |
+
for _ in range(self._prefetch_factor * self._num_workers):
|
1228 |
+
self._try_put_index()
|
1229 |
+
|
1230 |
+
def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
|
1231 |
+
# Tries to fetch data from `self._data_queue` once for a given timeout.
|
1232 |
+
# This can also be used as inner loop of fetching without timeout, with
|
1233 |
+
# the sender status as the loop condition.
|
1234 |
+
#
|
1235 |
+
# This raises a `RuntimeError` if any worker died expectedly. This error
|
1236 |
+
# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
|
1237 |
+
# (only for non-Windows platforms), or the manual check below on errors
|
1238 |
+
# and timeouts.
|
1239 |
+
#
|
1240 |
+
# Returns a 2-tuple:
|
1241 |
+
# (bool: whether successfully get data, any: data if successful else None)
|
1242 |
+
try:
|
1243 |
+
data = self._data_queue.get(timeout=timeout)
|
1244 |
+
return (True, data)
|
1245 |
+
except Exception as e:
|
1246 |
+
# At timeout and error, we manually check whether any worker has
|
1247 |
+
# failed. Note that this is the only mechanism for Windows to detect
|
1248 |
+
# worker failures.
|
1249 |
+
failed_workers = []
|
1250 |
+
for worker_id, w in enumerate(self._workers):
|
1251 |
+
if self._workers_status[worker_id] and not w.is_alive():
|
1252 |
+
failed_workers.append(w)
|
1253 |
+
self._mark_worker_as_unavailable(worker_id)
|
1254 |
+
if len(failed_workers) > 0:
|
1255 |
+
pids_str = ", ".join(str(w.pid) for w in failed_workers)
|
1256 |
+
raise RuntimeError(
|
1257 |
+
"DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str)
|
1258 |
+
) from e
|
1259 |
+
if isinstance(e, queue.Empty):
|
1260 |
+
return (False, None)
|
1261 |
+
import errno
|
1262 |
+
import tempfile
|
1263 |
+
|
1264 |
+
try:
|
1265 |
+
# Raise an exception if we are this close to the FDs limit.
|
1266 |
+
# Apparently, trying to open only one file is not a sufficient
|
1267 |
+
# test.
|
1268 |
+
# See NOTE [ DataLoader on Linux and open files limit ]
|
1269 |
+
fds_limit_margin = 10
|
1270 |
+
fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
|
1271 |
+
except OSError as e:
|
1272 |
+
if e.errno == errno.EMFILE:
|
1273 |
+
raise RuntimeError(
|
1274 |
+
"Too many open files. Communication with the"
|
1275 |
+
" workers is no longer possible. Please increase the"
|
1276 |
+
" limit using `ulimit -n` in the shell or change the"
|
1277 |
+
" sharing strategy by calling"
|
1278 |
+
" `torch.multiprocessing.set_sharing_strategy('file_system')`"
|
1279 |
+
" at the beginning of your code"
|
1280 |
+
) from None
|
1281 |
+
raise
|
1282 |
+
|
1283 |
+
# NOTE [ DataLoader on Linux and open files limit ]
|
1284 |
+
#
|
1285 |
+
# On Linux when DataLoader is used with multiprocessing we pass the data between
|
1286 |
+
# the root process and the workers through SHM files. We remove those files from
|
1287 |
+
# the filesystem as soon as they are created and keep them alive by
|
1288 |
+
# passing around their file descriptors through AF_UNIX sockets. (See
|
1289 |
+
# docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in
|
1290 |
+
# the wiki (https://github.com/pytorch/pytorch/wiki).)
|
1291 |
+
#
|
1292 |
+
# This sometimes leads us to exceeding the open files limit. When that happens,
|
1293 |
+
# and the offending file descriptor is coming over a socket, the `socket` Python
|
1294 |
+
# package silently strips the file descriptor from the message, setting only the
|
1295 |
+
# `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that
|
1296 |
+
# it _indicates that some control data were discarded due to lack of space in
|
1297 |
+
# the buffer for ancillary data_). This might reflect the C implementation of
|
1298 |
+
# AF_UNIX sockets.
|
1299 |
+
#
|
1300 |
+
# This behaviour can be reproduced with the script and instructions at the
|
1301 |
+
# bottom of this note.
|
1302 |
+
#
|
1303 |
+
# When that happens, the standard Python `multiprocessing` (and not
|
1304 |
+
# `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata`
|
1305 |
+
#
|
1306 |
+
# Sometimes, instead of the FD being stripped, you may get an `OSError:
|
1307 |
+
# Too many open files`, both in the script below and in DataLoader. However,
|
1308 |
+
# this is rare and seems to be nondeterministic.
|
1309 |
+
#
|
1310 |
+
#
|
1311 |
+
# #!/usr/bin/env python3
|
1312 |
+
# import sys
|
1313 |
+
# import socket
|
1314 |
+
# import os
|
1315 |
+
# import array
|
1316 |
+
# import shutil
|
1317 |
+
# import socket
|
1318 |
+
#
|
1319 |
+
#
|
1320 |
+
# if len(sys.argv) != 4:
|
1321 |
+
# print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)")
|
1322 |
+
# sys.exit(1)
|
1323 |
+
#
|
1324 |
+
# if __name__ == '__main__':
|
1325 |
+
# dirname = sys.argv[1]
|
1326 |
+
# sock_path = dirname + "/sock"
|
1327 |
+
# iterations = int(sys.argv[2])
|
1328 |
+
# def dummy_path(i):
|
1329 |
+
# return dirname + "/" + str(i) + ".dummy"
|
1330 |
+
#
|
1331 |
+
#
|
1332 |
+
# if sys.argv[3] == 'send':
|
1333 |
+
# while not os.path.exists(sock_path):
|
1334 |
+
# pass
|
1335 |
+
# client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
|
1336 |
+
# client.connect(sock_path)
|
1337 |
+
# for i in range(iterations):
|
1338 |
+
# fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT)
|
1339 |
+
# ancdata = array.array('i', [fd])
|
1340 |
+
# msg = bytes([i % 256])
|
1341 |
+
# print("Sending fd ", fd, " (iteration #", i, ")")
|
1342 |
+
# client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)])
|
1343 |
+
#
|
1344 |
+
#
|
1345 |
+
# else:
|
1346 |
+
# assert sys.argv[3] == 'recv'
|
1347 |
+
#
|
1348 |
+
# if os.path.exists(dirname):
|
1349 |
+
# raise Exception("Directory exists")
|
1350 |
+
#
|
1351 |
+
# os.mkdir(dirname)
|
1352 |
+
#
|
1353 |
+
# print("Opening socket...")
|
1354 |
+
# server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
|
1355 |
+
# server.bind(sock_path)
|
1356 |
+
#
|
1357 |
+
# print("Listening...")
|
1358 |
+
# for i in range(iterations):
|
1359 |
+
# a = array.array('i')
|
1360 |
+
# msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize))
|
1361 |
+
# assert(len(ancdata) == 1)
|
1362 |
+
# cmsg_level, cmsg_type, cmsg_data = ancdata[0]
|
1363 |
+
# a.frombytes(cmsg_data)
|
1364 |
+
# print("Received fd ", a[0], " (iteration #", i, ")")
|
1365 |
+
#
|
1366 |
+
# shutil.rmtree(dirname)
|
1367 |
+
#
|
1368 |
+
# Steps to reproduce:
|
1369 |
+
#
|
1370 |
+
# 1. Run two shells and set lower file descriptor limit in the receiving one:
|
1371 |
+
# (shell1) ulimit -n 1020
|
1372 |
+
# (shell2) ulimit -n 1022
|
1373 |
+
#
|
1374 |
+
# 2. Run the script above with the `recv` option in the first shell
|
1375 |
+
# (shell1) ./test_socket.py sock_tmp 1017 recv
|
1376 |
+
#
|
1377 |
+
# 3. Run the script with the `send` option in the second shell:
|
1378 |
+
# (shell2) ./test_socket.py sock_tmp 1017 send
|
1379 |
+
|
1380 |
+
def _get_data(self):
|
1381 |
+
# Fetches data from `self._data_queue`.
|
1382 |
+
#
|
1383 |
+
# We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
|
1384 |
+
# which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
|
1385 |
+
# in a loop. This is the only mechanism to detect worker failures for
|
1386 |
+
# Windows. For other platforms, a SIGCHLD handler is also used for
|
1387 |
+
# worker failure detection.
|
1388 |
+
#
|
1389 |
+
# If `pin_memory=True`, we also need check if `pin_memory_thread` had
|
1390 |
+
# died at timeouts.
|
1391 |
+
if self._timeout > 0:
|
1392 |
+
success, data = self._try_get_data(self._timeout)
|
1393 |
+
if success:
|
1394 |
+
return data
|
1395 |
+
else:
|
1396 |
+
raise RuntimeError(
|
1397 |
+
"DataLoader timed out after {} seconds".format(self._timeout)
|
1398 |
+
)
|
1399 |
+
elif self._pin_memory:
|
1400 |
+
while self._pin_memory_thread.is_alive():
|
1401 |
+
success, data = self._try_get_data()
|
1402 |
+
if success:
|
1403 |
+
return data
|
1404 |
+
else:
|
1405 |
+
# while condition is false, i.e., pin_memory_thread died.
|
1406 |
+
raise RuntimeError("Pin memory thread exited unexpectedly")
|
1407 |
+
# In this case, `self._data_queue` is a `queue.Queue`,. But we don't
|
1408 |
+
# need to call `.task_done()` because we don't use `.join()`.
|
1409 |
+
else:
|
1410 |
+
while True:
|
1411 |
+
success, data = self._try_get_data()
|
1412 |
+
if success:
|
1413 |
+
return data
|
1414 |
+
|
1415 |
+
def _next_data(self):
|
1416 |
+
while True:
|
1417 |
+
# If the worker responsible for `self._rcvd_idx` has already ended
|
1418 |
+
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
|
1419 |
+
# we try to advance `self._rcvd_idx` to find the next valid index.
|
1420 |
+
#
|
1421 |
+
# This part needs to run in the loop because both the `self._get_data()`
|
1422 |
+
# call and `_IterableDatasetStopIteration` check below can mark
|
1423 |
+
# extra worker(s) as dead.
|
1424 |
+
while self._rcvd_idx < self._send_idx:
|
1425 |
+
info = self._task_info[self._rcvd_idx]
|
1426 |
+
worker_id = info[0]
|
1427 |
+
if (
|
1428 |
+
len(info) == 2 or self._workers_status[worker_id]
|
1429 |
+
): # has data or is still active
|
1430 |
+
break
|
1431 |
+
del self._task_info[self._rcvd_idx]
|
1432 |
+
self._rcvd_idx += 1
|
1433 |
+
else:
|
1434 |
+
# no valid `self._rcvd_idx` is found (i.e., didn't break)
|
1435 |
+
if not self._persistent_workers:
|
1436 |
+
self._shutdown_workers()
|
1437 |
+
raise StopIteration
|
1438 |
+
|
1439 |
+
# Now `self._rcvd_idx` is the batch index we want to fetch
|
1440 |
+
|
1441 |
+
# Check if the next sample has already been generated
|
1442 |
+
if len(self._task_info[self._rcvd_idx]) == 2:
|
1443 |
+
data = self._task_info.pop(self._rcvd_idx)[1]
|
1444 |
+
return self._process_data(data)
|
1445 |
+
|
1446 |
+
assert not self._shutdown and self._tasks_outstanding > 0
|
1447 |
+
idx, data = self._get_data()
|
1448 |
+
self._tasks_outstanding -= 1
|
1449 |
+
if self._dataset_kind == _DatasetKind.Iterable:
|
1450 |
+
# Check for _IterableDatasetStopIteration
|
1451 |
+
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
|
1452 |
+
if self._persistent_workers:
|
1453 |
+
self._workers_status[data.worker_id] = False
|
1454 |
+
else:
|
1455 |
+
self._mark_worker_as_unavailable(data.worker_id)
|
1456 |
+
self._try_put_index()
|
1457 |
+
continue
|
1458 |
+
|
1459 |
+
if idx != self._rcvd_idx:
|
1460 |
+
# store out-of-order samples
|
1461 |
+
self._task_info[idx] += (data,)
|
1462 |
+
else:
|
1463 |
+
del self._task_info[idx]
|
1464 |
+
return self._process_data(data)
|
1465 |
+
|
1466 |
+
def _try_put_index(self):
|
1467 |
+
assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
|
1468 |
+
|
1469 |
+
try:
|
1470 |
+
index = self._next_index()
|
1471 |
+
except StopIteration:
|
1472 |
+
return
|
1473 |
+
for _ in range(self._num_workers): # find the next active worker, if any
|
1474 |
+
worker_queue_idx = next(self._worker_queue_idx_cycle)
|
1475 |
+
if self._workers_status[worker_queue_idx]:
|
1476 |
+
break
|
1477 |
+
else:
|
1478 |
+
# not found (i.e., didn't break)
|
1479 |
+
return
|
1480 |
+
|
1481 |
+
self._index_queues[worker_queue_idx].put((self._send_idx, index))
|
1482 |
+
self._task_info[self._send_idx] = (worker_queue_idx,)
|
1483 |
+
self._tasks_outstanding += 1
|
1484 |
+
self._send_idx += 1
|
1485 |
+
|
1486 |
+
def _process_data(self, data):
|
1487 |
+
self._rcvd_idx += 1
|
1488 |
+
self._try_put_index()
|
1489 |
+
if isinstance(data, ExceptionWrapper):
|
1490 |
+
data.reraise()
|
1491 |
+
return data
|
1492 |
+
|
1493 |
+
def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
|
1494 |
+
# Mark a worker as having finished its work e.g., due to
|
1495 |
+
# exhausting an `IterableDataset`. This should be used only when this
|
1496 |
+
# `_MultiProcessingDataLoaderIter` is going to continue running.
|
1497 |
+
|
1498 |
+
assert self._workers_status[worker_id] or (
|
1499 |
+
self._persistent_workers and shutdown
|
1500 |
+
)
|
1501 |
+
|
1502 |
+
# Signal termination to that specific worker.
|
1503 |
+
q = self._index_queues[worker_id]
|
1504 |
+
# Indicate that no more data will be put on this queue by the current
|
1505 |
+
# process.
|
1506 |
+
q.put(None)
|
1507 |
+
|
1508 |
+
# Note that we don't actually join the worker here, nor do we remove the
|
1509 |
+
# worker's pid from C side struct because (1) joining may be slow, and
|
1510 |
+
# (2) since we don't join, the worker may still raise error, and we
|
1511 |
+
# prefer capturing those, rather than ignoring them, even though they
|
1512 |
+
# are raised after the worker has finished its job.
|
1513 |
+
# Joinning is deferred to `_shutdown_workers`, which it is called when
|
1514 |
+
# all workers finish their jobs (e.g., `IterableDataset` replicas) or
|
1515 |
+
# when this iterator is garbage collected.
|
1516 |
+
|
1517 |
+
self._workers_status[worker_id] = False
|
1518 |
+
|
1519 |
+
assert self._workers_done_event.is_set() == shutdown
|
1520 |
+
|
1521 |
+
def _shutdown_workers(self):
|
1522 |
+
# Called when shutting down this `_MultiProcessingDataLoaderIter`.
|
1523 |
+
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
|
1524 |
+
# the logic of this function.
|
1525 |
+
if (
|
1526 |
+
_utils is None
|
1527 |
+
or _utils.python_exit_status is True
|
1528 |
+
or _utils.python_exit_status is None
|
1529 |
+
):
|
1530 |
+
# See (2) of the note. If Python is shutting down, do no-op.
|
1531 |
+
return
|
1532 |
+
# Normal exit when last reference is gone / iterator is depleted.
|
1533 |
+
# See (1) and the second half of the note.
|
1534 |
+
if not self._shutdown:
|
1535 |
+
self._shutdown = True
|
1536 |
+
try:
|
1537 |
+
# Normal exit when last reference is gone / iterator is depleted.
|
1538 |
+
# See (1) and the second half of the note.
|
1539 |
+
|
1540 |
+
# Exit `pin_memory_thread` first because exiting workers may leave
|
1541 |
+
# corrupted data in `worker_result_queue` which `pin_memory_thread`
|
1542 |
+
# reads from.
|
1543 |
+
if hasattr(self, "_pin_memory_thread"):
|
1544 |
+
# Use hasattr in case error happens before we set the attribute.
|
1545 |
+
self._pin_memory_thread_done_event.set()
|
1546 |
+
# Send something to pin_memory_thread in case it is waiting
|
1547 |
+
# so that it can wake up and check `pin_memory_thread_done_event`
|
1548 |
+
self._worker_result_queue.put((None, None))
|
1549 |
+
self._pin_memory_thread.join()
|
1550 |
+
self._worker_result_queue.cancel_join_thread()
|
1551 |
+
self._worker_result_queue.close()
|
1552 |
+
|
1553 |
+
# Exit workers now.
|
1554 |
+
self._workers_done_event.set()
|
1555 |
+
for worker_id in range(len(self._workers)):
|
1556 |
+
# Get number of workers from `len(self._workers)` instead of
|
1557 |
+
# `self._num_workers` in case we error before starting all
|
1558 |
+
# workers.
|
1559 |
+
# If we are using workers_status with persistent_workers
|
1560 |
+
# we have to shut it down because the worker is paused
|
1561 |
+
if self._persistent_workers or self._workers_status[worker_id]:
|
1562 |
+
self._mark_worker_as_unavailable(worker_id, shutdown=True)
|
1563 |
+
for w in self._workers:
|
1564 |
+
# We should be able to join here, but in case anything went
|
1565 |
+
# wrong, we set a timeout and if the workers fail to join,
|
1566 |
+
# they are killed in the `finally` block.
|
1567 |
+
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
|
1568 |
+
for q in self._index_queues:
|
1569 |
+
q.cancel_join_thread()
|
1570 |
+
q.close()
|
1571 |
+
finally:
|
1572 |
+
# Even though all this function does is putting into queues that
|
1573 |
+
# we have called `cancel_join_thread` on, weird things can
|
1574 |
+
# happen when a worker is killed by a signal, e.g., hanging in
|
1575 |
+
# `Event.set()`. So we need to guard this with SIGCHLD handler,
|
1576 |
+
# and remove pids from the C side data structure only at the
|
1577 |
+
# end.
|
1578 |
+
#
|
1579 |
+
# FIXME: Unfortunately, for Windows, we are missing a worker
|
1580 |
+
# error detection mechanism here in this function, as it
|
1581 |
+
# doesn't provide a SIGCHLD handler.
|
1582 |
+
if self._worker_pids_set:
|
1583 |
+
_utils.signal_handling._remove_worker_pids(id(self))
|
1584 |
+
self._worker_pids_set = False
|
1585 |
+
for w in self._workers:
|
1586 |
+
if w.is_alive():
|
1587 |
+
# Existing mechanisms try to make the workers exit
|
1588 |
+
# peacefully, but in case that we unfortunately reach
|
1589 |
+
# here, which we shouldn't, (e.g., pytorch/pytorch#39570),
|
1590 |
+
# we kill the worker.
|
1591 |
+
w.terminate()
|
1592 |
+
|
1593 |
+
# staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
|
1594 |
+
@staticmethod
|
1595 |
+
def _clean_up_worker(w):
|
1596 |
+
try:
|
1597 |
+
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
|
1598 |
+
finally:
|
1599 |
+
if w.is_alive():
|
1600 |
+
w.terminate()
|
1601 |
+
|
1602 |
+
def __del__(self):
|
1603 |
+
self._shutdown_workers()
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/_data_worker.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""""This file is based on torch/utils/data/_utils/worker.py
|
2 |
+
|
3 |
+
Contains definitions of the methods used by the _BaseDataLoaderIter workers.
|
4 |
+
These **needs** to be in global scope since Py2 doesn't support serializing
|
5 |
+
static methods.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import queue
|
10 |
+
import random
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from typing import TYPE_CHECKING, Optional, Union
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch._utils import ExceptionWrapper
|
16 |
+
from torch.utils.data._utils import (HAS_NUMPY, IS_WINDOWS,
|
17 |
+
MP_STATUS_CHECK_INTERVAL, signal_handling)
|
18 |
+
|
19 |
+
if TYPE_CHECKING:
|
20 |
+
from torch.utils.data import Dataset
|
21 |
+
|
22 |
+
from .controller import RRSController
|
23 |
+
|
24 |
+
if IS_WINDOWS:
|
25 |
+
import ctypes
|
26 |
+
from ctypes.wintypes import BOOL, DWORD, HANDLE
|
27 |
+
|
28 |
+
# On Windows, the parent ID of the worker process remains unchanged when the manager process
|
29 |
+
# is gone, and the only way to check it through OS is to let the worker have a process handle
|
30 |
+
# of the manager and ask if the process status has changed.
|
31 |
+
class ManagerWatchdog:
|
32 |
+
def __init__(self):
|
33 |
+
self.manager_pid = os.getppid()
|
34 |
+
|
35 |
+
# mypy cannot detect this code is windows only
|
36 |
+
self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined]
|
37 |
+
self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
|
38 |
+
self.kernel32.OpenProcess.restype = HANDLE
|
39 |
+
self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
|
40 |
+
self.kernel32.WaitForSingleObject.restype = DWORD
|
41 |
+
|
42 |
+
# Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
|
43 |
+
SYNCHRONIZE = 0x00100000
|
44 |
+
self.manager_handle = self.kernel32.OpenProcess(
|
45 |
+
SYNCHRONIZE, 0, self.manager_pid
|
46 |
+
)
|
47 |
+
|
48 |
+
if not self.manager_handle:
|
49 |
+
raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined]
|
50 |
+
|
51 |
+
self.manager_dead = False
|
52 |
+
|
53 |
+
def is_alive(self):
|
54 |
+
if not self.manager_dead:
|
55 |
+
# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
|
56 |
+
self.manager_dead = (
|
57 |
+
self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
|
58 |
+
)
|
59 |
+
return not self.manager_dead
|
60 |
+
|
61 |
+
else:
|
62 |
+
|
63 |
+
class ManagerWatchdog: # type: ignore[no-redef]
|
64 |
+
def __init__(self):
|
65 |
+
self.manager_pid = os.getppid()
|
66 |
+
self.manager_dead = False
|
67 |
+
|
68 |
+
def is_alive(self):
|
69 |
+
if not self.manager_dead:
|
70 |
+
self.manager_dead = os.getppid() != self.manager_pid
|
71 |
+
return not self.manager_dead
|
72 |
+
|
73 |
+
|
74 |
+
_worker_info = None
|
75 |
+
|
76 |
+
|
77 |
+
class WorkerInfo:
|
78 |
+
id: int
|
79 |
+
num_workers: int
|
80 |
+
seed: int
|
81 |
+
dataset: "Dataset"
|
82 |
+
__initialized = False
|
83 |
+
|
84 |
+
def __init__(self, **kwargs):
|
85 |
+
for k, v in kwargs.items():
|
86 |
+
setattr(self, k, v)
|
87 |
+
self.__keys = tuple(kwargs.keys())
|
88 |
+
self.__initialized = True
|
89 |
+
|
90 |
+
def __setattr__(self, key, val):
|
91 |
+
if self.__initialized:
|
92 |
+
raise RuntimeError(
|
93 |
+
"Cannot assign attributes to {} objects".format(self.__class__.__name__)
|
94 |
+
)
|
95 |
+
return super().__setattr__(key, val)
|
96 |
+
|
97 |
+
def __repr__(self):
|
98 |
+
items = []
|
99 |
+
for k in self.__keys:
|
100 |
+
items.append("{}={}".format(k, getattr(self, k)))
|
101 |
+
return "{}({})".format(self.__class__.__name__, ", ".join(items))
|
102 |
+
|
103 |
+
|
104 |
+
def get_worker_info() -> Optional[WorkerInfo]:
|
105 |
+
r"""Returns the information about the current
|
106 |
+
:class:`~torch.utils.data.DataLoader` iterator worker process.
|
107 |
+
|
108 |
+
When called in a worker, this returns an object guaranteed to have the
|
109 |
+
following attributes:
|
110 |
+
|
111 |
+
* :attr:`id`: the current worker id.
|
112 |
+
* :attr:`num_workers`: the total number of workers.
|
113 |
+
* :attr:`seed`: the random seed set for the current worker. This value is
|
114 |
+
determined by main process RNG and the worker id. See
|
115 |
+
:class:`~torch.utils.data.DataLoader`'s documentation for more details.
|
116 |
+
* :attr:`dataset`: the copy of the dataset object in **this** process. Note
|
117 |
+
that this will be a different object in a different process than the one
|
118 |
+
in the main process.
|
119 |
+
|
120 |
+
When called in the main process, this returns ``None``.
|
121 |
+
|
122 |
+
.. note::
|
123 |
+
When used in a :attr:`worker_init_fn` passed over to
|
124 |
+
:class:`~torch.utils.data.DataLoader`, this method can be useful to
|
125 |
+
set up each worker process differently, for instance, using ``worker_id``
|
126 |
+
to configure the ``dataset`` object to only read a specific fraction of a
|
127 |
+
sharded dataset, or use ``seed`` to seed other libraries used in dataset
|
128 |
+
code.
|
129 |
+
"""
|
130 |
+
return _worker_info
|
131 |
+
|
132 |
+
|
133 |
+
r"""Dummy class used to signal the end of an IterableDataset"""
|
134 |
+
|
135 |
+
|
136 |
+
@dataclass(frozen=True)
|
137 |
+
class _IterableDatasetStopIteration:
|
138 |
+
worker_id: int
|
139 |
+
|
140 |
+
|
141 |
+
r"""Dummy class used to resume the fetching when worker reuse is enabled"""
|
142 |
+
|
143 |
+
|
144 |
+
@dataclass(frozen=True)
|
145 |
+
class _ResumeIteration:
|
146 |
+
seed: Optional[int] = None
|
147 |
+
|
148 |
+
|
149 |
+
# The function `_generate_state` is adapted from `numpy.random.SeedSequence`
|
150 |
+
# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
|
151 |
+
# It's MIT licensed, here is the copyright:
|
152 |
+
|
153 |
+
# Copyright (c) 2015 Melissa E. O'Neill
|
154 |
+
# Copyright (c) 2019 NumPy Developers
|
155 |
+
#
|
156 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
157 |
+
# of this software and associated documentation files (the "Software"), to deal
|
158 |
+
# in the Software without restriction, including without limitation the rights
|
159 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
160 |
+
# copies of the Software, and to permit persons to whom the Software is
|
161 |
+
# furnished to do so, subject to the following conditions:
|
162 |
+
#
|
163 |
+
# The above copyright notice and this permission notice shall be included in
|
164 |
+
# all copies or substantial portions of the Software.
|
165 |
+
#
|
166 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
167 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
168 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
169 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
170 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
171 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
172 |
+
# SOFTWARE.
|
173 |
+
|
174 |
+
|
175 |
+
# This function generates an array of int32 as the seed for
|
176 |
+
# `numpy.random`, in order to prevent state collision due to same
|
177 |
+
# seed and algorithm for `numpy.random` and `random` modules.
|
178 |
+
# TODO: Implement `SeedSequence` like object for `torch.random`
|
179 |
+
def _generate_state(base_seed, worker_id):
|
180 |
+
INIT_A = 0x43B0D7E5
|
181 |
+
MULT_A = 0x931E8875
|
182 |
+
INIT_B = 0x8B51F9DD
|
183 |
+
MULT_B = 0x58F38DED
|
184 |
+
MIX_MULT_L = 0xCA01F9DD
|
185 |
+
MIX_MULT_R = 0x4973F715
|
186 |
+
XSHIFT = 4 * 8 // 2
|
187 |
+
MASK32 = 0xFFFFFFFF
|
188 |
+
|
189 |
+
entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
|
190 |
+
pool = [0] * 4
|
191 |
+
|
192 |
+
hash_const_A = INIT_A
|
193 |
+
|
194 |
+
def hash(value):
|
195 |
+
nonlocal hash_const_A
|
196 |
+
value = (value ^ hash_const_A) & MASK32
|
197 |
+
hash_const_A = (hash_const_A * MULT_A) & MASK32
|
198 |
+
value = (value * hash_const_A) & MASK32
|
199 |
+
value = (value ^ (value >> XSHIFT)) & MASK32
|
200 |
+
return value
|
201 |
+
|
202 |
+
def mix(x, y):
|
203 |
+
result_x = (MIX_MULT_L * x) & MASK32
|
204 |
+
result_y = (MIX_MULT_R * y) & MASK32
|
205 |
+
result = (result_x - result_y) & MASK32
|
206 |
+
result = (result ^ (result >> XSHIFT)) & MASK32
|
207 |
+
return result
|
208 |
+
|
209 |
+
# Add in the entropy to the pool.
|
210 |
+
for i in range(len(pool)):
|
211 |
+
pool[i] = hash(entropy[i])
|
212 |
+
|
213 |
+
# Mix all bits together so late bits can affect earlier bits.
|
214 |
+
for i_src in range(len(pool)):
|
215 |
+
for i_dst in range(len(pool)):
|
216 |
+
if i_src != i_dst:
|
217 |
+
pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
|
218 |
+
|
219 |
+
hash_const_B = INIT_B
|
220 |
+
state = []
|
221 |
+
for i_dst in range(4):
|
222 |
+
data_val = pool[i_dst]
|
223 |
+
data_val = (data_val ^ hash_const_B) & MASK32
|
224 |
+
hash_const_B = (hash_const_B * MULT_B) & MASK32
|
225 |
+
data_val = (data_val * hash_const_B) & MASK32
|
226 |
+
data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
|
227 |
+
state.append(data_val)
|
228 |
+
return state
|
229 |
+
|
230 |
+
|
231 |
+
def _worker_loop(
|
232 |
+
dataset_kind,
|
233 |
+
dataset,
|
234 |
+
index_queue,
|
235 |
+
data_queue,
|
236 |
+
done_event,
|
237 |
+
auto_collation,
|
238 |
+
collate_fn,
|
239 |
+
drop_last,
|
240 |
+
base_seed,
|
241 |
+
init_fn,
|
242 |
+
worker_id,
|
243 |
+
num_workers,
|
244 |
+
persistent_workers,
|
245 |
+
shared_seed,
|
246 |
+
):
|
247 |
+
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
|
248 |
+
# logic of this function.
|
249 |
+
|
250 |
+
try:
|
251 |
+
# Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
|
252 |
+
# module's handlers are executed after Python returns from C low-level
|
253 |
+
# handlers, likely when the same fatal signal had already happened
|
254 |
+
# again.
|
255 |
+
# https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
|
256 |
+
signal_handling._set_worker_signal_handlers()
|
257 |
+
|
258 |
+
torch.set_num_threads(1)
|
259 |
+
seed = base_seed + worker_id
|
260 |
+
random.seed(seed)
|
261 |
+
torch.manual_seed(seed)
|
262 |
+
if HAS_NUMPY:
|
263 |
+
np_seed = _generate_state(base_seed, worker_id)
|
264 |
+
import numpy as np
|
265 |
+
|
266 |
+
np.random.seed(np_seed)
|
267 |
+
|
268 |
+
from torch.utils.data import IterDataPipe
|
269 |
+
from torch.utils.data.graph_settings import apply_random_seed
|
270 |
+
|
271 |
+
shared_rng = torch.Generator()
|
272 |
+
if isinstance(dataset, IterDataPipe):
|
273 |
+
assert shared_seed is not None
|
274 |
+
shared_rng.manual_seed(shared_seed)
|
275 |
+
dataset = apply_random_seed(dataset, shared_rng)
|
276 |
+
|
277 |
+
global _worker_info
|
278 |
+
_worker_info = WorkerInfo(
|
279 |
+
id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset
|
280 |
+
)
|
281 |
+
|
282 |
+
from torch.utils.data import _DatasetKind
|
283 |
+
|
284 |
+
init_exception = None
|
285 |
+
|
286 |
+
try:
|
287 |
+
if init_fn is not None:
|
288 |
+
init_fn(worker_id)
|
289 |
+
|
290 |
+
fetcher = _DatasetKind.create_fetcher(
|
291 |
+
dataset_kind, dataset, auto_collation, collate_fn, drop_last
|
292 |
+
)
|
293 |
+
except Exception:
|
294 |
+
init_exception = ExceptionWrapper(
|
295 |
+
where="in DataLoader worker process {}".format(worker_id)
|
296 |
+
)
|
297 |
+
|
298 |
+
# When using Iterable mode, some worker can exit earlier than others due
|
299 |
+
# to the IterableDataset behaving differently for different workers.
|
300 |
+
# When such things happen, an `_IterableDatasetStopIteration` object is
|
301 |
+
# sent over to the main process with the ID of this worker, so that the
|
302 |
+
# main process won't send more tasks to this worker, and will send
|
303 |
+
# `None` to this worker to properly exit it.
|
304 |
+
#
|
305 |
+
# Note that we cannot set `done_event` from a worker as it is shared
|
306 |
+
# among all processes. Instead, we set the `iteration_end` flag to
|
307 |
+
# signify that the iterator is exhausted. When either `done_event` or
|
308 |
+
# `iteration_end` is set, we skip all processing step and just wait for
|
309 |
+
# `None`.
|
310 |
+
iteration_end = False
|
311 |
+
|
312 |
+
watchdog = ManagerWatchdog()
|
313 |
+
|
314 |
+
while watchdog.is_alive():
|
315 |
+
try:
|
316 |
+
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
317 |
+
except queue.Empty:
|
318 |
+
continue
|
319 |
+
if isinstance(r, _ResumeIteration):
|
320 |
+
# Acknowledge the main process
|
321 |
+
data_queue.put((r, None))
|
322 |
+
iteration_end = False
|
323 |
+
|
324 |
+
if isinstance(dataset, IterDataPipe):
|
325 |
+
assert r.seed is not None
|
326 |
+
shared_rng.manual_seed(r.seed)
|
327 |
+
dataset = apply_random_seed(dataset, shared_rng)
|
328 |
+
|
329 |
+
# Recreate the fetcher for worker-reuse policy
|
330 |
+
fetcher = _DatasetKind.create_fetcher(
|
331 |
+
dataset_kind, dataset, auto_collation, collate_fn, drop_last
|
332 |
+
)
|
333 |
+
continue
|
334 |
+
elif r is None:
|
335 |
+
# Received the final signal
|
336 |
+
assert done_event.is_set() or iteration_end
|
337 |
+
break
|
338 |
+
elif done_event.is_set() or iteration_end:
|
339 |
+
# `done_event` is set. But I haven't received the final signal
|
340 |
+
# (None) yet. I will keep continuing until get it, and skip the
|
341 |
+
# processing steps.
|
342 |
+
continue
|
343 |
+
idx, index = r
|
344 |
+
""" Added """
|
345 |
+
RRSController.sample_resolution(batch_id=idx)
|
346 |
+
""" Added """
|
347 |
+
data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
|
348 |
+
if init_exception is not None:
|
349 |
+
data = init_exception
|
350 |
+
init_exception = None
|
351 |
+
else:
|
352 |
+
try:
|
353 |
+
data = fetcher.fetch(index)
|
354 |
+
except Exception as e:
|
355 |
+
if (
|
356 |
+
isinstance(e, StopIteration)
|
357 |
+
and dataset_kind == _DatasetKind.Iterable
|
358 |
+
):
|
359 |
+
data = _IterableDatasetStopIteration(worker_id)
|
360 |
+
# Set `iteration_end`
|
361 |
+
# (1) to save future `next(...)` calls, and
|
362 |
+
# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
|
363 |
+
iteration_end = True
|
364 |
+
else:
|
365 |
+
# It is important that we don't store exc_info in a variable.
|
366 |
+
# `ExceptionWrapper` does the correct thing.
|
367 |
+
# See NOTE [ Python Traceback Reference Cycle Problem ]
|
368 |
+
data = ExceptionWrapper(
|
369 |
+
where="in DataLoader worker process {}".format(worker_id)
|
370 |
+
)
|
371 |
+
data_queue.put((idx, data))
|
372 |
+
del data, idx, index, r # save memory
|
373 |
+
except KeyboardInterrupt:
|
374 |
+
# Main process will raise KeyboardInterrupt anyways.
|
375 |
+
pass
|
376 |
+
if done_event.is_set():
|
377 |
+
data_queue.cancel_join_thread()
|
378 |
+
data_queue.close()
|
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/controller.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
import copy
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
import torchvision.transforms.functional as F
|
10 |
+
|
11 |
+
from efficientvit.models.utils import torch_random_choices
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
"RRSController",
|
15 |
+
"get_interpolate",
|
16 |
+
"MyRandomResizedCrop",
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
class RRSController:
|
21 |
+
ACTIVE_SIZE = (224, 224)
|
22 |
+
IMAGE_SIZE_LIST = [(224, 224)]
|
23 |
+
|
24 |
+
CHOICE_LIST = None
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def get_candidates() -> list[tuple[int, int]]:
|
28 |
+
return copy.deepcopy(RRSController.IMAGE_SIZE_LIST)
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def sample_resolution(batch_id: int) -> None:
|
32 |
+
RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id]
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def set_epoch(epoch: int, batch_per_epoch: int) -> None:
|
36 |
+
g = torch.Generator()
|
37 |
+
g.manual_seed(epoch)
|
38 |
+
RRSController.CHOICE_LIST = torch_random_choices(
|
39 |
+
RRSController.get_candidates(),
|
40 |
+
g,
|
41 |
+
batch_per_epoch,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def get_interpolate(name: str) -> F.InterpolationMode:
|
46 |
+
mapping = {
|
47 |
+
"nearest": F.InterpolationMode.NEAREST,
|
48 |
+
"bilinear": F.InterpolationMode.BILINEAR,
|
49 |
+
"bicubic": F.InterpolationMode.BICUBIC,
|
50 |
+
"box": F.InterpolationMode.BOX,
|
51 |
+
"hamming": F.InterpolationMode.HAMMING,
|
52 |
+
"lanczos": F.InterpolationMode.LANCZOS,
|
53 |
+
}
|
54 |
+
if name in mapping:
|
55 |
+
return mapping[name]
|
56 |
+
elif name == "random":
|
57 |
+
return torch_random_choices(
|
58 |
+
[
|
59 |
+
F.InterpolationMode.NEAREST,
|
60 |
+
F.InterpolationMode.BILINEAR,
|
61 |
+
F.InterpolationMode.BICUBIC,
|
62 |
+
F.InterpolationMode.BOX,
|
63 |
+
F.InterpolationMode.HAMMING,
|
64 |
+
F.InterpolationMode.LANCZOS,
|
65 |
+
],
|
66 |
+
)
|
67 |
+
else:
|
68 |
+
raise NotImplementedError
|
69 |
+
|
70 |
+
|
71 |
+
class MyRandomResizedCrop(transforms.RandomResizedCrop):
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
scale=(0.08, 1.0),
|
75 |
+
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
76 |
+
interpolation: str = "random",
|
77 |
+
):
|
78 |
+
super(MyRandomResizedCrop, self).__init__(224, scale, ratio)
|
79 |
+
self.interpolation = interpolation
|
80 |
+
|
81 |
+
def forward(self, img: torch.Tensor) -> torch.Tensor:
|
82 |
+
i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio))
|
83 |
+
target_size = RRSController.ACTIVE_SIZE
|
84 |
+
return F.resized_crop(
|
85 |
+
img, i, j, h, w, list(target_size), get_interpolate(self.interpolation)
|
86 |
+
)
|
87 |
+
|
88 |
+
def __repr__(self) -> str:
|
89 |
+
format_string = self.__class__.__name__
|
90 |
+
format_string += f"(\n\tsize={RRSController.get_candidates()},\n"
|
91 |
+
format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n"
|
92 |
+
format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n"
|
93 |
+
format_string += f"\tinterpolation={self.interpolation})"
|
94 |
+
return format_string
|
yolo-world-with-efficientvit-sam/efficientvit/apps/setup.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
from copy import deepcopy
|
8 |
+
|
9 |
+
import torch.backends.cudnn
|
10 |
+
import torch.distributed
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
from efficientvit.apps.data_provider import DataProvider
|
14 |
+
from efficientvit.apps.trainer.run_config import RunConfig
|
15 |
+
from efficientvit.apps.utils import (dist_init, dump_config,
|
16 |
+
get_dist_local_rank, get_dist_rank,
|
17 |
+
get_dist_size, init_modules, is_master,
|
18 |
+
load_config, partial_update_config,
|
19 |
+
zero_last_gamma)
|
20 |
+
from efficientvit.models.utils import (build_kwargs_from_config,
|
21 |
+
load_state_dict_from_file)
|
22 |
+
|
23 |
+
__all__ = [
|
24 |
+
"save_exp_config",
|
25 |
+
"setup_dist_env",
|
26 |
+
"setup_seed",
|
27 |
+
"setup_exp_config",
|
28 |
+
"setup_data_provider",
|
29 |
+
"setup_run_config",
|
30 |
+
"init_model",
|
31 |
+
]
|
32 |
+
|
33 |
+
|
34 |
+
def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
|
35 |
+
if not is_master():
|
36 |
+
return
|
37 |
+
dump_config(exp_config, os.path.join(path, name))
|
38 |
+
|
39 |
+
|
40 |
+
def setup_dist_env(gpu: str or None = None) -> None:
|
41 |
+
if gpu is not None:
|
42 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
|
43 |
+
if not torch.distributed.is_initialized():
|
44 |
+
dist_init()
|
45 |
+
torch.backends.cudnn.benchmark = True
|
46 |
+
torch.cuda.set_device(get_dist_local_rank())
|
47 |
+
|
48 |
+
|
49 |
+
def setup_seed(manual_seed: int, resume: bool) -> None:
|
50 |
+
if resume:
|
51 |
+
manual_seed = int(time.time())
|
52 |
+
manual_seed = get_dist_rank() + manual_seed
|
53 |
+
torch.manual_seed(manual_seed)
|
54 |
+
torch.cuda.manual_seed_all(manual_seed)
|
55 |
+
|
56 |
+
|
57 |
+
def setup_exp_config(
|
58 |
+
config_path: str, recursive=True, opt_args: dict or None = None
|
59 |
+
) -> dict:
|
60 |
+
# load config
|
61 |
+
if not os.path.isfile(config_path):
|
62 |
+
raise ValueError(config_path)
|
63 |
+
|
64 |
+
fpaths = [config_path]
|
65 |
+
if recursive:
|
66 |
+
extension = os.path.splitext(config_path)[1]
|
67 |
+
while os.path.dirname(config_path) != config_path:
|
68 |
+
config_path = os.path.dirname(config_path)
|
69 |
+
fpath = os.path.join(config_path, "default" + extension)
|
70 |
+
if os.path.isfile(fpath):
|
71 |
+
fpaths.append(fpath)
|
72 |
+
fpaths = fpaths[::-1]
|
73 |
+
|
74 |
+
default_config = load_config(fpaths[0])
|
75 |
+
exp_config = deepcopy(default_config)
|
76 |
+
for fpath in fpaths[1:]:
|
77 |
+
partial_update_config(exp_config, load_config(fpath))
|
78 |
+
# update config via args
|
79 |
+
if opt_args is not None:
|
80 |
+
partial_update_config(exp_config, opt_args)
|
81 |
+
|
82 |
+
return exp_config
|
83 |
+
|
84 |
+
|
85 |
+
def setup_data_provider(
|
86 |
+
exp_config: dict,
|
87 |
+
data_provider_classes: list[type[DataProvider]],
|
88 |
+
is_distributed: bool = True,
|
89 |
+
) -> DataProvider:
|
90 |
+
dp_config = exp_config["data_provider"]
|
91 |
+
dp_config["num_replicas"] = get_dist_size() if is_distributed else None
|
92 |
+
dp_config["rank"] = get_dist_rank() if is_distributed else None
|
93 |
+
dp_config["test_batch_size"] = (
|
94 |
+
dp_config.get("test_batch_size", None) or dp_config["base_batch_size"] * 2
|
95 |
+
)
|
96 |
+
dp_config["batch_size"] = dp_config["train_batch_size"] = dp_config[
|
97 |
+
"base_batch_size"
|
98 |
+
]
|
99 |
+
|
100 |
+
data_provider_lookup = {
|
101 |
+
provider.name: provider for provider in data_provider_classes
|
102 |
+
}
|
103 |
+
data_provider_class = data_provider_lookup[dp_config["dataset"]]
|
104 |
+
|
105 |
+
data_provider_kwargs = build_kwargs_from_config(dp_config, data_provider_class)
|
106 |
+
data_provider = data_provider_class(**data_provider_kwargs)
|
107 |
+
return data_provider
|
108 |
+
|
109 |
+
|
110 |
+
def setup_run_config(exp_config: dict, run_config_cls: type[RunConfig]) -> RunConfig:
|
111 |
+
exp_config["run_config"]["init_lr"] = (
|
112 |
+
exp_config["run_config"]["base_lr"] * get_dist_size()
|
113 |
+
)
|
114 |
+
|
115 |
+
run_config = run_config_cls(**exp_config["run_config"])
|
116 |
+
|
117 |
+
return run_config
|
118 |
+
|
119 |
+
|
120 |
+
def init_model(
|
121 |
+
network: nn.Module,
|
122 |
+
init_from: str or None = None,
|
123 |
+
backbone_init_from: str or None = None,
|
124 |
+
rand_init="trunc_normal",
|
125 |
+
last_gamma=None,
|
126 |
+
) -> None:
|
127 |
+
# initialization
|
128 |
+
init_modules(network, init_type=rand_init)
|
129 |
+
# zero gamma of last bn in each block
|
130 |
+
if last_gamma is not None:
|
131 |
+
zero_last_gamma(network, last_gamma)
|
132 |
+
|
133 |
+
# load weight
|
134 |
+
if init_from is not None and os.path.isfile(init_from):
|
135 |
+
network.load_state_dict(load_state_dict_from_file(init_from))
|
136 |
+
print(f"Loaded init from {init_from}")
|
137 |
+
elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
|
138 |
+
network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
|
139 |
+
print(f"Loaded backbone init from {backbone_init_from}")
|
140 |
+
else:
|
141 |
+
print(f"Random init ({rand_init}) with last gamma {last_gamma}")
|
yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
from .base import *
|
6 |
+
from .run_config import *
|
yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (264 Bytes). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/base.cpython-311.pyc
ADDED
Binary file (16.8 kB). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/run_config.cpython-311.pyc
ADDED
Binary file (7.04 kB). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/base.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from efficientvit.apps.data_provider import DataProvider, parse_image_size
|
11 |
+
from efficientvit.apps.trainer.run_config import RunConfig
|
12 |
+
from efficientvit.apps.utils import (EMA, dist_barrier, get_dist_local_rank,
|
13 |
+
is_master)
|
14 |
+
from efficientvit.models.nn.norm import reset_bn
|
15 |
+
from efficientvit.models.utils import is_parallel, load_state_dict_from_file
|
16 |
+
|
17 |
+
__all__ = ["Trainer"]
|
18 |
+
|
19 |
+
|
20 |
+
class Trainer:
|
21 |
+
def __init__(self, path: str, model: nn.Module, data_provider: DataProvider):
|
22 |
+
self.path = os.path.realpath(os.path.expanduser(path))
|
23 |
+
self.model = model.cuda()
|
24 |
+
self.data_provider = data_provider
|
25 |
+
|
26 |
+
self.ema = None
|
27 |
+
|
28 |
+
self.checkpoint_path = os.path.join(self.path, "checkpoint")
|
29 |
+
self.logs_path = os.path.join(self.path, "logs")
|
30 |
+
for path in [self.path, self.checkpoint_path, self.logs_path]:
|
31 |
+
os.makedirs(path, exist_ok=True)
|
32 |
+
|
33 |
+
self.best_val = 0.0
|
34 |
+
self.start_epoch = 0
|
35 |
+
|
36 |
+
@property
|
37 |
+
def network(self) -> nn.Module:
|
38 |
+
return self.model.module if is_parallel(self.model) else self.model
|
39 |
+
|
40 |
+
@property
|
41 |
+
def eval_network(self) -> nn.Module:
|
42 |
+
if self.ema is None:
|
43 |
+
model = self.model
|
44 |
+
else:
|
45 |
+
model = self.ema.shadows
|
46 |
+
model = model.module if is_parallel(model) else model
|
47 |
+
return model
|
48 |
+
|
49 |
+
def write_log(self, log_str, prefix="valid", print_log=True, mode="a") -> None:
|
50 |
+
if is_master():
|
51 |
+
fout = open(os.path.join(self.logs_path, f"{prefix}.log"), mode)
|
52 |
+
fout.write(log_str + "\n")
|
53 |
+
fout.flush()
|
54 |
+
fout.close()
|
55 |
+
if print_log:
|
56 |
+
print(log_str)
|
57 |
+
|
58 |
+
def save_model(
|
59 |
+
self,
|
60 |
+
checkpoint=None,
|
61 |
+
only_state_dict=True,
|
62 |
+
epoch=0,
|
63 |
+
model_name=None,
|
64 |
+
) -> None:
|
65 |
+
if is_master():
|
66 |
+
if checkpoint is None:
|
67 |
+
if only_state_dict:
|
68 |
+
checkpoint = {"state_dict": self.network.state_dict()}
|
69 |
+
else:
|
70 |
+
checkpoint = {
|
71 |
+
"state_dict": self.network.state_dict(),
|
72 |
+
"epoch": epoch,
|
73 |
+
"best_val": self.best_val,
|
74 |
+
"optimizer": self.optimizer.state_dict(),
|
75 |
+
"lr_scheduler": self.lr_scheduler.state_dict(),
|
76 |
+
"ema": self.ema.state_dict() if self.ema is not None else None,
|
77 |
+
"scaler": self.scaler.state_dict() if self.fp16 else None,
|
78 |
+
}
|
79 |
+
|
80 |
+
model_name = model_name or "checkpoint.pt"
|
81 |
+
|
82 |
+
latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
|
83 |
+
model_path = os.path.join(self.checkpoint_path, model_name)
|
84 |
+
with open(latest_fname, "w") as _fout:
|
85 |
+
_fout.write(model_path + "\n")
|
86 |
+
torch.save(checkpoint, model_path)
|
87 |
+
|
88 |
+
def load_model(self, model_fname=None) -> None:
|
89 |
+
latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
|
90 |
+
if model_fname is None and os.path.exists(latest_fname):
|
91 |
+
with open(latest_fname, "r") as fin:
|
92 |
+
model_fname = fin.readline()
|
93 |
+
if len(model_fname) > 0 and model_fname[-1] == "\n":
|
94 |
+
model_fname = model_fname[:-1]
|
95 |
+
try:
|
96 |
+
if model_fname is None:
|
97 |
+
model_fname = f"{self.checkpoint_path}/checkpoint.pt"
|
98 |
+
elif not os.path.exists(model_fname):
|
99 |
+
model_fname = f"{self.checkpoint_path}/{os.path.basename(model_fname)}"
|
100 |
+
if not os.path.exists(model_fname):
|
101 |
+
model_fname = f"{self.checkpoint_path}/checkpoint.pt"
|
102 |
+
print(f"=> loading checkpoint {model_fname}")
|
103 |
+
checkpoint = load_state_dict_from_file(model_fname, False)
|
104 |
+
except Exception:
|
105 |
+
self.write_log(f"fail to load checkpoint from {self.checkpoint_path}")
|
106 |
+
return
|
107 |
+
|
108 |
+
# load checkpoint
|
109 |
+
self.network.load_state_dict(checkpoint["state_dict"], strict=False)
|
110 |
+
log = []
|
111 |
+
if "epoch" in checkpoint:
|
112 |
+
self.start_epoch = checkpoint["epoch"] + 1
|
113 |
+
self.run_config.update_global_step(self.start_epoch)
|
114 |
+
log.append(f"epoch={self.start_epoch - 1}")
|
115 |
+
if "best_val" in checkpoint:
|
116 |
+
self.best_val = checkpoint["best_val"]
|
117 |
+
log.append(f"best_val={self.best_val:.2f}")
|
118 |
+
if "optimizer" in checkpoint:
|
119 |
+
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
120 |
+
log.append("optimizer")
|
121 |
+
if "lr_scheduler" in checkpoint:
|
122 |
+
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
123 |
+
log.append("lr_scheduler")
|
124 |
+
if "ema" in checkpoint and self.ema is not None:
|
125 |
+
self.ema.load_state_dict(checkpoint["ema"])
|
126 |
+
log.append("ema")
|
127 |
+
if "scaler" in checkpoint and self.fp16:
|
128 |
+
self.scaler.load_state_dict(checkpoint["scaler"])
|
129 |
+
log.append("scaler")
|
130 |
+
self.write_log("Loaded: " + ", ".join(log))
|
131 |
+
|
132 |
+
""" validate """
|
133 |
+
|
134 |
+
def reset_bn(
|
135 |
+
self,
|
136 |
+
network: nn.Module or None = None,
|
137 |
+
subset_size: int = 16000,
|
138 |
+
subset_batch_size: int = 100,
|
139 |
+
data_loader=None,
|
140 |
+
progress_bar=False,
|
141 |
+
) -> None:
|
142 |
+
network = network or self.network
|
143 |
+
if data_loader is None:
|
144 |
+
data_loader = []
|
145 |
+
for data in self.data_provider.build_sub_train_loader(
|
146 |
+
subset_size, subset_batch_size
|
147 |
+
):
|
148 |
+
if isinstance(data, list):
|
149 |
+
data_loader.append(data[0])
|
150 |
+
elif isinstance(data, dict):
|
151 |
+
data_loader.append(data["data"])
|
152 |
+
elif isinstance(data, torch.Tensor):
|
153 |
+
data_loader.append(data)
|
154 |
+
else:
|
155 |
+
raise NotImplementedError
|
156 |
+
|
157 |
+
network.eval()
|
158 |
+
reset_bn(
|
159 |
+
network,
|
160 |
+
data_loader,
|
161 |
+
sync=True,
|
162 |
+
progress_bar=progress_bar,
|
163 |
+
)
|
164 |
+
|
165 |
+
def _validate(self, model, data_loader, epoch) -> dict[str, any]:
|
166 |
+
raise NotImplementedError
|
167 |
+
|
168 |
+
def validate(
|
169 |
+
self, model=None, data_loader=None, is_test=True, epoch=0
|
170 |
+
) -> dict[str, any]:
|
171 |
+
model = model or self.eval_network
|
172 |
+
if data_loader is None:
|
173 |
+
if is_test:
|
174 |
+
data_loader = self.data_provider.test
|
175 |
+
else:
|
176 |
+
data_loader = self.data_provider.valid
|
177 |
+
|
178 |
+
model.eval()
|
179 |
+
return self._validate(model, data_loader, epoch)
|
180 |
+
|
181 |
+
def multires_validate(
|
182 |
+
self,
|
183 |
+
model=None,
|
184 |
+
data_loader=None,
|
185 |
+
is_test=True,
|
186 |
+
epoch=0,
|
187 |
+
eval_image_size=None,
|
188 |
+
) -> dict[str, dict[str, any]]:
|
189 |
+
eval_image_size = eval_image_size or self.run_config.eval_image_size
|
190 |
+
eval_image_size = eval_image_size or self.data_provider.image_size
|
191 |
+
model = model or self.eval_network
|
192 |
+
|
193 |
+
if not isinstance(eval_image_size, list):
|
194 |
+
eval_image_size = [eval_image_size]
|
195 |
+
|
196 |
+
output_dict = {}
|
197 |
+
for r in eval_image_size:
|
198 |
+
self.data_provider.assign_active_image_size(parse_image_size(r))
|
199 |
+
if self.run_config.reset_bn:
|
200 |
+
self.reset_bn(
|
201 |
+
network=model,
|
202 |
+
subset_size=self.run_config.reset_bn_size,
|
203 |
+
subset_batch_size=self.run_config.reset_bn_batch_size,
|
204 |
+
progress_bar=True,
|
205 |
+
)
|
206 |
+
output_dict[f"r{r}"] = self.validate(model, data_loader, is_test, epoch)
|
207 |
+
return output_dict
|
208 |
+
|
209 |
+
""" training """
|
210 |
+
|
211 |
+
def prep_for_training(
|
212 |
+
self, run_config: RunConfig, ema_decay: float or None = None, fp16=False
|
213 |
+
) -> None:
|
214 |
+
self.run_config = run_config
|
215 |
+
self.model = nn.parallel.DistributedDataParallel(
|
216 |
+
self.model.cuda(),
|
217 |
+
device_ids=[get_dist_local_rank()],
|
218 |
+
static_graph=True,
|
219 |
+
)
|
220 |
+
|
221 |
+
self.run_config.global_step = 0
|
222 |
+
self.run_config.batch_per_epoch = len(self.data_provider.train)
|
223 |
+
assert self.run_config.batch_per_epoch > 0, "Training set is empty"
|
224 |
+
|
225 |
+
# build optimizer
|
226 |
+
self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model)
|
227 |
+
|
228 |
+
if ema_decay is not None:
|
229 |
+
self.ema = EMA(self.network, ema_decay)
|
230 |
+
|
231 |
+
# fp16
|
232 |
+
self.fp16 = fp16
|
233 |
+
self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
|
234 |
+
|
235 |
+
def sync_model(self):
|
236 |
+
print("Sync model")
|
237 |
+
self.save_model(model_name="sync.pt")
|
238 |
+
dist_barrier()
|
239 |
+
checkpoint = torch.load(
|
240 |
+
os.path.join(self.checkpoint_path, "sync.pt"), map_location="cpu"
|
241 |
+
)
|
242 |
+
dist_barrier()
|
243 |
+
if is_master():
|
244 |
+
os.remove(os.path.join(self.checkpoint_path, "sync.pt"))
|
245 |
+
dist_barrier()
|
246 |
+
|
247 |
+
# load checkpoint
|
248 |
+
self.network.load_state_dict(checkpoint["state_dict"], strict=False)
|
249 |
+
if "optimizer" in checkpoint:
|
250 |
+
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
251 |
+
if "lr_scheduler" in checkpoint:
|
252 |
+
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
253 |
+
if "ema" in checkpoint and self.ema is not None:
|
254 |
+
self.ema.load_state_dict(checkpoint["ema"])
|
255 |
+
if "scaler" in checkpoint and self.fp16:
|
256 |
+
self.scaler.load_state_dict(checkpoint["scaler"])
|
257 |
+
|
258 |
+
def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
|
259 |
+
for key in feed_dict:
|
260 |
+
if isinstance(feed_dict[key], torch.Tensor):
|
261 |
+
feed_dict[key] = feed_dict[key].cuda()
|
262 |
+
return feed_dict
|
263 |
+
|
264 |
+
def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
|
265 |
+
raise NotImplementedError
|
266 |
+
|
267 |
+
def after_step(self) -> None:
|
268 |
+
self.scaler.unscale_(self.optimizer)
|
269 |
+
# gradient clip
|
270 |
+
if self.run_config.grad_clip is not None:
|
271 |
+
torch.nn.utils.clip_grad_value_(
|
272 |
+
self.model.parameters(), self.run_config.grad_clip
|
273 |
+
)
|
274 |
+
# update
|
275 |
+
self.scaler.step(self.optimizer)
|
276 |
+
self.scaler.update()
|
277 |
+
|
278 |
+
self.lr_scheduler.step()
|
279 |
+
self.run_config.step()
|
280 |
+
# update ema
|
281 |
+
if self.ema is not None:
|
282 |
+
self.ema.step(self.network, self.run_config.global_step)
|
283 |
+
|
284 |
+
def _train_one_epoch(self, epoch: int) -> dict[str, any]:
|
285 |
+
raise NotImplementedError
|
286 |
+
|
287 |
+
def train_one_epoch(self, epoch: int) -> dict[str, any]:
|
288 |
+
self.model.train()
|
289 |
+
|
290 |
+
self.data_provider.set_epoch(epoch)
|
291 |
+
|
292 |
+
train_info_dict = self._train_one_epoch(epoch)
|
293 |
+
|
294 |
+
return train_info_dict
|
295 |
+
|
296 |
+
def train(self) -> None:
|
297 |
+
raise NotImplementedError
|
yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/run_config.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
import json
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from efficientvit.apps.utils import CosineLRwithWarmup, build_optimizer
|
11 |
+
|
12 |
+
__all__ = ["Scheduler", "RunConfig"]
|
13 |
+
|
14 |
+
|
15 |
+
class Scheduler:
|
16 |
+
PROGRESS = 0
|
17 |
+
|
18 |
+
|
19 |
+
class RunConfig:
|
20 |
+
n_epochs: int
|
21 |
+
init_lr: float
|
22 |
+
warmup_epochs: int
|
23 |
+
warmup_lr: float
|
24 |
+
lr_schedule_name: str
|
25 |
+
lr_schedule_param: dict
|
26 |
+
optimizer_name: str
|
27 |
+
optimizer_params: dict
|
28 |
+
weight_decay: float
|
29 |
+
no_wd_keys: list
|
30 |
+
grad_clip: float # allow none to turn off grad clipping
|
31 |
+
reset_bn: bool
|
32 |
+
reset_bn_size: int
|
33 |
+
reset_bn_batch_size: int
|
34 |
+
eval_image_size: list # allow none to use image_size in data_provider
|
35 |
+
|
36 |
+
@property
|
37 |
+
def none_allowed(self):
|
38 |
+
return ["grad_clip", "eval_image_size"]
|
39 |
+
|
40 |
+
def __init__(self, **kwargs): # arguments must be passed as kwargs
|
41 |
+
for k, val in kwargs.items():
|
42 |
+
setattr(self, k, val)
|
43 |
+
|
44 |
+
# check that all relevant configs are there
|
45 |
+
annotations = {}
|
46 |
+
for clas in type(self).mro():
|
47 |
+
if hasattr(clas, "__annotations__"):
|
48 |
+
annotations.update(clas.__annotations__)
|
49 |
+
for k, k_type in annotations.items():
|
50 |
+
assert hasattr(
|
51 |
+
self, k
|
52 |
+
), f"Key {k} with type {k_type} required for initialization."
|
53 |
+
attr = getattr(self, k)
|
54 |
+
if k in self.none_allowed:
|
55 |
+
k_type = (k_type, type(None))
|
56 |
+
assert isinstance(
|
57 |
+
attr, k_type
|
58 |
+
), f"Key {k} must be type {k_type}, provided={attr}."
|
59 |
+
|
60 |
+
self.global_step = 0
|
61 |
+
self.batch_per_epoch = 1
|
62 |
+
|
63 |
+
def build_optimizer(self, network: nn.Module) -> tuple[any, any]:
|
64 |
+
r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler"""
|
65 |
+
param_dict = {}
|
66 |
+
for name, param in network.named_parameters():
|
67 |
+
if param.requires_grad:
|
68 |
+
opt_config = [self.weight_decay, self.init_lr]
|
69 |
+
if self.no_wd_keys is not None and len(self.no_wd_keys) > 0:
|
70 |
+
if np.any([key in name for key in self.no_wd_keys]):
|
71 |
+
opt_config[0] = 0
|
72 |
+
opt_key = json.dumps(opt_config)
|
73 |
+
param_dict[opt_key] = param_dict.get(opt_key, []) + [param]
|
74 |
+
|
75 |
+
net_params = []
|
76 |
+
for opt_key, param_list in param_dict.items():
|
77 |
+
wd, lr = json.loads(opt_key)
|
78 |
+
net_params.append({"params": param_list, "weight_decay": wd, "lr": lr})
|
79 |
+
|
80 |
+
optimizer = build_optimizer(
|
81 |
+
net_params, self.optimizer_name, self.optimizer_params, self.init_lr
|
82 |
+
)
|
83 |
+
# build lr scheduler
|
84 |
+
if self.lr_schedule_name == "cosine":
|
85 |
+
decay_steps = []
|
86 |
+
for epoch in self.lr_schedule_param.get("step", []):
|
87 |
+
decay_steps.append(epoch * self.batch_per_epoch)
|
88 |
+
decay_steps.append(self.n_epochs * self.batch_per_epoch)
|
89 |
+
decay_steps.sort()
|
90 |
+
lr_scheduler = CosineLRwithWarmup(
|
91 |
+
optimizer,
|
92 |
+
self.warmup_epochs * self.batch_per_epoch,
|
93 |
+
self.warmup_lr,
|
94 |
+
decay_steps,
|
95 |
+
)
|
96 |
+
else:
|
97 |
+
raise NotImplementedError
|
98 |
+
return optimizer, lr_scheduler
|
99 |
+
|
100 |
+
def update_global_step(self, epoch, batch_id=0) -> None:
|
101 |
+
self.global_step = epoch * self.batch_per_epoch + batch_id
|
102 |
+
Scheduler.PROGRESS = self.progress
|
103 |
+
|
104 |
+
@property
|
105 |
+
def progress(self) -> float:
|
106 |
+
warmup_steps = self.warmup_epochs * self.batch_per_epoch
|
107 |
+
steps = max(0, self.global_step - warmup_steps)
|
108 |
+
return steps / (self.n_epochs * self.batch_per_epoch)
|
109 |
+
|
110 |
+
def step(self) -> None:
|
111 |
+
self.global_step += 1
|
112 |
+
Scheduler.PROGRESS = self.progress
|
113 |
+
|
114 |
+
def get_remaining_epoch(self, epoch, post=True) -> int:
|
115 |
+
return self.n_epochs + self.warmup_epochs - epoch - int(post)
|
116 |
+
|
117 |
+
def epoch_format(self, epoch: int) -> str:
|
118 |
+
epoch_format = f"%.{len(str(self.n_epochs))}d"
|
119 |
+
epoch_format = f"[{epoch_format}/{epoch_format}]"
|
120 |
+
epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs)
|
121 |
+
return epoch_format
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
from .dist import *
|
6 |
+
from .ema import *
|
7 |
+
from .export import *
|
8 |
+
from .init import *
|
9 |
+
from .lr import *
|
10 |
+
from .metric import *
|
11 |
+
from .misc import *
|
12 |
+
from .opt import *
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (412 Bytes). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/dist.cpython-311.pyc
ADDED
Binary file (3.92 kB). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/ema.cpython-311.pyc
ADDED
Binary file (3.54 kB). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/export.cpython-311.pyc
ADDED
Binary file (2.57 kB). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/init.cpython-311.pyc
ADDED
Binary file (4.3 kB). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/lr.cpython-311.pyc
ADDED
Binary file (3.02 kB). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/metric.cpython-311.pyc
ADDED
Binary file (2.67 kB). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/misc.cpython-311.pyc
ADDED
Binary file (5.06 kB). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/opt.cpython-311.pyc
ADDED
Binary file (1.32 kB). View file
|
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/dist.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed
|
9 |
+
|
10 |
+
from efficientvit.models.utils.list import list_mean, list_sum
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
"dist_init",
|
14 |
+
"get_dist_rank",
|
15 |
+
"get_dist_size",
|
16 |
+
"is_master",
|
17 |
+
"dist_barrier",
|
18 |
+
"get_dist_local_rank",
|
19 |
+
"sync_tensor",
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
def dist_init() -> None:
|
24 |
+
try:
|
25 |
+
torch.distributed.init_process_group(backend="nccl")
|
26 |
+
assert torch.distributed.is_initialized()
|
27 |
+
except Exception:
|
28 |
+
# use torchpack
|
29 |
+
from torchpack import distributed as dist
|
30 |
+
|
31 |
+
dist.init()
|
32 |
+
os.environ["RANK"] = f"{dist.rank()}"
|
33 |
+
os.environ["WORLD_SIZE"] = f"{dist.size()}"
|
34 |
+
os.environ["LOCAL_RANK"] = f"{dist.local_rank()}"
|
35 |
+
|
36 |
+
|
37 |
+
def get_dist_rank() -> int:
|
38 |
+
return int(os.environ["RANK"])
|
39 |
+
|
40 |
+
|
41 |
+
def get_dist_size() -> int:
|
42 |
+
return int(os.environ["WORLD_SIZE"])
|
43 |
+
|
44 |
+
|
45 |
+
def is_master() -> bool:
|
46 |
+
return get_dist_rank() == 0
|
47 |
+
|
48 |
+
|
49 |
+
def dist_barrier() -> None:
|
50 |
+
torch.distributed.barrier()
|
51 |
+
|
52 |
+
|
53 |
+
def get_dist_local_rank() -> int:
|
54 |
+
return int(os.environ["LOCAL_RANK"])
|
55 |
+
|
56 |
+
|
57 |
+
def sync_tensor(
|
58 |
+
tensor: torch.Tensor or float, reduce="mean"
|
59 |
+
) -> torch.Tensor or list[torch.Tensor]:
|
60 |
+
if not isinstance(tensor, torch.Tensor):
|
61 |
+
tensor = torch.Tensor(1).fill_(tensor).cuda()
|
62 |
+
tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())]
|
63 |
+
torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False)
|
64 |
+
if reduce == "mean":
|
65 |
+
return list_mean(tensor_list)
|
66 |
+
elif reduce == "sum":
|
67 |
+
return list_sum(tensor_list)
|
68 |
+
elif reduce == "cat":
|
69 |
+
return torch.cat(tensor_list, dim=0)
|
70 |
+
elif reduce == "root":
|
71 |
+
return tensor_list[0]
|
72 |
+
else:
|
73 |
+
return tensor_list
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/ema.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
import copy
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from efficientvit.models.utils import is_parallel
|
12 |
+
|
13 |
+
__all__ = ["EMA"]
|
14 |
+
|
15 |
+
|
16 |
+
def update_ema(
|
17 |
+
ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float
|
18 |
+
) -> None:
|
19 |
+
for k, v in ema.state_dict().items():
|
20 |
+
if v.dtype.is_floating_point:
|
21 |
+
v -= (1.0 - decay) * (v - new_state_dict[k].detach())
|
22 |
+
|
23 |
+
|
24 |
+
class EMA:
|
25 |
+
def __init__(self, model: nn.Module, decay: float, warmup_steps=2000):
|
26 |
+
self.shadows = copy.deepcopy(
|
27 |
+
model.module if is_parallel(model) else model
|
28 |
+
).eval()
|
29 |
+
self.decay = decay
|
30 |
+
self.warmup_steps = warmup_steps
|
31 |
+
|
32 |
+
for p in self.shadows.parameters():
|
33 |
+
p.requires_grad = False
|
34 |
+
|
35 |
+
def step(self, model: nn.Module, global_step: int) -> None:
|
36 |
+
with torch.no_grad():
|
37 |
+
msd = (model.module if is_parallel(model) else model).state_dict()
|
38 |
+
update_ema(
|
39 |
+
self.shadows,
|
40 |
+
msd,
|
41 |
+
self.decay * (1 - math.exp(-global_step / self.warmup_steps)),
|
42 |
+
)
|
43 |
+
|
44 |
+
def state_dict(self) -> dict[float, dict[str, torch.Tensor]]:
|
45 |
+
return {self.decay: self.shadows.state_dict()}
|
46 |
+
|
47 |
+
def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None:
|
48 |
+
for decay in state_dict:
|
49 |
+
if decay == self.decay:
|
50 |
+
self.shadows.load_state_dict(state_dict[decay])
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/export.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
|
8 |
+
import onnx
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from onnxsim import simplify as simplify_func
|
12 |
+
|
13 |
+
__all__ = ["export_onnx"]
|
14 |
+
|
15 |
+
|
16 |
+
def export_onnx(
|
17 |
+
model: nn.Module, export_path: str, sample_inputs: any, simplify=True, opset=11
|
18 |
+
) -> None:
|
19 |
+
"""Export a model to a platform-specific onnx format.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
model: a torch.nn.Module object.
|
23 |
+
export_path: export location.
|
24 |
+
sample_inputs: Any.
|
25 |
+
simplify: a flag to turn on onnx-simplifier
|
26 |
+
opset: int
|
27 |
+
"""
|
28 |
+
model.eval()
|
29 |
+
|
30 |
+
buffer = io.BytesIO()
|
31 |
+
with torch.no_grad():
|
32 |
+
torch.onnx.export(model, sample_inputs, buffer, opset_version=opset)
|
33 |
+
buffer.seek(0, 0)
|
34 |
+
if simplify:
|
35 |
+
onnx_model = onnx.load_model(buffer)
|
36 |
+
onnx_model, success = simplify_func(onnx_model)
|
37 |
+
assert success
|
38 |
+
new_buffer = io.BytesIO()
|
39 |
+
onnx.save(onnx_model, new_buffer)
|
40 |
+
buffer = new_buffer
|
41 |
+
buffer.seek(0, 0)
|
42 |
+
|
43 |
+
if buffer.getbuffer().nbytes > 0:
|
44 |
+
save_dir = os.path.dirname(export_path)
|
45 |
+
os.makedirs(save_dir, exist_ok=True)
|
46 |
+
with open(export_path, "wb") as f:
|
47 |
+
f.write(buffer.read())
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/init.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
8 |
+
|
9 |
+
__all__ = ["init_modules", "zero_last_gamma"]
|
10 |
+
|
11 |
+
|
12 |
+
def init_modules(model: nn.Module or list[nn.Module], init_type="trunc_normal") -> None:
|
13 |
+
_DEFAULT_INIT_PARAM = {"trunc_normal": 0.02}
|
14 |
+
|
15 |
+
if isinstance(model, list):
|
16 |
+
for sub_module in model:
|
17 |
+
init_modules(sub_module, init_type)
|
18 |
+
else:
|
19 |
+
init_params = init_type.split("@")
|
20 |
+
init_params = float(init_params[1]) if len(init_params) > 1 else None
|
21 |
+
|
22 |
+
if init_type.startswith("trunc_normal"):
|
23 |
+
init_func = lambda param: nn.init.trunc_normal_(
|
24 |
+
param, std=(init_params or _DEFAULT_INIT_PARAM["trunc_normal"])
|
25 |
+
)
|
26 |
+
else:
|
27 |
+
raise NotImplementedError
|
28 |
+
|
29 |
+
for m in model.modules():
|
30 |
+
if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
|
31 |
+
init_func(m.weight)
|
32 |
+
if m.bias is not None:
|
33 |
+
m.bias.data.zero_()
|
34 |
+
elif isinstance(m, nn.Embedding):
|
35 |
+
init_func(m.weight)
|
36 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
37 |
+
m.weight.data.fill_(1)
|
38 |
+
m.bias.data.zero_()
|
39 |
+
else:
|
40 |
+
weight = getattr(m, "weight", None)
|
41 |
+
bias = getattr(m, "bias", None)
|
42 |
+
if isinstance(weight, torch.nn.Parameter):
|
43 |
+
init_func(weight)
|
44 |
+
if isinstance(bias, torch.nn.Parameter):
|
45 |
+
bias.data.zero_()
|
46 |
+
|
47 |
+
|
48 |
+
def zero_last_gamma(model: nn.Module, init_val=0) -> None:
|
49 |
+
import efficientvit.models.nn.ops as ops
|
50 |
+
|
51 |
+
for m in model.modules():
|
52 |
+
if isinstance(m, ops.ResidualBlock) and isinstance(
|
53 |
+
m.shortcut, ops.IdentityLayer
|
54 |
+
):
|
55 |
+
if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)):
|
56 |
+
parent_module = m.main.point_conv
|
57 |
+
elif isinstance(m.main, ops.ResBlock):
|
58 |
+
parent_module = m.main.conv2
|
59 |
+
elif isinstance(m.main, ops.ConvLayer):
|
60 |
+
parent_module = m.main
|
61 |
+
elif isinstance(m.main, (ops.LiteMLA)):
|
62 |
+
parent_module = m.main.proj
|
63 |
+
else:
|
64 |
+
parent_module = None
|
65 |
+
if parent_module is not None:
|
66 |
+
norm = getattr(parent_module, "norm", None)
|
67 |
+
if norm is not None:
|
68 |
+
nn.init.constant_(norm.weight, init_val)
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/lr.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from efficientvit.models.utils.list import val2list
|
10 |
+
|
11 |
+
__all__ = ["CosineLRwithWarmup"]
|
12 |
+
|
13 |
+
|
14 |
+
class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
optimizer: torch.optim.Optimizer,
|
18 |
+
warmup_steps: int,
|
19 |
+
warmup_lr: float,
|
20 |
+
decay_steps: int or list[int],
|
21 |
+
last_epoch: int = -1,
|
22 |
+
) -> None:
|
23 |
+
self.warmup_steps = warmup_steps
|
24 |
+
self.warmup_lr = warmup_lr
|
25 |
+
self.decay_steps = val2list(decay_steps)
|
26 |
+
super().__init__(optimizer, last_epoch)
|
27 |
+
|
28 |
+
def get_lr(self) -> list[float]:
|
29 |
+
if self.last_epoch < self.warmup_steps:
|
30 |
+
return [
|
31 |
+
(base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps
|
32 |
+
+ self.warmup_lr
|
33 |
+
for base_lr in self.base_lrs
|
34 |
+
]
|
35 |
+
else:
|
36 |
+
current_steps = self.last_epoch - self.warmup_steps
|
37 |
+
decay_steps = [0] + self.decay_steps
|
38 |
+
idx = len(decay_steps) - 2
|
39 |
+
for i, decay_step in enumerate(decay_steps[:-1]):
|
40 |
+
if decay_step <= current_steps < decay_steps[i + 1]:
|
41 |
+
idx = i
|
42 |
+
break
|
43 |
+
current_steps -= decay_steps[idx]
|
44 |
+
decay_step = decay_steps[idx + 1] - decay_steps[idx]
|
45 |
+
return [
|
46 |
+
0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step))
|
47 |
+
for base_lr in self.base_lrs
|
48 |
+
]
|
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/metric.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from efficientvit.apps.utils.dist import sync_tensor
|
8 |
+
|
9 |
+
__all__ = ["AverageMeter"]
|
10 |
+
|
11 |
+
|
12 |
+
class AverageMeter:
|
13 |
+
"""Computes and stores the average and current value."""
|
14 |
+
|
15 |
+
def __init__(self, is_distributed=True):
|
16 |
+
self.is_distributed = is_distributed
|
17 |
+
self.sum = 0
|
18 |
+
self.count = 0
|
19 |
+
|
20 |
+
def _sync(self, val: torch.Tensor or int or float) -> torch.Tensor or int or float:
|
21 |
+
return sync_tensor(val, reduce="sum") if self.is_distributed else val
|
22 |
+
|
23 |
+
def update(self, val: torch.Tensor or int or float, delta_n=1):
|
24 |
+
self.count += self._sync(delta_n)
|
25 |
+
self.sum += self._sync(val * delta_n)
|
26 |
+
|
27 |
+
def get_count(self) -> torch.Tensor or int or float:
|
28 |
+
return (
|
29 |
+
self.count.item()
|
30 |
+
if isinstance(self.count, torch.Tensor) and self.count.numel() == 1
|
31 |
+
else self.count
|
32 |
+
)
|
33 |
+
|
34 |
+
@property
|
35 |
+
def avg(self):
|
36 |
+
avg = -1 if self.count == 0 else self.sum / self.count
|
37 |
+
return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg
|