Roll20 commited on
Commit
3c859e4
·
1 Parent(s): d022c6c

add lib/timm

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. lib/timm-0.5.4.dist-info/INSTALLER +1 -0
  2. lib/timm-0.5.4.dist-info/LICENSE +201 -0
  3. lib/timm-0.5.4.dist-info/METADATA +503 -0
  4. lib/timm-0.5.4.dist-info/RECORD +364 -0
  5. lib/timm-0.5.4.dist-info/REQUESTED +0 -0
  6. lib/timm-0.5.4.dist-info/WHEEL +5 -0
  7. lib/timm-0.5.4.dist-info/top_level.txt +1 -0
  8. lib/timm/__init__.py +4 -0
  9. lib/timm/__pycache__/__init__.cpython-310.pyc +0 -0
  10. lib/timm/__pycache__/version.cpython-310.pyc +0 -0
  11. lib/timm/data/__init__.py +12 -0
  12. lib/timm/data/__pycache__/__init__.cpython-310.pyc +0 -0
  13. lib/timm/data/__pycache__/auto_augment.cpython-310.pyc +0 -0
  14. lib/timm/data/__pycache__/config.cpython-310.pyc +0 -0
  15. lib/timm/data/__pycache__/constants.cpython-310.pyc +0 -0
  16. lib/timm/data/__pycache__/dataset.cpython-310.pyc +0 -0
  17. lib/timm/data/__pycache__/dataset_factory.cpython-310.pyc +0 -0
  18. lib/timm/data/__pycache__/distributed_sampler.cpython-310.pyc +0 -0
  19. lib/timm/data/__pycache__/loader.cpython-310.pyc +0 -0
  20. lib/timm/data/__pycache__/mixup.cpython-310.pyc +0 -0
  21. lib/timm/data/__pycache__/random_erasing.cpython-310.pyc +0 -0
  22. lib/timm/data/__pycache__/real_labels.cpython-310.pyc +0 -0
  23. lib/timm/data/__pycache__/tf_preprocessing.cpython-310.pyc +0 -0
  24. lib/timm/data/__pycache__/transforms.cpython-310.pyc +0 -0
  25. lib/timm/data/__pycache__/transforms_factory.cpython-310.pyc +0 -0
  26. lib/timm/data/auto_augment.py +865 -0
  27. lib/timm/data/config.py +78 -0
  28. lib/timm/data/constants.py +7 -0
  29. lib/timm/data/dataset.py +152 -0
  30. lib/timm/data/dataset_factory.py +143 -0
  31. lib/timm/data/distributed_sampler.py +129 -0
  32. lib/timm/data/loader.py +289 -0
  33. lib/timm/data/mixup.py +316 -0
  34. lib/timm/data/parsers/__init__.py +1 -0
  35. lib/timm/data/parsers/__pycache__/__init__.cpython-310.pyc +0 -0
  36. lib/timm/data/parsers/__pycache__/class_map.cpython-310.pyc +0 -0
  37. lib/timm/data/parsers/__pycache__/constants.cpython-310.pyc +0 -0
  38. lib/timm/data/parsers/__pycache__/parser.cpython-310.pyc +0 -0
  39. lib/timm/data/parsers/__pycache__/parser_factory.cpython-310.pyc +0 -0
  40. lib/timm/data/parsers/__pycache__/parser_image_folder.cpython-310.pyc +0 -0
  41. lib/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-310.pyc +0 -0
  42. lib/timm/data/parsers/__pycache__/parser_image_tar.cpython-310.pyc +0 -0
  43. lib/timm/data/parsers/__pycache__/parser_tfds.cpython-310.pyc +0 -0
  44. lib/timm/data/parsers/class_map.py +19 -0
  45. lib/timm/data/parsers/constants.py +1 -0
  46. lib/timm/data/parsers/parser.py +17 -0
  47. lib/timm/data/parsers/parser_factory.py +29 -0
  48. lib/timm/data/parsers/parser_image_folder.py +69 -0
  49. lib/timm/data/parsers/parser_image_in_tar.py +222 -0
  50. lib/timm/data/parsers/parser_image_tar.py +72 -0
