luxmorocco commited on
Commit
4efbc62
1 Parent(s): 09723cd

Upload 86 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. yolo-world-with-efficientvit-sam/.DS_Store +0 -0
  2. yolo-world-with-efficientvit-sam/.gitignore +171 -0
  3. yolo-world-with-efficientvit-sam/LICENSE +201 -0
  4. yolo-world-with-efficientvit-sam/Makefile +18 -0
  5. yolo-world-with-efficientvit-sam/README.md +68 -0
  6. yolo-world-with-efficientvit-sam/app.py +132 -0
  7. yolo-world-with-efficientvit-sam/efficientvit/__init__.py +0 -0
  8. yolo-world-with-efficientvit-sam/efficientvit/__pycache__/__init__.cpython-311.pyc +0 -0
  9. yolo-world-with-efficientvit-sam/efficientvit/__pycache__/sam_model_zoo.cpython-311.pyc +0 -0
  10. yolo-world-with-efficientvit-sam/efficientvit/apps/__init__.py +0 -0
  11. yolo-world-with-efficientvit-sam/efficientvit/apps/__pycache__/__init__.cpython-311.pyc +0 -0
  12. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__init__.py +7 -0
  13. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__pycache__/__init__.cpython-311.pyc +0 -0
  14. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__pycache__/base.cpython-311.pyc +0 -0
  15. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__init__.py +6 -0
  16. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/__init__.cpython-311.pyc +0 -0
  17. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/bbox.cpython-311.pyc +0 -0
  18. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/color_aug.cpython-311.pyc +0 -0
  19. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/bbox.py +30 -0
  20. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/color_aug.py +84 -0
  21. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/base.py +223 -0
  22. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__init__.py +7 -0
  23. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__pycache__/__init__.cpython-311.pyc +0 -0
  24. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__pycache__/controller.cpython-311.pyc +0 -0
  25. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/_data_loader.py +1603 -0
  26. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/_data_worker.py +378 -0
  27. yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/controller.py +94 -0
  28. yolo-world-with-efficientvit-sam/efficientvit/apps/setup.py +141 -0
  29. yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__init__.py +6 -0
  30. yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/__init__.cpython-311.pyc +0 -0
  31. yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/base.cpython-311.pyc +0 -0
  32. yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/run_config.cpython-311.pyc +0 -0
  33. yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/base.py +297 -0
  34. yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/run_config.py +121 -0
  35. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__init__.py +12 -0
  36. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  37. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/dist.cpython-311.pyc +0 -0
  38. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/ema.cpython-311.pyc +0 -0
  39. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/export.cpython-311.pyc +0 -0
  40. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/init.cpython-311.pyc +0 -0
  41. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/lr.cpython-311.pyc +0 -0
  42. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/metric.cpython-311.pyc +0 -0
  43. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/misc.cpython-311.pyc +0 -0
  44. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/opt.cpython-311.pyc +0 -0
  45. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/dist.py +73 -0
  46. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/ema.py +50 -0
  47. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/export.py +47 -0
  48. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/init.py +68 -0
  49. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/lr.py +48 -0
  50. yolo-world-with-efficientvit-sam/efficientvit/apps/utils/metric.py +37 -0
yolo-world-with-efficientvit-sam/.DS_Store ADDED
Binary file (6.15 kB). View file
 
