add lib/timm
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- lib/timm-0.5.4.dist-info/INSTALLER +1 -0
- lib/timm-0.5.4.dist-info/LICENSE +201 -0
- lib/timm-0.5.4.dist-info/METADATA +503 -0
- lib/timm-0.5.4.dist-info/RECORD +364 -0
- lib/timm-0.5.4.dist-info/REQUESTED +0 -0
- lib/timm-0.5.4.dist-info/WHEEL +5 -0
- lib/timm-0.5.4.dist-info/top_level.txt +1 -0
- lib/timm/__init__.py +4 -0
- lib/timm/__pycache__/__init__.cpython-310.pyc +0 -0
- lib/timm/__pycache__/version.cpython-310.pyc +0 -0
- lib/timm/data/__init__.py +12 -0
- lib/timm/data/__pycache__/__init__.cpython-310.pyc +0 -0
- lib/timm/data/__pycache__/auto_augment.cpython-310.pyc +0 -0
- lib/timm/data/__pycache__/config.cpython-310.pyc +0 -0
- lib/timm/data/__pycache__/constants.cpython-310.pyc +0 -0
- lib/timm/data/__pycache__/dataset.cpython-310.pyc +0 -0
- lib/timm/data/__pycache__/dataset_factory.cpython-310.pyc +0 -0
- lib/timm/data/__pycache__/distributed_sampler.cpython-310.pyc +0 -0
- lib/timm/data/__pycache__/loader.cpython-310.pyc +0 -0
- lib/timm/data/__pycache__/mixup.cpython-310.pyc +0 -0
- lib/timm/data/__pycache__/random_erasing.cpython-310.pyc +0 -0
- lib/timm/data/__pycache__/real_labels.cpython-310.pyc +0 -0
- lib/timm/data/__pycache__/tf_preprocessing.cpython-310.pyc +0 -0
- lib/timm/data/__pycache__/transforms.cpython-310.pyc +0 -0
- lib/timm/data/__pycache__/transforms_factory.cpython-310.pyc +0 -0
- lib/timm/data/auto_augment.py +865 -0
- lib/timm/data/config.py +78 -0
- lib/timm/data/constants.py +7 -0
- lib/timm/data/dataset.py +152 -0
- lib/timm/data/dataset_factory.py +143 -0
- lib/timm/data/distributed_sampler.py +129 -0
- lib/timm/data/loader.py +289 -0
- lib/timm/data/mixup.py +316 -0
- lib/timm/data/parsers/__init__.py +1 -0
- lib/timm/data/parsers/__pycache__/__init__.cpython-310.pyc +0 -0
- lib/timm/data/parsers/__pycache__/class_map.cpython-310.pyc +0 -0
- lib/timm/data/parsers/__pycache__/constants.cpython-310.pyc +0 -0
- lib/timm/data/parsers/__pycache__/parser.cpython-310.pyc +0 -0
- lib/timm/data/parsers/__pycache__/parser_factory.cpython-310.pyc +0 -0
- lib/timm/data/parsers/__pycache__/parser_image_folder.cpython-310.pyc +0 -0
- lib/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-310.pyc +0 -0
- lib/timm/data/parsers/__pycache__/parser_image_tar.cpython-310.pyc +0 -0
- lib/timm/data/parsers/__pycache__/parser_tfds.cpython-310.pyc +0 -0
- lib/timm/data/parsers/class_map.py +19 -0
- lib/timm/data/parsers/constants.py +1 -0
- lib/timm/data/parsers/parser.py +17 -0
- lib/timm/data/parsers/parser_factory.py +29 -0
- lib/timm/data/parsers/parser_image_folder.py +69 -0
- lib/timm/data/parsers/parser_image_in_tar.py +222 -0
- lib/timm/data/parsers/parser_image_tar.py +72 -0
lib/timm-0.5.4.dist-info/INSTALLER
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pip
|
lib/timm-0.5.4.dist-info/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2019 Ross Wightman
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
lib/timm-0.5.4.dist-info/METADATA
ADDED
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: timm
|
3 |
+
Version: 0.5.4
|
4 |
+
Summary: (Unofficial) PyTorch Image Models
|
5 |
+
Home-page: https://github.com/rwightman/pytorch-image-models
|
6 |
+
Author: Ross Wightman
|
7 |
+
Author-email: hello@rwightman.com
|
8 |
+
License: UNKNOWN
|
9 |
+
Keywords: pytorch pretrained models efficientnet mobilenetv3 mnasnet
|
10 |
+
Platform: UNKNOWN
|
11 |
+
Classifier: Development Status :: 3 - Alpha
|
12 |
+
Classifier: Intended Audience :: Education
|
13 |
+
Classifier: Intended Audience :: Science/Research
|
14 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
15 |
+
Classifier: Programming Language :: Python :: 3.6
|
16 |
+
Classifier: Programming Language :: Python :: 3.7
|
17 |
+
Classifier: Programming Language :: Python :: 3.8
|
18 |
+
Classifier: Topic :: Scientific/Engineering
|
19 |
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
20 |
+
Classifier: Topic :: Software Development
|
21 |
+
Classifier: Topic :: Software Development :: Libraries
|
22 |
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
23 |
+
Requires-Python: >=3.6
|
24 |
+
Description-Content-Type: text/markdown
|
25 |
+
License-File: LICENSE
|
26 |
+
Requires-Dist: torch (>=1.4)
|
27 |
+
Requires-Dist: torchvision
|
28 |
+
|
29 |
+
# PyTorch Image Models
|
30 |
+
- [Sponsors](#sponsors)
|
31 |
+
- [What's New](#whats-new)
|
32 |
+
- [Introduction](#introduction)
|
33 |
+
- [Models](#models)
|
34 |
+
- [Features](#features)
|
35 |
+
- [Results](#results)
|
36 |
+
- [Getting Started (Documentation)](#getting-started-documentation)
|
37 |
+
- [Train, Validation, Inference Scripts](#train-validation-inference-scripts)
|
38 |
+
- [Awesome PyTorch Resources](#awesome-pytorch-resources)
|
39 |
+
- [Licenses](#licenses)
|
40 |
+
- [Citing](#citing)
|
41 |
+
|
42 |
+
## Sponsors
|
43 |
+
|
44 |
+
A big thank you to my [GitHub Sponsors](https://github.com/sponsors/rwightman) for their support!
|
45 |
+
|
46 |
+
In addition to the sponsors at the link above, I've received hardware and/or cloud resources from
|
47 |
+
* Nvidia (https://www.nvidia.com/en-us/)
|
48 |
+
* TFRC (https://www.tensorflow.org/tfrc)
|
49 |
+
|
50 |
+
I'm fortunate to be able to dedicate significant time and money of my own supporting this and other open source projects. However, as the projects increase in scope, outside support is needed to continue with the current trajectory of cloud services, hardware, and electricity costs.
|
51 |
+
|
52 |
+
## What's New
|
53 |
+
|
54 |
+
### Jan 14, 2022
|
55 |
+
* Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon....
|
56 |
+
* Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features
|
57 |
+
* Tried training a few small (~1.8-3M param) / mobile optimized models, a few are good so far, more on the way...
|
58 |
+
* `mnasnet_small` - 65.6 top-1
|
59 |
+
* `mobilenetv2_050` - 65.9
|
60 |
+
* `lcnet_100/075/050` - 72.1 / 68.8 / 63.1
|
61 |
+
* `semnasnet_075` - 73
|
62 |
+
* `fbnetv3_b/d/g` - 79.1 / 79.7 / 82.0
|
63 |
+
* TinyNet models added by [rsomani95](https://github.com/rsomani95)
|
64 |
+
* LCNet added via MobileNetV3 architecture
|
65 |
+
|
66 |
+
### Nov 22, 2021
|
67 |
+
* A number of updated weights anew new model defs
|
68 |
+
* `eca_halonext26ts` - 79.5 @ 256
|
69 |
+
* `resnet50_gn` (new) - 80.1 @ 224, 81.3 @ 288
|
70 |
+
* `resnet50` - 80.7 @ 224, 80.9 @ 288 (trained at 176, not replacing current a1 weights as default since these don't scale as well to higher res, [weights](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth))
|
71 |
+
* `resnext50_32x4d` - 81.1 @ 224, 82.0 @ 288
|
72 |
+
* `sebotnet33ts_256` (new) - 81.2 @ 224
|
73 |
+
* `lamhalobotnet50ts_256` - 81.5 @ 256
|
74 |
+
* `halonet50ts` - 81.7 @ 256
|
75 |
+
* `halo2botnet50ts_256` - 82.0 @ 256
|
76 |
+
* `resnet101` - 82.0 @ 224, 82.8 @ 288
|
77 |
+
* `resnetv2_101` (new) - 82.1 @ 224, 83.0 @ 288
|
78 |
+
* `resnet152` - 82.8 @ 224, 83.5 @ 288
|
79 |
+
* `regnetz_d8` (new) - 83.5 @ 256, 84.0 @ 320
|
80 |
+
* `regnetz_e8` (new) - 84.5 @ 256, 85.0 @ 320
|
81 |
+
* `vit_base_patch8_224` (85.8 top-1) & `in21k` variant weights added thanks [Martins Bruveris](https://github.com/martinsbruveris)
|
82 |
+
* Groundwork in for FX feature extraction thanks to [Alexander Soare](https://github.com/alexander-soare)
|
83 |
+
* models updated for tracing compatibility (almost full support with some distlled transformer exceptions)
|
84 |
+
|
85 |
+
### Oct 19, 2021
|
86 |
+
* ResNet strikes back (https://arxiv.org/abs/2110.00476) weights added, plus any extra training components used. Model weights and some more details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-rsb-weights)
|
87 |
+
* BCE loss and Repeated Augmentation support for RSB paper
|
88 |
+
* 4 series of ResNet based attention model experiments being added (implemented across byobnet.py/byoanet.py). These include all sorts of attention, from channel attn like SE, ECA to 2D QKV self-attention layers such as Halo, Bottlneck, Lambda. Details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
|
89 |
+
* Working implementations of the following 2D self-attention modules (likely to be differences from paper or eventual official impl):
|
90 |
+
* Halo (https://arxiv.org/abs/2103.12731)
|
91 |
+
* Bottleneck Transformer (https://arxiv.org/abs/2101.11605)
|
92 |
+
* LambdaNetworks (https://arxiv.org/abs/2102.08602)
|
93 |
+
* A RegNetZ series of models with some attention experiments (being added to). These do not follow the paper (https://arxiv.org/abs/2103.06877) in any way other than block architecture, details of official models are not available. See more here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
|
94 |
+
* ConvMixer (https://openreview.net/forum?id=TVHS5Y4dNvM), CrossVit (https://arxiv.org/abs/2103.14899), and BeiT (https://arxiv.org/abs/2106.08254) architectures + weights added
|
95 |
+
* freeze/unfreeze helpers by [Alexander Soare](https://github.com/alexander-soare)
|
96 |
+
|
97 |
+
### Aug 18, 2021
|
98 |
+
* Optimizer bonanza!
|
99 |
+
* Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits))
|
100 |
+
* Add MADGRAD from FB research w/ a few tweaks (decoupled decay option, step handling that works with PyTorch XLA)
|
101 |
+
* Some cleanup on all optimizers and factory. No more `.data`, a bit more consistency, unit tests for all!
|
102 |
+
* SGDP and AdamP still won't work with PyTorch XLA but others should (have yet to test Adabelief, Adafactor, Adahessian myself).
|
103 |
+
* EfficientNet-V2 XL TF ported weights added, but they don't validate well in PyTorch (L is better). The pre-processing for the V2 TF training is a bit diff and the fine-tuned 21k -> 1k weights are very sensitive and less robust than the 1k weights.
|
104 |
+
* Added PyTorch trained EfficientNet-V2 'Tiny' w/ GlobalContext attn weights. Only .1-.2 top-1 better than the SE so more of a curiosity for those interested.
|
105 |
+
|
106 |
+
### July 12, 2021
|
107 |
+
* Add XCiT models from [official facebook impl](https://github.com/facebookresearch/xcit). Contributed by [Alexander Soare](https://github.com/alexander-soare)
|
108 |
+
|
109 |
+
### July 5-9, 2021
|
110 |
+
* Add `efficientnetv2_rw_t` weights, a custom 'tiny' 13.6M param variant that is a bit better than (non NoisyStudent) B3 models. Both faster and better accuracy (at same or lower res)
|
111 |
+
* top-1 82.34 @ 288x288 and 82.54 @ 320x320
|
112 |
+
* Add [SAM pretrained](https://arxiv.org/abs/2106.01548) in1k weight for ViT B/16 (`vit_base_patch16_sam_224`) and B/32 (`vit_base_patch32_sam_224`) models.
|
113 |
+
* Add 'Aggregating Nested Transformer' (NesT) w/ weights converted from official [Flax impl](https://github.com/google-research/nested-transformer). Contributed by [Alexander Soare](https://github.com/alexander-soare).
|
114 |
+
* `jx_nest_base` - 83.534, `jx_nest_small` - 83.120, `jx_nest_tiny` - 81.426
|
115 |
+
|
116 |
+
### June 23, 2021
|
117 |
+
* Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050). Hparams for this and other recent MLP training [here](https://gist.github.com/rwightman/d6c264a9001f9167e06c209f630b2cc6)
|
118 |
+
|
119 |
+
### June 20, 2021
|
120 |
+
* Release Vision Transformer 'AugReg' weights from [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270)
|
121 |
+
* .npz weight loading support added, can load any of the 50K+ weights from the [AugReg series](https://console.cloud.google.com/storage/browser/vit_models/augreg)
|
122 |
+
* See [example notebook](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb) from [official impl](https://github.com/google-research/vision_transformer/) for navigating the augreg weights
|
123 |
+
* Replaced all default weights w/ best AugReg variant (if possible). All AugReg 21k classifiers work.
|
124 |
+
* Highlights: `vit_large_patch16_384` (87.1 top-1), `vit_large_r50_s32_384` (86.2 top-1), `vit_base_patch16_384` (86.0 top-1)
|
125 |
+
* `vit_deit_*` renamed to just `deit_*`
|
126 |
+
* Remove my old small model, replace with DeiT compatible small w/ AugReg weights
|
127 |
+
* Add 1st training of my `gmixer_24_224` MLP /w GLU, 78.1 top-1 w/ 25M params.
|
128 |
+
* Add weights from official ResMLP release (https://github.com/facebookresearch/deit)
|
129 |
+
* Add `eca_nfnet_l2` weights from my 'lightweight' series. 84.7 top-1 at 384x384.
|
130 |
+
* Add distilled BiT 50x1 student and 152x2 Teacher weights from [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237)
|
131 |
+
* NFNets and ResNetV2-BiT models work w/ Pytorch XLA now
|
132 |
+
* weight standardization uses F.batch_norm instead of std_mean (std_mean wasn't lowered)
|
133 |
+
* eps values adjusted, will be slight differences but should be quite close
|
134 |
+
* Improve test coverage and classifier interface of non-conv (vision transformer and mlp) models
|
135 |
+
* Cleanup a few classifier / flatten details for models w/ conv classifiers or early global pool
|
136 |
+
* Please report any regressions, this PR touched quite a few models.
|
137 |
+
|
138 |
+
### June 8, 2021
|
139 |
+
* Add first ResMLP weights, trained in PyTorch XLA on TPU-VM w/ my XLA branch. 24 block variant, 79.2 top-1.
|
140 |
+
* Add ResNet51-Q model w/ pretrained weights at 82.36 top-1.
|
141 |
+
* NFNet inspired block layout with quad layer stem and no maxpool
|
142 |
+
* Same param count (35.7M) and throughput as ResNetRS-50 but +1.5 top-1 @ 224x224 and +2.5 top-1 at 288x288
|
143 |
+
|
144 |
+
### May 25, 2021
|
145 |
+
* Add LeViT, Visformer, ConViT (PR by Aman Arora), Twins (PR by paper authors) transformer models
|
146 |
+
* Add ResMLP and gMLP MLP vision models to the existing MLP Mixer impl
|
147 |
+
* Fix a number of torchscript issues with various vision transformer models
|
148 |
+
* Cleanup input_size/img_size override handling and improve testing / test coverage for all vision transformer and MLP models
|
149 |
+
* More flexible pos embedding resize (non-square) for ViT and TnT. Thanks [Alexander Soare](https://github.com/alexander-soare)
|
150 |
+
* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params.
|
151 |
+
|
152 |
+
### May 14, 2021
|
153 |
+
* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl.
|
154 |
+
* 1k trained variants: `tf_efficientnetv2_s/m/l`
|
155 |
+
* 21k trained variants: `tf_efficientnetv2_s/m/l_in21k`
|
156 |
+
* 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_in21ft1k`
|
157 |
+
* v2 models w/ v1 scaling: `tf_efficientnetv2_b0` through `b3`
|
158 |
+
* Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s`
|
159 |
+
* Some blank `efficientnetv2_*` models in-place for future native PyTorch training
|
160 |
+
|
161 |
+
### May 5, 2021
|
162 |
+
* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen)
|
163 |
+
* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit)
|
164 |
+
* Add ResNet-RS models and weights from [TF](https://github.com/tensorflow/tpu/tree/master/models/official/resnet/resnet_rs). Thanks [Aman Arora](https://github.com/amaarora)
|
165 |
+
* Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin)
|
166 |
+
* Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23)
|
167 |
+
* Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai)
|
168 |
+
* Update ByoaNet attention modules
|
169 |
+
* Improve SA module inits
|
170 |
+
* Hack together experimental stand-alone Swin based attn module and `swinnet`
|
171 |
+
* Consistent '26t' model defs for experiments.
|
172 |
+
* Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1.
|
173 |
+
* WandB logging support
|
174 |
+
|
175 |
+
### April 13, 2021
|
176 |
+
* Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer
|
177 |
+
|
178 |
+
### April 12, 2021
|
179 |
+
* Add ECA-NFNet-L1 (slimmed down F1 w/ SiLU, 41M params) trained with this code. 84% top-1 @ 320x320. Trained at 256x256.
|
180 |
+
* Add EfficientNet-V2S model (unverified model definition) weights. 83.3 top-1 @ 288x288. Only trained single res 224. Working on progressive training.
|
181 |
+
* Add ByoaNet model definition (Bring-your-own-attention) w/ SelfAttention block and corresponding SA/SA-like modules and model defs
|
182 |
+
* Lambda Networks - https://arxiv.org/abs/2102.08602
|
183 |
+
* Bottleneck Transformers - https://arxiv.org/abs/2101.11605
|
184 |
+
* Halo Nets - https://arxiv.org/abs/2103.12731
|
185 |
+
* Adabelief optimizer contributed by Juntang Zhuang
|
186 |
+
|
187 |
+
### April 1, 2021
|
188 |
+
* Add snazzy `benchmark.py` script for bulk `timm` model benchmarking of train and/or inference
|
189 |
+
* Add Pooling-based Vision Transformer (PiT) models (from https://github.com/naver-ai/pit)
|
190 |
+
* Merged distilled variant into main for torchscript compatibility
|
191 |
+
* Some `timm` cleanup/style tweaks and weights have hub download support
|
192 |
+
* Cleanup Vision Transformer (ViT) models
|
193 |
+
* Merge distilled (DeiT) model into main so that torchscript can work
|
194 |
+
* Support updated weight init (defaults to old still) that closer matches original JAX impl (possibly better training from scratch)
|
195 |
+
* Separate hybrid model defs into different file and add several new model defs to fiddle with, support patch_size != 1 for hybrids
|
196 |
+
* Fix fine-tuning num_class changes (PiT and ViT) and pos_embed resizing (Vit) with distilled variants
|
197 |
+
* nn.Sequential for block stack (does not break downstream compat)
|
198 |
+
* TnT (Transformer-in-Transformer) models contributed by author (from https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT)
|
199 |
+
* Add RegNetY-160 weights from DeiT teacher model
|
200 |
+
* Add new NFNet-L0 w/ SE attn (rename `nfnet_l0b`->`nfnet_l0`) weights 82.75 top-1 @ 288x288
|
201 |
+
* Some fixes/improvements for TFDS dataset wrapper
|
202 |
+
|
203 |
+
### March 17, 2021
|
204 |
+
* Add new ECA-NFNet-L0 (rename `nfnet_l0c`->`eca_nfnet_l0`) weights trained by myself.
|
205 |
+
* 82.6 top-1 @ 288x288, 82.8 @ 320x320, trained at 224x224
|
206 |
+
* Uses SiLU activation, approx 2x faster than `dm_nfnet_f0` and 50% faster than `nfnet_f0s` w/ 1/3 param count
|
207 |
+
* Integrate [Hugging Face model hub](https://huggingface.co/models) into timm create_model and default_cfg handling for pretrained weight and config sharing (more on this soon!)
|
208 |
+
* Merge HardCoRe NAS models contributed by https://github.com/yoniaflalo
|
209 |
+
* Merge PyTorch trained EfficientNet-EL and pruned ES/EL variants contributed by [DeGirum](https://github.com/DeGirum)
|
210 |
+
|
211 |
+
|
212 |
+
### March 7, 2021
|
213 |
+
* First 0.4.x PyPi release w/ NFNets (& related), ByoB (GPU-Efficient, RepVGG, etc).
|
214 |
+
* Change feature extraction for pre-activation nets (NFNets, ResNetV2) to return features before activation.
|
215 |
+
* Tested with PyTorch 1.8 release. Updated CI to use 1.8.
|
216 |
+
* Benchmarked several arch on RTX 3090, Titan RTX, and V100 across 1.7.1, 1.8, NGC 20.12, and 21.02. Some interesting performance variations to take note of https://gist.github.com/rwightman/bb59f9e245162cee0e38bd66bd8cd77f
|
217 |
+
|
218 |
+
### Feb 18, 2021
|
219 |
+
* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets).
|
220 |
+
* Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn.
|
221 |
+
* These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants.
|
222 |
+
* Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated).
|
223 |
+
* Matching the original pre-processing as closely as possible I get these results:
|
224 |
+
* `dm_nfnet_f6` - 86.352
|
225 |
+
* `dm_nfnet_f5` - 86.100
|
226 |
+
* `dm_nfnet_f4` - 85.834
|
227 |
+
* `dm_nfnet_f3` - 85.676
|
228 |
+
* `dm_nfnet_f2` - 85.178
|
229 |
+
* `dm_nfnet_f1` - 84.696
|
230 |
+
* `dm_nfnet_f0` - 83.464
|
231 |
+
|
232 |
+
### Feb 16, 2021
|
233 |
+
* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py.
|
234 |
+
* AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc`
|
235 |
+
* PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0`
|
236 |
+
* PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value`
|
237 |
+
* AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet.
|
238 |
+
|
239 |
+
### Feb 12, 2021
|
240 |
+
* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs
|
241 |
+
|
242 |
+
### Feb 10, 2021
|
243 |
+
* First Normalization-Free model training experiments done,
|
244 |
+
* nf_resnet50 - 80.68 top-1 @ 288x288, 80.31 @ 256x256
|
245 |
+
* nf_regnet_b1 - 79.30 @ 288x288, 78.75 @ 256x256
|
246 |
+
* More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks')
|
247 |
+
* GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in `byobnet.py`
|
248 |
+
* RepVGG (https://github.com/DingXiaoH/RepVGG), impl in `byobnet.py`
|
249 |
+
* classic VGG (from torchvision, impl in `vgg.py`)
|
250 |
+
* Refinements to normalizer layer arg handling and normalizer+act layer handling in some models
|
251 |
+
* Default AMP mode changed to native PyTorch AMP instead of APEX. Issues not being fixed with APEX. Native works with `--channels-last` and `--torchscript` model training, APEX does not.
|
252 |
+
* Fix a few bugs introduced since last pypi release
|
253 |
+
|
254 |
+
### Feb 8, 2021
|
255 |
+
* Add several ResNet weights with ECA attention. 26t & 50t trained @ 256, test @ 320. 269d train @ 256, fine-tune @320, test @ 352.
|
256 |
+
* `ecaresnet26t` - 79.88 top-1 @ 320x320, 79.08 @ 256x256
|
257 |
+
* `ecaresnet50t` - 82.35 top-1 @ 320x320, 81.52 @ 256x256
|
258 |
+
* `ecaresnet269d` - 84.93 top-1 @ 352x352, 84.87 @ 320x320
|
259 |
+
* Remove separate tiered (`t`) vs tiered_narrow (`tn`) ResNet model defs, all `tn` changed to `t` and `t` models removed (`seresnext26t_32x4d` only model w/ weights that was removed).
|
260 |
+
* Support model default_cfgs with separate train vs test resolution `test_input_size` and remove extra `_320` suffix ResNet model defs that were just for test.
|
261 |
+
|
262 |
+
### Jan 30, 2021
|
263 |
+
* Add initial "Normalization Free" NF-RegNet-B* and NF-ResNet model definitions based on [paper](https://arxiv.org/abs/2101.08692)
|
264 |
+
|
265 |
+
### Jan 25, 2021
|
266 |
+
* Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer
|
267 |
+
* Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer
|
268 |
+
* ImageNet-21k ViT weights are added w/ model defs and representation layer (pre logits) support
|
269 |
+
* NOTE: ImageNet-21k classifier heads were zero'd in original weights, they are only useful for transfer learning
|
270 |
+
* Add model defs and weights for DeiT Vision Transformer models from https://github.com/facebookresearch/deit
|
271 |
+
* Refactor dataset classes into ImageDataset/IterableImageDataset + dataset specific parser classes
|
272 |
+
* Add Tensorflow-Datasets (TFDS) wrapper to allow use of TFDS image classification sets with train script
|
273 |
+
* Ex: `train.py /data/tfds --dataset tfds/oxford_iiit_pet --val-split test --model resnet50 -b 256 --amp --num-classes 37 --opt adamw --lr 3e-4 --weight-decay .001 --pretrained -j 2`
|
274 |
+
* Add improved .tar dataset parser that reads images from .tar, folder of .tar files, or .tar within .tar
|
275 |
+
* Run validation on full ImageNet-21k directly from tar w/ BiT model: `validate.py /data/fall11_whole.tar --model resnetv2_50x1_bitm_in21k --amp`
|
276 |
+
* Models in this update should be stable w/ possible exception of ViT/BiT, possibility of some regressions with train/val scripts and dataset handling
|
277 |
+
|
278 |
+
### Jan 3, 2021
|
279 |
+
* Add SE-ResNet-152D weights
|
280 |
+
* 256x256 val, 0.94 crop top-1 - 83.75
|
281 |
+
* 320x320 val, 1.0 crop - 84.36
|
282 |
+
* Update [results files](results/)
|
283 |
+
|
284 |
+
|
285 |
+
## Introduction
|
286 |
+
|
287 |
+
Py**T**orch **Im**age **M**odels (`timm`) is a collection of image models, layers, utilities, optimizers, schedulers, data-loaders / augmentations, and reference training / validation scripts that aim to pull together a wide variety of SOTA models with ability to reproduce ImageNet training results.
|
288 |
+
|
289 |
+
The work of many others is present here. I've tried to make sure all source material is acknowledged via links to github, arxiv papers, etc in the README, documentation, and code docstrings. Please let me know if I missed anything.
|
290 |
+
|
291 |
+
## Models
|
292 |
+
|
293 |
+
All model architecture families include variants with pretrained weights. There are specific model variants without any weights, it is NOT a bug. Help training new or better weights is always appreciated. Here are some example [training hparams](https://rwightman.github.io/pytorch-image-models/training_hparam_examples) to get you started.
|
294 |
+
|
295 |
+
A full version of the list below with source links can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/models/).
|
296 |
+
|
297 |
+
* Aggregating Nested Transformers - https://arxiv.org/abs/2105.12723
|
298 |
+
* BEiT - https://arxiv.org/abs/2106.08254
|
299 |
+
* Big Transfer ResNetV2 (BiT) - https://arxiv.org/abs/1912.11370
|
300 |
+
* Bottleneck Transformers - https://arxiv.org/abs/2101.11605
|
301 |
+
* CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239
|
302 |
+
* CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399
|
303 |
+
* ConvNeXt - https://arxiv.org/abs/2201.03545
|
304 |
+
* ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697
|
305 |
+
* CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
|
306 |
+
* DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877
|
307 |
+
* DenseNet - https://arxiv.org/abs/1608.06993
|
308 |
+
* DLA - https://arxiv.org/abs/1707.06484
|
309 |
+
* DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629
|
310 |
+
* EfficientNet (MBConvNet Family)
|
311 |
+
* EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252
|
312 |
+
* EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665
|
313 |
+
* EfficientNet (B0-B7) - https://arxiv.org/abs/1905.11946
|
314 |
+
* EfficientNet-EdgeTPU (S, M, L) - https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html
|
315 |
+
* EfficientNet V2 - https://arxiv.org/abs/2104.00298
|
316 |
+
* FBNet-C - https://arxiv.org/abs/1812.03443
|
317 |
+
* MixNet - https://arxiv.org/abs/1907.09595
|
318 |
+
* MNASNet B1, A1 (Squeeze-Excite), and Small - https://arxiv.org/abs/1807.11626
|
319 |
+
* MobileNet-V2 - https://arxiv.org/abs/1801.04381
|
320 |
+
* Single-Path NAS - https://arxiv.org/abs/1904.02877
|
321 |
+
* TinyNet - https://arxiv.org/abs/2010.14819
|
322 |
+
* GhostNet - https://arxiv.org/abs/1911.11907
|
323 |
+
* gMLP - https://arxiv.org/abs/2105.08050
|
324 |
+
* GPU-Efficient Networks - https://arxiv.org/abs/2006.14090
|
325 |
+
* Halo Nets - https://arxiv.org/abs/2103.12731
|
326 |
+
* HRNet - https://arxiv.org/abs/1908.07919
|
327 |
+
* Inception-V3 - https://arxiv.org/abs/1512.00567
|
328 |
+
* Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
|
329 |
+
* Lambda Networks - https://arxiv.org/abs/2102.08602
|
330 |
+
* LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136
|
331 |
+
* MLP-Mixer - https://arxiv.org/abs/2105.01601
|
332 |
+
* MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
|
333 |
+
* FBNet-V3 - https://arxiv.org/abs/2006.02049
|
334 |
+
* HardCoRe-NAS - https://arxiv.org/abs/2102.11646
|
335 |
+
* LCNet - https://arxiv.org/abs/2109.15099
|
336 |
+
* NASNet-A - https://arxiv.org/abs/1707.07012
|
337 |
+
* NesT - https://arxiv.org/abs/2105.12723
|
338 |
+
* NFNet-F - https://arxiv.org/abs/2102.06171
|
339 |
+
* NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
|
340 |
+
* PNasNet - https://arxiv.org/abs/1712.00559
|
341 |
+
* Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302
|
342 |
+
* RegNet - https://arxiv.org/abs/2003.13678
|
343 |
+
* RepVGG - https://arxiv.org/abs/2101.03697
|
344 |
+
* ResMLP - https://arxiv.org/abs/2105.03404
|
345 |
+
* ResNet/ResNeXt
|
346 |
+
* ResNet (v1b/v1.5) - https://arxiv.org/abs/1512.03385
|
347 |
+
* ResNeXt - https://arxiv.org/abs/1611.05431
|
348 |
+
* 'Bag of Tricks' / Gluon C, D, E, S variations - https://arxiv.org/abs/1812.01187
|
349 |
+
* Weakly-supervised (WSL) Instagram pretrained / ImageNet tuned ResNeXt101 - https://arxiv.org/abs/1805.00932
|
350 |
+
* Semi-supervised (SSL) / Semi-weakly Supervised (SWSL) ResNet/ResNeXts - https://arxiv.org/abs/1905.00546
|
351 |
+
* ECA-Net (ECAResNet) - https://arxiv.org/abs/1910.03151v4
|
352 |
+
* Squeeze-and-Excitation Networks (SEResNet) - https://arxiv.org/abs/1709.01507
|
353 |
+
* ResNet-RS - https://arxiv.org/abs/2103.07579
|
354 |
+
* Res2Net - https://arxiv.org/abs/1904.01169
|
355 |
+
* ResNeSt - https://arxiv.org/abs/2004.08955
|
356 |
+
* ReXNet - https://arxiv.org/abs/2007.00992
|
357 |
+
* SelecSLS - https://arxiv.org/abs/1907.00837
|
358 |
+
* Selective Kernel Networks - https://arxiv.org/abs/1903.06586
|
359 |
+
* Swin Transformer - https://arxiv.org/abs/2103.14030
|
360 |
+
* Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112
|
361 |
+
* TResNet - https://arxiv.org/abs/2003.13630
|
362 |
+
* Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf
|
363 |
+
* Visformer - https://arxiv.org/abs/2104.12533
|
364 |
+
* Vision Transformer - https://arxiv.org/abs/2010.11929
|
365 |
+
* VovNet V2 and V1 - https://arxiv.org/abs/1911.06667
|
366 |
+
* Xception - https://arxiv.org/abs/1610.02357
|
367 |
+
* Xception (Modified Aligned, Gluon) - https://arxiv.org/abs/1802.02611
|
368 |
+
* Xception (Modified Aligned, TF) - https://arxiv.org/abs/1802.02611
|
369 |
+
* XCiT (Cross-Covariance Image Transformers) - https://arxiv.org/abs/2106.09681
|
370 |
+
|
371 |
+
## Features
|
372 |
+
|
373 |
+
Several (less common) features that I often utilize in my projects are included. Many of their additions are the reason why I maintain my own set of models, instead of using others' via PIP:
|
374 |
+
|
375 |
+
* All models have a common default configuration interface and API for
|
376 |
+
* accessing/changing the classifier - `get_classifier` and `reset_classifier`
|
377 |
+
* doing a forward pass on just the features - `forward_features` (see [documentation](https://rwightman.github.io/pytorch-image-models/feature_extraction/))
|
378 |
+
* these makes it easy to write consistent network wrappers that work with any of the models
|
379 |
+
* All models support multi-scale feature map extraction (feature pyramids) via create_model (see [documentation](https://rwightman.github.io/pytorch-image-models/feature_extraction/))
|
380 |
+
* `create_model(name, features_only=True, out_indices=..., output_stride=...)`
|
381 |
+
* `out_indices` creation arg specifies which feature maps to return, these indices are 0 based and generally correspond to the `C(i + 1)` feature level.
|
382 |
+
* `output_stride` creation arg controls output stride of the network by using dilated convolutions. Most networks are stride 32 by default. Not all networks support this.
|
383 |
+
* feature map channel counts, reduction level (stride) can be queried AFTER model creation via the `.feature_info` member
|
384 |
+
* All models have a consistent pretrained weight loader that adapts last linear if necessary, and from 3 to 1 channel input if desired
|
385 |
+
* High performance [reference training, validation, and inference scripts](https://rwightman.github.io/pytorch-image-models/scripts/) that work in several process/GPU modes:
|
386 |
+
* NVIDIA DDP w/ a single GPU per process, multiple processes with APEX present (AMP mixed-precision optional)
|
387 |
+
* PyTorch DistributedDataParallel w/ multi-gpu, single process (AMP disabled as it crashes when enabled)
|
388 |
+
* PyTorch w/ single GPU single process (AMP optional)
|
389 |
+
* A dynamic global pool implementation that allows selecting from average pooling, max pooling, average + max, or concat([average, max]) at model creation. All global pooling is adaptive average by default and compatible with pretrained weights.
|
390 |
+
* A 'Test Time Pool' wrapper that can wrap any of the included models and usually provides improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs)
|
391 |
+
* Learning rate schedulers
|
392 |
+
* Ideas adopted from
|
393 |
+
* [AllenNLP schedulers](https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers)
|
394 |
+
* [FAIRseq lr_scheduler](https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler)
|
395 |
+
* SGDR: Stochastic Gradient Descent with Warm Restarts (https://arxiv.org/abs/1608.03983)
|
396 |
+
* Schedulers include `step`, `cosine` w/ restarts, `tanh` w/ restarts, `plateau`
|
397 |
+
* Optimizers:
|
398 |
+
* `rmsprop_tf` adapted from PyTorch RMSProp by myself. Reproduces much improved Tensorflow RMSProp behaviour.
|
399 |
+
* `radam` by [Liyuan Liu](https://github.com/LiyuanLucasLiu/RAdam) (https://arxiv.org/abs/1908.03265)
|
400 |
+
* `novograd` by [Masashi Kimura](https://github.com/convergence-lab/novograd) (https://arxiv.org/abs/1905.11286)
|
401 |
+
* `lookahead` adapted from impl by [Liam](https://github.com/alphadl/lookahead.pytorch) (https://arxiv.org/abs/1907.08610)
|
402 |
+
* `fused<name>` optimizers by name with [NVIDIA Apex](https://github.com/NVIDIA/apex/tree/master/apex/optimizers) installed
|
403 |
+
* `adamp` and `sgdp` by [Naver ClovAI](https://github.com/clovaai) (https://arxiv.org/abs/2006.08217)
|
404 |
+
* `adafactor` adapted from [FAIRSeq impl](https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py) (https://arxiv.org/abs/1804.04235)
|
405 |
+
* `adahessian` by [David Samuel](https://github.com/davda54/ada-hessian) (https://arxiv.org/abs/2006.00719)
|
406 |
+
* Random Erasing from [Zhun Zhong](https://github.com/zhunzhong07/Random-Erasing/blob/master/transforms.py) (https://arxiv.org/abs/1708.04896)
|
407 |
+
* Mixup (https://arxiv.org/abs/1710.09412)
|
408 |
+
* CutMix (https://arxiv.org/abs/1905.04899)
|
409 |
+
* AutoAugment (https://arxiv.org/abs/1805.09501) and RandAugment (https://arxiv.org/abs/1909.13719) ImageNet configurations modeled after impl for EfficientNet training (https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py)
|
410 |
+
* AugMix w/ JSD loss (https://arxiv.org/abs/1912.02781), JSD w/ clean + augmented mixing support works with AutoAugment and RandAugment as well
|
411 |
+
* SplitBachNorm - allows splitting batch norm layers between clean and augmented (auxiliary batch norm) data
|
412 |
+
* DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382)
|
413 |
+
* DropBlock (https://arxiv.org/abs/1810.12890)
|
414 |
+
* Blur Pooling (https://arxiv.org/abs/1904.11486)
|
415 |
+
* Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper?
|
416 |
+
* Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets)
|
417 |
+
* An extensive selection of channel and/or spatial attention modules:
|
418 |
+
* Bottleneck Transformer - https://arxiv.org/abs/2101.11605
|
419 |
+
* CBAM - https://arxiv.org/abs/1807.06521
|
420 |
+
* Effective Squeeze-Excitation (ESE) - https://arxiv.org/abs/1911.06667
|
421 |
+
* Efficient Channel Attention (ECA) - https://arxiv.org/abs/1910.03151
|
422 |
+
* Gather-Excite (GE) - https://arxiv.org/abs/1810.12348
|
423 |
+
* Global Context (GC) - https://arxiv.org/abs/1904.11492
|
424 |
+
* Halo - https://arxiv.org/abs/2103.12731
|
425 |
+
* Involution - https://arxiv.org/abs/2103.06255
|
426 |
+
* Lambda Layer - https://arxiv.org/abs/2102.08602
|
427 |
+
* Non-Local (NL) - https://arxiv.org/abs/1711.07971
|
428 |
+
* Squeeze-and-Excitation (SE) - https://arxiv.org/abs/1709.01507
|
429 |
+
* Selective Kernel (SK) - (https://arxiv.org/abs/1903.06586
|
430 |
+
* Split (SPLAT) - https://arxiv.org/abs/2004.08955
|
431 |
+
* Shifted Window (SWIN) - https://arxiv.org/abs/2103.14030
|
432 |
+
|
433 |
+
## Results
|
434 |
+
|
435 |
+
Model validation results can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/results/) and in the [results tables](results/README.md)
|
436 |
+
|
437 |
+
## Getting Started (Documentation)
|
438 |
+
|
439 |
+
My current [documentation](https://rwightman.github.io/pytorch-image-models/) for `timm` covers the basics.
|
440 |
+
|
441 |
+
[timmdocs](https://fastai.github.io/timmdocs/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
|
442 |
+
|
443 |
+
[paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`.
|
444 |
+
|
445 |
+
## Train, Validation, Inference Scripts
|
446 |
+
|
447 |
+
The root folder of the repository contains reference train, validation, and inference scripts that work with the included models and other features of this repository. They are adaptable for other datasets and use cases with a little hacking. See [documentation](https://rwightman.github.io/pytorch-image-models/scripts/) for some basics and [training hparams](https://rwightman.github.io/pytorch-image-models/training_hparam_examples) for some train examples that produce SOTA ImageNet results.
|
448 |
+
|
449 |
+
## Awesome PyTorch Resources
|
450 |
+
|
451 |
+
One of the greatest assets of PyTorch is the community and their contributions. A few of my favourite resources that pair well with the models and components here are listed below.
|
452 |
+
|
453 |
+
### Object Detection, Instance and Semantic Segmentation
|
454 |
+
* Detectron2 - https://github.com/facebookresearch/detectron2
|
455 |
+
* Segmentation Models (Semantic) - https://github.com/qubvel/segmentation_models.pytorch
|
456 |
+
* EfficientDet (Obj Det, Semantic soon) - https://github.com/rwightman/efficientdet-pytorch
|
457 |
+
|
458 |
+
### Computer Vision / Image Augmentation
|
459 |
+
* Albumentations - https://github.com/albumentations-team/albumentations
|
460 |
+
* Kornia - https://github.com/kornia/kornia
|
461 |
+
|
462 |
+
### Knowledge Distillation
|
463 |
+
* RepDistiller - https://github.com/HobbitLong/RepDistiller
|
464 |
+
* torchdistill - https://github.com/yoshitomo-matsubara/torchdistill
|
465 |
+
|
466 |
+
### Metric Learning
|
467 |
+
* PyTorch Metric Learning - https://github.com/KevinMusgrave/pytorch-metric-learning
|
468 |
+
|
469 |
+
### Training / Frameworks
|
470 |
+
* fastai - https://github.com/fastai/fastai
|
471 |
+
|
472 |
+
## Licenses
|
473 |
+
|
474 |
+
### Code
|
475 |
+
The code here is licensed Apache 2.0. I've taken care to make sure any third party code included or adapted has compatible (permissive) licenses such as MIT, BSD, etc. I've made an effort to avoid any GPL / LGPL conflicts. That said, it is your responsibility to ensure you comply with licenses here and conditions of any dependent licenses. Where applicable, I've linked the sources/references for various components in docstrings. If you think I've missed anything please create an issue.
|
476 |
+
|
477 |
+
### Pretrained Weights
|
478 |
+
So far all of the pretrained weights available here are pretrained on ImageNet with a select few that have some additional pretraining (see extra note below). ImageNet was released for non-commercial research purposes only (https://image-net.org/download). It's not clear what the implications of that are for the use of pretrained weights from that dataset. Any models I have trained with ImageNet are done for research purposes and one should assume that the original dataset license applies to the weights. It's best to seek legal advice if you intend to use the pretrained weights in a commercial product.
|
479 |
+
|
480 |
+
#### Pretrained on more than ImageNet
|
481 |
+
Several weights included or references here were pretrained with proprietary datasets that I do not have access to. These include the Facebook WSL, SSL, SWSL ResNe(Xt) and the Google Noisy Student EfficientNet models. The Facebook models have an explicit non-commercial license (CC-BY-NC 4.0, https://github.com/facebookresearch/semi-supervised-ImageNet1K-models, https://github.com/facebookresearch/WSL-Images). The Google models do not appear to have any restriction beyond the Apache 2.0 license (and ImageNet concerns). In either case, you should contact Facebook or Google with any questions.
|
482 |
+
|
483 |
+
## Citing
|
484 |
+
|
485 |
+
### BibTeX
|
486 |
+
|
487 |
+
```bibtex
|
488 |
+
@misc{rw2019timm,
|
489 |
+
author = {Ross Wightman},
|
490 |
+
title = {PyTorch Image Models},
|
491 |
+
year = {2019},
|
492 |
+
publisher = {GitHub},
|
493 |
+
journal = {GitHub repository},
|
494 |
+
doi = {10.5281/zenodo.4414861},
|
495 |
+
howpublished = {\url{https://github.com/rwightman/pytorch-image-models}}
|
496 |
+
}
|
497 |
+
```
|
498 |
+
|
499 |
+
### Latest DOI
|
500 |
+
|
501 |
+
[![DOI](https://zenodo.org/badge/168799526.svg)](https://zenodo.org/badge/latestdoi/168799526)
|
502 |
+
|
503 |
+
|
lib/timm-0.5.4.dist-info/RECORD
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
timm-0.5.4.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
2 |
+
timm-0.5.4.dist-info/LICENSE,sha256=cbERYg-jLBeoDM1tstp1nTGlkeSX2LXzghdPWdG1nUk,11343
|
3 |
+
timm-0.5.4.dist-info/METADATA,sha256=_o4k9R4FYZ1msA33NowGB-awL2oEA8zUsuZjK6xgB4c,36181
|
4 |
+
timm-0.5.4.dist-info/RECORD,,
|
5 |
+
timm-0.5.4.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6 |
+
timm-0.5.4.dist-info/WHEEL,sha256=ewwEueio1C2XeHTvT17n8dZUJgOvyCWCt0WVNLClP9o,92
|
7 |
+
timm-0.5.4.dist-info/top_level.txt,sha256=Mi21FFh17x9WpGQnfmIkFMK_kFw_m3eb_G3J8-UJ5SY,5
|
8 |
+
timm/__init__.py,sha256=9mTvNS2J6SoMaBrYEv6Xcmc9EGMytuFwkJgUYCXSArg,286
|
9 |
+
timm/__pycache__/__init__.cpython-310.pyc,,
|
10 |
+
timm/__pycache__/version.cpython-310.pyc,,
|
11 |
+
timm/data/__init__.py,sha256=UY7Kh_mkF-oCUoBc9hZUF_3PUlqmjm7kGhSHu9RuHoQ,553
|
12 |
+
timm/data/__pycache__/__init__.cpython-310.pyc,,
|
13 |
+
timm/data/__pycache__/auto_augment.cpython-310.pyc,,
|
14 |
+
timm/data/__pycache__/config.cpython-310.pyc,,
|
15 |
+
timm/data/__pycache__/constants.cpython-310.pyc,,
|
16 |
+
timm/data/__pycache__/dataset.cpython-310.pyc,,
|
17 |
+
timm/data/__pycache__/dataset_factory.cpython-310.pyc,,
|
18 |
+
timm/data/__pycache__/distributed_sampler.cpython-310.pyc,,
|
19 |
+
timm/data/__pycache__/loader.cpython-310.pyc,,
|
20 |
+
timm/data/__pycache__/mixup.cpython-310.pyc,,
|
21 |
+
timm/data/__pycache__/random_erasing.cpython-310.pyc,,
|
22 |
+
timm/data/__pycache__/real_labels.cpython-310.pyc,,
|
23 |
+
timm/data/__pycache__/tf_preprocessing.cpython-310.pyc,,
|
24 |
+
timm/data/__pycache__/transforms.cpython-310.pyc,,
|
25 |
+
timm/data/__pycache__/transforms_factory.cpython-310.pyc,,
|
26 |
+
timm/data/auto_augment.py,sha256=C4rrMeP0oAAhqPLzdEc_MBimKGnmaq-J1e6ik0XO3Pg,31686
|
27 |
+
timm/data/config.py,sha256=grh9sGvu3YNQJm0pOd09ZVRvz8dnMdbbqoFJDLT_TJw,2915
|
28 |
+
timm/data/constants.py,sha256=xc_A1oVqqTSSr9fNPijM3eZPm8bgbh8OEhwdsS-XwlQ,303
|
29 |
+
timm/data/dataset.py,sha256=XUBknhmSG8iuGJrZKF47emzVTX1vgxxBN3uc0E3p40Q,4805
|
30 |
+
timm/data/dataset_factory.py,sha256=FR5gDs5vofvjrdDrs-SaHI-5jb56BVqp4ChrJT-TreI,5533
|
31 |
+
timm/data/distributed_sampler.py,sha256=ZQ7KA3xEyWDLumq2aHjjQy9o_SpZgkdk8dmF-eP9BsQ,5125
|
32 |
+
timm/data/loader.py,sha256=h087Nqq-5Ct2AN8g_uqDWESDgMlA9hYwT03Qlv4xBQU,9924
|
33 |
+
timm/data/mixup.py,sha256=xA8ZMdVoVIfwZMA9vMdLDkcosJYInaZnVsQW1QX3ano,14722
|
34 |
+
timm/data/parsers/__init__.py,sha256=f1ffuO9Pj74PBI2y3h1iweooYJlTJlADGs8wdQuaZmw,42
|
35 |
+
timm/data/parsers/__pycache__/__init__.cpython-310.pyc,,
|
36 |
+
timm/data/parsers/__pycache__/class_map.cpython-310.pyc,,
|
37 |
+
timm/data/parsers/__pycache__/constants.cpython-310.pyc,,
|
38 |
+
timm/data/parsers/__pycache__/parser.cpython-310.pyc,,
|
39 |
+
timm/data/parsers/__pycache__/parser_factory.cpython-310.pyc,,
|
40 |
+
timm/data/parsers/__pycache__/parser_image_folder.cpython-310.pyc,,
|
41 |
+
timm/data/parsers/__pycache__/parser_image_in_tar.cpython-310.pyc,,
|
42 |
+
timm/data/parsers/__pycache__/parser_image_tar.cpython-310.pyc,,
|
43 |
+
timm/data/parsers/__pycache__/parser_tfds.cpython-310.pyc,,
|
44 |
+
timm/data/parsers/class_map.py,sha256=rYqhlPYplTRj86FH1CZRf76NB5zUECX4-D93T-u-SjQ,759
|
45 |
+
timm/data/parsers/constants.py,sha256=X4eulfViOcrn8_k-Awi6LTGHFK5vZKXeS8exbD_KtT4,43
|
46 |
+
timm/data/parsers/parser.py,sha256=BoIya5V--BuyXWEeYzi-rg8wsXI0UiEiuxQxkho0dCY,487
|
47 |
+
timm/data/parsers/parser_factory.py,sha256=Z6Vku9nKq1lcKgqwjuYCdid4lJyJH7Pyf9QnoVvHbGU,1078
|
48 |
+
timm/data/parsers/parser_image_folder.py,sha256=Up4Yv5NPvE5xkXv2AtoBc2FQY9Xt5AqUOkDD5Qr0OMs,2508
|
49 |
+
timm/data/parsers/parser_image_in_tar.py,sha256=VNRtMtJX8WkcNPxfQ0THVVb2fFwAH8f7m9-RskaK0Ak,8987
|
50 |
+
timm/data/parsers/parser_image_tar.py,sha256=E48D2RfbCftAty7mXtdZIX8d0yIPCEQIagLVkliEJ1o,2589
|
51 |
+
timm/data/parsers/parser_tfds.py,sha256=iZ6F8GGw-YbqN8DYOLtqZLTwQwjOLG5So9Rpt614lhY,15819
|
52 |
+
timm/data/random_erasing.py,sha256=zCv0vGv32QCHUIamd_HC_GnjLlODoyJBDoRuV7S2XCI,4767
|
53 |
+
timm/data/real_labels.py,sha256=D9pgNrsyiPIZTDVYRqFmkIYyJi-Dplql4RONZ3NNFTM,1590
|
54 |
+
timm/data/tf_preprocessing.py,sha256=vuJSsleBnS41C9upL4JTuO399nrojHV_WCSwtQrv694,9120
|
55 |
+
timm/data/transforms.py,sha256=0mPf22INsd3XPuBwdAAuhUptWjC-j-b_hFija_Irw2k,6194
|
56 |
+
timm/data/transforms_factory.py,sha256=8UuHo_2DG8eYnYMAHDv2BY4uSob6PokqEJ-CmC5Rk4k,8351
|
57 |
+
timm/loss/__init__.py,sha256=iCNB9bUAf69neNe1_XO0eeg1QXuxu6jRTAuy4V9yFL8,245
|
58 |
+
timm/loss/__pycache__/__init__.cpython-310.pyc,,
|
59 |
+
timm/loss/__pycache__/asymmetric_loss.cpython-310.pyc,,
|
60 |
+
timm/loss/__pycache__/binary_cross_entropy.cpython-310.pyc,,
|
61 |
+
timm/loss/__pycache__/cross_entropy.cpython-310.pyc,,
|
62 |
+
timm/loss/__pycache__/jsd.cpython-310.pyc,,
|
63 |
+
timm/loss/asymmetric_loss.py,sha256=YkMktzxiXncKK_GF5yBGDONOUENSdpzb7FJluCxuSlw,3322
|
64 |
+
timm/loss/binary_cross_entropy.py,sha256=gs6iNMKB2clMixxFnIhYKOxZ7LwFWWiHGQhIzjqgAA4,2030
|
65 |
+
timm/loss/cross_entropy.py,sha256=XDE19FnhYjeudAerb6UulIID34AmZoXQ1CPEAjEkCQM,1145
|
66 |
+
timm/loss/jsd.py,sha256=MFe8H_JC1srFE_FKinF7jMVIQYgNWgeT7kZL9WeIXGI,1595
|
67 |
+
timm/models/__init__.py,sha256=YvxGghp2a4O-xIlWtOFEJFMM1mCElnuZtNG525WiTtU,1784
|
68 |
+
timm/models/__pycache__/__init__.cpython-310.pyc,,
|
69 |
+
timm/models/__pycache__/beit.cpython-310.pyc,,
|
70 |
+
timm/models/__pycache__/byoanet.cpython-310.pyc,,
|
71 |
+
timm/models/__pycache__/byobnet.cpython-310.pyc,,
|
72 |
+
timm/models/__pycache__/cait.cpython-310.pyc,,
|
73 |
+
timm/models/__pycache__/coat.cpython-310.pyc,,
|
74 |
+
timm/models/__pycache__/convit.cpython-310.pyc,,
|
75 |
+
timm/models/__pycache__/convmixer.cpython-310.pyc,,
|
76 |
+
timm/models/__pycache__/convnext.cpython-310.pyc,,
|
77 |
+
timm/models/__pycache__/crossvit.cpython-310.pyc,,
|
78 |
+
timm/models/__pycache__/cspnet.cpython-310.pyc,,
|
79 |
+
timm/models/__pycache__/densenet.cpython-310.pyc,,
|
80 |
+
timm/models/__pycache__/dla.cpython-310.pyc,,
|
81 |
+
timm/models/__pycache__/dpn.cpython-310.pyc,,
|
82 |
+
timm/models/__pycache__/efficientnet.cpython-310.pyc,,
|
83 |
+
timm/models/__pycache__/efficientnet_blocks.cpython-310.pyc,,
|
84 |
+
timm/models/__pycache__/efficientnet_builder.cpython-310.pyc,,
|
85 |
+
timm/models/__pycache__/factory.cpython-310.pyc,,
|
86 |
+
timm/models/__pycache__/features.cpython-310.pyc,,
|
87 |
+
timm/models/__pycache__/fx_features.cpython-310.pyc,,
|
88 |
+
timm/models/__pycache__/ghostnet.cpython-310.pyc,,
|
89 |
+
timm/models/__pycache__/gluon_resnet.cpython-310.pyc,,
|
90 |
+
timm/models/__pycache__/gluon_xception.cpython-310.pyc,,
|
91 |
+
timm/models/__pycache__/hardcorenas.cpython-310.pyc,,
|
92 |
+
timm/models/__pycache__/helpers.cpython-310.pyc,,
|
93 |
+
timm/models/__pycache__/hrnet.cpython-310.pyc,,
|
94 |
+
timm/models/__pycache__/hub.cpython-310.pyc,,
|
95 |
+
timm/models/__pycache__/inception_resnet_v2.cpython-310.pyc,,
|
96 |
+
timm/models/__pycache__/inception_v3.cpython-310.pyc,,
|
97 |
+
timm/models/__pycache__/inception_v4.cpython-310.pyc,,
|
98 |
+
timm/models/__pycache__/levit.cpython-310.pyc,,
|
99 |
+
timm/models/__pycache__/mlp_mixer.cpython-310.pyc,,
|
100 |
+
timm/models/__pycache__/mobilenetv3.cpython-310.pyc,,
|
101 |
+
timm/models/__pycache__/nasnet.cpython-310.pyc,,
|
102 |
+
timm/models/__pycache__/nest.cpython-310.pyc,,
|
103 |
+
timm/models/__pycache__/nfnet.cpython-310.pyc,,
|
104 |
+
timm/models/__pycache__/pit.cpython-310.pyc,,
|
105 |
+
timm/models/__pycache__/pnasnet.cpython-310.pyc,,
|
106 |
+
timm/models/__pycache__/registry.cpython-310.pyc,,
|
107 |
+
timm/models/__pycache__/regnet.cpython-310.pyc,,
|
108 |
+
timm/models/__pycache__/res2net.cpython-310.pyc,,
|
109 |
+
timm/models/__pycache__/resnest.cpython-310.pyc,,
|
110 |
+
timm/models/__pycache__/resnet.cpython-310.pyc,,
|
111 |
+
timm/models/__pycache__/resnetv2.cpython-310.pyc,,
|
112 |
+
timm/models/__pycache__/rexnet.cpython-310.pyc,,
|
113 |
+
timm/models/__pycache__/selecsls.cpython-310.pyc,,
|
114 |
+
timm/models/__pycache__/senet.cpython-310.pyc,,
|
115 |
+
timm/models/__pycache__/sknet.cpython-310.pyc,,
|
116 |
+
timm/models/__pycache__/swin_transformer.cpython-310.pyc,,
|
117 |
+
timm/models/__pycache__/tnt.cpython-310.pyc,,
|
118 |
+
timm/models/__pycache__/tresnet.cpython-310.pyc,,
|
119 |
+
timm/models/__pycache__/twins.cpython-310.pyc,,
|
120 |
+
timm/models/__pycache__/vgg.cpython-310.pyc,,
|
121 |
+
timm/models/__pycache__/visformer.cpython-310.pyc,,
|
122 |
+
timm/models/__pycache__/vision_transformer.cpython-310.pyc,,
|
123 |
+
timm/models/__pycache__/vision_transformer_hybrid.cpython-310.pyc,,
|
124 |
+
timm/models/__pycache__/vovnet.cpython-310.pyc,,
|
125 |
+
timm/models/__pycache__/xception.cpython-310.pyc,,
|
126 |
+
timm/models/__pycache__/xception_aligned.cpython-310.pyc,,
|
127 |
+
timm/models/__pycache__/xcit.cpython-310.pyc,,
|
128 |
+
timm/models/beit.py,sha256=xUTeREocSxQizBpZh9suEzSr6T7J0wxTCV-LU1AhVRw,18558
|
129 |
+
timm/models/byoanet.py,sha256=Oq7z6zIGNUAYtrhalg1v-PM1GUY60rTJgtZFHZESc4o,18350
|
130 |
+
timm/models/byobnet.py,sha256=Ztjh8PPTWtiI9c6qDmPXj_tXAzqGGe1gCTAkLtxnRcc,62040
|
131 |
+
timm/models/cait.py,sha256=AT5n5q3rzZQL8qdcPg27Q2YlHxHXJ-tUbEpbhwuKEV0,14940
|
132 |
+
timm/models/coat.py,sha256=Mf7HSiTcilmiJn5u07QUbhz713G28CCiD12Be8mRuAs,26936
|
133 |
+
timm/models/convit.py,sha256=CpsyDS0P-PlrAs-hHsXMJhLXR9YPLGYfGCsDk3t42S4,13952
|
134 |
+
timm/models/convmixer.py,sha256=YT1IrJaUb2HFyvVYsGJN9Fb4i4eYVGxPPjJMfffvG0Q,3631
|
135 |
+
timm/models/convnext.py,sha256=jTdEBzMG9ZH_zsRXKN-2VeGfpodIWUQlphf6imnFPSg,17427
|
136 |
+
timm/models/crossvit.py,sha256=CP52LuFfiGTMs3iQIFJYmZSkDl2pYQaVARddYl9kOtI,22472
|
137 |
+
timm/models/cspnet.py,sha256=sqHI7IU9sqo1NuwGEhoHvUIzMp95MWFQHA9xDgsXMWg,18221
|
138 |
+
timm/models/densenet.py,sha256=7BW80_eMdBd54qOygP0jQGDedk9JSy1YPu0x2F0oBDQ,15611
|
139 |
+
timm/models/dla.py,sha256=rx7v7egYQocwGtr5TIeslP0AaJ88oFr0MdK4xzr2XlY,17205
|
140 |
+
timm/models/dpn.py,sha256=tFcBpES6yFM8jeyvI-giarXTz9YEZy2fex2vSQ7Lnhg,12448
|
141 |
+
timm/models/efficientnet.py,sha256=xKOm9HrD3AqF1In6ahRNPGP8KF7NuhL7W7W7PmcqPT8,97919
|
142 |
+
timm/models/efficientnet_blocks.py,sha256=i32acRskLshWDVKxkCL8T728baoD80FQLb-6zl-wi00,12442
|
143 |
+
timm/models/efficientnet_builder.py,sha256=ovUDS8FzdJqqZhI7hgay85kFMPhN87VQXvwEL0TRGkA,19459
|
144 |
+
timm/models/factory.py,sha256=EaYqVnAFXXGfTCPyxe0z-kAvjz4evZx9VE_mklkYmrE,3305
|
145 |
+
timm/models/features.py,sha256=DnO84Xi2mvMWIzK4kyN2SCePsMPxehSLWFZEGc_MItc,12155
|
146 |
+
timm/models/fx_features.py,sha256=Lu8PCRUiGe0SrjxTOyDb4BKVqzG1QN7E5nRFKErgEa0,2855
|
147 |
+
timm/models/ghostnet.py,sha256=D6tvP5ClRU6vUQzZWgk2exlhCCp71w0C-wwpDV17dy4,9326
|
148 |
+
timm/models/gluon_resnet.py,sha256=jTOSW9-gS6mu1V8omQM82nHvihn3oiI_5uPnDYqsua0,11362
|
149 |
+
timm/models/gluon_xception.py,sha256=UF0JRxJ1vPaQhYuNxz0hGjlS5IPHPBgrt2HSk4r5-Sk,8705
|
150 |
+
timm/models/hardcorenas.py,sha256=rpTmpo_aHC-i64-f2txejPCaitStjlYu0_ERX_Bul6o,8036
|
151 |
+
timm/models/helpers.py,sha256=hYyaTpH_4Rcn93prm70f-E_ahqM9HQJ0Cw0ma4qnpAo,22659
|
152 |
+
timm/models/hrnet.py,sha256=8LtXT2oEF9twTHWrwJW18S9OUVXC9RqHwTCnckIXDOM,29402
|
153 |
+
timm/models/hub.py,sha256=AdmCNspFlzLLohb3tFPiHv6R3ZFyXM12CxpjNM5gyts,5988
|
154 |
+
timm/models/inception_resnet_v2.py,sha256=Wg19hoczgHbWqHP5Uanf7zyhSUIYxs6P6RS6U3yB4uY,12464
|
155 |
+
timm/models/inception_v3.py,sha256=Cd7W-J6tao7jkQ9-sDvy0Guv1PvWFI81kxe5pYrSFbM,17469
|
156 |
+
timm/models/inception_v4.py,sha256=k2Yqzw-Ak3YxhCmIHXOfBgulBDiOugdxMLNmROTOgpk,10804
|
157 |
+
timm/models/layers/__init__.py,sha256=CI57cbXYovHX-pFO8O42eQwAd8D_XAnaXPp0IEezKgU,2213
|
158 |
+
timm/models/layers/__pycache__/__init__.cpython-310.pyc,,
|
159 |
+
timm/models/layers/__pycache__/activations.cpython-310.pyc,,
|
160 |
+
timm/models/layers/__pycache__/activations_jit.cpython-310.pyc,,
|
161 |
+
timm/models/layers/__pycache__/activations_me.cpython-310.pyc,,
|
162 |
+
timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-310.pyc,,
|
163 |
+
timm/models/layers/__pycache__/attention_pool2d.cpython-310.pyc,,
|
164 |
+
timm/models/layers/__pycache__/blur_pool.cpython-310.pyc,,
|
165 |
+
timm/models/layers/__pycache__/bottleneck_attn.cpython-310.pyc,,
|
166 |
+
timm/models/layers/__pycache__/cbam.cpython-310.pyc,,
|
167 |
+
timm/models/layers/__pycache__/classifier.cpython-310.pyc,,
|
168 |
+
timm/models/layers/__pycache__/cond_conv2d.cpython-310.pyc,,
|
169 |
+
timm/models/layers/__pycache__/config.cpython-310.pyc,,
|
170 |
+
timm/models/layers/__pycache__/conv2d_same.cpython-310.pyc,,
|
171 |
+
timm/models/layers/__pycache__/conv_bn_act.cpython-310.pyc,,
|
172 |
+
timm/models/layers/__pycache__/create_act.cpython-310.pyc,,
|
173 |
+
timm/models/layers/__pycache__/create_attn.cpython-310.pyc,,
|
174 |
+
timm/models/layers/__pycache__/create_conv2d.cpython-310.pyc,,
|
175 |
+
timm/models/layers/__pycache__/create_norm_act.cpython-310.pyc,,
|
176 |
+
timm/models/layers/__pycache__/drop.cpython-310.pyc,,
|
177 |
+
timm/models/layers/__pycache__/eca.cpython-310.pyc,,
|
178 |
+
timm/models/layers/__pycache__/evo_norm.cpython-310.pyc,,
|
179 |
+
timm/models/layers/__pycache__/gather_excite.cpython-310.pyc,,
|
180 |
+
timm/models/layers/__pycache__/global_context.cpython-310.pyc,,
|
181 |
+
timm/models/layers/__pycache__/halo_attn.cpython-310.pyc,,
|
182 |
+
timm/models/layers/__pycache__/helpers.cpython-310.pyc,,
|
183 |
+
timm/models/layers/__pycache__/inplace_abn.cpython-310.pyc,,
|
184 |
+
timm/models/layers/__pycache__/lambda_layer.cpython-310.pyc,,
|
185 |
+
timm/models/layers/__pycache__/linear.cpython-310.pyc,,
|
186 |
+
timm/models/layers/__pycache__/median_pool.cpython-310.pyc,,
|
187 |
+
timm/models/layers/__pycache__/mixed_conv2d.cpython-310.pyc,,
|
188 |
+
timm/models/layers/__pycache__/mlp.cpython-310.pyc,,
|
189 |
+
timm/models/layers/__pycache__/non_local_attn.cpython-310.pyc,,
|
190 |
+
timm/models/layers/__pycache__/norm.cpython-310.pyc,,
|
191 |
+
timm/models/layers/__pycache__/norm_act.cpython-310.pyc,,
|
192 |
+
timm/models/layers/__pycache__/padding.cpython-310.pyc,,
|
193 |
+
timm/models/layers/__pycache__/patch_embed.cpython-310.pyc,,
|
194 |
+
timm/models/layers/__pycache__/pool2d_same.cpython-310.pyc,,
|
195 |
+
timm/models/layers/__pycache__/selective_kernel.cpython-310.pyc,,
|
196 |
+
timm/models/layers/__pycache__/separable_conv.cpython-310.pyc,,
|
197 |
+
timm/models/layers/__pycache__/space_to_depth.cpython-310.pyc,,
|
198 |
+
timm/models/layers/__pycache__/split_attn.cpython-310.pyc,,
|
199 |
+
timm/models/layers/__pycache__/split_batchnorm.cpython-310.pyc,,
|
200 |
+
timm/models/layers/__pycache__/squeeze_excite.cpython-310.pyc,,
|
201 |
+
timm/models/layers/__pycache__/std_conv.cpython-310.pyc,,
|
202 |
+
timm/models/layers/__pycache__/test_time_pool.cpython-310.pyc,,
|
203 |
+
timm/models/layers/__pycache__/trace_utils.cpython-310.pyc,,
|
204 |
+
timm/models/layers/__pycache__/weight_init.cpython-310.pyc,,
|
205 |
+
timm/models/layers/activations.py,sha256=iT_WSweK1B14fAamfJOFcgqQ5DoBI6Fvt-X4fRuYsSM,4040
|
206 |
+
timm/models/layers/activations_jit.py,sha256=BQI8MYjZhJx0w6yTUpsl_x_tg-XFdbYDkYy5EEBQYIQ,2529
|
207 |
+
timm/models/layers/activations_me.py,sha256=Qlrh-NWXxC6OsxI1wppeBdsd8UGWXT8ECH95tFaXGEQ,5886
|
208 |
+
timm/models/layers/adaptive_avgmax_pool.py,sha256=I-lZ-AdvlTYguuxyTBgDWcquqhFf00uVujfVqZfLv_A,3890
|
209 |
+
timm/models/layers/attention_pool2d.py,sha256=wYx25PT4KZNVPdHW07mDkT19SIEcEtz3SNzh88Q98tw,6866
|
210 |
+
timm/models/layers/blur_pool.py,sha256=gVZRqXFUpOuH_ui98sTSsXsmnWXBmy9PoPlXgaMV8Q4,1591
|
211 |
+
timm/models/layers/bottleneck_attn.py,sha256=HLuZbyep1Nf9Qq9Aei81kCzQMs6U1aQBQRLrOnjnkHo,6895
|
212 |
+
timm/models/layers/cbam.py,sha256=5A0MAd2BaBBYKzWjbN_t81olp1tDMAoun816OyT9bVA,4418
|
213 |
+
timm/models/layers/classifier.py,sha256=GHJ80KXZu8sXOLLAB5S8zODJUFNxcZ-iIlHrxO63ymU,2231
|
214 |
+
timm/models/layers/cond_conv2d.py,sha256=YMnfZ9MSQyqqPQ1VxZsZsWRG1FAgWNWl9MHzXWZ1mWE,5129
|
215 |
+
timm/models/layers/config.py,sha256=Nna27P_B1cy4obs4Ma5_sd5VlXe_sCEYjv9ttyNABcE,3069
|
216 |
+
timm/models/layers/conv2d_same.py,sha256=2fv1zNaZJZgFJ1P5quM9pikQV-Pf620HRLX8ygQFHGU,1490
|
217 |
+
timm/models/layers/conv_bn_act.py,sha256=gTJJYA4-QOpL4prtWHc5aPfWbqrR10SAvaUwfnRIPN4,1404
|
218 |
+
timm/models/layers/create_act.py,sha256=CmKQZ1vu3bT7EVwmfLjSSEY6qLwR3Dea7ZR8mul5xdI,5359
|
219 |
+
timm/models/layers/create_attn.py,sha256=Z7uwbr07LDSBEz8kUaDwCuw5JX3xM-xTkzvsCXV4Duw,3526
|
220 |
+
timm/models/layers/create_conv2d.py,sha256=UH4RvUhNCz1ohkM3ayC6kXwBLuO1ro8go4jpgLgyjLs,1500
|
221 |
+
timm/models/layers/create_norm_act.py,sha256=Ln0UOFtMIWVJeAePJWx5qsLNSDl-b7Br3Z7XSzEVBqc,3450
|
222 |
+
timm/models/layers/drop.py,sha256=TVGBZLvuSEQrpNLH_FZOdUqanFRuY9a9Qq6yN8JmTgs,6732
|
223 |
+
timm/models/layers/eca.py,sha256=MiVhboDUqLUfeubpypWfaR3LMLHwgLCNsWO3iemcQFs,6386
|
224 |
+
timm/models/layers/evo_norm.py,sha256=HiGnhaOUYKIhBnVG6d9bwCGHO_BhsG7e75hDRQfu0E4,3519
|
225 |
+
timm/models/layers/gather_excite.py,sha256=53DHt6cySjPqd9NW3voZuhw8b9nUzvsG9NVl_D-9NAo,3824
|
226 |
+
timm/models/layers/global_context.py,sha256=aZWvij4J-T5I1rdTK725D6R0fDuJyYPDaXvl36QMmkw,2445
|
227 |
+
timm/models/layers/halo_attn.py,sha256=zMJkf9S-ocCvrfvWOe0I97UHTpEQIkP381DON3OXm-c,10662
|
228 |
+
timm/models/layers/helpers.py,sha256=pHZa-j8xR-BWLgflFyzvtwn9o1m52t5V__KauMkOutA,748
|
229 |
+
timm/models/layers/inplace_abn.py,sha256=4-8ZyftTMoNaU59NvUaQH8qpqBYeQDTIjjvYEwo1Lzg,3353
|
230 |
+
timm/models/layers/lambda_layer.py,sha256=WSlH2buUBDxRd5RFou5IN7iTFGS4nZL--j-XhrdOci8,5941
|
231 |
+
timm/models/layers/linear.py,sha256=baS2Wpl0vYELnvhnQ6Lw65jVotaJz5iGbroJJ9JmIRM,743
|
232 |
+
timm/models/layers/median_pool.py,sha256=b02v36VGvs_gCY9NhVwU7-mglcXJHzrJVzcEpEUuHBI,1737
|
233 |
+
timm/models/layers/mixed_conv2d.py,sha256=mRSmAUtpgHya_RdnUq4j85K5QS7JFTdSPUXOUTKgpmA,1843
|
234 |
+
timm/models/layers/mlp.py,sha256=hOai0-VxpEFGV_Lf4sb82Ko2vU5gJK-2U2F2GvnIFtU,4097
|
235 |
+
timm/models/layers/non_local_attn.py,sha256=58GuC8xjOFedZjWClPO_Bc9UJRl3J7giVm60F6bcsYo,6209
|
236 |
+
timm/models/layers/norm.py,sha256=qC3u_ncT88rnhc_Z3c6KP0JRMveDgHfoVQuVDV8sPjg,876
|
237 |
+
timm/models/layers/norm_act.py,sha256=WPUkzeBWPw-IHF-FoqYTyeoALxVXxc0bq5gW4PJnRBA,3545
|
238 |
+
timm/models/layers/padding.py,sha256=BjbtxJmui-DyeroZohYMdRAj5mR4o6pWgW04imi3hDI,2167
|
239 |
+
timm/models/layers/patch_embed.py,sha256=GXjHXzEvmDAOndeBuO3z7C4pkON1hzZeGi7NovFt1A4,1490
|
240 |
+
timm/models/layers/pool2d_same.py,sha256=UsmtWna5k5kfVTP25T1-OKJOgtcfBQCqSy0FmaZbjRw,3045
|
241 |
+
timm/models/layers/selective_kernel.py,sha256=Ltlwuw3gFRlmAUpxt0i44PGGaY1b1nNcfDO5AbqaK5U,5349
|
242 |
+
timm/models/layers/separable_conv.py,sha256=WrbdyyBteT3eqe5rN7TIrkqlWXfChP8f_0AMJS1sNDM,2528
|
243 |
+
timm/models/layers/space_to_depth.py,sha256=P-9czBYvbbPR7Ji-fB0dVFA_yABJQoO9C3IA4sL9ttI,1750
|
244 |
+
timm/models/layers/split_attn.py,sha256=HHpEHYCiBk0ecrnZpTVbb-4r4q1065WiCR6aC3AGfMU,3085
|
245 |
+
timm/models/layers/split_batchnorm.py,sha256=4ghGtliK5z0ZnzR29zJB_rN9BJPiGuy1PSltmVyF5Ww,3441
|
246 |
+
timm/models/layers/squeeze_excite.py,sha256=krSwp4li7AWhSP4zJLMZUwl_XIdHJk_Eb81iQ_s8FnA,3018
|
247 |
+
timm/models/layers/std_conv.py,sha256=zYhcKCbE0_Rqn422gEM9gr3LeBewu0CXKqvlsa9-M2Q,5887
|
248 |
+
timm/models/layers/test_time_pool.py,sha256=oQbw-agOC6sc3MjvbvsxrBtDa62r9gYiTEW01ICqjDY,1995
|
249 |
+
timm/models/layers/trace_utils.py,sha256=cbZufOaGKmhTGEMc52QAnqzGRTfn4vvzqsAOJaLKJQ8,335
|
250 |
+
timm/models/layers/weight_init.py,sha256=EbhK0ecja64ZJ4eXLpDVs7kLseumJBlSzWgbG9v5HL4,3324
|
251 |
+
timm/models/levit.py,sha256=2jBuY_jBe-orwUyhs7xfAs4sV0_aTSkZgPdWQh0-yi8,21163
|
252 |
+
timm/models/mlp_mixer.py,sha256=TC0HsK--Bxrp0aZddfCoO9bq5wpkAYtaJbElLmXbgCw,26042
|
253 |
+
timm/models/mobilenetv3.py,sha256=XJimeE7IFicHJtwqEd3SBOPhA-HJwgNUHF7oox2O_AE,26586
|
254 |
+
timm/models/nasnet.py,sha256=dZq20FEUebcRddTkpPVWpzL2HcATVttE4Ma6VsXIAzw,25944
|
255 |
+
timm/models/nest.py,sha256=dyR4nEKmtBdIp0nH1Fz6VcuQeAxYDF-8qQa8ROXTETc,19521
|
256 |
+
timm/models/nfnet.py,sha256=4A7W_-2Q-e2x5V1ovRYt0zzq7PNMvzG9RAlCeMPmtjU,41006
|
257 |
+
timm/models/pit.py,sha256=LpXaTnzYZ5WCfCtDnXjG2CJTh-A_lLEiqoRa2PrisME,13037
|
258 |
+
timm/models/pnasnet.py,sha256=8FISd7eO2N7flUeSsQwSyVm2DzhJiQWbm-ECtxwjh9w,14961
|
259 |
+
timm/models/pruned/ecaresnet101d_pruned.txt,sha256=1zA7XaxsTnFJxZ9PMbfMVST7wPSQcAV-UzSgdFfGgYY,8734
|
260 |
+
timm/models/pruned/ecaresnet50d_pruned.txt,sha256=J4AlTwabaSB6-XrINCPCDWMiM_FrdNjuJN_JJRb89WE,4520
|
261 |
+
timm/models/pruned/efficientnet_b1_pruned.txt,sha256=pNDm1EENJYMT8-GjXZ3kXWCXADLDun-4jfigh74RELE,18596
|
262 |
+
timm/models/pruned/efficientnet_b2_pruned.txt,sha256=e_oaSVM-Ux3NMVARynJ74YwjzxuBAX_w7kzOw9Ml3gM,18676
|
263 |
+
timm/models/pruned/efficientnet_b3_pruned.txt,sha256=A1DJEwjEmrg8oUr0QwzwBkdAJV3dVeUFjnO9pNC_0Pg,21133
|
264 |
+
timm/models/registry.py,sha256=YRSPZojqIfVIQ4z5yLK-LnNZZ2FXJ__cP3bNFVVbVck,5680
|
265 |
+
timm/models/regnet.py,sha256=Pk5Ybc-GP-6yHQ2uSn5nDQ3jkBl9C2rkCUW8wyluDfs,20998
|
266 |
+
timm/models/res2net.py,sha256=Przr9hV-FJQS7Ez0z_exGsPslnLefynBN4mAPGAtFH8,7865
|
267 |
+
timm/models/resnest.py,sha256=rOwUAV4vQHb89q54qVIra8eOXlpklqr8Aep7rd1Axig,10092
|
268 |
+
timm/models/resnet.py,sha256=zkPFrwSd2KXqCSMT7B0Zc_mXQ_D4mFKDkaigEF4XoW4,65986
|
269 |
+
timm/models/resnetv2.py,sha256=-4BvS5XIMFccvmaPRDWU6z7geIfaidix5-SwrVuvhug,28274
|
270 |
+
timm/models/rexnet.py,sha256=ee5vM9NPLRA-Ilo6ifJp39lGtXdlenAN1DALeVLHff8,9228
|
271 |
+
timm/models/selecsls.py,sha256=YrNit4pigYrqKuQPsBi8nzfXn4lVU8hva17bAcCB1a4,13124
|
272 |
+
timm/models/senet.py,sha256=MEvKLC_Msa1nK-_fzCDSjTgk9ysS4hHu7qmGBi4n7iU,17642
|
273 |
+
timm/models/sknet.py,sha256=3_MNhX4Jg3A8TAV7vJxcVaMKYbUXO-rF9JbNRliPD-w,8742
|
274 |
+
timm/models/swin_transformer.py,sha256=5So8rK82wHSu2ZkfHCptDWzrg4yZOGsLnM_iBd5PK3c,27372
|
275 |
+
timm/models/tnt.py,sha256=J28c3t1go0DPkVLf2Om6fwEES10cvCSxP9Ailp7dDt8,11209
|
276 |
+
timm/models/tresnet.py,sha256=2dUbOytAqAgFxNog9qFtM8M4wwep9-LOaOsgl7RanDs,11596
|
277 |
+
timm/models/twins.py,sha256=wL3gkk_Gwe220k_Td_VIHDcA1AiLBouGV4HRZJubjnE,17346
|
278 |
+
timm/models/vgg.py,sha256=ZXmmA7kUrvkwZJtAjjx_oDPzIY3S5yLWhGImCA0v1NE,10455
|
279 |
+
timm/models/visformer.py,sha256=iN7DjQ1aI3EQCil8g9Jv_S6a7IR2639BVObnVFXKwk0,16086
|
280 |
+
timm/models/vision_transformer.py,sha256=XINVbQd-cQTX0xBlnMSS1ZGy6rLUGYej-4iVs16rGNs,46475
|
281 |
+
timm/models/vision_transformer_hybrid.py,sha256=XGcP5sUmw-zixwi7UKDS3jx6PvAY0lUWnOc3DmSAX5s,16099
|
282 |
+
timm/models/vovnet.py,sha256=fnzwB_ciOog2Hbt-340QAT_UQq96HqIPRGNF09IuCcg,13845
|
283 |
+
timm/models/xception.py,sha256=lDcnUwDMVScwXzrplIeKPpSx_-Lr2ItfI0VPhP1KpjY,7388
|
284 |
+
timm/models/xception_aligned.py,sha256=BdbW2j5ql4mD0-TAqV6FtlnxSXSaK7Oe6TR-5TegzCQ,8948
|
285 |
+
timm/models/xcit.py,sha256=LMcM5J8aaKN119DYYZeisaaYNwgdMSlxYQujld4dMJE,35892
|
286 |
+
timm/optim/__init__.py,sha256=z1mMVKZ9loGnBRnS8SeE_F259As2DbXSXzHcRpuWy2E,484
|
287 |
+
timm/optim/__pycache__/__init__.cpython-310.pyc,,
|
288 |
+
timm/optim/__pycache__/adabelief.cpython-310.pyc,,
|
289 |
+
timm/optim/__pycache__/adafactor.cpython-310.pyc,,
|
290 |
+
timm/optim/__pycache__/adahessian.cpython-310.pyc,,
|
291 |
+
timm/optim/__pycache__/adamp.cpython-310.pyc,,
|
292 |
+
timm/optim/__pycache__/adamw.cpython-310.pyc,,
|
293 |
+
timm/optim/__pycache__/lamb.cpython-310.pyc,,
|
294 |
+
timm/optim/__pycache__/lars.cpython-310.pyc,,
|
295 |
+
timm/optim/__pycache__/lookahead.cpython-310.pyc,,
|
296 |
+
timm/optim/__pycache__/madgrad.cpython-310.pyc,,
|
297 |
+
timm/optim/__pycache__/nadam.cpython-310.pyc,,
|
298 |
+
timm/optim/__pycache__/nvnovograd.cpython-310.pyc,,
|
299 |
+
timm/optim/__pycache__/optim_factory.cpython-310.pyc,,
|
300 |
+
timm/optim/__pycache__/radam.cpython-310.pyc,,
|
301 |
+
timm/optim/__pycache__/rmsprop_tf.cpython-310.pyc,,
|
302 |
+
timm/optim/__pycache__/sgdp.cpython-310.pyc,,
|
303 |
+
timm/optim/adabelief.py,sha256=n8nVbFX0TrCgkI98s7sV9D1l_rwPoqgVdfUW1KxGMPY,9827
|
304 |
+
timm/optim/adafactor.py,sha256=UOYdbisCGOXJJF4sklBa4XEb3m68IyV6IkzcEopGack,7459
|
305 |
+
timm/optim/adahessian.py,sha256=vJtQ8bZTGLrkMYuGPOJdgO-5V8hjVvM2Il-HSqg59Ao,6535
|
306 |
+
timm/optim/adamp.py,sha256=PSJYfobQvxy9K0tdU6-mjaiF4BqhIXY9sHV2vposx5I,3574
|
307 |
+
timm/optim/adamw.py,sha256=OKSBGfaWs6DJC1aXJHadAp4FADAnDDwb-ZRKuPao7zk,5147
|
308 |
+
timm/optim/lamb.py,sha256=II9zTpcxWzNqgk4K-bs5VGKlQPabUolSAmHkcSjsqSU,9184
|
309 |
+
timm/optim/lars.py,sha256=8Ytu-q4FvXQWTEcP7R-8xSKdb72c2s1XhTvMzIshBME,5255
|
310 |
+
timm/optim/lookahead.py,sha256=nd42FXVedX6qlnyBXGMcxkj1IsUUOtwbVFa4dQYy83M,2463
|
311 |
+
timm/optim/madgrad.py,sha256=V3LJuPjGwiO7RdHAZFF0Qqa8JT8a9DJJLSEO2PCG7Ho,6893
|
312 |
+
timm/optim/nadam.py,sha256=ASEISt72rXnpfqVkKfgotJXBYpsyG9Pr17I8VFO6Eac,3871
|
313 |
+
timm/optim/nvnovograd.py,sha256=NkRLq007qqiRDrhqiZK1KP_kfCcFcDSYCWRcoYvddOQ,4856
|
314 |
+
timm/optim/optim_factory.py,sha256=te26CtKS1wOh1gwEeB6glHdMGxOyKH7xhtbC_V91upQ,8415
|
315 |
+
timm/optim/radam.py,sha256=dCeFJGKo5WC8w7Ad8tuldM6QFz41nYXJIYI5HkH6uxk,3468
|
316 |
+
timm/optim/rmsprop_tf.py,sha256=SX47YRaLPNB-YpJpLUbXqx21ZFoDPeqvpJX2kin4wCc,6143
|
317 |
+
timm/optim/sgdp.py,sha256=7f4ZMVHbjCTDTgPOZfE06S4lmdUBnIBCDr_Yzy1RFhY,2296
|
318 |
+
timm/scheduler/__init__.py,sha256=WhoyJyfj6SE2YIE06C1eMmnqr7tm5m17YG5s4uL9lXU,291
|
319 |
+
timm/scheduler/__pycache__/__init__.cpython-310.pyc,,
|
320 |
+
timm/scheduler/__pycache__/cosine_lr.cpython-310.pyc,,
|
321 |
+
timm/scheduler/__pycache__/multistep_lr.cpython-310.pyc,,
|
322 |
+
timm/scheduler/__pycache__/plateau_lr.cpython-310.pyc,,
|
323 |
+
timm/scheduler/__pycache__/poly_lr.cpython-310.pyc,,
|
324 |
+
timm/scheduler/__pycache__/scheduler.cpython-310.pyc,,
|
325 |
+
timm/scheduler/__pycache__/scheduler_factory.cpython-310.pyc,,
|
326 |
+
timm/scheduler/__pycache__/step_lr.cpython-310.pyc,,
|
327 |
+
timm/scheduler/__pycache__/tanh_lr.cpython-310.pyc,,
|
328 |
+
timm/scheduler/cosine_lr.py,sha256=VNWO_gQoFM436vnse7dgTZQtFbfx1PT3S-lwS1l4MLw,4161
|
329 |
+
timm/scheduler/multistep_lr.py,sha256=je8uCJHIlnyhvgFKhkrpSyQtCYl2G5QWWnZNVwPo8YQ,2098
|
330 |
+
timm/scheduler/plateau_lr.py,sha256=831LB8XE2nGCSmXSeWhSN8bDr_S4rS9mAR1TZNl-ttA,4140
|
331 |
+
timm/scheduler/poly_lr.py,sha256=k-XGB63zQdCYtUVP6WiUocaKgAV4dqwOHna9gw0p0tQ,4003
|
332 |
+
timm/scheduler/scheduler.py,sha256=t1F7sPeaMzTio6xeyxS9z2nLKMYt08c0xOPSEtCVUig,4750
|
333 |
+
timm/scheduler/scheduler_factory.py,sha256=hXi1jEFNLYuFqHPOOa1oW00g-dsqJ99ExhvjBAjSy0w,3682
|
334 |
+
timm/scheduler/step_lr.py,sha256=2L8uA_mq5_25jL0WASNnaQ3IkX9q5cZgE789PBNXryg,1902
|
335 |
+
timm/scheduler/tanh_lr.py,sha256=ApHa_ziKBwqZpcSm1xen_siQmFl-ja3yRJyck08F_04,3936
|
336 |
+
timm/utils/__init__.py,sha256=5Xoo-5kP6dZ5f6LV9u1DCRIqQrcX4OciHLs29DNcuEU,587
|
337 |
+
timm/utils/__pycache__/__init__.cpython-310.pyc,,
|
338 |
+
timm/utils/__pycache__/agc.cpython-310.pyc,,
|
339 |
+
timm/utils/__pycache__/checkpoint_saver.cpython-310.pyc,,
|
340 |
+
timm/utils/__pycache__/clip_grad.cpython-310.pyc,,
|
341 |
+
timm/utils/__pycache__/cuda.cpython-310.pyc,,
|
342 |
+
timm/utils/__pycache__/distributed.cpython-310.pyc,,
|
343 |
+
timm/utils/__pycache__/jit.cpython-310.pyc,,
|
344 |
+
timm/utils/__pycache__/log.cpython-310.pyc,,
|
345 |
+
timm/utils/__pycache__/metrics.cpython-310.pyc,,
|
346 |
+
timm/utils/__pycache__/misc.cpython-310.pyc,,
|
347 |
+
timm/utils/__pycache__/model.cpython-310.pyc,,
|
348 |
+
timm/utils/__pycache__/model_ema.cpython-310.pyc,,
|
349 |
+
timm/utils/__pycache__/random.cpython-310.pyc,,
|
350 |
+
timm/utils/__pycache__/summary.cpython-310.pyc,,
|
351 |
+
timm/utils/agc.py,sha256=6lZCChfbW0KGNMfkzztWD_NP87ESopjk24Xtb3WbBqU,1624
|
352 |
+
timm/utils/checkpoint_saver.py,sha256=RljigPicMAHnk48K2Qbl17cWnQepgO4QMZQ0FCjd8xw,6133
|
353 |
+
timm/utils/clip_grad.py,sha256=iYFEf7fvPbpyh5K1SI-EKey5Gqs2gztR9VUUGja0GB0,796
|
354 |
+
timm/utils/cuda.py,sha256=nerhMCalMMv1QHZXW1brcIvtzpWwalWYLTb6vJz2bnY,1703
|
355 |
+
timm/utils/distributed.py,sha256=MD1xa17GKPyPoUJ4lIkxucmD635GbKNuVvcsVCQlhsc,896
|
356 |
+
timm/utils/jit.py,sha256=OpEbtA3TgJTNDfKWX2oOsd1p4Sm5RQqZnZz5N78t2lk,648
|
357 |
+
timm/utils/log.py,sha256=BdZ2OqWo3v8d7wsDRJ-uACcoeNUhS8TJSwI3CYvq3Ss,1015
|
358 |
+
timm/utils/metrics.py,sha256=RSHpbbkyW6FsbxT6TzcBL7MZh4sv4A_GG1Bo8aN5qKc,901
|
359 |
+
timm/utils/misc.py,sha256=o6ZbvZJB-M6NX73E3ZFKzpKpbnV2bU-ZEj-RDlz-P58,644
|
360 |
+
timm/utils/model.py,sha256=IMc8JCt89gmCjlb68k-6RNb7wR4ZawWjG74Y4PMiSdo,12085
|
361 |
+
timm/utils/model_ema.py,sha256=PG0B6k198xf36sK1PrT32LaMe6zKlzUPPZmaPoSRbiQ,5670
|
362 |
+
timm/utils/random.py,sha256=Ysv6F3nIO8JYE8j6UrDxGyJDp3uNpq5v8U0KqL_8dic,178
|
363 |
+
timm/utils/summary.py,sha256=pdvBnXsLAS4CPGlMhB0lkUyizKGY4XP9M9f-VWlaJA0,1184
|
364 |
+
timm/version.py,sha256=YYYpMsC3ActrKnohyOQTZhA6i4ZGdIcgCekNZ1_03lo,22
|
lib/timm-0.5.4.dist-info/REQUESTED
ADDED
File without changes
|
lib/timm-0.5.4.dist-info/WHEEL
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Wheel-Version: 1.0
|
2 |
+
Generator: bdist_wheel (0.37.0)
|
3 |
+
Root-Is-Purelib: true
|
4 |
+
Tag: py3-none-any
|
5 |
+
|
lib/timm-0.5.4.dist-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
timm
|
lib/timm/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .version import __version__
|
2 |
+
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
|
3 |
+
is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
|
4 |
+
get_model_default_value, is_model_pretrained
|
lib/timm/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (549 Bytes). View file
|
|
lib/timm/__pycache__/version.cpython-310.pyc
ADDED
Binary file (167 Bytes). View file
|
|
lib/timm/data/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
2 |
+
rand_augment_transform, auto_augment_transform
|
3 |
+
from .config import resolve_data_config
|
4 |
+
from .constants import *
|
5 |
+
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
6 |
+
from .dataset_factory import create_dataset
|
7 |
+
from .loader import create_loader
|
8 |
+
from .mixup import Mixup, FastCollateMixup
|
9 |
+
from .parsers import create_parser
|
10 |
+
from .real_labels import RealLabelsImagenet
|
11 |
+
from .transforms import *
|
12 |
+
from .transforms_factory import create_transform
|
lib/timm/data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (860 Bytes). View file
|
|
lib/timm/data/__pycache__/auto_augment.cpython-310.pyc
ADDED
Binary file (24 kB). View file
|
|
lib/timm/data/__pycache__/config.cpython-310.pyc
ADDED
Binary file (1.57 kB). View file
|
|
lib/timm/data/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (487 Bytes). View file
|
|
lib/timm/data/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (5.05 kB). View file
|
|
lib/timm/data/__pycache__/dataset_factory.cpython-310.pyc
ADDED
Binary file (4.25 kB). View file
|
|
lib/timm/data/__pycache__/distributed_sampler.cpython-310.pyc
ADDED
Binary file (4.12 kB). View file
|
|
lib/timm/data/__pycache__/loader.cpython-310.pyc
ADDED
Binary file (7.96 kB). View file
|
|
lib/timm/data/__pycache__/mixup.cpython-310.pyc
ADDED
Binary file (11.3 kB). View file
|
|
lib/timm/data/__pycache__/random_erasing.cpython-310.pyc
ADDED
Binary file (3.94 kB). View file
|
|
lib/timm/data/__pycache__/real_labels.cpython-310.pyc
ADDED
Binary file (2.4 kB). View file
|
|
lib/timm/data/__pycache__/tf_preprocessing.cpython-310.pyc
ADDED
Binary file (7.21 kB). View file
|
|
lib/timm/data/__pycache__/transforms.cpython-310.pyc
ADDED
Binary file (6.34 kB). View file
|
|
lib/timm/data/__pycache__/transforms_factory.cpython-310.pyc
ADDED
Binary file (5.07 kB). View file
|
|
lib/timm/data/auto_augment.py
ADDED
@@ -0,0 +1,865 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" AutoAugment, RandAugment, and AugMix for PyTorch
|
2 |
+
|
3 |
+
This code implements the searched ImageNet policies with various tweaks and improvements and
|
4 |
+
does not include any of the search code.
|
5 |
+
|
6 |
+
AA and RA Implementation adapted from:
|
7 |
+
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
|
8 |
+
|
9 |
+
AugMix adapted from:
|
10 |
+
https://github.com/google-research/augmix
|
11 |
+
|
12 |
+
Papers:
|
13 |
+
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
|
14 |
+
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
|
15 |
+
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
|
16 |
+
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
|
17 |
+
|
18 |
+
Hacked together by / Copyright 2019, Ross Wightman
|
19 |
+
"""
|
20 |
+
import random
|
21 |
+
import math
|
22 |
+
import re
|
23 |
+
from PIL import Image, ImageOps, ImageEnhance, ImageChops
|
24 |
+
import PIL
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
|
28 |
+
_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
|
29 |
+
|
30 |
+
_FILL = (128, 128, 128)
|
31 |
+
|
32 |
+
_LEVEL_DENOM = 10. # denominator for conversion from 'Mx' magnitude scale to fractional aug level for op arguments
|
33 |
+
|
34 |
+
_HPARAMS_DEFAULT = dict(
|
35 |
+
translate_const=250,
|
36 |
+
img_mean=_FILL,
|
37 |
+
)
|
38 |
+
|
39 |
+
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
40 |
+
|
41 |
+
|
42 |
+
def _interpolation(kwargs):
|
43 |
+
interpolation = kwargs.pop('resample', Image.BILINEAR)
|
44 |
+
if isinstance(interpolation, (list, tuple)):
|
45 |
+
return random.choice(interpolation)
|
46 |
+
else:
|
47 |
+
return interpolation
|
48 |
+
|
49 |
+
|
50 |
+
def _check_args_tf(kwargs):
|
51 |
+
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
|
52 |
+
kwargs.pop('fillcolor')
|
53 |
+
kwargs['resample'] = _interpolation(kwargs)
|
54 |
+
|
55 |
+
|
56 |
+
def shear_x(img, factor, **kwargs):
|
57 |
+
_check_args_tf(kwargs)
|
58 |
+
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
|
59 |
+
|
60 |
+
|
61 |
+
def shear_y(img, factor, **kwargs):
|
62 |
+
_check_args_tf(kwargs)
|
63 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
|
64 |
+
|
65 |
+
|
66 |
+
def translate_x_rel(img, pct, **kwargs):
|
67 |
+
pixels = pct * img.size[0]
|
68 |
+
_check_args_tf(kwargs)
|
69 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
70 |
+
|
71 |
+
|
72 |
+
def translate_y_rel(img, pct, **kwargs):
|
73 |
+
pixels = pct * img.size[1]
|
74 |
+
_check_args_tf(kwargs)
|
75 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
76 |
+
|
77 |
+
|
78 |
+
def translate_x_abs(img, pixels, **kwargs):
|
79 |
+
_check_args_tf(kwargs)
|
80 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
|
81 |
+
|
82 |
+
|
83 |
+
def translate_y_abs(img, pixels, **kwargs):
|
84 |
+
_check_args_tf(kwargs)
|
85 |
+
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
|
86 |
+
|
87 |
+
|
88 |
+
def rotate(img, degrees, **kwargs):
|
89 |
+
_check_args_tf(kwargs)
|
90 |
+
if _PIL_VER >= (5, 2):
|
91 |
+
return img.rotate(degrees, **kwargs)
|
92 |
+
elif _PIL_VER >= (5, 0):
|
93 |
+
w, h = img.size
|
94 |
+
post_trans = (0, 0)
|
95 |
+
rotn_center = (w / 2.0, h / 2.0)
|
96 |
+
angle = -math.radians(degrees)
|
97 |
+
matrix = [
|
98 |
+
round(math.cos(angle), 15),
|
99 |
+
round(math.sin(angle), 15),
|
100 |
+
0.0,
|
101 |
+
round(-math.sin(angle), 15),
|
102 |
+
round(math.cos(angle), 15),
|
103 |
+
0.0,
|
104 |
+
]
|
105 |
+
|
106 |
+
def transform(x, y, matrix):
|
107 |
+
(a, b, c, d, e, f) = matrix
|
108 |
+
return a * x + b * y + c, d * x + e * y + f
|
109 |
+
|
110 |
+
matrix[2], matrix[5] = transform(
|
111 |
+
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
|
112 |
+
)
|
113 |
+
matrix[2] += rotn_center[0]
|
114 |
+
matrix[5] += rotn_center[1]
|
115 |
+
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
|
116 |
+
else:
|
117 |
+
return img.rotate(degrees, resample=kwargs['resample'])
|
118 |
+
|
119 |
+
|
120 |
+
def auto_contrast(img, **__):
|
121 |
+
return ImageOps.autocontrast(img)
|
122 |
+
|
123 |
+
|
124 |
+
def invert(img, **__):
|
125 |
+
return ImageOps.invert(img)
|
126 |
+
|
127 |
+
|
128 |
+
def equalize(img, **__):
|
129 |
+
return ImageOps.equalize(img)
|
130 |
+
|
131 |
+
|
132 |
+
def solarize(img, thresh, **__):
|
133 |
+
return ImageOps.solarize(img, thresh)
|
134 |
+
|
135 |
+
|
136 |
+
def solarize_add(img, add, thresh=128, **__):
|
137 |
+
lut = []
|
138 |
+
for i in range(256):
|
139 |
+
if i < thresh:
|
140 |
+
lut.append(min(255, i + add))
|
141 |
+
else:
|
142 |
+
lut.append(i)
|
143 |
+
if img.mode in ("L", "RGB"):
|
144 |
+
if img.mode == "RGB" and len(lut) == 256:
|
145 |
+
lut = lut + lut + lut
|
146 |
+
return img.point(lut)
|
147 |
+
else:
|
148 |
+
return img
|
149 |
+
|
150 |
+
|
151 |
+
def posterize(img, bits_to_keep, **__):
|
152 |
+
if bits_to_keep >= 8:
|
153 |
+
return img
|
154 |
+
return ImageOps.posterize(img, bits_to_keep)
|
155 |
+
|
156 |
+
|
157 |
+
def contrast(img, factor, **__):
|
158 |
+
return ImageEnhance.Contrast(img).enhance(factor)
|
159 |
+
|
160 |
+
|
161 |
+
def color(img, factor, **__):
|
162 |
+
return ImageEnhance.Color(img).enhance(factor)
|
163 |
+
|
164 |
+
|
165 |
+
def brightness(img, factor, **__):
|
166 |
+
return ImageEnhance.Brightness(img).enhance(factor)
|
167 |
+
|
168 |
+
|
169 |
+
def sharpness(img, factor, **__):
|
170 |
+
return ImageEnhance.Sharpness(img).enhance(factor)
|
171 |
+
|
172 |
+
|
173 |
+
def _randomly_negate(v):
|
174 |
+
"""With 50% prob, negate the value"""
|
175 |
+
return -v if random.random() > 0.5 else v
|
176 |
+
|
177 |
+
|
178 |
+
def _rotate_level_to_arg(level, _hparams):
|
179 |
+
# range [-30, 30]
|
180 |
+
level = (level / _LEVEL_DENOM) * 30.
|
181 |
+
level = _randomly_negate(level)
|
182 |
+
return level,
|
183 |
+
|
184 |
+
|
185 |
+
def _enhance_level_to_arg(level, _hparams):
|
186 |
+
# range [0.1, 1.9]
|
187 |
+
return (level / _LEVEL_DENOM) * 1.8 + 0.1,
|
188 |
+
|
189 |
+
|
190 |
+
def _enhance_increasing_level_to_arg(level, _hparams):
|
191 |
+
# the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
|
192 |
+
# range [0.1, 1.9] if level <= _LEVEL_DENOM
|
193 |
+
level = (level / _LEVEL_DENOM) * .9
|
194 |
+
level = max(0.1, 1.0 + _randomly_negate(level)) # keep it >= 0.1
|
195 |
+
return level,
|
196 |
+
|
197 |
+
|
198 |
+
def _shear_level_to_arg(level, _hparams):
|
199 |
+
# range [-0.3, 0.3]
|
200 |
+
level = (level / _LEVEL_DENOM) * 0.3
|
201 |
+
level = _randomly_negate(level)
|
202 |
+
return level,
|
203 |
+
|
204 |
+
|
205 |
+
def _translate_abs_level_to_arg(level, hparams):
|
206 |
+
translate_const = hparams['translate_const']
|
207 |
+
level = (level / _LEVEL_DENOM) * float(translate_const)
|
208 |
+
level = _randomly_negate(level)
|
209 |
+
return level,
|
210 |
+
|
211 |
+
|
212 |
+
def _translate_rel_level_to_arg(level, hparams):
|
213 |
+
# default range [-0.45, 0.45]
|
214 |
+
translate_pct = hparams.get('translate_pct', 0.45)
|
215 |
+
level = (level / _LEVEL_DENOM) * translate_pct
|
216 |
+
level = _randomly_negate(level)
|
217 |
+
return level,
|
218 |
+
|
219 |
+
|
220 |
+
def _posterize_level_to_arg(level, _hparams):
|
221 |
+
# As per Tensorflow TPU EfficientNet impl
|
222 |
+
# range [0, 4], 'keep 0 up to 4 MSB of original image'
|
223 |
+
# intensity/severity of augmentation decreases with level
|
224 |
+
return int((level / _LEVEL_DENOM) * 4),
|
225 |
+
|
226 |
+
|
227 |
+
def _posterize_increasing_level_to_arg(level, hparams):
|
228 |
+
# As per Tensorflow models research and UDA impl
|
229 |
+
# range [4, 0], 'keep 4 down to 0 MSB of original image',
|
230 |
+
# intensity/severity of augmentation increases with level
|
231 |
+
return 4 - _posterize_level_to_arg(level, hparams)[0],
|
232 |
+
|
233 |
+
|
234 |
+
def _posterize_original_level_to_arg(level, _hparams):
|
235 |
+
# As per original AutoAugment paper description
|
236 |
+
# range [4, 8], 'keep 4 up to 8 MSB of image'
|
237 |
+
# intensity/severity of augmentation decreases with level
|
238 |
+
return int((level / _LEVEL_DENOM) * 4) + 4,
|
239 |
+
|
240 |
+
|
241 |
+
def _solarize_level_to_arg(level, _hparams):
|
242 |
+
# range [0, 256]
|
243 |
+
# intensity/severity of augmentation decreases with level
|
244 |
+
return int((level / _LEVEL_DENOM) * 256),
|
245 |
+
|
246 |
+
|
247 |
+
def _solarize_increasing_level_to_arg(level, _hparams):
|
248 |
+
# range [0, 256]
|
249 |
+
# intensity/severity of augmentation increases with level
|
250 |
+
return 256 - _solarize_level_to_arg(level, _hparams)[0],
|
251 |
+
|
252 |
+
|
253 |
+
def _solarize_add_level_to_arg(level, _hparams):
|
254 |
+
# range [0, 110]
|
255 |
+
return int((level / _LEVEL_DENOM) * 110),
|
256 |
+
|
257 |
+
|
258 |
+
LEVEL_TO_ARG = {
|
259 |
+
'AutoContrast': None,
|
260 |
+
'Equalize': None,
|
261 |
+
'Invert': None,
|
262 |
+
'Rotate': _rotate_level_to_arg,
|
263 |
+
# There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
|
264 |
+
'Posterize': _posterize_level_to_arg,
|
265 |
+
'PosterizeIncreasing': _posterize_increasing_level_to_arg,
|
266 |
+
'PosterizeOriginal': _posterize_original_level_to_arg,
|
267 |
+
'Solarize': _solarize_level_to_arg,
|
268 |
+
'SolarizeIncreasing': _solarize_increasing_level_to_arg,
|
269 |
+
'SolarizeAdd': _solarize_add_level_to_arg,
|
270 |
+
'Color': _enhance_level_to_arg,
|
271 |
+
'ColorIncreasing': _enhance_increasing_level_to_arg,
|
272 |
+
'Contrast': _enhance_level_to_arg,
|
273 |
+
'ContrastIncreasing': _enhance_increasing_level_to_arg,
|
274 |
+
'Brightness': _enhance_level_to_arg,
|
275 |
+
'BrightnessIncreasing': _enhance_increasing_level_to_arg,
|
276 |
+
'Sharpness': _enhance_level_to_arg,
|
277 |
+
'SharpnessIncreasing': _enhance_increasing_level_to_arg,
|
278 |
+
'ShearX': _shear_level_to_arg,
|
279 |
+
'ShearY': _shear_level_to_arg,
|
280 |
+
'TranslateX': _translate_abs_level_to_arg,
|
281 |
+
'TranslateY': _translate_abs_level_to_arg,
|
282 |
+
'TranslateXRel': _translate_rel_level_to_arg,
|
283 |
+
'TranslateYRel': _translate_rel_level_to_arg,
|
284 |
+
}
|
285 |
+
|
286 |
+
|
287 |
+
NAME_TO_OP = {
|
288 |
+
'AutoContrast': auto_contrast,
|
289 |
+
'Equalize': equalize,
|
290 |
+
'Invert': invert,
|
291 |
+
'Rotate': rotate,
|
292 |
+
'Posterize': posterize,
|
293 |
+
'PosterizeIncreasing': posterize,
|
294 |
+
'PosterizeOriginal': posterize,
|
295 |
+
'Solarize': solarize,
|
296 |
+
'SolarizeIncreasing': solarize,
|
297 |
+
'SolarizeAdd': solarize_add,
|
298 |
+
'Color': color,
|
299 |
+
'ColorIncreasing': color,
|
300 |
+
'Contrast': contrast,
|
301 |
+
'ContrastIncreasing': contrast,
|
302 |
+
'Brightness': brightness,
|
303 |
+
'BrightnessIncreasing': brightness,
|
304 |
+
'Sharpness': sharpness,
|
305 |
+
'SharpnessIncreasing': sharpness,
|
306 |
+
'ShearX': shear_x,
|
307 |
+
'ShearY': shear_y,
|
308 |
+
'TranslateX': translate_x_abs,
|
309 |
+
'TranslateY': translate_y_abs,
|
310 |
+
'TranslateXRel': translate_x_rel,
|
311 |
+
'TranslateYRel': translate_y_rel,
|
312 |
+
}
|
313 |
+
|
314 |
+
|
315 |
+
class AugmentOp:
|
316 |
+
|
317 |
+
def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
|
318 |
+
hparams = hparams or _HPARAMS_DEFAULT
|
319 |
+
self.name = name
|
320 |
+
self.aug_fn = NAME_TO_OP[name]
|
321 |
+
self.level_fn = LEVEL_TO_ARG[name]
|
322 |
+
self.prob = prob
|
323 |
+
self.magnitude = magnitude
|
324 |
+
self.hparams = hparams.copy()
|
325 |
+
self.kwargs = dict(
|
326 |
+
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
|
327 |
+
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
|
328 |
+
)
|
329 |
+
|
330 |
+
# If magnitude_std is > 0, we introduce some randomness
|
331 |
+
# in the usually fixed policy and sample magnitude from a normal distribution
|
332 |
+
# with mean `magnitude` and std-dev of `magnitude_std`.
|
333 |
+
# NOTE This is my own hack, being tested, not in papers or reference impls.
|
334 |
+
# If magnitude_std is inf, we sample magnitude from a uniform distribution
|
335 |
+
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
336 |
+
self.magnitude_max = self.hparams.get('magnitude_max', None)
|
337 |
+
|
338 |
+
def __call__(self, img):
|
339 |
+
if self.prob < 1.0 and random.random() > self.prob:
|
340 |
+
return img
|
341 |
+
magnitude = self.magnitude
|
342 |
+
if self.magnitude_std > 0:
|
343 |
+
# magnitude randomization enabled
|
344 |
+
if self.magnitude_std == float('inf'):
|
345 |
+
magnitude = random.uniform(0, magnitude)
|
346 |
+
elif self.magnitude_std > 0:
|
347 |
+
magnitude = random.gauss(magnitude, self.magnitude_std)
|
348 |
+
# default upper_bound for the timm RA impl is _LEVEL_DENOM (10)
|
349 |
+
# setting magnitude_max overrides this to allow M > 10 (behaviour closer to Google TF RA impl)
|
350 |
+
upper_bound = self.magnitude_max or _LEVEL_DENOM
|
351 |
+
magnitude = max(0., min(magnitude, upper_bound))
|
352 |
+
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
|
353 |
+
return self.aug_fn(img, *level_args, **self.kwargs)
|
354 |
+
|
355 |
+
def __repr__(self):
|
356 |
+
fs = self.__class__.__name__ + f'(name={self.name}, p={self.prob}'
|
357 |
+
fs += f', m={self.magnitude}, mstd={self.magnitude_std}'
|
358 |
+
if self.magnitude_max is not None:
|
359 |
+
fs += f', mmax={self.magnitude_max}'
|
360 |
+
fs += ')'
|
361 |
+
return fs
|
362 |
+
|
363 |
+
|
364 |
+
def auto_augment_policy_v0(hparams):
|
365 |
+
# ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
|
366 |
+
policy = [
|
367 |
+
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
368 |
+
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
369 |
+
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
|
370 |
+
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
|
371 |
+
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
|
372 |
+
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
|
373 |
+
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
|
374 |
+
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
|
375 |
+
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
|
376 |
+
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
377 |
+
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
378 |
+
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
379 |
+
[('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
380 |
+
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
381 |
+
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
382 |
+
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
383 |
+
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
|
384 |
+
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
|
385 |
+
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
|
386 |
+
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
387 |
+
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
388 |
+
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
389 |
+
[('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
|
390 |
+
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
391 |
+
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
392 |
+
]
|
393 |
+
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
394 |
+
return pc
|
395 |
+
|
396 |
+
|
397 |
+
def auto_augment_policy_v0r(hparams):
|
398 |
+
# ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
|
399 |
+
# in Google research implementation (number of bits discarded increases with magnitude)
|
400 |
+
policy = [
|
401 |
+
[('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
|
402 |
+
[('Color', 0.4, 9), ('Equalize', 0.6, 3)],
|
403 |
+
[('Color', 0.4, 1), ('Rotate', 0.6, 8)],
|
404 |
+
[('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
|
405 |
+
[('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
|
406 |
+
[('Color', 0.2, 0), ('Equalize', 0.8, 8)],
|
407 |
+
[('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
|
408 |
+
[('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
|
409 |
+
[('Color', 0.6, 1), ('Equalize', 1.0, 2)],
|
410 |
+
[('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
|
411 |
+
[('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
|
412 |
+
[('Color', 0.4, 7), ('Equalize', 0.6, 0)],
|
413 |
+
[('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
|
414 |
+
[('Solarize', 0.6, 8), ('Color', 0.6, 9)],
|
415 |
+
[('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
|
416 |
+
[('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
|
417 |
+
[('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
|
418 |
+
[('ShearY', 0.8, 0), ('Color', 0.6, 4)],
|
419 |
+
[('Color', 1.0, 0), ('Rotate', 0.6, 2)],
|
420 |
+
[('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
|
421 |
+
[('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
|
422 |
+
[('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
|
423 |
+
[('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
|
424 |
+
[('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
|
425 |
+
[('Color', 0.8, 6), ('Rotate', 0.4, 5)],
|
426 |
+
]
|
427 |
+
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
428 |
+
return pc
|
429 |
+
|
430 |
+
|
431 |
+
def auto_augment_policy_original(hparams):
|
432 |
+
# ImageNet policy from https://arxiv.org/abs/1805.09501
|
433 |
+
policy = [
|
434 |
+
[('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
|
435 |
+
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
436 |
+
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
437 |
+
[('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
|
438 |
+
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
439 |
+
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
440 |
+
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
441 |
+
[('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
|
442 |
+
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
443 |
+
[('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
|
444 |
+
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
445 |
+
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
446 |
+
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
447 |
+
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
448 |
+
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
449 |
+
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
|
450 |
+
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
|
451 |
+
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
|
452 |
+
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
|
453 |
+
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
|
454 |
+
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
455 |
+
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
456 |
+
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
457 |
+
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
458 |
+
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
459 |
+
]
|
460 |
+
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
461 |
+
return pc
|
462 |
+
|
463 |
+
|
464 |
+
def auto_augment_policy_originalr(hparams):
|
465 |
+
# ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
|
466 |
+
policy = [
|
467 |
+
[('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
|
468 |
+
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
469 |
+
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
470 |
+
[('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
|
471 |
+
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
472 |
+
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
|
473 |
+
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
|
474 |
+
[('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
|
475 |
+
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
|
476 |
+
[('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
|
477 |
+
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
|
478 |
+
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
|
479 |
+
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
|
480 |
+
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
481 |
+
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
482 |
+
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
|
483 |
+
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
|
484 |
+
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
|
485 |
+
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
|
486 |
+
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
|
487 |
+
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
|
488 |
+
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
|
489 |
+
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
|
490 |
+
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
|
491 |
+
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
|
492 |
+
]
|
493 |
+
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
494 |
+
return pc
|
495 |
+
|
496 |
+
|
497 |
+
def auto_augment_policy(name='v0', hparams=None):
|
498 |
+
hparams = hparams or _HPARAMS_DEFAULT
|
499 |
+
if name == 'original':
|
500 |
+
return auto_augment_policy_original(hparams)
|
501 |
+
elif name == 'originalr':
|
502 |
+
return auto_augment_policy_originalr(hparams)
|
503 |
+
elif name == 'v0':
|
504 |
+
return auto_augment_policy_v0(hparams)
|
505 |
+
elif name == 'v0r':
|
506 |
+
return auto_augment_policy_v0r(hparams)
|
507 |
+
else:
|
508 |
+
assert False, 'Unknown AA policy (%s)' % name
|
509 |
+
|
510 |
+
|
511 |
+
class AutoAugment:
|
512 |
+
|
513 |
+
def __init__(self, policy):
|
514 |
+
self.policy = policy
|
515 |
+
|
516 |
+
def __call__(self, img):
|
517 |
+
sub_policy = random.choice(self.policy)
|
518 |
+
for op in sub_policy:
|
519 |
+
img = op(img)
|
520 |
+
return img
|
521 |
+
|
522 |
+
def __repr__(self):
|
523 |
+
fs = self.__class__.__name__ + f'(policy='
|
524 |
+
for p in self.policy:
|
525 |
+
fs += '\n\t['
|
526 |
+
fs += ', '.join([str(op) for op in p])
|
527 |
+
fs += ']'
|
528 |
+
fs += ')'
|
529 |
+
return fs
|
530 |
+
|
531 |
+
|
532 |
+
def auto_augment_transform(config_str, hparams):
|
533 |
+
"""
|
534 |
+
Create a AutoAugment transform
|
535 |
+
|
536 |
+
:param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
|
537 |
+
dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
|
538 |
+
The remaining sections, not order sepecific determine
|
539 |
+
'mstd' - float std deviation of magnitude noise applied
|
540 |
+
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
|
541 |
+
|
542 |
+
:param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
|
543 |
+
|
544 |
+
:return: A PyTorch compatible Transform
|
545 |
+
"""
|
546 |
+
config = config_str.split('-')
|
547 |
+
policy_name = config[0]
|
548 |
+
config = config[1:]
|
549 |
+
for c in config:
|
550 |
+
cs = re.split(r'(\d.*)', c)
|
551 |
+
if len(cs) < 2:
|
552 |
+
continue
|
553 |
+
key, val = cs[:2]
|
554 |
+
if key == 'mstd':
|
555 |
+
# noise param injected via hparams for now
|
556 |
+
hparams.setdefault('magnitude_std', float(val))
|
557 |
+
else:
|
558 |
+
assert False, 'Unknown AutoAugment config section'
|
559 |
+
aa_policy = auto_augment_policy(policy_name, hparams=hparams)
|
560 |
+
return AutoAugment(aa_policy)
|
561 |
+
|
562 |
+
|
563 |
+
_RAND_TRANSFORMS = [
|
564 |
+
'AutoContrast',
|
565 |
+
'Equalize',
|
566 |
+
'Invert',
|
567 |
+
'Rotate',
|
568 |
+
'Posterize',
|
569 |
+
'Solarize',
|
570 |
+
'SolarizeAdd',
|
571 |
+
'Color',
|
572 |
+
'Contrast',
|
573 |
+
'Brightness',
|
574 |
+
'Sharpness',
|
575 |
+
'ShearX',
|
576 |
+
'ShearY',
|
577 |
+
'TranslateXRel',
|
578 |
+
'TranslateYRel',
|
579 |
+
#'Cutout' # NOTE I've implement this as random erasing separately
|
580 |
+
]
|
581 |
+
|
582 |
+
|
583 |
+
_RAND_INCREASING_TRANSFORMS = [
|
584 |
+
'AutoContrast',
|
585 |
+
'Equalize',
|
586 |
+
'Invert',
|
587 |
+
'Rotate',
|
588 |
+
'PosterizeIncreasing',
|
589 |
+
'SolarizeIncreasing',
|
590 |
+
'SolarizeAdd',
|
591 |
+
'ColorIncreasing',
|
592 |
+
'ContrastIncreasing',
|
593 |
+
'BrightnessIncreasing',
|
594 |
+
'SharpnessIncreasing',
|
595 |
+
'ShearX',
|
596 |
+
'ShearY',
|
597 |
+
'TranslateXRel',
|
598 |
+
'TranslateYRel',
|
599 |
+
#'Cutout' # NOTE I've implement this as random erasing separately
|
600 |
+
]
|
601 |
+
|
602 |
+
|
603 |
+
|
604 |
+
# These experimental weights are based loosely on the relative improvements mentioned in paper.
|
605 |
+
# They may not result in increased performance, but could likely be tuned to so.
|
606 |
+
_RAND_CHOICE_WEIGHTS_0 = {
|
607 |
+
'Rotate': 0.3,
|
608 |
+
'ShearX': 0.2,
|
609 |
+
'ShearY': 0.2,
|
610 |
+
'TranslateXRel': 0.1,
|
611 |
+
'TranslateYRel': 0.1,
|
612 |
+
'Color': .025,
|
613 |
+
'Sharpness': 0.025,
|
614 |
+
'AutoContrast': 0.025,
|
615 |
+
'Solarize': .005,
|
616 |
+
'SolarizeAdd': .005,
|
617 |
+
'Contrast': .005,
|
618 |
+
'Brightness': .005,
|
619 |
+
'Equalize': .005,
|
620 |
+
'Posterize': 0,
|
621 |
+
'Invert': 0,
|
622 |
+
}
|
623 |
+
|
624 |
+
|
625 |
+
def _select_rand_weights(weight_idx=0, transforms=None):
|
626 |
+
transforms = transforms or _RAND_TRANSFORMS
|
627 |
+
assert weight_idx == 0 # only one set of weights currently
|
628 |
+
rand_weights = _RAND_CHOICE_WEIGHTS_0
|
629 |
+
probs = [rand_weights[k] for k in transforms]
|
630 |
+
probs /= np.sum(probs)
|
631 |
+
return probs
|
632 |
+
|
633 |
+
|
634 |
+
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
|
635 |
+
hparams = hparams or _HPARAMS_DEFAULT
|
636 |
+
transforms = transforms or _RAND_TRANSFORMS
|
637 |
+
return [AugmentOp(
|
638 |
+
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
|
639 |
+
|
640 |
+
|
641 |
+
class RandAugment:
|
642 |
+
def __init__(self, ops, num_layers=2, choice_weights=None):
|
643 |
+
self.ops = ops
|
644 |
+
self.num_layers = num_layers
|
645 |
+
self.choice_weights = choice_weights
|
646 |
+
|
647 |
+
def __call__(self, img):
|
648 |
+
# no replacement when using weighted choice
|
649 |
+
ops = np.random.choice(
|
650 |
+
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
|
651 |
+
for op in ops:
|
652 |
+
img = op(img)
|
653 |
+
return img
|
654 |
+
|
655 |
+
def __repr__(self):
|
656 |
+
fs = self.__class__.__name__ + f'(n={self.num_layers}, ops='
|
657 |
+
for op in self.ops:
|
658 |
+
fs += f'\n\t{op}'
|
659 |
+
fs += ')'
|
660 |
+
return fs
|
661 |
+
|
662 |
+
|
663 |
+
def rand_augment_transform(config_str, hparams):
|
664 |
+
"""
|
665 |
+
Create a RandAugment transform
|
666 |
+
|
667 |
+
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
668 |
+
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
669 |
+
sections, not order sepecific determine
|
670 |
+
'm' - integer magnitude of rand augment
|
671 |
+
'n' - integer num layers (number of transform ops selected per image)
|
672 |
+
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
|
673 |
+
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
|
674 |
+
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
|
675 |
+
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
676 |
+
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
677 |
+
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
|
678 |
+
|
679 |
+
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme
|
680 |
+
|
681 |
+
:return: A PyTorch compatible Transform
|
682 |
+
"""
|
683 |
+
magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
|
684 |
+
num_layers = 2 # default to 2 ops per image
|
685 |
+
weight_idx = None # default to no probability weights for op choice
|
686 |
+
transforms = _RAND_TRANSFORMS
|
687 |
+
config = config_str.split('-')
|
688 |
+
assert config[0] == 'rand'
|
689 |
+
config = config[1:]
|
690 |
+
for c in config:
|
691 |
+
cs = re.split(r'(\d.*)', c)
|
692 |
+
if len(cs) < 2:
|
693 |
+
continue
|
694 |
+
key, val = cs[:2]
|
695 |
+
if key == 'mstd':
|
696 |
+
# noise param / randomization of magnitude values
|
697 |
+
mstd = float(val)
|
698 |
+
if mstd > 100:
|
699 |
+
# use uniform sampling in 0 to magnitude if mstd is > 100
|
700 |
+
mstd = float('inf')
|
701 |
+
hparams.setdefault('magnitude_std', mstd)
|
702 |
+
elif key == 'mmax':
|
703 |
+
# clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
|
704 |
+
hparams.setdefault('magnitude_max', int(val))
|
705 |
+
elif key == 'inc':
|
706 |
+
if bool(val):
|
707 |
+
transforms = _RAND_INCREASING_TRANSFORMS
|
708 |
+
elif key == 'm':
|
709 |
+
magnitude = int(val)
|
710 |
+
elif key == 'n':
|
711 |
+
num_layers = int(val)
|
712 |
+
elif key == 'w':
|
713 |
+
weight_idx = int(val)
|
714 |
+
else:
|
715 |
+
assert False, 'Unknown RandAugment config section'
|
716 |
+
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
|
717 |
+
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
|
718 |
+
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
719 |
+
|
720 |
+
|
721 |
+
_AUGMIX_TRANSFORMS = [
|
722 |
+
'AutoContrast',
|
723 |
+
'ColorIncreasing', # not in paper
|
724 |
+
'ContrastIncreasing', # not in paper
|
725 |
+
'BrightnessIncreasing', # not in paper
|
726 |
+
'SharpnessIncreasing', # not in paper
|
727 |
+
'Equalize',
|
728 |
+
'Rotate',
|
729 |
+
'PosterizeIncreasing',
|
730 |
+
'SolarizeIncreasing',
|
731 |
+
'ShearX',
|
732 |
+
'ShearY',
|
733 |
+
'TranslateXRel',
|
734 |
+
'TranslateYRel',
|
735 |
+
]
|
736 |
+
|
737 |
+
|
738 |
+
def augmix_ops(magnitude=10, hparams=None, transforms=None):
|
739 |
+
hparams = hparams or _HPARAMS_DEFAULT
|
740 |
+
transforms = transforms or _AUGMIX_TRANSFORMS
|
741 |
+
return [AugmentOp(
|
742 |
+
name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
|
743 |
+
|
744 |
+
|
745 |
+
class AugMixAugment:
|
746 |
+
""" AugMix Transform
|
747 |
+
Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
|
748 |
+
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
|
749 |
+
https://arxiv.org/abs/1912.02781
|
750 |
+
"""
|
751 |
+
def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
|
752 |
+
self.ops = ops
|
753 |
+
self.alpha = alpha
|
754 |
+
self.width = width
|
755 |
+
self.depth = depth
|
756 |
+
self.blended = blended # blended mode is faster but not well tested
|
757 |
+
|
758 |
+
def _calc_blended_weights(self, ws, m):
|
759 |
+
ws = ws * m
|
760 |
+
cump = 1.
|
761 |
+
rws = []
|
762 |
+
for w in ws[::-1]:
|
763 |
+
alpha = w / cump
|
764 |
+
cump *= (1 - alpha)
|
765 |
+
rws.append(alpha)
|
766 |
+
return np.array(rws[::-1], dtype=np.float32)
|
767 |
+
|
768 |
+
def _apply_blended(self, img, mixing_weights, m):
|
769 |
+
# This is my first crack and implementing a slightly faster mixed augmentation. Instead
|
770 |
+
# of accumulating the mix for each chain in a Numpy array and then blending with original,
|
771 |
+
# it recomputes the blending coefficients and applies one PIL image blend per chain.
|
772 |
+
# TODO the results appear in the right ballpark but they differ by more than rounding.
|
773 |
+
img_orig = img.copy()
|
774 |
+
ws = self._calc_blended_weights(mixing_weights, m)
|
775 |
+
for w in ws:
|
776 |
+
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
|
777 |
+
ops = np.random.choice(self.ops, depth, replace=True)
|
778 |
+
img_aug = img_orig # no ops are in-place, deep copy not necessary
|
779 |
+
for op in ops:
|
780 |
+
img_aug = op(img_aug)
|
781 |
+
img = Image.blend(img, img_aug, w)
|
782 |
+
return img
|
783 |
+
|
784 |
+
def _apply_basic(self, img, mixing_weights, m):
|
785 |
+
# This is a literal adaptation of the paper/official implementation without normalizations and
|
786 |
+
# PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
|
787 |
+
# typical augmentation transforms, could use a GPU / Kornia implementation.
|
788 |
+
img_shape = img.size[0], img.size[1], len(img.getbands())
|
789 |
+
mixed = np.zeros(img_shape, dtype=np.float32)
|
790 |
+
for mw in mixing_weights:
|
791 |
+
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
|
792 |
+
ops = np.random.choice(self.ops, depth, replace=True)
|
793 |
+
img_aug = img # no ops are in-place, deep copy not necessary
|
794 |
+
for op in ops:
|
795 |
+
img_aug = op(img_aug)
|
796 |
+
mixed += mw * np.asarray(img_aug, dtype=np.float32)
|
797 |
+
np.clip(mixed, 0, 255., out=mixed)
|
798 |
+
mixed = Image.fromarray(mixed.astype(np.uint8))
|
799 |
+
return Image.blend(img, mixed, m)
|
800 |
+
|
801 |
+
def __call__(self, img):
|
802 |
+
mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
|
803 |
+
m = np.float32(np.random.beta(self.alpha, self.alpha))
|
804 |
+
if self.blended:
|
805 |
+
mixed = self._apply_blended(img, mixing_weights, m)
|
806 |
+
else:
|
807 |
+
mixed = self._apply_basic(img, mixing_weights, m)
|
808 |
+
return mixed
|
809 |
+
|
810 |
+
def __repr__(self):
|
811 |
+
fs = self.__class__.__name__ + f'(alpha={self.alpha}, width={self.width}, depth={self.depth}, ops='
|
812 |
+
for op in self.ops:
|
813 |
+
fs += f'\n\t{op}'
|
814 |
+
fs += ')'
|
815 |
+
return fs
|
816 |
+
|
817 |
+
|
818 |
+
def augment_and_mix_transform(config_str, hparams):
|
819 |
+
""" Create AugMix PyTorch transform
|
820 |
+
|
821 |
+
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
822 |
+
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
823 |
+
sections, not order sepecific determine
|
824 |
+
'm' - integer magnitude (severity) of augmentation mix (default: 3)
|
825 |
+
'w' - integer width of augmentation chain (default: 3)
|
826 |
+
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
|
827 |
+
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
|
828 |
+
'mstd' - float std deviation of magnitude noise applied (default: 0)
|
829 |
+
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
|
830 |
+
|
831 |
+
:param hparams: Other hparams (kwargs) for the Augmentation transforms
|
832 |
+
|
833 |
+
:return: A PyTorch compatible Transform
|
834 |
+
"""
|
835 |
+
magnitude = 3
|
836 |
+
width = 3
|
837 |
+
depth = -1
|
838 |
+
alpha = 1.
|
839 |
+
blended = False
|
840 |
+
config = config_str.split('-')
|
841 |
+
assert config[0] == 'augmix'
|
842 |
+
config = config[1:]
|
843 |
+
for c in config:
|
844 |
+
cs = re.split(r'(\d.*)', c)
|
845 |
+
if len(cs) < 2:
|
846 |
+
continue
|
847 |
+
key, val = cs[:2]
|
848 |
+
if key == 'mstd':
|
849 |
+
# noise param injected via hparams for now
|
850 |
+
hparams.setdefault('magnitude_std', float(val))
|
851 |
+
elif key == 'm':
|
852 |
+
magnitude = int(val)
|
853 |
+
elif key == 'w':
|
854 |
+
width = int(val)
|
855 |
+
elif key == 'd':
|
856 |
+
depth = int(val)
|
857 |
+
elif key == 'a':
|
858 |
+
alpha = float(val)
|
859 |
+
elif key == 'b':
|
860 |
+
blended = bool(val)
|
861 |
+
else:
|
862 |
+
assert False, 'Unknown AugMix config section'
|
863 |
+
hparams.setdefault('magnitude_std', float('inf')) # default to uniform sampling (if not set via mstd arg)
|
864 |
+
ops = augmix_ops(magnitude=magnitude, hparams=hparams)
|
865 |
+
return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
|
lib/timm/data/config.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from .constants import *
|
3 |
+
|
4 |
+
|
5 |
+
_logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
|
8 |
+
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
|
9 |
+
new_config = {}
|
10 |
+
default_cfg = default_cfg
|
11 |
+
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
|
12 |
+
default_cfg = model.default_cfg
|
13 |
+
|
14 |
+
# Resolve input/image size
|
15 |
+
in_chans = 3
|
16 |
+
if 'chans' in args and args['chans'] is not None:
|
17 |
+
in_chans = args['chans']
|
18 |
+
|
19 |
+
input_size = (in_chans, 224, 224)
|
20 |
+
if 'input_size' in args and args['input_size'] is not None:
|
21 |
+
assert isinstance(args['input_size'], (tuple, list))
|
22 |
+
assert len(args['input_size']) == 3
|
23 |
+
input_size = tuple(args['input_size'])
|
24 |
+
in_chans = input_size[0] # input_size overrides in_chans
|
25 |
+
elif 'img_size' in args and args['img_size'] is not None:
|
26 |
+
assert isinstance(args['img_size'], int)
|
27 |
+
input_size = (in_chans, args['img_size'], args['img_size'])
|
28 |
+
else:
|
29 |
+
if use_test_size and 'test_input_size' in default_cfg:
|
30 |
+
input_size = default_cfg['test_input_size']
|
31 |
+
elif 'input_size' in default_cfg:
|
32 |
+
input_size = default_cfg['input_size']
|
33 |
+
new_config['input_size'] = input_size
|
34 |
+
|
35 |
+
# resolve interpolation method
|
36 |
+
new_config['interpolation'] = 'bicubic'
|
37 |
+
if 'interpolation' in args and args['interpolation']:
|
38 |
+
new_config['interpolation'] = args['interpolation']
|
39 |
+
elif 'interpolation' in default_cfg:
|
40 |
+
new_config['interpolation'] = default_cfg['interpolation']
|
41 |
+
|
42 |
+
# resolve dataset + model mean for normalization
|
43 |
+
new_config['mean'] = IMAGENET_DEFAULT_MEAN
|
44 |
+
if 'mean' in args and args['mean'] is not None:
|
45 |
+
mean = tuple(args['mean'])
|
46 |
+
if len(mean) == 1:
|
47 |
+
mean = tuple(list(mean) * in_chans)
|
48 |
+
else:
|
49 |
+
assert len(mean) == in_chans
|
50 |
+
new_config['mean'] = mean
|
51 |
+
elif 'mean' in default_cfg:
|
52 |
+
new_config['mean'] = default_cfg['mean']
|
53 |
+
|
54 |
+
# resolve dataset + model std deviation for normalization
|
55 |
+
new_config['std'] = IMAGENET_DEFAULT_STD
|
56 |
+
if 'std' in args and args['std'] is not None:
|
57 |
+
std = tuple(args['std'])
|
58 |
+
if len(std) == 1:
|
59 |
+
std = tuple(list(std) * in_chans)
|
60 |
+
else:
|
61 |
+
assert len(std) == in_chans
|
62 |
+
new_config['std'] = std
|
63 |
+
elif 'std' in default_cfg:
|
64 |
+
new_config['std'] = default_cfg['std']
|
65 |
+
|
66 |
+
# resolve default crop percentage
|
67 |
+
new_config['crop_pct'] = DEFAULT_CROP_PCT
|
68 |
+
if 'crop_pct' in args and args['crop_pct'] is not None:
|
69 |
+
new_config['crop_pct'] = args['crop_pct']
|
70 |
+
elif 'crop_pct' in default_cfg:
|
71 |
+
new_config['crop_pct'] = default_cfg['crop_pct']
|
72 |
+
|
73 |
+
if verbose:
|
74 |
+
_logger.info('Data processing configuration for current model + dataset:')
|
75 |
+
for n, v in new_config.items():
|
76 |
+
_logger.info('\t%s: %s' % (n, str(v)))
|
77 |
+
|
78 |
+
return new_config
|
lib/timm/data/constants.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DEFAULT_CROP_PCT = 0.875
|
2 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
3 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
4 |
+
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
5 |
+
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
6 |
+
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
|
7 |
+
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
|
lib/timm/data/dataset.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Quick n Simple Image Folder, Tarfile based DataSet
|
2 |
+
|
3 |
+
Hacked together by / Copyright 2019, Ross Wightman
|
4 |
+
"""
|
5 |
+
import torch.utils.data as data
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import logging
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from .parsers import create_parser
|
13 |
+
|
14 |
+
_logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
_ERROR_RETRY = 50
|
18 |
+
|
19 |
+
|
20 |
+
class ImageDataset(data.Dataset):
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
root,
|
25 |
+
parser=None,
|
26 |
+
class_map=None,
|
27 |
+
load_bytes=False,
|
28 |
+
transform=None,
|
29 |
+
target_transform=None,
|
30 |
+
):
|
31 |
+
if parser is None or isinstance(parser, str):
|
32 |
+
parser = create_parser(parser or '', root=root, class_map=class_map)
|
33 |
+
self.parser = parser
|
34 |
+
self.load_bytes = load_bytes
|
35 |
+
self.transform = transform
|
36 |
+
self.target_transform = target_transform
|
37 |
+
self._consecutive_errors = 0
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
img, target = self.parser[index]
|
41 |
+
try:
|
42 |
+
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
|
43 |
+
except Exception as e:
|
44 |
+
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
|
45 |
+
self._consecutive_errors += 1
|
46 |
+
if self._consecutive_errors < _ERROR_RETRY:
|
47 |
+
return self.__getitem__((index + 1) % len(self.parser))
|
48 |
+
else:
|
49 |
+
raise e
|
50 |
+
self._consecutive_errors = 0
|
51 |
+
if self.transform is not None:
|
52 |
+
img = self.transform(img)
|
53 |
+
if target is None:
|
54 |
+
target = -1
|
55 |
+
elif self.target_transform is not None:
|
56 |
+
target = self.target_transform(target)
|
57 |
+
return img, target
|
58 |
+
|
59 |
+
def __len__(self):
|
60 |
+
return len(self.parser)
|
61 |
+
|
62 |
+
def filename(self, index, basename=False, absolute=False):
|
63 |
+
return self.parser.filename(index, basename, absolute)
|
64 |
+
|
65 |
+
def filenames(self, basename=False, absolute=False):
|
66 |
+
return self.parser.filenames(basename, absolute)
|
67 |
+
|
68 |
+
|
69 |
+
class IterableImageDataset(data.IterableDataset):
|
70 |
+
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
root,
|
74 |
+
parser=None,
|
75 |
+
split='train',
|
76 |
+
is_training=False,
|
77 |
+
batch_size=None,
|
78 |
+
repeats=0,
|
79 |
+
download=False,
|
80 |
+
transform=None,
|
81 |
+
target_transform=None,
|
82 |
+
):
|
83 |
+
assert parser is not None
|
84 |
+
if isinstance(parser, str):
|
85 |
+
self.parser = create_parser(
|
86 |
+
parser, root=root, split=split, is_training=is_training,
|
87 |
+
batch_size=batch_size, repeats=repeats, download=download)
|
88 |
+
else:
|
89 |
+
self.parser = parser
|
90 |
+
self.transform = transform
|
91 |
+
self.target_transform = target_transform
|
92 |
+
self._consecutive_errors = 0
|
93 |
+
|
94 |
+
def __iter__(self):
|
95 |
+
for img, target in self.parser:
|
96 |
+
if self.transform is not None:
|
97 |
+
img = self.transform(img)
|
98 |
+
if self.target_transform is not None:
|
99 |
+
target = self.target_transform(target)
|
100 |
+
yield img, target
|
101 |
+
|
102 |
+
def __len__(self):
|
103 |
+
if hasattr(self.parser, '__len__'):
|
104 |
+
return len(self.parser)
|
105 |
+
else:
|
106 |
+
return 0
|
107 |
+
|
108 |
+
def filename(self, index, basename=False, absolute=False):
|
109 |
+
assert False, 'Filename lookup by index not supported, use filenames().'
|
110 |
+
|
111 |
+
def filenames(self, basename=False, absolute=False):
|
112 |
+
return self.parser.filenames(basename, absolute)
|
113 |
+
|
114 |
+
|
115 |
+
class AugMixDataset(torch.utils.data.Dataset):
|
116 |
+
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
|
117 |
+
|
118 |
+
def __init__(self, dataset, num_splits=2):
|
119 |
+
self.augmentation = None
|
120 |
+
self.normalize = None
|
121 |
+
self.dataset = dataset
|
122 |
+
if self.dataset.transform is not None:
|
123 |
+
self._set_transforms(self.dataset.transform)
|
124 |
+
self.num_splits = num_splits
|
125 |
+
|
126 |
+
def _set_transforms(self, x):
|
127 |
+
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
|
128 |
+
self.dataset.transform = x[0]
|
129 |
+
self.augmentation = x[1]
|
130 |
+
self.normalize = x[2]
|
131 |
+
|
132 |
+
@property
|
133 |
+
def transform(self):
|
134 |
+
return self.dataset.transform
|
135 |
+
|
136 |
+
@transform.setter
|
137 |
+
def transform(self, x):
|
138 |
+
self._set_transforms(x)
|
139 |
+
|
140 |
+
def _normalize(self, x):
|
141 |
+
return x if self.normalize is None else self.normalize(x)
|
142 |
+
|
143 |
+
def __getitem__(self, i):
|
144 |
+
x, y = self.dataset[i] # all splits share the same dataset base transform
|
145 |
+
x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
|
146 |
+
# run the full augmentation on the remaining splits
|
147 |
+
for _ in range(self.num_splits - 1):
|
148 |
+
x_list.append(self._normalize(self.augmentation(x)))
|
149 |
+
return tuple(x_list), y
|
150 |
+
|
151 |
+
def __len__(self):
|
152 |
+
return len(self.dataset)
|
lib/timm/data/dataset_factory.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Dataset Factory
|
2 |
+
|
3 |
+
Hacked together by / Copyright 2021, Ross Wightman
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
|
7 |
+
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder
|
8 |
+
try:
|
9 |
+
from torchvision.datasets import Places365
|
10 |
+
has_places365 = True
|
11 |
+
except ImportError:
|
12 |
+
has_places365 = False
|
13 |
+
try:
|
14 |
+
from torchvision.datasets import INaturalist
|
15 |
+
has_inaturalist = True
|
16 |
+
except ImportError:
|
17 |
+
has_inaturalist = False
|
18 |
+
|
19 |
+
from .dataset import IterableImageDataset, ImageDataset
|
20 |
+
|
21 |
+
_TORCH_BASIC_DS = dict(
|
22 |
+
cifar10=CIFAR10,
|
23 |
+
cifar100=CIFAR100,
|
24 |
+
mnist=MNIST,
|
25 |
+
qmist=QMNIST,
|
26 |
+
kmnist=KMNIST,
|
27 |
+
fashion_mnist=FashionMNIST,
|
28 |
+
)
|
29 |
+
_TRAIN_SYNONYM = {'train', 'training'}
|
30 |
+
_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}
|
31 |
+
|
32 |
+
|
33 |
+
def _search_split(root, split):
|
34 |
+
# look for sub-folder with name of split in root and use that if it exists
|
35 |
+
split_name = split.split('[')[0]
|
36 |
+
try_root = os.path.join(root, split_name)
|
37 |
+
if os.path.exists(try_root):
|
38 |
+
return try_root
|
39 |
+
|
40 |
+
def _try(syn):
|
41 |
+
for s in syn:
|
42 |
+
try_root = os.path.join(root, s)
|
43 |
+
if os.path.exists(try_root):
|
44 |
+
return try_root
|
45 |
+
return root
|
46 |
+
if split_name in _TRAIN_SYNONYM:
|
47 |
+
root = _try(_TRAIN_SYNONYM)
|
48 |
+
elif split_name in _EVAL_SYNONYM:
|
49 |
+
root = _try(_EVAL_SYNONYM)
|
50 |
+
return root
|
51 |
+
|
52 |
+
|
53 |
+
def create_dataset(
|
54 |
+
name,
|
55 |
+
root,
|
56 |
+
split='validation',
|
57 |
+
search_split=True,
|
58 |
+
class_map=None,
|
59 |
+
load_bytes=False,
|
60 |
+
is_training=False,
|
61 |
+
download=False,
|
62 |
+
batch_size=None,
|
63 |
+
repeats=0,
|
64 |
+
**kwargs
|
65 |
+
):
|
66 |
+
""" Dataset factory method
|
67 |
+
|
68 |
+
In parenthesis after each arg are the type of dataset supported for each arg, one of:
|
69 |
+
* folder - default, timm folder (or tar) based ImageDataset
|
70 |
+
* torch - torchvision based datasets
|
71 |
+
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
|
72 |
+
* all - any of the above
|
73 |
+
|
74 |
+
Args:
|
75 |
+
name: dataset name, empty is okay for folder based datasets
|
76 |
+
root: root folder of dataset (all)
|
77 |
+
split: dataset split (all)
|
78 |
+
search_split: search for split specific child fold from root so one can specify
|
79 |
+
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
|
80 |
+
class_map: specify class -> index mapping via text file or dict (folder)
|
81 |
+
load_bytes: load data, return images as undecoded bytes (folder)
|
82 |
+
download: download dataset if not present and supported (TFDS, torch)
|
83 |
+
is_training: create dataset in train mode, this is different from the split.
|
84 |
+
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
|
85 |
+
batch_size: batch size hint for (TFDS)
|
86 |
+
repeats: dataset repeats per iteration i.e. epoch (TFDS)
|
87 |
+
**kwargs: other args to pass to dataset
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
Dataset object
|
91 |
+
"""
|
92 |
+
name = name.lower()
|
93 |
+
if name.startswith('torch/'):
|
94 |
+
name = name.split('/', 2)[-1]
|
95 |
+
torch_kwargs = dict(root=root, download=download, **kwargs)
|
96 |
+
if name in _TORCH_BASIC_DS:
|
97 |
+
ds_class = _TORCH_BASIC_DS[name]
|
98 |
+
use_train = split in _TRAIN_SYNONYM
|
99 |
+
ds = ds_class(train=use_train, **torch_kwargs)
|
100 |
+
elif name == 'inaturalist' or name == 'inat':
|
101 |
+
assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'
|
102 |
+
target_type = 'full'
|
103 |
+
split_split = split.split('/')
|
104 |
+
if len(split_split) > 1:
|
105 |
+
target_type = split_split[0].split('_')
|
106 |
+
if len(target_type) == 1:
|
107 |
+
target_type = target_type[0]
|
108 |
+
split = split_split[-1]
|
109 |
+
if split in _TRAIN_SYNONYM:
|
110 |
+
split = '2021_train'
|
111 |
+
elif split in _EVAL_SYNONYM:
|
112 |
+
split = '2021_valid'
|
113 |
+
ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
|
114 |
+
elif name == 'places365':
|
115 |
+
assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.'
|
116 |
+
if split in _TRAIN_SYNONYM:
|
117 |
+
split = 'train-standard'
|
118 |
+
elif split in _EVAL_SYNONYM:
|
119 |
+
split = 'val'
|
120 |
+
ds = Places365(split=split, **torch_kwargs)
|
121 |
+
elif name == 'imagenet':
|
122 |
+
if split in _EVAL_SYNONYM:
|
123 |
+
split = 'val'
|
124 |
+
ds = ImageNet(split=split, **torch_kwargs)
|
125 |
+
elif name == 'image_folder' or name == 'folder':
|
126 |
+
# in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
|
127 |
+
if search_split and os.path.isdir(root):
|
128 |
+
# look for split specific sub-folder in root
|
129 |
+
root = _search_split(root, split)
|
130 |
+
ds = ImageFolder(root, **kwargs)
|
131 |
+
else:
|
132 |
+
assert False, f"Unknown torchvision dataset {name}"
|
133 |
+
elif name.startswith('tfds/'):
|
134 |
+
ds = IterableImageDataset(
|
135 |
+
root, parser=name, split=split, is_training=is_training,
|
136 |
+
download=download, batch_size=batch_size, repeats=repeats, **kwargs)
|
137 |
+
else:
|
138 |
+
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
139 |
+
if search_split and os.path.isdir(root):
|
140 |
+
# look for split specific sub-folder in root
|
141 |
+
root = _search_split(root, split)
|
142 |
+
ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
|
143 |
+
return ds
|
lib/timm/data/distributed_sampler.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Sampler
|
4 |
+
import torch.distributed as dist
|
5 |
+
|
6 |
+
|
7 |
+
class OrderedDistributedSampler(Sampler):
|
8 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
9 |
+
It is especially useful in conjunction with
|
10 |
+
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
|
11 |
+
process can pass a DistributedSampler instance as a DataLoader sampler,
|
12 |
+
and load a subset of the original dataset that is exclusive to it.
|
13 |
+
.. note::
|
14 |
+
Dataset is assumed to be of constant size.
|
15 |
+
Arguments:
|
16 |
+
dataset: Dataset used for sampling.
|
17 |
+
num_replicas (optional): Number of processes participating in
|
18 |
+
distributed training.
|
19 |
+
rank (optional): Rank of the current process within num_replicas.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, dataset, num_replicas=None, rank=None):
|
23 |
+
if num_replicas is None:
|
24 |
+
if not dist.is_available():
|
25 |
+
raise RuntimeError("Requires distributed package to be available")
|
26 |
+
num_replicas = dist.get_world_size()
|
27 |
+
if rank is None:
|
28 |
+
if not dist.is_available():
|
29 |
+
raise RuntimeError("Requires distributed package to be available")
|
30 |
+
rank = dist.get_rank()
|
31 |
+
self.dataset = dataset
|
32 |
+
self.num_replicas = num_replicas
|
33 |
+
self.rank = rank
|
34 |
+
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
35 |
+
self.total_size = self.num_samples * self.num_replicas
|
36 |
+
|
37 |
+
def __iter__(self):
|
38 |
+
indices = list(range(len(self.dataset)))
|
39 |
+
|
40 |
+
# add extra samples to make it evenly divisible
|
41 |
+
indices += indices[:(self.total_size - len(indices))]
|
42 |
+
assert len(indices) == self.total_size
|
43 |
+
|
44 |
+
# subsample
|
45 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
46 |
+
assert len(indices) == self.num_samples
|
47 |
+
|
48 |
+
return iter(indices)
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return self.num_samples
|
52 |
+
|
53 |
+
|
54 |
+
class RepeatAugSampler(Sampler):
|
55 |
+
"""Sampler that restricts data loading to a subset of the dataset for distributed,
|
56 |
+
with repeated augmentation.
|
57 |
+
It ensures that different each augmented version of a sample will be visible to a
|
58 |
+
different process (GPU). Heavily based on torch.utils.data.DistributedSampler
|
59 |
+
|
60 |
+
This sampler was taken from https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
|
61 |
+
Used in
|
62 |
+
Copyright (c) 2015-present, Facebook, Inc.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
dataset,
|
68 |
+
num_replicas=None,
|
69 |
+
rank=None,
|
70 |
+
shuffle=True,
|
71 |
+
num_repeats=3,
|
72 |
+
selected_round=256,
|
73 |
+
selected_ratio=0,
|
74 |
+
):
|
75 |
+
if num_replicas is None:
|
76 |
+
if not dist.is_available():
|
77 |
+
raise RuntimeError("Requires distributed package to be available")
|
78 |
+
num_replicas = dist.get_world_size()
|
79 |
+
if rank is None:
|
80 |
+
if not dist.is_available():
|
81 |
+
raise RuntimeError("Requires distributed package to be available")
|
82 |
+
rank = dist.get_rank()
|
83 |
+
self.dataset = dataset
|
84 |
+
self.num_replicas = num_replicas
|
85 |
+
self.rank = rank
|
86 |
+
self.shuffle = shuffle
|
87 |
+
self.num_repeats = num_repeats
|
88 |
+
self.epoch = 0
|
89 |
+
self.num_samples = int(math.ceil(len(self.dataset) * num_repeats / self.num_replicas))
|
90 |
+
self.total_size = self.num_samples * self.num_replicas
|
91 |
+
# Determine the number of samples to select per epoch for each rank.
|
92 |
+
# num_selected logic defaults to be the same as original RASampler impl, but this one can be tweaked
|
93 |
+
# via selected_ratio and selected_round args.
|
94 |
+
selected_ratio = selected_ratio or num_replicas # ratio to reduce selected samples by, num_replicas if 0
|
95 |
+
if selected_round:
|
96 |
+
self.num_selected_samples = int(math.floor(
|
97 |
+
len(self.dataset) // selected_round * selected_round / selected_ratio))
|
98 |
+
else:
|
99 |
+
self.num_selected_samples = int(math.ceil(len(self.dataset) / selected_ratio))
|
100 |
+
|
101 |
+
def __iter__(self):
|
102 |
+
# deterministically shuffle based on epoch
|
103 |
+
g = torch.Generator()
|
104 |
+
g.manual_seed(self.epoch)
|
105 |
+
if self.shuffle:
|
106 |
+
indices = torch.randperm(len(self.dataset), generator=g)
|
107 |
+
else:
|
108 |
+
indices = torch.arange(start=0, end=len(self.dataset))
|
109 |
+
|
110 |
+
# produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
|
111 |
+
indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
|
112 |
+
# add extra samples to make it evenly divisible
|
113 |
+
padding_size = self.total_size - len(indices)
|
114 |
+
if padding_size > 0:
|
115 |
+
indices += indices[:padding_size]
|
116 |
+
assert len(indices) == self.total_size
|
117 |
+
|
118 |
+
# subsample per rank
|
119 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
120 |
+
assert len(indices) == self.num_samples
|
121 |
+
|
122 |
+
# return up to num selected samples
|
123 |
+
return iter(indices[:self.num_selected_samples])
|
124 |
+
|
125 |
+
def __len__(self):
|
126 |
+
return self.num_selected_samples
|
127 |
+
|
128 |
+
def set_epoch(self, epoch):
|
129 |
+
self.epoch = epoch
|
lib/timm/data/loader.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Loader Factory, Fast Collate, CUDA Prefetcher
|
2 |
+
|
3 |
+
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
|
4 |
+
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
|
5 |
+
|
6 |
+
Hacked together by / Copyright 2019, Ross Wightman
|
7 |
+
"""
|
8 |
+
import random
|
9 |
+
from functools import partial
|
10 |
+
from typing import Callable
|
11 |
+
|
12 |
+
import torch.utils.data
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
from .transforms_factory import create_transform
|
16 |
+
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
17 |
+
from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
|
18 |
+
from .random_erasing import RandomErasing
|
19 |
+
from .mixup import FastCollateMixup
|
20 |
+
|
21 |
+
|
22 |
+
def fast_collate(batch):
|
23 |
+
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
|
24 |
+
assert isinstance(batch[0], tuple)
|
25 |
+
batch_size = len(batch)
|
26 |
+
if isinstance(batch[0][0], tuple):
|
27 |
+
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
|
28 |
+
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
|
29 |
+
inner_tuple_size = len(batch[0][0])
|
30 |
+
flattened_batch_size = batch_size * inner_tuple_size
|
31 |
+
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
|
32 |
+
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
|
33 |
+
for i in range(batch_size):
|
34 |
+
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
|
35 |
+
for j in range(inner_tuple_size):
|
36 |
+
targets[i + j * batch_size] = batch[i][1]
|
37 |
+
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
|
38 |
+
return tensor, targets
|
39 |
+
elif isinstance(batch[0][0], np.ndarray):
|
40 |
+
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
41 |
+
assert len(targets) == batch_size
|
42 |
+
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
43 |
+
for i in range(batch_size):
|
44 |
+
tensor[i] += torch.from_numpy(batch[i][0])
|
45 |
+
return tensor, targets
|
46 |
+
elif isinstance(batch[0][0], torch.Tensor):
|
47 |
+
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
48 |
+
assert len(targets) == batch_size
|
49 |
+
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
50 |
+
for i in range(batch_size):
|
51 |
+
tensor[i].copy_(batch[i][0])
|
52 |
+
return tensor, targets
|
53 |
+
else:
|
54 |
+
assert False
|
55 |
+
|
56 |
+
|
57 |
+
class PrefetchLoader:
|
58 |
+
|
59 |
+
def __init__(self,
|
60 |
+
loader,
|
61 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
62 |
+
std=IMAGENET_DEFAULT_STD,
|
63 |
+
fp16=False,
|
64 |
+
re_prob=0.,
|
65 |
+
re_mode='const',
|
66 |
+
re_count=1,
|
67 |
+
re_num_splits=0):
|
68 |
+
self.loader = loader
|
69 |
+
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
70 |
+
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
71 |
+
self.fp16 = fp16
|
72 |
+
if fp16:
|
73 |
+
self.mean = self.mean.half()
|
74 |
+
self.std = self.std.half()
|
75 |
+
if re_prob > 0.:
|
76 |
+
self.random_erasing = RandomErasing(
|
77 |
+
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
|
78 |
+
else:
|
79 |
+
self.random_erasing = None
|
80 |
+
|
81 |
+
def __iter__(self):
|
82 |
+
stream = torch.cuda.Stream()
|
83 |
+
first = True
|
84 |
+
|
85 |
+
for next_input, next_target in self.loader:
|
86 |
+
with torch.cuda.stream(stream):
|
87 |
+
next_input = next_input.cuda(non_blocking=True)
|
88 |
+
next_target = next_target.cuda(non_blocking=True)
|
89 |
+
if self.fp16:
|
90 |
+
next_input = next_input.half().sub_(self.mean).div_(self.std)
|
91 |
+
else:
|
92 |
+
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
93 |
+
if self.random_erasing is not None:
|
94 |
+
next_input = self.random_erasing(next_input)
|
95 |
+
|
96 |
+
if not first:
|
97 |
+
yield input, target
|
98 |
+
else:
|
99 |
+
first = False
|
100 |
+
|
101 |
+
torch.cuda.current_stream().wait_stream(stream)
|
102 |
+
input = next_input
|
103 |
+
target = next_target
|
104 |
+
|
105 |
+
yield input, target
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
return len(self.loader)
|
109 |
+
|
110 |
+
@property
|
111 |
+
def sampler(self):
|
112 |
+
return self.loader.sampler
|
113 |
+
|
114 |
+
@property
|
115 |
+
def dataset(self):
|
116 |
+
return self.loader.dataset
|
117 |
+
|
118 |
+
@property
|
119 |
+
def mixup_enabled(self):
|
120 |
+
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
121 |
+
return self.loader.collate_fn.mixup_enabled
|
122 |
+
else:
|
123 |
+
return False
|
124 |
+
|
125 |
+
@mixup_enabled.setter
|
126 |
+
def mixup_enabled(self, x):
|
127 |
+
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
128 |
+
self.loader.collate_fn.mixup_enabled = x
|
129 |
+
|
130 |
+
|
131 |
+
def _worker_init(worker_id, worker_seeding='all'):
|
132 |
+
worker_info = torch.utils.data.get_worker_info()
|
133 |
+
assert worker_info.id == worker_id
|
134 |
+
if isinstance(worker_seeding, Callable):
|
135 |
+
seed = worker_seeding(worker_info)
|
136 |
+
random.seed(seed)
|
137 |
+
torch.manual_seed(seed)
|
138 |
+
np.random.seed(seed % (2 ** 32 - 1))
|
139 |
+
else:
|
140 |
+
assert worker_seeding in ('all', 'part')
|
141 |
+
# random / torch seed already called in dataloader iter class w/ worker_info.seed
|
142 |
+
# to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)
|
143 |
+
if worker_seeding == 'all':
|
144 |
+
np.random.seed(worker_info.seed % (2 ** 32 - 1))
|
145 |
+
|
146 |
+
|
147 |
+
def create_loader(
|
148 |
+
dataset,
|
149 |
+
input_size,
|
150 |
+
batch_size,
|
151 |
+
is_training=False,
|
152 |
+
use_prefetcher=True,
|
153 |
+
no_aug=False,
|
154 |
+
re_prob=0.,
|
155 |
+
re_mode='const',
|
156 |
+
re_count=1,
|
157 |
+
re_split=False,
|
158 |
+
scale=None,
|
159 |
+
ratio=None,
|
160 |
+
hflip=0.5,
|
161 |
+
vflip=0.,
|
162 |
+
color_jitter=0.4,
|
163 |
+
auto_augment=None,
|
164 |
+
num_aug_repeats=0,
|
165 |
+
num_aug_splits=0,
|
166 |
+
interpolation='bilinear',
|
167 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
168 |
+
std=IMAGENET_DEFAULT_STD,
|
169 |
+
num_workers=1,
|
170 |
+
distributed=False,
|
171 |
+
crop_pct=None,
|
172 |
+
collate_fn=None,
|
173 |
+
pin_memory=False,
|
174 |
+
fp16=False,
|
175 |
+
tf_preprocessing=False,
|
176 |
+
use_multi_epochs_loader=False,
|
177 |
+
persistent_workers=True,
|
178 |
+
worker_seeding='all',
|
179 |
+
):
|
180 |
+
re_num_splits = 0
|
181 |
+
if re_split:
|
182 |
+
# apply RE to second half of batch if no aug split otherwise line up with aug split
|
183 |
+
re_num_splits = num_aug_splits or 2
|
184 |
+
dataset.transform = create_transform(
|
185 |
+
input_size,
|
186 |
+
is_training=is_training,
|
187 |
+
use_prefetcher=use_prefetcher,
|
188 |
+
no_aug=no_aug,
|
189 |
+
scale=scale,
|
190 |
+
ratio=ratio,
|
191 |
+
hflip=hflip,
|
192 |
+
vflip=vflip,
|
193 |
+
color_jitter=color_jitter,
|
194 |
+
auto_augment=auto_augment,
|
195 |
+
interpolation=interpolation,
|
196 |
+
mean=mean,
|
197 |
+
std=std,
|
198 |
+
crop_pct=crop_pct,
|
199 |
+
tf_preprocessing=tf_preprocessing,
|
200 |
+
re_prob=re_prob,
|
201 |
+
re_mode=re_mode,
|
202 |
+
re_count=re_count,
|
203 |
+
re_num_splits=re_num_splits,
|
204 |
+
separate=num_aug_splits > 0,
|
205 |
+
)
|
206 |
+
|
207 |
+
sampler = None
|
208 |
+
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
|
209 |
+
if is_training:
|
210 |
+
if num_aug_repeats:
|
211 |
+
sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats)
|
212 |
+
else:
|
213 |
+
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
214 |
+
else:
|
215 |
+
# This will add extra duplicate entries to result in equal num
|
216 |
+
# of samples per-process, will slightly alter validation results
|
217 |
+
sampler = OrderedDistributedSampler(dataset)
|
218 |
+
else:
|
219 |
+
assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"
|
220 |
+
|
221 |
+
if collate_fn is None:
|
222 |
+
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
|
223 |
+
|
224 |
+
loader_class = torch.utils.data.DataLoader
|
225 |
+
if use_multi_epochs_loader:
|
226 |
+
loader_class = MultiEpochsDataLoader
|
227 |
+
|
228 |
+
loader_args = dict(
|
229 |
+
batch_size=batch_size,
|
230 |
+
shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
|
231 |
+
num_workers=num_workers,
|
232 |
+
sampler=sampler,
|
233 |
+
collate_fn=collate_fn,
|
234 |
+
pin_memory=pin_memory,
|
235 |
+
drop_last=is_training,
|
236 |
+
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
|
237 |
+
persistent_workers=persistent_workers
|
238 |
+
)
|
239 |
+
try:
|
240 |
+
loader = loader_class(dataset, **loader_args)
|
241 |
+
except TypeError as e:
|
242 |
+
loader_args.pop('persistent_workers') # only in Pytorch 1.7+
|
243 |
+
loader = loader_class(dataset, **loader_args)
|
244 |
+
if use_prefetcher:
|
245 |
+
prefetch_re_prob = re_prob if is_training and not no_aug else 0.
|
246 |
+
loader = PrefetchLoader(
|
247 |
+
loader,
|
248 |
+
mean=mean,
|
249 |
+
std=std,
|
250 |
+
fp16=fp16,
|
251 |
+
re_prob=prefetch_re_prob,
|
252 |
+
re_mode=re_mode,
|
253 |
+
re_count=re_count,
|
254 |
+
re_num_splits=re_num_splits
|
255 |
+
)
|
256 |
+
|
257 |
+
return loader
|
258 |
+
|
259 |
+
|
260 |
+
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
261 |
+
|
262 |
+
def __init__(self, *args, **kwargs):
|
263 |
+
super().__init__(*args, **kwargs)
|
264 |
+
self._DataLoader__initialized = False
|
265 |
+
self.batch_sampler = _RepeatSampler(self.batch_sampler)
|
266 |
+
self._DataLoader__initialized = True
|
267 |
+
self.iterator = super().__iter__()
|
268 |
+
|
269 |
+
def __len__(self):
|
270 |
+
return len(self.batch_sampler.sampler)
|
271 |
+
|
272 |
+
def __iter__(self):
|
273 |
+
for i in range(len(self)):
|
274 |
+
yield next(self.iterator)
|
275 |
+
|
276 |
+
|
277 |
+
class _RepeatSampler(object):
|
278 |
+
""" Sampler that repeats forever.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
sampler (Sampler)
|
282 |
+
"""
|
283 |
+
|
284 |
+
def __init__(self, sampler):
|
285 |
+
self.sampler = sampler
|
286 |
+
|
287 |
+
def __iter__(self):
|
288 |
+
while True:
|
289 |
+
yield from iter(self.sampler)
|
lib/timm/data/mixup.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Mixup and Cutmix
|
2 |
+
|
3 |
+
Papers:
|
4 |
+
mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
|
5 |
+
|
6 |
+
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
|
7 |
+
|
8 |
+
Code Reference:
|
9 |
+
CutMix: https://github.com/clovaai/CutMix-PyTorch
|
10 |
+
|
11 |
+
Hacked together by / Copyright 2019, Ross Wightman
|
12 |
+
"""
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
|
16 |
+
|
17 |
+
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
|
18 |
+
x = x.long().view(-1, 1)
|
19 |
+
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
|
20 |
+
|
21 |
+
|
22 |
+
def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
|
23 |
+
off_value = smoothing / num_classes
|
24 |
+
on_value = 1. - smoothing + off_value
|
25 |
+
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
|
26 |
+
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
|
27 |
+
return y1 * lam + y2 * (1. - lam)
|
28 |
+
|
29 |
+
|
30 |
+
def rand_bbox(img_shape, lam, margin=0., count=None):
|
31 |
+
""" Standard CutMix bounding-box
|
32 |
+
Generates a random square bbox based on lambda value. This impl includes
|
33 |
+
support for enforcing a border margin as percent of bbox dimensions.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
img_shape (tuple): Image shape as tuple
|
37 |
+
lam (float): Cutmix lambda value
|
38 |
+
margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
|
39 |
+
count (int): Number of bbox to generate
|
40 |
+
"""
|
41 |
+
ratio = np.sqrt(1 - lam)
|
42 |
+
img_h, img_w = img_shape[-2:]
|
43 |
+
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
|
44 |
+
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
|
45 |
+
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
|
46 |
+
cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
|
47 |
+
yl = np.clip(cy - cut_h // 2, 0, img_h)
|
48 |
+
yh = np.clip(cy + cut_h // 2, 0, img_h)
|
49 |
+
xl = np.clip(cx - cut_w // 2, 0, img_w)
|
50 |
+
xh = np.clip(cx + cut_w // 2, 0, img_w)
|
51 |
+
return yl, yh, xl, xh
|
52 |
+
|
53 |
+
|
54 |
+
def rand_bbox_minmax(img_shape, minmax, count=None):
|
55 |
+
""" Min-Max CutMix bounding-box
|
56 |
+
Inspired by Darknet cutmix impl, generates a random rectangular bbox
|
57 |
+
based on min/max percent values applied to each dimension of the input image.
|
58 |
+
|
59 |
+
Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
img_shape (tuple): Image shape as tuple
|
63 |
+
minmax (tuple or list): Min and max bbox ratios (as percent of image size)
|
64 |
+
count (int): Number of bbox to generate
|
65 |
+
"""
|
66 |
+
assert len(minmax) == 2
|
67 |
+
img_h, img_w = img_shape[-2:]
|
68 |
+
cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
|
69 |
+
cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
|
70 |
+
yl = np.random.randint(0, img_h - cut_h, size=count)
|
71 |
+
xl = np.random.randint(0, img_w - cut_w, size=count)
|
72 |
+
yu = yl + cut_h
|
73 |
+
xu = xl + cut_w
|
74 |
+
return yl, yu, xl, xu
|
75 |
+
|
76 |
+
|
77 |
+
def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
|
78 |
+
""" Generate bbox and apply lambda correction.
|
79 |
+
"""
|
80 |
+
if ratio_minmax is not None:
|
81 |
+
yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
|
82 |
+
else:
|
83 |
+
yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
|
84 |
+
if correct_lam or ratio_minmax is not None:
|
85 |
+
bbox_area = (yu - yl) * (xu - xl)
|
86 |
+
lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
|
87 |
+
return (yl, yu, xl, xu), lam
|
88 |
+
|
89 |
+
|
90 |
+
class Mixup:
|
91 |
+
""" Mixup/Cutmix that applies different params to each element or whole batch
|
92 |
+
|
93 |
+
Args:
|
94 |
+
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
|
95 |
+
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
|
96 |
+
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
|
97 |
+
prob (float): probability of applying mixup or cutmix per batch or element
|
98 |
+
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
|
99 |
+
mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
|
100 |
+
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
|
101 |
+
label_smoothing (float): apply label smoothing to the mixed target tensor
|
102 |
+
num_classes (int): number of classes for target
|
103 |
+
"""
|
104 |
+
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
|
105 |
+
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
|
106 |
+
self.mixup_alpha = mixup_alpha
|
107 |
+
self.cutmix_alpha = cutmix_alpha
|
108 |
+
self.cutmix_minmax = cutmix_minmax
|
109 |
+
if self.cutmix_minmax is not None:
|
110 |
+
assert len(self.cutmix_minmax) == 2
|
111 |
+
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
|
112 |
+
self.cutmix_alpha = 1.0
|
113 |
+
self.mix_prob = prob
|
114 |
+
self.switch_prob = switch_prob
|
115 |
+
self.label_smoothing = label_smoothing
|
116 |
+
self.num_classes = num_classes
|
117 |
+
self.mode = mode
|
118 |
+
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
|
119 |
+
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
|
120 |
+
|
121 |
+
def _params_per_elem(self, batch_size):
|
122 |
+
lam = np.ones(batch_size, dtype=np.float32)
|
123 |
+
use_cutmix = np.zeros(batch_size, dtype=np.bool)
|
124 |
+
if self.mixup_enabled:
|
125 |
+
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
|
126 |
+
use_cutmix = np.random.rand(batch_size) < self.switch_prob
|
127 |
+
lam_mix = np.where(
|
128 |
+
use_cutmix,
|
129 |
+
np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
|
130 |
+
np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
|
131 |
+
elif self.mixup_alpha > 0.:
|
132 |
+
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
|
133 |
+
elif self.cutmix_alpha > 0.:
|
134 |
+
use_cutmix = np.ones(batch_size, dtype=np.bool)
|
135 |
+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
|
136 |
+
else:
|
137 |
+
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
|
138 |
+
lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
|
139 |
+
return lam, use_cutmix
|
140 |
+
|
141 |
+
def _params_per_batch(self):
|
142 |
+
lam = 1.
|
143 |
+
use_cutmix = False
|
144 |
+
if self.mixup_enabled and np.random.rand() < self.mix_prob:
|
145 |
+
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
|
146 |
+
use_cutmix = np.random.rand() < self.switch_prob
|
147 |
+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
|
148 |
+
np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
149 |
+
elif self.mixup_alpha > 0.:
|
150 |
+
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
151 |
+
elif self.cutmix_alpha > 0.:
|
152 |
+
use_cutmix = True
|
153 |
+
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
|
154 |
+
else:
|
155 |
+
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
|
156 |
+
lam = float(lam_mix)
|
157 |
+
return lam, use_cutmix
|
158 |
+
|
159 |
+
def _mix_elem(self, x):
|
160 |
+
batch_size = len(x)
|
161 |
+
lam_batch, use_cutmix = self._params_per_elem(batch_size)
|
162 |
+
x_orig = x.clone() # need to keep an unmodified original for mixing source
|
163 |
+
for i in range(batch_size):
|
164 |
+
j = batch_size - i - 1
|
165 |
+
lam = lam_batch[i]
|
166 |
+
if lam != 1.:
|
167 |
+
if use_cutmix[i]:
|
168 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
169 |
+
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
170 |
+
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
|
171 |
+
lam_batch[i] = lam
|
172 |
+
else:
|
173 |
+
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
|
174 |
+
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
|
175 |
+
|
176 |
+
def _mix_pair(self, x):
|
177 |
+
batch_size = len(x)
|
178 |
+
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
|
179 |
+
x_orig = x.clone() # need to keep an unmodified original for mixing source
|
180 |
+
for i in range(batch_size // 2):
|
181 |
+
j = batch_size - i - 1
|
182 |
+
lam = lam_batch[i]
|
183 |
+
if lam != 1.:
|
184 |
+
if use_cutmix[i]:
|
185 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
186 |
+
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
187 |
+
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
|
188 |
+
x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
|
189 |
+
lam_batch[i] = lam
|
190 |
+
else:
|
191 |
+
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
|
192 |
+
x[j] = x[j] * lam + x_orig[i] * (1 - lam)
|
193 |
+
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
|
194 |
+
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
|
195 |
+
|
196 |
+
def _mix_batch(self, x):
|
197 |
+
lam, use_cutmix = self._params_per_batch()
|
198 |
+
if lam == 1.:
|
199 |
+
return 1.
|
200 |
+
if use_cutmix:
|
201 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
202 |
+
x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
203 |
+
x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
|
204 |
+
else:
|
205 |
+
x_flipped = x.flip(0).mul_(1. - lam)
|
206 |
+
x.mul_(lam).add_(x_flipped)
|
207 |
+
return lam
|
208 |
+
|
209 |
+
def __call__(self, x, target):
|
210 |
+
assert len(x) % 2 == 0, 'Batch size should be even when using this'
|
211 |
+
if self.mode == 'elem':
|
212 |
+
lam = self._mix_elem(x)
|
213 |
+
elif self.mode == 'pair':
|
214 |
+
lam = self._mix_pair(x)
|
215 |
+
else:
|
216 |
+
lam = self._mix_batch(x)
|
217 |
+
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
|
218 |
+
return x, target
|
219 |
+
|
220 |
+
|
221 |
+
class FastCollateMixup(Mixup):
|
222 |
+
""" Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
|
223 |
+
|
224 |
+
A Mixup impl that's performed while collating the batches.
|
225 |
+
"""
|
226 |
+
|
227 |
+
def _mix_elem_collate(self, output, batch, half=False):
|
228 |
+
batch_size = len(batch)
|
229 |
+
num_elem = batch_size // 2 if half else batch_size
|
230 |
+
assert len(output) == num_elem
|
231 |
+
lam_batch, use_cutmix = self._params_per_elem(num_elem)
|
232 |
+
for i in range(num_elem):
|
233 |
+
j = batch_size - i - 1
|
234 |
+
lam = lam_batch[i]
|
235 |
+
mixed = batch[i][0]
|
236 |
+
if lam != 1.:
|
237 |
+
if use_cutmix[i]:
|
238 |
+
if not half:
|
239 |
+
mixed = mixed.copy()
|
240 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
241 |
+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
242 |
+
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
|
243 |
+
lam_batch[i] = lam
|
244 |
+
else:
|
245 |
+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
|
246 |
+
np.rint(mixed, out=mixed)
|
247 |
+
output[i] += torch.from_numpy(mixed.astype(np.uint8))
|
248 |
+
if half:
|
249 |
+
lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
|
250 |
+
return torch.tensor(lam_batch).unsqueeze(1)
|
251 |
+
|
252 |
+
def _mix_pair_collate(self, output, batch):
|
253 |
+
batch_size = len(batch)
|
254 |
+
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
|
255 |
+
for i in range(batch_size // 2):
|
256 |
+
j = batch_size - i - 1
|
257 |
+
lam = lam_batch[i]
|
258 |
+
mixed_i = batch[i][0]
|
259 |
+
mixed_j = batch[j][0]
|
260 |
+
assert 0 <= lam <= 1.0
|
261 |
+
if lam < 1.:
|
262 |
+
if use_cutmix[i]:
|
263 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
264 |
+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
265 |
+
patch_i = mixed_i[:, yl:yh, xl:xh].copy()
|
266 |
+
mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
|
267 |
+
mixed_j[:, yl:yh, xl:xh] = patch_i
|
268 |
+
lam_batch[i] = lam
|
269 |
+
else:
|
270 |
+
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
|
271 |
+
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
|
272 |
+
mixed_i = mixed_temp
|
273 |
+
np.rint(mixed_j, out=mixed_j)
|
274 |
+
np.rint(mixed_i, out=mixed_i)
|
275 |
+
output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
|
276 |
+
output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
|
277 |
+
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
|
278 |
+
return torch.tensor(lam_batch).unsqueeze(1)
|
279 |
+
|
280 |
+
def _mix_batch_collate(self, output, batch):
|
281 |
+
batch_size = len(batch)
|
282 |
+
lam, use_cutmix = self._params_per_batch()
|
283 |
+
if use_cutmix:
|
284 |
+
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
285 |
+
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
286 |
+
for i in range(batch_size):
|
287 |
+
j = batch_size - i - 1
|
288 |
+
mixed = batch[i][0]
|
289 |
+
if lam != 1.:
|
290 |
+
if use_cutmix:
|
291 |
+
mixed = mixed.copy() # don't want to modify the original while iterating
|
292 |
+
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
|
293 |
+
else:
|
294 |
+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
|
295 |
+
np.rint(mixed, out=mixed)
|
296 |
+
output[i] += torch.from_numpy(mixed.astype(np.uint8))
|
297 |
+
return lam
|
298 |
+
|
299 |
+
def __call__(self, batch, _=None):
|
300 |
+
batch_size = len(batch)
|
301 |
+
assert batch_size % 2 == 0, 'Batch size should be even when using this'
|
302 |
+
half = 'half' in self.mode
|
303 |
+
if half:
|
304 |
+
batch_size //= 2
|
305 |
+
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
306 |
+
if self.mode == 'elem' or self.mode == 'half':
|
307 |
+
lam = self._mix_elem_collate(output, batch, half=half)
|
308 |
+
elif self.mode == 'pair':
|
309 |
+
lam = self._mix_pair_collate(output, batch)
|
310 |
+
else:
|
311 |
+
lam = self._mix_batch_collate(output, batch)
|
312 |
+
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
313 |
+
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
|
314 |
+
target = target[:batch_size]
|
315 |
+
return output, target
|
316 |
+
|
lib/timm/data/parsers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .parser_factory import create_parser
|
lib/timm/data/parsers/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (212 Bytes). View file
|
|
lib/timm/data/parsers/__pycache__/class_map.cpython-310.pyc
ADDED
Binary file (942 Bytes). View file
|
|
lib/timm/data/parsers/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (199 Bytes). View file
|
|
lib/timm/data/parsers/__pycache__/parser.cpython-310.pyc
ADDED
Binary file (1.11 kB). View file
|
|
lib/timm/data/parsers/__pycache__/parser_factory.cpython-310.pyc
ADDED
Binary file (877 Bytes). View file
|
|
lib/timm/data/parsers/__pycache__/parser_image_folder.cpython-310.pyc
ADDED
Binary file (3.06 kB). View file
|
|
lib/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-310.pyc
ADDED
Binary file (7.68 kB). View file
|
|
lib/timm/data/parsers/__pycache__/parser_image_tar.cpython-310.pyc
ADDED
Binary file (3.14 kB). View file
|
|
lib/timm/data/parsers/__pycache__/parser_tfds.cpython-310.pyc
ADDED
Binary file (10.1 kB). View file
|
|
lib/timm/data/parsers/class_map.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
def load_class_map(map_or_filename, root=''):
|
5 |
+
if isinstance(map_or_filename, dict):
|
6 |
+
assert dict, 'class_map dict must be non-empty'
|
7 |
+
return map_or_filename
|
8 |
+
class_map_path = map_or_filename
|
9 |
+
if not os.path.exists(class_map_path):
|
10 |
+
class_map_path = os.path.join(root, class_map_path)
|
11 |
+
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename
|
12 |
+
class_map_ext = os.path.splitext(map_or_filename)[-1].lower()
|
13 |
+
if class_map_ext == '.txt':
|
14 |
+
with open(class_map_path) as f:
|
15 |
+
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
|
16 |
+
else:
|
17 |
+
assert False, f'Unsupported class map file extension ({class_map_ext}).'
|
18 |
+
return class_to_idx
|
19 |
+
|
lib/timm/data/parsers/constants.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')
|
lib/timm/data/parsers/parser.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class Parser:
|
5 |
+
def __init__(self):
|
6 |
+
pass
|
7 |
+
|
8 |
+
@abstractmethod
|
9 |
+
def _filename(self, index, basename=False, absolute=False):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def filename(self, index, basename=False, absolute=False):
|
13 |
+
return self._filename(index, basename=basename, absolute=absolute)
|
14 |
+
|
15 |
+
def filenames(self, basename=False, absolute=False):
|
16 |
+
return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]
|
17 |
+
|
lib/timm/data/parsers/parser_factory.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from .parser_image_folder import ParserImageFolder
|
4 |
+
from .parser_image_tar import ParserImageTar
|
5 |
+
from .parser_image_in_tar import ParserImageInTar
|
6 |
+
|
7 |
+
|
8 |
+
def create_parser(name, root, split='train', **kwargs):
|
9 |
+
name = name.lower()
|
10 |
+
name = name.split('/', 2)
|
11 |
+
prefix = ''
|
12 |
+
if len(name) > 1:
|
13 |
+
prefix = name[0]
|
14 |
+
name = name[-1]
|
15 |
+
|
16 |
+
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
17 |
+
# explicitly select other options shortly
|
18 |
+
if prefix == 'tfds':
|
19 |
+
from .parser_tfds import ParserTfds # defer tensorflow import
|
20 |
+
parser = ParserTfds(root, name, split=split, **kwargs)
|
21 |
+
else:
|
22 |
+
assert os.path.exists(root)
|
23 |
+
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
24 |
+
# FIXME support split here, in parser?
|
25 |
+
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
|
26 |
+
parser = ParserImageInTar(root, **kwargs)
|
27 |
+
else:
|
28 |
+
parser = ParserImageFolder(root, **kwargs)
|
29 |
+
return parser
|
lib/timm/data/parsers/parser_image_folder.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" A dataset parser that reads images from folders
|
2 |
+
|
3 |
+
Folders are scannerd recursively to find image files. Labels are based
|
4 |
+
on the folder hierarchy, just leaf folders by default.
|
5 |
+
|
6 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
7 |
+
"""
|
8 |
+
import os
|
9 |
+
|
10 |
+
from timm.utils.misc import natural_key
|
11 |
+
|
12 |
+
from .parser import Parser
|
13 |
+
from .class_map import load_class_map
|
14 |
+
from .constants import IMG_EXTENSIONS
|
15 |
+
|
16 |
+
|
17 |
+
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
|
18 |
+
labels = []
|
19 |
+
filenames = []
|
20 |
+
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
|
21 |
+
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
|
22 |
+
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
|
23 |
+
for f in files:
|
24 |
+
base, ext = os.path.splitext(f)
|
25 |
+
if ext.lower() in types:
|
26 |
+
filenames.append(os.path.join(root, f))
|
27 |
+
labels.append(label)
|
28 |
+
if class_to_idx is None:
|
29 |
+
# building class index
|
30 |
+
unique_labels = set(labels)
|
31 |
+
sorted_labels = list(sorted(unique_labels, key=natural_key))
|
32 |
+
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
33 |
+
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
|
34 |
+
if sort:
|
35 |
+
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
|
36 |
+
return images_and_targets, class_to_idx
|
37 |
+
|
38 |
+
|
39 |
+
class ParserImageFolder(Parser):
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
root,
|
44 |
+
class_map=''):
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
self.root = root
|
48 |
+
class_to_idx = None
|
49 |
+
if class_map:
|
50 |
+
class_to_idx = load_class_map(class_map, root)
|
51 |
+
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
|
52 |
+
if len(self.samples) == 0:
|
53 |
+
raise RuntimeError(
|
54 |
+
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
|
55 |
+
|
56 |
+
def __getitem__(self, index):
|
57 |
+
path, target = self.samples[index]
|
58 |
+
return open(path, 'rb'), target
|
59 |
+
|
60 |
+
def __len__(self):
|
61 |
+
return len(self.samples)
|
62 |
+
|
63 |
+
def _filename(self, index, basename=False, absolute=False):
|
64 |
+
filename = self.samples[index][0]
|
65 |
+
if basename:
|
66 |
+
filename = os.path.basename(filename)
|
67 |
+
elif not absolute:
|
68 |
+
filename = os.path.relpath(filename, self.root)
|
69 |
+
return filename
|
lib/timm/data/parsers/parser_image_in_tar.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" A dataset parser that reads tarfile based datasets
|
2 |
+
|
3 |
+
This parser can read and extract image samples from:
|
4 |
+
* a single tar of image files
|
5 |
+
* a folder of multiple tarfiles containing imagefiles
|
6 |
+
* a tar of tars containing image files
|
7 |
+
|
8 |
+
Labels are based on the combined folder and/or tar name structure.
|
9 |
+
|
10 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
11 |
+
"""
|
12 |
+
import os
|
13 |
+
import tarfile
|
14 |
+
import pickle
|
15 |
+
import logging
|
16 |
+
import numpy as np
|
17 |
+
from glob import glob
|
18 |
+
from typing import List, Dict
|
19 |
+
|
20 |
+
from timm.utils.misc import natural_key
|
21 |
+
|
22 |
+
from .parser import Parser
|
23 |
+
from .class_map import load_class_map
|
24 |
+
from .constants import IMG_EXTENSIONS
|
25 |
+
|
26 |
+
|
27 |
+
_logger = logging.getLogger(__name__)
|
28 |
+
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
|
29 |
+
|
30 |
+
|
31 |
+
class TarState:
|
32 |
+
|
33 |
+
def __init__(self, tf: tarfile.TarFile = None, ti: tarfile.TarInfo = None):
|
34 |
+
self.tf: tarfile.TarFile = tf
|
35 |
+
self.ti: tarfile.TarInfo = ti
|
36 |
+
self.children: Dict[str, TarState] = {} # child states (tars within tars)
|
37 |
+
|
38 |
+
def reset(self):
|
39 |
+
self.tf = None
|
40 |
+
|
41 |
+
|
42 |
+
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS):
|
43 |
+
sample_count = 0
|
44 |
+
for i, ti in enumerate(tf):
|
45 |
+
if not ti.isfile():
|
46 |
+
continue
|
47 |
+
dirname, basename = os.path.split(ti.path)
|
48 |
+
name, ext = os.path.splitext(basename)
|
49 |
+
ext = ext.lower()
|
50 |
+
if ext == '.tar':
|
51 |
+
with tarfile.open(fileobj=tf.extractfile(ti), mode='r|') as ctf:
|
52 |
+
child_info = dict(
|
53 |
+
name=ti.name, path=os.path.join(parent_info['path'], name), ti=ti, children=[], samples=[])
|
54 |
+
sample_count += _extract_tarinfo(ctf, child_info, extensions=extensions)
|
55 |
+
_logger.debug(f'{i}/?. Extracted child tarinfos from {ti.name}. {len(child_info["samples"])} images.')
|
56 |
+
parent_info['children'].append(child_info)
|
57 |
+
elif ext in extensions:
|
58 |
+
parent_info['samples'].append(ti)
|
59 |
+
sample_count += 1
|
60 |
+
return sample_count
|
61 |
+
|
62 |
+
|
63 |
+
def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True):
|
64 |
+
root_is_tar = False
|
65 |
+
if os.path.isfile(root):
|
66 |
+
assert os.path.splitext(root)[-1].lower() == '.tar'
|
67 |
+
tar_filenames = [root]
|
68 |
+
root, root_name = os.path.split(root)
|
69 |
+
root_name = os.path.splitext(root_name)[0]
|
70 |
+
root_is_tar = True
|
71 |
+
else:
|
72 |
+
root_name = root.strip(os.path.sep).split(os.path.sep)[-1]
|
73 |
+
tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True)
|
74 |
+
num_tars = len(tar_filenames)
|
75 |
+
tar_bytes = sum([os.path.getsize(f) for f in tar_filenames])
|
76 |
+
assert num_tars, f'No .tar files found at specified path ({root}).'
|
77 |
+
|
78 |
+
_logger.info(f'Scanning {tar_bytes/1024**2:.2f}MB of tar files...')
|
79 |
+
info = dict(tartrees=[])
|
80 |
+
cache_path = ''
|
81 |
+
if cache_tarinfo is None:
|
82 |
+
cache_tarinfo = True if tar_bytes > 10*1024**3 else False # FIXME magic number, 10GB
|
83 |
+
if cache_tarinfo:
|
84 |
+
cache_filename = '_' + root_name + CACHE_FILENAME_SUFFIX
|
85 |
+
cache_path = os.path.join(root, cache_filename)
|
86 |
+
if os.path.exists(cache_path):
|
87 |
+
_logger.info(f'Reading tar info from cache file {cache_path}.')
|
88 |
+
with open(cache_path, 'rb') as pf:
|
89 |
+
info = pickle.load(pf)
|
90 |
+
assert len(info['tartrees']) == num_tars, "Cached tartree len doesn't match number of tarfiles"
|
91 |
+
else:
|
92 |
+
for i, fn in enumerate(tar_filenames):
|
93 |
+
path = '' if root_is_tar else os.path.splitext(os.path.basename(fn))[0]
|
94 |
+
with tarfile.open(fn, mode='r|') as tf: # tarinfo scans done in streaming mode
|
95 |
+
parent_info = dict(name=os.path.relpath(fn, root), path=path, ti=None, children=[], samples=[])
|
96 |
+
num_samples = _extract_tarinfo(tf, parent_info, extensions=extensions)
|
97 |
+
num_children = len(parent_info["children"])
|
98 |
+
_logger.debug(
|
99 |
+
f'{i}/{num_tars}. Extracted tarinfos from {fn}. {num_children} children, {num_samples} samples.')
|
100 |
+
info['tartrees'].append(parent_info)
|
101 |
+
if cache_path:
|
102 |
+
_logger.info(f'Writing tar info to cache file {cache_path}.')
|
103 |
+
with open(cache_path, 'wb') as pf:
|
104 |
+
pickle.dump(info, pf)
|
105 |
+
|
106 |
+
samples = []
|
107 |
+
labels = []
|
108 |
+
build_class_map = False
|
109 |
+
if class_name_to_idx is None:
|
110 |
+
build_class_map = True
|
111 |
+
|
112 |
+
# Flatten tartree info into lists of samples and targets w/ targets based on label id via
|
113 |
+
# class map arg or from unique paths.
|
114 |
+
# NOTE: currently only flattening up to two-levels, filesystem .tars and then one level of sub-tar children
|
115 |
+
# this covers my current use cases and keeps things a little easier to test for now.
|
116 |
+
tarfiles = []
|
117 |
+
|
118 |
+
def _label_from_paths(*path, leaf_only=True):
|
119 |
+
path = os.path.join(*path).strip(os.path.sep)
|
120 |
+
return path.split(os.path.sep)[-1] if leaf_only else path.replace(os.path.sep, '_')
|
121 |
+
|
122 |
+
def _add_samples(info, fn):
|
123 |
+
added = 0
|
124 |
+
for s in info['samples']:
|
125 |
+
label = _label_from_paths(info['path'], os.path.dirname(s.path))
|
126 |
+
if not build_class_map and label not in class_name_to_idx:
|
127 |
+
continue
|
128 |
+
samples.append((s, fn, info['ti']))
|
129 |
+
labels.append(label)
|
130 |
+
added += 1
|
131 |
+
return added
|
132 |
+
|
133 |
+
_logger.info(f'Collecting samples and building tar states.')
|
134 |
+
for parent_info in info['tartrees']:
|
135 |
+
# if tartree has children, we assume all samples are at the child level
|
136 |
+
tar_name = None if root_is_tar else parent_info['name']
|
137 |
+
tar_state = TarState()
|
138 |
+
parent_added = 0
|
139 |
+
for child_info in parent_info['children']:
|
140 |
+
child_added = _add_samples(child_info, fn=tar_name)
|
141 |
+
if child_added:
|
142 |
+
tar_state.children[child_info['name']] = TarState(ti=child_info['ti'])
|
143 |
+
parent_added += child_added
|
144 |
+
parent_added += _add_samples(parent_info, fn=tar_name)
|
145 |
+
if parent_added:
|
146 |
+
tarfiles.append((tar_name, tar_state))
|
147 |
+
del info
|
148 |
+
|
149 |
+
if build_class_map:
|
150 |
+
# build class index
|
151 |
+
sorted_labels = list(sorted(set(labels), key=natural_key))
|
152 |
+
class_name_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
153 |
+
|
154 |
+
_logger.info(f'Mapping targets and sorting samples.')
|
155 |
+
samples_and_targets = [(s, class_name_to_idx[l]) for s, l in zip(samples, labels) if l in class_name_to_idx]
|
156 |
+
if sort:
|
157 |
+
samples_and_targets = sorted(samples_and_targets, key=lambda k: natural_key(k[0][0].path))
|
158 |
+
samples, targets = zip(*samples_and_targets)
|
159 |
+
samples = np.array(samples)
|
160 |
+
targets = np.array(targets)
|
161 |
+
_logger.info(f'Finished processing {len(samples)} samples across {len(tarfiles)} tar files.')
|
162 |
+
return samples, targets, class_name_to_idx, tarfiles
|
163 |
+
|
164 |
+
|
165 |
+
class ParserImageInTar(Parser):
|
166 |
+
""" Multi-tarfile dataset parser where there is one .tar file per class
|
167 |
+
"""
|
168 |
+
|
169 |
+
def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
|
170 |
+
super().__init__()
|
171 |
+
|
172 |
+
class_name_to_idx = None
|
173 |
+
if class_map:
|
174 |
+
class_name_to_idx = load_class_map(class_map, root)
|
175 |
+
self.root = root
|
176 |
+
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
|
177 |
+
self.root,
|
178 |
+
class_name_to_idx=class_name_to_idx,
|
179 |
+
cache_tarinfo=cache_tarinfo,
|
180 |
+
extensions=IMG_EXTENSIONS)
|
181 |
+
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
|
182 |
+
if len(tarfiles) == 1 and tarfiles[0][0] is None:
|
183 |
+
self.root_is_tar = True
|
184 |
+
self.tar_state = tarfiles[0][1]
|
185 |
+
else:
|
186 |
+
self.root_is_tar = False
|
187 |
+
self.tar_state = dict(tarfiles)
|
188 |
+
self.cache_tarfiles = cache_tarfiles
|
189 |
+
|
190 |
+
def __len__(self):
|
191 |
+
return len(self.samples)
|
192 |
+
|
193 |
+
def __getitem__(self, index):
|
194 |
+
sample = self.samples[index]
|
195 |
+
target = self.targets[index]
|
196 |
+
sample_ti, parent_fn, child_ti = sample
|
197 |
+
parent_abs = os.path.join(self.root, parent_fn) if parent_fn else self.root
|
198 |
+
|
199 |
+
tf = None
|
200 |
+
cache_state = None
|
201 |
+
if self.cache_tarfiles:
|
202 |
+
cache_state = self.tar_state if self.root_is_tar else self.tar_state[parent_fn]
|
203 |
+
tf = cache_state.tf
|
204 |
+
if tf is None:
|
205 |
+
tf = tarfile.open(parent_abs)
|
206 |
+
if self.cache_tarfiles:
|
207 |
+
cache_state.tf = tf
|
208 |
+
if child_ti is not None:
|
209 |
+
ctf = cache_state.children[child_ti.name].tf if self.cache_tarfiles else None
|
210 |
+
if ctf is None:
|
211 |
+
ctf = tarfile.open(fileobj=tf.extractfile(child_ti))
|
212 |
+
if self.cache_tarfiles:
|
213 |
+
cache_state.children[child_ti.name].tf = ctf
|
214 |
+
tf = ctf
|
215 |
+
|
216 |
+
return tf.extractfile(sample_ti), target
|
217 |
+
|
218 |
+
def _filename(self, index, basename=False, absolute=False):
|
219 |
+
filename = self.samples[index][0].name
|
220 |
+
if basename:
|
221 |
+
filename = os.path.basename(filename)
|
222 |
+
return filename
|
lib/timm/data/parsers/parser_image_tar.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" A dataset parser that reads single tarfile based datasets
|
2 |
+
|
3 |
+
This parser can read datasets consisting if a single tarfile containing images.
|
4 |
+
I am planning to deprecated it in favour of ParerImageInTar.
|
5 |
+
|
6 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
7 |
+
"""
|
8 |
+
import os
|
9 |
+
import tarfile
|
10 |
+
|
11 |
+
from .parser import Parser
|
12 |
+
from .class_map import load_class_map
|
13 |
+
from .constants import IMG_EXTENSIONS
|
14 |
+
from timm.utils.misc import natural_key
|
15 |
+
|
16 |
+
|
17 |
+
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
18 |
+
files = []
|
19 |
+
labels = []
|
20 |
+
for ti in tarfile.getmembers():
|
21 |
+
if not ti.isfile():
|
22 |
+
continue
|
23 |
+
dirname, basename = os.path.split(ti.path)
|
24 |
+
label = os.path.basename(dirname)
|
25 |
+
ext = os.path.splitext(basename)[1]
|
26 |
+
if ext.lower() in IMG_EXTENSIONS:
|
27 |
+
files.append(ti)
|
28 |
+
labels.append(label)
|
29 |
+
if class_to_idx is None:
|
30 |
+
unique_labels = set(labels)
|
31 |
+
sorted_labels = list(sorted(unique_labels, key=natural_key))
|
32 |
+
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
33 |
+
tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
|
34 |
+
if sort:
|
35 |
+
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
|
36 |
+
return tarinfo_and_targets, class_to_idx
|
37 |
+
|
38 |
+
|
39 |
+
class ParserImageTar(Parser):
|
40 |
+
""" Single tarfile dataset where classes are mapped to folders within tar
|
41 |
+
NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
|
42 |
+
operate on folders of tars or tars in tars.
|
43 |
+
"""
|
44 |
+
def __init__(self, root, class_map=''):
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
class_to_idx = None
|
48 |
+
if class_map:
|
49 |
+
class_to_idx = load_class_map(class_map, root)
|
50 |
+
assert os.path.isfile(root)
|
51 |
+
self.root = root
|
52 |
+
|
53 |
+
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
|
54 |
+
self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
|
55 |
+
self.imgs = self.samples
|
56 |
+
self.tarfile = None # lazy init in __getitem__
|
57 |
+
|
58 |
+
def __getitem__(self, index):
|
59 |
+
if self.tarfile is None:
|
60 |
+
self.tarfile = tarfile.open(self.root)
|
61 |
+
tarinfo, target = self.samples[index]
|
62 |
+
fileobj = self.tarfile.extractfile(tarinfo)
|
63 |
+
return fileobj, target
|
64 |
+
|
65 |
+
def __len__(self):
|
66 |
+
return len(self.samples)
|
67 |
+
|
68 |
+
def _filename(self, index, basename=False, absolute=False):
|
69 |
+
filename = self.samples[index][0].name
|
70 |
+
if basename:
|
71 |
+
filename = os.path.basename(filename)
|
72 |
+
return filename
|