lib/timm-0.5.4.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
lib/timm-0.5.4.dist-info/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "{}"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2019 Ross Wightman
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
lib/timm-0.5.4.dist-info/METADATA ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: timm
3
+ Version: 0.5.4
4
+ Summary: (Unofficial) PyTorch Image Models
5
+ Home-page: https://github.com/rwightman/pytorch-image-models
6
+ Author: Ross Wightman
7
+ Author-email: hello@rwightman.com
8
+ License: UNKNOWN
9
+ Keywords: pytorch pretrained models efficientnet mobilenetv3 mnasnet
10
+ Platform: UNKNOWN
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Education
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: Apache Software License
15
+ Classifier: Programming Language :: Python :: 3.6
16
+ Classifier: Programming Language :: Python :: 3.7
17
+ Classifier: Programming Language :: Python :: 3.8
18
+ Classifier: Topic :: Scientific/Engineering
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Classifier: Topic :: Software Development
21
+ Classifier: Topic :: Software Development :: Libraries
22
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
23
+ Requires-Python: >=3.6
24
+ Description-Content-Type: text/markdown
25
+ License-File: LICENSE
26
+ Requires-Dist: torch (>=1.4)
27
+ Requires-Dist: torchvision
28
+
29
+ # PyTorch Image Models
30
+ - [Sponsors](#sponsors)
31
+ - [What's New](#whats-new)
32
+ - [Introduction](#introduction)
33
+ - [Models](#models)
34
+ - [Features](#features)
35
+ - [Results](#results)
36
+ - [Getting Started (Documentation)](#getting-started-documentation)
37
+ - [Train, Validation, Inference Scripts](#train-validation-inference-scripts)
38
+ - [Awesome PyTorch Resources](#awesome-pytorch-resources)
39
+ - [Licenses](#licenses)
40
+ - [Citing](#citing)
41
+
42
+ ## Sponsors
43
+
44
+ A big thank you to my [GitHub Sponsors](https://github.com/sponsors/rwightman) for their support!
45
+
46
+ In addition to the sponsors at the link above, I've received hardware and/or cloud resources from
47
+ * Nvidia (https://www.nvidia.com/en-us/)
48
+ * TFRC (https://www.tensorflow.org/tfrc)
49
+
50
+ I'm fortunate to be able to dedicate significant time and money of my own supporting this and other open source projects. However, as the projects increase in scope, outside support is needed to continue with the current trajectory of cloud services, hardware, and electricity costs.
51
+
52
+ ## What's New
53
+
54
+ ### Jan 14, 2022
55
+ * Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon....
56
+ * Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features
57
+ * Tried training a few small (~1.8-3M param) / mobile optimized models, a few are good so far, more on the way...
58
+ * `mnasnet_small` - 65.6 top-1
59
+ * `mobilenetv2_050` - 65.9
60
+ * `lcnet_100/075/050` - 72.1 / 68.8 / 63.1
61
+ * `semnasnet_075` - 73
62
+ * `fbnetv3_b/d/g` - 79.1 / 79.7 / 82.0
63
+ * TinyNet models added by [rsomani95](https://github.com/rsomani95)
64
+ * LCNet added via MobileNetV3 architecture
65
+
66
+ ### Nov 22, 2021
67
+ * A number of updated weights anew new model defs
68
+ * `eca_halonext26ts` - 79.5 @ 256
69
+ * `resnet50_gn` (new) - 80.1 @ 224, 81.3 @ 288
70
+ * `resnet50` - 80.7 @ 224, 80.9 @ 288 (trained at 176, not replacing current a1 weights as default since these don't scale as well to higher res, [weights](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth))
71
+ * `resnext50_32x4d` - 81.1 @ 224, 82.0 @ 288
72
+ * `sebotnet33ts_256` (new) - 81.2 @ 224
73
+ * `lamhalobotnet50ts_256` - 81.5 @ 256
74
+ * `halonet50ts` - 81.7 @ 256
75
+ * `halo2botnet50ts_256` - 82.0 @ 256
76
+ * `resnet101` - 82.0 @ 224, 82.8 @ 288
77
+ * `resnetv2_101` (new) - 82.1 @ 224, 83.0 @ 288
78
+ * `resnet152` - 82.8 @ 224, 83.5 @ 288
79
+ * `regnetz_d8` (new) - 83.5 @ 256, 84.0 @ 320
80
+ * `regnetz_e8` (new) - 84.5 @ 256, 85.0 @ 320
81
+ * `vit_base_patch8_224` (85.8 top-1) & `in21k` variant weights added thanks [Martins Bruveris](https://github.com/martinsbruveris)
82
+ * Groundwork in for FX feature extraction thanks to [Alexander Soare](https://github.com/alexander-soare)
83
+ * models updated for tracing compatibility (almost full support with some distlled transformer exceptions)
84
+
85
+ ### Oct 19, 2021
86
+ * ResNet strikes back (https://arxiv.org/abs/2110.00476) weights added, plus any extra training components used. Model weights and some more details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-rsb-weights)
87
+ * BCE loss and Repeated Augmentation support for RSB paper
88
+ * 4 series of ResNet based attention model experiments being added (implemented across byobnet.py/byoanet.py). These include all sorts of attention, from channel attn like SE, ECA to 2D QKV self-attention layers such as Halo, Bottlneck, Lambda. Details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
89
+ * Working implementations of the following 2D self-attention modules (likely to be differences from paper or eventual official impl):
90
+ * Halo (https://arxiv.org/abs/2103.12731)
91
+ * Bottleneck Transformer (https://arxiv.org/abs/2101.11605)
92
+ * LambdaNetworks (https://arxiv.org/abs/2102.08602)
93
+ * A RegNetZ series of models with some attention experiments (being added to). These do not follow the paper (https://arxiv.org/abs/2103.06877) in any way other than block architecture, details of official models are not available. See more here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
94
+ * ConvMixer (https://openreview.net/forum?id=TVHS5Y4dNvM), CrossVit (https://arxiv.org/abs/2103.14899), and BeiT (https://arxiv.org/abs/2106.08254) architectures + weights added
95
+ * freeze/unfreeze helpers by [Alexander Soare](https://github.com/alexander-soare)
96
+
97
+ ### Aug 18, 2021
98
+ * Optimizer bonanza!
99
+ * Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits))
100
+ * Add MADGRAD from FB research w/ a few tweaks (decoupled decay option, step handling that works with PyTorch XLA)
101
+ * Some cleanup on all optimizers and factory. No more `.data`, a bit more consistency, unit tests for all!
102
+ * SGDP and AdamP still won't work with PyTorch XLA but others should (have yet to test Adabelief, Adafactor, Adahessian myself).
103
+ * EfficientNet-V2 XL TF ported weights added, but they don't validate well in PyTorch (L is better). The pre-processing for the V2 TF training is a bit diff and the fine-tuned 21k -> 1k weights are very sensitive and less robust than the 1k weights.
104
+ * Added PyTorch trained EfficientNet-V2 'Tiny' w/ GlobalContext attn weights. Only .1-.2 top-1 better than the SE so more of a curiosity for those interested.
105
+
106
+ ### July 12, 2021
107
+ * Add XCiT models from [official facebook impl](https://github.com/facebookresearch/xcit). Contributed by [Alexander Soare](https://github.com/alexander-soare)
108
+
109
+ ### July 5-9, 2021
110
+ * Add `efficientnetv2_rw_t` weights, a custom 'tiny' 13.6M param variant that is a bit better than (non NoisyStudent) B3 models. Both faster and better accuracy (at same or lower res)
111
+ * top-1 82.34 @ 288x288 and 82.54 @ 320x320
112
+ * Add [SAM pretrained](https://arxiv.org/abs/2106.01548) in1k weight for ViT B/16 (`vit_base_patch16_sam_224`) and B/32 (`vit_base_patch32_sam_224`) models.
113
+ * Add 'Aggregating Nested Transformer' (NesT) w/ weights converted from official [Flax impl](https://github.com/google-research/nested-transformer). Contributed by [Alexander Soare](https://github.com/alexander-soare).
114
+ * `jx_nest_base` - 83.534, `jx_nest_small` - 83.120, `jx_nest_tiny` - 81.426
115
+
116
+ ### June 23, 2021
117
+ * Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050). Hparams for this and other recent MLP training [here](https://gist.github.com/rwightman/d6c264a9001f9167e06c209f630b2cc6)
118
+
119
+ ### June 20, 2021
120
+ * Release Vision Transformer 'AugReg' weights from [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270)
121
+ * .npz weight loading support added, can load any of the 50K+ weights from the [AugReg series](https://console.cloud.google.com/storage/browser/vit_models/augreg)
122
+ * See [example notebook](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb) from [official impl](https://github.com/google-research/vision_transformer/) for navigating the augreg weights
123
+ * Replaced all default weights w/ best AugReg variant (if possible). All AugReg 21k classifiers work.
124
+ * Highlights: `vit_large_patch16_384` (87.1 top-1), `vit_large_r50_s32_384` (86.2 top-1), `vit_base_patch16_384` (86.0 top-1)
125
+ * `vit_deit_*` renamed to just `deit_*`
126
+ * Remove my old small model, replace with DeiT compatible small w/ AugReg weights
127
+ * Add 1st training of my `gmixer_24_224` MLP /w GLU, 78.1 top-1 w/ 25M params.
128
+ * Add weights from official ResMLP release (https://github.com/facebookresearch/deit)
129
+ * Add `eca_nfnet_l2` weights from my 'lightweight' series. 84.7 top-1 at 384x384.
130
+ * Add distilled BiT 50x1 student and 152x2 Teacher weights from [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237)
131
+ * NFNets and ResNetV2-BiT models work w/ Pytorch XLA now
132
+ * weight standardization uses F.batch_norm instead of std_mean (std_mean wasn't lowered)
133
+ * eps values adjusted, will be slight differences but should be quite close
134
+ * Improve test coverage and classifier interface of non-conv (vision transformer and mlp) models
135
+ * Cleanup a few classifier / flatten details for models w/ conv classifiers or early global pool
136
+ * Please report any regressions, this PR touched quite a few models.
137
+
138
+ ### June 8, 2021
139
+ * Add first ResMLP weights, trained in PyTorch XLA on TPU-VM w/ my XLA branch. 24 block variant, 79.2 top-1.
140
+ * Add ResNet51-Q model w/ pretrained weights at 82.36 top-1.
141
+ * NFNet inspired block layout with quad layer stem and no maxpool
142
+ * Same param count (35.7M) and throughput as ResNetRS-50 but +1.5 top-1 @ 224x224 and +2.5 top-1 at 288x288
143
+
144
+ ### May 25, 2021
145
+ * Add LeViT, Visformer, ConViT (PR by Aman Arora), Twins (PR by paper authors) transformer models
146
+ * Add ResMLP and gMLP MLP vision models to the existing MLP Mixer impl
147
+ * Fix a number of torchscript issues with various vision transformer models
148
+ * Cleanup input_size/img_size override handling and improve testing / test coverage for all vision transformer and MLP models
149
+ * More flexible pos embedding resize (non-square) for ViT and TnT. Thanks [Alexander Soare](https://github.com/alexander-soare)
150
+ * Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params.
151
+
152
+ ### May 14, 2021
153
+ * Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl.
154
+ * 1k trained variants: `tf_efficientnetv2_s/m/l`
155
+ * 21k trained variants: `tf_efficientnetv2_s/m/l_in21k`
156
+ * 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_in21ft1k`
157
+ * v2 models w/ v1 scaling: `tf_efficientnetv2_b0` through `b3`
158
+ * Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s`
159
+ * Some blank `efficientnetv2_*` models in-place for future native PyTorch training
160
+
161
+ ### May 5, 2021
162
+ * Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen)
163
+ * Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit)
164
+ * Add ResNet-RS models and weights from [TF](https://github.com/tensorflow/tpu/tree/master/models/official/resnet/resnet_rs). Thanks [Aman Arora](https://github.com/amaarora)
165
+ * Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin)
166
+ * Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23)
167
+ * Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai)
168
+ * Update ByoaNet attention modules
169
+ * Improve SA module inits
170
+ * Hack together experimental stand-alone Swin based attn module and `swinnet`
171
+ * Consistent '26t' model defs for experiments.
172
+ * Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1.
173
+ * WandB logging support
174
+
175
+ ### April 13, 2021
176
+ * Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer
177
+
178
+ ### April 12, 2021
179
+ * Add ECA-NFNet-L1 (slimmed down F1 w/ SiLU, 41M params) trained with this code. 84% top-1 @ 320x320. Trained at 256x256.
180
+ * Add EfficientNet-V2S model (unverified model definition) weights. 83.3 top-1 @ 288x288. Only trained single res 224. Working on progressive training.
181
+ * Add ByoaNet model definition (Bring-your-own-attention) w/ SelfAttention block and corresponding SA/SA-like modules and model defs
182
+ * Lambda Networks - https://arxiv.org/abs/2102.08602
183
+ * Bottleneck Transformers - https://arxiv.org/abs/2101.11605
184
+ * Halo Nets - https://arxiv.org/abs/2103.12731
185
+ * Adabelief optimizer contributed by Juntang Zhuang
186
+
187
+ ### April 1, 2021
188
+ * Add snazzy `benchmark.py` script for bulk `timm` model benchmarking of train and/or inference
189
+ * Add Pooling-based Vision Transformer (PiT) models (from https://github.com/naver-ai/pit)
190
+ * Merged distilled variant into main for torchscript compatibility
191
+ * Some `timm` cleanup/style tweaks and weights have hub download support
192
+ * Cleanup Vision Transformer (ViT) models
193
+ * Merge distilled (DeiT) model into main so that torchscript can work
194
+ * Support updated weight init (defaults to old still) that closer matches original JAX impl (possibly better training from scratch)
195
+ * Separate hybrid model defs into different file and add several new model defs to fiddle with, support patch_size != 1 for hybrids
196
+ * Fix fine-tuning num_class changes (PiT and ViT) and pos_embed resizing (Vit) with distilled variants
197
+ * nn.Sequential for block stack (does not break downstream compat)
198
+ * TnT (Transformer-in-Transformer) models contributed by author (from https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT)
199
+ * Add RegNetY-160 weights from DeiT teacher model
200
+ * Add new NFNet-L0 w/ SE attn (rename `nfnet_l0b`->`nfnet_l0`) weights 82.75 top-1 @ 288x288
201
+ * Some fixes/improvements for TFDS dataset wrapper
202
+
203
+ ### March 17, 2021
204
+ * Add new ECA-NFNet-L0 (rename `nfnet_l0c`->`eca_nfnet_l0`) weights trained by myself.
205
+ * 82.6 top-1 @ 288x288, 82.8 @ 320x320, trained at 224x224
206
+ * Uses SiLU activation, approx 2x faster than `dm_nfnet_f0` and 50% faster than `nfnet_f0s` w/ 1/3 param count
207
+ * Integrate [Hugging Face model hub](https://huggingface.co/models) into timm create_model and default_cfg handling for pretrained weight and config sharing (more on this soon!)
208
+ * Merge HardCoRe NAS models contributed by https://github.com/yoniaflalo
209
+ * Merge PyTorch trained EfficientNet-EL and pruned ES/EL variants contributed by [DeGirum](https://github.com/DeGirum)
210
+
211
+
212
+ ### March 7, 2021
213
+ * First 0.4.x PyPi release w/ NFNets (& related), ByoB (GPU-Efficient, RepVGG, etc).
214
+ * Change feature extraction for pre-activation nets (NFNets, ResNetV2) to return features before activation.
215
+ * Tested with PyTorch 1.8 release. Updated CI to use 1.8.
216
+ * Benchmarked several arch on RTX 3090, Titan RTX, and V100 across 1.7.1, 1.8, NGC 20.12, and 21.02. Some interesting performance variations to take note of https://gist.github.com/rwightman/bb59f9e245162cee0e38bd66bd8cd77f
217
+
218
+ ### Feb 18, 2021
219
+ * Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets).
220
+ * Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn.
221
+ * These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants.
222
+ * Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated).
223
+ * Matching the original pre-processing as closely as possible I get these results:
224
+ * `dm_nfnet_f6` - 86.352
225
+ * `dm_nfnet_f5` - 86.100
226
+ * `dm_nfnet_f4` - 85.834
227
+ * `dm_nfnet_f3` - 85.676
228
+ * `dm_nfnet_f2` - 85.178
229
+ * `dm_nfnet_f1` - 84.696
230
+ * `dm_nfnet_f0` - 83.464
231
+
232
+ ### Feb 16, 2021
233
+ * Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py.
234
+ * AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc`
235
+ * PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0`
236
+ * PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value`
237
+ * AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet.
238
+
239
+ ### Feb 12, 2021
240
+ * Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs
241
+
242
+ ### Feb 10, 2021
243
+ * First Normalization-Free model training experiments done,
244
+ * nf_resnet50 - 80.68 top-1 @ 288x288, 80.31 @ 256x256
245
+ * nf_regnet_b1 - 79.30 @ 288x288, 78.75 @ 256x256
246
+ * More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks')
247
+ * GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in `byobnet.py`
248
+ * RepVGG (https://github.com/DingXiaoH/RepVGG), impl in `byobnet.py`
249
+ * classic VGG (from torchvision, impl in `vgg.py`)
250
+ * Refinements to normalizer layer arg handling and normalizer+act layer handling in some models
251
+ * Default AMP mode changed to native PyTorch AMP instead of APEX. Issues not being fixed with APEX. Native works with `--channels-last` and `--torchscript` model training, APEX does not.
252
+ * Fix a few bugs introduced since last pypi release
253
+
254
+ ### Feb 8, 2021
255
+ * Add several ResNet weights with ECA attention. 26t & 50t trained @ 256, test @ 320. 269d train @ 256, fine-tune @320, test @ 352.
256
+ * `ecaresnet26t` - 79.88 top-1 @ 320x320, 79.08 @ 256x256
257
+ * `ecaresnet50t` - 82.35 top-1 @ 320x320, 81.52 @ 256x256
258
+ * `ecaresnet269d` - 84.93 top-1 @ 352x352, 84.87 @ 320x320
259
+ * Remove separate tiered (`t`) vs tiered_narrow (`tn`) ResNet model defs, all `tn` changed to `t` and `t` models removed (`seresnext26t_32x4d` only model w/ weights that was removed).
260
+ * Support model default_cfgs with separate train vs test resolution `test_input_size` and remove extra `_320` suffix ResNet model defs that were just for test.
261
+
262
+ ### Jan 30, 2021
263
+ * Add initial "Normalization Free" NF-RegNet-B* and NF-ResNet model definitions based on [paper](https://arxiv.org/abs/2101.08692)
264
+
265
+ ### Jan 25, 2021
266
+ * Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer
267
+ * Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer
268
+ * ImageNet-21k ViT weights are added w/ model defs and representation layer (pre logits) support
269
+ * NOTE: ImageNet-21k classifier heads were zero'd in original weights, they are only useful for transfer learning
270
+ * Add model defs and weights for DeiT Vision Transformer models from https://github.com/facebookresearch/deit
271
+ * Refactor dataset classes into ImageDataset/IterableImageDataset + dataset specific parser classes
272
+ * Add Tensorflow-Datasets (TFDS) wrapper to allow use of TFDS image classification sets with train script
273
+ * Ex: `train.py /data/tfds --dataset tfds/oxford_iiit_pet --val-split test --model resnet50 -b 256 --amp --num-classes 37 --opt adamw --lr 3e-4 --weight-decay .001 --pretrained -j 2`
274
+ * Add improved .tar dataset parser that reads images from .tar, folder of .tar files, or .tar within .tar
275
+ * Run validation on full ImageNet-21k directly from tar w/ BiT model: `validate.py /data/fall11_whole.tar --model resnetv2_50x1_bitm_in21k --amp`
276
+ * Models in this update should be stable w/ possible exception of ViT/BiT, possibility of some regressions with train/val scripts and dataset handling
277
+
278
+ ### Jan 3, 2021
279
+ * Add SE-ResNet-152D weights
280
+ * 256x256 val, 0.94 crop top-1 - 83.75
281
+ * 320x320 val, 1.0 crop - 84.36
282
+ * Update [results files](results/)
283
+
284
+
285
+ ## Introduction
286
+
287
+ Py**T**orch **Im**age **M**odels (`timm`) is a collection of image models, layers, utilities, optimizers, schedulers, data-loaders / augmentations, and reference training / validation scripts that aim to pull together a wide variety of SOTA models with ability to reproduce ImageNet training results.
288
+
289
+ The work of many others is present here. I've tried to make sure all source material is acknowledged via links to github, arxiv papers, etc in the README, documentation, and code docstrings. Please let me know if I missed anything.
290
+
291
+ ## Models
292
+
293
+ All model architecture families include variants with pretrained weights. There are specific model variants without any weights, it is NOT a bug. Help training new or better weights is always appreciated. Here are some example [training hparams](https://rwightman.github.io/pytorch-image-models/training_hparam_examples) to get you started.
294
+
295
+ A full version of the list below with source links can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/models/).
296
+
297
+ * Aggregating Nested Transformers - https://arxiv.org/abs/2105.12723
298
+ * BEiT - https://arxiv.org/abs/2106.08254
299
+ * Big Transfer ResNetV2 (BiT) - https://arxiv.org/abs/1912.11370
300
+ * Bottleneck Transformers - https://arxiv.org/abs/2101.11605
301
+ * CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239
302
+ * CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399
303
+ * ConvNeXt - https://arxiv.org/abs/2201.03545
304
+ * ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697
305
+ * CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
306
+ * DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877
307
+ * DenseNet - https://arxiv.org/abs/1608.06993
308
+ * DLA - https://arxiv.org/abs/1707.06484
309
+ * DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629
310
+ * EfficientNet (MBConvNet Family)
311
+ * EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252
312
+ * EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665
313
+ * EfficientNet (B0-B7) - https://arxiv.org/abs/1905.11946
314
+ * EfficientNet-EdgeTPU (S, M, L) - https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html
315
+ * EfficientNet V2 - https://arxiv.org/abs/2104.00298
316
+ * FBNet-C - https://arxiv.org/abs/1812.03443
317
+ * MixNet - https://arxiv.org/abs/1907.09595
318
+ * MNASNet B1, A1 (Squeeze-Excite), and Small - https://arxiv.org/abs/1807.11626
319
+ * MobileNet-V2 - https://arxiv.org/abs/1801.04381
320
+ * Single-Path NAS - https://arxiv.org/abs/1904.02877
321
+ * TinyNet - https://arxiv.org/abs/2010.14819
322
+ * GhostNet - https://arxiv.org/abs/1911.11907
323
+ * gMLP - https://arxiv.org/abs/2105.08050
324
+ * GPU-Efficient Networks - https://arxiv.org/abs/2006.14090
325
+ * Halo Nets - https://arxiv.org/abs/2103.12731
326
+ * HRNet - https://arxiv.org/abs/1908.07919
327
+ * Inception-V3 - https://arxiv.org/abs/1512.00567
328
+ * Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
329
+ * Lambda Networks - https://arxiv.org/abs/2102.08602
330
+ * LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136
331
+ * MLP-Mixer - https://arxiv.org/abs/2105.01601
332
+ * MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
333
+ * FBNet-V3 - https://arxiv.org/abs/2006.02049
334
+ * HardCoRe-NAS - https://arxiv.org/abs/2102.11646
335
+ * LCNet - https://arxiv.org/abs/2109.15099
336
+ * NASNet-A - https://arxiv.org/abs/1707.07012
337
+ * NesT - https://arxiv.org/abs/2105.12723
338
+ * NFNet-F - https://arxiv.org/abs/2102.06171
339
+ * NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
340
+ * PNasNet - https://arxiv.org/abs/1712.00559
341
+ * Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302
342
+ * RegNet - https://arxiv.org/abs/2003.13678
343
+ * RepVGG - https://arxiv.org/abs/2101.03697
344
+ * ResMLP - https://arxiv.org/abs/2105.03404
345
+ * ResNet/ResNeXt
346
+ * ResNet (v1b/v1.5) - https://arxiv.org/abs/1512.03385
347
+ * ResNeXt - https://arxiv.org/abs/1611.05431
348
+ * 'Bag of Tricks' / Gluon C, D, E, S variations - https://arxiv.org/abs/1812.01187
349
+ * Weakly-supervised (WSL) Instagram pretrained / ImageNet tuned ResNeXt101 - https://arxiv.org/abs/1805.00932
350
+ * Semi-supervised (SSL) / Semi-weakly Supervised (SWSL) ResNet/ResNeXts - https://arxiv.org/abs/1905.00546
351
+ * ECA-Net (ECAResNet) - https://arxiv.org/abs/1910.03151v4
352
+ * Squeeze-and-Excitation Networks (SEResNet) - https://arxiv.org/abs/1709.01507
353
+ * ResNet-RS - https://arxiv.org/abs/2103.07579
354
+ * Res2Net - https://arxiv.org/abs/1904.01169
355
+ * ResNeSt - https://arxiv.org/abs/2004.08955
356
+ * ReXNet - https://arxiv.org/abs/2007.00992
357
+ * SelecSLS - https://arxiv.org/abs/1907.00837
358
+ * Selective Kernel Networks - https://arxiv.org/abs/1903.06586
359
+ * Swin Transformer - https://arxiv.org/abs/2103.14030
360
+ * Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112
361
+ * TResNet - https://arxiv.org/abs/2003.13630
362
+ * Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf
363
+ * Visformer - https://arxiv.org/abs/2104.12533
364
+ * Vision Transformer - https://arxiv.org/abs/2010.11929
365
+ * VovNet V2 and V1 - https://arxiv.org/abs/1911.06667
366
+ * Xception - https://arxiv.org/abs/1610.02357
367
+ * Xception (Modified Aligned, Gluon) - https://arxiv.org/abs/1802.02611
368
+ * Xception (Modified Aligned, TF) - https://arxiv.org/abs/1802.02611
369
+ * XCiT (Cross-Covariance Image Transformers) - https://arxiv.org/abs/2106.09681
370
+
371
+ ## Features
372
+
373
+ Several (less common) features that I often utilize in my projects are included. Many of their additions are the reason why I maintain my own set of models, instead of using others' via PIP:
374
+
375
+ * All models have a common default configuration interface and API for
376
+ * accessing/changing the classifier - `get_classifier` and `reset_classifier`
377
+ * doing a forward pass on just the features - `forward_features` (see [documentation](https://rwightman.github.io/pytorch-image-models/feature_extraction/))
378
+ * these makes it easy to write consistent network wrappers that work with any of the models
379
+ * All models support multi-scale feature map extraction (feature pyramids) via create_model (see [documentation](https://rwightman.github.io/pytorch-image-models/feature_extraction/))
380
+ * `create_model(name, features_only=True, out_indices=..., output_stride=...)`
381
+ * `out_indices` creation arg specifies which feature maps to return, these indices are 0 based and generally correspond to the `C(i + 1)` feature level.
382
+ * `output_stride` creation arg controls output stride of the network by using dilated convolutions. Most networks are stride 32 by default. Not all networks support this.
383
+ * feature map channel counts, reduction level (stride) can be queried AFTER model creation via the `.feature_info` member
384
+ * All models have a consistent pretrained weight loader that adapts last linear if necessary, and from 3 to 1 channel input if desired
385
+ * High performance [reference training, validation, and inference scripts](https://rwightman.github.io/pytorch-image-models/scripts/) that work in several process/GPU modes:
386
+ * NVIDIA DDP w/ a single GPU per process, multiple processes with APEX present (AMP mixed-precision optional)
387
+ * PyTorch DistributedDataParallel w/ multi-gpu, single process (AMP disabled as it crashes when enabled)
388
+ * PyTorch w/ single GPU single process (AMP optional)
389
+ * A dynamic global pool implementation that allows selecting from average pooling, max pooling, average + max, or concat([average, max]) at model creation. All global pooling is adaptive average by default and compatible with pretrained weights.
390
+ * A 'Test Time Pool' wrapper that can wrap any of the included models and usually provides improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs)
391
+ * Learning rate schedulers
392
+ * Ideas adopted from
393
+ * [AllenNLP schedulers](https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers)
394
+ * [FAIRseq lr_scheduler](https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler)
395
+ * SGDR: Stochastic Gradient Descent with Warm Restarts (https://arxiv.org/abs/1608.03983)
396
+ * Schedulers include `step`, `cosine` w/ restarts, `tanh` w/ restarts, `plateau`
397
+ * Optimizers:
398
+ * `rmsprop_tf` adapted from PyTorch RMSProp by myself. Reproduces much improved Tensorflow RMSProp behaviour.
399
+ * `radam` by [Liyuan Liu](https://github.com/LiyuanLucasLiu/RAdam) (https://arxiv.org/abs/1908.03265)
400
+ * `novograd` by [Masashi Kimura](https://github.com/convergence-lab/novograd) (https://arxiv.org/abs/1905.11286)
401
+ * `lookahead` adapted from impl by [Liam](https://github.com/alphadl/lookahead.pytorch) (https://arxiv.org/abs/1907.08610)
402
+ * `fused<name>` optimizers by name with [NVIDIA Apex](https://github.com/NVIDIA/apex/tree/master/apex/optimizers) installed
403
+ * `adamp` and `sgdp` by [Naver ClovAI](https://github.com/clovaai) (https://arxiv.org/abs/2006.08217)
404
+ * `adafactor` adapted from [FAIRSeq impl](https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py) (https://arxiv.org/abs/1804.04235)
405
+ * `adahessian` by [David Samuel](https://github.com/davda54/ada-hessian) (https://arxiv.org/abs/2006.00719)
406
+ * Random Erasing from [Zhun Zhong](https://github.com/zhunzhong07/Random-Erasing/blob/master/transforms.py) (https://arxiv.org/abs/1708.04896)
407
+ * Mixup (https://arxiv.org/abs/1710.09412)
408
+ * CutMix (https://arxiv.org/abs/1905.04899)
409
+ * AutoAugment (https://arxiv.org/abs/1805.09501) and RandAugment (https://arxiv.org/abs/1909.13719) ImageNet configurations modeled after impl for EfficientNet training (https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py)
410
+ * AugMix w/ JSD loss (https://arxiv.org/abs/1912.02781), JSD w/ clean + augmented mixing support works with AutoAugment and RandAugment as well
411
+ * SplitBachNorm - allows splitting batch norm layers between clean and augmented (auxiliary batch norm) data
412
+ * DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382)
413
+ * DropBlock (https://arxiv.org/abs/1810.12890)
414
+ * Blur Pooling (https://arxiv.org/abs/1904.11486)
415
+ * Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper?
416
+ * Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets)
417
+ * An extensive selection of channel and/or spatial attention modules:
418
+ * Bottleneck Transformer - https://arxiv.org/abs/2101.11605
419
+ * CBAM - https://arxiv.org/abs/1807.06521
420
+ * Effective Squeeze-Excitation (ESE) - https://arxiv.org/abs/1911.06667
421
+ * Efficient Channel Attention (ECA) - https://arxiv.org/abs/1910.03151
422
+ * Gather-Excite (GE) - https://arxiv.org/abs/1810.12348
423
+ * Global Context (GC) - https://arxiv.org/abs/1904.11492
424
+ * Halo - https://arxiv.org/abs/2103.12731
425
+ * Involution - https://arxiv.org/abs/2103.06255
426
+ * Lambda Layer - https://arxiv.org/abs/2102.08602
427
+ * Non-Local (NL) - https://arxiv.org/abs/1711.07971
428
+ * Squeeze-and-Excitation (SE) - https://arxiv.org/abs/1709.01507
429
+ * Selective Kernel (SK) - (https://arxiv.org/abs/1903.06586
430
+ * Split (SPLAT) - https://arxiv.org/abs/2004.08955
431
+ * Shifted Window (SWIN) - https://arxiv.org/abs/2103.14030
432
+
433
+ ## Results
434
+
435
+ Model validation results can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/results/) and in the [results tables](results/README.md)
436
+
437
+ ## Getting Started (Documentation)
438
+
439
+ My current [documentation](https://rwightman.github.io/pytorch-image-models/) for `timm` covers the basics.
440
+
441
+ [timmdocs](https://fastai.github.io/timmdocs/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
442
+
443
+ [paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`.
444
+
445
+ ## Train, Validation, Inference Scripts
446
+
447
+ The root folder of the repository contains reference train, validation, and inference scripts that work with the included models and other features of this repository. They are adaptable for other datasets and use cases with a little hacking. See [documentation](https://rwightman.github.io/pytorch-image-models/scripts/) for some basics and [training hparams](https://rwightman.github.io/pytorch-image-models/training_hparam_examples) for some train examples that produce SOTA ImageNet results.
448
+
449
+ ## Awesome PyTorch Resources
450
+
451
+ One of the greatest assets of PyTorch is the community and their contributions. A few of my favourite resources that pair well with the models and components here are listed below.
452
+
453
+ ### Object Detection, Instance and Semantic Segmentation
454
+ * Detectron2 - https://github.com/facebookresearch/detectron2
455
+ * Segmentation Models (Semantic) - https://github.com/qubvel/segmentation_models.pytorch
456
+ * EfficientDet (Obj Det, Semantic soon) - https://github.com/rwightman/efficientdet-pytorch
457
+
458
+ ### Computer Vision / Image Augmentation
459
+ * Albumentations - https://github.com/albumentations-team/albumentations
460
+ * Kornia - https://github.com/kornia/kornia
461
+
462
+ ### Knowledge Distillation
463
+ * RepDistiller - https://github.com/HobbitLong/RepDistiller
464
+ * torchdistill - https://github.com/yoshitomo-matsubara/torchdistill
465
+
466
+ ### Metric Learning
467
+ * PyTorch Metric Learning - https://github.com/KevinMusgrave/pytorch-metric-learning
468
+
469
+ ### Training / Frameworks
470
+ * fastai - https://github.com/fastai/fastai
471
+
472
+ ## Licenses
473
+
474
+ ### Code
475
+ The code here is licensed Apache 2.0. I've taken care to make sure any third party code included or adapted has compatible (permissive) licenses such as MIT, BSD, etc. I've made an effort to avoid any GPL / LGPL conflicts. That said, it is your responsibility to ensure you comply with licenses here and conditions of any dependent licenses. Where applicable, I've linked the sources/references for various components in docstrings. If you think I've missed anything please create an issue.
476
+
477
+ ### Pretrained Weights
478
+ So far all of the pretrained weights available here are pretrained on ImageNet with a select few that have some additional pretraining (see extra note below). ImageNet was released for non-commercial research purposes only (https://image-net.org/download). It's not clear what the implications of that are for the use of pretrained weights from that dataset. Any models I have trained with ImageNet are done for research purposes and one should assume that the original dataset license applies to the weights. It's best to seek legal advice if you intend to use the pretrained weights in a commercial product.
479
+
480
+ #### Pretrained on more than ImageNet
481
+ Several weights included or references here were pretrained with proprietary datasets that I do not have access to. These include the Facebook WSL, SSL, SWSL ResNe(Xt) and the Google Noisy Student EfficientNet models. The Facebook models have an explicit non-commercial license (CC-BY-NC 4.0, https://github.com/facebookresearch/semi-supervised-ImageNet1K-models, https://github.com/facebookresearch/WSL-Images). The Google models do not appear to have any restriction beyond the Apache 2.0 license (and ImageNet concerns). In either case, you should contact Facebook or Google with any questions.
482
+
483
+ ## Citing
484
+
485
+ ### BibTeX
486
+
487
+ ```bibtex
488
+ @misc{rw2019timm,
489
+ author = {Ross Wightman},
490
+ title = {PyTorch Image Models},
491
+ year = {2019},
492
+ publisher = {GitHub},
493
+ journal = {GitHub repository},
494
+ doi = {10.5281/zenodo.4414861},
495
+ howpublished = {\url{https://github.com/rwightman/pytorch-image-models}}
496
+ }
497
+ ```
498
+
499
+ ### Latest DOI
500
+
501
+ [![DOI](https://zenodo.org/badge/168799526.svg)](https://zenodo.org/badge/latestdoi/168799526)
502
+
503
+
lib/timm-0.5.4.dist-info/RECORD ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ timm-0.5.4.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
2
+ timm-0.5.4.dist-info/LICENSE,sha256=cbERYg-jLBeoDM1tstp1nTGlkeSX2LXzghdPWdG1nUk,11343
3
+ timm-0.5.4.dist-info/METADATA,sha256=_o4k9R4FYZ1msA33NowGB-awL2oEA8zUsuZjK6xgB4c,36181
4
+ timm-0.5.4.dist-info/RECORD,,
5
+ timm-0.5.4.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ timm-0.5.4.dist-info/WHEEL,sha256=ewwEueio1C2XeHTvT17n8dZUJgOvyCWCt0WVNLClP9o,92
7
+ timm-0.5.4.dist-info/top_level.txt,sha256=Mi21FFh17x9WpGQnfmIkFMK_kFw_m3eb_G3J8-UJ5SY,5
8
+ timm/__init__.py,sha256=9mTvNS2J6SoMaBrYEv6Xcmc9EGMytuFwkJgUYCXSArg,286
9
+ timm/__pycache__/__init__.cpython-310.pyc,,
10
+ timm/__pycache__/version.cpython-310.pyc,,
11
+ timm/data/__init__.py,sha256=UY7Kh_mkF-oCUoBc9hZUF_3PUlqmjm7kGhSHu9RuHoQ,553
12
+ timm/data/__pycache__/__init__.cpython-310.pyc,,
13
+ timm/data/__pycache__/auto_augment.cpython-310.pyc,,
14
+ timm/data/__pycache__/config.cpython-310.pyc,,
15
+ timm/data/__pycache__/constants.cpython-310.pyc,,
16
+ timm/data/__pycache__/dataset.cpython-310.pyc,,
17
+ timm/data/__pycache__/dataset_factory.cpython-310.pyc,,
18
+ timm/data/__pycache__/distributed_sampler.cpython-310.pyc,,
19
+ timm/data/__pycache__/loader.cpython-310.pyc,,
20
+ timm/data/__pycache__/mixup.cpython-310.pyc,,
21
+ timm/data/__pycache__/random_erasing.cpython-310.pyc,,
22
+ timm/data/__pycache__/real_labels.cpython-310.pyc,,
23
+ timm/data/__pycache__/tf_preprocessing.cpython-310.pyc,,
24
+ timm/data/__pycache__/transforms.cpython-310.pyc,,
25
+ timm/data/__pycache__/transforms_factory.cpython-310.pyc,,
26
+ timm/data/auto_augment.py,sha256=C4rrMeP0oAAhqPLzdEc_MBimKGnmaq-J1e6ik0XO3Pg,31686
27
+ timm/data/config.py,sha256=grh9sGvu3YNQJm0pOd09ZVRvz8dnMdbbqoFJDLT_TJw,2915
28
+ timm/data/constants.py,sha256=xc_A1oVqqTSSr9fNPijM3eZPm8bgbh8OEhwdsS-XwlQ,303
29
+ timm/data/dataset.py,sha256=XUBknhmSG8iuGJrZKF47emzVTX1vgxxBN3uc0E3p40Q,4805
30
+ timm/data/dataset_factory.py,sha256=FR5gDs5vofvjrdDrs-SaHI-5jb56BVqp4ChrJT-TreI,5533
31
+ timm/data/distributed_sampler.py,sha256=ZQ7KA3xEyWDLumq2aHjjQy9o_SpZgkdk8dmF-eP9BsQ,5125
32
+ timm/data/loader.py,sha256=h087Nqq-5Ct2AN8g_uqDWESDgMlA9hYwT03Qlv4xBQU,9924
33
+ timm/data/mixup.py,sha256=xA8ZMdVoVIfwZMA9vMdLDkcosJYInaZnVsQW1QX3ano,14722
34
+ timm/data/parsers/__init__.py,sha256=f1ffuO9Pj74PBI2y3h1iweooYJlTJlADGs8wdQuaZmw,42
35
+ timm/data/parsers/__pycache__/__init__.cpython-310.pyc,,
36
+ timm/data/parsers/__pycache__/class_map.cpython-310.pyc,,
37
+ timm/data/parsers/__pycache__/constants.cpython-310.pyc,,
38
+ timm/data/parsers/__pycache__/parser.cpython-310.pyc,,
39
+ timm/data/parsers/__pycache__/parser_factory.cpython-310.pyc,,
40
+ timm/data/parsers/__pycache__/parser_image_folder.cpython-310.pyc,,
41
+ timm/data/parsers/__pycache__/parser_image_in_tar.cpython-310.pyc,,
42
+ timm/data/parsers/__pycache__/parser_image_tar.cpython-310.pyc,,
43
+ timm/data/parsers/__pycache__/parser_tfds.cpython-310.pyc,,
44
+ timm/data/parsers/class_map.py,sha256=rYqhlPYplTRj86FH1CZRf76NB5zUECX4-D93T-u-SjQ,759
45
+ timm/data/parsers/constants.py,sha256=X4eulfViOcrn8_k-Awi6LTGHFK5vZKXeS8exbD_KtT4,43
46
+ timm/data/parsers/parser.py,sha256=BoIya5V--BuyXWEeYzi-rg8wsXI0UiEiuxQxkho0dCY,487
47
+ timm/data/parsers/parser_factory.py,sha256=Z6Vku9nKq1lcKgqwjuYCdid4lJyJH7Pyf9QnoVvHbGU,1078
48
+ timm/data/parsers/parser_image_folder.py,sha256=Up4Yv5NPvE5xkXv2AtoBc2FQY9Xt5AqUOkDD5Qr0OMs,2508
49
+ timm/data/parsers/parser_image_in_tar.py,sha256=VNRtMtJX8WkcNPxfQ0THVVb2fFwAH8f7m9-RskaK0Ak,8987
50
+ timm/data/parsers/parser_image_tar.py,sha256=E48D2RfbCftAty7mXtdZIX8d0yIPCEQIagLVkliEJ1o,2589
51
+ timm/data/parsers/parser_tfds.py,sha256=iZ6F8GGw-YbqN8DYOLtqZLTwQwjOLG5So9Rpt614lhY,15819
52
+ timm/data/random_erasing.py,sha256=zCv0vGv32QCHUIamd_HC_GnjLlODoyJBDoRuV7S2XCI,4767
53
+ timm/data/real_labels.py,sha256=D9pgNrsyiPIZTDVYRqFmkIYyJi-Dplql4RONZ3NNFTM,1590
54
+ timm/data/tf_preprocessing.py,sha256=vuJSsleBnS41C9upL4JTuO399nrojHV_WCSwtQrv694,9120
55
+ timm/data/transforms.py,sha256=0mPf22INsd3XPuBwdAAuhUptWjC-j-b_hFija_Irw2k,6194
56
+ timm/data/transforms_factory.py,sha256=8UuHo_2DG8eYnYMAHDv2BY4uSob6PokqEJ-CmC5Rk4k,8351
57
+ timm/loss/__init__.py,sha256=iCNB9bUAf69neNe1_XO0eeg1QXuxu6jRTAuy4V9yFL8,245
58
+ timm/loss/__pycache__/__init__.cpython-310.pyc,,
59
+ timm/loss/__pycache__/asymmetric_loss.cpython-310.pyc,,
60
+ timm/loss/__pycache__/binary_cross_entropy.cpython-310.pyc,,
61
+ timm/loss/__pycache__/cross_entropy.cpython-310.pyc,,
62
+ timm/loss/__pycache__/jsd.cpython-310.pyc,,
63
+ timm/loss/asymmetric_loss.py,sha256=YkMktzxiXncKK_GF5yBGDONOUENSdpzb7FJluCxuSlw,3322
64
+ timm/loss/binary_cross_entropy.py,sha256=gs6iNMKB2clMixxFnIhYKOxZ7LwFWWiHGQhIzjqgAA4,2030
65
+ timm/loss/cross_entropy.py,sha256=XDE19FnhYjeudAerb6UulIID34AmZoXQ1CPEAjEkCQM,1145
66
+ timm/loss/jsd.py,sha256=MFe8H_JC1srFE_FKinF7jMVIQYgNWgeT7kZL9WeIXGI,1595
67
+ timm/models/__init__.py,sha256=YvxGghp2a4O-xIlWtOFEJFMM1mCElnuZtNG525WiTtU,1784
68
+ timm/models/__pycache__/__init__.cpython-310.pyc,,
69
+ timm/models/__pycache__/beit.cpython-310.pyc,,
70
+ timm/models/__pycache__/byoanet.cpython-310.pyc,,
71
+ timm/models/__pycache__/byobnet.cpython-310.pyc,,
72
+ timm/models/__pycache__/cait.cpython-310.pyc,,
73
+ timm/models/__pycache__/coat.cpython-310.pyc,,
74
+ timm/models/__pycache__/convit.cpython-310.pyc,,
75
+ timm/models/__pycache__/convmixer.cpython-310.pyc,,
76
+ timm/models/__pycache__/convnext.cpython-310.pyc,,
77
+ timm/models/__pycache__/crossvit.cpython-310.pyc,,
78
+ timm/models/__pycache__/cspnet.cpython-310.pyc,,
79
+ timm/models/__pycache__/densenet.cpython-310.pyc,,
80
+ timm/models/__pycache__/dla.cpython-310.pyc,,
81
+ timm/models/__pycache__/dpn.cpython-310.pyc,,
82
+ timm/models/__pycache__/efficientnet.cpython-310.pyc,,
83
+ timm/models/__pycache__/efficientnet_blocks.cpython-310.pyc,,
84
+ timm/models/__pycache__/efficientnet_builder.cpython-310.pyc,,
85
+ timm/models/__pycache__/factory.cpython-310.pyc,,
86
+ timm/models/__pycache__/features.cpython-310.pyc,,
87
+ timm/models/__pycache__/fx_features.cpython-310.pyc,,
88
+ timm/models/__pycache__/ghostnet.cpython-310.pyc,,
89
+ timm/models/__pycache__/gluon_resnet.cpython-310.pyc,,
90
+ timm/models/__pycache__/gluon_xception.cpython-310.pyc,,
91
+ timm/models/__pycache__/hardcorenas.cpython-310.pyc,,
92
+ timm/models/__pycache__/helpers.cpython-310.pyc,,
93
+ timm/models/__pycache__/hrnet.cpython-310.pyc,,
94
+ timm/models/__pycache__/hub.cpython-310.pyc,,
95
+ timm/models/__pycache__/inception_resnet_v2.cpython-310.pyc,,
96
+ timm/models/__pycache__/inception_v3.cpython-310.pyc,,
97
+ timm/models/__pycache__/inception_v4.cpython-310.pyc,,
98
+ timm/models/__pycache__/levit.cpython-310.pyc,,
99
+ timm/models/__pycache__/mlp_mixer.cpython-310.pyc,,
100
+ timm/models/__pycache__/mobilenetv3.cpython-310.pyc,,
101
+ timm/models/__pycache__/nasnet.cpython-310.pyc,,
102
+ timm/models/__pycache__/nest.cpython-310.pyc,,
103
+ timm/models/__pycache__/nfnet.cpython-310.pyc,,
104
+ timm/models/__pycache__/pit.cpython-310.pyc,,
105
+ timm/models/__pycache__/pnasnet.cpython-310.pyc,,
106
+ timm/models/__pycache__/registry.cpython-310.pyc,,
107
+ timm/models/__pycache__/regnet.cpython-310.pyc,,
108
+ timm/models/__pycache__/res2net.cpython-310.pyc,,
109
+ timm/models/__pycache__/resnest.cpython-310.pyc,,
110
+ timm/models/__pycache__/resnet.cpython-310.pyc,,
111
+ timm/models/__pycache__/resnetv2.cpython-310.pyc,,
112
+ timm/models/__pycache__/rexnet.cpython-310.pyc,,
113
+ timm/models/__pycache__/selecsls.cpython-310.pyc,,
114
+ timm/models/__pycache__/senet.cpython-310.pyc,,
115
+ timm/models/__pycache__/sknet.cpython-310.pyc,,
116
+ timm/models/__pycache__/swin_transformer.cpython-310.pyc,,
117
+ timm/models/__pycache__/tnt.cpython-310.pyc,,
118
+ timm/models/__pycache__/tresnet.cpython-310.pyc,,
119
+ timm/models/__pycache__/twins.cpython-310.pyc,,
120
+ timm/models/__pycache__/vgg.cpython-310.pyc,,
121
+ timm/models/__pycache__/visformer.cpython-310.pyc,,
122
+ timm/models/__pycache__/vision_transformer.cpython-310.pyc,,
123
+ timm/models/__pycache__/vision_transformer_hybrid.cpython-310.pyc,,
124
+ timm/models/__pycache__/vovnet.cpython-310.pyc,,
125
+ timm/models/__pycache__/xception.cpython-310.pyc,,
126
+ timm/models/__pycache__/xception_aligned.cpython-310.pyc,,
127
+ timm/models/__pycache__/xcit.cpython-310.pyc,,
128
+ timm/models/beit.py,sha256=xUTeREocSxQizBpZh9suEzSr6T7J0wxTCV-LU1AhVRw,18558
129
+ timm/models/byoanet.py,sha256=Oq7z6zIGNUAYtrhalg1v-PM1GUY60rTJgtZFHZESc4o,18350
130
+ timm/models/byobnet.py,sha256=Ztjh8PPTWtiI9c6qDmPXj_tXAzqGGe1gCTAkLtxnRcc,62040
131
+ timm/models/cait.py,sha256=AT5n5q3rzZQL8qdcPg27Q2YlHxHXJ-tUbEpbhwuKEV0,14940
132
+ timm/models/coat.py,sha256=Mf7HSiTcilmiJn5u07QUbhz713G28CCiD12Be8mRuAs,26936
133
+ timm/models/convit.py,sha256=CpsyDS0P-PlrAs-hHsXMJhLXR9YPLGYfGCsDk3t42S4,13952
134
+ timm/models/convmixer.py,sha256=YT1IrJaUb2HFyvVYsGJN9Fb4i4eYVGxPPjJMfffvG0Q,3631
135
+ timm/models/convnext.py,sha256=jTdEBzMG9ZH_zsRXKN-2VeGfpodIWUQlphf6imnFPSg,17427
136
+ timm/models/crossvit.py,sha256=CP52LuFfiGTMs3iQIFJYmZSkDl2pYQaVARddYl9kOtI,22472
137
+ timm/models/cspnet.py,sha256=sqHI7IU9sqo1NuwGEhoHvUIzMp95MWFQHA9xDgsXMWg,18221
138
+ timm/models/densenet.py,sha256=7BW80_eMdBd54qOygP0jQGDedk9JSy1YPu0x2F0oBDQ,15611
139
+ timm/models/dla.py,sha256=rx7v7egYQocwGtr5TIeslP0AaJ88oFr0MdK4xzr2XlY,17205
140
+ timm/models/dpn.py,sha256=tFcBpES6yFM8jeyvI-giarXTz9YEZy2fex2vSQ7Lnhg,12448
141
+ timm/models/efficientnet.py,sha256=xKOm9HrD3AqF1In6ahRNPGP8KF7NuhL7W7W7PmcqPT8,97919
142
+ timm/models/efficientnet_blocks.py,sha256=i32acRskLshWDVKxkCL8T728baoD80FQLb-6zl-wi00,12442
143
+ timm/models/efficientnet_builder.py,sha256=ovUDS8FzdJqqZhI7hgay85kFMPhN87VQXvwEL0TRGkA,19459
144
+ timm/models/factory.py,sha256=EaYqVnAFXXGfTCPyxe0z-kAvjz4evZx9VE_mklkYmrE,3305
145
+ timm/models/features.py,sha256=DnO84Xi2mvMWIzK4kyN2SCePsMPxehSLWFZEGc_MItc,12155
146
+ timm/models/fx_features.py,sha256=Lu8PCRUiGe0SrjxTOyDb4BKVqzG1QN7E5nRFKErgEa0,2855
147
+ timm/models/ghostnet.py,sha256=D6tvP5ClRU6vUQzZWgk2exlhCCp71w0C-wwpDV17dy4,9326
148
+ timm/models/gluon_resnet.py,sha256=jTOSW9-gS6mu1V8omQM82nHvihn3oiI_5uPnDYqsua0,11362
149
+ timm/models/gluon_xception.py,sha256=UF0JRxJ1vPaQhYuNxz0hGjlS5IPHPBgrt2HSk4r5-Sk,8705
150
+ timm/models/hardcorenas.py,sha256=rpTmpo_aHC-i64-f2txejPCaitStjlYu0_ERX_Bul6o,8036
151
+ timm/models/helpers.py,sha256=hYyaTpH_4Rcn93prm70f-E_ahqM9HQJ0Cw0ma4qnpAo,22659
152
+ timm/models/hrnet.py,sha256=8LtXT2oEF9twTHWrwJW18S9OUVXC9RqHwTCnckIXDOM,29402
153
+ timm/models/hub.py,sha256=AdmCNspFlzLLohb3tFPiHv6R3ZFyXM12CxpjNM5gyts,5988
154
+ timm/models/inception_resnet_v2.py,sha256=Wg19hoczgHbWqHP5Uanf7zyhSUIYxs6P6RS6U3yB4uY,12464
155
+ timm/models/inception_v3.py,sha256=Cd7W-J6tao7jkQ9-sDvy0Guv1PvWFI81kxe5pYrSFbM,17469
156
+ timm/models/inception_v4.py,sha256=k2Yqzw-Ak3YxhCmIHXOfBgulBDiOugdxMLNmROTOgpk,10804
157
+ timm/models/layers/__init__.py,sha256=CI57cbXYovHX-pFO8O42eQwAd8D_XAnaXPp0IEezKgU,2213
158
+ timm/models/layers/__pycache__/__init__.cpython-310.pyc,,
159
+ timm/models/layers/__pycache__/activations.cpython-310.pyc,,
160
+ timm/models/layers/__pycache__/activations_jit.cpython-310.pyc,,
161
+ timm/models/layers/__pycache__/activations_me.cpython-310.pyc,,
162
+ timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-310.pyc,,
163
+ timm/models/layers/__pycache__/attention_pool2d.cpython-310.pyc,,
164
+ timm/models/layers/__pycache__/blur_pool.cpython-310.pyc,,
165
+ timm/models/layers/__pycache__/bottleneck_attn.cpython-310.pyc,,
166
+ timm/models/layers/__pycache__/cbam.cpython-310.pyc,,
167
+ timm/models/layers/__pycache__/classifier.cpython-310.pyc,,
168
+ timm/models/layers/__pycache__/cond_conv2d.cpython-310.pyc,,
169
+ timm/models/layers/__pycache__/config.cpython-310.pyc,,
170
+ timm/models/layers/__pycache__/conv2d_same.cpython-310.pyc,,
171
+ timm/models/layers/__pycache__/conv_bn_act.cpython-310.pyc,,
172
+ timm/models/layers/__pycache__/create_act.cpython-310.pyc,,
173
+ timm/models/layers/__pycache__/create_attn.cpython-310.pyc,,
174
+ timm/models/layers/__pycache__/create_conv2d.cpython-310.pyc,,
175
+ timm/models/layers/__pycache__/create_norm_act.cpython-310.pyc,,
176
+ timm/models/layers/__pycache__/drop.cpython-310.pyc,,
177
+ timm/models/layers/__pycache__/eca.cpython-310.pyc,,
178
+ timm/models/layers/__pycache__/evo_norm.cpython-310.pyc,,
179
+ timm/models/layers/__pycache__/gather_excite.cpython-310.pyc,,
180
+ timm/models/layers/__pycache__/global_context.cpython-310.pyc,,
181
+ timm/models/layers/__pycache__/halo_attn.cpython-310.pyc,,
182
+ timm/models/layers/__pycache__/helpers.cpython-310.pyc,,
183
+ timm/models/layers/__pycache__/inplace_abn.cpython-310.pyc,,
184
+ timm/models/layers/__pycache__/lambda_layer.cpython-310.pyc,,
185
+ timm/models/layers/__pycache__/linear.cpython-310.pyc,,
186
+ timm/models/layers/__pycache__/median_pool.cpython-310.pyc,,
187
+ timm/models/layers/__pycache__/mixed_conv2d.cpython-310.pyc,,
188
+ timm/models/layers/__pycache__/mlp.cpython-310.pyc,,
189
+ timm/models/layers/__pycache__/non_local_attn.cpython-310.pyc,,
190
+ timm/models/layers/__pycache__/norm.cpython-310.pyc,,
191
+ timm/models/layers/__pycache__/norm_act.cpython-310.pyc,,
192
+ timm/models/layers/__pycache__/padding.cpython-310.pyc,,
193
+ timm/models/layers/__pycache__/patch_embed.cpython-310.pyc,,
194
+ timm/models/layers/__pycache__/pool2d_same.cpython-310.pyc,,
195
+ timm/models/layers/__pycache__/selective_kernel.cpython-310.pyc,,
196
+ timm/models/layers/__pycache__/separable_conv.cpython-310.pyc,,
197
+ timm/models/layers/__pycache__/space_to_depth.cpython-310.pyc,,
198
+ timm/models/layers/__pycache__/split_attn.cpython-310.pyc,,
199
+ timm/models/layers/__pycache__/split_batchnorm.cpython-310.pyc,,
200
+ timm/models/layers/__pycache__/squeeze_excite.cpython-310.pyc,,
201
+ timm/models/layers/__pycache__/std_conv.cpython-310.pyc,,
202
+ timm/models/layers/__pycache__/test_time_pool.cpython-310.pyc,,
203
+ timm/models/layers/__pycache__/trace_utils.cpython-310.pyc,,
204
+ timm/models/layers/__pycache__/weight_init.cpython-310.pyc,,
205
+ timm/models/layers/activations.py,sha256=iT_WSweK1B14fAamfJOFcgqQ5DoBI6Fvt-X4fRuYsSM,4040
206
+ timm/models/layers/activations_jit.py,sha256=BQI8MYjZhJx0w6yTUpsl_x_tg-XFdbYDkYy5EEBQYIQ,2529
207
+ timm/models/layers/activations_me.py,sha256=Qlrh-NWXxC6OsxI1wppeBdsd8UGWXT8ECH95tFaXGEQ,5886
208
+ timm/models/layers/adaptive_avgmax_pool.py,sha256=I-lZ-AdvlTYguuxyTBgDWcquqhFf00uVujfVqZfLv_A,3890
209
+ timm/models/layers/attention_pool2d.py,sha256=wYx25PT4KZNVPdHW07mDkT19SIEcEtz3SNzh88Q98tw,6866
210
+ timm/models/layers/blur_pool.py,sha256=gVZRqXFUpOuH_ui98sTSsXsmnWXBmy9PoPlXgaMV8Q4,1591
211
+ timm/models/layers/bottleneck_attn.py,sha256=HLuZbyep1Nf9Qq9Aei81kCzQMs6U1aQBQRLrOnjnkHo,6895
212
+ timm/models/layers/cbam.py,sha256=5A0MAd2BaBBYKzWjbN_t81olp1tDMAoun816OyT9bVA,4418
213
+ timm/models/layers/classifier.py,sha256=GHJ80KXZu8sXOLLAB5S8zODJUFNxcZ-iIlHrxO63ymU,2231
214
+ timm/models/layers/cond_conv2d.py,sha256=YMnfZ9MSQyqqPQ1VxZsZsWRG1FAgWNWl9MHzXWZ1mWE,5129
215
+ timm/models/layers/config.py,sha256=Nna27P_B1cy4obs4Ma5_sd5VlXe_sCEYjv9ttyNABcE,3069
216
+ timm/models/layers/conv2d_same.py,sha256=2fv1zNaZJZgFJ1P5quM9pikQV-Pf620HRLX8ygQFHGU,1490
217
+ timm/models/layers/conv_bn_act.py,sha256=gTJJYA4-QOpL4prtWHc5aPfWbqrR10SAvaUwfnRIPN4,1404
218
+ timm/models/layers/create_act.py,sha256=CmKQZ1vu3bT7EVwmfLjSSEY6qLwR3Dea7ZR8mul5xdI,5359
219
+ timm/models/layers/create_attn.py,sha256=Z7uwbr07LDSBEz8kUaDwCuw5JX3xM-xTkzvsCXV4Duw,3526
220
+ timm/models/layers/create_conv2d.py,sha256=UH4RvUhNCz1ohkM3ayC6kXwBLuO1ro8go4jpgLgyjLs,1500
221
+ timm/models/layers/create_norm_act.py,sha256=Ln0UOFtMIWVJeAePJWx5qsLNSDl-b7Br3Z7XSzEVBqc,3450
222
+ timm/models/layers/drop.py,sha256=TVGBZLvuSEQrpNLH_FZOdUqanFRuY9a9Qq6yN8JmTgs,6732
223
+ timm/models/layers/eca.py,sha256=MiVhboDUqLUfeubpypWfaR3LMLHwgLCNsWO3iemcQFs,6386
224
+ timm/models/layers/evo_norm.py,sha256=HiGnhaOUYKIhBnVG6d9bwCGHO_BhsG7e75hDRQfu0E4,3519
225
+ timm/models/layers/gather_excite.py,sha256=53DHt6cySjPqd9NW3voZuhw8b9nUzvsG9NVl_D-9NAo,3824
226
+ timm/models/layers/global_context.py,sha256=aZWvij4J-T5I1rdTK725D6R0fDuJyYPDaXvl36QMmkw,2445
227
+ timm/models/layers/halo_attn.py,sha256=zMJkf9S-ocCvrfvWOe0I97UHTpEQIkP381DON3OXm-c,10662
228
+ timm/models/layers/helpers.py,sha256=pHZa-j8xR-BWLgflFyzvtwn9o1m52t5V__KauMkOutA,748
229
+ timm/models/layers/inplace_abn.py,sha256=4-8ZyftTMoNaU59NvUaQH8qpqBYeQDTIjjvYEwo1Lzg,3353
230
+ timm/models/layers/lambda_layer.py,sha256=WSlH2buUBDxRd5RFou5IN7iTFGS4nZL--j-XhrdOci8,5941
231
+ timm/models/layers/linear.py,sha256=baS2Wpl0vYELnvhnQ6Lw65jVotaJz5iGbroJJ9JmIRM,743
232
+ timm/models/layers/median_pool.py,sha256=b02v36VGvs_gCY9NhVwU7-mglcXJHzrJVzcEpEUuHBI,1737
233
+ timm/models/layers/mixed_conv2d.py,sha256=mRSmAUtpgHya_RdnUq4j85K5QS7JFTdSPUXOUTKgpmA,1843
234
+ timm/models/layers/mlp.py,sha256=hOai0-VxpEFGV_Lf4sb82Ko2vU5gJK-2U2F2GvnIFtU,4097
235
+ timm/models/layers/non_local_attn.py,sha256=58GuC8xjOFedZjWClPO_Bc9UJRl3J7giVm60F6bcsYo,6209
236
+ timm/models/layers/norm.py,sha256=qC3u_ncT88rnhc_Z3c6KP0JRMveDgHfoVQuVDV8sPjg,876
237
+ timm/models/layers/norm_act.py,sha256=WPUkzeBWPw-IHF-FoqYTyeoALxVXxc0bq5gW4PJnRBA,3545
238
+ timm/models/layers/padding.py,sha256=BjbtxJmui-DyeroZohYMdRAj5mR4o6pWgW04imi3hDI,2167
239
+ timm/models/layers/patch_embed.py,sha256=GXjHXzEvmDAOndeBuO3z7C4pkON1hzZeGi7NovFt1A4,1490
240
+ timm/models/layers/pool2d_same.py,sha256=UsmtWna5k5kfVTP25T1-OKJOgtcfBQCqSy0FmaZbjRw,3045
241
+ timm/models/layers/selective_kernel.py,sha256=Ltlwuw3gFRlmAUpxt0i44PGGaY1b1nNcfDO5AbqaK5U,5349
242
+ timm/models/layers/separable_conv.py,sha256=WrbdyyBteT3eqe5rN7TIrkqlWXfChP8f_0AMJS1sNDM,2528
243
+ timm/models/layers/space_to_depth.py,sha256=P-9czBYvbbPR7Ji-fB0dVFA_yABJQoO9C3IA4sL9ttI,1750
244
+ timm/models/layers/split_attn.py,sha256=HHpEHYCiBk0ecrnZpTVbb-4r4q1065WiCR6aC3AGfMU,3085
245
+ timm/models/layers/split_batchnorm.py,sha256=4ghGtliK5z0ZnzR29zJB_rN9BJPiGuy1PSltmVyF5Ww,3441
246
+ timm/models/layers/squeeze_excite.py,sha256=krSwp4li7AWhSP4zJLMZUwl_XIdHJk_Eb81iQ_s8FnA,3018
247
+ timm/models/layers/std_conv.py,sha256=zYhcKCbE0_Rqn422gEM9gr3LeBewu0CXKqvlsa9-M2Q,5887
248
+ timm/models/layers/test_time_pool.py,sha256=oQbw-agOC6sc3MjvbvsxrBtDa62r9gYiTEW01ICqjDY,1995
249
+ timm/models/layers/trace_utils.py,sha256=cbZufOaGKmhTGEMc52QAnqzGRTfn4vvzqsAOJaLKJQ8,335
250
+ timm/models/layers/weight_init.py,sha256=EbhK0ecja64ZJ4eXLpDVs7kLseumJBlSzWgbG9v5HL4,3324
251
+ timm/models/levit.py,sha256=2jBuY_jBe-orwUyhs7xfAs4sV0_aTSkZgPdWQh0-yi8,21163
252
+ timm/models/mlp_mixer.py,sha256=TC0HsK--Bxrp0aZddfCoO9bq5wpkAYtaJbElLmXbgCw,26042
253
+ timm/models/mobilenetv3.py,sha256=XJimeE7IFicHJtwqEd3SBOPhA-HJwgNUHF7oox2O_AE,26586
254
+ timm/models/nasnet.py,sha256=dZq20FEUebcRddTkpPVWpzL2HcATVttE4Ma6VsXIAzw,25944
255
+ timm/models/nest.py,sha256=dyR4nEKmtBdIp0nH1Fz6VcuQeAxYDF-8qQa8ROXTETc,19521
256
+ timm/models/nfnet.py,sha256=4A7W_-2Q-e2x5V1ovRYt0zzq7PNMvzG9RAlCeMPmtjU,41006
257
+ timm/models/pit.py,sha256=LpXaTnzYZ5WCfCtDnXjG2CJTh-A_lLEiqoRa2PrisME,13037
258
+ timm/models/pnasnet.py,sha256=8FISd7eO2N7flUeSsQwSyVm2DzhJiQWbm-ECtxwjh9w,14961
259
+ timm/models/pruned/ecaresnet101d_pruned.txt,sha256=1zA7XaxsTnFJxZ9PMbfMVST7wPSQcAV-UzSgdFfGgYY,8734
260
+ timm/models/pruned/ecaresnet50d_pruned.txt,sha256=J4AlTwabaSB6-XrINCPCDWMiM_FrdNjuJN_JJRb89WE,4520
261
+ timm/models/pruned/efficientnet_b1_pruned.txt,sha256=pNDm1EENJYMT8-GjXZ3kXWCXADLDun-4jfigh74RELE,18596
262
+ timm/models/pruned/efficientnet_b2_pruned.txt,sha256=e_oaSVM-Ux3NMVARynJ74YwjzxuBAX_w7kzOw9Ml3gM,18676
263
+ timm/models/pruned/efficientnet_b3_pruned.txt,sha256=A1DJEwjEmrg8oUr0QwzwBkdAJV3dVeUFjnO9pNC_0Pg,21133
264
+ timm/models/registry.py,sha256=YRSPZojqIfVIQ4z5yLK-LnNZZ2FXJ__cP3bNFVVbVck,5680
265
+ timm/models/regnet.py,sha256=Pk5Ybc-GP-6yHQ2uSn5nDQ3jkBl9C2rkCUW8wyluDfs,20998
266
+ timm/models/res2net.py,sha256=Przr9hV-FJQS7Ez0z_exGsPslnLefynBN4mAPGAtFH8,7865
267
+ timm/models/resnest.py,sha256=rOwUAV4vQHb89q54qVIra8eOXlpklqr8Aep7rd1Axig,10092
268
+ timm/models/resnet.py,sha256=zkPFrwSd2KXqCSMT7B0Zc_mXQ_D4mFKDkaigEF4XoW4,65986
269
+ timm/models/resnetv2.py,sha256=-4BvS5XIMFccvmaPRDWU6z7geIfaidix5-SwrVuvhug,28274
270
+ timm/models/rexnet.py,sha256=ee5vM9NPLRA-Ilo6ifJp39lGtXdlenAN1DALeVLHff8,9228
271
+ timm/models/selecsls.py,sha256=YrNit4pigYrqKuQPsBi8nzfXn4lVU8hva17bAcCB1a4,13124
272
+ timm/models/senet.py,sha256=MEvKLC_Msa1nK-_fzCDSjTgk9ysS4hHu7qmGBi4n7iU,17642
273
+ timm/models/sknet.py,sha256=3_MNhX4Jg3A8TAV7vJxcVaMKYbUXO-rF9JbNRliPD-w,8742
274
+ timm/models/swin_transformer.py,sha256=5So8rK82wHSu2ZkfHCptDWzrg4yZOGsLnM_iBd5PK3c,27372
275
+ timm/models/tnt.py,sha256=J28c3t1go0DPkVLf2Om6fwEES10cvCSxP9Ailp7dDt8,11209
276
+ timm/models/tresnet.py,sha256=2dUbOytAqAgFxNog9qFtM8M4wwep9-LOaOsgl7RanDs,11596
277
+ timm/models/twins.py,sha256=wL3gkk_Gwe220k_Td_VIHDcA1AiLBouGV4HRZJubjnE,17346
278
+ timm/models/vgg.py,sha256=ZXmmA7kUrvkwZJtAjjx_oDPzIY3S5yLWhGImCA0v1NE,10455
279
+ timm/models/visformer.py,sha256=iN7DjQ1aI3EQCil8g9Jv_S6a7IR2639BVObnVFXKwk0,16086
280
+ timm/models/vision_transformer.py,sha256=XINVbQd-cQTX0xBlnMSS1ZGy6rLUGYej-4iVs16rGNs,46475
281
+ timm/models/vision_transformer_hybrid.py,sha256=XGcP5sUmw-zixwi7UKDS3jx6PvAY0lUWnOc3DmSAX5s,16099
282
+ timm/models/vovnet.py,sha256=fnzwB_ciOog2Hbt-340QAT_UQq96HqIPRGNF09IuCcg,13845
283
+ timm/models/xception.py,sha256=lDcnUwDMVScwXzrplIeKPpSx_-Lr2ItfI0VPhP1KpjY,7388
284
+ timm/models/xception_aligned.py,sha256=BdbW2j5ql4mD0-TAqV6FtlnxSXSaK7Oe6TR-5TegzCQ,8948
285
+ timm/models/xcit.py,sha256=LMcM5J8aaKN119DYYZeisaaYNwgdMSlxYQujld4dMJE,35892
286
+ timm/optim/__init__.py,sha256=z1mMVKZ9loGnBRnS8SeE_F259As2DbXSXzHcRpuWy2E,484
287
+ timm/optim/__pycache__/__init__.cpython-310.pyc,,
288
+ timm/optim/__pycache__/adabelief.cpython-310.pyc,,
289
+ timm/optim/__pycache__/adafactor.cpython-310.pyc,,
290
+ timm/optim/__pycache__/adahessian.cpython-310.pyc,,
291
+ timm/optim/__pycache__/adamp.cpython-310.pyc,,
292
+ timm/optim/__pycache__/adamw.cpython-310.pyc,,
293
+ timm/optim/__pycache__/lamb.cpython-310.pyc,,
294
+ timm/optim/__pycache__/lars.cpython-310.pyc,,
295
+ timm/optim/__pycache__/lookahead.cpython-310.pyc,,
296
+ timm/optim/__pycache__/madgrad.cpython-310.pyc,,
297
+ timm/optim/__pycache__/nadam.cpython-310.pyc,,
298
+ timm/optim/__pycache__/nvnovograd.cpython-310.pyc,,
299
+ timm/optim/__pycache__/optim_factory.cpython-310.pyc,,
300
+ timm/optim/__pycache__/radam.cpython-310.pyc,,
301
+ timm/optim/__pycache__/rmsprop_tf.cpython-310.pyc,,
302
+ timm/optim/__pycache__/sgdp.cpython-310.pyc,,
303
+ timm/optim/adabelief.py,sha256=n8nVbFX0TrCgkI98s7sV9D1l_rwPoqgVdfUW1KxGMPY,9827
304
+ timm/optim/adafactor.py,sha256=UOYdbisCGOXJJF4sklBa4XEb3m68IyV6IkzcEopGack,7459
305
+ timm/optim/adahessian.py,sha256=vJtQ8bZTGLrkMYuGPOJdgO-5V8hjVvM2Il-HSqg59Ao,6535
306
+ timm/optim/adamp.py,sha256=PSJYfobQvxy9K0tdU6-mjaiF4BqhIXY9sHV2vposx5I,3574
307
+ timm/optim/adamw.py,sha256=OKSBGfaWs6DJC1aXJHadAp4FADAnDDwb-ZRKuPao7zk,5147
308
+ timm/optim/lamb.py,sha256=II9zTpcxWzNqgk4K-bs5VGKlQPabUolSAmHkcSjsqSU,9184
309
+ timm/optim/lars.py,sha256=8Ytu-q4FvXQWTEcP7R-8xSKdb72c2s1XhTvMzIshBME,5255
310
+ timm/optim/lookahead.py,sha256=nd42FXVedX6qlnyBXGMcxkj1IsUUOtwbVFa4dQYy83M,2463
311
+ timm/optim/madgrad.py,sha256=V3LJuPjGwiO7RdHAZFF0Qqa8JT8a9DJJLSEO2PCG7Ho,6893
312
+ timm/optim/nadam.py,sha256=ASEISt72rXnpfqVkKfgotJXBYpsyG9Pr17I8VFO6Eac,3871
313
+ timm/optim/nvnovograd.py,sha256=NkRLq007qqiRDrhqiZK1KP_kfCcFcDSYCWRcoYvddOQ,4856
314
+ timm/optim/optim_factory.py,sha256=te26CtKS1wOh1gwEeB6glHdMGxOyKH7xhtbC_V91upQ,8415
315
+ timm/optim/radam.py,sha256=dCeFJGKo5WC8w7Ad8tuldM6QFz41nYXJIYI5HkH6uxk,3468
316
+ timm/optim/rmsprop_tf.py,sha256=SX47YRaLPNB-YpJpLUbXqx21ZFoDPeqvpJX2kin4wCc,6143
317
+ timm/optim/sgdp.py,sha256=7f4ZMVHbjCTDTgPOZfE06S4lmdUBnIBCDr_Yzy1RFhY,2296
318
+ timm/scheduler/__init__.py,sha256=WhoyJyfj6SE2YIE06C1eMmnqr7tm5m17YG5s4uL9lXU,291
319
+ timm/scheduler/__pycache__/__init__.cpython-310.pyc,,
320
+ timm/scheduler/__pycache__/cosine_lr.cpython-310.pyc,,
321
+ timm/scheduler/__pycache__/multistep_lr.cpython-310.pyc,,
322
+ timm/scheduler/__pycache__/plateau_lr.cpython-310.pyc,,
323
+ timm/scheduler/__pycache__/poly_lr.cpython-310.pyc,,
324
+ timm/scheduler/__pycache__/scheduler.cpython-310.pyc,,
325
+ timm/scheduler/__pycache__/scheduler_factory.cpython-310.pyc,,
326
+ timm/scheduler/__pycache__/step_lr.cpython-310.pyc,,
327
+ timm/scheduler/__pycache__/tanh_lr.cpython-310.pyc,,
328
+ timm/scheduler/cosine_lr.py,sha256=VNWO_gQoFM436vnse7dgTZQtFbfx1PT3S-lwS1l4MLw,4161
329
+ timm/scheduler/multistep_lr.py,sha256=je8uCJHIlnyhvgFKhkrpSyQtCYl2G5QWWnZNVwPo8YQ,2098
330
+ timm/scheduler/plateau_lr.py,sha256=831LB8XE2nGCSmXSeWhSN8bDr_S4rS9mAR1TZNl-ttA,4140
331
+ timm/scheduler/poly_lr.py,sha256=k-XGB63zQdCYtUVP6WiUocaKgAV4dqwOHna9gw0p0tQ,4003
332
+ timm/scheduler/scheduler.py,sha256=t1F7sPeaMzTio6xeyxS9z2nLKMYt08c0xOPSEtCVUig,4750
333
+ timm/scheduler/scheduler_factory.py,sha256=hXi1jEFNLYuFqHPOOa1oW00g-dsqJ99ExhvjBAjSy0w,3682
334
+ timm/scheduler/step_lr.py,sha256=2L8uA_mq5_25jL0WASNnaQ3IkX9q5cZgE789PBNXryg,1902
335
+ timm/scheduler/tanh_lr.py,sha256=ApHa_ziKBwqZpcSm1xen_siQmFl-ja3yRJyck08F_04,3936
336
+ timm/utils/__init__.py,sha256=5Xoo-5kP6dZ5f6LV9u1DCRIqQrcX4OciHLs29DNcuEU,587
337
+ timm/utils/__pycache__/__init__.cpython-310.pyc,,
338
+ timm/utils/__pycache__/agc.cpython-310.pyc,,
339
+ timm/utils/__pycache__/checkpoint_saver.cpython-310.pyc,,
340
+ timm/utils/__pycache__/clip_grad.cpython-310.pyc,,
341
+ timm/utils/__pycache__/cuda.cpython-310.pyc,,
342
+ timm/utils/__pycache__/distributed.cpython-310.pyc,,
343
+ timm/utils/__pycache__/jit.cpython-310.pyc,,
344
+ timm/utils/__pycache__/log.cpython-310.pyc,,
345
+ timm/utils/__pycache__/metrics.cpython-310.pyc,,
346
+ timm/utils/__pycache__/misc.cpython-310.pyc,,
347
+ timm/utils/__pycache__/model.cpython-310.pyc,,
348
+ timm/utils/__pycache__/model_ema.cpython-310.pyc,,
349
+ timm/utils/__pycache__/random.cpython-310.pyc,,
350
+ timm/utils/__pycache__/summary.cpython-310.pyc,,
351
+ timm/utils/agc.py,sha256=6lZCChfbW0KGNMfkzztWD_NP87ESopjk24Xtb3WbBqU,1624
352
+ timm/utils/checkpoint_saver.py,sha256=RljigPicMAHnk48K2Qbl17cWnQepgO4QMZQ0FCjd8xw,6133
353
+ timm/utils/clip_grad.py,sha256=iYFEf7fvPbpyh5K1SI-EKey5Gqs2gztR9VUUGja0GB0,796
354
+ timm/utils/cuda.py,sha256=nerhMCalMMv1QHZXW1brcIvtzpWwalWYLTb6vJz2bnY,1703
355
+ timm/utils/distributed.py,sha256=MD1xa17GKPyPoUJ4lIkxucmD635GbKNuVvcsVCQlhsc,896
356
+ timm/utils/jit.py,sha256=OpEbtA3TgJTNDfKWX2oOsd1p4Sm5RQqZnZz5N78t2lk,648
357
+ timm/utils/log.py,sha256=BdZ2OqWo3v8d7wsDRJ-uACcoeNUhS8TJSwI3CYvq3Ss,1015
358
+ timm/utils/metrics.py,sha256=RSHpbbkyW6FsbxT6TzcBL7MZh4sv4A_GG1Bo8aN5qKc,901
359
+ timm/utils/misc.py,sha256=o6ZbvZJB-M6NX73E3ZFKzpKpbnV2bU-ZEj-RDlz-P58,644
360
+ timm/utils/model.py,sha256=IMc8JCt89gmCjlb68k-6RNb7wR4ZawWjG74Y4PMiSdo,12085
361
+ timm/utils/model_ema.py,sha256=PG0B6k198xf36sK1PrT32LaMe6zKlzUPPZmaPoSRbiQ,5670
362
+ timm/utils/random.py,sha256=Ysv6F3nIO8JYE8j6UrDxGyJDp3uNpq5v8U0KqL_8dic,178
363
+ timm/utils/summary.py,sha256=pdvBnXsLAS4CPGlMhB0lkUyizKGY4XP9M9f-VWlaJA0,1184
364
+ timm/version.py,sha256=YYYpMsC3ActrKnohyOQTZhA6i4ZGdIcgCekNZ1_03lo,22
lib/timm-0.5.4.dist-info/REQUESTED ADDED
File without changes
lib/timm-0.5.4.dist-info/WHEEL ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: bdist_wheel (0.37.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
lib/timm-0.5.4.dist-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ timm
lib/timm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .version import __version__
2
+ from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
3
+ is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
4
+ get_model_default_value, is_model_pretrained
lib/timm/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (549 Bytes). View file
 
lib/timm/__pycache__/version.cpython-310.pyc ADDED
Binary file (167 Bytes). View file
 
lib/timm/data/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
2
+ rand_augment_transform, auto_augment_transform
3
+ from .config import resolve_data_config
4
+ from .constants import *
5
+ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
6
+ from .dataset_factory import create_dataset
7
+ from .loader import create_loader
8
+ from .mixup import Mixup, FastCollateMixup
9
+ from .parsers import create_parser
10
+ from .real_labels import RealLabelsImagenet
11
+ from .transforms import *
12
+ from .transforms_factory import create_transform
lib/timm/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (860 Bytes). View file
 
lib/timm/data/__pycache__/auto_augment.cpython-310.pyc ADDED
Binary file (24 kB). View file
 
lib/timm/data/__pycache__/config.cpython-310.pyc ADDED
Binary file (1.57 kB). View file
 
lib/timm/data/__pycache__/constants.cpython-310.pyc ADDED
Binary file (487 Bytes). View file
 
lib/timm/data/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (5.05 kB). View file
 
lib/timm/data/__pycache__/dataset_factory.cpython-310.pyc ADDED
Binary file (4.25 kB). View file
 
lib/timm/data/__pycache__/distributed_sampler.cpython-310.pyc ADDED
Binary file (4.12 kB). View file
 
lib/timm/data/__pycache__/loader.cpython-310.pyc ADDED
Binary file (7.96 kB). View file
 
lib/timm/data/__pycache__/mixup.cpython-310.pyc ADDED
Binary file (11.3 kB). View file
 
lib/timm/data/__pycache__/random_erasing.cpython-310.pyc ADDED
Binary file (3.94 kB). View file
 
lib/timm/data/__pycache__/real_labels.cpython-310.pyc ADDED
Binary file (2.4 kB). View file
 
lib/timm/data/__pycache__/tf_preprocessing.cpython-310.pyc ADDED
Binary file (7.21 kB). View file
 
lib/timm/data/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (6.34 kB). View file
 
lib/timm/data/__pycache__/transforms_factory.cpython-310.pyc ADDED
Binary file (5.07 kB). View file
 
lib/timm/data/auto_augment.py ADDED
@@ -0,0 +1,865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ AutoAugment, RandAugment, and AugMix for PyTorch
2
+
3
+ This code implements the searched ImageNet policies with various tweaks and improvements and
4
+ does not include any of the search code.
5
+
6
+ AA and RA Implementation adapted from:
7
+ https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
8
+
9
+ AugMix adapted from:
10
+ https://github.com/google-research/augmix
11
+
12
+ Papers:
13
+ AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
14
+ Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
15
+ RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
16
+ AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
17
+
18
+ Hacked together by / Copyright 2019, Ross Wightman
19
+ """
20
+ import random
21
+ import math
22
+ import re
23
+ from PIL import Image, ImageOps, ImageEnhance, ImageChops
24
+ import PIL
25
+ import numpy as np
26
+
27
+
28
+ _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
29
+
30
+ _FILL = (128, 128, 128)
31
+
32
+ _LEVEL_DENOM = 10. # denominator for conversion from 'Mx' magnitude scale to fractional aug level for op arguments
33
+
34
+ _HPARAMS_DEFAULT = dict(
35
+ translate_const=250,
36
+ img_mean=_FILL,
37
+ )
38
+
39
+ _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
40
+
41
+
42
+ def _interpolation(kwargs):
43
+ interpolation = kwargs.pop('resample', Image.BILINEAR)
44
+ if isinstance(interpolation, (list, tuple)):
45
+ return random.choice(interpolation)
46
+ else:
47
+ return interpolation
48
+
49
+
50
+ def _check_args_tf(kwargs):
51
+ if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
52
+ kwargs.pop('fillcolor')
53
+ kwargs['resample'] = _interpolation(kwargs)
54
+
55
+
56
+ def shear_x(img, factor, **kwargs):
57
+ _check_args_tf(kwargs)
58
+ return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
59
+
60
+
61
+ def shear_y(img, factor, **kwargs):
62
+ _check_args_tf(kwargs)
63
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
64
+
65
+
66
+ def translate_x_rel(img, pct, **kwargs):
67
+ pixels = pct * img.size[0]
68
+ _check_args_tf(kwargs)
69
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
70
+
71
+
72
+ def translate_y_rel(img, pct, **kwargs):
73
+ pixels = pct * img.size[1]
74
+ _check_args_tf(kwargs)
75
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
76
+
77
+
78
+ def translate_x_abs(img, pixels, **kwargs):
79
+ _check_args_tf(kwargs)
80
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
81
+
82
+
83
+ def translate_y_abs(img, pixels, **kwargs):
84
+ _check_args_tf(kwargs)
85
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
86
+
87
+
88
+ def rotate(img, degrees, **kwargs):
89
+ _check_args_tf(kwargs)
90
+ if _PIL_VER >= (5, 2):
91
+ return img.rotate(degrees, **kwargs)
92
+ elif _PIL_VER >= (5, 0):
93
+ w, h = img.size
94
+ post_trans = (0, 0)
95
+ rotn_center = (w / 2.0, h / 2.0)
96
+ angle = -math.radians(degrees)
97
+ matrix = [
98
+ round(math.cos(angle), 15),
99
+ round(math.sin(angle), 15),
100
+ 0.0,
101
+ round(-math.sin(angle), 15),
102
+ round(math.cos(angle), 15),
103
+ 0.0,
104
+ ]
105
+
106
+ def transform(x, y, matrix):
107
+ (a, b, c, d, e, f) = matrix
108
+ return a * x + b * y + c, d * x + e * y + f
109
+
110
+ matrix[2], matrix[5] = transform(
111
+ -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
112
+ )
113
+ matrix[2] += rotn_center[0]
114
+ matrix[5] += rotn_center[1]
115
+ return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
116
+ else:
117
+ return img.rotate(degrees, resample=kwargs['resample'])
118
+
119
+
120
+ def auto_contrast(img, **__):
121
+ return ImageOps.autocontrast(img)
122
+
123
+
124
+ def invert(img, **__):
125
+ return ImageOps.invert(img)
126
+
127
+
128
+ def equalize(img, **__):
129
+ return ImageOps.equalize(img)
130
+
131
+
132
+ def solarize(img, thresh, **__):
133
+ return ImageOps.solarize(img, thresh)
134
+
135
+
136
+ def solarize_add(img, add, thresh=128, **__):
137
+ lut = []
138
+ for i in range(256):
139
+ if i < thresh:
140
+ lut.append(min(255, i + add))
141
+ else:
142
+ lut.append(i)
143
+ if img.mode in ("L", "RGB"):
144
+ if img.mode == "RGB" and len(lut) == 256:
145
+ lut = lut + lut + lut
146
+ return img.point(lut)
147
+ else:
148
+ return img
149
+
150
+
151
+ def posterize(img, bits_to_keep, **__):
152
+ if bits_to_keep >= 8:
153
+ return img
154
+ return ImageOps.posterize(img, bits_to_keep)
155
+
156
+
157
+ def contrast(img, factor, **__):
158
+ return ImageEnhance.Contrast(img).enhance(factor)
159
+
160
+
161
+ def color(img, factor, **__):
162
+ return ImageEnhance.Color(img).enhance(factor)
163
+
164
+
165
+ def brightness(img, factor, **__):
166
+ return ImageEnhance.Brightness(img).enhance(factor)
167
+
168
+
169
+ def sharpness(img, factor, **__):
170
+ return ImageEnhance.Sharpness(img).enhance(factor)
171
+
172
+
173
+ def _randomly_negate(v):
174
+ """With 50% prob, negate the value"""
175
+ return -v if random.random() > 0.5 else v
176
+
177
+
178
+ def _rotate_level_to_arg(level, _hparams):
179
+ # range [-30, 30]
180
+ level = (level / _LEVEL_DENOM) * 30.
181
+ level = _randomly_negate(level)
182
+ return level,
183
+
184
+
185
+ def _enhance_level_to_arg(level, _hparams):
186
+ # range [0.1, 1.9]
187
+ return (level / _LEVEL_DENOM) * 1.8 + 0.1,
188
+
189
+
190
+ def _enhance_increasing_level_to_arg(level, _hparams):
191
+ # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
192
+ # range [0.1, 1.9] if level <= _LEVEL_DENOM
193
+ level = (level / _LEVEL_DENOM) * .9
194
+ level = max(0.1, 1.0 + _randomly_negate(level)) # keep it >= 0.1
195
+ return level,
196
+
197
+
198
+ def _shear_level_to_arg(level, _hparams):
199
+ # range [-0.3, 0.3]
200
+ level = (level / _LEVEL_DENOM) * 0.3
201
+ level = _randomly_negate(level)
202
+ return level,
203
+
204
+
205
+ def _translate_abs_level_to_arg(level, hparams):
206
+ translate_const = hparams['translate_const']
207
+ level = (level / _LEVEL_DENOM) * float(translate_const)
208
+ level = _randomly_negate(level)
209
+ return level,
210
+
211
+
212
+ def _translate_rel_level_to_arg(level, hparams):
213
+ # default range [-0.45, 0.45]
214
+ translate_pct = hparams.get('translate_pct', 0.45)
215
+ level = (level / _LEVEL_DENOM) * translate_pct
216
+ level = _randomly_negate(level)
217
+ return level,
218
+
219
+
220
+ def _posterize_level_to_arg(level, _hparams):
221
+ # As per Tensorflow TPU EfficientNet impl
222
+ # range [0, 4], 'keep 0 up to 4 MSB of original image'
223
+ # intensity/severity of augmentation decreases with level
224
+ return int((level / _LEVEL_DENOM) * 4),
225
+
226
+
227
+ def _posterize_increasing_level_to_arg(level, hparams):
228
+ # As per Tensorflow models research and UDA impl
229
+ # range [4, 0], 'keep 4 down to 0 MSB of original image',
230
+ # intensity/severity of augmentation increases with level
231
+ return 4 - _posterize_level_to_arg(level, hparams)[0],
232
+
233
+
234
+ def _posterize_original_level_to_arg(level, _hparams):
235
+ # As per original AutoAugment paper description
236
+ # range [4, 8], 'keep 4 up to 8 MSB of image'
237
+ # intensity/severity of augmentation decreases with level
238
+ return int((level / _LEVEL_DENOM) * 4) + 4,
239
+
240
+
241
+ def _solarize_level_to_arg(level, _hparams):
242
+ # range [0, 256]
243
+ # intensity/severity of augmentation decreases with level
244
+ return int((level / _LEVEL_DENOM) * 256),
245
+
246
+
247
+ def _solarize_increasing_level_to_arg(level, _hparams):
248
+ # range [0, 256]
249
+ # intensity/severity of augmentation increases with level
250
+ return 256 - _solarize_level_to_arg(level, _hparams)[0],
251
+
252
+
253
+ def _solarize_add_level_to_arg(level, _hparams):
254
+ # range [0, 110]
255
+ return int((level / _LEVEL_DENOM) * 110),
256
+
257
+
258
+ LEVEL_TO_ARG = {
259
+ 'AutoContrast': None,
260
+ 'Equalize': None,
261
+ 'Invert': None,
262
+ 'Rotate': _rotate_level_to_arg,
263
+ # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
264
+ 'Posterize': _posterize_level_to_arg,
265
+ 'PosterizeIncreasing': _posterize_increasing_level_to_arg,
266
+ 'PosterizeOriginal': _posterize_original_level_to_arg,
267
+ 'Solarize': _solarize_level_to_arg,
268
+ 'SolarizeIncreasing': _solarize_increasing_level_to_arg,
269
+ 'SolarizeAdd': _solarize_add_level_to_arg,
270
+ 'Color': _enhance_level_to_arg,
271
+ 'ColorIncreasing': _enhance_increasing_level_to_arg,
272
+ 'Contrast': _enhance_level_to_arg,
273
+ 'ContrastIncreasing': _enhance_increasing_level_to_arg,
274
+ 'Brightness': _enhance_level_to_arg,
275
+ 'BrightnessIncreasing': _enhance_increasing_level_to_arg,
276
+ 'Sharpness': _enhance_level_to_arg,
277
+ 'SharpnessIncreasing': _enhance_increasing_level_to_arg,
278
+ 'ShearX': _shear_level_to_arg,
279
+ 'ShearY': _shear_level_to_arg,
280
+ 'TranslateX': _translate_abs_level_to_arg,
281
+ 'TranslateY': _translate_abs_level_to_arg,
282
+ 'TranslateXRel': _translate_rel_level_to_arg,
283
+ 'TranslateYRel': _translate_rel_level_to_arg,
284
+ }
285
+
286
+
287
+ NAME_TO_OP = {
288
+ 'AutoContrast': auto_contrast,
289
+ 'Equalize': equalize,
290
+ 'Invert': invert,
291
+ 'Rotate': rotate,
292
+ 'Posterize': posterize,
293
+ 'PosterizeIncreasing': posterize,
294
+ 'PosterizeOriginal': posterize,
295
+ 'Solarize': solarize,
296
+ 'SolarizeIncreasing': solarize,
297
+ 'SolarizeAdd': solarize_add,
298
+ 'Color': color,
299
+ 'ColorIncreasing': color,
300
+ 'Contrast': contrast,
301
+ 'ContrastIncreasing': contrast,
302
+ 'Brightness': brightness,
303
+ 'BrightnessIncreasing': brightness,
304
+ 'Sharpness': sharpness,
305
+ 'SharpnessIncreasing': sharpness,
306
+ 'ShearX': shear_x,
307
+ 'ShearY': shear_y,
308
+ 'TranslateX': translate_x_abs,
309
+ 'TranslateY': translate_y_abs,
310
+ 'TranslateXRel': translate_x_rel,
311
+ 'TranslateYRel': translate_y_rel,
312
+ }
313
+
314
+
315
+ class AugmentOp:
316
+
317
+ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
318
+ hparams = hparams or _HPARAMS_DEFAULT
319
+ self.name = name
320
+ self.aug_fn = NAME_TO_OP[name]
321
+ self.level_fn = LEVEL_TO_ARG[name]
322
+ self.prob = prob
323
+ self.magnitude = magnitude
324
+ self.hparams = hparams.copy()
325
+ self.kwargs = dict(
326
+ fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
327
+ resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
328
+ )
329
+
330
+ # If magnitude_std is > 0, we introduce some randomness
331
+ # in the usually fixed policy and sample magnitude from a normal distribution
332
+ # with mean `magnitude` and std-dev of `magnitude_std`.
333
+ # NOTE This is my own hack, being tested, not in papers or reference impls.
334
+ # If magnitude_std is inf, we sample magnitude from a uniform distribution
335
+ self.magnitude_std = self.hparams.get('magnitude_std', 0)
336
+ self.magnitude_max = self.hparams.get('magnitude_max', None)
337
+
338
+ def __call__(self, img):
339
+ if self.prob < 1.0 and random.random() > self.prob:
340
+ return img
341
+ magnitude = self.magnitude
342
+ if self.magnitude_std > 0:
343
+ # magnitude randomization enabled
344
+ if self.magnitude_std == float('inf'):
345
+ magnitude = random.uniform(0, magnitude)
346
+ elif self.magnitude_std > 0:
347
+ magnitude = random.gauss(magnitude, self.magnitude_std)
348
+ # default upper_bound for the timm RA impl is _LEVEL_DENOM (10)
349
+ # setting magnitude_max overrides this to allow M > 10 (behaviour closer to Google TF RA impl)
350
+ upper_bound = self.magnitude_max or _LEVEL_DENOM
351
+ magnitude = max(0., min(magnitude, upper_bound))
352
+ level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
353
+ return self.aug_fn(img, *level_args, **self.kwargs)
354
+
355
+ def __repr__(self):
356
+ fs = self.__class__.__name__ + f'(name={self.name}, p={self.prob}'
357
+ fs += f', m={self.magnitude}, mstd={self.magnitude_std}'
358
+ if self.magnitude_max is not None:
359
+ fs += f', mmax={self.magnitude_max}'
360
+ fs += ')'
361
+ return fs
362
+
363
+
364
+ def auto_augment_policy_v0(hparams):
365
+ # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
366
+ policy = [
367
+ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
368
+ [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
369
+ [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
370
+ [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
371
+ [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
372
+ [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
373
+ [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
374
+ [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
375
+ [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
376
+ [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
377
+ [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
378
+ [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
379
+ [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
380
+ [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
381
+ [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
382
+ [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
383
+ [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
384
+ [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
385
+ [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
386
+ [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
387
+ [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
388
+ [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
389
+ [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
390
+ [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
391
+ [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
392
+ ]
393
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
394
+ return pc
395
+
396
+
397
+ def auto_augment_policy_v0r(hparams):
398
+ # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
399
+ # in Google research implementation (number of bits discarded increases with magnitude)
400
+ policy = [
401
+ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
402
+ [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
403
+ [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
404
+ [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
405
+ [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
406
+ [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
407
+ [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
408
+ [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
409
+ [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
410
+ [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
411
+ [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
412
+ [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
413
+ [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
414
+ [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
415
+ [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
416
+ [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
417
+ [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
418
+ [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
419
+ [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
420
+ [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
421
+ [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
422
+ [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
423
+ [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
424
+ [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
425
+ [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
426
+ ]
427
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
428
+ return pc
429
+
430
+
431
+ def auto_augment_policy_original(hparams):
432
+ # ImageNet policy from https://arxiv.org/abs/1805.09501
433
+ policy = [
434
+ [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
435
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
436
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
437
+ [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
438
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
439
+ [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
440
+ [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
441
+ [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
442
+ [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
443
+ [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
444
+ [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
445
+ [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
446
+ [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
447
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
448
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
449
+ [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
450
+ [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
451
+ [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
452
+ [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
453
+ [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
454
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
455
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
456
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
457
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
458
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
459
+ ]
460
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
461
+ return pc
462
+
463
+
464
+ def auto_augment_policy_originalr(hparams):
465
+ # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
466
+ policy = [
467
+ [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
468
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
469
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
470
+ [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
471
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
472
+ [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
473
+ [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
474
+ [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
475
+ [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
476
+ [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
477
+ [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
478
+ [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
479
+ [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
480
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
481
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
482
+ [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
483
+ [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
484
+ [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
485
+ [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
486
+ [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
487
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
488
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
489
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
490
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
491
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
492
+ ]
493
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
494
+ return pc
495
+
496
+
497
+ def auto_augment_policy(name='v0', hparams=None):
498
+ hparams = hparams or _HPARAMS_DEFAULT
499
+ if name == 'original':
500
+ return auto_augment_policy_original(hparams)
501
+ elif name == 'originalr':
502
+ return auto_augment_policy_originalr(hparams)
503
+ elif name == 'v0':
504
+ return auto_augment_policy_v0(hparams)
505
+ elif name == 'v0r':
506
+ return auto_augment_policy_v0r(hparams)
507
+ else:
508
+ assert False, 'Unknown AA policy (%s)' % name
509
+
510
+
511
+ class AutoAugment:
512
+
513
+ def __init__(self, policy):
514
+ self.policy = policy
515
+
516
+ def __call__(self, img):
517
+ sub_policy = random.choice(self.policy)
518
+ for op in sub_policy:
519
+ img = op(img)
520
+ return img
521
+
522
+ def __repr__(self):
523
+ fs = self.__class__.__name__ + f'(policy='
524
+ for p in self.policy:
525
+ fs += '\n\t['
526
+ fs += ', '.join([str(op) for op in p])
527
+ fs += ']'
528
+ fs += ')'
529
+ return fs
530
+
531
+
532
+ def auto_augment_transform(config_str, hparams):
533
+ """
534
+ Create a AutoAugment transform
535
+
536
+ :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
537
+ dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
538
+ The remaining sections, not order sepecific determine
539
+ 'mstd' - float std deviation of magnitude noise applied
540
+ Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
541
+
542
+ :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
543
+
544
+ :return: A PyTorch compatible Transform
545
+ """
546
+ config = config_str.split('-')
547
+ policy_name = config[0]
548
+ config = config[1:]
549
+ for c in config:
550
+ cs = re.split(r'(\d.*)', c)
551
+ if len(cs) < 2:
552
+ continue
553
+ key, val = cs[:2]
554
+ if key == 'mstd':
555
+ # noise param injected via hparams for now
556
+ hparams.setdefault('magnitude_std', float(val))
557
+ else:
558
+ assert False, 'Unknown AutoAugment config section'
559
+ aa_policy = auto_augment_policy(policy_name, hparams=hparams)
560
+ return AutoAugment(aa_policy)
561
+
562
+
563
+ _RAND_TRANSFORMS = [
564
+ 'AutoContrast',
565
+ 'Equalize',
566
+ 'Invert',
567
+ 'Rotate',
568
+ 'Posterize',
569
+ 'Solarize',
570
+ 'SolarizeAdd',
571
+ 'Color',
572
+ 'Contrast',
573
+ 'Brightness',
574
+ 'Sharpness',
575
+ 'ShearX',
576
+ 'ShearY',
577
+ 'TranslateXRel',
578
+ 'TranslateYRel',
579
+ #'Cutout' # NOTE I've implement this as random erasing separately
580
+ ]
581
+
582
+
583
+ _RAND_INCREASING_TRANSFORMS = [
584
+ 'AutoContrast',
585
+ 'Equalize',
586
+ 'Invert',
587
+ 'Rotate',
588
+ 'PosterizeIncreasing',
589
+ 'SolarizeIncreasing',
590
+ 'SolarizeAdd',
591
+ 'ColorIncreasing',
592
+ 'ContrastIncreasing',
593
+ 'BrightnessIncreasing',
594
+ 'SharpnessIncreasing',
595
+ 'ShearX',
596
+ 'ShearY',
597
+ 'TranslateXRel',
598
+ 'TranslateYRel',
599
+ #'Cutout' # NOTE I've implement this as random erasing separately
600
+ ]
601
+
602
+
603
+
604
+ # These experimental weights are based loosely on the relative improvements mentioned in paper.
605
+ # They may not result in increased performance, but could likely be tuned to so.
606
+ _RAND_CHOICE_WEIGHTS_0 = {
607
+ 'Rotate': 0.3,
608
+ 'ShearX': 0.2,
609
+ 'ShearY': 0.2,
610
+ 'TranslateXRel': 0.1,
611
+ 'TranslateYRel': 0.1,
612
+ 'Color': .025,
613
+ 'Sharpness': 0.025,
614
+ 'AutoContrast': 0.025,
615
+ 'Solarize': .005,
616
+ 'SolarizeAdd': .005,
617
+ 'Contrast': .005,
618
+ 'Brightness': .005,
619
+ 'Equalize': .005,
620
+ 'Posterize': 0,
621
+ 'Invert': 0,
622
+ }
623
+
624
+
625
+ def _select_rand_weights(weight_idx=0, transforms=None):
626
+ transforms = transforms or _RAND_TRANSFORMS
627
+ assert weight_idx == 0 # only one set of weights currently
628
+ rand_weights = _RAND_CHOICE_WEIGHTS_0
629
+ probs = [rand_weights[k] for k in transforms]
630
+ probs /= np.sum(probs)
631
+ return probs
632
+
633
+
634
+ def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
635
+ hparams = hparams or _HPARAMS_DEFAULT
636
+ transforms = transforms or _RAND_TRANSFORMS
637
+ return [AugmentOp(
638
+ name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
639
+
640
+
641
+ class RandAugment:
642
+ def __init__(self, ops, num_layers=2, choice_weights=None):
643
+ self.ops = ops
644
+ self.num_layers = num_layers
645
+ self.choice_weights = choice_weights
646
+
647
+ def __call__(self, img):
648
+ # no replacement when using weighted choice
649
+ ops = np.random.choice(
650
+ self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
651
+ for op in ops:
652
+ img = op(img)
653
+ return img
654
+
655
+ def __repr__(self):
656
+ fs = self.__class__.__name__ + f'(n={self.num_layers}, ops='
657
+ for op in self.ops:
658
+ fs += f'\n\t{op}'
659
+ fs += ')'
660
+ return fs
661
+
662
+
663
+ def rand_augment_transform(config_str, hparams):
664
+ """
665
+ Create a RandAugment transform
666
+
667
+ :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
668
+ dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
669
+ sections, not order sepecific determine
670
+ 'm' - integer magnitude of rand augment
671
+ 'n' - integer num layers (number of transform ops selected per image)
672
+ 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
673
+ 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
674
+ 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
675
+ 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
676
+ Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
677
+ 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
678
+
679
+ :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
680
+
681
+ :return: A PyTorch compatible Transform
682
+ """
683
+ magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
684
+ num_layers = 2 # default to 2 ops per image
685
+ weight_idx = None # default to no probability weights for op choice
686
+ transforms = _RAND_TRANSFORMS
687
+ config = config_str.split('-')
688
+ assert config[0] == 'rand'
689
+ config = config[1:]
690
+ for c in config:
691
+ cs = re.split(r'(\d.*)', c)
692
+ if len(cs) < 2:
693
+ continue
694
+ key, val = cs[:2]
695
+ if key == 'mstd':
696
+ # noise param / randomization of magnitude values
697
+ mstd = float(val)
698
+ if mstd > 100:
699
+ # use uniform sampling in 0 to magnitude if mstd is > 100
700
+ mstd = float('inf')
701
+ hparams.setdefault('magnitude_std', mstd)
702
+ elif key == 'mmax':
703
+ # clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
704
+ hparams.setdefault('magnitude_max', int(val))
705
+ elif key == 'inc':
706
+ if bool(val):
707
+ transforms = _RAND_INCREASING_TRANSFORMS
708
+ elif key == 'm':
709
+ magnitude = int(val)
710
+ elif key == 'n':
711
+ num_layers = int(val)
712
+ elif key == 'w':
713
+ weight_idx = int(val)
714
+ else:
715
+ assert False, 'Unknown RandAugment config section'
716
+ ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
717
+ choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
718
+ return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
719
+
720
+
721
+ _AUGMIX_TRANSFORMS = [
722
+ 'AutoContrast',
723
+ 'ColorIncreasing', # not in paper
724
+ 'ContrastIncreasing', # not in paper
725
+ 'BrightnessIncreasing', # not in paper
726
+ 'SharpnessIncreasing', # not in paper
727
+ 'Equalize',
728
+ 'Rotate',
729
+ 'PosterizeIncreasing',
730
+ 'SolarizeIncreasing',
731
+ 'ShearX',
732
+ 'ShearY',
733
+ 'TranslateXRel',
734
+ 'TranslateYRel',
735
+ ]
736
+
737
+
738
+ def augmix_ops(magnitude=10, hparams=None, transforms=None):
739
+ hparams = hparams or _HPARAMS_DEFAULT
740
+ transforms = transforms or _AUGMIX_TRANSFORMS
741
+ return [AugmentOp(
742
+ name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
743
+
744
+
745
+ class AugMixAugment:
746
+ """ AugMix Transform
747
+ Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
748
+ From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
749
+ https://arxiv.org/abs/1912.02781
750
+ """
751
+ def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
752
+ self.ops = ops
753
+ self.alpha = alpha
754
+ self.width = width
755
+ self.depth = depth
756
+ self.blended = blended # blended mode is faster but not well tested
757
+
758
+ def _calc_blended_weights(self, ws, m):
759
+ ws = ws * m
760
+ cump = 1.
761
+ rws = []
762
+ for w in ws[::-1]:
763
+ alpha = w / cump
764
+ cump *= (1 - alpha)
765
+ rws.append(alpha)
766
+ return np.array(rws[::-1], dtype=np.float32)
767
+
768
+ def _apply_blended(self, img, mixing_weights, m):
769
+ # This is my first crack and implementing a slightly faster mixed augmentation. Instead
770
+ # of accumulating the mix for each chain in a Numpy array and then blending with original,
771
+ # it recomputes the blending coefficients and applies one PIL image blend per chain.
772
+ # TODO the results appear in the right ballpark but they differ by more than rounding.
773
+ img_orig = img.copy()
774
+ ws = self._calc_blended_weights(mixing_weights, m)
775
+ for w in ws:
776
+ depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
777
+ ops = np.random.choice(self.ops, depth, replace=True)
778
+ img_aug = img_orig # no ops are in-place, deep copy not necessary
779
+ for op in ops:
780
+ img_aug = op(img_aug)
781
+ img = Image.blend(img, img_aug, w)
782
+ return img
783
+
784
+ def _apply_basic(self, img, mixing_weights, m):
785
+ # This is a literal adaptation of the paper/official implementation without normalizations and
786
+ # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
787
+ # typical augmentation transforms, could use a GPU / Kornia implementation.
788
+ img_shape = img.size[0], img.size[1], len(img.getbands())
789
+ mixed = np.zeros(img_shape, dtype=np.float32)
790
+ for mw in mixing_weights:
791
+ depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
792
+ ops = np.random.choice(self.ops, depth, replace=True)
793
+ img_aug = img # no ops are in-place, deep copy not necessary
794
+ for op in ops:
795
+ img_aug = op(img_aug)
796
+ mixed += mw * np.asarray(img_aug, dtype=np.float32)
797
+ np.clip(mixed, 0, 255., out=mixed)
798
+ mixed = Image.fromarray(mixed.astype(np.uint8))
799
+ return Image.blend(img, mixed, m)
800
+
801
+ def __call__(self, img):
802
+ mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
803
+ m = np.float32(np.random.beta(self.alpha, self.alpha))
804
+ if self.blended:
805
+ mixed = self._apply_blended(img, mixing_weights, m)
806
+ else:
807
+ mixed = self._apply_basic(img, mixing_weights, m)
808
+ return mixed
809
+
810
+ def __repr__(self):
811
+ fs = self.__class__.__name__ + f'(alpha={self.alpha}, width={self.width}, depth={self.depth}, ops='
812
+ for op in self.ops:
813
+ fs += f'\n\t{op}'
814
+ fs += ')'
815
+ return fs
816
+
817
+
818
+ def augment_and_mix_transform(config_str, hparams):
819
+ """ Create AugMix PyTorch transform
820
+
821
+ :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
822
+ dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
823
+ sections, not order sepecific determine
824
+ 'm' - integer magnitude (severity) of augmentation mix (default: 3)
825
+ 'w' - integer width of augmentation chain (default: 3)
826
+ 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
827
+ 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
828
+ 'mstd' - float std deviation of magnitude noise applied (default: 0)
829
+ Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
830
+
831
+ :param hparams: Other hparams (kwargs) for the Augmentation transforms
832
+
833
+ :return: A PyTorch compatible Transform
834
+ """
835
+ magnitude = 3
836
+ width = 3
837
+ depth = -1
838
+ alpha = 1.
839
+ blended = False
840
+ config = config_str.split('-')
841
+ assert config[0] == 'augmix'
842
+ config = config[1:]
843
+ for c in config:
844
+ cs = re.split(r'(\d.*)', c)
845
+ if len(cs) < 2:
846
+ continue
847
+ key, val = cs[:2]
848
+ if key == 'mstd':
849
+ # noise param injected via hparams for now
850
+ hparams.setdefault('magnitude_std', float(val))
851
+ elif key == 'm':
852
+ magnitude = int(val)
853
+ elif key == 'w':
854
+ width = int(val)
855
+ elif key == 'd':
856
+ depth = int(val)
857
+ elif key == 'a':
858
+ alpha = float(val)
859
+ elif key == 'b':
860
+ blended = bool(val)
861
+ else:
862
+ assert False, 'Unknown AugMix config section'
863
+ hparams.setdefault('magnitude_std', float('inf')) # default to uniform sampling (if not set via mstd arg)
864
+ ops = augmix_ops(magnitude=magnitude, hparams=hparams)
865
+ return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
lib/timm/data/config.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from .constants import *
3
+
4
+
5
+ _logger = logging.getLogger(__name__)
6
+
7
+
8
+ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
9
+ new_config = {}
10
+ default_cfg = default_cfg
11
+ if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
12
+ default_cfg = model.default_cfg
13
+
14
+ # Resolve input/image size
15
+ in_chans = 3
16
+ if 'chans' in args and args['chans'] is not None:
17
+ in_chans = args['chans']
18
+
19
+ input_size = (in_chans, 224, 224)
20
+ if 'input_size' in args and args['input_size'] is not None:
21
+ assert isinstance(args['input_size'], (tuple, list))
22
+ assert len(args['input_size']) == 3
23
+ input_size = tuple(args['input_size'])
24
+ in_chans = input_size[0] # input_size overrides in_chans
25
+ elif 'img_size' in args and args['img_size'] is not None:
26
+ assert isinstance(args['img_size'], int)
27
+ input_size = (in_chans, args['img_size'], args['img_size'])
28
+ else:
29
+ if use_test_size and 'test_input_size' in default_cfg:
30
+ input_size = default_cfg['test_input_size']
31
+ elif 'input_size' in default_cfg:
32
+ input_size = default_cfg['input_size']
33
+ new_config['input_size'] = input_size
34
+
35
+ # resolve interpolation method
36
+ new_config['interpolation'] = 'bicubic'
37
+ if 'interpolation' in args and args['interpolation']:
38
+ new_config['interpolation'] = args['interpolation']
39
+ elif 'interpolation' in default_cfg:
40
+ new_config['interpolation'] = default_cfg['interpolation']
41
+
42
+ # resolve dataset + model mean for normalization
43
+ new_config['mean'] = IMAGENET_DEFAULT_MEAN
44
+ if 'mean' in args and args['mean'] is not None:
45
+ mean = tuple(args['mean'])
46
+ if len(mean) == 1:
47
+ mean = tuple(list(mean) * in_chans)
48
+ else:
49
+ assert len(mean) == in_chans
50
+ new_config['mean'] = mean
51
+ elif 'mean' in default_cfg:
52
+ new_config['mean'] = default_cfg['mean']
53
+
54
+ # resolve dataset + model std deviation for normalization
55
+ new_config['std'] = IMAGENET_DEFAULT_STD
56
+ if 'std' in args and args['std'] is not None:
57
+ std = tuple(args['std'])
58
+ if len(std) == 1:
59
+ std = tuple(list(std) * in_chans)
60
+ else:
61
+ assert len(std) == in_chans
62
+ new_config['std'] = std
63
+ elif 'std' in default_cfg:
64
+ new_config['std'] = default_cfg['std']
65
+
66
+ # resolve default crop percentage
67
+ new_config['crop_pct'] = DEFAULT_CROP_PCT
68
+ if 'crop_pct' in args and args['crop_pct'] is not None:
69
+ new_config['crop_pct'] = args['crop_pct']
70
+ elif 'crop_pct' in default_cfg:
71
+ new_config['crop_pct'] = default_cfg['crop_pct']
72
+
73
+ if verbose:
74
+ _logger.info('Data processing configuration for current model + dataset:')
75
+ for n, v in new_config.items():
76
+ _logger.info('\t%s: %s' % (n, str(v)))
77
+
78
+ return new_config
lib/timm/data/constants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ DEFAULT_CROP_PCT = 0.875
2
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
3
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
4
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
5
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
6
+ IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
7
+ IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
lib/timm/data/dataset.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Quick n Simple Image Folder, Tarfile based DataSet
2
+
3
+ Hacked together by / Copyright 2019, Ross Wightman
4
+ """
5
+ import torch.utils.data as data
6
+ import os
7
+ import torch
8
+ import logging
9
+
10
+ from PIL import Image
11
+
12
+ from .parsers import create_parser
13
+
14
+ _logger = logging.getLogger(__name__)
15
+
16
+
17
+ _ERROR_RETRY = 50
18
+
19
+
20
+ class ImageDataset(data.Dataset):
21
+
22
+ def __init__(
23
+ self,
24
+ root,
25
+ parser=None,
26
+ class_map=None,
27
+ load_bytes=False,
28
+ transform=None,
29
+ target_transform=None,
30
+ ):
31
+ if parser is None or isinstance(parser, str):
32
+ parser = create_parser(parser or '', root=root, class_map=class_map)
33
+ self.parser = parser
34
+ self.load_bytes = load_bytes
35
+ self.transform = transform
36
+ self.target_transform = target_transform
37
+ self._consecutive_errors = 0
38
+
39
+ def __getitem__(self, index):
40
+ img, target = self.parser[index]
41
+ try:
42
+ img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
43
+ except Exception as e:
44
+ _logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
45
+ self._consecutive_errors += 1
46
+ if self._consecutive_errors < _ERROR_RETRY:
47
+ return self.__getitem__((index + 1) % len(self.parser))
48
+ else:
49
+ raise e
50
+ self._consecutive_errors = 0
51
+ if self.transform is not None:
52
+ img = self.transform(img)
53
+ if target is None:
54
+ target = -1
55
+ elif self.target_transform is not None:
56
+ target = self.target_transform(target)
57
+ return img, target
58
+
59
+ def __len__(self):
60
+ return len(self.parser)
61
+
62
+ def filename(self, index, basename=False, absolute=False):
63
+ return self.parser.filename(index, basename, absolute)
64
+
65
+ def filenames(self, basename=False, absolute=False):
66
+ return self.parser.filenames(basename, absolute)
67
+
68
+
69
+ class IterableImageDataset(data.IterableDataset):
70
+
71
+ def __init__(
72
+ self,
73
+ root,
74
+ parser=None,
75
+ split='train',
76
+ is_training=False,
77
+ batch_size=None,
78
+ repeats=0,
79
+ download=False,
80
+ transform=None,
81
+ target_transform=None,
82
+ ):
83
+ assert parser is not None
84
+ if isinstance(parser, str):
85
+ self.parser = create_parser(
86
+ parser, root=root, split=split, is_training=is_training,
87
+ batch_size=batch_size, repeats=repeats, download=download)
88
+ else:
89
+ self.parser = parser
90
+ self.transform = transform
91
+ self.target_transform = target_transform
92
+ self._consecutive_errors = 0
93
+
94
+ def __iter__(self):
95
+ for img, target in self.parser:
96
+ if self.transform is not None:
97
+ img = self.transform(img)
98
+ if self.target_transform is not None:
99
+ target = self.target_transform(target)
100
+ yield img, target
101
+
102
+ def __len__(self):
103
+ if hasattr(self.parser, '__len__'):
104
+ return len(self.parser)
105
+ else:
106
+ return 0
107
+
108
+ def filename(self, index, basename=False, absolute=False):
109
+ assert False, 'Filename lookup by index not supported, use filenames().'
110
+
111
+ def filenames(self, basename=False, absolute=False):
112
+ return self.parser.filenames(basename, absolute)
113
+
114
+
115
+ class AugMixDataset(torch.utils.data.Dataset):
116
+ """Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
117
+
118
+ def __init__(self, dataset, num_splits=2):
119
+ self.augmentation = None
120
+ self.normalize = None
121
+ self.dataset = dataset
122
+ if self.dataset.transform is not None:
123
+ self._set_transforms(self.dataset.transform)
124
+ self.num_splits = num_splits
125
+
126
+ def _set_transforms(self, x):
127
+ assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
128
+ self.dataset.transform = x[0]
129
+ self.augmentation = x[1]
130
+ self.normalize = x[2]
131
+
132
+ @property
133
+ def transform(self):
134
+ return self.dataset.transform
135
+
136
+ @transform.setter
137
+ def transform(self, x):
138
+ self._set_transforms(x)
139
+
140
+ def _normalize(self, x):
141
+ return x if self.normalize is None else self.normalize(x)
142
+
143
+ def __getitem__(self, i):
144
+ x, y = self.dataset[i] # all splits share the same dataset base transform
145
+ x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
146
+ # run the full augmentation on the remaining splits
147
+ for _ in range(self.num_splits - 1):
148
+ x_list.append(self._normalize(self.augmentation(x)))
149
+ return tuple(x_list), y
150
+
151
+ def __len__(self):
152
+ return len(self.dataset)
lib/timm/data/dataset_factory.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Dataset Factory
2
+
3
+ Hacked together by / Copyright 2021, Ross Wightman
4
+ """
5
+ import os
6
+
7
+ from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder
8
+ try:
9
+ from torchvision.datasets import Places365
10
+ has_places365 = True
11
+ except ImportError:
12
+ has_places365 = False
13
+ try:
14
+ from torchvision.datasets import INaturalist
15
+ has_inaturalist = True
16
+ except ImportError:
17
+ has_inaturalist = False
18
+
19
+ from .dataset import IterableImageDataset, ImageDataset
20
+
21
+ _TORCH_BASIC_DS = dict(
22
+ cifar10=CIFAR10,
23
+ cifar100=CIFAR100,
24
+ mnist=MNIST,
25
+ qmist=QMNIST,
26
+ kmnist=KMNIST,
27
+ fashion_mnist=FashionMNIST,
28
+ )
29
+ _TRAIN_SYNONYM = {'train', 'training'}
30
+ _EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}
31
+
32
+
33
+ def _search_split(root, split):
34
+ # look for sub-folder with name of split in root and use that if it exists
35
+ split_name = split.split('[')[0]
36
+ try_root = os.path.join(root, split_name)
37
+ if os.path.exists(try_root):
38
+ return try_root
39
+
40
+ def _try(syn):
41
+ for s in syn:
42
+ try_root = os.path.join(root, s)
43
+ if os.path.exists(try_root):
44
+ return try_root
45
+ return root
46
+ if split_name in _TRAIN_SYNONYM:
47
+ root = _try(_TRAIN_SYNONYM)
48
+ elif split_name in _EVAL_SYNONYM:
49
+ root = _try(_EVAL_SYNONYM)
50
+ return root
51
+
52
+
53
+ def create_dataset(
54
+ name,
55
+ root,
56
+ split='validation',
57
+ search_split=True,
58
+ class_map=None,
59
+ load_bytes=False,
60
+ is_training=False,
61
+ download=False,
62
+ batch_size=None,
63
+ repeats=0,
64
+ **kwargs
65
+ ):
66
+ """ Dataset factory method
67
+
68
+ In parenthesis after each arg are the type of dataset supported for each arg, one of:
69
+ * folder - default, timm folder (or tar) based ImageDataset
70
+ * torch - torchvision based datasets
71
+ * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
72
+ * all - any of the above
73
+
74
+ Args:
75
+ name: dataset name, empty is okay for folder based datasets
76
+ root: root folder of dataset (all)
77
+ split: dataset split (all)
78
+ search_split: search for split specific child fold from root so one can specify
79
+ `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
80
+ class_map: specify class -> index mapping via text file or dict (folder)
81
+ load_bytes: load data, return images as undecoded bytes (folder)
82
+ download: download dataset if not present and supported (TFDS, torch)
83
+ is_training: create dataset in train mode, this is different from the split.
84
+ For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
85
+ batch_size: batch size hint for (TFDS)
86
+ repeats: dataset repeats per iteration i.e. epoch (TFDS)
87
+ **kwargs: other args to pass to dataset
88
+
89
+ Returns:
90
+ Dataset object
91
+ """
92
+ name = name.lower()
93
+ if name.startswith('torch/'):
94
+ name = name.split('/', 2)[-1]
95
+ torch_kwargs = dict(root=root, download=download, **kwargs)
96
+ if name in _TORCH_BASIC_DS:
97
+ ds_class = _TORCH_BASIC_DS[name]
98
+ use_train = split in _TRAIN_SYNONYM
99
+ ds = ds_class(train=use_train, **torch_kwargs)
100
+ elif name == 'inaturalist' or name == 'inat':
101
+ assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'
102
+ target_type = 'full'
103
+ split_split = split.split('/')
104
+ if len(split_split) > 1:
105
+ target_type = split_split[0].split('_')
106
+ if len(target_type) == 1:
107
+ target_type = target_type[0]
108
+ split = split_split[-1]
109
+ if split in _TRAIN_SYNONYM:
110
+ split = '2021_train'
111
+ elif split in _EVAL_SYNONYM:
112
+ split = '2021_valid'
113
+ ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
114
+ elif name == 'places365':
115
+ assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.'
116
+ if split in _TRAIN_SYNONYM:
117
+ split = 'train-standard'
118
+ elif split in _EVAL_SYNONYM:
119
+ split = 'val'
120
+ ds = Places365(split=split, **torch_kwargs)
121
+ elif name == 'imagenet':
122
+ if split in _EVAL_SYNONYM:
123
+ split = 'val'
124
+ ds = ImageNet(split=split, **torch_kwargs)
125
+ elif name == 'image_folder' or name == 'folder':
126
+ # in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
127
+ if search_split and os.path.isdir(root):
128
+ # look for split specific sub-folder in root
129
+ root = _search_split(root, split)
130
+ ds = ImageFolder(root, **kwargs)
131
+ else:
132
+ assert False, f"Unknown torchvision dataset {name}"
133
+ elif name.startswith('tfds/'):
134
+ ds = IterableImageDataset(
135
+ root, parser=name, split=split, is_training=is_training,
136
+ download=download, batch_size=batch_size, repeats=repeats, **kwargs)
137
+ else:
138
+ # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
139
+ if search_split and os.path.isdir(root):
140
+ # look for split specific sub-folder in root
141
+ root = _search_split(root, split)
142
+ ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
143
+ return ds
lib/timm/data/distributed_sampler.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data import Sampler
4
+ import torch.distributed as dist
5
+
6
+
7
+ class OrderedDistributedSampler(Sampler):
8
+ """Sampler that restricts data loading to a subset of the dataset.
9
+ It is especially useful in conjunction with
10
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
11
+ process can pass a DistributedSampler instance as a DataLoader sampler,
12
+ and load a subset of the original dataset that is exclusive to it.
13
+ .. note::
14
+ Dataset is assumed to be of constant size.
15
+ Arguments:
16
+ dataset: Dataset used for sampling.
17
+ num_replicas (optional): Number of processes participating in
18
+ distributed training.
19
+ rank (optional): Rank of the current process within num_replicas.
20
+ """
21
+
22
+ def __init__(self, dataset, num_replicas=None, rank=None):
23
+ if num_replicas is None:
24
+ if not dist.is_available():
25
+ raise RuntimeError("Requires distributed package to be available")
26
+ num_replicas = dist.get_world_size()
27
+ if rank is None:
28
+ if not dist.is_available():
29
+ raise RuntimeError("Requires distributed package to be available")
30
+ rank = dist.get_rank()
31
+ self.dataset = dataset
32
+ self.num_replicas = num_replicas
33
+ self.rank = rank
34
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
35
+ self.total_size = self.num_samples * self.num_replicas
36
+
37
+ def __iter__(self):
38
+ indices = list(range(len(self.dataset)))
39
+
40
+ # add extra samples to make it evenly divisible
41
+ indices += indices[:(self.total_size - len(indices))]
42
+ assert len(indices) == self.total_size
43
+
44
+ # subsample
45
+ indices = indices[self.rank:self.total_size:self.num_replicas]
46
+ assert len(indices) == self.num_samples
47
+
48
+ return iter(indices)
49
+
50
+ def __len__(self):
51
+ return self.num_samples
52
+
53
+
54
+ class RepeatAugSampler(Sampler):
55
+ """Sampler that restricts data loading to a subset of the dataset for distributed,
56
+ with repeated augmentation.
57
+ It ensures that different each augmented version of a sample will be visible to a
58
+ different process (GPU). Heavily based on torch.utils.data.DistributedSampler
59
+
60
+ This sampler was taken from https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
61
+ Used in
62
+ Copyright (c) 2015-present, Facebook, Inc.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ dataset,
68
+ num_replicas=None,
69
+ rank=None,
70
+ shuffle=True,
71
+ num_repeats=3,
72
+ selected_round=256,
73
+ selected_ratio=0,
74
+ ):
75
+ if num_replicas is None:
76
+ if not dist.is_available():
77
+ raise RuntimeError("Requires distributed package to be available")
78
+ num_replicas = dist.get_world_size()
79
+ if rank is None:
80
+ if not dist.is_available():
81
+ raise RuntimeError("Requires distributed package to be available")
82
+ rank = dist.get_rank()
83
+ self.dataset = dataset
84
+ self.num_replicas = num_replicas
85
+ self.rank = rank
86
+ self.shuffle = shuffle
87
+ self.num_repeats = num_repeats
88
+ self.epoch = 0
89
+ self.num_samples = int(math.ceil(len(self.dataset) * num_repeats / self.num_replicas))
90
+ self.total_size = self.num_samples * self.num_replicas
91
+ # Determine the number of samples to select per epoch for each rank.
92
+ # num_selected logic defaults to be the same as original RASampler impl, but this one can be tweaked
93
+ # via selected_ratio and selected_round args.
94
+ selected_ratio = selected_ratio or num_replicas # ratio to reduce selected samples by, num_replicas if 0
95
+ if selected_round:
96
+ self.num_selected_samples = int(math.floor(
97
+ len(self.dataset) // selected_round * selected_round / selected_ratio))
98
+ else:
99
+ self.num_selected_samples = int(math.ceil(len(self.dataset) / selected_ratio))
100
+
101
+ def __iter__(self):
102
+ # deterministically shuffle based on epoch
103
+ g = torch.Generator()
104
+ g.manual_seed(self.epoch)
105
+ if self.shuffle:
106
+ indices = torch.randperm(len(self.dataset), generator=g)
107
+ else:
108
+ indices = torch.arange(start=0, end=len(self.dataset))
109
+
110
+ # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
111
+ indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
112
+ # add extra samples to make it evenly divisible
113
+ padding_size = self.total_size - len(indices)
114
+ if padding_size > 0:
115
+ indices += indices[:padding_size]
116
+ assert len(indices) == self.total_size
117
+
118
+ # subsample per rank
119
+ indices = indices[self.rank:self.total_size:self.num_replicas]
120
+ assert len(indices) == self.num_samples
121
+
122
+ # return up to num selected samples
123
+ return iter(indices[:self.num_selected_samples])
124
+
125
+ def __len__(self):
126
+ return self.num_selected_samples
127
+
128
+ def set_epoch(self, epoch):
129
+ self.epoch = epoch
lib/timm/data/loader.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Loader Factory, Fast Collate, CUDA Prefetcher
2
+
3
+ Prefetcher and Fast Collate inspired by NVIDIA APEX example at
4
+ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
5
+
6
+ Hacked together by / Copyright 2019, Ross Wightman
7
+ """
8
+ import random
9
+ from functools import partial
10
+ from typing import Callable
11
+
12
+ import torch.utils.data
13
+ import numpy as np
14
+
15
+ from .transforms_factory import create_transform
16
+ from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
17
+ from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
18
+ from .random_erasing import RandomErasing
19
+ from .mixup import FastCollateMixup
20
+
21
+
22
+ def fast_collate(batch):
23
+ """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
24
+ assert isinstance(batch[0], tuple)
25
+ batch_size = len(batch)
26
+ if isinstance(batch[0][0], tuple):
27
+ # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
28
+ # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
29
+ inner_tuple_size = len(batch[0][0])
30
+ flattened_batch_size = batch_size * inner_tuple_size
31
+ targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
32
+ tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
33
+ for i in range(batch_size):
34
+ assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
35
+ for j in range(inner_tuple_size):
36
+ targets[i + j * batch_size] = batch[i][1]
37
+ tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
38
+ return tensor, targets
39
+ elif isinstance(batch[0][0], np.ndarray):
40
+ targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
41
+ assert len(targets) == batch_size
42
+ tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
43
+ for i in range(batch_size):
44
+ tensor[i] += torch.from_numpy(batch[i][0])
45
+ return tensor, targets
46
+ elif isinstance(batch[0][0], torch.Tensor):
47
+ targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
48
+ assert len(targets) == batch_size
49
+ tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
50
+ for i in range(batch_size):
51
+ tensor[i].copy_(batch[i][0])
52
+ return tensor, targets
53
+ else:
54
+ assert False
55
+
56
+
57
+ class PrefetchLoader:
58
+
59
+ def __init__(self,
60
+ loader,
61
+ mean=IMAGENET_DEFAULT_MEAN,
62
+ std=IMAGENET_DEFAULT_STD,
63
+ fp16=False,
64
+ re_prob=0.,
65
+ re_mode='const',
66
+ re_count=1,
67
+ re_num_splits=0):
68
+ self.loader = loader
69
+ self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
70
+ self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
71
+ self.fp16 = fp16
72
+ if fp16:
73
+ self.mean = self.mean.half()
74
+ self.std = self.std.half()
75
+ if re_prob > 0.:
76
+ self.random_erasing = RandomErasing(
77
+ probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
78
+ else:
79
+ self.random_erasing = None
80
+
81
+ def __iter__(self):
82
+ stream = torch.cuda.Stream()
83
+ first = True
84
+
85
+ for next_input, next_target in self.loader:
86
+ with torch.cuda.stream(stream):
87
+ next_input = next_input.cuda(non_blocking=True)
88
+ next_target = next_target.cuda(non_blocking=True)
89
+ if self.fp16:
90
+ next_input = next_input.half().sub_(self.mean).div_(self.std)
91
+ else:
92
+ next_input = next_input.float().sub_(self.mean).div_(self.std)
93
+ if self.random_erasing is not None:
94
+ next_input = self.random_erasing(next_input)
95
+
96
+ if not first:
97
+ yield input, target
98
+ else:
99
+ first = False
100
+
101
+ torch.cuda.current_stream().wait_stream(stream)
102
+ input = next_input
103
+ target = next_target
104
+
105
+ yield input, target
106
+
107
+ def __len__(self):
108
+ return len(self.loader)
109
+
110
+ @property
111
+ def sampler(self):
112
+ return self.loader.sampler
113
+
114
+ @property
115
+ def dataset(self):
116
+ return self.loader.dataset
117
+
118
+ @property
119
+ def mixup_enabled(self):
120
+ if isinstance(self.loader.collate_fn, FastCollateMixup):
121
+ return self.loader.collate_fn.mixup_enabled
122
+ else:
123
+ return False
124
+
125
+ @mixup_enabled.setter
126
+ def mixup_enabled(self, x):
127
+ if isinstance(self.loader.collate_fn, FastCollateMixup):
128
+ self.loader.collate_fn.mixup_enabled = x
129
+
130
+
131
+ def _worker_init(worker_id, worker_seeding='all'):
132
+ worker_info = torch.utils.data.get_worker_info()
133
+ assert worker_info.id == worker_id
134
+ if isinstance(worker_seeding, Callable):
135
+ seed = worker_seeding(worker_info)
136
+ random.seed(seed)
137
+ torch.manual_seed(seed)
138
+ np.random.seed(seed % (2 ** 32 - 1))
139
+ else:
140
+ assert worker_seeding in ('all', 'part')
141
+ # random / torch seed already called in dataloader iter class w/ worker_info.seed
142
+ # to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)
143
+ if worker_seeding == 'all':
144
+ np.random.seed(worker_info.seed % (2 ** 32 - 1))
145
+
146
+
147
+ def create_loader(
148
+ dataset,
149
+ input_size,
150
+ batch_size,
151
+ is_training=False,
152
+ use_prefetcher=True,
153
+ no_aug=False,
154
+ re_prob=0.,
155
+ re_mode='const',
156
+ re_count=1,
157
+ re_split=False,
158
+ scale=None,
159
+ ratio=None,
160
+ hflip=0.5,
161
+ vflip=0.,
162
+ color_jitter=0.4,
163
+ auto_augment=None,
164
+ num_aug_repeats=0,
165
+ num_aug_splits=0,
166
+ interpolation='bilinear',
167
+ mean=IMAGENET_DEFAULT_MEAN,
168
+ std=IMAGENET_DEFAULT_STD,
169
+ num_workers=1,
170
+ distributed=False,
171
+ crop_pct=None,
172
+ collate_fn=None,
173
+ pin_memory=False,
174
+ fp16=False,
175
+ tf_preprocessing=False,
176
+ use_multi_epochs_loader=False,
177
+ persistent_workers=True,
178
+ worker_seeding='all',
179
+ ):
180
+ re_num_splits = 0
181
+ if re_split:
182
+ # apply RE to second half of batch if no aug split otherwise line up with aug split
183
+ re_num_splits = num_aug_splits or 2
184
+ dataset.transform = create_transform(
185
+ input_size,
186
+ is_training=is_training,
187
+ use_prefetcher=use_prefetcher,
188
+ no_aug=no_aug,
189
+ scale=scale,
190
+ ratio=ratio,
191
+ hflip=hflip,
192
+ vflip=vflip,
193
+ color_jitter=color_jitter,
194
+ auto_augment=auto_augment,
195
+ interpolation=interpolation,
196
+ mean=mean,
197
+ std=std,
198
+ crop_pct=crop_pct,
199
+ tf_preprocessing=tf_preprocessing,
200
+ re_prob=re_prob,
201
+ re_mode=re_mode,
202
+ re_count=re_count,
203
+ re_num_splits=re_num_splits,
204
+ separate=num_aug_splits > 0,
205
+ )
206
+
207
+ sampler = None
208
+ if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
209
+ if is_training:
210
+ if num_aug_repeats:
211
+ sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats)
212
+ else:
213
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset)
214
+ else:
215
+ # This will add extra duplicate entries to result in equal num
216
+ # of samples per-process, will slightly alter validation results
217
+ sampler = OrderedDistributedSampler(dataset)
218
+ else:
219
+ assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"
220
+
221
+ if collate_fn is None:
222
+ collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
223
+
224
+ loader_class = torch.utils.data.DataLoader
225
+ if use_multi_epochs_loader:
226
+ loader_class = MultiEpochsDataLoader
227
+
228
+ loader_args = dict(
229
+ batch_size=batch_size,
230
+ shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
231
+ num_workers=num_workers,
232
+ sampler=sampler,
233
+ collate_fn=collate_fn,
234
+ pin_memory=pin_memory,
235
+ drop_last=is_training,
236
+ worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
237
+ persistent_workers=persistent_workers
238
+ )
239
+ try:
240
+ loader = loader_class(dataset, **loader_args)
241
+ except TypeError as e:
242
+ loader_args.pop('persistent_workers') # only in Pytorch 1.7+
243
+ loader = loader_class(dataset, **loader_args)
244
+ if use_prefetcher:
245
+ prefetch_re_prob = re_prob if is_training and not no_aug else 0.
246
+ loader = PrefetchLoader(
247
+ loader,
248
+ mean=mean,
249
+ std=std,
250
+ fp16=fp16,
251
+ re_prob=prefetch_re_prob,
252
+ re_mode=re_mode,
253
+ re_count=re_count,
254
+ re_num_splits=re_num_splits
255
+ )
256
+
257
+ return loader
258
+
259
+
260
+ class MultiEpochsDataLoader(torch.utils.data.DataLoader):
261
+
262
+ def __init__(self, *args, **kwargs):
263
+ super().__init__(*args, **kwargs)
264
+ self._DataLoader__initialized = False
265
+ self.batch_sampler = _RepeatSampler(self.batch_sampler)
266
+ self._DataLoader__initialized = True
267
+ self.iterator = super().__iter__()
268
+
269
+ def __len__(self):
270
+ return len(self.batch_sampler.sampler)
271
+
272
+ def __iter__(self):
273
+ for i in range(len(self)):
274
+ yield next(self.iterator)
275
+
276
+
277
+ class _RepeatSampler(object):
278
+ """ Sampler that repeats forever.
279
+
280
+ Args:
281
+ sampler (Sampler)
282
+ """
283
+
284
+ def __init__(self, sampler):
285
+ self.sampler = sampler
286
+
287
+ def __iter__(self):
288
+ while True:
289
+ yield from iter(self.sampler)
lib/timm/data/mixup.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Mixup and Cutmix
2
+
3
+ Papers:
4
+ mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5
+
6
+ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7
+
8
+ Code Reference:
9
+ CutMix: https://github.com/clovaai/CutMix-PyTorch
10
+
11
+ Hacked together by / Copyright 2019, Ross Wightman
12
+ """
13
+ import numpy as np
14
+ import torch
15
+
16
+
17
+ def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
18
+ x = x.long().view(-1, 1)
19
+ return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
20
+
21
+
22
+ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
23
+ off_value = smoothing / num_classes
24
+ on_value = 1. - smoothing + off_value
25
+ y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
26
+ y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
27
+ return y1 * lam + y2 * (1. - lam)
28
+
29
+
30
+ def rand_bbox(img_shape, lam, margin=0., count=None):
31
+ """ Standard CutMix bounding-box
32
+ Generates a random square bbox based on lambda value. This impl includes
33
+ support for enforcing a border margin as percent of bbox dimensions.
34
+
35
+ Args:
36
+ img_shape (tuple): Image shape as tuple
37
+ lam (float): Cutmix lambda value
38
+ margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
39
+ count (int): Number of bbox to generate
40
+ """
41
+ ratio = np.sqrt(1 - lam)
42
+ img_h, img_w = img_shape[-2:]
43
+ cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
44
+ margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
45
+ cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
46
+ cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
47
+ yl = np.clip(cy - cut_h // 2, 0, img_h)
48
+ yh = np.clip(cy + cut_h // 2, 0, img_h)
49
+ xl = np.clip(cx - cut_w // 2, 0, img_w)
50
+ xh = np.clip(cx + cut_w // 2, 0, img_w)
51
+ return yl, yh, xl, xh
52
+
53
+
54
+ def rand_bbox_minmax(img_shape, minmax, count=None):
55
+ """ Min-Max CutMix bounding-box
56
+ Inspired by Darknet cutmix impl, generates a random rectangular bbox
57
+ based on min/max percent values applied to each dimension of the input image.
58
+
59
+ Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
60
+
61
+ Args:
62
+ img_shape (tuple): Image shape as tuple
63
+ minmax (tuple or list): Min and max bbox ratios (as percent of image size)
64
+ count (int): Number of bbox to generate
65
+ """
66
+ assert len(minmax) == 2
67
+ img_h, img_w = img_shape[-2:]
68
+ cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
69
+ cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
70
+ yl = np.random.randint(0, img_h - cut_h, size=count)
71
+ xl = np.random.randint(0, img_w - cut_w, size=count)
72
+ yu = yl + cut_h
73
+ xu = xl + cut_w
74
+ return yl, yu, xl, xu
75
+
76
+
77
+ def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
78
+ """ Generate bbox and apply lambda correction.
79
+ """
80
+ if ratio_minmax is not None:
81
+ yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
82
+ else:
83
+ yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
84
+ if correct_lam or ratio_minmax is not None:
85
+ bbox_area = (yu - yl) * (xu - xl)
86
+ lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
87
+ return (yl, yu, xl, xu), lam
88
+
89
+
90
+ class Mixup:
91
+ """ Mixup/Cutmix that applies different params to each element or whole batch
92
+
93
+ Args:
94
+ mixup_alpha (float): mixup alpha value, mixup is active if > 0.
95
+ cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
96
+ cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
97
+ prob (float): probability of applying mixup or cutmix per batch or element
98
+ switch_prob (float): probability of switching to cutmix instead of mixup when both are active
99
+ mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
100
+ correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
101
+ label_smoothing (float): apply label smoothing to the mixed target tensor
102
+ num_classes (int): number of classes for target
103
+ """
104
+ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
105
+ mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
106
+ self.mixup_alpha = mixup_alpha
107
+ self.cutmix_alpha = cutmix_alpha
108
+ self.cutmix_minmax = cutmix_minmax
109
+ if self.cutmix_minmax is not None:
110
+ assert len(self.cutmix_minmax) == 2
111
+ # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
112
+ self.cutmix_alpha = 1.0
113
+ self.mix_prob = prob
114
+ self.switch_prob = switch_prob
115
+ self.label_smoothing = label_smoothing
116
+ self.num_classes = num_classes
117
+ self.mode = mode
118
+ self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
119
+ self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
120
+
121
+ def _params_per_elem(self, batch_size):
122
+ lam = np.ones(batch_size, dtype=np.float32)
123
+ use_cutmix = np.zeros(batch_size, dtype=np.bool)
124
+ if self.mixup_enabled:
125
+ if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
126
+ use_cutmix = np.random.rand(batch_size) < self.switch_prob
127
+ lam_mix = np.where(
128
+ use_cutmix,
129
+ np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
130
+ np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
131
+ elif self.mixup_alpha > 0.:
132
+ lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
133
+ elif self.cutmix_alpha > 0.:
134
+ use_cutmix = np.ones(batch_size, dtype=np.bool)
135
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
136
+ else:
137
+ assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
138
+ lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
139
+ return lam, use_cutmix
140
+
141
+ def _params_per_batch(self):
142
+ lam = 1.
143
+ use_cutmix = False
144
+ if self.mixup_enabled and np.random.rand() < self.mix_prob:
145
+ if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
146
+ use_cutmix = np.random.rand() < self.switch_prob
147
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
148
+ np.random.beta(self.mixup_alpha, self.mixup_alpha)
149
+ elif self.mixup_alpha > 0.:
150
+ lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
151
+ elif self.cutmix_alpha > 0.:
152
+ use_cutmix = True
153
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
154
+ else:
155
+ assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
156
+ lam = float(lam_mix)
157
+ return lam, use_cutmix
158
+
159
+ def _mix_elem(self, x):
160
+ batch_size = len(x)
161
+ lam_batch, use_cutmix = self._params_per_elem(batch_size)
162
+ x_orig = x.clone() # need to keep an unmodified original for mixing source
163
+ for i in range(batch_size):
164
+ j = batch_size - i - 1
165
+ lam = lam_batch[i]
166
+ if lam != 1.:
167
+ if use_cutmix[i]:
168
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
169
+ x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
170
+ x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
171
+ lam_batch[i] = lam
172
+ else:
173
+ x[i] = x[i] * lam + x_orig[j] * (1 - lam)
174
+ return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
175
+
176
+ def _mix_pair(self, x):
177
+ batch_size = len(x)
178
+ lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
179
+ x_orig = x.clone() # need to keep an unmodified original for mixing source
180
+ for i in range(batch_size // 2):
181
+ j = batch_size - i - 1
182
+ lam = lam_batch[i]
183
+ if lam != 1.:
184
+ if use_cutmix[i]:
185
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
186
+ x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
187
+ x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
188
+ x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
189
+ lam_batch[i] = lam
190
+ else:
191
+ x[i] = x[i] * lam + x_orig[j] * (1 - lam)
192
+ x[j] = x[j] * lam + x_orig[i] * (1 - lam)
193
+ lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
194
+ return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
195
+
196
+ def _mix_batch(self, x):
197
+ lam, use_cutmix = self._params_per_batch()
198
+ if lam == 1.:
199
+ return 1.
200
+ if use_cutmix:
201
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
202
+ x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
203
+ x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
204
+ else:
205
+ x_flipped = x.flip(0).mul_(1. - lam)
206
+ x.mul_(lam).add_(x_flipped)
207
+ return lam
208
+
209
+ def __call__(self, x, target):
210
+ assert len(x) % 2 == 0, 'Batch size should be even when using this'
211
+ if self.mode == 'elem':
212
+ lam = self._mix_elem(x)
213
+ elif self.mode == 'pair':
214
+ lam = self._mix_pair(x)
215
+ else:
216
+ lam = self._mix_batch(x)
217
+ target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
218
+ return x, target
219
+
220
+
221
+ class FastCollateMixup(Mixup):
222
+ """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
223
+
224
+ A Mixup impl that's performed while collating the batches.
225
+ """
226
+
227
+ def _mix_elem_collate(self, output, batch, half=False):
228
+ batch_size = len(batch)
229
+ num_elem = batch_size // 2 if half else batch_size
230
+ assert len(output) == num_elem
231
+ lam_batch, use_cutmix = self._params_per_elem(num_elem)
232
+ for i in range(num_elem):
233
+ j = batch_size - i - 1
234
+ lam = lam_batch[i]
235
+ mixed = batch[i][0]
236
+ if lam != 1.:
237
+ if use_cutmix[i]:
238
+ if not half:
239
+ mixed = mixed.copy()
240
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
241
+ output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
242
+ mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
243
+ lam_batch[i] = lam
244
+ else:
245
+ mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
246
+ np.rint(mixed, out=mixed)
247
+ output[i] += torch.from_numpy(mixed.astype(np.uint8))
248
+ if half:
249
+ lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
250
+ return torch.tensor(lam_batch).unsqueeze(1)
251
+
252
+ def _mix_pair_collate(self, output, batch):
253
+ batch_size = len(batch)
254
+ lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
255
+ for i in range(batch_size // 2):
256
+ j = batch_size - i - 1
257
+ lam = lam_batch[i]
258
+ mixed_i = batch[i][0]
259
+ mixed_j = batch[j][0]
260
+ assert 0 <= lam <= 1.0
261
+ if lam < 1.:
262
+ if use_cutmix[i]:
263
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
264
+ output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
265
+ patch_i = mixed_i[:, yl:yh, xl:xh].copy()
266
+ mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
267
+ mixed_j[:, yl:yh, xl:xh] = patch_i
268
+ lam_batch[i] = lam
269
+ else:
270
+ mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
271
+ mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
272
+ mixed_i = mixed_temp
273
+ np.rint(mixed_j, out=mixed_j)
274
+ np.rint(mixed_i, out=mixed_i)
275
+ output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
276
+ output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
277
+ lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
278
+ return torch.tensor(lam_batch).unsqueeze(1)
279
+
280
+ def _mix_batch_collate(self, output, batch):
281
+ batch_size = len(batch)
282
+ lam, use_cutmix = self._params_per_batch()
283
+ if use_cutmix:
284
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
285
+ output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
286
+ for i in range(batch_size):
287
+ j = batch_size - i - 1
288
+ mixed = batch[i][0]
289
+ if lam != 1.:
290
+ if use_cutmix:
291
+ mixed = mixed.copy() # don't want to modify the original while iterating
292
+ mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
293
+ else:
294
+ mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
295
+ np.rint(mixed, out=mixed)
296
+ output[i] += torch.from_numpy(mixed.astype(np.uint8))
297
+ return lam
298
+
299
+ def __call__(self, batch, _=None):
300
+ batch_size = len(batch)
301
+ assert batch_size % 2 == 0, 'Batch size should be even when using this'
302
+ half = 'half' in self.mode
303
+ if half:
304
+ batch_size //= 2
305
+ output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
306
+ if self.mode == 'elem' or self.mode == 'half':
307
+ lam = self._mix_elem_collate(output, batch, half=half)
308
+ elif self.mode == 'pair':
309
+ lam = self._mix_pair_collate(output, batch)
310
+ else:
311
+ lam = self._mix_batch_collate(output, batch)
312
+ target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
313
+ target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
314
+ target = target[:batch_size]
315
+ return output, target
316
+
lib/timm/data/parsers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .parser_factory import create_parser
lib/timm/data/parsers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (212 Bytes). View file
 
lib/timm/data/parsers/__pycache__/class_map.cpython-310.pyc ADDED
Binary file (942 Bytes). View file
 
lib/timm/data/parsers/__pycache__/constants.cpython-310.pyc ADDED
Binary file (199 Bytes). View file
 
lib/timm/data/parsers/__pycache__/parser.cpython-310.pyc ADDED
Binary file (1.11 kB). View file
 
lib/timm/data/parsers/__pycache__/parser_factory.cpython-310.pyc ADDED
Binary file (877 Bytes). View file
 
lib/timm/data/parsers/__pycache__/parser_image_folder.cpython-310.pyc ADDED
Binary file (3.06 kB). View file
 
lib/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-310.pyc ADDED
Binary file (7.68 kB). View file
 
lib/timm/data/parsers/__pycache__/parser_image_tar.cpython-310.pyc ADDED
Binary file (3.14 kB). View file
 
lib/timm/data/parsers/__pycache__/parser_tfds.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
lib/timm/data/parsers/class_map.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def load_class_map(map_or_filename, root=''):
5
+ if isinstance(map_or_filename, dict):
6
+ assert dict, 'class_map dict must be non-empty'
7
+ return map_or_filename
8
+ class_map_path = map_or_filename
9
+ if not os.path.exists(class_map_path):
10
+ class_map_path = os.path.join(root, class_map_path)
11
+ assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename
12
+ class_map_ext = os.path.splitext(map_or_filename)[-1].lower()
13
+ if class_map_ext == '.txt':
14
+ with open(class_map_path) as f:
15
+ class_to_idx = {v.strip(): k for k, v in enumerate(f)}
16
+ else:
17
+ assert False, f'Unsupported class map file extension ({class_map_ext}).'
18
+ return class_to_idx
19
+
lib/timm/data/parsers/constants.py ADDED
@@ -0,0 +1 @@
 
 
1
+ IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')
lib/timm/data/parsers/parser.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+
4
+ class Parser:
5
+ def __init__(self):
6
+ pass
7
+
8
+ @abstractmethod
9
+ def _filename(self, index, basename=False, absolute=False):
10
+ pass
11
+
12
+ def filename(self, index, basename=False, absolute=False):
13
+ return self._filename(index, basename=basename, absolute=absolute)
14
+
15
+ def filenames(self, basename=False, absolute=False):
16
+ return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]
17
+
lib/timm/data/parsers/parser_factory.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from .parser_image_folder import ParserImageFolder
4
+ from .parser_image_tar import ParserImageTar
5
+ from .parser_image_in_tar import ParserImageInTar
6
+
7
+
8
+ def create_parser(name, root, split='train', **kwargs):
9
+ name = name.lower()
10
+ name = name.split('/', 2)
11
+ prefix = ''
12
+ if len(name) > 1:
13
+ prefix = name[0]
14
+ name = name[-1]
15
+
16
+ # FIXME improve the selection right now just tfds prefix or fallback path, will need options to
17
+ # explicitly select other options shortly
18
+ if prefix == 'tfds':
19
+ from .parser_tfds import ParserTfds # defer tensorflow import
20
+ parser = ParserTfds(root, name, split=split, **kwargs)
21
+ else:
22
+ assert os.path.exists(root)
23
+ # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
24
+ # FIXME support split here, in parser?
25
+ if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
26
+ parser = ParserImageInTar(root, **kwargs)
27
+ else:
28
+ parser = ParserImageFolder(root, **kwargs)
29
+ return parser
lib/timm/data/parsers/parser_image_folder.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ A dataset parser that reads images from folders
2
+
3
+ Folders are scannerd recursively to find image files. Labels are based
4
+ on the folder hierarchy, just leaf folders by default.
5
+
6
+ Hacked together by / Copyright 2020 Ross Wightman
7
+ """
8
+ import os
9
+
10
+ from timm.utils.misc import natural_key
11
+
12
+ from .parser import Parser
13
+ from .class_map import load_class_map
14
+ from .constants import IMG_EXTENSIONS
15
+
16
+
17
+ def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
18
+ labels = []
19
+ filenames = []
20
+ for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
21
+ rel_path = os.path.relpath(root, folder) if (root != folder) else ''
22
+ label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
23
+ for f in files:
24
+ base, ext = os.path.splitext(f)
25
+ if ext.lower() in types:
26
+ filenames.append(os.path.join(root, f))
27
+ labels.append(label)
28
+ if class_to_idx is None:
29
+ # building class index
30
+ unique_labels = set(labels)
31
+ sorted_labels = list(sorted(unique_labels, key=natural_key))
32
+ class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
33
+ images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
34
+ if sort:
35
+ images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
36
+ return images_and_targets, class_to_idx
37
+
38
+
39
+ class ParserImageFolder(Parser):
40
+
41
+ def __init__(
42
+ self,
43
+ root,
44
+ class_map=''):
45
+ super().__init__()
46
+
47
+ self.root = root
48
+ class_to_idx = None
49
+ if class_map:
50
+ class_to_idx = load_class_map(class_map, root)
51
+ self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
52
+ if len(self.samples) == 0:
53
+ raise RuntimeError(
54
+ f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
55
+
56
+ def __getitem__(self, index):
57
+ path, target = self.samples[index]
58
+ return open(path, 'rb'), target
59
+
60
+ def __len__(self):
61
+ return len(self.samples)
62
+
63
+ def _filename(self, index, basename=False, absolute=False):
64
+ filename = self.samples[index][0]
65
+ if basename:
66
+ filename = os.path.basename(filename)
67
+ elif not absolute:
68
+ filename = os.path.relpath(filename, self.root)
69
+ return filename
lib/timm/data/parsers/parser_image_in_tar.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ A dataset parser that reads tarfile based datasets
2
+
3
+ This parser can read and extract image samples from:
4
+ * a single tar of image files
5
+ * a folder of multiple tarfiles containing imagefiles
6
+ * a tar of tars containing image files
7
+
8
+ Labels are based on the combined folder and/or tar name structure.
9
+
10
+ Hacked together by / Copyright 2020 Ross Wightman
11
+ """
12
+ import os
13
+ import tarfile
14
+ import pickle
15
+ import logging
16
+ import numpy as np
17
+ from glob import glob
18
+ from typing import List, Dict
19
+
20
+ from timm.utils.misc import natural_key
21
+
22
+ from .parser import Parser
23
+ from .class_map import load_class_map
24
+ from .constants import IMG_EXTENSIONS
25
+
26
+
27
+ _logger = logging.getLogger(__name__)
28
+ CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
29
+
30
+
31
+ class TarState:
32
+
33
+ def __init__(self, tf: tarfile.TarFile = None, ti: tarfile.TarInfo = None):
34
+ self.tf: tarfile.TarFile = tf
35
+ self.ti: tarfile.TarInfo = ti
36
+ self.children: Dict[str, TarState] = {} # child states (tars within tars)
37
+
38
+ def reset(self):
39
+ self.tf = None
40
+
41
+
42
+ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS):
43
+ sample_count = 0
44
+ for i, ti in enumerate(tf):
45
+ if not ti.isfile():
46
+ continue
47
+ dirname, basename = os.path.split(ti.path)
48
+ name, ext = os.path.splitext(basename)
49
+ ext = ext.lower()
50
+ if ext == '.tar':
51
+ with tarfile.open(fileobj=tf.extractfile(ti), mode='r|') as ctf:
52
+ child_info = dict(
53
+ name=ti.name, path=os.path.join(parent_info['path'], name), ti=ti, children=[], samples=[])
54
+ sample_count += _extract_tarinfo(ctf, child_info, extensions=extensions)
55
+ _logger.debug(f'{i}/?. Extracted child tarinfos from {ti.name}. {len(child_info["samples"])} images.')
56
+ parent_info['children'].append(child_info)
57
+ elif ext in extensions:
58
+ parent_info['samples'].append(ti)
59
+ sample_count += 1
60
+ return sample_count
61
+
62
+
63
+ def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True):
64
+ root_is_tar = False
65
+ if os.path.isfile(root):
66
+ assert os.path.splitext(root)[-1].lower() == '.tar'
67
+ tar_filenames = [root]
68
+ root, root_name = os.path.split(root)
69
+ root_name = os.path.splitext(root_name)[0]
70
+ root_is_tar = True
71
+ else:
72
+ root_name = root.strip(os.path.sep).split(os.path.sep)[-1]
73
+ tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True)
74
+ num_tars = len(tar_filenames)
75
+ tar_bytes = sum([os.path.getsize(f) for f in tar_filenames])
76
+ assert num_tars, f'No .tar files found at specified path ({root}).'
77
+
78
+ _logger.info(f'Scanning {tar_bytes/1024**2:.2f}MB of tar files...')
79
+ info = dict(tartrees=[])
80
+ cache_path = ''
81
+ if cache_tarinfo is None:
82
+ cache_tarinfo = True if tar_bytes > 10*1024**3 else False # FIXME magic number, 10GB
83
+ if cache_tarinfo:
84
+ cache_filename = '_' + root_name + CACHE_FILENAME_SUFFIX
85
+ cache_path = os.path.join(root, cache_filename)
86
+ if os.path.exists(cache_path):
87
+ _logger.info(f'Reading tar info from cache file {cache_path}.')
88
+ with open(cache_path, 'rb') as pf:
89
+ info = pickle.load(pf)
90
+ assert len(info['tartrees']) == num_tars, "Cached tartree len doesn't match number of tarfiles"
91
+ else:
92
+ for i, fn in enumerate(tar_filenames):
93
+ path = '' if root_is_tar else os.path.splitext(os.path.basename(fn))[0]
94
+ with tarfile.open(fn, mode='r|') as tf: # tarinfo scans done in streaming mode
95
+ parent_info = dict(name=os.path.relpath(fn, root), path=path, ti=None, children=[], samples=[])
96
+ num_samples = _extract_tarinfo(tf, parent_info, extensions=extensions)
97
+ num_children = len(parent_info["children"])
98
+ _logger.debug(
99
+ f'{i}/{num_tars}. Extracted tarinfos from {fn}. {num_children} children, {num_samples} samples.')
100
+ info['tartrees'].append(parent_info)
101
+ if cache_path:
102
+ _logger.info(f'Writing tar info to cache file {cache_path}.')
103
+ with open(cache_path, 'wb') as pf:
104
+ pickle.dump(info, pf)
105
+
106
+ samples = []
107
+ labels = []
108
+ build_class_map = False
109
+ if class_name_to_idx is None:
110
+ build_class_map = True
111
+
112
+ # Flatten tartree info into lists of samples and targets w/ targets based on label id via
113
+ # class map arg or from unique paths.
114
+ # NOTE: currently only flattening up to two-levels, filesystem .tars and then one level of sub-tar children
115
+ # this covers my current use cases and keeps things a little easier to test for now.
116
+ tarfiles = []
117
+
118
+ def _label_from_paths(*path, leaf_only=True):
119
+ path = os.path.join(*path).strip(os.path.sep)
120
+ return path.split(os.path.sep)[-1] if leaf_only else path.replace(os.path.sep, '_')
121
+
122
+ def _add_samples(info, fn):
123
+ added = 0
124
+ for s in info['samples']:
125
+ label = _label_from_paths(info['path'], os.path.dirname(s.path))
126
+ if not build_class_map and label not in class_name_to_idx:
127
+ continue
128
+ samples.append((s, fn, info['ti']))
129
+ labels.append(label)
130
+ added += 1
131
+ return added
132
+
133
+ _logger.info(f'Collecting samples and building tar states.')
134
+ for parent_info in info['tartrees']:
135
+ # if tartree has children, we assume all samples are at the child level
136
+ tar_name = None if root_is_tar else parent_info['name']
137
+ tar_state = TarState()
138
+ parent_added = 0
139
+ for child_info in parent_info['children']:
140
+ child_added = _add_samples(child_info, fn=tar_name)
141
+ if child_added:
142
+ tar_state.children[child_info['name']] = TarState(ti=child_info['ti'])
143
+ parent_added += child_added
144
+ parent_added += _add_samples(parent_info, fn=tar_name)
145
+ if parent_added:
146
+ tarfiles.append((tar_name, tar_state))
147
+ del info
148
+
149
+ if build_class_map:
150
+ # build class index
151
+ sorted_labels = list(sorted(set(labels), key=natural_key))
152
+ class_name_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
153
+
154
+ _logger.info(f'Mapping targets and sorting samples.')
155
+ samples_and_targets = [(s, class_name_to_idx[l]) for s, l in zip(samples, labels) if l in class_name_to_idx]
156
+ if sort:
157
+ samples_and_targets = sorted(samples_and_targets, key=lambda k: natural_key(k[0][0].path))
158
+ samples, targets = zip(*samples_and_targets)
159
+ samples = np.array(samples)
160
+ targets = np.array(targets)
161
+ _logger.info(f'Finished processing {len(samples)} samples across {len(tarfiles)} tar files.')
162
+ return samples, targets, class_name_to_idx, tarfiles
163
+
164
+
165
+ class ParserImageInTar(Parser):
166
+ """ Multi-tarfile dataset parser where there is one .tar file per class
167
+ """
168
+
169
+ def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
170
+ super().__init__()
171
+
172
+ class_name_to_idx = None
173
+ if class_map:
174
+ class_name_to_idx = load_class_map(class_map, root)
175
+ self.root = root
176
+ self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
177
+ self.root,
178
+ class_name_to_idx=class_name_to_idx,
179
+ cache_tarinfo=cache_tarinfo,
180
+ extensions=IMG_EXTENSIONS)
181
+ self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
182
+ if len(tarfiles) == 1 and tarfiles[0][0] is None:
183
+ self.root_is_tar = True
184
+ self.tar_state = tarfiles[0][1]
185
+ else:
186
+ self.root_is_tar = False
187
+ self.tar_state = dict(tarfiles)
188
+ self.cache_tarfiles = cache_tarfiles
189
+
190
+ def __len__(self):
191
+ return len(self.samples)
192
+
193
+ def __getitem__(self, index):
194
+ sample = self.samples[index]
195
+ target = self.targets[index]
196
+ sample_ti, parent_fn, child_ti = sample
197
+ parent_abs = os.path.join(self.root, parent_fn) if parent_fn else self.root
198
+
199
+ tf = None
200
+ cache_state = None
201
+ if self.cache_tarfiles:
202
+ cache_state = self.tar_state if self.root_is_tar else self.tar_state[parent_fn]
203
+ tf = cache_state.tf
204
+ if tf is None:
205
+ tf = tarfile.open(parent_abs)
206
+ if self.cache_tarfiles:
207
+ cache_state.tf = tf
208
+ if child_ti is not None:
209
+ ctf = cache_state.children[child_ti.name].tf if self.cache_tarfiles else None
210
+ if ctf is None:
211
+ ctf = tarfile.open(fileobj=tf.extractfile(child_ti))
212
+ if self.cache_tarfiles:
213
+ cache_state.children[child_ti.name].tf = ctf
214
+ tf = ctf
215
+
216
+ return tf.extractfile(sample_ti), target
217
+
218
+ def _filename(self, index, basename=False, absolute=False):
219
+ filename = self.samples[index][0].name
220
+ if basename:
221
+ filename = os.path.basename(filename)
222
+ return filename
lib/timm/data/parsers/parser_image_tar.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ A dataset parser that reads single tarfile based datasets
2
+
3
+ This parser can read datasets consisting if a single tarfile containing images.
4
+ I am planning to deprecated it in favour of ParerImageInTar.
5
+
6
+ Hacked together by / Copyright 2020 Ross Wightman
7
+ """
8
+ import os
9
+ import tarfile
10
+
11
+ from .parser import Parser
12
+ from .class_map import load_class_map
13
+ from .constants import IMG_EXTENSIONS
14
+ from timm.utils.misc import natural_key
15
+
16
+
17
+ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
18
+ files = []
19
+ labels = []
20
+ for ti in tarfile.getmembers():
21
+ if not ti.isfile():
22
+ continue
23
+ dirname, basename = os.path.split(ti.path)
24
+ label = os.path.basename(dirname)
25
+ ext = os.path.splitext(basename)[1]
26
+ if ext.lower() in IMG_EXTENSIONS:
27
+ files.append(ti)
28
+ labels.append(label)
29
+ if class_to_idx is None:
30
+ unique_labels = set(labels)
31
+ sorted_labels = list(sorted(unique_labels, key=natural_key))
32
+ class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
33
+ tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
34
+ if sort:
35
+ tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
36
+ return tarinfo_and_targets, class_to_idx
37
+
38
+
39
+ class ParserImageTar(Parser):
40
+ """ Single tarfile dataset where classes are mapped to folders within tar
41
+ NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
42
+ operate on folders of tars or tars in tars.
43
+ """
44
+ def __init__(self, root, class_map=''):
45
+ super().__init__()
46
+
47
+ class_to_idx = None
48
+ if class_map:
49
+ class_to_idx = load_class_map(class_map, root)
50
+ assert os.path.isfile(root)
51
+ self.root = root
52
+
53
+ with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
54
+ self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
55
+ self.imgs = self.samples
56
+ self.tarfile = None # lazy init in __getitem__
57
+
58
+ def __getitem__(self, index):
59
+ if self.tarfile is None:
60
+ self.tarfile = tarfile.open(self.root)
61
+ tarinfo, target = self.samples[index]
62
+ fileobj = self.tarfile.extractfile(tarinfo)
63
+ return fileobj, target
64
+
65
+ def __len__(self):
66
+ return len(self.samples)
67
+
68
+ def _filename(self, index, basename=False, absolute=False):
69
+ filename = self.samples[index][0].name
70
+ if basename:
71
+ filename = os.path.basename(filename)
72
+ return filename