yolo-world-with-efficientvit-sam/.gitignore ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # Model Weights
163
+ *.pth
164
+ *.pt
165
+
166
+ # Yolo-World
167
+ work_dirs
168
+ src
169
+
170
+ # Etc
171
+ .DS_Store
yolo-world-with-efficientvit-sam/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
yolo-world-with-efficientvit-sam/Makefile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EFFICIENTVIT_SAM_URL := "https://huggingface.co/han-cai/efficientvit-sam/resolve/main"
2
+ EFFICIENTVIT_SAM_MODEL := "xl1.pt"
3
+
4
+
5
+ define download
6
+ @if [ ! -f $(2) ]; then \
7
+ echo "Download $(2)..."; \
8
+ wget "$(1)/$(2)"; \
9
+ fi
10
+ endef
11
+
12
+
13
+ setup:
14
+ pip install -r requirements.txt
15
+
16
+
17
+ model:
18
+ $(call download,$(EFFICIENTVIT_SAM_URL),$(EFFICIENTVIT_SAM_MODEL))
yolo-world-with-efficientvit-sam/README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLO-World + EfficientViT SAM
2
+
3
+ 🤗 [HuggingFace Space](https://huggingface.co/spaces/curt-park/yolo-world-with-efficientvit-sam)
4
+
5
+ ![example_0](https://github.com/Curt-Park/yolo-world-with-efficientvit-sam/assets/14961526/326bde19-d535-4be5-829e-782fce0c1d00)
6
+
7
+ ## Prerequisites
8
+ This project is developed and tested on Python3.10.
9
+
10
+ ```bash
11
+ # Create and activate a python 3.10 environment.
12
+ conda create -n yolo-world-with-efficientvit-sam python=3.10 -y
13
+ conda activate yolo-world-with-efficientvit-sam
14
+ # Setup packages.
15
+ make setup
16
+ ```
17
+
18
+ ## How to Run
19
+ ```bash
20
+ python app.py
21
+ ```
22
+
23
+ Open http://127.0.0.1:7860/ on your web browser.
24
+
25
+ ![example_1](https://github.com/Curt-Park/yolo-world-with-efficientvit-sam/assets/14961526/9388e4ee-6f71-4428-b17c-d218fd059949)
26
+
27
+ ## Core Components
28
+
29
+ ### YOLO-World
30
+ [YOLO-World](https://github.com/AILab-CVC/YOLO-World) is an open-vocabulary object detection model with high efficiency.
31
+ On the challenging LVIS dataset, YOLO-World achieves 35.4 AP with 52.0 FPS on V100,
32
+ which outperforms many state-of-the-art methods in terms of both accuracy and speed.
33
+ ![image](https://github.com/Curt-Park/yolo-world-with-efficientvit-sam/assets/14961526/8a4a17bd-918d-478a-8451-f58e4a2dce79)
34
+ <img width="1024" src="https://github.com/Curt-Park/yolo-world-with-efficientvit-sam/assets/14961526/fce57405-e18d-45f3-bea8-fc3971faf975">
35
+
36
+ ### EfficientViT SAM
37
+ [EfficientViT SAM](https://github.com/mit-han-lab/efficientvit) is a new family of accelerated segment anything models.
38
+ Thanks to the lightweight and hardware-efficient core building block,
39
+ it delivers 48.9× measured TensorRT speedup on A100 GPU over SAM-ViT-H without sacrificing performance.
40
+
41
+ <img width="1024" src="https://github.com/Curt-Park/yolo-world-with-efficientvit-sam/assets/14961526/9eec003f-47c9-43a5-86b0-82d6689e1bf9">
42
+ <img width="1024" src="https://github.com/Curt-Park/yolo-world-with-efficientvit-sam/assets/14961526/d79973bb-0d80-4b64-a175-252de56d0d09">
43
+
44
+ ## Powered By
45
+ ```
46
+ @misc{zhang2024efficientvitsam,
47
+ title={EfficientViT-SAM: Accelerated Segment Anything Model Without Performance Loss},
48
+ author={Zhuoyang Zhang and Han Cai and Song Han},
49
+ year={2024},
50
+ eprint={2402.05008},
51
+ archivePrefix={arXiv},
52
+ primaryClass={cs.CV}
53
+ }
54
+
55
+ @article{cheng2024yolow,
56
+ title={YOLO-World: Real-Time Open-Vocabulary Object Detection},
57
+ author={Cheng, Tianheng and Song, Lin and Ge, Yixiao and Liu, Wenyu and Wang, Xinggang and Shan, Ying},
58
+ journal={arXiv preprint arXiv:2401.17270},
59
+ year={2024}
60
+ }
61
+
62
+ @article{cai2022efficientvit,
63
+ title={Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition},
64
+ author={Cai, Han and Gan, Chuang and Han, Song},
65
+ journal={arXiv preprint arXiv:2205.14756},
66
+ year={2022}
67
+ }
68
+ ```
yolo-world-with-efficientvit-sam/app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fast text to segmentation with yolo-world and efficient-vit sam."""
2
+ import os
3
+
4
+ import cv2
5
+ import gradio as gr
6
+ import numpy as np
7
+ import supervision as sv
8
+ import torch
9
+ from inference.models import YOLOWorld
10
+
11
+ from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
12
+ from efficientvit.sam_model_zoo import create_sam_model
13
+
14
+
15
+ # Download model weights.
16
+ os.system("make model")
17
+
18
+ # Load models.
19
+ yolo_world = YOLOWorld(model_id="yolo_world/l")
20
+ #yolo_world = YOLOWorld("/Users/tounsi/Desktop/DOCTORIA/Doctoria\ Full\ Software/Doctoria\ CXR/Doctoria\ CXR\ Thoracic\ Abnormalities/YOLOv8/CXR\ YOLOv8l.pt")
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ sam = EfficientViTSamPredictor(
23
+ create_sam_model(name="xl1", weight_url="xl1.pt").to(device).eval()
24
+ )
25
+
26
+
27
+ # Load annotators.
28
+ BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
29
+ MASK_ANNOTATOR = sv.MaskAnnotator()
30
+ LABEL_ANNOTATOR = sv.LabelAnnotator()
31
+
32
+
33
+ def detect(
34
+ image: np.ndarray,
35
+ query: str,
36
+ confidence_threshold: float,
37
+ nms_threshold: float,
38
+ ) -> np.ndarray:
39
+ # Preparation.
40
+ categories = [category.strip() for category in query.split(",")]
41
+ yolo_world.set_classes(categories)
42
+ print("categories:", categories)
43
+
44
+ # Object detection.
45
+ results = yolo_world.infer(image, confidence=confidence_threshold)
46
+ detections = sv.Detections.from_inference(results).with_nms(
47
+ class_agnostic=True, threshold=nms_threshold
48
+ )
49
+ print("detected:", detections)
50
+
51
+ # Segmentation.
52
+ sam.set_image(image, image_format="RGB")
53
+ masks = []
54
+ for xyxy in detections.xyxy:
55
+ mask, _, _ = sam.predict(box=xyxy, multimask_output=False)
56
+ masks.append(mask.squeeze())
57
+ detections.mask = np.array(masks)
58
+ print("masks shaped as", detections.mask.shape)
59
+
60
+ # Annotation.
61
+ output_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
62
+ labels = [
63
+ f"{categories[class_id]}: {confidence:.2f}"
64
+ for class_id, confidence in zip(detections.class_id, detections.confidence)
65
+ ]
66
+ output_image = MASK_ANNOTATOR.annotate(output_image, detections)
67
+ output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
68
+ output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
69
+ return cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
70
+
71
+
72
+ app = gr.Interface(
73
+ fn=detect,
74
+ inputs=[
75
+ gr.Image(type="numpy", label="input image"),
76
+ gr.Text(info="you can input multiple words with comma (,)"),
77
+ gr.Slider(
78
+ minimum=0,
79
+ maximum=1,
80
+ value=0.3,
81
+ step=0.01,
82
+ interactive=True,
83
+ label="Confidence Threshold",
84
+ ),
85
+ gr.Slider(
86
+ minimum=0,
87
+ maximum=1,
88
+ value=0.5,
89
+ step=0.01,
90
+ interactive=True,
91
+ label="NMS Threshold",
92
+ ),
93
+ ],
94
+ outputs=gr.Image(type="numpy", label="output image"),
95
+ allow_flagging="never",
96
+ title="Fast Text to Segmentation with YOLO-World + EfficientViT SAM",
97
+ description="""
98
+ ## Core components
99
+ ### YOLO-World
100
+ [YOLO-World](https://github.com/AILab-CVC/YOLO-World) is an open-vocabulary object detection model with high efficiency.
101
+ On the challenging LVIS dataset, YOLO-World achieves 35.4 AP with 52.0 FPS on V100,
102
+ which outperforms many state-of-the-art methods in terms of both accuracy and speed.
103
+
104
+ ### EfficientViT SAM
105
+ [EfficientViT SAM](https://github.com/mit-han-lab/efficientvit) is a new family of accelerated segment anything models.
106
+ Thanks to the lightweight and hardware-efficient core building block,
107
+ it delivers 48.9× measured TensorRT speedup on A100 GPU over SAM-ViT-H without sacrificing performance.
108
+
109
+ ## Demo especially powered by
110
+ Roboflow's [inference](https://github.com/roboflow/inference) and [supervision](https://github.com/roboflow/supervision).
111
+
112
+ ## Example images came from
113
+ [Segment Anything Demo](https://segment-anything.com/demo) and [Unsplash](https://unsplash.com/).
114
+ """,
115
+ examples=[
116
+ [
117
+ os.path.join(os.path.dirname(__file__), "examples/livingroom.jpg"),
118
+ "table, lamp, dog, sofa, plant, clock, carpet, frame on the wall",
119
+ 0.05,
120
+ 0.5
121
+ ],
122
+ [
123
+ os.path.join(os.path.dirname(__file__), "examples/cat_and_dogs.jpg"),
124
+ "cat, dog",
125
+ 0.2,
126
+ 0.5
127
+ ],
128
+ ],
129
+ )
130
+
131
+
132
+ app.launch(server_name="0.0.0.0")
yolo-world-with-efficientvit-sam/efficientvit/__init__.py ADDED
File without changes
yolo-world-with-efficientvit-sam/efficientvit/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (188 Bytes). View file
 
yolo-world-with-efficientvit-sam/efficientvit/__pycache__/sam_model_zoo.cpython-311.pyc ADDED
Binary file (2.28 kB). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/__init__.py ADDED
File without changes
yolo-world-with-efficientvit-sam/efficientvit/apps/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (193 Bytes). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .augment import *
6
+ from .base import *
7
+ from .random_resolution import *
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (306 Bytes). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/__pycache__/base.cpython-311.pyc ADDED
Binary file (11.4 kB). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .bbox import *
6
+ from .color_aug import *
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (277 Bytes). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/bbox.cpython-311.pyc ADDED
Binary file (1.44 kB). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/__pycache__/color_aug.cpython-311.pyc ADDED
Binary file (5.17 kB). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/bbox.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import numpy as np
6
+
7
+ __all__ = ["rand_bbox"]
8
+
9
+
10
+ def rand_bbox(
11
+ h: int,
12
+ w: int,
13
+ lam: float,
14
+ rand_func: callable = np.random.uniform,
15
+ ) -> tuple[int, int, int, int]:
16
+ """randomly sample bbox, used in cutmix"""
17
+ cut_rat = np.sqrt(1.0 - lam)
18
+ cut_w = w * cut_rat
19
+ cut_h = h * cut_rat
20
+
21
+ # uniform
22
+ cx = rand_func(0, w)
23
+ cy = rand_func(0, h)
24
+
25
+ bbx1 = int(np.clip(cx - cut_w / 2, 0, w))
26
+ bby1 = int(np.clip(cy - cut_h / 2, 0, h))
27
+ bbx2 = int(np.clip(cx + cut_w / 2, 0, w))
28
+ bby2 = int(np.clip(cy + cut_h / 2, 0, h))
29
+
30
+ return bbx1, bby1, bbx2, bby2
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/augment/color_aug.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import numpy as np
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image
8
+ from timm.data.auto_augment import rand_augment_transform
9
+
10
+ __all__ = ["ColorAug", "RandAug"]
11
+
12
+
13
+ class ImageAug:
14
+ def aug_image(self, image: Image.Image) -> Image.Image:
15
+ raise NotImplementedError
16
+
17
+ def __call__(
18
+ self, feed_dict: dict or np.ndarray or Image.Image
19
+ ) -> dict or np.ndarray or Image.Image:
20
+ if isinstance(feed_dict, dict):
21
+ output_dict = feed_dict
22
+ image = feed_dict[self.key]
23
+ else:
24
+ output_dict = None
25
+ image = feed_dict
26
+ is_ndarray = isinstance(image, np.ndarray)
27
+ if is_ndarray:
28
+ image = Image.fromarray(image)
29
+
30
+ image = self.aug_image(image)
31
+
32
+ if is_ndarray:
33
+ image = np.array(image)
34
+
35
+ if output_dict is None:
36
+ return image
37
+ else:
38
+ output_dict[self.key] = image
39
+ return output_dict
40
+
41
+
42
+ class ColorAug(transforms.ColorJitter, ImageAug):
43
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, key="data"):
44
+ super().__init__(
45
+ brightness=brightness,
46
+ contrast=contrast,
47
+ saturation=saturation,
48
+ hue=hue,
49
+ )
50
+ self.key = key
51
+
52
+ def aug_image(self, image: Image.Image) -> Image.Image:
53
+ return transforms.ColorJitter.forward(self, image)
54
+
55
+ def forward(
56
+ self, feed_dict: dict or np.ndarray or Image.Image
57
+ ) -> dict or np.ndarray or Image.Image:
58
+ return ImageAug.__call__(self, feed_dict)
59
+
60
+
61
+ class RandAug(ImageAug):
62
+ def __init__(
63
+ self, config: dict[str, any], mean: tuple[float, float, float], key="data"
64
+ ):
65
+ n = config.get("n", 2)
66
+ m = config.get("m", 9)
67
+ mstd = config.get("mstd", 1.0)
68
+ inc = config.get("inc", 1)
69
+ tpct = config.get("tpct", 0.45)
70
+ config_str = f"rand-n{n}-m{m}-mstd{mstd}-inc{inc}"
71
+
72
+ aa_params = dict(
73
+ translate_pct=tpct,
74
+ img_mean=tuple([min(255, round(255 * x)) for x in mean]),
75
+ interpolation=Image.BICUBIC,
76
+ )
77
+ self.aug_op = rand_augment_transform(config_str, aa_params)
78
+ self.key = key
79
+
80
+ def aug_image(self, image: Image.Image) -> Image.Image:
81
+ return self.aug_op(image)
82
+
83
+ def __repr__(self):
84
+ return self.aug_op.__repr__()
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/base.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import copy
6
+ import warnings
7
+
8
+ import torch.utils.data
9
+ from torch.utils.data.distributed import DistributedSampler
10
+
11
+ from efficientvit.apps.data_provider.random_resolution import RRSController
12
+ from efficientvit.models.utils import val2tuple
13
+
14
+ __all__ = ["parse_image_size", "random_drop_data", "DataProvider"]
15
+
16
+
17
+ def parse_image_size(size: int or str) -> tuple[int, int]:
18
+ if isinstance(size, str):
19
+ size = [int(val) for val in size.split("-")]
20
+ return size[0], size[1]
21
+ else:
22
+ return val2tuple(size, 2)
23
+
24
+
25
+ def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)):
26
+ g = torch.Generator()
27
+ g.manual_seed(seed) # set random seed before sampling validation set
28
+ rand_indexes = torch.randperm(len(dataset), generator=g).tolist()
29
+
30
+ dropped_indexes = rand_indexes[:drop_size]
31
+ remaining_indexes = rand_indexes[drop_size:]
32
+
33
+ dropped_dataset = copy.deepcopy(dataset)
34
+ for key in keys:
35
+ setattr(
36
+ dropped_dataset,
37
+ key,
38
+ [getattr(dropped_dataset, key)[idx] for idx in dropped_indexes],
39
+ )
40
+ setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes])
41
+ return dataset, dropped_dataset
42
+
43
+
44
+ class DataProvider:
45
+ data_keys = ("samples",)
46
+ mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}
47
+ SUB_SEED = 937162211 # random seed for sampling subset
48
+ VALID_SEED = 2147483647 # random seed for the validation set
49
+
50
+ name: str
51
+
52
+ def __init__(
53
+ self,
54
+ train_batch_size: int,
55
+ test_batch_size: int or None,
56
+ valid_size: int or float or None,
57
+ n_worker: int,
58
+ image_size: int or list[int] or str or list[str],
59
+ num_replicas: int or None = None,
60
+ rank: int or None = None,
61
+ train_ratio: float or None = None,
62
+ drop_last: bool = False,
63
+ ):
64
+ warnings.filterwarnings("ignore")
65
+ super().__init__()
66
+
67
+ # batch_size & valid_size
68
+ self.train_batch_size = train_batch_size
69
+ self.test_batch_size = test_batch_size or self.train_batch_size
70
+ self.valid_size = valid_size
71
+
72
+ # image size
73
+ if isinstance(image_size, list):
74
+ self.image_size = [parse_image_size(size) for size in image_size]
75
+ self.image_size.sort() # e.g., 160 -> 224
76
+ RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size)
77
+ self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1]
78
+ else:
79
+ self.image_size = parse_image_size(image_size)
80
+ RRSController.IMAGE_SIZE_LIST = [self.image_size]
81
+ self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size
82
+
83
+ # distributed configs
84
+ self.num_replicas = num_replicas
85
+ self.rank = rank
86
+
87
+ # build datasets
88
+ train_dataset, val_dataset, test_dataset = self.build_datasets()
89
+
90
+ if train_ratio is not None and train_ratio < 1.0:
91
+ assert 0 < train_ratio < 1
92
+ _, train_dataset = random_drop_data(
93
+ train_dataset,
94
+ int(train_ratio * len(train_dataset)),
95
+ self.SUB_SEED,
96
+ self.data_keys,
97
+ )
98
+
99
+ # build data loader
100
+ self.train = self.build_dataloader(
101
+ train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True
102
+ )
103
+ self.valid = self.build_dataloader(
104
+ val_dataset, test_batch_size, n_worker, drop_last=False, train=False
105
+ )
106
+ self.test = self.build_dataloader(
107
+ test_dataset, test_batch_size, n_worker, drop_last=False, train=False
108
+ )
109
+ if self.valid is None:
110
+ self.valid = self.test
111
+ self.sub_train = None
112
+
113
+ @property
114
+ def data_shape(self) -> tuple[int, ...]:
115
+ return 3, self.active_image_size[0], self.active_image_size[1]
116
+
117
+ def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any:
118
+ raise NotImplementedError
119
+
120
+ def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any:
121
+ raise NotImplementedError
122
+
123
+ def build_datasets(self) -> tuple[any, any, any]:
124
+ raise NotImplementedError
125
+
126
+ def build_dataloader(
127
+ self,
128
+ dataset: any or None,
129
+ batch_size: int,
130
+ n_worker: int,
131
+ drop_last: bool,
132
+ train: bool,
133
+ ):
134
+ if dataset is None:
135
+ return None
136
+ if isinstance(self.image_size, list) and train:
137
+ from efficientvit.apps.data_provider.random_resolution._data_loader import \
138
+ RRSDataLoader
139
+
140
+ dataloader_class = RRSDataLoader
141
+ else:
142
+ dataloader_class = torch.utils.data.DataLoader
143
+ if self.num_replicas is None:
144
+ return dataloader_class(
145
+ dataset=dataset,
146
+ batch_size=batch_size,
147
+ shuffle=True,
148
+ num_workers=n_worker,
149
+ pin_memory=True,
150
+ drop_last=drop_last,
151
+ )
152
+ else:
153
+ sampler = DistributedSampler(dataset, self.num_replicas, self.rank)
154
+ return dataloader_class(
155
+ dataset=dataset,
156
+ batch_size=batch_size,
157
+ sampler=sampler,
158
+ num_workers=n_worker,
159
+ pin_memory=True,
160
+ drop_last=drop_last,
161
+ )
162
+
163
+ def set_epoch(self, epoch: int) -> None:
164
+ RRSController.set_epoch(epoch, len(self.train))
165
+ if isinstance(self.train.sampler, DistributedSampler):
166
+ self.train.sampler.set_epoch(epoch)
167
+
168
+ def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None:
169
+ self.active_image_size = val2tuple(new_size, 2)
170
+ new_transform = self.build_valid_transform(self.active_image_size)
171
+ # change the transform of the valid and test set
172
+ self.valid.dataset.transform = self.test.dataset.transform = new_transform
173
+
174
+ def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]:
175
+ if self.valid_size is not None:
176
+ if 0 < self.valid_size < 1:
177
+ valid_size = int(self.valid_size * len(train_dataset))
178
+ else:
179
+ assert self.valid_size >= 1
180
+ valid_size = int(self.valid_size)
181
+ train_dataset, val_dataset = random_drop_data(
182
+ train_dataset,
183
+ valid_size,
184
+ self.VALID_SEED,
185
+ self.data_keys,
186
+ )
187
+ val_dataset.transform = valid_transform
188
+ else:
189
+ val_dataset = None
190
+ return train_dataset, val_dataset
191
+
192
+ def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any:
193
+ # used for resetting BN running statistics
194
+ if self.sub_train is None:
195
+ self.sub_train = {}
196
+ if self.active_image_size in self.sub_train:
197
+ return self.sub_train[self.active_image_size]
198
+
199
+ # construct dataset and dataloader
200
+ train_dataset = copy.deepcopy(self.train.dataset)
201
+ if n_samples < len(train_dataset):
202
+ _, train_dataset = random_drop_data(
203
+ train_dataset,
204
+ n_samples,
205
+ self.SUB_SEED,
206
+ self.data_keys,
207
+ )
208
+ RRSController.ACTIVE_SIZE = self.active_image_size
209
+ train_dataset.transform = self.build_train_transform(
210
+ image_size=self.active_image_size
211
+ )
212
+ data_loader = self.build_dataloader(
213
+ train_dataset, batch_size, self.train.num_workers, True, False
214
+ )
215
+
216
+ # pre-fetch data
217
+ self.sub_train[self.active_image_size] = [
218
+ data
219
+ for data in data_loader
220
+ for _ in range(max(1, n_samples // len(train_dataset)))
221
+ ]
222
+
223
+ return self.sub_train[self.active_image_size]
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Random resolution data loader compatible with multi-processing and distributed training.
2
+
3
+ Replace Pytorch's DataLoader with RRSDataLoader to support random resolution
4
+ at the training time, resolution sampling is controlled by RRSController
5
+ """
6
+
7
+ from .controller import *
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (527 Bytes). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/__pycache__/controller.cpython-311.pyc ADDED
Binary file (5.7 kB). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/_data_loader.py ADDED
@@ -0,0 +1,1603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""This file is based on torch/utils/data/data_loader.py
2
+
3
+ Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
4
+
5
+ To support these two classes, in `./_utils` we define many utility methods and
6
+ functions to be run in multiprocessing. E.g., the data loading worker loop is
7
+ in `./_utils/worker.py`.
8
+ """
9
+
10
+ import functools
11
+ import itertools
12
+ import logging
13
+ import multiprocessing as python_multiprocessing
14
+ import os
15
+ import queue
16
+ import threading
17
+ import warnings
18
+ from typing import (Any, Callable, Generic, Iterable, List, Optional, Sequence,
19
+ TypeVar, Union)
20
+
21
+ import torch
22
+ import torch.distributed as dist
23
+ import torch.multiprocessing as multiprocessing
24
+ import torch.utils.data.graph_settings
25
+ from torch._utils import ExceptionWrapper
26
+ from torch.utils.data import (BatchSampler, Dataset, IterableDataset,
27
+ IterDataPipe, MapDataPipe, RandomSampler,
28
+ Sampler, SequentialSampler, _utils)
29
+ from torch.utils.data.datapipes.datapipe import (
30
+ _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper)
31
+
32
+ from ._data_worker import _worker_loop
33
+
34
+ __all__ = ["RRSDataLoader"]
35
+
36
+ T_co = TypeVar("T_co", covariant=True)
37
+ T = TypeVar("T")
38
+ _worker_init_fn_t = Callable[[int], None]
39
+
40
+ # Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
41
+ # type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
42
+ # See https://github.com/python/mypy/issues/3737.
43
+ _collate_fn_t = Callable[[List[T]], Any]
44
+
45
+
46
+ # These functions used to be defined in this file. However, it was moved to
47
+ # _utils/collate.py. Although it is rather hard to access this from user land
48
+ # (one has to explicitly directly `import torch.utils.data.dataloader`), there
49
+ # probably is user code out there using it. This aliasing maintains BC in this
50
+ # aspect.
51
+ default_collate: _collate_fn_t = _utils.collate.default_collate
52
+ default_convert = _utils.collate.default_convert
53
+
54
+ get_worker_info = _utils.worker.get_worker_info
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+
59
+ class _DatasetKind:
60
+ Map = 0
61
+ Iterable = 1
62
+
63
+ @staticmethod
64
+ def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
65
+ if kind == _DatasetKind.Map:
66
+ return _utils.fetch._MapDatasetFetcher(
67
+ dataset, auto_collation, collate_fn, drop_last
68
+ )
69
+ else:
70
+ return _utils.fetch._IterableDatasetFetcher(
71
+ dataset, auto_collation, collate_fn, drop_last
72
+ )
73
+
74
+
75
+ class _InfiniteConstantSampler(Sampler):
76
+ r"""Analogous to ``itertools.repeat(None, None)``.
77
+ Used as sampler for :class:`~torch.utils.data.IterableDataset`.
78
+
79
+ Args:
80
+ data_source (Dataset): dataset to sample from
81
+ """
82
+
83
+ def __init__(self):
84
+ super().__init__(None)
85
+
86
+ def __iter__(self):
87
+ while True:
88
+ yield None
89
+
90
+
91
+ def _get_distributed_settings():
92
+ if dist.is_available() and dist.is_initialized():
93
+ return dist.get_world_size(), dist.get_rank()
94
+ else:
95
+ return 1, 0
96
+
97
+
98
+ def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
99
+ global_worker_id = worker_id
100
+ info = torch.utils.data.get_worker_info()
101
+ assert info is not None
102
+ total_workers = info.num_workers
103
+ datapipe = info.dataset
104
+ assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
105
+ # To distribute elements across distributed process evenly, we should shard data on distributed
106
+ # processes first then shard on worker processes
107
+ total_workers *= world_size
108
+ global_worker_id = global_worker_id * world_size + rank_id
109
+ # For BC, use default SHARDING_PRIORITIES
110
+ torch.utils.data.graph_settings.apply_sharding(
111
+ datapipe, total_workers, global_worker_id
112
+ )
113
+ if worker_init_fn is not None:
114
+ worker_init_fn(worker_id)
115
+
116
+
117
+ def _share_dist_seed(generator, pg):
118
+ _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator)
119
+ if isinstance(pg, dist.ProcessGroup):
120
+ dist.broadcast(_shared_seed, src=0, group=pg)
121
+ return _shared_seed.item()
122
+
123
+
124
+ class RRSDataLoader(Generic[T_co]):
125
+ r"""
126
+ Data loader. Combines a dataset and a sampler, and provides an iterable over
127
+ the given dataset.
128
+
129
+ The :class:`~torch.utils.data.DataLoader` supports both map-style and
130
+ iterable-style datasets with single- or multi-process loading, customizing
131
+ loading order and optional automatic batching (collation) and memory pinning.
132
+
133
+ See :py:mod:`torch.utils.data` documentation page for more details.
134
+
135
+ Args:
136
+ dataset (Dataset): dataset from which to load the data.
137
+ batch_size (int, optional): how many samples per batch to load
138
+ (default: ``1``).
139
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
140
+ at every epoch (default: ``False``).
141
+ sampler (Sampler or Iterable, optional): defines the strategy to draw
142
+ samples from the dataset. Can be any ``Iterable`` with ``__len__``
143
+ implemented. If specified, :attr:`shuffle` must not be specified.
144
+ batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
145
+ returns a batch of indices at a time. Mutually exclusive with
146
+ :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
147
+ and :attr:`drop_last`.
148
+ num_workers (int, optional): how many subprocesses to use for data
149
+ loading. ``0`` means that the data will be loaded in the main process.
150
+ (default: ``0``)
151
+ collate_fn (Callable, optional): merges a list of samples to form a
152
+ mini-batch of Tensor(s). Used when using batched loading from a
153
+ map-style dataset.
154
+ pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
155
+ into device/CUDA pinned memory before returning them. If your data elements
156
+ are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
157
+ see the example below.
158
+ drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
159
+ if the dataset size is not divisible by the batch size. If ``False`` and
160
+ the size of dataset is not divisible by the batch size, then the last batch
161
+ will be smaller. (default: ``False``)
162
+ timeout (numeric, optional): if positive, the timeout value for collecting a batch
163
+ from workers. Should always be non-negative. (default: ``0``)
164
+ worker_init_fn (Callable, optional): If not ``None``, this will be called on each
165
+ worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
166
+ input, after seeding and before data loading. (default: ``None``)
167
+ generator (torch.Generator, optional): If not ``None``, this RNG will be used
168
+ by RandomSampler to generate random indexes and multiprocessing to generate
169
+ `base_seed` for workers. (default: ``None``)
170
+ prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
171
+ in advance by each worker. ``2`` means there will be a total of
172
+ 2 * num_workers batches prefetched across all workers. (default value depends
173
+ on the set value for num_workers. If value of num_workers=0 default is ``None``.
174
+ Otherwise if value of num_workers>0 default is ``2``).
175
+ persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
176
+ the worker processes after a dataset has been consumed once. This allows to
177
+ maintain the workers `Dataset` instances alive. (default: ``False``)
178
+ pin_memory_device (str, optional): the data loader will copy Tensors
179
+ into device pinned memory before returning them if pin_memory is set to true.
180
+
181
+
182
+ .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
183
+ cannot be an unpicklable object, e.g., a lambda function. See
184
+ :ref:`multiprocessing-best-practices` on more details related
185
+ to multiprocessing in PyTorch.
186
+
187
+ .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
188
+ When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
189
+ it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
190
+ rounding depending on :attr:`drop_last`, regardless of multi-process loading
191
+ configurations. This represents the best guess PyTorch can make because PyTorch
192
+ trusts user :attr:`dataset` code in correctly handling multi-process
193
+ loading to avoid duplicate data.
194
+
195
+ However, if sharding results in multiple workers having incomplete last batches,
196
+ this estimate can still be inaccurate, because (1) an otherwise complete batch can
197
+ be broken into multiple ones and (2) more than one batch worth of samples can be
198
+ dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
199
+ cases in general.
200
+
201
+ See `Dataset Types`_ for more details on these two types of datasets and how
202
+ :class:`~torch.utils.data.IterableDataset` interacts with
203
+ `Multi-process data loading`_.
204
+
205
+ .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
206
+ :ref:`data-loading-randomness` notes for random seed related questions.
207
+ """
208
+
209
+ dataset: Dataset[T_co]
210
+ batch_size: Optional[int]
211
+ num_workers: int
212
+ pin_memory: bool
213
+ drop_last: bool
214
+ timeout: float
215
+ sampler: Union[Sampler, Iterable]
216
+ pin_memory_device: str
217
+ prefetch_factor: Optional[int]
218
+ _iterator: Optional["_BaseDataLoaderIter"]
219
+ __initialized = False
220
+
221
+ def __init__(
222
+ self,
223
+ dataset: Dataset[T_co],
224
+ batch_size: Optional[int] = 1,
225
+ shuffle: Optional[bool] = None,
226
+ sampler: Union[Sampler, Iterable, None] = None,
227
+ batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
228
+ num_workers: int = 0,
229
+ collate_fn: Optional[_collate_fn_t] = None,
230
+ pin_memory: bool = False,
231
+ drop_last: bool = False,
232
+ timeout: float = 0,
233
+ worker_init_fn: Optional[_worker_init_fn_t] = None,
234
+ multiprocessing_context=None,
235
+ generator=None,
236
+ *,
237
+ prefetch_factor: Optional[int] = None,
238
+ persistent_workers: bool = False,
239
+ pin_memory_device: str = ""
240
+ ):
241
+ torch._C._log_api_usage_once("python.data_loader")
242
+
243
+ if num_workers < 0:
244
+ raise ValueError(
245
+ "num_workers option should be non-negative; "
246
+ "use num_workers=0 to disable multiprocessing."
247
+ )
248
+
249
+ if timeout < 0:
250
+ raise ValueError("timeout option should be non-negative")
251
+
252
+ if num_workers == 0 and prefetch_factor is not None:
253
+ raise ValueError(
254
+ "prefetch_factor option could only be specified in multiprocessing."
255
+ "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None."
256
+ )
257
+ elif num_workers > 0 and prefetch_factor is None:
258
+ prefetch_factor = 2
259
+ elif prefetch_factor is not None and prefetch_factor < 0:
260
+ raise ValueError("prefetch_factor option should be non-negative")
261
+
262
+ if persistent_workers and num_workers == 0:
263
+ raise ValueError("persistent_workers option needs num_workers > 0")
264
+
265
+ self.dataset = dataset
266
+ self.num_workers = num_workers
267
+ self.prefetch_factor = prefetch_factor
268
+ self.pin_memory = pin_memory
269
+ self.pin_memory_device = pin_memory_device
270
+ self.timeout = timeout
271
+ self.worker_init_fn = worker_init_fn
272
+ self.multiprocessing_context = multiprocessing_context
273
+
274
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
275
+ # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
276
+ if isinstance(self.dataset, IterDataPipe):
277
+ self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
278
+ elif isinstance(self.dataset, MapDataPipe):
279
+ self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
280
+
281
+ # Arg-check dataset related before checking samplers because we want to
282
+ # tell users that iterable-style datasets are incompatible with custom
283
+ # samplers first, so that they don't learn that this combo doesn't work
284
+ # after spending time fixing the custom sampler errors.
285
+ if isinstance(dataset, IterableDataset):
286
+ self._dataset_kind = _DatasetKind.Iterable
287
+ # NOTE [ Custom Samplers and IterableDataset ]
288
+ #
289
+ # `IterableDataset` does not support custom `batch_sampler` or
290
+ # `sampler` since the key is irrelevant (unless we support
291
+ # generator-style dataset one day...).
292
+ #
293
+ # For `sampler`, we always create a dummy sampler. This is an
294
+ # infinite sampler even when the dataset may have an implemented
295
+ # finite `__len__` because in multi-process data loading, naive
296
+ # settings will return duplicated data (which may be desired), and
297
+ # thus using a sampler with length matching that of dataset will
298
+ # cause data lost (you may have duplicates of the first couple
299
+ # batches, but never see anything afterwards). Therefore,
300
+ # `Iterabledataset` always uses an infinite sampler, an instance of
301
+ # `_InfiniteConstantSampler` defined above.
302
+ #
303
+ # A custom `batch_sampler` essentially only controls the batch size.
304
+ # However, it is unclear how useful it would be since an iterable-style
305
+ # dataset can handle that within itself. Moreover, it is pointless
306
+ # in multi-process data loading as the assignment order of batches
307
+ # to workers is an implementation detail so users can not control
308
+ # how to batchify each worker's iterable. Thus, we disable this
309
+ # option. If this turns out to be useful in future, we can re-enable
310
+ # this, and support custom samplers that specify the assignments to
311
+ # specific workers.
312
+ if isinstance(dataset, IterDataPipe):
313
+ if shuffle is not None:
314
+ dataset = torch.utils.data.graph_settings.apply_shuffle_settings(
315
+ dataset, shuffle=shuffle
316
+ )
317
+ # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
318
+ elif shuffle not in {False, None}:
319
+ raise ValueError(
320
+ "DataLoader with IterableDataset: expected unspecified "
321
+ "shuffle option, but got shuffle={}".format(shuffle)
322
+ )
323
+
324
+ if sampler is not None:
325
+ # See NOTE [ Custom Samplers and IterableDataset ]
326
+ raise ValueError(
327
+ "DataLoader with IterableDataset: expected unspecified "
328
+ "sampler option, but got sampler={}".format(sampler)
329
+ )
330
+ elif batch_sampler is not None:
331
+ # See NOTE [ Custom Samplers and IterableDataset ]
332
+ raise ValueError(
333
+ "DataLoader with IterableDataset: expected unspecified "
334
+ "batch_sampler option, but got batch_sampler={}".format(
335
+ batch_sampler
336
+ )
337
+ )
338
+ else:
339
+ shuffle = bool(shuffle)
340
+ self._dataset_kind = _DatasetKind.Map
341
+
342
+ if sampler is not None and shuffle:
343
+ raise ValueError("sampler option is mutually exclusive with " "shuffle")
344
+
345
+ if batch_sampler is not None:
346
+ # auto_collation with custom batch_sampler
347
+ if batch_size != 1 or shuffle or sampler is not None or drop_last:
348
+ raise ValueError(
349
+ "batch_sampler option is mutually exclusive "
350
+ "with batch_size, shuffle, sampler, and "
351
+ "drop_last"
352
+ )
353
+ batch_size = None
354
+ drop_last = False
355
+ elif batch_size is None:
356
+ # no auto_collation
357
+ if drop_last:
358
+ raise ValueError(
359
+ "batch_size=None option disables auto-batching "
360
+ "and is mutually exclusive with drop_last"
361
+ )
362
+
363
+ if sampler is None: # give default samplers
364
+ if self._dataset_kind == _DatasetKind.Iterable:
365
+ # See NOTE [ Custom Samplers and IterableDataset ]
366
+ sampler = _InfiniteConstantSampler()
367
+ else: # map-style
368
+ if shuffle:
369
+ sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
370
+ else:
371
+ sampler = SequentialSampler(dataset) # type: ignore[arg-type]
372
+
373
+ if batch_size is not None and batch_sampler is None:
374
+ # auto_collation without custom batch_sampler
375
+ batch_sampler = BatchSampler(sampler, batch_size, drop_last)
376
+
377
+ self.batch_size = batch_size
378
+ self.drop_last = drop_last
379
+ self.sampler = sampler
380
+ self.batch_sampler = batch_sampler
381
+ self.generator = generator
382
+
383
+ if collate_fn is None:
384
+ if self._auto_collation:
385
+ collate_fn = _utils.collate.default_collate
386
+ else:
387
+ collate_fn = _utils.collate.default_convert
388
+
389
+ self.collate_fn = collate_fn
390
+ self.persistent_workers = persistent_workers
391
+
392
+ self.__initialized = True
393
+ self._IterableDataset_len_called = (
394
+ None # See NOTE [ IterableDataset and __len__ ]
395
+ )
396
+
397
+ self._iterator = None
398
+
399
+ self.check_worker_number_rationality()
400
+
401
+ torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined]
402
+
403
+ def _get_iterator(self) -> "_BaseDataLoaderIter":
404
+ if self.num_workers == 0:
405
+ return _SingleProcessDataLoaderIter(self)
406
+ else:
407
+ self.check_worker_number_rationality()
408
+ return _MultiProcessingDataLoaderIter(self)
409
+
410
+ @property
411
+ def multiprocessing_context(self):
412
+ return self.__multiprocessing_context
413
+
414
+ @multiprocessing_context.setter
415
+ def multiprocessing_context(self, multiprocessing_context):
416
+ if multiprocessing_context is not None:
417
+ if self.num_workers > 0:
418
+ if isinstance(multiprocessing_context, str):
419
+ valid_start_methods = multiprocessing.get_all_start_methods()
420
+ if multiprocessing_context not in valid_start_methods:
421
+ raise ValueError(
422
+ (
423
+ "multiprocessing_context option "
424
+ "should specify a valid start method in {!r}, but got "
425
+ "multiprocessing_context={!r}"
426
+ ).format(valid_start_methods, multiprocessing_context)
427
+ )
428
+ multiprocessing_context = multiprocessing.get_context(
429
+ multiprocessing_context
430
+ )
431
+
432
+ if not isinstance(
433
+ multiprocessing_context, python_multiprocessing.context.BaseContext
434
+ ):
435
+ raise TypeError(
436
+ (
437
+ "multiprocessing_context option should be a valid context "
438
+ "object or a string specifying the start method, but got "
439
+ "multiprocessing_context={}"
440
+ ).format(multiprocessing_context)
441
+ )
442
+ else:
443
+ raise ValueError(
444
+ (
445
+ "multiprocessing_context can only be used with "
446
+ "multi-process loading (num_workers > 0), but got "
447
+ "num_workers={}"
448
+ ).format(self.num_workers)
449
+ )
450
+
451
+ self.__multiprocessing_context = multiprocessing_context
452
+
453
+ def __setattr__(self, attr, val):
454
+ if self.__initialized and attr in (
455
+ "batch_size",
456
+ "batch_sampler",
457
+ "sampler",
458
+ "drop_last",
459
+ "dataset",
460
+ "persistent_workers",
461
+ ):
462
+ raise ValueError(
463
+ "{} attribute should not be set after {} is "
464
+ "initialized".format(attr, self.__class__.__name__)
465
+ )
466
+
467
+ super().__setattr__(attr, val)
468
+
469
+ # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
470
+ # since '_BaseDataLoaderIter' references 'DataLoader'.
471
+ def __iter__(self) -> "_BaseDataLoaderIter":
472
+ # When using a single worker the returned iterator should be
473
+ # created everytime to avoid reseting its state
474
+ # However, in the case of a multiple workers iterator
475
+ # the iterator is only created once in the lifetime of the
476
+ # DataLoader object so that workers can be reused
477
+ if self.persistent_workers and self.num_workers > 0:
478
+ if self._iterator is None:
479
+ self._iterator = self._get_iterator()
480
+ else:
481
+ self._iterator._reset(self)
482
+ return self._iterator
483
+ else:
484
+ return self._get_iterator()
485
+
486
+ @property
487
+ def _auto_collation(self):
488
+ return self.batch_sampler is not None
489
+
490
+ @property
491
+ def _index_sampler(self):
492
+ # The actual sampler used for generating indices for `_DatasetFetcher`
493
+ # (see _utils/fetch.py) to read data at each time. This would be
494
+ # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
495
+ # We can't change `.sampler` and `.batch_sampler` attributes for BC
496
+ # reasons.
497
+ if self._auto_collation:
498
+ return self.batch_sampler
499
+ else:
500
+ return self.sampler
501
+
502
+ def __len__(self) -> int:
503
+ if self._dataset_kind == _DatasetKind.Iterable:
504
+ # NOTE [ IterableDataset and __len__ ]
505
+ #
506
+ # For `IterableDataset`, `__len__` could be inaccurate when one naively
507
+ # does multi-processing data loading, since the samples will be duplicated.
508
+ # However, no real use case should be actually using that behavior, so
509
+ # it should count as a user error. We should generally trust user
510
+ # code to do the proper thing (e.g., configure each replica differently
511
+ # in `__iter__`), and give us the correct `__len__` if they choose to
512
+ # implement it (this will still throw if the dataset does not implement
513
+ # a `__len__`).
514
+ #
515
+ # To provide a further warning, we track if `__len__` was called on the
516
+ # `DataLoader`, save the returned value in `self._len_called`, and warn
517
+ # if the iterator ends up yielding more than this number of samples.
518
+
519
+ # Cannot statically verify that dataset is Sized
520
+ length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type]
521
+ if (
522
+ self.batch_size is not None
523
+ ): # IterableDataset doesn't allow custom sampler or batch_sampler
524
+ from math import ceil
525
+
526
+ if self.drop_last:
527
+ length = length // self.batch_size
528
+ else:
529
+ length = ceil(length / self.batch_size)
530
+ return length
531
+ else:
532
+ return len(self._index_sampler)
533
+
534
+ def check_worker_number_rationality(self):
535
+ # This function check whether the dataloader's worker number is rational based on
536
+ # current system's resource. Current rule is that if the number of workers this
537
+ # Dataloader will create is bigger than the number of logical cpus that is allowed to
538
+ # use, than we will pop up a warning to let user pay attention.
539
+ #
540
+ # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
541
+ # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
542
+ # DataLoader process can use half of them which is 32, then the rational max number of
543
+ # worker that initiated from this process is 32.
544
+ # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
545
+ # So the warning message is triggered to notify the user to lower the worker number if
546
+ # necessary.
547
+ #
548
+ #
549
+ # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
550
+ # available (available in most of Linux system, but not OSX and Windows).
551
+ # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
552
+ # it doesn't repect cpuset.
553
+ # We don't take threading into account since each worker process is single threaded
554
+ # at this time.
555
+ #
556
+ # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
557
+ # other than `torch.set_num_threads` to 1 in the worker process, if the passing
558
+ # in functions use 3rd party modules that rely on those threading flags to determine
559
+ # how many thread to create (eg. numpy, etc), then it is caller's responsibility to
560
+ # set those flags correctly.
561
+ def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
562
+
563
+ suggested_max_worker_msg = (
564
+ (
565
+ (
566
+ "Our suggested max number of worker in current system is {}{}, which is smaller "
567
+ "than what this DataLoader is going to create."
568
+ ).format(
569
+ num_worker_suggest,
570
+ (
571
+ ""
572
+ if cpuset_checked
573
+ else " (`cpuset` is not taken into account)"
574
+ ),
575
+ )
576
+ )
577
+ if num_worker_suggest is not None
578
+ else (
579
+ "DataLoader is not able to compute a suggested max number of worker in current system."
580
+ )
581
+ )
582
+
583
+ warn_msg = (
584
+ "This DataLoader will create {} worker processes in total. {} "
585
+ "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
586
+ "lower the worker number to avoid potential slowness/freeze if necessary."
587
+ ).format(num_worker_created, suggested_max_worker_msg)
588
+ return warn_msg
589
+
590
+ if not self.num_workers or self.num_workers == 0:
591
+ return
592
+
593
+ # try to compute a suggested max number of worker based on system's resource
594
+ max_num_worker_suggest = None
595
+ cpuset_checked = False
596
+ if hasattr(os, "sched_getaffinity"):
597
+ try:
598
+ max_num_worker_suggest = len(os.sched_getaffinity(0))
599
+ cpuset_checked = True
600
+ except Exception:
601
+ pass
602
+ if max_num_worker_suggest is None:
603
+ # os.cpu_count() could return Optional[int]
604
+ # get cpu count first and check None in order to satify mypy check
605
+ cpu_count = os.cpu_count()
606
+ if cpu_count is not None:
607
+ max_num_worker_suggest = cpu_count
608
+
609
+ if max_num_worker_suggest is None:
610
+ warnings.warn(
611
+ _create_warning_msg(
612
+ max_num_worker_suggest, self.num_workers, cpuset_checked
613
+ )
614
+ )
615
+ return
616
+
617
+ if self.num_workers > max_num_worker_suggest:
618
+ warnings.warn(
619
+ _create_warning_msg(
620
+ max_num_worker_suggest, self.num_workers, cpuset_checked
621
+ )
622
+ )
623
+
624
+
625
+ class _BaseDataLoaderIter:
626
+ def __init__(self, loader: RRSDataLoader) -> None:
627
+ self._dataset = loader.dataset
628
+ self._shared_seed = None
629
+ self._pg = None
630
+ if isinstance(self._dataset, IterDataPipe):
631
+ if dist.is_available() and dist.is_initialized():
632
+ self._pg = dist.new_group(backend="gloo")
633
+ self._shared_seed = _share_dist_seed(loader.generator, self._pg)
634
+ shared_rng = torch.Generator()
635
+ shared_rng.manual_seed(self._shared_seed)
636
+ self._dataset = torch.utils.data.graph_settings.apply_random_seed(
637
+ self._dataset, shared_rng
638
+ )
639
+ self._dataset_kind = loader._dataset_kind
640
+ self._IterableDataset_len_called = loader._IterableDataset_len_called
641
+ self._auto_collation = loader._auto_collation
642
+ self._drop_last = loader.drop_last
643
+ self._index_sampler = loader._index_sampler
644
+ self._num_workers = loader.num_workers
645
+ ws, rank = _get_distributed_settings()
646
+ self._world_size = ws
647
+ self._rank = rank
648
+ # for other backends, pin_memory_device need to set. if not set
649
+ # default behaviour is CUDA device. if pin_memory_device is selected
650
+ # and pin_memory is not set, the default behaviour false.
651
+ if len(loader.pin_memory_device) == 0:
652
+ self._pin_memory = loader.pin_memory and torch.cuda.is_available()
653
+ self._pin_memory_device = None
654
+ else:
655
+ if not loader.pin_memory:
656
+ warn_msg = (
657
+ "pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
658
+ "please set pin_memory to true, if you need to use the device pin memory"
659
+ )
660
+ warnings.warn(warn_msg)
661
+
662
+ self._pin_memory = loader.pin_memory
663
+ self._pin_memory_device = loader.pin_memory_device
664
+ self._timeout = loader.timeout
665
+ self._collate_fn = loader.collate_fn
666
+ self._sampler_iter = iter(self._index_sampler)
667
+ self._base_seed = (
668
+ torch.empty((), dtype=torch.int64)
669
+ .random_(generator=loader.generator)
670
+ .item()
671
+ )
672
+ self._persistent_workers = loader.persistent_workers
673
+ self._num_yielded = 0
674
+ self._profile_name = "enumerate(DataLoader)#{}.__next__".format(
675
+ self.__class__.__name__
676
+ )
677
+
678
+ def __iter__(self) -> "_BaseDataLoaderIter":
679
+ return self
680
+
681
+ def _reset(self, loader, first_iter=False):
682
+ self._sampler_iter = iter(self._index_sampler)
683
+ self._num_yielded = 0
684
+ self._IterableDataset_len_called = loader._IterableDataset_len_called
685
+ if isinstance(self._dataset, IterDataPipe):
686
+ self._shared_seed = _share_dist_seed(loader.generator, self._pg)
687
+ shared_rng = torch.Generator()
688
+ shared_rng.manual_seed(self._shared_seed)
689
+ self._dataset = torch.utils.data.graph_settings.apply_random_seed(
690
+ self._dataset, shared_rng
691
+ )
692
+
693
+ def _next_index(self):
694
+ return next(self._sampler_iter) # may raise StopIteration
695
+
696
+ def _next_data(self):
697
+ raise NotImplementedError
698
+
699
+ def __next__(self) -> Any:
700
+ with torch.autograd.profiler.record_function(self._profile_name):
701
+ if self._sampler_iter is None:
702
+ # TODO(https://github.com/pytorch/pytorch/issues/76750)
703
+ self._reset() # type: ignore[call-arg]
704
+ data = self._next_data()
705
+ self._num_yielded += 1
706
+ if (
707
+ self._dataset_kind == _DatasetKind.Iterable
708
+ and self._IterableDataset_len_called is not None
709
+ and self._num_yielded > self._IterableDataset_len_called
710
+ ):
711
+ warn_msg = (
712
+ "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
713
+ "samples have been fetched. "
714
+ ).format(
715
+ self._dataset, self._IterableDataset_len_called, self._num_yielded
716
+ )
717
+ if self._num_workers > 0:
718
+ warn_msg += (
719
+ "For multiprocessing data-loading, this could be caused by not properly configuring the "
720
+ "IterableDataset replica at each worker. Please see "
721
+ "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples."
722
+ )
723
+ warnings.warn(warn_msg)
724
+ return data
725
+
726
+ def __len__(self) -> int:
727
+ return len(self._index_sampler)
728
+
729
+ def __getstate__(self):
730
+ # TODO: add limited pickling support for sharing an iterator
731
+ # across multiple threads for HOGWILD.
732
+ # Probably the best way to do this is by moving the sample pushing
733
+ # to a separate thread and then just sharing the data queue
734
+ # but signalling the end is tricky without a non-blocking API
735
+ raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
736
+
737
+
738
+ class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
739
+ def __init__(self, loader):
740
+ super().__init__(loader)
741
+ assert self._timeout == 0
742
+ assert self._num_workers == 0
743
+
744
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
745
+ # Taking care of distributed sharding
746
+ if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
747
+ # For BC, use default SHARDING_PRIORITIES
748
+ torch.utils.data.graph_settings.apply_sharding(
749
+ self._dataset, self._world_size, self._rank
750
+ )
751
+
752
+ self._dataset_fetcher = _DatasetKind.create_fetcher(
753
+ self._dataset_kind,
754
+ self._dataset,
755
+ self._auto_collation,
756
+ self._collate_fn,
757
+ self._drop_last,
758
+ )
759
+
760
+ def _next_data(self):
761
+ index = self._next_index() # may raise StopIteration
762
+ data = self._dataset_fetcher.fetch(index) # may raise StopIteration
763
+ if self._pin_memory:
764
+ data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
765
+ return data
766
+
767
+
768
+ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
769
+ r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
770
+
771
+ # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
772
+ #
773
+ # Preliminary:
774
+ #
775
+ # Our data model looks like this (queues are indicated with curly brackets):
776
+ #
777
+ # main process ||
778
+ # | ||
779
+ # {index_queue} ||
780
+ # | ||
781
+ # worker processes || DATA
782
+ # | ||
783
+ # {worker_result_queue} || FLOW
784
+ # | ||
785
+ # pin_memory_thread of main process || DIRECTION
786
+ # | ||
787
+ # {data_queue} ||
788
+ # | ||
789
+ # data output \/
790
+ #
791
+ # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
792
+ # `pin_memory=False`.
793
+ #
794
+ #
795
+ # Terminating multiprocessing logic requires very careful design. In
796
+ # particular, we need to make sure that
797
+ #
798
+ # 1. The iterator gracefully exits the workers when its last reference is
799
+ # gone or it is depleted.
800
+ #
801
+ # In this case, the workers should be gracefully exited because the
802
+ # main process may still need to continue to run, and we want cleaning
803
+ # up code in the workers to be executed (e.g., releasing GPU memory).
804
+ # Naturally, we implement the shutdown logic in `__del__` of
805
+ # DataLoaderIterator.
806
+ #
807
+ # We delay the discussion on the logic in this case until later.
808
+ #
809
+ # 2. The iterator exits the workers when the loader process and/or worker
810
+ # processes exits normally or with error.
811
+ #
812
+ # We set all workers and `pin_memory_thread` to have `daemon=True`.
813
+ #
814
+ # You may ask, why can't we make the workers non-daemonic, and
815
+ # gracefully exit using the same logic as we have in `__del__` when the
816
+ # iterator gets deleted (see 1 above)?
817
+ #
818
+ # First of all, `__del__` is **not** guaranteed to be called when
819
+ # interpreter exits. Even if it is called, by the time it executes,
820
+ # many Python core library resources may alreay be freed, and even
821
+ # simple things like acquiring an internal lock of a queue may hang.
822
+ # Therefore, in this case, we actually need to prevent `__del__` from
823
+ # being executed, and rely on the automatic termination of daemonic
824
+ # children.
825
+ #
826
+ # Thus, we register an `atexit` hook that sets a global flag
827
+ # `_utils.python_exit_status`. Since `atexit` hooks are executed in the
828
+ # reverse order of registration, we are guaranteed that this flag is
829
+ # set before library resources we use are freed (which, at least in
830
+ # CPython, is done via an `atexit` handler defined in
831
+ # `multiprocessing/util.py`
832
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
833
+ # registered when an object requiring this mechanism is first
834
+ # created, e.g., `mp.Queue`
835
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
836
+ # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
837
+ # )
838
+ #
839
+ # So in `__del__`, we check if `_utils.python_exit_status` is set or
840
+ # `None` (freed), and perform no-op if so.
841
+ #
842
+ # However, simply letting library clean-up codes run can also be bad,
843
+ # because such codes (i.e., `multiprocessing.util._exit_function()`)
844
+ # include join putting threads for `mp.Queue`, which can be blocking.
845
+ # Hence, the main process putting threads are called with
846
+ # `cancel_join_thread` at creation. See later section
847
+ # [ 3b. A process won't hang when putting into a queue; ]
848
+ # for more details.
849
+ #
850
+ # Here are two example cases where library clean-up codes can run
851
+ # before `__del__` is called:
852
+ #
853
+ # 1. If we hold onto a reference to the iterator, it more often
854
+ # than not tries to do `multiprocessing` library cleaning before
855
+ # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)
856
+ # and thus prevents our cleaning-up code to run first.
857
+ #
858
+ # 2. A similar issue araises when a `DataLoader` is used in a subprocess.
859
+ # When a process ends, it shuts the all its daemonic children
860
+ # down with a SIGTERM (instead of joining them without a timeout).
861
+ # Simiarly for threads, but by a different mechanism. This fact,
862
+ # together with a few implementation details of multiprocessing, forces
863
+ # us to make workers daemonic. All of our problems arise when a
864
+ # DataLoader is used in a subprocess, and are caused by multiprocessing
865
+ # code which looks more or less like this:
866
+ #
867
+ # try:
868
+ # your_function_using_a_dataloader()
869
+ # finally:
870
+ # multiprocessing.util._exit_function()
871
+ #
872
+ # The joining/termination mentioned above happens inside
873
+ # `_exit_function()`. Now, if `your_function_using_a_dataloader()`
874
+ # throws, the stack trace stored in the exception will prevent the
875
+ # frame which uses `DataLoaderIter` to be freed. If the frame has any
876
+ # reference to the `DataLoaderIter` (e.g., in a method of the iter),
877
+ # its `__del__`, which starts the shutdown procedure, will not be
878
+ # called. That, in turn, means that workers aren't notified. Attempting
879
+ # to join in `_exit_function` will then result in a hang.
880
+ #
881
+ # For context, `_exit_function` is also registered as an `atexit` call.
882
+ # So it is unclear to me (@ssnl) why this is needed in a finally block.
883
+ # The code dates back to 2008 and there is no comment on the original
884
+ # PEP 371 or patch https://bugs.python.org/issue3050 (containing both
885
+ # the finally block and the `atexit` registration) that explains this.
886
+ #
887
+ #
888
+ # Finally, another choice is to just shutdown workers with logic in 1
889
+ # above whenever we see an error in `next`. This isn't ideal because
890
+ # a. It prevents users from using try-catch to resume data loading.
891
+ # b. It doesn't prevent hanging if users have references to the
892
+ # iterator.
893
+ #
894
+ # 3. All processes exit if any of them die unexpectedly by fatal signals.
895
+ #
896
+ # As shown above, the workers are set as daemonic children of the main
897
+ # process. However, automatic cleaning-up of such child processes only
898
+ # happens if the parent process exits gracefully (e.g., not via fatal
899
+ # signals like SIGKILL). So we must ensure that each process will exit
900
+ # even the process that should send/receive data to/from it were
901
+ # killed, i.e.,
902
+ #
903
+ # a. A process won't hang when getting from a queue.
904
+ #
905
+ # Even with carefully designed data dependencies (i.e., a `put()`
906
+ # always corresponding to a `get()`), hanging on `get()` can still
907
+ # happen when data in queue is corrupted (e.g., due to
908
+ # `cancel_join_thread` or unexpected exit).
909
+ #
910
+ # For child exit, we set a timeout whenever we try to get data
911
+ # from `data_queue`, and check the workers' status on each timeout
912
+ # and error.
913
+ # See `_DataLoaderiter._get_batch()` and
914
+ # `_DataLoaderiter._try_get_data()` for details.
915
+ #
916
+ # Additionally, for child exit on non-Windows platforms, we also
917
+ # register a SIGCHLD handler (which is supported on Windows) on
918
+ # the main process, which checks if any of the workers fail in the
919
+ # (Python) handler. This is more efficient and faster in detecting
920
+ # worker failures, compared to only using the above mechanism.
921
+ # See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
922
+ #
923
+ # For `.get()` calls where the sender(s) is not the workers, we
924
+ # guard them with timeouts, and check the status of the sender
925
+ # when timeout happens:
926
+ # + in the workers, the `_utils.worker.ManagerWatchdog` class
927
+ # checks the status of the main process.
928
+ # + if `pin_memory=True`, when getting from `pin_memory_thread`,
929
+ # check `pin_memory_thread` status periodically until `.get()`
930
+ # returns or see that `pin_memory_thread` died.
931
+ #
932
+ # b. A process won't hang when putting into a queue;
933
+ #
934
+ # We use `mp.Queue` which has a separate background thread to put
935
+ # objects from an unbounded buffer array. The background thread is
936
+ # daemonic and usually automatically joined when the process
937
+ # *exits*.
938
+ #
939
+ # In case that the receiver has ended abruptly while
940
+ # reading from the pipe, the join will hang forever. The usual
941
+ # solution for this in Python is calling `q.cancel_join_thread`,
942
+ # which prevents automatically joining it when finalizing
943
+ # (exiting).
944
+ #
945
+ # Nonetheless, `cancel_join_thread` must only be called when the
946
+ # queue is **not** going to be read from or write into by another
947
+ # process, because it may hold onto a lock or leave corrupted data
948
+ # in the queue, leading other readers/writers to hang.
949
+ #
950
+ # Hence,
951
+ # + For worker processes, we only do so (for their output
952
+ # queues, i.e., `worker_result_queue`) before exiting.
953
+ # + For `pin_memory_thread`, its output queue `data_queue` is a
954
+ # `queue.Queue` that does blocking `put` if the queue is full.
955
+ # So there is no above problem, but as a result, in
956
+ # `_pin_memory_loop`, we do need to wrap the `put` in a loop
957
+ # that breaks not only upon success, but also when the main
958
+ # process stops reading, i.e., is shutting down.
959
+ # + For loader process, we `cancel_join_thread()` for all
960
+ # `_index_queues` because the whole purpose of workers and
961
+ # `pin_memory_thread` is to serve the loader process. If
962
+ # loader process is already exiting, we don't really care if
963
+ # the queues are corrupted.
964
+ #
965
+ #
966
+ # Now let's get back to 1:
967
+ # how we gracefully exit the workers when the last reference to the
968
+ # iterator is gone.
969
+ #
970
+ # To achieve this, we implement the following logic along with the design
971
+ # choices mentioned above:
972
+ #
973
+ # `workers_done_event`:
974
+ # A `multiprocessing.Event` shared among the main process and all worker
975
+ # processes. This is used to signal the workers that the iterator is
976
+ # shutting down. After it is set, they will not send processed data to
977
+ # queues anymore, and only wait for the final `None` before exiting.
978
+ # `done_event` isn't strictly needed. I.e., we can just check for `None`
979
+ # from the input queue, but it allows us to skip wasting resources
980
+ # processing data if we are already shutting down.
981
+ #
982
+ # `pin_memory_thread_done_event`:
983
+ # A `threading.Event` for a similar purpose to that of
984
+ # `workers_done_event`, but is for the `pin_memory_thread`. The reason
985
+ # that separate events are needed is that `pin_memory_thread` reads from
986
+ # the output queue of the workers. But the workers, upon seeing that
987
+ # `workers_done_event` is set, only wants to see the final `None`, and is
988
+ # not required to flush all data in the output queue (e.g., it may call
989
+ # `cancel_join_thread` on that queue if its `IterableDataset` iterator
990
+ # happens to exhaust coincidentally, which is out of the control of the
991
+ # main process). Thus, since we will exit `pin_memory_thread` before the
992
+ # workers (see below), two separete events are used.
993
+ #
994
+ # NOTE: In short, the protocol is that the main process will set these
995
+ # `done_event`s and then the corresponding processes/threads a `None`,
996
+ # and that they may exit at any time after receiving the `None`.
997
+ #
998
+ # NOTE: Using `None` as the final signal is valid, since normal data will
999
+ # always be a 2-tuple with the 1st element being the index of the data
1000
+ # transferred (different from dataset index/key), and the 2nd being
1001
+ # either the dataset key or the data sample (depending on which part
1002
+ # of the data model the queue is at).
1003
+ #
1004
+ # [ worker processes ]
1005
+ # While loader process is alive:
1006
+ # Get from `index_queue`.
1007
+ # If get anything else,
1008
+ # Check `workers_done_event`.
1009
+ # If set, continue to next iteration
1010
+ # i.e., keep getting until see the `None`, then exit.
1011
+ # Otherwise, process data:
1012
+ # If is fetching from an `IterableDataset` and the iterator
1013
+ # is exhausted, send an `_IterableDatasetStopIteration`
1014
+ # object to signal iteration end. The main process, upon
1015
+ # receiving such an object, will send `None` to this
1016
+ # worker and not use the corresponding `index_queue`
1017
+ # anymore.
1018
+ # If timed out,
1019
+ # No matter `workers_done_event` is set (still need to see `None`)
1020
+ # or not, must continue to next iteration.
1021
+ # (outside loop)
1022
+ # If `workers_done_event` is set, (this can be False with `IterableDataset`)
1023
+ # `data_queue.cancel_join_thread()`. (Everything is ending here:
1024
+ # main process won't read from it;
1025
+ # other workers will also call
1026
+ # `cancel_join_thread`.)
1027
+ #
1028
+ # [ pin_memory_thread ]
1029
+ # # No need to check main thread. If this thread is alive, the main loader
1030
+ # # thread must be alive, because this thread is set as daemonic.
1031
+ # While `pin_memory_thread_done_event` is not set:
1032
+ # Get from `index_queue`.
1033
+ # If timed out, continue to get in the next iteration.
1034
+ # Otherwise, process data.
1035
+ # While `pin_memory_thread_done_event` is not set:
1036
+ # Put processed data to `data_queue` (a `queue.Queue` with blocking put)
1037
+ # If timed out, continue to put in the next iteration.
1038
+ # Otherwise, break, i.e., continuing to the out loop.
1039
+ #
1040
+ # NOTE: we don't check the status of the main thread because
1041
+ # 1. if the process is killed by fatal signal, `pin_memory_thread`
1042
+ # ends.
1043
+ # 2. in other cases, either the cleaning-up in __del__ or the
1044
+ # automatic exit of daemonic thread will take care of it.
1045
+ # This won't busy-wait either because `.get(timeout)` does not
1046
+ # busy-wait.
1047
+ #
1048
+ # [ main process ]
1049
+ # In the DataLoader Iter's `__del__`
1050
+ # b. Exit `pin_memory_thread`
1051
+ # i. Set `pin_memory_thread_done_event`.
1052
+ # ii Put `None` in `worker_result_queue`.
1053
+ # iii. Join the `pin_memory_thread`.
1054
+ # iv. `worker_result_queue.cancel_join_thread()`.
1055
+ #
1056
+ # c. Exit the workers.
1057
+ # i. Set `workers_done_event`.
1058
+ # ii. Put `None` in each worker's `index_queue`.
1059
+ # iii. Join the workers.
1060
+ # iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
1061
+ #
1062
+ # NOTE: (c) is better placed after (b) because it may leave corrupted
1063
+ # data in `worker_result_queue`, which `pin_memory_thread`
1064
+ # reads from, in which case the `pin_memory_thread` can only
1065
+ # happen at timeing out, which is slow. Nonetheless, same thing
1066
+ # happens if a worker is killed by signal at unfortunate times,
1067
+ # but in other cases, we are better off having a non-corrupted
1068
+ # `worker_result_queue` for `pin_memory_thread`.
1069
+ #
1070
+ # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
1071
+ # can be omitted
1072
+ #
1073
+ # NB: `done_event`s isn't strictly needed. E.g., we can just check for
1074
+ # `None` from `index_queue`, but it allows us to skip wasting resources
1075
+ # processing indices already in `index_queue` if we are already shutting
1076
+ # down.
1077
+
1078
+ def __init__(self, loader):
1079
+ super().__init__(loader)
1080
+
1081
+ self._prefetch_factor = loader.prefetch_factor
1082
+
1083
+ assert self._num_workers > 0
1084
+ assert self._prefetch_factor > 0
1085
+
1086
+ if loader.multiprocessing_context is None:
1087
+ multiprocessing_context = multiprocessing
1088
+ else:
1089
+ multiprocessing_context = loader.multiprocessing_context
1090
+
1091
+ self._worker_init_fn = loader.worker_init_fn
1092
+
1093
+ # Adds forward compatibilities so classic DataLoader can work with DataPipes:
1094
+ # Additional worker init function will take care of sharding in MP and Distributed
1095
+ if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
1096
+ self._worker_init_fn = functools.partial(
1097
+ _sharding_worker_init_fn,
1098
+ self._worker_init_fn,
1099
+ self._world_size,
1100
+ self._rank,
1101
+ )
1102
+
1103
+ # No certainty which module multiprocessing_context is
1104
+ self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
1105
+ self._worker_pids_set = False
1106
+ self._shutdown = False
1107
+ self._workers_done_event = multiprocessing_context.Event()
1108
+
1109
+ self._index_queues = []
1110
+ self._workers = []
1111
+ for i in range(self._num_workers):
1112
+ # No certainty which module multiprocessing_context is
1113
+ index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
1114
+ # Need to `cancel_join_thread` here!
1115
+ # See sections (2) and (3b) above.
1116
+ index_queue.cancel_join_thread()
1117
+ w = multiprocessing_context.Process(
1118
+ target=_worker_loop,
1119
+ args=(
1120
+ self._dataset_kind,
1121
+ self._dataset,
1122
+ index_queue,
1123
+ self._worker_result_queue,
1124
+ self._workers_done_event,
1125
+ self._auto_collation,
1126
+ self._collate_fn,
1127
+ self._drop_last,
1128
+ self._base_seed,
1129
+ self._worker_init_fn,
1130
+ i,
1131
+ self._num_workers,
1132
+ self._persistent_workers,
1133
+ self._shared_seed,
1134
+ ),
1135
+ )
1136
+ w.daemon = True
1137
+ # NB: Process.start() actually take some time as it needs to
1138
+ # start a process and pass the arguments over via a pipe.
1139
+ # Therefore, we only add a worker to self._workers list after
1140
+ # it started, so that we do not call .join() if program dies
1141
+ # before it starts, and __del__ tries to join but will get:
1142
+ # AssertionError: can only join a started process.
1143
+ w.start()
1144
+ self._index_queues.append(index_queue)
1145
+ self._workers.append(w)
1146
+
1147
+ if self._pin_memory:
1148
+ self._pin_memory_thread_done_event = threading.Event()
1149
+
1150
+ # Queue is not type-annotated
1151
+ self._data_queue = queue.Queue() # type: ignore[var-annotated]
1152
+ if self._pin_memory_device == "xpu":
1153
+ current_device = torch.xpu.current_device() # type: ignore[attr-defined]
1154
+ else:
1155
+ current_device = torch.cuda.current_device() # choose cuda for default
1156
+ pin_memory_thread = threading.Thread(
1157
+ target=_utils.pin_memory._pin_memory_loop,
1158
+ args=(
1159
+ self._worker_result_queue,
1160
+ self._data_queue,
1161
+ current_device,
1162
+ self._pin_memory_thread_done_event,
1163
+ self._pin_memory_device,
1164
+ ),
1165
+ )
1166
+ pin_memory_thread.daemon = True
1167
+ pin_memory_thread.start()
1168
+ # Similar to workers (see comment above), we only register
1169
+ # pin_memory_thread once it is started.
1170
+ self._pin_memory_thread = pin_memory_thread
1171
+ else:
1172
+ self._data_queue = self._worker_result_queue
1173
+
1174
+ # In some rare cases, persistent workers (daemonic processes)
1175
+ # would be terminated before `__del__` of iterator is invoked
1176
+ # when main process exits
1177
+ # It would cause failure when pin_memory_thread tries to read
1178
+ # corrupted data from worker_result_queue
1179
+ # atexit is used to shutdown thread and child processes in the
1180
+ # right sequence before main process exits
1181
+ if self._persistent_workers and self._pin_memory:
1182
+ import atexit
1183
+
1184
+ for w in self._workers:
1185
+ atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
1186
+
1187
+ # .pid can be None only before process is spawned (not the case, so ignore)
1188
+ _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
1189
+ _utils.signal_handling._set_SIGCHLD_handler()
1190
+ self._worker_pids_set = True
1191
+ self._reset(loader, first_iter=True)
1192
+
1193
+ def _reset(self, loader, first_iter=False):
1194
+ super()._reset(loader, first_iter)
1195
+ self._send_idx = 0 # idx of the next task to be sent to workers
1196
+ self._rcvd_idx = 0 # idx of the next task to be returned in __next__
1197
+ # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
1198
+ # map: task idx => - (worker_id,) if data isn't fetched (outstanding)
1199
+ # \ (worker_id, data) if data is already fetched (out-of-order)
1200
+ self._task_info = {}
1201
+ self._tasks_outstanding = (
1202
+ 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
1203
+ )
1204
+ # A list of booleans representing whether each worker still has work to
1205
+ # do, i.e., not having exhausted its iterable dataset object. It always
1206
+ # contains all `True`s if not using an iterable-style dataset
1207
+ # (i.e., if kind != Iterable).
1208
+ # Not that this indicates that a worker still has work to do *for this epoch*.
1209
+ # It does not mean that a worker is dead. In case of `_persistent_workers`,
1210
+ # the worker will be reset to available in the next epoch.
1211
+ self._workers_status = [True for i in range(self._num_workers)]
1212
+ # Reset the worker queue cycle so it resumes next epoch at worker 0
1213
+ self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
1214
+ # We resume the prefetching in case it was enabled
1215
+ if not first_iter:
1216
+ for idx in range(self._num_workers):
1217
+ self._index_queues[idx].put(
1218
+ _utils.worker._ResumeIteration(self._shared_seed)
1219
+ )
1220
+ resume_iteration_cnt = self._num_workers
1221
+ while resume_iteration_cnt > 0:
1222
+ return_idx, return_data = self._get_data()
1223
+ if isinstance(return_idx, _utils.worker._ResumeIteration):
1224
+ assert return_data is None
1225
+ resume_iteration_cnt -= 1
1226
+ # prime the prefetch loop
1227
+ for _ in range(self._prefetch_factor * self._num_workers):
1228
+ self._try_put_index()
1229
+
1230
+ def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
1231
+ # Tries to fetch data from `self._data_queue` once for a given timeout.
1232
+ # This can also be used as inner loop of fetching without timeout, with
1233
+ # the sender status as the loop condition.
1234
+ #
1235
+ # This raises a `RuntimeError` if any worker died expectedly. This error
1236
+ # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
1237
+ # (only for non-Windows platforms), or the manual check below on errors
1238
+ # and timeouts.
1239
+ #
1240
+ # Returns a 2-tuple:
1241
+ # (bool: whether successfully get data, any: data if successful else None)
1242
+ try:
1243
+ data = self._data_queue.get(timeout=timeout)
1244
+ return (True, data)
1245
+ except Exception as e:
1246
+ # At timeout and error, we manually check whether any worker has
1247
+ # failed. Note that this is the only mechanism for Windows to detect
1248
+ # worker failures.
1249
+ failed_workers = []
1250
+ for worker_id, w in enumerate(self._workers):
1251
+ if self._workers_status[worker_id] and not w.is_alive():
1252
+ failed_workers.append(w)
1253
+ self._mark_worker_as_unavailable(worker_id)
1254
+ if len(failed_workers) > 0:
1255
+ pids_str = ", ".join(str(w.pid) for w in failed_workers)
1256
+ raise RuntimeError(
1257
+ "DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str)
1258
+ ) from e
1259
+ if isinstance(e, queue.Empty):
1260
+ return (False, None)
1261
+ import errno
1262
+ import tempfile
1263
+
1264
+ try:
1265
+ # Raise an exception if we are this close to the FDs limit.
1266
+ # Apparently, trying to open only one file is not a sufficient
1267
+ # test.
1268
+ # See NOTE [ DataLoader on Linux and open files limit ]
1269
+ fds_limit_margin = 10
1270
+ fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
1271
+ except OSError as e:
1272
+ if e.errno == errno.EMFILE:
1273
+ raise RuntimeError(
1274
+ "Too many open files. Communication with the"
1275
+ " workers is no longer possible. Please increase the"
1276
+ " limit using `ulimit -n` in the shell or change the"
1277
+ " sharing strategy by calling"
1278
+ " `torch.multiprocessing.set_sharing_strategy('file_system')`"
1279
+ " at the beginning of your code"
1280
+ ) from None
1281
+ raise
1282
+
1283
+ # NOTE [ DataLoader on Linux and open files limit ]
1284
+ #
1285
+ # On Linux when DataLoader is used with multiprocessing we pass the data between
1286
+ # the root process and the workers through SHM files. We remove those files from
1287
+ # the filesystem as soon as they are created and keep them alive by
1288
+ # passing around their file descriptors through AF_UNIX sockets. (See
1289
+ # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in
1290
+ # the wiki (https://github.com/pytorch/pytorch/wiki).)
1291
+ #
1292
+ # This sometimes leads us to exceeding the open files limit. When that happens,
1293
+ # and the offending file descriptor is coming over a socket, the `socket` Python
1294
+ # package silently strips the file descriptor from the message, setting only the
1295
+ # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that
1296
+ # it _indicates that some control data were discarded due to lack of space in
1297
+ # the buffer for ancillary data_). This might reflect the C implementation of
1298
+ # AF_UNIX sockets.
1299
+ #
1300
+ # This behaviour can be reproduced with the script and instructions at the
1301
+ # bottom of this note.
1302
+ #
1303
+ # When that happens, the standard Python `multiprocessing` (and not
1304
+ # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata`
1305
+ #
1306
+ # Sometimes, instead of the FD being stripped, you may get an `OSError:
1307
+ # Too many open files`, both in the script below and in DataLoader. However,
1308
+ # this is rare and seems to be nondeterministic.
1309
+ #
1310
+ #
1311
+ # #!/usr/bin/env python3
1312
+ # import sys
1313
+ # import socket
1314
+ # import os
1315
+ # import array
1316
+ # import shutil
1317
+ # import socket
1318
+ #
1319
+ #
1320
+ # if len(sys.argv) != 4:
1321
+ # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)")
1322
+ # sys.exit(1)
1323
+ #
1324
+ # if __name__ == '__main__':
1325
+ # dirname = sys.argv[1]
1326
+ # sock_path = dirname + "/sock"
1327
+ # iterations = int(sys.argv[2])
1328
+ # def dummy_path(i):
1329
+ # return dirname + "/" + str(i) + ".dummy"
1330
+ #
1331
+ #
1332
+ # if sys.argv[3] == 'send':
1333
+ # while not os.path.exists(sock_path):
1334
+ # pass
1335
+ # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
1336
+ # client.connect(sock_path)
1337
+ # for i in range(iterations):
1338
+ # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT)
1339
+ # ancdata = array.array('i', [fd])
1340
+ # msg = bytes([i % 256])
1341
+ # print("Sending fd ", fd, " (iteration #", i, ")")
1342
+ # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)])
1343
+ #
1344
+ #
1345
+ # else:
1346
+ # assert sys.argv[3] == 'recv'
1347
+ #
1348
+ # if os.path.exists(dirname):
1349
+ # raise Exception("Directory exists")
1350
+ #
1351
+ # os.mkdir(dirname)
1352
+ #
1353
+ # print("Opening socket...")
1354
+ # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
1355
+ # server.bind(sock_path)
1356
+ #
1357
+ # print("Listening...")
1358
+ # for i in range(iterations):
1359
+ # a = array.array('i')
1360
+ # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize))
1361
+ # assert(len(ancdata) == 1)
1362
+ # cmsg_level, cmsg_type, cmsg_data = ancdata[0]
1363
+ # a.frombytes(cmsg_data)
1364
+ # print("Received fd ", a[0], " (iteration #", i, ")")
1365
+ #
1366
+ # shutil.rmtree(dirname)
1367
+ #
1368
+ # Steps to reproduce:
1369
+ #
1370
+ # 1. Run two shells and set lower file descriptor limit in the receiving one:
1371
+ # (shell1) ulimit -n 1020
1372
+ # (shell2) ulimit -n 1022
1373
+ #
1374
+ # 2. Run the script above with the `recv` option in the first shell
1375
+ # (shell1) ./test_socket.py sock_tmp 1017 recv
1376
+ #
1377
+ # 3. Run the script with the `send` option in the second shell:
1378
+ # (shell2) ./test_socket.py sock_tmp 1017 send
1379
+
1380
+ def _get_data(self):
1381
+ # Fetches data from `self._data_queue`.
1382
+ #
1383
+ # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
1384
+ # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
1385
+ # in a loop. This is the only mechanism to detect worker failures for
1386
+ # Windows. For other platforms, a SIGCHLD handler is also used for
1387
+ # worker failure detection.
1388
+ #
1389
+ # If `pin_memory=True`, we also need check if `pin_memory_thread` had
1390
+ # died at timeouts.
1391
+ if self._timeout > 0:
1392
+ success, data = self._try_get_data(self._timeout)
1393
+ if success:
1394
+ return data
1395
+ else:
1396
+ raise RuntimeError(
1397
+ "DataLoader timed out after {} seconds".format(self._timeout)
1398
+ )
1399
+ elif self._pin_memory:
1400
+ while self._pin_memory_thread.is_alive():
1401
+ success, data = self._try_get_data()
1402
+ if success:
1403
+ return data
1404
+ else:
1405
+ # while condition is false, i.e., pin_memory_thread died.
1406
+ raise RuntimeError("Pin memory thread exited unexpectedly")
1407
+ # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
1408
+ # need to call `.task_done()` because we don't use `.join()`.
1409
+ else:
1410
+ while True:
1411
+ success, data = self._try_get_data()
1412
+ if success:
1413
+ return data
1414
+
1415
+ def _next_data(self):
1416
+ while True:
1417
+ # If the worker responsible for `self._rcvd_idx` has already ended
1418
+ # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
1419
+ # we try to advance `self._rcvd_idx` to find the next valid index.
1420
+ #
1421
+ # This part needs to run in the loop because both the `self._get_data()`
1422
+ # call and `_IterableDatasetStopIteration` check below can mark
1423
+ # extra worker(s) as dead.
1424
+ while self._rcvd_idx < self._send_idx:
1425
+ info = self._task_info[self._rcvd_idx]
1426
+ worker_id = info[0]
1427
+ if (
1428
+ len(info) == 2 or self._workers_status[worker_id]
1429
+ ): # has data or is still active
1430
+ break
1431
+ del self._task_info[self._rcvd_idx]
1432
+ self._rcvd_idx += 1
1433
+ else:
1434
+ # no valid `self._rcvd_idx` is found (i.e., didn't break)
1435
+ if not self._persistent_workers:
1436
+ self._shutdown_workers()
1437
+ raise StopIteration
1438
+
1439
+ # Now `self._rcvd_idx` is the batch index we want to fetch
1440
+
1441
+ # Check if the next sample has already been generated
1442
+ if len(self._task_info[self._rcvd_idx]) == 2:
1443
+ data = self._task_info.pop(self._rcvd_idx)[1]
1444
+ return self._process_data(data)
1445
+
1446
+ assert not self._shutdown and self._tasks_outstanding > 0
1447
+ idx, data = self._get_data()
1448
+ self._tasks_outstanding -= 1
1449
+ if self._dataset_kind == _DatasetKind.Iterable:
1450
+ # Check for _IterableDatasetStopIteration
1451
+ if isinstance(data, _utils.worker._IterableDatasetStopIteration):
1452
+ if self._persistent_workers:
1453
+ self._workers_status[data.worker_id] = False
1454
+ else:
1455
+ self._mark_worker_as_unavailable(data.worker_id)
1456
+ self._try_put_index()
1457
+ continue
1458
+
1459
+ if idx != self._rcvd_idx:
1460
+ # store out-of-order samples
1461
+ self._task_info[idx] += (data,)
1462
+ else:
1463
+ del self._task_info[idx]
1464
+ return self._process_data(data)
1465
+
1466
+ def _try_put_index(self):
1467
+ assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
1468
+
1469
+ try:
1470
+ index = self._next_index()
1471
+ except StopIteration:
1472
+ return
1473
+ for _ in range(self._num_workers): # find the next active worker, if any
1474
+ worker_queue_idx = next(self._worker_queue_idx_cycle)
1475
+ if self._workers_status[worker_queue_idx]:
1476
+ break
1477
+ else:
1478
+ # not found (i.e., didn't break)
1479
+ return
1480
+
1481
+ self._index_queues[worker_queue_idx].put((self._send_idx, index))
1482
+ self._task_info[self._send_idx] = (worker_queue_idx,)
1483
+ self._tasks_outstanding += 1
1484
+ self._send_idx += 1
1485
+
1486
+ def _process_data(self, data):
1487
+ self._rcvd_idx += 1
1488
+ self._try_put_index()
1489
+ if isinstance(data, ExceptionWrapper):
1490
+ data.reraise()
1491
+ return data
1492
+
1493
+ def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
1494
+ # Mark a worker as having finished its work e.g., due to
1495
+ # exhausting an `IterableDataset`. This should be used only when this
1496
+ # `_MultiProcessingDataLoaderIter` is going to continue running.
1497
+
1498
+ assert self._workers_status[worker_id] or (
1499
+ self._persistent_workers and shutdown
1500
+ )
1501
+
1502
+ # Signal termination to that specific worker.
1503
+ q = self._index_queues[worker_id]
1504
+ # Indicate that no more data will be put on this queue by the current
1505
+ # process.
1506
+ q.put(None)
1507
+
1508
+ # Note that we don't actually join the worker here, nor do we remove the
1509
+ # worker's pid from C side struct because (1) joining may be slow, and
1510
+ # (2) since we don't join, the worker may still raise error, and we
1511
+ # prefer capturing those, rather than ignoring them, even though they
1512
+ # are raised after the worker has finished its job.
1513
+ # Joinning is deferred to `_shutdown_workers`, which it is called when
1514
+ # all workers finish their jobs (e.g., `IterableDataset` replicas) or
1515
+ # when this iterator is garbage collected.
1516
+
1517
+ self._workers_status[worker_id] = False
1518
+
1519
+ assert self._workers_done_event.is_set() == shutdown
1520
+
1521
+ def _shutdown_workers(self):
1522
+ # Called when shutting down this `_MultiProcessingDataLoaderIter`.
1523
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
1524
+ # the logic of this function.
1525
+ if (
1526
+ _utils is None
1527
+ or _utils.python_exit_status is True
1528
+ or _utils.python_exit_status is None
1529
+ ):
1530
+ # See (2) of the note. If Python is shutting down, do no-op.
1531
+ return
1532
+ # Normal exit when last reference is gone / iterator is depleted.
1533
+ # See (1) and the second half of the note.
1534
+ if not self._shutdown:
1535
+ self._shutdown = True
1536
+ try:
1537
+ # Normal exit when last reference is gone / iterator is depleted.
1538
+ # See (1) and the second half of the note.
1539
+
1540
+ # Exit `pin_memory_thread` first because exiting workers may leave
1541
+ # corrupted data in `worker_result_queue` which `pin_memory_thread`
1542
+ # reads from.
1543
+ if hasattr(self, "_pin_memory_thread"):
1544
+ # Use hasattr in case error happens before we set the attribute.
1545
+ self._pin_memory_thread_done_event.set()
1546
+ # Send something to pin_memory_thread in case it is waiting
1547
+ # so that it can wake up and check `pin_memory_thread_done_event`
1548
+ self._worker_result_queue.put((None, None))
1549
+ self._pin_memory_thread.join()
1550
+ self._worker_result_queue.cancel_join_thread()
1551
+ self._worker_result_queue.close()
1552
+
1553
+ # Exit workers now.
1554
+ self._workers_done_event.set()
1555
+ for worker_id in range(len(self._workers)):
1556
+ # Get number of workers from `len(self._workers)` instead of
1557
+ # `self._num_workers` in case we error before starting all
1558
+ # workers.
1559
+ # If we are using workers_status with persistent_workers
1560
+ # we have to shut it down because the worker is paused
1561
+ if self._persistent_workers or self._workers_status[worker_id]:
1562
+ self._mark_worker_as_unavailable(worker_id, shutdown=True)
1563
+ for w in self._workers:
1564
+ # We should be able to join here, but in case anything went
1565
+ # wrong, we set a timeout and if the workers fail to join,
1566
+ # they are killed in the `finally` block.
1567
+ w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
1568
+ for q in self._index_queues:
1569
+ q.cancel_join_thread()
1570
+ q.close()
1571
+ finally:
1572
+ # Even though all this function does is putting into queues that
1573
+ # we have called `cancel_join_thread` on, weird things can
1574
+ # happen when a worker is killed by a signal, e.g., hanging in
1575
+ # `Event.set()`. So we need to guard this with SIGCHLD handler,
1576
+ # and remove pids from the C side data structure only at the
1577
+ # end.
1578
+ #
1579
+ # FIXME: Unfortunately, for Windows, we are missing a worker
1580
+ # error detection mechanism here in this function, as it
1581
+ # doesn't provide a SIGCHLD handler.
1582
+ if self._worker_pids_set:
1583
+ _utils.signal_handling._remove_worker_pids(id(self))
1584
+ self._worker_pids_set = False
1585
+ for w in self._workers:
1586
+ if w.is_alive():
1587
+ # Existing mechanisms try to make the workers exit
1588
+ # peacefully, but in case that we unfortunately reach
1589
+ # here, which we shouldn't, (e.g., pytorch/pytorch#39570),
1590
+ # we kill the worker.
1591
+ w.terminate()
1592
+
1593
+ # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
1594
+ @staticmethod
1595
+ def _clean_up_worker(w):
1596
+ try:
1597
+ w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
1598
+ finally:
1599
+ if w.is_alive():
1600
+ w.terminate()
1601
+
1602
+ def __del__(self):
1603
+ self._shutdown_workers()
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/_data_worker.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""""This file is based on torch/utils/data/_utils/worker.py
2
+
3
+ Contains definitions of the methods used by the _BaseDataLoaderIter workers.
4
+ These **needs** to be in global scope since Py2 doesn't support serializing
5
+ static methods.
6
+ """
7
+
8
+ import os
9
+ import queue
10
+ import random
11
+ from dataclasses import dataclass
12
+ from typing import TYPE_CHECKING, Optional, Union
13
+
14
+ import torch
15
+ from torch._utils import ExceptionWrapper
16
+ from torch.utils.data._utils import (HAS_NUMPY, IS_WINDOWS,
17
+ MP_STATUS_CHECK_INTERVAL, signal_handling)
18
+
19
+ if TYPE_CHECKING:
20
+ from torch.utils.data import Dataset
21
+
22
+ from .controller import RRSController
23
+
24
+ if IS_WINDOWS:
25
+ import ctypes
26
+ from ctypes.wintypes import BOOL, DWORD, HANDLE
27
+
28
+ # On Windows, the parent ID of the worker process remains unchanged when the manager process
29
+ # is gone, and the only way to check it through OS is to let the worker have a process handle
30
+ # of the manager and ask if the process status has changed.
31
+ class ManagerWatchdog:
32
+ def __init__(self):
33
+ self.manager_pid = os.getppid()
34
+
35
+ # mypy cannot detect this code is windows only
36
+ self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined]
37
+ self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
38
+ self.kernel32.OpenProcess.restype = HANDLE
39
+ self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
40
+ self.kernel32.WaitForSingleObject.restype = DWORD
41
+
42
+ # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
43
+ SYNCHRONIZE = 0x00100000
44
+ self.manager_handle = self.kernel32.OpenProcess(
45
+ SYNCHRONIZE, 0, self.manager_pid
46
+ )
47
+
48
+ if not self.manager_handle:
49
+ raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined]
50
+
51
+ self.manager_dead = False
52
+
53
+ def is_alive(self):
54
+ if not self.manager_dead:
55
+ # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
56
+ self.manager_dead = (
57
+ self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
58
+ )
59
+ return not self.manager_dead
60
+
61
+ else:
62
+
63
+ class ManagerWatchdog: # type: ignore[no-redef]
64
+ def __init__(self):
65
+ self.manager_pid = os.getppid()
66
+ self.manager_dead = False
67
+
68
+ def is_alive(self):
69
+ if not self.manager_dead:
70
+ self.manager_dead = os.getppid() != self.manager_pid
71
+ return not self.manager_dead
72
+
73
+
74
+ _worker_info = None
75
+
76
+
77
+ class WorkerInfo:
78
+ id: int
79
+ num_workers: int
80
+ seed: int
81
+ dataset: "Dataset"
82
+ __initialized = False
83
+
84
+ def __init__(self, **kwargs):
85
+ for k, v in kwargs.items():
86
+ setattr(self, k, v)
87
+ self.__keys = tuple(kwargs.keys())
88
+ self.__initialized = True
89
+
90
+ def __setattr__(self, key, val):
91
+ if self.__initialized:
92
+ raise RuntimeError(
93
+ "Cannot assign attributes to {} objects".format(self.__class__.__name__)
94
+ )
95
+ return super().__setattr__(key, val)
96
+
97
+ def __repr__(self):
98
+ items = []
99
+ for k in self.__keys:
100
+ items.append("{}={}".format(k, getattr(self, k)))
101
+ return "{}({})".format(self.__class__.__name__, ", ".join(items))
102
+
103
+
104
+ def get_worker_info() -> Optional[WorkerInfo]:
105
+ r"""Returns the information about the current
106
+ :class:`~torch.utils.data.DataLoader` iterator worker process.
107
+
108
+ When called in a worker, this returns an object guaranteed to have the
109
+ following attributes:
110
+
111
+ * :attr:`id`: the current worker id.
112
+ * :attr:`num_workers`: the total number of workers.
113
+ * :attr:`seed`: the random seed set for the current worker. This value is
114
+ determined by main process RNG and the worker id. See
115
+ :class:`~torch.utils.data.DataLoader`'s documentation for more details.
116
+ * :attr:`dataset`: the copy of the dataset object in **this** process. Note
117
+ that this will be a different object in a different process than the one
118
+ in the main process.
119
+
120
+ When called in the main process, this returns ``None``.
121
+
122
+ .. note::
123
+ When used in a :attr:`worker_init_fn` passed over to
124
+ :class:`~torch.utils.data.DataLoader`, this method can be useful to
125
+ set up each worker process differently, for instance, using ``worker_id``
126
+ to configure the ``dataset`` object to only read a specific fraction of a
127
+ sharded dataset, or use ``seed`` to seed other libraries used in dataset
128
+ code.
129
+ """
130
+ return _worker_info
131
+
132
+
133
+ r"""Dummy class used to signal the end of an IterableDataset"""
134
+
135
+
136
+ @dataclass(frozen=True)
137
+ class _IterableDatasetStopIteration:
138
+ worker_id: int
139
+
140
+
141
+ r"""Dummy class used to resume the fetching when worker reuse is enabled"""
142
+
143
+
144
+ @dataclass(frozen=True)
145
+ class _ResumeIteration:
146
+ seed: Optional[int] = None
147
+
148
+
149
+ # The function `_generate_state` is adapted from `numpy.random.SeedSequence`
150
+ # from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
151
+ # It's MIT licensed, here is the copyright:
152
+
153
+ # Copyright (c) 2015 Melissa E. O'Neill
154
+ # Copyright (c) 2019 NumPy Developers
155
+ #
156
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
157
+ # of this software and associated documentation files (the "Software"), to deal
158
+ # in the Software without restriction, including without limitation the rights
159
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
160
+ # copies of the Software, and to permit persons to whom the Software is
161
+ # furnished to do so, subject to the following conditions:
162
+ #
163
+ # The above copyright notice and this permission notice shall be included in
164
+ # all copies or substantial portions of the Software.
165
+ #
166
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
167
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
168
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
169
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
170
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
171
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
172
+ # SOFTWARE.
173
+
174
+
175
+ # This function generates an array of int32 as the seed for
176
+ # `numpy.random`, in order to prevent state collision due to same
177
+ # seed and algorithm for `numpy.random` and `random` modules.
178
+ # TODO: Implement `SeedSequence` like object for `torch.random`
179
+ def _generate_state(base_seed, worker_id):
180
+ INIT_A = 0x43B0D7E5
181
+ MULT_A = 0x931E8875
182
+ INIT_B = 0x8B51F9DD
183
+ MULT_B = 0x58F38DED
184
+ MIX_MULT_L = 0xCA01F9DD
185
+ MIX_MULT_R = 0x4973F715
186
+ XSHIFT = 4 * 8 // 2
187
+ MASK32 = 0xFFFFFFFF
188
+
189
+ entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
190
+ pool = [0] * 4
191
+
192
+ hash_const_A = INIT_A
193
+
194
+ def hash(value):
195
+ nonlocal hash_const_A
196
+ value = (value ^ hash_const_A) & MASK32
197
+ hash_const_A = (hash_const_A * MULT_A) & MASK32
198
+ value = (value * hash_const_A) & MASK32
199
+ value = (value ^ (value >> XSHIFT)) & MASK32
200
+ return value
201
+
202
+ def mix(x, y):
203
+ result_x = (MIX_MULT_L * x) & MASK32
204
+ result_y = (MIX_MULT_R * y) & MASK32
205
+ result = (result_x - result_y) & MASK32
206
+ result = (result ^ (result >> XSHIFT)) & MASK32
207
+ return result
208
+
209
+ # Add in the entropy to the pool.
210
+ for i in range(len(pool)):
211
+ pool[i] = hash(entropy[i])
212
+
213
+ # Mix all bits together so late bits can affect earlier bits.
214
+ for i_src in range(len(pool)):
215
+ for i_dst in range(len(pool)):
216
+ if i_src != i_dst:
217
+ pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
218
+
219
+ hash_const_B = INIT_B
220
+ state = []
221
+ for i_dst in range(4):
222
+ data_val = pool[i_dst]
223
+ data_val = (data_val ^ hash_const_B) & MASK32
224
+ hash_const_B = (hash_const_B * MULT_B) & MASK32
225
+ data_val = (data_val * hash_const_B) & MASK32
226
+ data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
227
+ state.append(data_val)
228
+ return state
229
+
230
+
231
+ def _worker_loop(
232
+ dataset_kind,
233
+ dataset,
234
+ index_queue,
235
+ data_queue,
236
+ done_event,
237
+ auto_collation,
238
+ collate_fn,
239
+ drop_last,
240
+ base_seed,
241
+ init_fn,
242
+ worker_id,
243
+ num_workers,
244
+ persistent_workers,
245
+ shared_seed,
246
+ ):
247
+ # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
248
+ # logic of this function.
249
+
250
+ try:
251
+ # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
252
+ # module's handlers are executed after Python returns from C low-level
253
+ # handlers, likely when the same fatal signal had already happened
254
+ # again.
255
+ # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
256
+ signal_handling._set_worker_signal_handlers()
257
+
258
+ torch.set_num_threads(1)
259
+ seed = base_seed + worker_id
260
+ random.seed(seed)
261
+ torch.manual_seed(seed)
262
+ if HAS_NUMPY:
263
+ np_seed = _generate_state(base_seed, worker_id)
264
+ import numpy as np
265
+
266
+ np.random.seed(np_seed)
267
+
268
+ from torch.utils.data import IterDataPipe
269
+ from torch.utils.data.graph_settings import apply_random_seed
270
+
271
+ shared_rng = torch.Generator()
272
+ if isinstance(dataset, IterDataPipe):
273
+ assert shared_seed is not None
274
+ shared_rng.manual_seed(shared_seed)
275
+ dataset = apply_random_seed(dataset, shared_rng)
276
+
277
+ global _worker_info
278
+ _worker_info = WorkerInfo(
279
+ id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset
280
+ )
281
+
282
+ from torch.utils.data import _DatasetKind
283
+
284
+ init_exception = None
285
+
286
+ try:
287
+ if init_fn is not None:
288
+ init_fn(worker_id)
289
+
290
+ fetcher = _DatasetKind.create_fetcher(
291
+ dataset_kind, dataset, auto_collation, collate_fn, drop_last
292
+ )
293
+ except Exception:
294
+ init_exception = ExceptionWrapper(
295
+ where="in DataLoader worker process {}".format(worker_id)
296
+ )
297
+
298
+ # When using Iterable mode, some worker can exit earlier than others due
299
+ # to the IterableDataset behaving differently for different workers.
300
+ # When such things happen, an `_IterableDatasetStopIteration` object is
301
+ # sent over to the main process with the ID of this worker, so that the
302
+ # main process won't send more tasks to this worker, and will send
303
+ # `None` to this worker to properly exit it.
304
+ #
305
+ # Note that we cannot set `done_event` from a worker as it is shared
306
+ # among all processes. Instead, we set the `iteration_end` flag to
307
+ # signify that the iterator is exhausted. When either `done_event` or
308
+ # `iteration_end` is set, we skip all processing step and just wait for
309
+ # `None`.
310
+ iteration_end = False
311
+
312
+ watchdog = ManagerWatchdog()
313
+
314
+ while watchdog.is_alive():
315
+ try:
316
+ r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
317
+ except queue.Empty:
318
+ continue
319
+ if isinstance(r, _ResumeIteration):
320
+ # Acknowledge the main process
321
+ data_queue.put((r, None))
322
+ iteration_end = False
323
+
324
+ if isinstance(dataset, IterDataPipe):
325
+ assert r.seed is not None
326
+ shared_rng.manual_seed(r.seed)
327
+ dataset = apply_random_seed(dataset, shared_rng)
328
+
329
+ # Recreate the fetcher for worker-reuse policy
330
+ fetcher = _DatasetKind.create_fetcher(
331
+ dataset_kind, dataset, auto_collation, collate_fn, drop_last
332
+ )
333
+ continue
334
+ elif r is None:
335
+ # Received the final signal
336
+ assert done_event.is_set() or iteration_end
337
+ break
338
+ elif done_event.is_set() or iteration_end:
339
+ # `done_event` is set. But I haven't received the final signal
340
+ # (None) yet. I will keep continuing until get it, and skip the
341
+ # processing steps.
342
+ continue
343
+ idx, index = r
344
+ """ Added """
345
+ RRSController.sample_resolution(batch_id=idx)
346
+ """ Added """
347
+ data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
348
+ if init_exception is not None:
349
+ data = init_exception
350
+ init_exception = None
351
+ else:
352
+ try:
353
+ data = fetcher.fetch(index)
354
+ except Exception as e:
355
+ if (
356
+ isinstance(e, StopIteration)
357
+ and dataset_kind == _DatasetKind.Iterable
358
+ ):
359
+ data = _IterableDatasetStopIteration(worker_id)
360
+ # Set `iteration_end`
361
+ # (1) to save future `next(...)` calls, and
362
+ # (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
363
+ iteration_end = True
364
+ else:
365
+ # It is important that we don't store exc_info in a variable.
366
+ # `ExceptionWrapper` does the correct thing.
367
+ # See NOTE [ Python Traceback Reference Cycle Problem ]
368
+ data = ExceptionWrapper(
369
+ where="in DataLoader worker process {}".format(worker_id)
370
+ )
371
+ data_queue.put((idx, data))
372
+ del data, idx, index, r # save memory
373
+ except KeyboardInterrupt:
374
+ # Main process will raise KeyboardInterrupt anyways.
375
+ pass
376
+ if done_event.is_set():
377
+ data_queue.cancel_join_thread()
378
+ data_queue.close()
yolo-world-with-efficientvit-sam/efficientvit/apps/data_provider/random_resolution/controller.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import copy
6
+
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ import torchvision.transforms.functional as F
10
+
11
+ from efficientvit.models.utils import torch_random_choices
12
+
13
+ __all__ = [
14
+ "RRSController",
15
+ "get_interpolate",
16
+ "MyRandomResizedCrop",
17
+ ]
18
+
19
+
20
+ class RRSController:
21
+ ACTIVE_SIZE = (224, 224)
22
+ IMAGE_SIZE_LIST = [(224, 224)]
23
+
24
+ CHOICE_LIST = None
25
+
26
+ @staticmethod
27
+ def get_candidates() -> list[tuple[int, int]]:
28
+ return copy.deepcopy(RRSController.IMAGE_SIZE_LIST)
29
+
30
+ @staticmethod
31
+ def sample_resolution(batch_id: int) -> None:
32
+ RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id]
33
+
34
+ @staticmethod
35
+ def set_epoch(epoch: int, batch_per_epoch: int) -> None:
36
+ g = torch.Generator()
37
+ g.manual_seed(epoch)
38
+ RRSController.CHOICE_LIST = torch_random_choices(
39
+ RRSController.get_candidates(),
40
+ g,
41
+ batch_per_epoch,
42
+ )
43
+
44
+
45
+ def get_interpolate(name: str) -> F.InterpolationMode:
46
+ mapping = {
47
+ "nearest": F.InterpolationMode.NEAREST,
48
+ "bilinear": F.InterpolationMode.BILINEAR,
49
+ "bicubic": F.InterpolationMode.BICUBIC,
50
+ "box": F.InterpolationMode.BOX,
51
+ "hamming": F.InterpolationMode.HAMMING,
52
+ "lanczos": F.InterpolationMode.LANCZOS,
53
+ }
54
+ if name in mapping:
55
+ return mapping[name]
56
+ elif name == "random":
57
+ return torch_random_choices(
58
+ [
59
+ F.InterpolationMode.NEAREST,
60
+ F.InterpolationMode.BILINEAR,
61
+ F.InterpolationMode.BICUBIC,
62
+ F.InterpolationMode.BOX,
63
+ F.InterpolationMode.HAMMING,
64
+ F.InterpolationMode.LANCZOS,
65
+ ],
66
+ )
67
+ else:
68
+ raise NotImplementedError
69
+
70
+
71
+ class MyRandomResizedCrop(transforms.RandomResizedCrop):
72
+ def __init__(
73
+ self,
74
+ scale=(0.08, 1.0),
75
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
76
+ interpolation: str = "random",
77
+ ):
78
+ super(MyRandomResizedCrop, self).__init__(224, scale, ratio)
79
+ self.interpolation = interpolation
80
+
81
+ def forward(self, img: torch.Tensor) -> torch.Tensor:
82
+ i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio))
83
+ target_size = RRSController.ACTIVE_SIZE
84
+ return F.resized_crop(
85
+ img, i, j, h, w, list(target_size), get_interpolate(self.interpolation)
86
+ )
87
+
88
+ def __repr__(self) -> str:
89
+ format_string = self.__class__.__name__
90
+ format_string += f"(\n\tsize={RRSController.get_candidates()},\n"
91
+ format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n"
92
+ format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n"
93
+ format_string += f"\tinterpolation={self.interpolation})"
94
+ return format_string
yolo-world-with-efficientvit-sam/efficientvit/apps/setup.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import os
6
+ import time
7
+ from copy import deepcopy
8
+
9
+ import torch.backends.cudnn
10
+ import torch.distributed
11
+ import torch.nn as nn
12
+
13
+ from efficientvit.apps.data_provider import DataProvider
14
+ from efficientvit.apps.trainer.run_config import RunConfig
15
+ from efficientvit.apps.utils import (dist_init, dump_config,
16
+ get_dist_local_rank, get_dist_rank,
17
+ get_dist_size, init_modules, is_master,
18
+ load_config, partial_update_config,
19
+ zero_last_gamma)
20
+ from efficientvit.models.utils import (build_kwargs_from_config,
21
+ load_state_dict_from_file)
22
+
23
+ __all__ = [
24
+ "save_exp_config",
25
+ "setup_dist_env",
26
+ "setup_seed",
27
+ "setup_exp_config",
28
+ "setup_data_provider",
29
+ "setup_run_config",
30
+ "init_model",
31
+ ]
32
+
33
+
34
+ def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
35
+ if not is_master():
36
+ return
37
+ dump_config(exp_config, os.path.join(path, name))
38
+
39
+
40
+ def setup_dist_env(gpu: str or None = None) -> None:
41
+ if gpu is not None:
42
+ os.environ["CUDA_VISIBLE_DEVICES"] = gpu
43
+ if not torch.distributed.is_initialized():
44
+ dist_init()
45
+ torch.backends.cudnn.benchmark = True
46
+ torch.cuda.set_device(get_dist_local_rank())
47
+
48
+
49
+ def setup_seed(manual_seed: int, resume: bool) -> None:
50
+ if resume:
51
+ manual_seed = int(time.time())
52
+ manual_seed = get_dist_rank() + manual_seed
53
+ torch.manual_seed(manual_seed)
54
+ torch.cuda.manual_seed_all(manual_seed)
55
+
56
+
57
+ def setup_exp_config(
58
+ config_path: str, recursive=True, opt_args: dict or None = None
59
+ ) -> dict:
60
+ # load config
61
+ if not os.path.isfile(config_path):
62
+ raise ValueError(config_path)
63
+
64
+ fpaths = [config_path]
65
+ if recursive:
66
+ extension = os.path.splitext(config_path)[1]
67
+ while os.path.dirname(config_path) != config_path:
68
+ config_path = os.path.dirname(config_path)
69
+ fpath = os.path.join(config_path, "default" + extension)
70
+ if os.path.isfile(fpath):
71
+ fpaths.append(fpath)
72
+ fpaths = fpaths[::-1]
73
+
74
+ default_config = load_config(fpaths[0])
75
+ exp_config = deepcopy(default_config)
76
+ for fpath in fpaths[1:]:
77
+ partial_update_config(exp_config, load_config(fpath))
78
+ # update config via args
79
+ if opt_args is not None:
80
+ partial_update_config(exp_config, opt_args)
81
+
82
+ return exp_config
83
+
84
+
85
+ def setup_data_provider(
86
+ exp_config: dict,
87
+ data_provider_classes: list[type[DataProvider]],
88
+ is_distributed: bool = True,
89
+ ) -> DataProvider:
90
+ dp_config = exp_config["data_provider"]
91
+ dp_config["num_replicas"] = get_dist_size() if is_distributed else None
92
+ dp_config["rank"] = get_dist_rank() if is_distributed else None
93
+ dp_config["test_batch_size"] = (
94
+ dp_config.get("test_batch_size", None) or dp_config["base_batch_size"] * 2
95
+ )
96
+ dp_config["batch_size"] = dp_config["train_batch_size"] = dp_config[
97
+ "base_batch_size"
98
+ ]
99
+
100
+ data_provider_lookup = {
101
+ provider.name: provider for provider in data_provider_classes
102
+ }
103
+ data_provider_class = data_provider_lookup[dp_config["dataset"]]
104
+
105
+ data_provider_kwargs = build_kwargs_from_config(dp_config, data_provider_class)
106
+ data_provider = data_provider_class(**data_provider_kwargs)
107
+ return data_provider
108
+
109
+
110
+ def setup_run_config(exp_config: dict, run_config_cls: type[RunConfig]) -> RunConfig:
111
+ exp_config["run_config"]["init_lr"] = (
112
+ exp_config["run_config"]["base_lr"] * get_dist_size()
113
+ )
114
+
115
+ run_config = run_config_cls(**exp_config["run_config"])
116
+
117
+ return run_config
118
+
119
+
120
+ def init_model(
121
+ network: nn.Module,
122
+ init_from: str or None = None,
123
+ backbone_init_from: str or None = None,
124
+ rand_init="trunc_normal",
125
+ last_gamma=None,
126
+ ) -> None:
127
+ # initialization
128
+ init_modules(network, init_type=rand_init)
129
+ # zero gamma of last bn in each block
130
+ if last_gamma is not None:
131
+ zero_last_gamma(network, last_gamma)
132
+
133
+ # load weight
134
+ if init_from is not None and os.path.isfile(init_from):
135
+ network.load_state_dict(load_state_dict_from_file(init_from))
136
+ print(f"Loaded init from {init_from}")
137
+ elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
138
+ network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
139
+ print(f"Loaded backbone init from {backbone_init_from}")
140
+ else:
141
+ print(f"Random init ({rand_init}) with last gamma {last_gamma}")
yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .base import *
6
+ from .run_config import *
yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (264 Bytes). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/base.cpython-311.pyc ADDED
Binary file (16.8 kB). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/__pycache__/run_config.cpython-311.pyc ADDED
Binary file (7.04 kB). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/base.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import os
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from efficientvit.apps.data_provider import DataProvider, parse_image_size
11
+ from efficientvit.apps.trainer.run_config import RunConfig
12
+ from efficientvit.apps.utils import (EMA, dist_barrier, get_dist_local_rank,
13
+ is_master)
14
+ from efficientvit.models.nn.norm import reset_bn
15
+ from efficientvit.models.utils import is_parallel, load_state_dict_from_file
16
+
17
+ __all__ = ["Trainer"]
18
+
19
+
20
+ class Trainer:
21
+ def __init__(self, path: str, model: nn.Module, data_provider: DataProvider):
22
+ self.path = os.path.realpath(os.path.expanduser(path))
23
+ self.model = model.cuda()
24
+ self.data_provider = data_provider
25
+
26
+ self.ema = None
27
+
28
+ self.checkpoint_path = os.path.join(self.path, "checkpoint")
29
+ self.logs_path = os.path.join(self.path, "logs")
30
+ for path in [self.path, self.checkpoint_path, self.logs_path]:
31
+ os.makedirs(path, exist_ok=True)
32
+
33
+ self.best_val = 0.0
34
+ self.start_epoch = 0
35
+
36
+ @property
37
+ def network(self) -> nn.Module:
38
+ return self.model.module if is_parallel(self.model) else self.model
39
+
40
+ @property
41
+ def eval_network(self) -> nn.Module:
42
+ if self.ema is None:
43
+ model = self.model
44
+ else:
45
+ model = self.ema.shadows
46
+ model = model.module if is_parallel(model) else model
47
+ return model
48
+
49
+ def write_log(self, log_str, prefix="valid", print_log=True, mode="a") -> None:
50
+ if is_master():
51
+ fout = open(os.path.join(self.logs_path, f"{prefix}.log"), mode)
52
+ fout.write(log_str + "\n")
53
+ fout.flush()
54
+ fout.close()
55
+ if print_log:
56
+ print(log_str)
57
+
58
+ def save_model(
59
+ self,
60
+ checkpoint=None,
61
+ only_state_dict=True,
62
+ epoch=0,
63
+ model_name=None,
64
+ ) -> None:
65
+ if is_master():
66
+ if checkpoint is None:
67
+ if only_state_dict:
68
+ checkpoint = {"state_dict": self.network.state_dict()}
69
+ else:
70
+ checkpoint = {
71
+ "state_dict": self.network.state_dict(),
72
+ "epoch": epoch,
73
+ "best_val": self.best_val,
74
+ "optimizer": self.optimizer.state_dict(),
75
+ "lr_scheduler": self.lr_scheduler.state_dict(),
76
+ "ema": self.ema.state_dict() if self.ema is not None else None,
77
+ "scaler": self.scaler.state_dict() if self.fp16 else None,
78
+ }
79
+
80
+ model_name = model_name or "checkpoint.pt"
81
+
82
+ latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
83
+ model_path = os.path.join(self.checkpoint_path, model_name)
84
+ with open(latest_fname, "w") as _fout:
85
+ _fout.write(model_path + "\n")
86
+ torch.save(checkpoint, model_path)
87
+
88
+ def load_model(self, model_fname=None) -> None:
89
+ latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
90
+ if model_fname is None and os.path.exists(latest_fname):
91
+ with open(latest_fname, "r") as fin:
92
+ model_fname = fin.readline()
93
+ if len(model_fname) > 0 and model_fname[-1] == "\n":
94
+ model_fname = model_fname[:-1]
95
+ try:
96
+ if model_fname is None:
97
+ model_fname = f"{self.checkpoint_path}/checkpoint.pt"
98
+ elif not os.path.exists(model_fname):
99
+ model_fname = f"{self.checkpoint_path}/{os.path.basename(model_fname)}"
100
+ if not os.path.exists(model_fname):
101
+ model_fname = f"{self.checkpoint_path}/checkpoint.pt"
102
+ print(f"=> loading checkpoint {model_fname}")
103
+ checkpoint = load_state_dict_from_file(model_fname, False)
104
+ except Exception:
105
+ self.write_log(f"fail to load checkpoint from {self.checkpoint_path}")
106
+ return
107
+
108
+ # load checkpoint
109
+ self.network.load_state_dict(checkpoint["state_dict"], strict=False)
110
+ log = []
111
+ if "epoch" in checkpoint:
112
+ self.start_epoch = checkpoint["epoch"] + 1
113
+ self.run_config.update_global_step(self.start_epoch)
114
+ log.append(f"epoch={self.start_epoch - 1}")
115
+ if "best_val" in checkpoint:
116
+ self.best_val = checkpoint["best_val"]
117
+ log.append(f"best_val={self.best_val:.2f}")
118
+ if "optimizer" in checkpoint:
119
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
120
+ log.append("optimizer")
121
+ if "lr_scheduler" in checkpoint:
122
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
123
+ log.append("lr_scheduler")
124
+ if "ema" in checkpoint and self.ema is not None:
125
+ self.ema.load_state_dict(checkpoint["ema"])
126
+ log.append("ema")
127
+ if "scaler" in checkpoint and self.fp16:
128
+ self.scaler.load_state_dict(checkpoint["scaler"])
129
+ log.append("scaler")
130
+ self.write_log("Loaded: " + ", ".join(log))
131
+
132
+ """ validate """
133
+
134
+ def reset_bn(
135
+ self,
136
+ network: nn.Module or None = None,
137
+ subset_size: int = 16000,
138
+ subset_batch_size: int = 100,
139
+ data_loader=None,
140
+ progress_bar=False,
141
+ ) -> None:
142
+ network = network or self.network
143
+ if data_loader is None:
144
+ data_loader = []
145
+ for data in self.data_provider.build_sub_train_loader(
146
+ subset_size, subset_batch_size
147
+ ):
148
+ if isinstance(data, list):
149
+ data_loader.append(data[0])
150
+ elif isinstance(data, dict):
151
+ data_loader.append(data["data"])
152
+ elif isinstance(data, torch.Tensor):
153
+ data_loader.append(data)
154
+ else:
155
+ raise NotImplementedError
156
+
157
+ network.eval()
158
+ reset_bn(
159
+ network,
160
+ data_loader,
161
+ sync=True,
162
+ progress_bar=progress_bar,
163
+ )
164
+
165
+ def _validate(self, model, data_loader, epoch) -> dict[str, any]:
166
+ raise NotImplementedError
167
+
168
+ def validate(
169
+ self, model=None, data_loader=None, is_test=True, epoch=0
170
+ ) -> dict[str, any]:
171
+ model = model or self.eval_network
172
+ if data_loader is None:
173
+ if is_test:
174
+ data_loader = self.data_provider.test
175
+ else:
176
+ data_loader = self.data_provider.valid
177
+
178
+ model.eval()
179
+ return self._validate(model, data_loader, epoch)
180
+
181
+ def multires_validate(
182
+ self,
183
+ model=None,
184
+ data_loader=None,
185
+ is_test=True,
186
+ epoch=0,
187
+ eval_image_size=None,
188
+ ) -> dict[str, dict[str, any]]:
189
+ eval_image_size = eval_image_size or self.run_config.eval_image_size
190
+ eval_image_size = eval_image_size or self.data_provider.image_size
191
+ model = model or self.eval_network
192
+
193
+ if not isinstance(eval_image_size, list):
194
+ eval_image_size = [eval_image_size]
195
+
196
+ output_dict = {}
197
+ for r in eval_image_size:
198
+ self.data_provider.assign_active_image_size(parse_image_size(r))
199
+ if self.run_config.reset_bn:
200
+ self.reset_bn(
201
+ network=model,
202
+ subset_size=self.run_config.reset_bn_size,
203
+ subset_batch_size=self.run_config.reset_bn_batch_size,
204
+ progress_bar=True,
205
+ )
206
+ output_dict[f"r{r}"] = self.validate(model, data_loader, is_test, epoch)
207
+ return output_dict
208
+
209
+ """ training """
210
+
211
+ def prep_for_training(
212
+ self, run_config: RunConfig, ema_decay: float or None = None, fp16=False
213
+ ) -> None:
214
+ self.run_config = run_config
215
+ self.model = nn.parallel.DistributedDataParallel(
216
+ self.model.cuda(),
217
+ device_ids=[get_dist_local_rank()],
218
+ static_graph=True,
219
+ )
220
+
221
+ self.run_config.global_step = 0
222
+ self.run_config.batch_per_epoch = len(self.data_provider.train)
223
+ assert self.run_config.batch_per_epoch > 0, "Training set is empty"
224
+
225
+ # build optimizer
226
+ self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model)
227
+
228
+ if ema_decay is not None:
229
+ self.ema = EMA(self.network, ema_decay)
230
+
231
+ # fp16
232
+ self.fp16 = fp16
233
+ self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
234
+
235
+ def sync_model(self):
236
+ print("Sync model")
237
+ self.save_model(model_name="sync.pt")
238
+ dist_barrier()
239
+ checkpoint = torch.load(
240
+ os.path.join(self.checkpoint_path, "sync.pt"), map_location="cpu"
241
+ )
242
+ dist_barrier()
243
+ if is_master():
244
+ os.remove(os.path.join(self.checkpoint_path, "sync.pt"))
245
+ dist_barrier()
246
+
247
+ # load checkpoint
248
+ self.network.load_state_dict(checkpoint["state_dict"], strict=False)
249
+ if "optimizer" in checkpoint:
250
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
251
+ if "lr_scheduler" in checkpoint:
252
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
253
+ if "ema" in checkpoint and self.ema is not None:
254
+ self.ema.load_state_dict(checkpoint["ema"])
255
+ if "scaler" in checkpoint and self.fp16:
256
+ self.scaler.load_state_dict(checkpoint["scaler"])
257
+
258
+ def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
259
+ for key in feed_dict:
260
+ if isinstance(feed_dict[key], torch.Tensor):
261
+ feed_dict[key] = feed_dict[key].cuda()
262
+ return feed_dict
263
+
264
+ def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
265
+ raise NotImplementedError
266
+
267
+ def after_step(self) -> None:
268
+ self.scaler.unscale_(self.optimizer)
269
+ # gradient clip
270
+ if self.run_config.grad_clip is not None:
271
+ torch.nn.utils.clip_grad_value_(
272
+ self.model.parameters(), self.run_config.grad_clip
273
+ )
274
+ # update
275
+ self.scaler.step(self.optimizer)
276
+ self.scaler.update()
277
+
278
+ self.lr_scheduler.step()
279
+ self.run_config.step()
280
+ # update ema
281
+ if self.ema is not None:
282
+ self.ema.step(self.network, self.run_config.global_step)
283
+
284
+ def _train_one_epoch(self, epoch: int) -> dict[str, any]:
285
+ raise NotImplementedError
286
+
287
+ def train_one_epoch(self, epoch: int) -> dict[str, any]:
288
+ self.model.train()
289
+
290
+ self.data_provider.set_epoch(epoch)
291
+
292
+ train_info_dict = self._train_one_epoch(epoch)
293
+
294
+ return train_info_dict
295
+
296
+ def train(self) -> None:
297
+ raise NotImplementedError
yolo-world-with-efficientvit-sam/efficientvit/apps/trainer/run_config.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import json
6
+
7
+ import numpy as np
8
+ import torch.nn as nn
9
+
10
+ from efficientvit.apps.utils import CosineLRwithWarmup, build_optimizer
11
+
12
+ __all__ = ["Scheduler", "RunConfig"]
13
+
14
+
15
+ class Scheduler:
16
+ PROGRESS = 0
17
+
18
+
19
+ class RunConfig:
20
+ n_epochs: int
21
+ init_lr: float
22
+ warmup_epochs: int
23
+ warmup_lr: float
24
+ lr_schedule_name: str
25
+ lr_schedule_param: dict
26
+ optimizer_name: str
27
+ optimizer_params: dict
28
+ weight_decay: float
29
+ no_wd_keys: list
30
+ grad_clip: float # allow none to turn off grad clipping
31
+ reset_bn: bool
32
+ reset_bn_size: int
33
+ reset_bn_batch_size: int
34
+ eval_image_size: list # allow none to use image_size in data_provider
35
+
36
+ @property
37
+ def none_allowed(self):
38
+ return ["grad_clip", "eval_image_size"]
39
+
40
+ def __init__(self, **kwargs): # arguments must be passed as kwargs
41
+ for k, val in kwargs.items():
42
+ setattr(self, k, val)
43
+
44
+ # check that all relevant configs are there
45
+ annotations = {}
46
+ for clas in type(self).mro():
47
+ if hasattr(clas, "__annotations__"):
48
+ annotations.update(clas.__annotations__)
49
+ for k, k_type in annotations.items():
50
+ assert hasattr(
51
+ self, k
52
+ ), f"Key {k} with type {k_type} required for initialization."
53
+ attr = getattr(self, k)
54
+ if k in self.none_allowed:
55
+ k_type = (k_type, type(None))
56
+ assert isinstance(
57
+ attr, k_type
58
+ ), f"Key {k} must be type {k_type}, provided={attr}."
59
+
60
+ self.global_step = 0
61
+ self.batch_per_epoch = 1
62
+
63
+ def build_optimizer(self, network: nn.Module) -> tuple[any, any]:
64
+ r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler"""
65
+ param_dict = {}
66
+ for name, param in network.named_parameters():
67
+ if param.requires_grad:
68
+ opt_config = [self.weight_decay, self.init_lr]
69
+ if self.no_wd_keys is not None and len(self.no_wd_keys) > 0:
70
+ if np.any([key in name for key in self.no_wd_keys]):
71
+ opt_config[0] = 0
72
+ opt_key = json.dumps(opt_config)
73
+ param_dict[opt_key] = param_dict.get(opt_key, []) + [param]
74
+
75
+ net_params = []
76
+ for opt_key, param_list in param_dict.items():
77
+ wd, lr = json.loads(opt_key)
78
+ net_params.append({"params": param_list, "weight_decay": wd, "lr": lr})
79
+
80
+ optimizer = build_optimizer(
81
+ net_params, self.optimizer_name, self.optimizer_params, self.init_lr
82
+ )
83
+ # build lr scheduler
84
+ if self.lr_schedule_name == "cosine":
85
+ decay_steps = []
86
+ for epoch in self.lr_schedule_param.get("step", []):
87
+ decay_steps.append(epoch * self.batch_per_epoch)
88
+ decay_steps.append(self.n_epochs * self.batch_per_epoch)
89
+ decay_steps.sort()
90
+ lr_scheduler = CosineLRwithWarmup(
91
+ optimizer,
92
+ self.warmup_epochs * self.batch_per_epoch,
93
+ self.warmup_lr,
94
+ decay_steps,
95
+ )
96
+ else:
97
+ raise NotImplementedError
98
+ return optimizer, lr_scheduler
99
+
100
+ def update_global_step(self, epoch, batch_id=0) -> None:
101
+ self.global_step = epoch * self.batch_per_epoch + batch_id
102
+ Scheduler.PROGRESS = self.progress
103
+
104
+ @property
105
+ def progress(self) -> float:
106
+ warmup_steps = self.warmup_epochs * self.batch_per_epoch
107
+ steps = max(0, self.global_step - warmup_steps)
108
+ return steps / (self.n_epochs * self.batch_per_epoch)
109
+
110
+ def step(self) -> None:
111
+ self.global_step += 1
112
+ Scheduler.PROGRESS = self.progress
113
+
114
+ def get_remaining_epoch(self, epoch, post=True) -> int:
115
+ return self.n_epochs + self.warmup_epochs - epoch - int(post)
116
+
117
+ def epoch_format(self, epoch: int) -> str:
118
+ epoch_format = f"%.{len(str(self.n_epochs))}d"
119
+ epoch_format = f"[{epoch_format}/{epoch_format}]"
120
+ epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs)
121
+ return epoch_format
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ from .dist import *
6
+ from .ema import *
7
+ from .export import *
8
+ from .init import *
9
+ from .lr import *
10
+ from .metric import *
11
+ from .misc import *
12
+ from .opt import *
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (412 Bytes). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/dist.cpython-311.pyc ADDED
Binary file (3.92 kB). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/ema.cpython-311.pyc ADDED
Binary file (3.54 kB). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/export.cpython-311.pyc ADDED
Binary file (2.57 kB). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/init.cpython-311.pyc ADDED
Binary file (4.3 kB). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/lr.cpython-311.pyc ADDED
Binary file (3.02 kB). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/metric.cpython-311.pyc ADDED
Binary file (2.67 kB). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/misc.cpython-311.pyc ADDED
Binary file (5.06 kB). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/__pycache__/opt.cpython-311.pyc ADDED
Binary file (1.32 kB). View file
 
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/dist.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import os
6
+
7
+ import torch
8
+ import torch.distributed
9
+
10
+ from efficientvit.models.utils.list import list_mean, list_sum
11
+
12
+ __all__ = [
13
+ "dist_init",
14
+ "get_dist_rank",
15
+ "get_dist_size",
16
+ "is_master",
17
+ "dist_barrier",
18
+ "get_dist_local_rank",
19
+ "sync_tensor",
20
+ ]
21
+
22
+
23
+ def dist_init() -> None:
24
+ try:
25
+ torch.distributed.init_process_group(backend="nccl")
26
+ assert torch.distributed.is_initialized()
27
+ except Exception:
28
+ # use torchpack
29
+ from torchpack import distributed as dist
30
+
31
+ dist.init()
32
+ os.environ["RANK"] = f"{dist.rank()}"
33
+ os.environ["WORLD_SIZE"] = f"{dist.size()}"
34
+ os.environ["LOCAL_RANK"] = f"{dist.local_rank()}"
35
+
36
+
37
+ def get_dist_rank() -> int:
38
+ return int(os.environ["RANK"])
39
+
40
+
41
+ def get_dist_size() -> int:
42
+ return int(os.environ["WORLD_SIZE"])
43
+
44
+
45
+ def is_master() -> bool:
46
+ return get_dist_rank() == 0
47
+
48
+
49
+ def dist_barrier() -> None:
50
+ torch.distributed.barrier()
51
+
52
+
53
+ def get_dist_local_rank() -> int:
54
+ return int(os.environ["LOCAL_RANK"])
55
+
56
+
57
+ def sync_tensor(
58
+ tensor: torch.Tensor or float, reduce="mean"
59
+ ) -> torch.Tensor or list[torch.Tensor]:
60
+ if not isinstance(tensor, torch.Tensor):
61
+ tensor = torch.Tensor(1).fill_(tensor).cuda()
62
+ tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())]
63
+ torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False)
64
+ if reduce == "mean":
65
+ return list_mean(tensor_list)
66
+ elif reduce == "sum":
67
+ return list_sum(tensor_list)
68
+ elif reduce == "cat":
69
+ return torch.cat(tensor_list, dim=0)
70
+ elif reduce == "root":
71
+ return tensor_list[0]
72
+ else:
73
+ return tensor_list
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/ema.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import copy
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from efficientvit.models.utils import is_parallel
12
+
13
+ __all__ = ["EMA"]
14
+
15
+
16
+ def update_ema(
17
+ ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float
18
+ ) -> None:
19
+ for k, v in ema.state_dict().items():
20
+ if v.dtype.is_floating_point:
21
+ v -= (1.0 - decay) * (v - new_state_dict[k].detach())
22
+
23
+
24
+ class EMA:
25
+ def __init__(self, model: nn.Module, decay: float, warmup_steps=2000):
26
+ self.shadows = copy.deepcopy(
27
+ model.module if is_parallel(model) else model
28
+ ).eval()
29
+ self.decay = decay
30
+ self.warmup_steps = warmup_steps
31
+
32
+ for p in self.shadows.parameters():
33
+ p.requires_grad = False
34
+
35
+ def step(self, model: nn.Module, global_step: int) -> None:
36
+ with torch.no_grad():
37
+ msd = (model.module if is_parallel(model) else model).state_dict()
38
+ update_ema(
39
+ self.shadows,
40
+ msd,
41
+ self.decay * (1 - math.exp(-global_step / self.warmup_steps)),
42
+ )
43
+
44
+ def state_dict(self) -> dict[float, dict[str, torch.Tensor]]:
45
+ return {self.decay: self.shadows.state_dict()}
46
+
47
+ def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None:
48
+ for decay in state_dict:
49
+ if decay == self.decay:
50
+ self.shadows.load_state_dict(state_dict[decay])
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/export.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import io
6
+ import os
7
+
8
+ import onnx
9
+ import torch
10
+ import torch.nn as nn
11
+ from onnxsim import simplify as simplify_func
12
+
13
+ __all__ = ["export_onnx"]
14
+
15
+
16
+ def export_onnx(
17
+ model: nn.Module, export_path: str, sample_inputs: any, simplify=True, opset=11
18
+ ) -> None:
19
+ """Export a model to a platform-specific onnx format.
20
+
21
+ Args:
22
+ model: a torch.nn.Module object.
23
+ export_path: export location.
24
+ sample_inputs: Any.
25
+ simplify: a flag to turn on onnx-simplifier
26
+ opset: int
27
+ """
28
+ model.eval()
29
+
30
+ buffer = io.BytesIO()
31
+ with torch.no_grad():
32
+ torch.onnx.export(model, sample_inputs, buffer, opset_version=opset)
33
+ buffer.seek(0, 0)
34
+ if simplify:
35
+ onnx_model = onnx.load_model(buffer)
36
+ onnx_model, success = simplify_func(onnx_model)
37
+ assert success
38
+ new_buffer = io.BytesIO()
39
+ onnx.save(onnx_model, new_buffer)
40
+ buffer = new_buffer
41
+ buffer.seek(0, 0)
42
+
43
+ if buffer.getbuffer().nbytes > 0:
44
+ save_dir = os.path.dirname(export_path)
45
+ os.makedirs(save_dir, exist_ok=True)
46
+ with open(export_path, "wb") as f:
47
+ f.write(buffer.read())
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/init.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.modules.batchnorm import _BatchNorm
8
+
9
+ __all__ = ["init_modules", "zero_last_gamma"]
10
+
11
+
12
+ def init_modules(model: nn.Module or list[nn.Module], init_type="trunc_normal") -> None:
13
+ _DEFAULT_INIT_PARAM = {"trunc_normal": 0.02}
14
+
15
+ if isinstance(model, list):
16
+ for sub_module in model:
17
+ init_modules(sub_module, init_type)
18
+ else:
19
+ init_params = init_type.split("@")
20
+ init_params = float(init_params[1]) if len(init_params) > 1 else None
21
+
22
+ if init_type.startswith("trunc_normal"):
23
+ init_func = lambda param: nn.init.trunc_normal_(
24
+ param, std=(init_params or _DEFAULT_INIT_PARAM["trunc_normal"])
25
+ )
26
+ else:
27
+ raise NotImplementedError
28
+
29
+ for m in model.modules():
30
+ if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
31
+ init_func(m.weight)
32
+ if m.bias is not None:
33
+ m.bias.data.zero_()
34
+ elif isinstance(m, nn.Embedding):
35
+ init_func(m.weight)
36
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
37
+ m.weight.data.fill_(1)
38
+ m.bias.data.zero_()
39
+ else:
40
+ weight = getattr(m, "weight", None)
41
+ bias = getattr(m, "bias", None)
42
+ if isinstance(weight, torch.nn.Parameter):
43
+ init_func(weight)
44
+ if isinstance(bias, torch.nn.Parameter):
45
+ bias.data.zero_()
46
+
47
+
48
+ def zero_last_gamma(model: nn.Module, init_val=0) -> None:
49
+ import efficientvit.models.nn.ops as ops
50
+
51
+ for m in model.modules():
52
+ if isinstance(m, ops.ResidualBlock) and isinstance(
53
+ m.shortcut, ops.IdentityLayer
54
+ ):
55
+ if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)):
56
+ parent_module = m.main.point_conv
57
+ elif isinstance(m.main, ops.ResBlock):
58
+ parent_module = m.main.conv2
59
+ elif isinstance(m.main, ops.ConvLayer):
60
+ parent_module = m.main
61
+ elif isinstance(m.main, (ops.LiteMLA)):
62
+ parent_module = m.main.proj
63
+ else:
64
+ parent_module = None
65
+ if parent_module is not None:
66
+ norm = getattr(parent_module, "norm", None)
67
+ if norm is not None:
68
+ nn.init.constant_(norm.weight, init_val)
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/lr.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import math
6
+
7
+ import torch
8
+
9
+ from efficientvit.models.utils.list import val2list
10
+
11
+ __all__ = ["CosineLRwithWarmup"]
12
+
13
+
14
+ class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
15
+ def __init__(
16
+ self,
17
+ optimizer: torch.optim.Optimizer,
18
+ warmup_steps: int,
19
+ warmup_lr: float,
20
+ decay_steps: int or list[int],
21
+ last_epoch: int = -1,
22
+ ) -> None:
23
+ self.warmup_steps = warmup_steps
24
+ self.warmup_lr = warmup_lr
25
+ self.decay_steps = val2list(decay_steps)
26
+ super().__init__(optimizer, last_epoch)
27
+
28
+ def get_lr(self) -> list[float]:
29
+ if self.last_epoch < self.warmup_steps:
30
+ return [
31
+ (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps
32
+ + self.warmup_lr
33
+ for base_lr in self.base_lrs
34
+ ]
35
+ else:
36
+ current_steps = self.last_epoch - self.warmup_steps
37
+ decay_steps = [0] + self.decay_steps
38
+ idx = len(decay_steps) - 2
39
+ for i, decay_step in enumerate(decay_steps[:-1]):
40
+ if decay_step <= current_steps < decay_steps[i + 1]:
41
+ idx = i
42
+ break
43
+ current_steps -= decay_steps[idx]
44
+ decay_step = decay_steps[idx + 1] - decay_steps[idx]
45
+ return [
46
+ 0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step))
47
+ for base_lr in self.base_lrs
48
+ ]
yolo-world-with-efficientvit-sam/efficientvit/apps/utils/metric.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
2
+ # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
3
+ # International Conference on Computer Vision (ICCV), 2023
4
+
5
+ import torch
6
+
7
+ from efficientvit.apps.utils.dist import sync_tensor
8
+
9
+ __all__ = ["AverageMeter"]
10
+
11
+
12
+ class AverageMeter:
13
+ """Computes and stores the average and current value."""
14
+
15
+ def __init__(self, is_distributed=True):
16
+ self.is_distributed = is_distributed
17
+ self.sum = 0
18
+ self.count = 0
19
+
20
+ def _sync(self, val: torch.Tensor or int or float) -> torch.Tensor or int or float:
21
+ return sync_tensor(val, reduce="sum") if self.is_distributed else val
22
+
23
+ def update(self, val: torch.Tensor or int or float, delta_n=1):
24
+ self.count += self._sync(delta_n)
25
+ self.sum += self._sync(val * delta_n)
26
+
27
+ def get_count(self) -> torch.Tensor or int or float:
28
+ return (
29
+ self.count.item()
30
+ if isinstance(self.count, torch.Tensor) and self.count.numel() == 1
31
+ else self.count
32
+ )
33
+
34
+ @property
35
+ def avg(self):
36
+ avg = -1 if self.count == 0 else self.sum / self.count
37
+ return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg