Spaces:
Running
on
Zero
Running
on
Zero
yucornetto
commited on
upload
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- LICENSE +201 -0
- README.md +82 -14
- README_RAR.md +221 -0
- README_TiTok.md +213 -0
- assets/ILSVRC2012_val_00008636.png +0 -0
- assets/ILSVRC2012_val_00010240.png +0 -0
- assets/perf_comp.png +0 -0
- assets/random_vis_l32.png +3 -0
- assets/rar_overview.png +3 -0
- assets/recon_w_model_size_num_token.png +3 -0
- assets/speed_vs_perf.png +0 -0
- assets/titok_teaser.png +0 -0
- assets/vis1.png +3 -0
- assets/vis2.png +3 -0
- assets/vis3.png +3 -0
- configs/infer/titok_b64.yaml +39 -0
- configs/infer/titok_bl128_vae_c16.yaml +19 -0
- configs/infer/titok_bl128_vq8k.yaml +21 -0
- configs/infer/titok_bl64_vae_c16.yaml +19 -0
- configs/infer/titok_bl64_vq8k.yaml +21 -0
- configs/infer/titok_l32.yaml +40 -0
- configs/infer/titok_ll32_vae_c16.yaml +19 -0
- configs/infer/titok_s128.yaml +39 -0
- configs/infer/titok_sl256_vq8k.yaml +21 -0
- configs/training/generator/maskgit.yaml +86 -0
- configs/training/generator/rar.yaml +78 -0
- configs/training/stage1/titok_b64.yaml +70 -0
- configs/training/stage1/titok_l32.yaml +70 -0
- configs/training/stage1/titok_s128.yaml +70 -0
- configs/training/stage2/titok_b64.yaml +80 -0
- configs/training/stage2/titok_l32.yaml +79 -0
- configs/training/stage2/titok_s128.yaml +79 -0
- data/__init__.py +1 -0
- data/convert_imagenet_to_wds.py +68 -0
- data/webdataset_reader.py +227 -0
- demo.ipynb +0 -0
- demo_util.py +108 -0
- evaluator/__init__.py +1 -0
- evaluator/evaluator.py +245 -0
- evaluator/inception.py +231 -0
- imagenet_classes.py +1001 -0
- modeling/__init__.py +15 -0
- modeling/maskgit.py +282 -0
- modeling/modules/__init__.py +6 -0
- modeling/modules/base_model.py +127 -0
- modeling/modules/blocks.py +376 -0
- modeling/modules/discriminator.py +141 -0
- modeling/modules/ema_model.py +244 -0
- modeling/modules/losses.py +293 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/random_vis_l32.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/rar_overview.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/recon_w_model_size_num_token.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/vis1.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/vis2.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/vis3.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,14 +1,82 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 1D Visual Tokenization and Generation
|
2 |
+
|
3 |
+
This repo hosts the code and models for the following projects:
|
4 |
+
|
5 |
+
- RAR: [Randomized Autoregressive Visual Generation](https://yucornetto.github.io/projects/rar.html)
|
6 |
+
|
7 |
+
- TiTok: [An Image is Worth 32 Tokens for Reconstruction and Generation](https://yucornetto.github.io/projects/titok.html)
|
8 |
+
|
9 |
+
|
10 |
+
## Short Intro on [Randomized Autoregressive Visual Generation](https://arxiv.org/abs/2406.07550) ([README](README_RAR.md))
|
11 |
+
|
12 |
+
RAR is a an autoregressive (AR) image generator with full compatibility to language modeling. It introduces a randomness annealing strategy with permuted objective at no additional cost, which enhances the model's ability to learn bidirectional contexts while leaving the autoregressive framework intact. RAR sets a FID score 1.48, demonstrating state-of-the-art performance on ImageNet-256 benchmark and significantly outperforming prior AR image generators.
|
13 |
+
|
14 |
+
<p>
|
15 |
+
<img src="assets/rar_overview.png" alt="teaser" width=90% height=90%>
|
16 |
+
</p>
|
17 |
+
<p>
|
18 |
+
<img src="assets/perf_comp.png" alt="teaser" width=90% height=90%>
|
19 |
+
</p>
|
20 |
+
|
21 |
+
See more details at [README_RAR](README_RAR.md).
|
22 |
+
|
23 |
+
## Short Intro on [An Image is Worth 32 Tokens for Reconstruction and Generation](https://arxiv.org/abs/2406.07550) ([README](README_TiTok.md))
|
24 |
+
|
25 |
+
We present a compact 1D tokenizer which can represent an image with as few as 32 discrete tokens. As a result, it leads to a substantial speed-up on the sampling process (e.g., **410 × faster** than DiT-XL/2) while obtaining a competitive generation quality.
|
26 |
+
|
27 |
+
<p>
|
28 |
+
<img src="assets/titok_teaser.png" alt="teaser" width=90% height=90%>
|
29 |
+
</p>
|
30 |
+
<p>
|
31 |
+
<img src="assets/speed_vs_perf.png" alt="teaser" width=90% height=90%>
|
32 |
+
</p>
|
33 |
+
|
34 |
+
See more details at [README_TiTok](README_TiTok.md).
|
35 |
+
|
36 |
+
## Updates
|
37 |
+
- 11/04/2024: We release the [tech report](https://arxiv.org/abs/2411.00776) and code for RAR models.
|
38 |
+
- 10/16/2024: We update a set of TiTok tokenizer weights trained with an updated single-stage recipe, leading to easier training and better performance. We release the weight of different model size for both VQ and VAE variants TiTok, which we hope could facilitate the research in this area. More details will be available in a tech report later.
|
39 |
+
- 09/25/2024: TiTok is accepted by NeurIPS 2024.
|
40 |
+
- 09/11/2024: Release the training codes of generator based on TiTok.
|
41 |
+
- 08/28/2024: Release the training codes of TiTok.
|
42 |
+
- 08/09/2024: Better support on loading pretrained weights from huggingface models, thanks for the help from [@NielsRogge](https://github.com/NielsRogge)!
|
43 |
+
- 07/03/2024: Evaluation scripts for reproducing the results reported in the paper, checkpoints of TiTok-B64 and TiTok-S128 are available.
|
44 |
+
- 06/21/2024: Demo code and TiTok-L-32 checkpoints release.
|
45 |
+
- 06/11/2024: The [tech report](https://arxiv.org/abs/2406.07550) of TiTok is available.
|
46 |
+
|
47 |
+
|
48 |
+
## Installation
|
49 |
+
```shell
|
50 |
+
pip3 install -r requirements.txt
|
51 |
+
```
|
52 |
+
|
53 |
+
## Citing
|
54 |
+
If you use our work in your research, please use the following BibTeX entry.
|
55 |
+
|
56 |
+
```BibTeX
|
57 |
+
@article{yu2024randomized,
|
58 |
+
author = {Qihang Yu and Ju He and Xueqing Deng and Xiaohui Shen and Liang-Chieh Chen},
|
59 |
+
title = {Randomized Autoregressive Visual Generation},
|
60 |
+
journal = {arXiv preprint arXiv:2411.00776},
|
61 |
+
year = {2024}
|
62 |
+
}
|
63 |
+
```
|
64 |
+
|
65 |
+
```BibTeX
|
66 |
+
@article{yu2024an,
|
67 |
+
author = {Qihang Yu and Mark Weber and Xueqing Deng and Xiaohui Shen and Daniel Cremers and Liang-Chieh Chen},
|
68 |
+
title = {An Image is Worth 32 Tokens for Reconstruction and Generation},
|
69 |
+
journal = {NeurIPS},
|
70 |
+
year = {2024}
|
71 |
+
}
|
72 |
+
```
|
73 |
+
|
74 |
+
## Acknowledgement
|
75 |
+
|
76 |
+
[MaskGIT](https://github.com/google-research/maskgit)
|
77 |
+
|
78 |
+
[Taming-Transformers](https://github.com/CompVis/taming-transformers)
|
79 |
+
|
80 |
+
[Open-MUSE](https://github.com/huggingface/open-muse)
|
81 |
+
|
82 |
+
[MUSE-Pytorch](https://github.com/baaivision/MUSE-Pytorch)
|
README_RAR.md
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Randomized Autoregressive Visual Generation
|
2 |
+
|
3 |
+
|
4 |
+
RAR is a an autoregressive (AR) image generator with full compatibility to language modeling. It introduces a randomness annealing strategy with permuted objective at no additional cost, which enhances the model's ability to learn bidirectional contexts while leaving the autoregressive framework intact. RAR sets a FID score 1.48, demonstrating state-of-the-art performance on ImageNet-256 benchmark and significantly outperforming prior AR image generators.
|
5 |
+
|
6 |
+
|
7 |
+
<p>
|
8 |
+
<img src="assets/rar_overview.png" alt="teaser" width=90% height=90%>
|
9 |
+
</p>
|
10 |
+
<p>
|
11 |
+
<img src="assets/perf_comp.png" alt="teaser" width=90% height=90%>
|
12 |
+
</p>
|
13 |
+
|
14 |
+
|
15 |
+
## 🚀 Contributions
|
16 |
+
|
17 |
+
#### We introduce RAR, an improved training strategy enabling standard autoregressive image generator to achieve state-of-the-art performance.
|
18 |
+
|
19 |
+
#### The proposed RAR is extremly simple yet effective: During training, we randomly permute the input token sequence with probability r, where r will starts at 1.0 and linearly decays to 0.0 over the course of training. This simple strategy enbales better bidirectional representation learning which is missing in standard raster-order-based AR image generator training.
|
20 |
+
|
21 |
+
#### RAR keeps the AR framework intact, and thus it is totally compatible to the LLM optimization techniques, such as KV-cache, leading to a significantly faster sampling speed compared to MAR-H or MaskBit while maintaining a better performance.
|
22 |
+
|
23 |
+
## Model Zoo
|
24 |
+
| Model | Link | FID |
|
25 |
+
| ------------- | ------------- | ------------- |
|
26 |
+
| RAR-B | [checkpoint](https://huggingface.co/yucornetto/RAR/blob/main/rar_b.bin)| 1.95 (generation) |
|
27 |
+
| RAR-L | [checkpoint](https://huggingface.co/yucornetto/RAR/blob/main/rar_l.bin)| 1.70 (generation) |
|
28 |
+
| RAR-XL | [checkpoint](https://huggingface.co/yucornetto/RAR/blob/main/rar_xl.bin)| 1.50 (generation) |
|
29 |
+
| RAR-XXL | [checkpoint](https://huggingface.co/yucornetto/RAR/blob/main/rar_xxl.bin)| 1.48 (generation) |
|
30 |
+
|
31 |
+
Please note that these models are trained only on limited academic dataset ImageNet, and they are only for research purposes.
|
32 |
+
|
33 |
+
## Installation
|
34 |
+
```shell
|
35 |
+
pip3 install -r requirements.txt
|
36 |
+
```
|
37 |
+
|
38 |
+
## Get Started
|
39 |
+
```python
|
40 |
+
import torch
|
41 |
+
from PIL import Image
|
42 |
+
import numpy as np
|
43 |
+
import demo_util
|
44 |
+
from huggingface_hub import hf_hub_download
|
45 |
+
from utils.train_utils import create_pretrained_tokenizer
|
46 |
+
|
47 |
+
|
48 |
+
# Choose one from ["rar_b_imagenet", "rar_l_imagenet", "rar_xl_imagenet", "rar_xxl_imagenet"]
|
49 |
+
rar_model_name = ["rar_b", "rar_l", "rar_xl", "rar_xxl"][3]
|
50 |
+
|
51 |
+
# download the maskgit-vq tokenizer
|
52 |
+
hf_hub_download(repo_id="fun-research/TiTok", filename=f"maskgit-vqgan-imagenet-f16-256.bin", local_dir="./")
|
53 |
+
# download the rar generator weight
|
54 |
+
hf_hub_download(repo_id="yucornetto/RAR", filename=f"{rar_model_name}.bin", local_dir="./")
|
55 |
+
|
56 |
+
# load config
|
57 |
+
# config = demo_util.get_config("configs/infer/titok_l32.yaml")
|
58 |
+
# titok_tokenizer = demo_util.get_titok_tokenizer(config)
|
59 |
+
# titok_generator = demo_util.get_titok_generator(config)
|
60 |
+
|
61 |
+
device = "cuda"
|
62 |
+
# maskgit-vq as tokenizer
|
63 |
+
tokenizer = create_pretrained_tokenizer(config)
|
64 |
+
generator = demo_util.get_rar_generator(config)
|
65 |
+
tokenizer.to(device)
|
66 |
+
generator.to(device)
|
67 |
+
|
68 |
+
# generate an image
|
69 |
+
sample_labels = [torch.randint(0, 999, size=(1,)).item()] # random IN-1k class
|
70 |
+
generated_image = demo_util.sample_fn(
|
71 |
+
generator=generator,
|
72 |
+
tokenizer=tokenizer,
|
73 |
+
labels=sample_labels,
|
74 |
+
randomize_temperature=1.0,
|
75 |
+
guidance_scale=4.0,
|
76 |
+
guidance_scale_pow=0.0, # constant cfg
|
77 |
+
device=device
|
78 |
+
)
|
79 |
+
Image.fromarray(generated_image[0]).save(f"assets/rar_generated_{sample_labels[0]}.png")
|
80 |
+
```
|
81 |
+
|
82 |
+
## Testing on ImageNet-1K Benchmark
|
83 |
+
|
84 |
+
We provide a [sampling script](./sample_imagenet_rar.py) for reproducing the generation results on ImageNet-1K benchmark.
|
85 |
+
```bash
|
86 |
+
# Prepare ADM evaluation script
|
87 |
+
git clone https://github.com/openai/guided-diffusion.git
|
88 |
+
|
89 |
+
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
|
90 |
+
```
|
91 |
+
```python
|
92 |
+
# Reproducing RAR-B
|
93 |
+
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_rar.py config=configs/training/generator/rar.yaml \
|
94 |
+
experiment.output_dir="rar_b" \
|
95 |
+
experiment.generator_checkpoint="rar_b.bin" \
|
96 |
+
model.generator.hidden_size=768 \
|
97 |
+
model.generator.num_hidden_layers=24 \
|
98 |
+
model.generator.num_attention_heads=16 \
|
99 |
+
model.generator.intermediate_size=3072 \
|
100 |
+
model.generator.randomize_temperature=1.0 \
|
101 |
+
model.generator.guidance_scale=16.0 \
|
102 |
+
model.generator.guidance_scale_pow=2.75
|
103 |
+
# Run eval script. The result FID should be ~1.95
|
104 |
+
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz rar_b.npz
|
105 |
+
|
106 |
+
# Reproducing RAR-L
|
107 |
+
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_rar.py config=configs/training/generator/rar.yaml \
|
108 |
+
experiment.output_dir="rar_l" \
|
109 |
+
experiment.generator_checkpoint="rar_l.bin" \
|
110 |
+
model.generator.hidden_size=1024 \
|
111 |
+
model.generator.num_hidden_layers=24 \
|
112 |
+
model.generator.num_attention_heads=16 \
|
113 |
+
model.generator.intermediate_size=4096 \
|
114 |
+
model.generator.randomize_temperature=1.02 \
|
115 |
+
model.generator.guidance_scale=15.5 \
|
116 |
+
model.generator.guidance_scale_pow=2.5
|
117 |
+
# Run eval script. The result FID should be ~1.70
|
118 |
+
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz rar_l.npz
|
119 |
+
|
120 |
+
# Reproducing RAR-XL
|
121 |
+
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_rar.py config=configs/training/generator/rar.yaml \
|
122 |
+
experiment.output_dir="rar_xl" \
|
123 |
+
experiment.generator_checkpoint="rar_xl.bin" \
|
124 |
+
model.generator.hidden_size=1280 \
|
125 |
+
model.generator.num_hidden_layers=32 \
|
126 |
+
model.generator.num_attention_heads=16 \
|
127 |
+
model.generator.intermediate_size=5120 \
|
128 |
+
model.generator.randomize_temperature=1.02 \
|
129 |
+
model.generator.guidance_scale=6.9 \
|
130 |
+
model.generator.guidance_scale_pow=1.5
|
131 |
+
# Run eval script. The result FID should be ~1.50
|
132 |
+
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz rar_xl.npz
|
133 |
+
|
134 |
+
# Reproducing RAR-XXL
|
135 |
+
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_rar.py config=configs/training/generator/rar.yaml \
|
136 |
+
experiment.output_dir="rar_xxl" \
|
137 |
+
experiment.generator_checkpoint="rar_xxl.bin" \
|
138 |
+
model.generator.hidden_size=1408 \
|
139 |
+
model.generator.num_hidden_layers=40 \
|
140 |
+
model.generator.num_attention_heads=16 \
|
141 |
+
model.generator.intermediate_size=6144 \
|
142 |
+
model.generator.randomize_temperature=1.02 \
|
143 |
+
model.generator.guidance_scale=8.0 \
|
144 |
+
model.generator.guidance_scale_pow=1.2
|
145 |
+
# Run eval script. The result FID should be ~1.48
|
146 |
+
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz rar_xxl.npz
|
147 |
+
```
|
148 |
+
## Training Preparation
|
149 |
+
We pretokenize the whole dataset for speed-up the training process. We have uploaded [it](https://huggingface.co/yucornetto/RAR/blob/main/maskgitvq.jsonl) so you can train RAR directly. The training script will download the prerequisite checkpoints and dataset automatically.
|
150 |
+
|
151 |
+
## Training
|
152 |
+
We provide example commands to train RAR as follows:
|
153 |
+
```bash
|
154 |
+
# Training for RAR-B
|
155 |
+
WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP} --main_process_port=${ROOT_PORT} --same_network scripts/train_rar.py config=configs/training/generator/rar.yaml \
|
156 |
+
experiment.project="rar" \
|
157 |
+
experiment.name="rar_b" \
|
158 |
+
experiment.output_dir="rar_b" \
|
159 |
+
model.generator.hidden_size=768 \
|
160 |
+
model.generator.num_hidden_layers=24 \
|
161 |
+
model.generator.num_attention_heads=16 \
|
162 |
+
model.generator.intermediate_size=3072
|
163 |
+
|
164 |
+
# Training for RAR-L
|
165 |
+
WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP} --main_process_port=${ROOT_PORT} --same_network scripts/train_rar.py config=configs/training/generator/rar.yaml \
|
166 |
+
experiment.project="rar" \
|
167 |
+
experiment.name="rar_l" \
|
168 |
+
experiment.output_dir="rar_l" \
|
169 |
+
model.generator.hidden_size=1024 \
|
170 |
+
model.generator.num_hidden_layers=24 \
|
171 |
+
model.generator.num_attention_heads=16 \
|
172 |
+
model.generator.intermediate_size=4096
|
173 |
+
|
174 |
+
# Training for RAR-XL
|
175 |
+
WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP} --main_process_port=${ROOT_PORT} --same_network scripts/train_rar.py config=configs/training/generator/rar.yaml \
|
176 |
+
experiment.project="rar" \
|
177 |
+
experiment.name="rar_xl" \
|
178 |
+
experiment.output_dir="rar_xl" \
|
179 |
+
model.generator.hidden_size=1280 \
|
180 |
+
model.generator.num_hidden_layers=32 \
|
181 |
+
model.generator.num_attention_heads=16 \
|
182 |
+
model.generator.intermediate_size=5120
|
183 |
+
|
184 |
+
# Training for RAR-XXL
|
185 |
+
WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP} --main_process_port=${ROOT_PORT} --same_network scripts/train_rar.py config=configs/training/generator/rar.yaml \
|
186 |
+
experiment.project="rar" \
|
187 |
+
experiment.name="rar_xxl" \
|
188 |
+
experiment.output_dir="rar_xxl" \
|
189 |
+
model.generator.hidden_size=1408 \
|
190 |
+
model.generator.num_hidden_layers=40 \
|
191 |
+
model.generator.num_attention_heads=16 \
|
192 |
+
model.generator.intermediate_size=6144
|
193 |
+
```
|
194 |
+
You may remove the flag "WANDB_MODE=offline" to support online wandb logging, if you have configured it.
|
195 |
+
|
196 |
+
Notably, you can enable grad checkpointing by adding the flag "model.generator.use_checkpoint=True" and adjust the machine number & GPU number based on your own need. All RAR checkpoints were trained with a global batchsize = 2048.
|
197 |
+
|
198 |
+
|
199 |
+
## Visualizations
|
200 |
+
<p>
|
201 |
+
<img src="assets/vis1.png" alt="teaser" width=90% height=90%>
|
202 |
+
</p>
|
203 |
+
<p>
|
204 |
+
<img src="assets/vis2.png" alt="teaser" width=90% height=90%>
|
205 |
+
</p>
|
206 |
+
<p>
|
207 |
+
<img src="assets/vis3.png" alt="teaser" width=90% height=90%>
|
208 |
+
</p>
|
209 |
+
|
210 |
+
|
211 |
+
## Citing
|
212 |
+
If you use our work in your research, please use the following BibTeX entry.
|
213 |
+
|
214 |
+
```BibTeX
|
215 |
+
@inproceedings{yu2024randomized,
|
216 |
+
author = {Qihang Yu and Ju He and Xueqing Deng and Xiaohui Shen and Liang-Chieh Chen},
|
217 |
+
title = {Randomized Autoregressive Visual Generation},
|
218 |
+
journal = {arXiv preprint arXiv:2411.00776},
|
219 |
+
year = {2024}
|
220 |
+
}
|
221 |
+
```
|
README_TiTok.md
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# (NeurIPS 2024) Compact and Mighty - Image Tokenization with Only 32 Tokens for both Reconstruction and Generation!
|
2 |
+
|
3 |
+
<div align="center">
|
4 |
+
|
5 |
+
[![demo](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Online_Demo-blue)](https://huggingface.co/spaces/fun-research/TiTok)
|
6 |
+
[![Website](https://img.shields.io/badge/Project-Website-87CEEB)](https://yucornetto.github.io/projects/titok.html)
|
7 |
+
[![paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2406.07550)
|
8 |
+
|
9 |
+
</div>
|
10 |
+
|
11 |
+
We present a compact 1D tokenizer which can represent an image with as few as 32 discrete tokens. As a result, it leads to a substantial speed-up on the sampling process (e.g., **410 × faster** than DiT-XL/2) while obtaining a competitive generation quality.
|
12 |
+
|
13 |
+
|
14 |
+
<p>
|
15 |
+
<img src="assets/titok_teaser.png" alt="teaser" width=90% height=90%>
|
16 |
+
</p>
|
17 |
+
<p>
|
18 |
+
<img src="assets/speed_vs_perf.png" alt="teaser" width=90% height=90%>
|
19 |
+
</p>
|
20 |
+
|
21 |
+
|
22 |
+
## 🚀 Contributions
|
23 |
+
|
24 |
+
#### We introduce a novel 1D image tokenization framework that breaks grid constraints existing in 2D tokenization methods, leading to a much more flexible and compact image latent representation.
|
25 |
+
|
26 |
+
#### The proposed 1D tokenizer can tokenize a 256 × 256 image into as few as 32 discrete tokens, leading to a significant speed-up (hundreds times faster than diffusion models) in generation process, while maintaining state-of-the-art generation quality.
|
27 |
+
|
28 |
+
#### We conduct a series of experiments to probe the properties of rarely studied 1D image tokenization, paving the path towards compact latent space for efficient and effective image representation.
|
29 |
+
|
30 |
+
## Model Zoo
|
31 |
+
| Model | Link | FID |
|
32 |
+
| ------------- | ------------- | ------------- |
|
33 |
+
| TiTok-L-32 Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_l32_imagenet)| 2.21 (reconstruction) |
|
34 |
+
| TiTok-B-64 Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_b64_imagenet) | 1.70 (reconstruction) |
|
35 |
+
| TiTok-S-128 Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_s128_imagenet) | 1.71 (reconstruction) |
|
36 |
+
| TiTok-L-32 Generator | [checkpoint](https://huggingface.co/yucornetto/generator_titok_l32_imagenet) | 2.77 (generation) |
|
37 |
+
| TiTok-B-64 Generator | [checkpoint](https://huggingface.co/yucornetto/generator_titok_b64_imagenet) | 2.48 (generation) |
|
38 |
+
| TiTok-S-128 Generator | [checkpoint](https://huggingface.co/yucornetto/generator_titok_s128_imagenet) | 1.97 (generation) |
|
39 |
+
| TiTok-BL-64 VQ Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_bl64_vq8k_imagenet)| 2.06 (reconstruction) |
|
40 |
+
| TiTok-BL-128 VQ Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_bl128_vq8k_imagenet)| 1.49 (reconstruction) |
|
41 |
+
| TiTok-SL-256 VQ Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_sl256_vq8k_imagenet)| 1.03 (reconstruction) |
|
42 |
+
| TiTok-LL-32 VAE Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_ll32_vae_c16_imagenet)| 1.61 (reconstruction) |
|
43 |
+
| TiTok-BL-64 VAE Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_bl64_vae_c16_imagenet)| 1.25 (reconstruction) |
|
44 |
+
| TiTok-BL-128 VAE Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_bl128_vae_c16_imagenet)| 0.84 (reconstruction) |
|
45 |
+
|
46 |
+
Please note that these models are trained only on limited academic dataset ImageNet, and they are only for research purposes.
|
47 |
+
|
48 |
+
## Installation
|
49 |
+
```shell
|
50 |
+
pip3 install -r requirements.txt
|
51 |
+
```
|
52 |
+
|
53 |
+
## Get Started
|
54 |
+
```python
|
55 |
+
import torch
|
56 |
+
from PIL import Image
|
57 |
+
import numpy as np
|
58 |
+
import demo_util
|
59 |
+
from huggingface_hub import hf_hub_download
|
60 |
+
from modeling.maskgit import ImageBert
|
61 |
+
from modeling.titok import TiTok
|
62 |
+
|
63 |
+
# Choose one from ["tokenizer_titok_l32_imagenet", "tokenizer_titok_b64_imagenet",
|
64 |
+
# "tokenizer_titok_s128_imagenet", "tokenizer_titok_bl128_vae_c16_imagenet", tokenizer_titok_bl64_vae_c16_imagenet",
|
65 |
+
# "tokenizer_titok_ll32_vae_c16_imagenet", "tokenizer_titok_sl256_vq8k_imagenet", "tokenizer_titok_bl128_vq8k_imagenet",
|
66 |
+
# "tokenizer_titok_bl64_vq8k_imagenet",]
|
67 |
+
titok_tokenizer = TiTok.from_pretrained("yucornetto/tokenizer_titok_l32_imagenet")
|
68 |
+
titok_tokenizer.eval()
|
69 |
+
titok_tokenizer.requires_grad_(False)
|
70 |
+
titok_generator = ImageBert.from_pretrained("yucornetto/generator_titok_l32_imagenet")
|
71 |
+
titok_generator.eval()
|
72 |
+
titok_generator.requires_grad_(False)
|
73 |
+
|
74 |
+
# or alternatively, downloads from hf
|
75 |
+
# hf_hub_download(repo_id="fun-research/TiTok", filename="tokenizer_titok_l32.bin", local_dir="./")
|
76 |
+
# hf_hub_download(repo_id="fun-research/TiTok", filename="generator_titok_l32.bin", local_dir="./")
|
77 |
+
|
78 |
+
# load config
|
79 |
+
# config = demo_util.get_config("configs/infer/titok_l32.yaml")
|
80 |
+
# titok_tokenizer = demo_util.get_titok_tokenizer(config)
|
81 |
+
# titok_generator = demo_util.get_titok_generator(config)
|
82 |
+
|
83 |
+
device = "cuda"
|
84 |
+
titok_tokenizer = titok_tokenizer.to(device)
|
85 |
+
titok_generator = titok_generator.to(device)
|
86 |
+
|
87 |
+
# reconstruct an image. I.e., image -> 32 tokens -> image
|
88 |
+
img_path = "assets/ILSVRC2012_val_00010240.png"
|
89 |
+
image = torch.from_numpy(np.array(Image.open(img_path)).astype(np.float32)).permute(2, 0, 1).unsqueeze(0) / 255.0
|
90 |
+
# tokenization
|
91 |
+
if titok_tokenizer.quantize_mode == "vq":
|
92 |
+
encoded_tokens = titok_tokenizer.encode(image.to(device))[1]["min_encoding_indices"]
|
93 |
+
elif titok_tokenizer.quantize_mode == "vae":
|
94 |
+
posteriors = titok_tokenizer.encode(image.to(device))[1]
|
95 |
+
encoded_tokens = posteriors.sample()
|
96 |
+
else:
|
97 |
+
raise NotImplementedError
|
98 |
+
# image assets/ILSVRC2012_val_00010240.png is encoded into tokens tensor([[[ 887, 3979, 349, 720, 2809, 2743, 2101, 603, 2205, 1508, 1891, 4015, 1317, 2956, 3774, 2296, 484, 2612, 3472, 2330, 3140, 3113, 1056, 3779, 654, 2360, 1901, 2908, 2169, 953, 1326, 2598]]], device='cuda:0'), with shape torch.Size([1, 1, 32])
|
99 |
+
print(f"image {img_path} is encoded into tokens {encoded_tokens}, with shape {encoded_tokens.shape}")
|
100 |
+
# de-tokenization
|
101 |
+
reconstructed_image = titok_tokenizer.decode_tokens(encoded_tokens)
|
102 |
+
reconstructed_image = torch.clamp(reconstructed_image, 0.0, 1.0)
|
103 |
+
reconstructed_image = (reconstructed_image * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()[0]
|
104 |
+
reconstructed_image = Image.fromarray(reconstructed_image).save("assets/ILSVRC2012_val_00010240_recon.png")
|
105 |
+
|
106 |
+
# generate an image
|
107 |
+
sample_labels = [torch.randint(0, 999, size=(1,)).item()] # random IN-1k class
|
108 |
+
generated_image = demo_util.sample_fn(
|
109 |
+
generator=titok_generator,
|
110 |
+
tokenizer=titok_tokenizer,
|
111 |
+
labels=sample_labels,
|
112 |
+
guidance_scale=4.5,
|
113 |
+
randomize_temperature=1.0,
|
114 |
+
num_sample_steps=8,
|
115 |
+
device=device
|
116 |
+
)
|
117 |
+
Image.fromarray(generated_image[0]).save(f"assets/generated_{sample_labels[0]}.png")
|
118 |
+
```
|
119 |
+
|
120 |
+
We also provide a [jupyter notebook](demo.ipynb) for a quick tutorial on reconstructing and generating images with TiTok-L-32.
|
121 |
+
|
122 |
+
We also support TiTok with [HuggingFace 🤗 Demo](https://huggingface.co/spaces/fun-research/TiTok)!
|
123 |
+
|
124 |
+
## Testing on ImageNet-1K Benchmark
|
125 |
+
|
126 |
+
We provide a [sampling script](./sample_imagenet_titok.py) for reproducing the generation results on ImageNet-1K benchmark.
|
127 |
+
```bash
|
128 |
+
# Prepare ADM evaluation script
|
129 |
+
git clone https://github.com/openai/guided-diffusion.git
|
130 |
+
|
131 |
+
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
|
132 |
+
```
|
133 |
+
```python
|
134 |
+
# Reproducing TiTok-L-32
|
135 |
+
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_titok.py config=configs/infer/titok_l32.yaml experiment.output_dir="titok_l_32"
|
136 |
+
# Run eval script. The result FID should be ~2.77
|
137 |
+
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz titok_l_32.npz
|
138 |
+
|
139 |
+
# Reproducing TiTok-B-64
|
140 |
+
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_titok.py config=configs/infer/titok_b64.yaml experiment.output_dir="titok_b_64"
|
141 |
+
# Run eval script. The result FID should be ~2.48
|
142 |
+
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz titok_b_64.npz
|
143 |
+
|
144 |
+
# Reproducing TiTok-S-128
|
145 |
+
torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_titok.py config=configs/infer/titok_s128.yaml experiment.output_dir="titok_s_128"
|
146 |
+
# Run eval script. The result FID should be ~1.97
|
147 |
+
python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz titok_s_128.npz
|
148 |
+
```
|
149 |
+
## Training Preparation
|
150 |
+
We use [webdataset](https://github.com/webdataset/webdataset) format for data loading. To begin with, it is needed to convert the dataset into webdataset format. An example script to convert ImageNet to wds format is provided [here](./data/convert_imagenet_to_wds.py).
|
151 |
+
|
152 |
+
Furthermore, the stage1 training relies on a pre-trained MaskGIT-VQGAN to generate proxy codes as learning targets. You can convert the [official Jax weight](https://github.com/google-research/maskgit) to PyTorch version using [this script](https://github.com/huggingface/open-muse/blob/main/scripts/convert_maskgit_vqgan.py). Alternatively, we provided a converted version at [HuggingFace](https://huggingface.co/fun-research/TiTok/blob/main/maskgit-vqgan-imagenet-f16-256.bin) and [Google Drive](https://drive.google.com/file/d/1DjZqzJrUt2hwpmUPkjGSBTFEJcOkLY-Q/view?usp=sharing). The MaskGIT-VQGAN's weight will be automatically downloaded when you run the training script.
|
153 |
+
|
154 |
+
## Training
|
155 |
+
We provide example commands to train TiTok as follows:
|
156 |
+
```bash
|
157 |
+
# Training for TiTok-B64
|
158 |
+
# Stage 1
|
159 |
+
WANDB_MODE=offline accelerate launch --num_machines=1 --num_processes=8 --machine_rank=0 --main_process_ip=127.0.0.1 --main_process_port=9999 --same_network scripts/train_titok.py config=configs/training/stage1/titok_b64.yaml \
|
160 |
+
experiment.project="titok_b64_stage1" \
|
161 |
+
experiment.name="titok_b64_stage1_run1" \
|
162 |
+
experiment.output_dir="titok_b64_stage1_run1" \
|
163 |
+
training.per_gpu_batch_size=32
|
164 |
+
|
165 |
+
# Stage 2
|
166 |
+
WANDB_MODE=offline accelerate launch --num_machines=1 --num_processes=8 --machine_rank=0 --main_process_ip=127.0.0.1 --main_process_port=9999 --same_network scripts/train_titok.py config=configs/training/stage2/titok_b64.yaml \
|
167 |
+
experiment.project="titok_b64_stage2" \
|
168 |
+
experiment.name="titok_b64_stage2_run1" \
|
169 |
+
experiment.output_dir="titok_b64_stage2_run1" \
|
170 |
+
training.per_gpu_batch_size=32 \
|
171 |
+
experiment.init_weight=${PATH_TO_STAGE1_WEIGHT}
|
172 |
+
|
173 |
+
# Train Generator (TiTok-B64 as example)
|
174 |
+
WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP}--main_process_port=${ROOT_PORT} --same_network scripts/train_maskgit.py config=configs/training/generator/maskgit.yaml \
|
175 |
+
experiment.project="titok_generation" \
|
176 |
+
experiment.name="titok_b64_maskgit" \
|
177 |
+
experiment.output_dir="titok_b64_maskgit" \
|
178 |
+
experiment.tokenizer_checkpoint=${PATH_TO_STAGE1_or_STAGE2_WEIGHT}
|
179 |
+
```
|
180 |
+
You may remove the flag "WANDB_MODE=offline" to support online wandb logging, if you have configured it.
|
181 |
+
|
182 |
+
The config _titok_b64.yaml_ can be replaced with _titok_s128.yaml_ or _titok_l32.yaml_ for other TiTok variants.
|
183 |
+
|
184 |
+
## Visualizations
|
185 |
+
<p>
|
186 |
+
<img src="assets/recon_w_model_size_num_token.png" alt="teaser" width=90% height=90%>
|
187 |
+
</p>
|
188 |
+
<p>
|
189 |
+
<img src="assets/random_vis_l32.png" alt="teaser" width=90% height=90%>
|
190 |
+
</p>
|
191 |
+
|
192 |
+
|
193 |
+
## Citing
|
194 |
+
If you use our work in your research, please use the following BibTeX entry.
|
195 |
+
|
196 |
+
```BibTeX
|
197 |
+
@inproceedings{yu2024an,
|
198 |
+
author = {Qihang Yu and Mark Weber and Xueqing Deng and Xiaohui Shen and Daniel Cremers and Liang-Chieh Chen},
|
199 |
+
title = {An Image is Worth 32 Tokens for Reconstruction and Generation},
|
200 |
+
journal = {NeurIPS},
|
201 |
+
year = {2024}
|
202 |
+
}
|
203 |
+
```
|
204 |
+
|
205 |
+
## Acknowledgement
|
206 |
+
|
207 |
+
[MaskGIT](https://github.com/google-research/maskgit)
|
208 |
+
|
209 |
+
[Taming-Transformers](https://github.com/CompVis/taming-transformers)
|
210 |
+
|
211 |
+
[Open-MUSE](https://github.com/huggingface/open-muse)
|
212 |
+
|
213 |
+
[MUSE-Pytorch](https://github.com/baaivision/MUSE-Pytorch)
|
assets/ILSVRC2012_val_00008636.png
ADDED
assets/ILSVRC2012_val_00010240.png
ADDED
assets/perf_comp.png
ADDED
assets/random_vis_l32.png
ADDED
Git LFS Details
|
assets/rar_overview.png
ADDED
Git LFS Details
|
assets/recon_w_model_size_num_token.png
ADDED
Git LFS Details
|
assets/speed_vs_perf.png
ADDED
assets/titok_teaser.png
ADDED
assets/vis1.png
ADDED
Git LFS Details
|
assets/vis2.png
ADDED
Git LFS Details
|
assets/vis3.png
ADDED
Git LFS Details
|
configs/infer/titok_b64.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
tokenizer_checkpoint: "tokenizer_titok_b64.bin"
|
3 |
+
generator_checkpoint: "generator_titok_b64.bin"
|
4 |
+
output_dir: "titok_b_64"
|
5 |
+
model:
|
6 |
+
vq_model:
|
7 |
+
codebook_size: 4096
|
8 |
+
token_size: 12
|
9 |
+
use_l2_norm: True
|
10 |
+
commitment_cost: 0.25
|
11 |
+
# vit arch
|
12 |
+
vit_enc_model_size: "base"
|
13 |
+
vit_dec_model_size: "base"
|
14 |
+
vit_enc_patch_size: 16
|
15 |
+
vit_dec_patch_size: 16
|
16 |
+
num_latent_tokens: 64
|
17 |
+
finetune_decoder: True
|
18 |
+
|
19 |
+
generator:
|
20 |
+
model_type: "ViT"
|
21 |
+
hidden_size: 768
|
22 |
+
num_hidden_layers: 24
|
23 |
+
num_attention_heads: 16
|
24 |
+
intermediate_size: 3072
|
25 |
+
dropout: 0.1
|
26 |
+
attn_drop: 0.1
|
27 |
+
num_steps: 8
|
28 |
+
class_label_dropout: 0.1
|
29 |
+
image_seq_len: ${model.vq_model.num_latent_tokens}
|
30 |
+
condition_num_classes: 1000
|
31 |
+
|
32 |
+
# sampling hyper-params
|
33 |
+
randomize_temperature: 11.0
|
34 |
+
guidance_scale: 3.0
|
35 |
+
guidance_decay: "linear"
|
36 |
+
|
37 |
+
dataset:
|
38 |
+
preprocessing:
|
39 |
+
crop_size: 256
|
configs/infer/titok_bl128_vae_c16.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
tokenizer_checkpoint: "titok_bl128_vae_c16.bin"
|
3 |
+
output_dir: "titok_bl128_vae_c16"
|
4 |
+
model:
|
5 |
+
vq_model:
|
6 |
+
quantize_mode: "vae"
|
7 |
+
token_size: 16
|
8 |
+
# vit arch
|
9 |
+
vit_enc_model_size: "base"
|
10 |
+
vit_dec_model_size: "large"
|
11 |
+
vit_enc_patch_size: 16
|
12 |
+
vit_dec_patch_size: 16
|
13 |
+
num_latent_tokens: 128
|
14 |
+
finetune_decoder: False
|
15 |
+
is_legacy: False
|
16 |
+
|
17 |
+
dataset:
|
18 |
+
preprocessing:
|
19 |
+
crop_size: 256
|
configs/infer/titok_bl128_vq8k.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
tokenizer_checkpoint: "tokenizer_titok_bl128_vq8k.bin"
|
3 |
+
output_dir: "titok_bl128_vq8k"
|
4 |
+
model:
|
5 |
+
vq_model:
|
6 |
+
codebook_size: 8192
|
7 |
+
token_size: 64
|
8 |
+
use_l2_norm: False
|
9 |
+
commitment_cost: 0.25
|
10 |
+
# vit arch
|
11 |
+
vit_enc_model_size: "base"
|
12 |
+
vit_dec_model_size: "large"
|
13 |
+
vit_enc_patch_size: 16
|
14 |
+
vit_dec_patch_size: 16
|
15 |
+
num_latent_tokens: 128
|
16 |
+
finetune_decoder: False
|
17 |
+
is_legacy: False
|
18 |
+
|
19 |
+
dataset:
|
20 |
+
preprocessing:
|
21 |
+
crop_size: 256
|
configs/infer/titok_bl64_vae_c16.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
tokenizer_checkpoint: "titok_bl64_vae_c16.bin"
|
3 |
+
output_dir: "titok_bl64_vae_c16"
|
4 |
+
model:
|
5 |
+
vq_model:
|
6 |
+
quantize_mode: "vae"
|
7 |
+
token_size: 16
|
8 |
+
# vit arch
|
9 |
+
vit_enc_model_size: "base"
|
10 |
+
vit_dec_model_size: "large"
|
11 |
+
vit_enc_patch_size: 16
|
12 |
+
vit_dec_patch_size: 16
|
13 |
+
num_latent_tokens: 64
|
14 |
+
finetune_decoder: False
|
15 |
+
is_legacy: False
|
16 |
+
|
17 |
+
dataset:
|
18 |
+
preprocessing:
|
19 |
+
crop_size: 256
|
configs/infer/titok_bl64_vq8k.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
tokenizer_checkpoint: "tokenizer_titok_bl64_vq8k.bin"
|
3 |
+
output_dir: "titok_bl64_vq8k"
|
4 |
+
model:
|
5 |
+
vq_model:
|
6 |
+
codebook_size: 8192
|
7 |
+
token_size: 64
|
8 |
+
use_l2_norm: False
|
9 |
+
commitment_cost: 0.25
|
10 |
+
# vit arch
|
11 |
+
vit_enc_model_size: "base"
|
12 |
+
vit_dec_model_size: "large"
|
13 |
+
vit_enc_patch_size: 16
|
14 |
+
vit_dec_patch_size: 16
|
15 |
+
num_latent_tokens: 64
|
16 |
+
finetune_decoder: False
|
17 |
+
is_legacy: False
|
18 |
+
|
19 |
+
dataset:
|
20 |
+
preprocessing:
|
21 |
+
crop_size: 256
|
configs/infer/titok_l32.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
tokenizer_checkpoint: "tokenizer_titok_l32.bin"
|
3 |
+
generator_checkpoint: "generator_titok_l32.bin"
|
4 |
+
output_dir: "titok_l_32"
|
5 |
+
model:
|
6 |
+
vq_model:
|
7 |
+
codebook_size: 4096
|
8 |
+
token_size: 12
|
9 |
+
use_l2_norm: True
|
10 |
+
commitment_cost: 0.25
|
11 |
+
# vit arch
|
12 |
+
vit_enc_model_size: "large"
|
13 |
+
vit_dec_model_size: "large"
|
14 |
+
vit_enc_patch_size: 16
|
15 |
+
vit_dec_patch_size: 16
|
16 |
+
num_latent_tokens: 32
|
17 |
+
finetune_decoder: True
|
18 |
+
|
19 |
+
generator:
|
20 |
+
model_type: "ViT"
|
21 |
+
hidden_size: 768
|
22 |
+
num_hidden_layers: 24
|
23 |
+
num_attention_heads: 16
|
24 |
+
intermediate_size: 3072
|
25 |
+
dropout: 0.1
|
26 |
+
attn_drop: 0.1
|
27 |
+
num_steps: 8
|
28 |
+
class_label_dropout: 0.1
|
29 |
+
image_seq_len: ${model.vq_model.num_latent_tokens}
|
30 |
+
condition_num_classes: 1000
|
31 |
+
|
32 |
+
# sampling hyper-params
|
33 |
+
randomize_temperature: 9.5
|
34 |
+
guidance_scale: 4.5
|
35 |
+
guidance_decay: "linear"
|
36 |
+
|
37 |
+
|
38 |
+
dataset:
|
39 |
+
preprocessing:
|
40 |
+
crop_size: 256
|
configs/infer/titok_ll32_vae_c16.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
tokenizer_checkpoint: "titok_ll32_vae_c16.bin"
|
3 |
+
output_dir: "titok_ll32_vae_c16"
|
4 |
+
model:
|
5 |
+
vq_model:
|
6 |
+
quantize_mode: "vae"
|
7 |
+
token_size: 16
|
8 |
+
# vit arch
|
9 |
+
vit_enc_model_size: "large"
|
10 |
+
vit_dec_model_size: "large"
|
11 |
+
vit_enc_patch_size: 16
|
12 |
+
vit_dec_patch_size: 16
|
13 |
+
num_latent_tokens: 32
|
14 |
+
finetune_decoder: False
|
15 |
+
is_legacy: False
|
16 |
+
|
17 |
+
dataset:
|
18 |
+
preprocessing:
|
19 |
+
crop_size: 256
|
configs/infer/titok_s128.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
tokenizer_checkpoint: "tokenizer_titok_s128.bin"
|
3 |
+
generator_checkpoint: "generator_titok_s128.bin"
|
4 |
+
output_dir: "titok_s_128"
|
5 |
+
model:
|
6 |
+
vq_model:
|
7 |
+
codebook_size: 4096
|
8 |
+
token_size: 12
|
9 |
+
use_l2_norm: True
|
10 |
+
commitment_cost: 0.25
|
11 |
+
# vit arch
|
12 |
+
vit_enc_model_size: "small"
|
13 |
+
vit_dec_model_size: "small"
|
14 |
+
vit_enc_patch_size: 16
|
15 |
+
vit_dec_patch_size: 16
|
16 |
+
num_latent_tokens: 128
|
17 |
+
finetune_decoder: True
|
18 |
+
|
19 |
+
generator:
|
20 |
+
model_type: "UViT"
|
21 |
+
hidden_size: 1024
|
22 |
+
num_hidden_layers: 20
|
23 |
+
num_attention_heads: 16
|
24 |
+
intermediate_size: 4096
|
25 |
+
dropout: 0.1
|
26 |
+
attn_drop: 0.1
|
27 |
+
num_steps: 64
|
28 |
+
class_label_dropout: 0.1
|
29 |
+
image_seq_len: ${model.vq_model.num_latent_tokens}
|
30 |
+
condition_num_classes: 1000
|
31 |
+
|
32 |
+
# sampling hyper-params
|
33 |
+
randomize_temperature: 2.8
|
34 |
+
guidance_scale: 6.9
|
35 |
+
guidance_decay: "power-cosine"
|
36 |
+
|
37 |
+
dataset:
|
38 |
+
preprocessing:
|
39 |
+
crop_size: 256
|
configs/infer/titok_sl256_vq8k.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
tokenizer_checkpoint: "tokenizer_titok_sl256_vq8k.bin"
|
3 |
+
output_dir: "titok_sl256_vq8k"
|
4 |
+
model:
|
5 |
+
vq_model:
|
6 |
+
codebook_size: 8192
|
7 |
+
token_size: 64
|
8 |
+
use_l2_norm: False
|
9 |
+
commitment_cost: 0.25
|
10 |
+
# vit arch
|
11 |
+
vit_enc_model_size: "small"
|
12 |
+
vit_dec_model_size: "large"
|
13 |
+
vit_enc_patch_size: 16
|
14 |
+
vit_dec_patch_size: 16
|
15 |
+
num_latent_tokens: 256
|
16 |
+
finetune_decoder: False
|
17 |
+
is_legacy: False
|
18 |
+
|
19 |
+
dataset:
|
20 |
+
preprocessing:
|
21 |
+
crop_size: 256
|
configs/training/generator/maskgit.yaml
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
project: "titok_generation"
|
3 |
+
name: "titok_b64_maskgit"
|
4 |
+
max_train_examples: 1_281_167
|
5 |
+
save_every: 50_000
|
6 |
+
eval_every: 50_000
|
7 |
+
generate_every: 5_000
|
8 |
+
log_every: 50
|
9 |
+
log_grad_norm_every: 1_000
|
10 |
+
resume: True
|
11 |
+
tokenizer_checkpoint: "tokenizer_titok_b64.bin"
|
12 |
+
|
13 |
+
model:
|
14 |
+
vq_model:
|
15 |
+
codebook_size: 4096
|
16 |
+
token_size: 12
|
17 |
+
use_l2_norm: True
|
18 |
+
commitment_cost: 0.25
|
19 |
+
# vit arch
|
20 |
+
vit_enc_model_size: "base"
|
21 |
+
vit_dec_model_size: "base"
|
22 |
+
vit_enc_patch_size: 16
|
23 |
+
vit_dec_patch_size: 16
|
24 |
+
num_latent_tokens: 64
|
25 |
+
finetune_decoder: True
|
26 |
+
|
27 |
+
generator:
|
28 |
+
model_type: "ViT"
|
29 |
+
hidden_size: 768
|
30 |
+
num_hidden_layers: 24
|
31 |
+
num_attention_heads: 16
|
32 |
+
intermediate_size: 3072
|
33 |
+
dropout: 0.1
|
34 |
+
attn_drop: 0.1
|
35 |
+
num_steps: 8
|
36 |
+
class_label_dropout: 0.1
|
37 |
+
image_seq_len: ${model.vq_model.num_latent_tokens}
|
38 |
+
condition_num_classes: 1000
|
39 |
+
|
40 |
+
# sampling hyper-params on the flight
|
41 |
+
randomize_temperature: 1.0
|
42 |
+
guidance_scale: 4.5
|
43 |
+
guidance_decay: "constant"
|
44 |
+
|
45 |
+
losses:
|
46 |
+
label_smoothing: 0.1
|
47 |
+
loss_weight_unmasked_token: 0.1
|
48 |
+
|
49 |
+
dataset:
|
50 |
+
params:
|
51 |
+
train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
|
52 |
+
eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
|
53 |
+
num_workers_per_gpu: 12
|
54 |
+
preprocessing:
|
55 |
+
resize_shorter_edge: 256
|
56 |
+
crop_size: 256
|
57 |
+
random_crop: False
|
58 |
+
random_flip: True
|
59 |
+
|
60 |
+
optimizer:
|
61 |
+
name: adamw
|
62 |
+
params:
|
63 |
+
learning_rate: 2e-4
|
64 |
+
beta1: 0.9
|
65 |
+
beta2: 0.96
|
66 |
+
weight_decay: 0.03
|
67 |
+
|
68 |
+
|
69 |
+
lr_scheduler:
|
70 |
+
scheduler: "cosine"
|
71 |
+
params:
|
72 |
+
learning_rate: ${optimizer.params.learning_rate}
|
73 |
+
warmup_steps: 10_000
|
74 |
+
end_lr: 1e-5
|
75 |
+
|
76 |
+
|
77 |
+
training:
|
78 |
+
gradient_accumulation_steps: 1
|
79 |
+
per_gpu_batch_size: 64 # 32 GPU, total batch size 2048
|
80 |
+
mixed_precision: "bf16"
|
81 |
+
enable_tf32: True
|
82 |
+
enable_wandb: True
|
83 |
+
use_ema: True
|
84 |
+
seed: 42
|
85 |
+
max_train_steps: 500_000
|
86 |
+
max_grad_norm: 1.0
|
configs/training/generator/rar.yaml
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
project: "rar_generation"
|
3 |
+
name: "rar-b"
|
4 |
+
max_train_examples: 1_281_167
|
5 |
+
save_every: 50_000
|
6 |
+
eval_every: 50_000
|
7 |
+
generate_every: 5_000
|
8 |
+
log_every: 50
|
9 |
+
log_grad_norm_every: 1_000
|
10 |
+
resume: True
|
11 |
+
|
12 |
+
model:
|
13 |
+
vq_model:
|
14 |
+
codebook_size: 1024
|
15 |
+
token_size: 256
|
16 |
+
num_latent_tokens: 256
|
17 |
+
finetune_decoder: False
|
18 |
+
pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"
|
19 |
+
|
20 |
+
generator:
|
21 |
+
hidden_size: 768
|
22 |
+
num_hidden_layers: 24
|
23 |
+
num_attention_heads: 16
|
24 |
+
intermediate_size: 3072
|
25 |
+
dropout: 0.1
|
26 |
+
attn_drop: 0.1
|
27 |
+
class_label_dropout: 0.1
|
28 |
+
image_seq_len: 256
|
29 |
+
condition_num_classes: 1000
|
30 |
+
|
31 |
+
# sampling hyper-params for RAR-B
|
32 |
+
randomize_temperature: 1.0
|
33 |
+
guidance_scale: 16.0
|
34 |
+
guidance_scale_pow: 2.75
|
35 |
+
use_checkpoint: False # True to save memory
|
36 |
+
|
37 |
+
randomness_anneal_start: 125000 # 200 epoch
|
38 |
+
randomness_anneal_end: 187500 # 300 epoch
|
39 |
+
|
40 |
+
dataset:
|
41 |
+
params:
|
42 |
+
# use pretokenized dataset for speed-up
|
43 |
+
pretokenization: "maskgitvq.jsonl"
|
44 |
+
train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
|
45 |
+
eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
|
46 |
+
num_workers_per_gpu: 12
|
47 |
+
preprocessing:
|
48 |
+
resize_shorter_edge: 256
|
49 |
+
crop_size: 256
|
50 |
+
random_crop: False
|
51 |
+
random_flip: True
|
52 |
+
|
53 |
+
optimizer:
|
54 |
+
name: adamw
|
55 |
+
params:
|
56 |
+
learning_rate: 4e-4
|
57 |
+
beta1: 0.9
|
58 |
+
beta2: 0.96
|
59 |
+
weight_decay: 0.03
|
60 |
+
|
61 |
+
|
62 |
+
lr_scheduler:
|
63 |
+
scheduler: "cosine"
|
64 |
+
params:
|
65 |
+
learning_rate: ${optimizer.params.learning_rate}
|
66 |
+
warmup_steps: 62_500 # 100 epochs with bsz 2048
|
67 |
+
end_lr: 1e-5
|
68 |
+
|
69 |
+
training:
|
70 |
+
gradient_accumulation_steps: 1
|
71 |
+
per_gpu_batch_size: 64 # 32 GPU, total batch size 2048
|
72 |
+
mixed_precision: "bf16"
|
73 |
+
enable_tf32: True
|
74 |
+
enable_wandb: True
|
75 |
+
use_ema: False
|
76 |
+
seed: 42
|
77 |
+
max_train_steps: 250_000
|
78 |
+
max_grad_norm: 1.0
|
configs/training/stage1/titok_b64.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
project: "titok_b_64_stage1"
|
3 |
+
name: "titok_b_64_stage1_run1"
|
4 |
+
output_dir: "titok_b_64_stage1_run1"
|
5 |
+
max_train_examples: 1_281_167
|
6 |
+
save_every: 50_000
|
7 |
+
eval_every: 50_000
|
8 |
+
generate_every: 5_000
|
9 |
+
log_every: 50
|
10 |
+
log_grad_norm_every: 1_000
|
11 |
+
resume: True
|
12 |
+
init_weight: ""
|
13 |
+
|
14 |
+
|
15 |
+
model:
|
16 |
+
vq_model:
|
17 |
+
codebook_size: 4096
|
18 |
+
token_size: 12
|
19 |
+
use_l2_norm: True
|
20 |
+
commitment_cost: 0.25
|
21 |
+
# vit arch
|
22 |
+
vit_enc_model_size: "base"
|
23 |
+
vit_dec_model_size: "base"
|
24 |
+
vit_enc_patch_size: 16
|
25 |
+
vit_dec_patch_size: 16
|
26 |
+
num_latent_tokens: 64
|
27 |
+
finetune_decoder: False
|
28 |
+
pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"
|
29 |
+
|
30 |
+
losses:
|
31 |
+
quantizer_weight: 1.0
|
32 |
+
|
33 |
+
dataset:
|
34 |
+
params:
|
35 |
+
train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
|
36 |
+
eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
|
37 |
+
num_workers_per_gpu: 12
|
38 |
+
preprocessing:
|
39 |
+
resize_shorter_edge: 256
|
40 |
+
crop_size: 256
|
41 |
+
random_crop: True
|
42 |
+
random_flip: True
|
43 |
+
|
44 |
+
|
45 |
+
optimizer:
|
46 |
+
name: adamw
|
47 |
+
params:
|
48 |
+
learning_rate: 1e-4
|
49 |
+
beta1: 0.9
|
50 |
+
beta2: 0.99
|
51 |
+
weight_decay: 1e-4
|
52 |
+
|
53 |
+
lr_scheduler:
|
54 |
+
scheduler: "cosine"
|
55 |
+
params:
|
56 |
+
learning_rate: ${optimizer.params.learning_rate}
|
57 |
+
warmup_steps: 10_000
|
58 |
+
end_lr: 1e-5
|
59 |
+
|
60 |
+
training:
|
61 |
+
gradient_accumulation_steps: 1
|
62 |
+
per_gpu_batch_size: 32
|
63 |
+
mixed_precision: "fp16"
|
64 |
+
enable_tf32: True
|
65 |
+
enable_wandb: True
|
66 |
+
use_ema: True
|
67 |
+
seed: 42
|
68 |
+
max_train_steps: 1_000_000
|
69 |
+
num_generated_images: 2
|
70 |
+
max_grad_norm: 1.0
|
configs/training/stage1/titok_l32.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
project: "titok_l_32_stage1"
|
3 |
+
name: "titok_l_32_stage1_run1"
|
4 |
+
output_dir: "titok_l_32_stage1_run1"
|
5 |
+
max_train_examples: 1_281_167
|
6 |
+
save_every: 50_000
|
7 |
+
eval_every: 50_000
|
8 |
+
generate_every: 5_000
|
9 |
+
log_every: 50
|
10 |
+
log_grad_norm_every: 1_000
|
11 |
+
resume: True
|
12 |
+
init_weight: ""
|
13 |
+
|
14 |
+
|
15 |
+
model:
|
16 |
+
vq_model:
|
17 |
+
codebook_size: 4096
|
18 |
+
token_size: 12
|
19 |
+
use_l2_norm: True
|
20 |
+
commitment_cost: 0.25
|
21 |
+
# vit arch
|
22 |
+
vit_enc_model_size: "large"
|
23 |
+
vit_dec_model_size: "large"
|
24 |
+
vit_enc_patch_size: 16
|
25 |
+
vit_dec_patch_size: 16
|
26 |
+
num_latent_tokens: 32
|
27 |
+
finetune_decoder: False
|
28 |
+
pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"
|
29 |
+
|
30 |
+
losses:
|
31 |
+
quantizer_weight: 1.0
|
32 |
+
|
33 |
+
dataset:
|
34 |
+
params:
|
35 |
+
train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
|
36 |
+
eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
|
37 |
+
num_workers_per_gpu: 12
|
38 |
+
preprocessing:
|
39 |
+
resize_shorter_edge: 256
|
40 |
+
crop_size: 256
|
41 |
+
random_crop: True
|
42 |
+
random_flip: True
|
43 |
+
|
44 |
+
optimizer:
|
45 |
+
name: adamw
|
46 |
+
params:
|
47 |
+
learning_rate: 1e-4
|
48 |
+
beta1: 0.9
|
49 |
+
beta2: 0.99
|
50 |
+
weight_decay: 1e-4
|
51 |
+
epsilon: 1e-8
|
52 |
+
|
53 |
+
lr_scheduler:
|
54 |
+
scheduler: "cosine"
|
55 |
+
params:
|
56 |
+
learning_rate: ${optimizer.params.learning_rate}
|
57 |
+
warmup_steps: 10_000
|
58 |
+
end_lr: 1e-5
|
59 |
+
|
60 |
+
training:
|
61 |
+
gradient_accumulation_steps: 1
|
62 |
+
per_gpu_batch_size: 32
|
63 |
+
mixed_precision: "fp16"
|
64 |
+
enable_tf32: True
|
65 |
+
enable_wandb: True
|
66 |
+
use_ema: True
|
67 |
+
seed: 42
|
68 |
+
max_train_steps: 1_000_000
|
69 |
+
num_generated_images: 2
|
70 |
+
max_grad_norm: 1.0
|
configs/training/stage1/titok_s128.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
project: "titok_s_128_stage1"
|
3 |
+
name: "titok_s_128_stage1_run1"
|
4 |
+
output_dir: "titok_s_128_stage1_run1"
|
5 |
+
max_train_examples: 1_281_167
|
6 |
+
save_every: 50_000
|
7 |
+
eval_every: 50_000
|
8 |
+
generate_every: 5_000
|
9 |
+
log_every: 50
|
10 |
+
log_grad_norm_every: 1_000
|
11 |
+
resume: True
|
12 |
+
init_weight: ""
|
13 |
+
|
14 |
+
|
15 |
+
model:
|
16 |
+
vq_model:
|
17 |
+
codebook_size: 4096
|
18 |
+
token_size: 12
|
19 |
+
use_l2_norm: True
|
20 |
+
commitment_cost: 0.25
|
21 |
+
# vit arch
|
22 |
+
vit_enc_model_size: "small"
|
23 |
+
vit_dec_model_size: "small"
|
24 |
+
vit_enc_patch_size: 16
|
25 |
+
vit_dec_patch_size: 16
|
26 |
+
num_latent_tokens: 128
|
27 |
+
finetune_decoder: False
|
28 |
+
pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"
|
29 |
+
|
30 |
+
losses:
|
31 |
+
quantizer_weight: 1.0
|
32 |
+
|
33 |
+
dataset:
|
34 |
+
params:
|
35 |
+
train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
|
36 |
+
eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
|
37 |
+
num_workers_per_gpu: 12
|
38 |
+
preprocessing:
|
39 |
+
resize_shorter_edge: 256
|
40 |
+
crop_size: 256
|
41 |
+
random_crop: True
|
42 |
+
random_flip: True
|
43 |
+
|
44 |
+
optimizer:
|
45 |
+
name: adamw
|
46 |
+
params:
|
47 |
+
learning_rate: 1e-4
|
48 |
+
beta1: 0.9
|
49 |
+
beta2: 0.99
|
50 |
+
weight_decay: 1e-4
|
51 |
+
epsilon: 1e-8
|
52 |
+
|
53 |
+
lr_scheduler:
|
54 |
+
scheduler: "cosine"
|
55 |
+
params:
|
56 |
+
learning_rate: ${optimizer.params.learning_rate}
|
57 |
+
warmup_steps: 10_000
|
58 |
+
end_lr: 1e-5
|
59 |
+
|
60 |
+
training:
|
61 |
+
gradient_accumulation_steps: 1
|
62 |
+
per_gpu_batch_size: 32
|
63 |
+
mixed_precision: "fp16"
|
64 |
+
enable_tf32: True
|
65 |
+
enable_wandb: True
|
66 |
+
use_ema: True
|
67 |
+
seed: 42
|
68 |
+
max_train_steps: 1_000_000
|
69 |
+
num_generated_images: 2
|
70 |
+
max_grad_norm: 1.0
|
configs/training/stage2/titok_b64.yaml
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
project: "titok_b_64_stage2"
|
3 |
+
name: "titok_b_64_stage2_run1"
|
4 |
+
output_dir: "titok_b_64_stage2_run1"
|
5 |
+
max_train_examples: 1_281_167
|
6 |
+
save_every: 50_000
|
7 |
+
eval_every: 50_000
|
8 |
+
generate_every: 5_000
|
9 |
+
log_every: 50
|
10 |
+
log_grad_norm_every: 1_000
|
11 |
+
resume: True
|
12 |
+
init_weight: "titok_b_64_stage1.bin"
|
13 |
+
|
14 |
+
|
15 |
+
model:
|
16 |
+
vq_model:
|
17 |
+
codebook_size: 4096
|
18 |
+
token_size: 12
|
19 |
+
use_l2_norm: True
|
20 |
+
commitment_cost: 0.0
|
21 |
+
# vit arch
|
22 |
+
vit_enc_model_size: "base"
|
23 |
+
vit_dec_model_size: "base"
|
24 |
+
vit_enc_patch_size: 16
|
25 |
+
vit_dec_patch_size: 16
|
26 |
+
num_latent_tokens: 64
|
27 |
+
finetune_decoder: True
|
28 |
+
pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"
|
29 |
+
|
30 |
+
losses:
|
31 |
+
discriminator_start: 20_000
|
32 |
+
quantizer_weight: 0.0
|
33 |
+
discriminator_factor: 1.0
|
34 |
+
discriminator_weight: 0.01
|
35 |
+
perceptual_loss: "convnext_s"
|
36 |
+
perceptual_weight: 0.1
|
37 |
+
reconstruction_loss: "l2"
|
38 |
+
reconstruction_weight: 1.0
|
39 |
+
lecam_regularization_weight: 0.001
|
40 |
+
|
41 |
+
|
42 |
+
dataset:
|
43 |
+
params:
|
44 |
+
train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
|
45 |
+
eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
|
46 |
+
num_workers_per_gpu: 12
|
47 |
+
preprocessing:
|
48 |
+
resize_shorter_edge: 256
|
49 |
+
crop_size: 256
|
50 |
+
random_crop: True
|
51 |
+
random_flip: True
|
52 |
+
|
53 |
+
|
54 |
+
optimizer:
|
55 |
+
name: adamw
|
56 |
+
params:
|
57 |
+
learning_rate: 1e-4
|
58 |
+
discriminator_learning_rate: 1e-4
|
59 |
+
beta1: 0.9
|
60 |
+
beta2: 0.999
|
61 |
+
weight_decay: 1e-4
|
62 |
+
|
63 |
+
lr_scheduler:
|
64 |
+
scheduler: "cosine"
|
65 |
+
params:
|
66 |
+
learning_rate: ${optimizer.params.learning_rate}
|
67 |
+
warmup_steps: 5_000
|
68 |
+
end_lr: 1e-5
|
69 |
+
|
70 |
+
training:
|
71 |
+
gradient_accumulation_steps: 1
|
72 |
+
per_gpu_batch_size: 32
|
73 |
+
mixed_precision: "fp16"
|
74 |
+
enable_tf32: True
|
75 |
+
enable_wandb: True
|
76 |
+
use_ema: True
|
77 |
+
seed: 42
|
78 |
+
max_train_steps: 500_000
|
79 |
+
num_generated_images: 2
|
80 |
+
max_grad_norm: 1.0
|
configs/training/stage2/titok_l32.yaml
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
project: "titok_l_32_stage2"
|
3 |
+
name: "titok_l_32_stage2_run1"
|
4 |
+
output_dir: "titok_l_32_stage2_run1"
|
5 |
+
max_train_examples: 1_281_167
|
6 |
+
save_every: 50_000
|
7 |
+
eval_every: 50_000
|
8 |
+
generate_every: 5_000
|
9 |
+
log_every: 50
|
10 |
+
log_grad_norm_every: 1_000
|
11 |
+
resume: True
|
12 |
+
init_weight: "titok_l_32_stage1.bin"
|
13 |
+
|
14 |
+
|
15 |
+
model:
|
16 |
+
vq_model:
|
17 |
+
codebook_size: 4096
|
18 |
+
token_size: 12
|
19 |
+
use_l2_norm: True
|
20 |
+
commitment_cost: 0.0
|
21 |
+
# vit arch
|
22 |
+
vit_enc_model_size: "large"
|
23 |
+
vit_dec_model_size: "large"
|
24 |
+
vit_enc_patch_size: 16
|
25 |
+
vit_dec_patch_size: 16
|
26 |
+
num_latent_tokens: 32
|
27 |
+
finetune_decoder: True
|
28 |
+
pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"
|
29 |
+
|
30 |
+
losses:
|
31 |
+
discriminator_start: 20_000
|
32 |
+
quantizer_weight: 0.0
|
33 |
+
discriminator_factor: 1.0
|
34 |
+
discriminator_weight: 0.01
|
35 |
+
perceptual_loss: "convnext_s"
|
36 |
+
perceptual_weight: 0.1
|
37 |
+
reconstruction_loss: "l2"
|
38 |
+
reconstruction_weight: 1.0
|
39 |
+
lecam_regularization_weight: 0.001
|
40 |
+
|
41 |
+
dataset:
|
42 |
+
params:
|
43 |
+
train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
|
44 |
+
eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
|
45 |
+
num_workers_per_gpu: 12
|
46 |
+
preprocessing:
|
47 |
+
resize_shorter_edge: 256
|
48 |
+
crop_size: 256
|
49 |
+
random_crop: True
|
50 |
+
random_flip: True
|
51 |
+
|
52 |
+
|
53 |
+
optimizer:
|
54 |
+
name: adamw
|
55 |
+
params:
|
56 |
+
learning_rate: 1e-4
|
57 |
+
discriminator_learning_rate: 1e-4
|
58 |
+
beta1: 0.9
|
59 |
+
beta2: 0.999
|
60 |
+
weight_decay: 1e-4
|
61 |
+
|
62 |
+
lr_scheduler:
|
63 |
+
scheduler: "cosine"
|
64 |
+
params:
|
65 |
+
learning_rate: ${optimizer.params.learning_rate}
|
66 |
+
warmup_steps: 5_000
|
67 |
+
end_lr: 1e-5
|
68 |
+
|
69 |
+
training:
|
70 |
+
gradient_accumulation_steps: 1
|
71 |
+
per_gpu_batch_size: 32
|
72 |
+
mixed_precision: "fp16"
|
73 |
+
enable_tf32: True
|
74 |
+
enable_wandb: True
|
75 |
+
use_ema: True
|
76 |
+
seed: 42
|
77 |
+
max_train_steps: 500_000
|
78 |
+
num_generated_images: 2
|
79 |
+
max_grad_norm: 1.0
|
configs/training/stage2/titok_s128.yaml
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment:
|
2 |
+
project: "titok_s_128_stage2"
|
3 |
+
name: "titok_s_128_stage2_run1"
|
4 |
+
output_dir: "titok_s_128_stage2_run1"
|
5 |
+
max_train_examples: 1_281_167
|
6 |
+
save_every: 50_000
|
7 |
+
eval_every: 50_000
|
8 |
+
generate_every: 5_000
|
9 |
+
log_every: 50
|
10 |
+
log_grad_norm_every: 1_000
|
11 |
+
resume: True
|
12 |
+
init_weight: "titok_s_128_stage1.bin"
|
13 |
+
|
14 |
+
|
15 |
+
model:
|
16 |
+
vq_model:
|
17 |
+
codebook_size: 4096
|
18 |
+
token_size: 12
|
19 |
+
use_l2_norm: True
|
20 |
+
commitment_cost: 0.0
|
21 |
+
# vit arch
|
22 |
+
vit_enc_model_size: "small"
|
23 |
+
vit_dec_model_size: "small"
|
24 |
+
vit_enc_patch_size: 16
|
25 |
+
vit_dec_patch_size: 16
|
26 |
+
num_latent_tokens: 128
|
27 |
+
finetune_decoder: True
|
28 |
+
pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"
|
29 |
+
|
30 |
+
losses:
|
31 |
+
discriminator_start: 20_000
|
32 |
+
quantizer_weight: 0.0
|
33 |
+
discriminator_factor: 1.0
|
34 |
+
discriminator_weight: 0.01
|
35 |
+
perceptual_loss: "convnext_s"
|
36 |
+
perceptual_weight: 0.1
|
37 |
+
reconstruction_loss: "l2"
|
38 |
+
reconstruction_weight: 1.0
|
39 |
+
lecam_regularization_weight: 0.001
|
40 |
+
|
41 |
+
dataset:
|
42 |
+
params:
|
43 |
+
train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
|
44 |
+
eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
|
45 |
+
num_workers_per_gpu: 12
|
46 |
+
preprocessing:
|
47 |
+
resize_shorter_edge: 256
|
48 |
+
crop_size: 256
|
49 |
+
random_crop: True
|
50 |
+
random_flip: True
|
51 |
+
|
52 |
+
|
53 |
+
optimizer:
|
54 |
+
name: adamw
|
55 |
+
params:
|
56 |
+
learning_rate: 1e-4
|
57 |
+
discriminator_learning_rate: 1e-4
|
58 |
+
beta1: 0.9
|
59 |
+
beta2: 0.999
|
60 |
+
weight_decay: 1e-4
|
61 |
+
|
62 |
+
lr_scheduler:
|
63 |
+
scheduler: "cosine"
|
64 |
+
params:
|
65 |
+
learning_rate: ${optimizer.params.learning_rate}
|
66 |
+
warmup_steps: 5_000
|
67 |
+
end_lr: 1e-5
|
68 |
+
|
69 |
+
training:
|
70 |
+
gradient_accumulation_steps: 1
|
71 |
+
per_gpu_batch_size: 32
|
72 |
+
mixed_precision: "fp16"
|
73 |
+
enable_tf32: True
|
74 |
+
enable_wandb: True
|
75 |
+
use_ema: True
|
76 |
+
seed: 42
|
77 |
+
max_train_steps: 500_000
|
78 |
+
num_generated_images: 2
|
79 |
+
max_grad_norm: 1.0
|
data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .webdataset_reader import SimpleImageDataset, PretoeknizedDataSetJSONL
|
data/convert_imagenet_to_wds.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Adapted from https://github.com/webdataset/webdataset-imagenet/blob/main/convert-imagenet.py
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import os
|
20 |
+
import sys
|
21 |
+
import time
|
22 |
+
|
23 |
+
import webdataset as wds
|
24 |
+
from datasets import load_dataset
|
25 |
+
|
26 |
+
|
27 |
+
def convert_imagenet_to_wds(output_dir, max_train_samples_per_shard, max_val_samples_per_shard):
|
28 |
+
assert not os.path.exists(os.path.join(output_dir, "imagenet-train-000000.tar"))
|
29 |
+
assert not os.path.exists(os.path.join(output_dir, "imagenet-val-000000.tar"))
|
30 |
+
|
31 |
+
opat = os.path.join(output_dir, "imagenet-train-%06d.tar")
|
32 |
+
output = wds.ShardWriter(opat, maxcount=max_train_samples_per_shard)
|
33 |
+
dataset = load_dataset("imagenet-1k", streaming=True, split="train", use_auth_token=True)
|
34 |
+
now = time.time()
|
35 |
+
for i, example in enumerate(dataset):
|
36 |
+
if i % max_train_samples_per_shard == 0:
|
37 |
+
print(i, file=sys.stderr)
|
38 |
+
img, label = example["image"], example["label"]
|
39 |
+
output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label})
|
40 |
+
output.close()
|
41 |
+
time_taken = time.time() - now
|
42 |
+
print(f"Wrote {i+1} train examples in {time_taken // 3600} hours.")
|
43 |
+
|
44 |
+
opat = os.path.join(output_dir, "imagenet-val-%06d.tar")
|
45 |
+
output = wds.ShardWriter(opat, maxcount=max_val_samples_per_shard)
|
46 |
+
dataset = load_dataset("imagenet-1k", streaming=True, split="validation", use_auth_token=True)
|
47 |
+
now = time.time()
|
48 |
+
for i, example in enumerate(dataset):
|
49 |
+
if i % max_val_samples_per_shard == 0:
|
50 |
+
print(i, file=sys.stderr)
|
51 |
+
img, label = example["image"], example["label"]
|
52 |
+
output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label})
|
53 |
+
output.close()
|
54 |
+
time_taken = time.time() - now
|
55 |
+
print(f"Wrote {i+1} val examples in {time_taken // 60} min.")
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
# create parase object
|
60 |
+
parser = argparse.ArgumentParser()
|
61 |
+
parser.add_argument("--output_dir", type=str, required=True, help="Path to the output directory.")
|
62 |
+
parser.add_argument("--max_train_samples_per_shard", type=int, default=4000, help="Path to the output directory.")
|
63 |
+
parser.add_argument("--max_val_samples_per_shard", type=int, default=1000, help="Path to the output directory.")
|
64 |
+
args = parser.parse_args()
|
65 |
+
|
66 |
+
# create output directory
|
67 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
68 |
+
convert_imagenet_to_wds(args.output_dir, args.max_train_samples_per_shard, args.max_val_samples_per_shard)
|
data/webdataset_reader.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This file contains the definition of data loader using webdataset.
|
2 |
+
|
3 |
+
This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
|
4 |
+
All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
|
5 |
+
|
6 |
+
Reference:
|
7 |
+
https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py
|
8 |
+
https://github.com/huggingface/open-muse/blob/main/training/data.py
|
9 |
+
"""
|
10 |
+
|
11 |
+
import math
|
12 |
+
from typing import List, Union, Text
|
13 |
+
import webdataset as wds
|
14 |
+
import torch
|
15 |
+
from torch.utils.data import default_collate
|
16 |
+
from torchvision import transforms
|
17 |
+
from torch.utils.data import Dataset
|
18 |
+
import linecache
|
19 |
+
import json
|
20 |
+
|
21 |
+
|
22 |
+
def filter_keys(key_set):
|
23 |
+
def _f(dictionary):
|
24 |
+
return {k: v for k, v in dictionary.items() if k in key_set}
|
25 |
+
|
26 |
+
return _f
|
27 |
+
|
28 |
+
|
29 |
+
class ImageTransform:
|
30 |
+
def __init__(self,
|
31 |
+
resize_shorter_edge: int = 256,
|
32 |
+
crop_size: int = 256,
|
33 |
+
random_crop: bool = True,
|
34 |
+
random_flip: bool = True,
|
35 |
+
normalize_mean: List[float] = [0., 0., 0.],
|
36 |
+
normalize_std: List[float] = [1., 1., 1.]):
|
37 |
+
"""Initializes the WebDatasetReader with specified augmentation parameters.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
|
41 |
+
crop_size: An integer, the size to crop the input image to.
|
42 |
+
random_crop: A boolean, whether to use random crop augmentation during training.
|
43 |
+
random_flip: A boolean, whether to use random flipping augmentation during training.
|
44 |
+
normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
|
45 |
+
normalize_std: A list of float, the normalization std used to normalize the image tensor.
|
46 |
+
|
47 |
+
Raises:
|
48 |
+
NotImplementedError: If the interpolation mode is not one of ["bicubic", "bilinear"].
|
49 |
+
"""
|
50 |
+
train_transform = []
|
51 |
+
interpolation = transforms.InterpolationMode.BICUBIC
|
52 |
+
|
53 |
+
train_transform.append(
|
54 |
+
transforms.Resize(resize_shorter_edge, interpolation=interpolation, antialias=True))
|
55 |
+
if random_crop:
|
56 |
+
train_transform.append(transforms.RandomCrop(crop_size))
|
57 |
+
else:
|
58 |
+
train_transform.append(transforms.CenterCrop(crop_size))
|
59 |
+
if random_flip:
|
60 |
+
train_transform.append(transforms.RandomHorizontalFlip())
|
61 |
+
train_transform.append(transforms.ToTensor())
|
62 |
+
# normalize_mean = [0, 0, 0] and normalize_std = [1, 1, 1] will normalize images into [0, 1],
|
63 |
+
# normalize_mean = [0.5, 0.5, 0.5] and normalize_std = [0.5, 0.5, 0.5] will normalize images into [-1, 1].
|
64 |
+
train_transform.append(transforms.Normalize(normalize_mean, normalize_std))
|
65 |
+
|
66 |
+
self.train_transform = transforms.Compose(train_transform)
|
67 |
+
self.eval_transform = transforms.Compose(
|
68 |
+
[
|
69 |
+
# Note that we always resize to crop_size during eval to ensure the results
|
70 |
+
# can be compared against reference numbers on ImageNet etc.
|
71 |
+
transforms.Resize(crop_size, interpolation=interpolation, antialias=True),
|
72 |
+
transforms.CenterCrop(crop_size),
|
73 |
+
transforms.ToTensor(),
|
74 |
+
transforms.Normalize(normalize_mean, normalize_std)
|
75 |
+
]
|
76 |
+
)
|
77 |
+
print(f"self.train_transform: {self.train_transform}")
|
78 |
+
print(f"self.eval_transform: {self.eval_transform}")
|
79 |
+
|
80 |
+
|
81 |
+
class SimpleImageDataset:
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
train_shards_path: Union[Text, List[Text]],
|
85 |
+
eval_shards_path: Union[Text, List[Text]],
|
86 |
+
num_train_examples: int,
|
87 |
+
per_gpu_batch_size: int,
|
88 |
+
global_batch_size: int,
|
89 |
+
num_workers_per_gpu: int = 12,
|
90 |
+
resize_shorter_edge: int = 256,
|
91 |
+
crop_size: int = 256,
|
92 |
+
random_crop = True,
|
93 |
+
random_flip = True,
|
94 |
+
normalize_mean: List[float] = [0., 0., 0.],
|
95 |
+
normalize_std: List[float] = [1., 1., 1.],
|
96 |
+
):
|
97 |
+
"""Initializes the WebDatasetReader class.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
train_shards_path: A string or list of string, path to the training data shards in webdataset format.
|
101 |
+
eval_shards_path: A string or list of string, path to the evaluation data shards in webdataset format.
|
102 |
+
num_train_examples: An integer, total number of training examples.
|
103 |
+
per_gpu_batch_size: An integer, number of examples per GPU batch.
|
104 |
+
global_batch_size: An integer, total number of examples in a batch across all GPUs.
|
105 |
+
num_workers_per_gpu: An integer, number of workers per GPU.
|
106 |
+
resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
|
107 |
+
crop_size: An integer, the size to crop the input image to.
|
108 |
+
random_crop: A boolean, whether to use random crop augmentation during training.
|
109 |
+
random_flip: A boolean, whether to use random flipping augmentation during training.
|
110 |
+
normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
|
111 |
+
normalize_std: A list of float, the normalization std used to normalize the image tensor.
|
112 |
+
"""
|
113 |
+
transform = ImageTransform(
|
114 |
+
resize_shorter_edge, crop_size, random_crop, random_flip,
|
115 |
+
normalize_mean, normalize_std)
|
116 |
+
|
117 |
+
train_processing_pipeline = [
|
118 |
+
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"])),
|
119 |
+
wds.rename(
|
120 |
+
image="jpg;png;jpeg;webp",
|
121 |
+
class_id="cls",
|
122 |
+
handler=wds.warn_and_continue,
|
123 |
+
),
|
124 |
+
wds.map(filter_keys(set(["image", "class_id", "filename"]))),
|
125 |
+
wds.map_dict(
|
126 |
+
image=transform.train_transform,
|
127 |
+
class_id=lambda x: int(x),
|
128 |
+
handler=wds.warn_and_continue,
|
129 |
+
),
|
130 |
+
]
|
131 |
+
|
132 |
+
test_processing_pipeline = [
|
133 |
+
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"])),
|
134 |
+
wds.rename(
|
135 |
+
image="jpg;png;jpeg;webp",
|
136 |
+
class_id="cls",
|
137 |
+
handler=wds.warn_and_continue,
|
138 |
+
),
|
139 |
+
wds.map(filter_keys(set(["image", "class_id", "filename"]))),
|
140 |
+
wds.map_dict(
|
141 |
+
image=transform.eval_transform,
|
142 |
+
class_id=lambda x: int(x),
|
143 |
+
handler=wds.warn_and_continue,
|
144 |
+
),
|
145 |
+
]
|
146 |
+
|
147 |
+
# Create train dataset and loader.
|
148 |
+
pipeline = [
|
149 |
+
wds.ResampledShards(train_shards_path),
|
150 |
+
wds.tarfile_to_samples(handler=wds.warn_and_continue),
|
151 |
+
wds.shuffle(bufsize=5000,
|
152 |
+
initial=1000),
|
153 |
+
*train_processing_pipeline,
|
154 |
+
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
|
155 |
+
]
|
156 |
+
|
157 |
+
num_batches = math.ceil(num_train_examples / global_batch_size)
|
158 |
+
num_worker_batches = math.ceil(num_train_examples /
|
159 |
+
(global_batch_size * num_workers_per_gpu))
|
160 |
+
num_batches = num_worker_batches * num_workers_per_gpu
|
161 |
+
num_samples = num_batches * global_batch_size
|
162 |
+
|
163 |
+
# Each worker is iterating over the complete dataset.
|
164 |
+
self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
|
165 |
+
self._train_dataloader = wds.WebLoader(
|
166 |
+
self._train_dataset,
|
167 |
+
batch_size=None,
|
168 |
+
shuffle=False,
|
169 |
+
num_workers=num_workers_per_gpu,
|
170 |
+
pin_memory=True,
|
171 |
+
persistent_workers=True,
|
172 |
+
)
|
173 |
+
# Add meta-data to dataloader instance for convenience.
|
174 |
+
self._train_dataloader.num_batches = num_batches
|
175 |
+
self._train_dataloader.num_samples = num_samples
|
176 |
+
|
177 |
+
# Create eval dataset and loader.
|
178 |
+
pipeline = [
|
179 |
+
wds.SimpleShardList(eval_shards_path),
|
180 |
+
wds.split_by_worker,
|
181 |
+
wds.tarfile_to_samples(handler=wds.ignore_and_continue),
|
182 |
+
*test_processing_pipeline,
|
183 |
+
wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate),
|
184 |
+
]
|
185 |
+
self._eval_dataset = wds.DataPipeline(*pipeline)
|
186 |
+
self._eval_dataloader = wds.WebLoader(
|
187 |
+
self._eval_dataset,
|
188 |
+
batch_size=None,
|
189 |
+
shuffle=False,
|
190 |
+
num_workers=num_workers_per_gpu,
|
191 |
+
pin_memory=True,
|
192 |
+
persistent_workers=True,
|
193 |
+
)
|
194 |
+
|
195 |
+
@property
|
196 |
+
def train_dataset(self):
|
197 |
+
return self._train_dataset
|
198 |
+
|
199 |
+
@property
|
200 |
+
def train_dataloader(self):
|
201 |
+
return self._train_dataloader
|
202 |
+
|
203 |
+
@property
|
204 |
+
def eval_dataset(self):
|
205 |
+
return self._eval_dataset
|
206 |
+
|
207 |
+
@property
|
208 |
+
def eval_dataloader(self):
|
209 |
+
return self._eval_dataloader
|
210 |
+
|
211 |
+
|
212 |
+
class PretoeknizedDataSetJSONL(Dataset):
|
213 |
+
def __init__(self, data_path):
|
214 |
+
super().__init__()
|
215 |
+
self.jsonl_file = data_path
|
216 |
+
self.num_lines = sum(1 for _ in open(self.jsonl_file))
|
217 |
+
# Ensure the file is cached
|
218 |
+
linecache.checkcache(self.jsonl_file)
|
219 |
+
print("Number of data:", self.num_lines)
|
220 |
+
|
221 |
+
def __len__(self):
|
222 |
+
return self.num_lines
|
223 |
+
|
224 |
+
def __getitem__(self, idx):
|
225 |
+
line = linecache.getline(self.jsonl_file, idx + 1).strip()
|
226 |
+
data = json.loads(line)
|
227 |
+
return torch.tensor(data["class_id"]), torch.tensor(data["tokens"])
|
demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
demo_util.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Demo file for sampling images from TiTok.
|
2 |
+
|
3 |
+
Copyright (2024) Bytedance Ltd. and/or its affiliates
|
4 |
+
|
5 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
you may not use this file except in compliance with the License.
|
7 |
+
You may obtain a copy of the License at
|
8 |
+
|
9 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
|
11 |
+
Unless required by applicable law or agreed to in writing, software
|
12 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
See the License for the specific language governing permissions and
|
15 |
+
limitations under the License.
|
16 |
+
"""
|
17 |
+
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from modeling.titok import TiTok
|
23 |
+
from modeling.maskgit import ImageBert, UViTBert
|
24 |
+
from modeling.rar import RAR
|
25 |
+
|
26 |
+
|
27 |
+
def get_config_cli():
|
28 |
+
cli_conf = OmegaConf.from_cli()
|
29 |
+
|
30 |
+
yaml_conf = OmegaConf.load(cli_conf.config)
|
31 |
+
conf = OmegaConf.merge(yaml_conf, cli_conf)
|
32 |
+
|
33 |
+
return conf
|
34 |
+
|
35 |
+
def get_config(config_path):
|
36 |
+
conf = OmegaConf.load(config_path)
|
37 |
+
return conf
|
38 |
+
|
39 |
+
def get_titok_tokenizer(config):
|
40 |
+
tokenizer = TiTok(config)
|
41 |
+
tokenizer.load_state_dict(torch.load(config.experiment.tokenizer_checkpoint, map_location="cpu"))
|
42 |
+
tokenizer.eval()
|
43 |
+
tokenizer.requires_grad_(False)
|
44 |
+
return tokenizer
|
45 |
+
|
46 |
+
def get_titok_generator(config):
|
47 |
+
if config.model.generator.model_type == "ViT":
|
48 |
+
model_cls = ImageBert
|
49 |
+
elif config.model.generator.model_type == "UViT":
|
50 |
+
model_cls = UViTBert
|
51 |
+
else:
|
52 |
+
raise ValueError(f"Unsupported model type {config.model.generator.model_type}")
|
53 |
+
generator = model_cls(config)
|
54 |
+
generator.load_state_dict(torch.load(config.experiment.generator_checkpoint, map_location="cpu"))
|
55 |
+
generator.eval()
|
56 |
+
generator.requires_grad_(False)
|
57 |
+
return generator
|
58 |
+
|
59 |
+
def get_rar_generator(config):
|
60 |
+
model_cls = RAR
|
61 |
+
generator = model_cls(config)
|
62 |
+
generator.load_state_dict(torch.load(config.experiment.generator_checkpoint, map_location="cpu"))
|
63 |
+
generator.eval()
|
64 |
+
generator.requires_grad_(False)
|
65 |
+
generator.set_random_ratio(0)
|
66 |
+
return generator
|
67 |
+
|
68 |
+
|
69 |
+
@torch.no_grad()
|
70 |
+
def sample_fn(generator,
|
71 |
+
tokenizer,
|
72 |
+
labels=None,
|
73 |
+
guidance_scale=3.0,
|
74 |
+
guidance_decay="constant",
|
75 |
+
guidance_scale_pow=3.0,
|
76 |
+
randomize_temperature=2.0,
|
77 |
+
softmax_temperature_annealing=False,
|
78 |
+
num_sample_steps=8,
|
79 |
+
device="cuda",
|
80 |
+
return_tensor=False):
|
81 |
+
generator.eval()
|
82 |
+
tokenizer.eval()
|
83 |
+
if labels is None:
|
84 |
+
# goldfish, chicken, tiger, cat, hourglass, ship, dog, race car, airliner, teddy bear, random
|
85 |
+
labels = [1, 7, 282, 604, 724, 179, 751, 404, 850, torch.randint(0, 999, size=(1,))]
|
86 |
+
|
87 |
+
if not isinstance(labels, torch.Tensor):
|
88 |
+
labels = torch.LongTensor(labels).to(device)
|
89 |
+
|
90 |
+
generated_tokens = generator.generate(
|
91 |
+
condition=labels,
|
92 |
+
guidance_scale=guidance_scale,
|
93 |
+
guidance_decay=guidance_decay,
|
94 |
+
guidance_scale_pow=guidance_scale_pow,
|
95 |
+
randomize_temperature=randomize_temperature,
|
96 |
+
softmax_temperature_annealing=softmax_temperature_annealing,
|
97 |
+
num_sample_steps=num_sample_steps)
|
98 |
+
|
99 |
+
generated_image = tokenizer.decode_tokens(
|
100 |
+
generated_tokens.view(generated_tokens.shape[0], -1)
|
101 |
+
)
|
102 |
+
if return_tensor:
|
103 |
+
return generated_image
|
104 |
+
|
105 |
+
generated_image = torch.clamp(generated_image, 0.0, 1.0)
|
106 |
+
generated_image = (generated_image * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
|
107 |
+
|
108 |
+
return generated_image
|
evaluator/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .evaluator import VQGANEvaluator
|
evaluator/evaluator.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This file contains a class to evalute the reconstruction results.
|
2 |
+
|
3 |
+
Copyright (2024) Bytedance Ltd. and/or its affiliates
|
4 |
+
|
5 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
you may not use this file except in compliance with the License.
|
7 |
+
You may obtain a copy of the License at
|
8 |
+
|
9 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
|
11 |
+
Unless required by applicable law or agreed to in writing, software
|
12 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
See the License for the specific language governing permissions and
|
15 |
+
limitations under the License.
|
16 |
+
"""
|
17 |
+
|
18 |
+
import warnings
|
19 |
+
|
20 |
+
from typing import Sequence, Optional, Mapping, Text
|
21 |
+
import numpy as np
|
22 |
+
from scipy import linalg
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
|
26 |
+
from .inception import get_inception_model
|
27 |
+
|
28 |
+
|
29 |
+
def get_covariance(sigma: torch.Tensor, total: torch.Tensor, num_examples: int) -> torch.Tensor:
|
30 |
+
"""Computes covariance of the input tensor.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
sigma: A torch.Tensor, sum of outer products of input features.
|
34 |
+
total: A torch.Tensor, sum of all input features.
|
35 |
+
num_examples: An integer, number of examples in the input tensor.
|
36 |
+
Returns:
|
37 |
+
A torch.Tensor, covariance of the input tensor.
|
38 |
+
"""
|
39 |
+
if num_examples == 0:
|
40 |
+
return torch.zeros_like(sigma)
|
41 |
+
|
42 |
+
sub_matrix = torch.outer(total, total)
|
43 |
+
sub_matrix = sub_matrix / num_examples
|
44 |
+
|
45 |
+
return (sigma - sub_matrix) / (num_examples - 1)
|
46 |
+
|
47 |
+
|
48 |
+
class VQGANEvaluator:
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
device,
|
52 |
+
enable_rfid: bool = True,
|
53 |
+
enable_inception_score: bool = True,
|
54 |
+
enable_codebook_usage_measure: bool = False,
|
55 |
+
enable_codebook_entropy_measure: bool = False,
|
56 |
+
num_codebook_entries: int = 1024
|
57 |
+
):
|
58 |
+
"""Initializes VQGAN Evaluator.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
device: The device to use for evaluation.
|
62 |
+
enable_rfid: A boolean, whether enabling rFID score.
|
63 |
+
enable_inception_score: A boolean, whether enabling Inception Score.
|
64 |
+
enable_codebook_usage_measure: A boolean, whether enabling codebook usage measure.
|
65 |
+
enable_codebook_entropy_measure: A boolean, whether enabling codebook entropy measure.
|
66 |
+
num_codebook_entries: An integer, the number of codebook entries.
|
67 |
+
"""
|
68 |
+
self._device = device
|
69 |
+
|
70 |
+
self._enable_rfid = enable_rfid
|
71 |
+
self._enable_inception_score = enable_inception_score
|
72 |
+
self._enable_codebook_usage_measure = enable_codebook_usage_measure
|
73 |
+
self._enable_codebook_entropy_measure = enable_codebook_entropy_measure
|
74 |
+
self._num_codebook_entries = num_codebook_entries
|
75 |
+
|
76 |
+
# Variables related to Inception score and rFID.
|
77 |
+
self._inception_model = None
|
78 |
+
self._is_num_features = 0
|
79 |
+
self._rfid_num_features = 0
|
80 |
+
if self._enable_inception_score or self._enable_rfid:
|
81 |
+
self._rfid_num_features = 2048
|
82 |
+
self._is_num_features = 1008
|
83 |
+
self._inception_model = get_inception_model().to(self._device)
|
84 |
+
self._inception_model.eval()
|
85 |
+
self._is_eps = 1e-16
|
86 |
+
self._rfid_eps = 1e-6
|
87 |
+
|
88 |
+
self.reset_metrics()
|
89 |
+
|
90 |
+
def reset_metrics(self):
|
91 |
+
"""Resets all metrics."""
|
92 |
+
self._num_examples = 0
|
93 |
+
self._num_updates = 0
|
94 |
+
|
95 |
+
self._is_prob_total = torch.zeros(
|
96 |
+
self._is_num_features, dtype=torch.float64, device=self._device
|
97 |
+
)
|
98 |
+
self._is_total_kl_d = torch.zeros(
|
99 |
+
self._is_num_features, dtype=torch.float64, device=self._device
|
100 |
+
)
|
101 |
+
self._rfid_real_sigma = torch.zeros(
|
102 |
+
(self._rfid_num_features, self._rfid_num_features),
|
103 |
+
dtype=torch.float64, device=self._device
|
104 |
+
)
|
105 |
+
self._rfid_real_total = torch.zeros(
|
106 |
+
self._rfid_num_features, dtype=torch.float64, device=self._device
|
107 |
+
)
|
108 |
+
self._rfid_fake_sigma = torch.zeros(
|
109 |
+
(self._rfid_num_features, self._rfid_num_features),
|
110 |
+
dtype=torch.float64, device=self._device
|
111 |
+
)
|
112 |
+
self._rfid_fake_total = torch.zeros(
|
113 |
+
self._rfid_num_features, dtype=torch.float64, device=self._device
|
114 |
+
)
|
115 |
+
|
116 |
+
self._set_of_codebook_indices = set()
|
117 |
+
self._codebook_frequencies = torch.zeros((self._num_codebook_entries), dtype=torch.float64, device=self._device)
|
118 |
+
|
119 |
+
def update(
|
120 |
+
self,
|
121 |
+
real_images: torch.Tensor,
|
122 |
+
fake_images: torch.Tensor,
|
123 |
+
codebook_indices: Optional[torch.Tensor] = None
|
124 |
+
):
|
125 |
+
"""Updates the metrics with the given images.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
real_images: A torch.Tensor, the real images.
|
129 |
+
fake_images: A torch.Tensor, the fake images.
|
130 |
+
codebook_indices: A torch.Tensor, the indices of the codebooks for each image.
|
131 |
+
|
132 |
+
Raises:
|
133 |
+
ValueError: If the fake images is not in RGB (3 channel).
|
134 |
+
ValueError: If the fake and real images have different shape.
|
135 |
+
"""
|
136 |
+
|
137 |
+
batch_size = real_images.shape[0]
|
138 |
+
dim = tuple(range(1, real_images.ndim))
|
139 |
+
self._num_examples += batch_size
|
140 |
+
self._num_updates += 1
|
141 |
+
|
142 |
+
if self._enable_inception_score or self._enable_rfid:
|
143 |
+
# Quantize to uint8 as a real image.
|
144 |
+
fake_inception_images = (fake_images * 255).to(torch.uint8)
|
145 |
+
features_fake = self._inception_model(fake_inception_images)
|
146 |
+
inception_logits_fake = features_fake["logits_unbiased"]
|
147 |
+
inception_probabilities_fake = F.softmax(inception_logits_fake, dim=-1)
|
148 |
+
|
149 |
+
if self._enable_inception_score:
|
150 |
+
probabiliies_sum = torch.sum(inception_probabilities_fake, 0, dtype=torch.float64)
|
151 |
+
|
152 |
+
log_prob = torch.log(inception_probabilities_fake + self._is_eps)
|
153 |
+
if log_prob.dtype != inception_probabilities_fake.dtype:
|
154 |
+
log_prob = log_prob.to(inception_probabilities_fake)
|
155 |
+
kl_sum = torch.sum(inception_probabilities_fake * log_prob, 0, dtype=torch.float64)
|
156 |
+
|
157 |
+
self._is_prob_total += probabiliies_sum
|
158 |
+
self._is_total_kl_d += kl_sum
|
159 |
+
|
160 |
+
if self._enable_rfid:
|
161 |
+
real_inception_images = (real_images * 255).to(torch.uint8)
|
162 |
+
features_real = self._inception_model(real_inception_images)
|
163 |
+
if (features_real['2048'].shape[0] != features_fake['2048'].shape[0] or
|
164 |
+
features_real['2048'].shape[1] != features_fake['2048'].shape[1]):
|
165 |
+
raise ValueError(f"Number of features should be equal for real and fake.")
|
166 |
+
|
167 |
+
for f_real, f_fake in zip(features_real['2048'], features_fake['2048']):
|
168 |
+
self._rfid_real_total += f_real
|
169 |
+
self._rfid_fake_total += f_fake
|
170 |
+
|
171 |
+
self._rfid_real_sigma += torch.outer(f_real, f_real)
|
172 |
+
self._rfid_fake_sigma += torch.outer(f_fake, f_fake)
|
173 |
+
|
174 |
+
if self._enable_codebook_usage_measure:
|
175 |
+
self._set_of_codebook_indices |= set(torch.unique(codebook_indices, sorted=False).tolist())
|
176 |
+
|
177 |
+
if self._enable_codebook_entropy_measure:
|
178 |
+
entries, counts = torch.unique(codebook_indices, sorted=False, return_counts=True)
|
179 |
+
self._codebook_frequencies.index_add_(0, entries.int(), counts.double())
|
180 |
+
|
181 |
+
|
182 |
+
def result(self) -> Mapping[Text, torch.Tensor]:
|
183 |
+
"""Returns the evaluation result."""
|
184 |
+
eval_score = {}
|
185 |
+
|
186 |
+
if self._num_examples < 1:
|
187 |
+
raise ValueError("No examples to evaluate.")
|
188 |
+
|
189 |
+
if self._enable_inception_score:
|
190 |
+
mean_probs = self._is_prob_total / self._num_examples
|
191 |
+
log_mean_probs = torch.log(mean_probs + self._is_eps)
|
192 |
+
if log_mean_probs.dtype != self._is_prob_total.dtype:
|
193 |
+
log_mean_probs = log_mean_probs.to(self._is_prob_total)
|
194 |
+
excess_entropy = self._is_prob_total * log_mean_probs
|
195 |
+
avg_kl_d = torch.sum(self._is_total_kl_d - excess_entropy) / self._num_examples
|
196 |
+
|
197 |
+
inception_score = torch.exp(avg_kl_d).item()
|
198 |
+
eval_score["InceptionScore"] = inception_score
|
199 |
+
|
200 |
+
if self._enable_rfid:
|
201 |
+
mu_real = self._rfid_real_total / self._num_examples
|
202 |
+
mu_fake = self._rfid_fake_total / self._num_examples
|
203 |
+
sigma_real = get_covariance(self._rfid_real_sigma, self._rfid_real_total, self._num_examples)
|
204 |
+
sigma_fake = get_covariance(self._rfid_fake_sigma, self._rfid_fake_total, self._num_examples)
|
205 |
+
|
206 |
+
mu_real, mu_fake = mu_real.cpu(), mu_fake.cpu()
|
207 |
+
sigma_real, sigma_fake = sigma_real.cpu(), sigma_fake.cpu()
|
208 |
+
|
209 |
+
diff = mu_real - mu_fake
|
210 |
+
|
211 |
+
# Product might be almost singular.
|
212 |
+
covmean, _ = linalg.sqrtm(sigma_real.mm(sigma_fake).numpy(), disp=False)
|
213 |
+
# Numerical error might give slight imaginary component.
|
214 |
+
if np.iscomplexobj(covmean):
|
215 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
216 |
+
m = np.max(np.abs(covmean.imag))
|
217 |
+
raise ValueError("Imaginary component {}".format(m))
|
218 |
+
covmean = covmean.real
|
219 |
+
|
220 |
+
tr_covmean = np.trace(covmean)
|
221 |
+
|
222 |
+
if not np.isfinite(covmean).all():
|
223 |
+
tr_covmean = np.sum(np.sqrt((
|
224 |
+
(np.diag(sigma_real) * self._rfid_eps) * (np.diag(sigma_fake) * self._rfid_eps))
|
225 |
+
/ (self._rfid_eps * self._rfid_eps)
|
226 |
+
))
|
227 |
+
|
228 |
+
rfid = float(diff.dot(diff).item() + torch.trace(sigma_real) + torch.trace(sigma_fake)
|
229 |
+
- 2 * tr_covmean
|
230 |
+
)
|
231 |
+
if torch.isnan(torch.tensor(rfid)) or torch.isinf(torch.tensor(rfid)):
|
232 |
+
warnings.warn("The product of covariance of train and test features is out of bounds.")
|
233 |
+
|
234 |
+
eval_score["rFID"] = rfid
|
235 |
+
|
236 |
+
if self._enable_codebook_usage_measure:
|
237 |
+
usage = float(len(self._set_of_codebook_indices)) / self._num_codebook_entries
|
238 |
+
eval_score["CodebookUsage"] = usage
|
239 |
+
|
240 |
+
if self._enable_codebook_entropy_measure:
|
241 |
+
probs = self._codebook_frequencies / self._codebook_frequencies.sum()
|
242 |
+
entropy = (-torch.log2(probs + 1e-8) * probs).sum()
|
243 |
+
eval_score["CodebookEntropy"] = entropy
|
244 |
+
|
245 |
+
return eval_score
|
evaluator/inception.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This file is for Inception model borrowed from torch metrics / fidelity.
|
2 |
+
|
3 |
+
This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
|
4 |
+
All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
|
5 |
+
|
6 |
+
Reference:
|
7 |
+
https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py
|
8 |
+
"""
|
9 |
+
# Copyright The Lightning team.
|
10 |
+
#
|
11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
12 |
+
# you may not use this file except in compliance with the License.
|
13 |
+
# You may obtain a copy of the License at
|
14 |
+
#
|
15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
16 |
+
#
|
17 |
+
# Unless required by applicable law or agreed to in writing, software
|
18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
20 |
+
# See the License for the specific language governing permissions and
|
21 |
+
# limitations under the License.
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
|
25 |
+
from torch_fidelity.feature_extractor_base import FeatureExtractorBase
|
26 |
+
from torch_fidelity.helpers import vassert
|
27 |
+
from torch_fidelity.feature_extractor_inceptionv3 import BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE_1, InceptionE_2
|
28 |
+
from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x
|
29 |
+
|
30 |
+
try:
|
31 |
+
from torchvision.models.utils import load_state_dict_from_url
|
32 |
+
except ImportError:
|
33 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
34 |
+
|
35 |
+
|
36 |
+
# Note: Compared shasum and models should be the same.
|
37 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
|
38 |
+
|
39 |
+
class FeatureExtractorInceptionV3(FeatureExtractorBase):
|
40 |
+
INPUT_IMAGE_SIZE = 299
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
name,
|
45 |
+
features_list,
|
46 |
+
**kwargs,
|
47 |
+
):
|
48 |
+
"""
|
49 |
+
InceptionV3 feature extractor for 2D RGB 24bit images.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
|
53 |
+
name (str): Unique name of the feature extractor, must be the same as used in
|
54 |
+
:func:`register_feature_extractor`.
|
55 |
+
|
56 |
+
features_list (list): A list of the requested feature names, which will be produced for each input. This
|
57 |
+
feature extractor provides the following features:
|
58 |
+
|
59 |
+
- '64'
|
60 |
+
- '192'
|
61 |
+
- '768'
|
62 |
+
- '2048'
|
63 |
+
- 'logits_unbiased'
|
64 |
+
- 'logits'
|
65 |
+
|
66 |
+
"""
|
67 |
+
super(FeatureExtractorInceptionV3, self).__init__(name, features_list)
|
68 |
+
self.feature_extractor_internal_dtype = torch.float64
|
69 |
+
|
70 |
+
self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
|
71 |
+
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
|
72 |
+
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
|
73 |
+
self.MaxPool_1 = torch.nn.MaxPool2d(kernel_size=3, stride=2)
|
74 |
+
|
75 |
+
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
|
76 |
+
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
|
77 |
+
self.MaxPool_2 = torch.nn.MaxPool2d(kernel_size=3, stride=2)
|
78 |
+
|
79 |
+
self.Mixed_5b = InceptionA(192, pool_features=32)
|
80 |
+
self.Mixed_5c = InceptionA(256, pool_features=64)
|
81 |
+
self.Mixed_5d = InceptionA(288, pool_features=64)
|
82 |
+
self.Mixed_6a = InceptionB(288)
|
83 |
+
self.Mixed_6b = InceptionC(768, channels_7x7=128)
|
84 |
+
self.Mixed_6c = InceptionC(768, channels_7x7=160)
|
85 |
+
self.Mixed_6d = InceptionC(768, channels_7x7=160)
|
86 |
+
self.Mixed_6e = InceptionC(768, channels_7x7=192)
|
87 |
+
|
88 |
+
self.Mixed_7a = InceptionD(768)
|
89 |
+
self.Mixed_7b = InceptionE_1(1280)
|
90 |
+
self.Mixed_7c = InceptionE_2(2048)
|
91 |
+
self.AvgPool = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
92 |
+
|
93 |
+
self.fc = torch.nn.Linear(2048, 1008)
|
94 |
+
|
95 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
96 |
+
#state_dict = torch.load(FID_WEIGHTS_URL, map_location='cpu')
|
97 |
+
self.load_state_dict(state_dict)
|
98 |
+
|
99 |
+
self.to(self.feature_extractor_internal_dtype)
|
100 |
+
self.requires_grad_(False)
|
101 |
+
self.eval()
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
vassert(torch.is_tensor(x) and x.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8')
|
105 |
+
vassert(x.dim() == 4 and x.shape[1] == 3, f'Input is not Bx3xHxW: {x.shape}')
|
106 |
+
features = {}
|
107 |
+
remaining_features = self.features_list.copy()
|
108 |
+
|
109 |
+
x = x.to(self.feature_extractor_internal_dtype)
|
110 |
+
# N x 3 x ? x ?
|
111 |
+
|
112 |
+
x = interpolate_bilinear_2d_like_tensorflow1x(
|
113 |
+
x,
|
114 |
+
size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE),
|
115 |
+
align_corners=False,
|
116 |
+
)
|
117 |
+
# N x 3 x 299 x 299
|
118 |
+
|
119 |
+
# x = (x - 128) * torch.tensor(0.0078125, dtype=torch.float32, device=x.device) # really happening in graph
|
120 |
+
x = (x - 128) / 128 # but this gives bit-exact output _of this step_ too
|
121 |
+
# N x 3 x 299 x 299
|
122 |
+
|
123 |
+
x = self.Conv2d_1a_3x3(x)
|
124 |
+
# N x 32 x 149 x 149
|
125 |
+
x = self.Conv2d_2a_3x3(x)
|
126 |
+
# N x 32 x 147 x 147
|
127 |
+
x = self.Conv2d_2b_3x3(x)
|
128 |
+
# N x 64 x 147 x 147
|
129 |
+
x = self.MaxPool_1(x)
|
130 |
+
# N x 64 x 73 x 73
|
131 |
+
|
132 |
+
if '64' in remaining_features:
|
133 |
+
features['64'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
|
134 |
+
remaining_features.remove('64')
|
135 |
+
if len(remaining_features) == 0:
|
136 |
+
return features
|
137 |
+
|
138 |
+
x = self.Conv2d_3b_1x1(x)
|
139 |
+
# N x 80 x 73 x 73
|
140 |
+
x = self.Conv2d_4a_3x3(x)
|
141 |
+
# N x 192 x 71 x 71
|
142 |
+
x = self.MaxPool_2(x)
|
143 |
+
# N x 192 x 35 x 35
|
144 |
+
|
145 |
+
if '192' in remaining_features:
|
146 |
+
features['192'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
|
147 |
+
remaining_features.remove('192')
|
148 |
+
if len(remaining_features) == 0:
|
149 |
+
return features
|
150 |
+
|
151 |
+
x = self.Mixed_5b(x)
|
152 |
+
# N x 256 x 35 x 35
|
153 |
+
x = self.Mixed_5c(x)
|
154 |
+
# N x 288 x 35 x 35
|
155 |
+
x = self.Mixed_5d(x)
|
156 |
+
# N x 288 x 35 x 35
|
157 |
+
x = self.Mixed_6a(x)
|
158 |
+
# N x 768 x 17 x 17
|
159 |
+
x = self.Mixed_6b(x)
|
160 |
+
# N x 768 x 17 x 17
|
161 |
+
x = self.Mixed_6c(x)
|
162 |
+
# N x 768 x 17 x 17
|
163 |
+
x = self.Mixed_6d(x)
|
164 |
+
# N x 768 x 17 x 17
|
165 |
+
x = self.Mixed_6e(x)
|
166 |
+
# N x 768 x 17 x 17
|
167 |
+
|
168 |
+
if '768' in remaining_features:
|
169 |
+
features['768'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1).to(torch.float32)
|
170 |
+
remaining_features.remove('768')
|
171 |
+
if len(remaining_features) == 0:
|
172 |
+
return features
|
173 |
+
|
174 |
+
x = self.Mixed_7a(x)
|
175 |
+
# N x 1280 x 8 x 8
|
176 |
+
x = self.Mixed_7b(x)
|
177 |
+
# N x 2048 x 8 x 8
|
178 |
+
x = self.Mixed_7c(x)
|
179 |
+
# N x 2048 x 8 x 8
|
180 |
+
x = self.AvgPool(x)
|
181 |
+
# N x 2048 x 1 x 1
|
182 |
+
|
183 |
+
x = torch.flatten(x, 1)
|
184 |
+
# N x 2048
|
185 |
+
|
186 |
+
if '2048' in remaining_features:
|
187 |
+
features['2048'] = x
|
188 |
+
remaining_features.remove('2048')
|
189 |
+
if len(remaining_features) == 0:
|
190 |
+
return features
|
191 |
+
|
192 |
+
if 'logits_unbiased' in remaining_features:
|
193 |
+
x = x.mm(self.fc.weight.T)
|
194 |
+
# N x 1008 (num_classes)
|
195 |
+
features['logits_unbiased'] = x
|
196 |
+
remaining_features.remove('logits_unbiased')
|
197 |
+
if len(remaining_features) == 0:
|
198 |
+
return features
|
199 |
+
|
200 |
+
x = x + self.fc.bias.unsqueeze(0)
|
201 |
+
else:
|
202 |
+
x = self.fc(x)
|
203 |
+
# N x 1008 (num_classes)
|
204 |
+
|
205 |
+
features['logits'] = x
|
206 |
+
return features
|
207 |
+
|
208 |
+
@staticmethod
|
209 |
+
def get_provided_features_list():
|
210 |
+
return '64', '192', '768', '2048', 'logits_unbiased', 'logits'
|
211 |
+
|
212 |
+
@staticmethod
|
213 |
+
def get_default_feature_layer_for_metric(metric):
|
214 |
+
return {
|
215 |
+
'isc': 'logits_unbiased',
|
216 |
+
'fid': '2048',
|
217 |
+
'kid': '2048',
|
218 |
+
'prc': '2048',
|
219 |
+
}[metric]
|
220 |
+
|
221 |
+
@staticmethod
|
222 |
+
def can_be_compiled():
|
223 |
+
return True
|
224 |
+
|
225 |
+
@staticmethod
|
226 |
+
def get_dummy_input_for_compile():
|
227 |
+
return (torch.rand([1, 3, 4, 4]) * 255).to(torch.uint8)
|
228 |
+
|
229 |
+
def get_inception_model():
|
230 |
+
model = FeatureExtractorInceptionV3("inception_model", ["2048", "logits_unbiased"])
|
231 |
+
return model
|
imagenet_classes.py
ADDED
@@ -0,0 +1,1001 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
imagenet_idx2classname = {
|
2 |
+
0: 'tench, Tinca tinca',
|
3 |
+
1: 'goldfish, Carassius auratus',
|
4 |
+
2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
|
5 |
+
3: 'tiger shark, Galeocerdo cuvieri',
|
6 |
+
4: 'hammerhead, hammerhead shark',
|
7 |
+
5: 'electric ray, crampfish, numbfish, torpedo',
|
8 |
+
6: 'stingray',
|
9 |
+
7: 'cock',
|
10 |
+
8: 'hen',
|
11 |
+
9: 'ostrich, Struthio camelus',
|
12 |
+
10: 'brambling, Fringilla montifringilla',
|
13 |
+
11: 'goldfinch, Carduelis carduelis',
|
14 |
+
12: 'house finch, linnet, Carpodacus mexicanus',
|
15 |
+
13: 'junco, snowbird',
|
16 |
+
14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
|
17 |
+
15: 'robin, American robin, Turdus migratorius',
|
18 |
+
16: 'bulbul',
|
19 |
+
17: 'jay',
|
20 |
+
18: 'magpie',
|
21 |
+
19: 'chickadee',
|
22 |
+
20: 'water ouzel, dipper',
|
23 |
+
21: 'kite',
|
24 |
+
22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
|
25 |
+
23: 'vulture',
|
26 |
+
24: 'great grey owl, great gray owl, Strix nebulosa',
|
27 |
+
25: 'European fire salamander, Salamandra salamandra',
|
28 |
+
26: 'common newt, Triturus vulgaris',
|
29 |
+
27: 'eft',
|
30 |
+
28: 'spotted salamander, Ambystoma maculatum',
|
31 |
+
29: 'axolotl, mud puppy, Ambystoma mexicanum',
|
32 |
+
30: 'bullfrog, Rana catesbeiana',
|
33 |
+
31: 'tree frog, tree-frog',
|
34 |
+
32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
|
35 |
+
33: 'loggerhead, loggerhead turtle, Caretta caretta',
|
36 |
+
34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
|
37 |
+
35: 'mud turtle',
|
38 |
+
36: 'terrapin',
|
39 |
+
37: 'box turtle, box tortoise',
|
40 |
+
38: 'banded gecko',
|
41 |
+
39: 'common iguana, iguana, Iguana iguana',
|
42 |
+
40: 'American chameleon, anole, Anolis carolinensis',
|
43 |
+
41: 'whiptail, whiptail lizard',
|
44 |
+
42: 'agama',
|
45 |
+
43: 'frilled lizard, Chlamydosaurus kingi',
|
46 |
+
44: 'alligator lizard',
|
47 |
+
45: 'Gila monster, Heloderma suspectum',
|
48 |
+
46: 'green lizard, Lacerta viridis',
|
49 |
+
47: 'African chameleon, Chamaeleo chamaeleon',
|
50 |
+
48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
|
51 |
+
49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
|
52 |
+
50: 'American alligator, Alligator mississipiensis',
|
53 |
+
51: 'triceratops',
|
54 |
+
52: 'thunder snake, worm snake, Carphophis amoenus',
|
55 |
+
53: 'ringneck snake, ring-necked snake, ring snake',
|
56 |
+
54: 'hognose snake, puff adder, sand viper',
|
57 |
+
55: 'green snake, grass snake',
|
58 |
+
56: 'king snake, kingsnake',
|
59 |
+
57: 'garter snake, grass snake',
|
60 |
+
58: 'water snake',
|
61 |
+
59: 'vine snake',
|
62 |
+
60: 'night snake, Hypsiglena torquata',
|
63 |
+
61: 'boa constrictor, Constrictor constrictor',
|
64 |
+
62: 'rock python, rock snake, Python sebae',
|
65 |
+
63: 'Indian cobra, Naja naja',
|
66 |
+
64: 'green mamba',
|
67 |
+
65: 'sea snake',
|
68 |
+
66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
|
69 |
+
67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
|
70 |
+
68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
|
71 |
+
69: 'trilobite',
|
72 |
+
70: 'harvestman, daddy longlegs, Phalangium opilio',
|
73 |
+
71: 'scorpion',
|
74 |
+
72: 'black and gold garden spider, Argiope aurantia',
|
75 |
+
73: 'barn spider, Araneus cavaticus',
|
76 |
+
74: 'garden spider, Aranea diademata',
|
77 |
+
75: 'black widow, Latrodectus mactans',
|
78 |
+
76: 'tarantula',
|
79 |
+
77: 'wolf spider, hunting spider',
|
80 |
+
78: 'tick',
|
81 |
+
79: 'centipede',
|
82 |
+
80: 'black grouse',
|
83 |
+
81: 'ptarmigan',
|
84 |
+
82: 'ruffed grouse, partridge, Bonasa umbellus',
|
85 |
+
83: 'prairie chicken, prairie grouse, prairie fowl',
|
86 |
+
84: 'peacock',
|
87 |
+
85: 'quail',
|
88 |
+
86: 'partridge',
|
89 |
+
87: 'African grey, African gray, Psittacus erithacus',
|
90 |
+
88: 'macaw',
|
91 |
+
89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
|
92 |
+
90: 'lorikeet',
|
93 |
+
91: 'coucal',
|
94 |
+
92: 'bee eater',
|
95 |
+
93: 'hornbill',
|
96 |
+
94: 'hummingbird',
|
97 |
+
95: 'jacamar',
|
98 |
+
96: 'toucan',
|
99 |
+
97: 'drake',
|
100 |
+
98: 'red-breasted merganser, Mergus serrator',
|
101 |
+
99: 'goose',
|
102 |
+
100: 'black swan, Cygnus atratus',
|
103 |
+
101: 'tusker',
|
104 |
+
102: 'echidna, spiny anteater, anteater',
|
105 |
+
103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
|
106 |
+
104: 'wallaby, brush kangaroo',
|
107 |
+
105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
|
108 |
+
106: 'wombat',
|
109 |
+
107: 'jellyfish',
|
110 |
+
108: 'sea anemone, anemone',
|
111 |
+
109: 'brain coral',
|
112 |
+
110: 'flatworm, platyhelminth',
|
113 |
+
111: 'nematode, nematode worm, roundworm',
|
114 |
+
112: 'conch',
|
115 |
+
113: 'snail',
|
116 |
+
114: 'slug',
|
117 |
+
115: 'sea slug, nudibranch',
|
118 |
+
116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
|
119 |
+
117: 'chambered nautilus, pearly nautilus, nautilus',
|
120 |
+
118: 'Dungeness crab, Cancer magister',
|
121 |
+
119: 'rock crab, Cancer irroratus',
|
122 |
+
120: 'fiddler crab',
|
123 |
+
121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
|
124 |
+
122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
|
125 |
+
123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
|
126 |
+
124: 'crayfish, crawfish, crawdad, crawdaddy',
|
127 |
+
125: 'hermit crab',
|
128 |
+
126: 'isopod',
|
129 |
+
127: 'white stork, Ciconia ciconia',
|
130 |
+
128: 'black stork, Ciconia nigra',
|
131 |
+
129: 'spoonbill',
|
132 |
+
130: 'flamingo',
|
133 |
+
131: 'little blue heron, Egretta caerulea',
|
134 |
+
132: 'American egret, great white heron, Egretta albus',
|
135 |
+
133: 'bittern',
|
136 |
+
134: 'crane',
|
137 |
+
135: 'limpkin, Aramus pictus',
|
138 |
+
136: 'European gallinule, Porphyrio porphyrio',
|
139 |
+
137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
|
140 |
+
138: 'bustard',
|
141 |
+
139: 'ruddy turnstone, Arenaria interpres',
|
142 |
+
140: 'red-backed sandpiper, dunlin, Erolia alpina',
|
143 |
+
141: 'redshank, Tringa totanus',
|
144 |
+
142: 'dowitcher',
|
145 |
+
143: 'oystercatcher, oyster catcher',
|
146 |
+
144: 'pelican',
|
147 |
+
145: 'king penguin, Aptenodytes patagonica',
|
148 |
+
146: 'albatross, mollymawk',
|
149 |
+
147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
|
150 |
+
148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
|
151 |
+
149: 'dugong, Dugong dugon',
|
152 |
+
150: 'sea lion',
|
153 |
+
151: 'Chihuahua',
|
154 |
+
152: 'Japanese spaniel',
|
155 |
+
153: 'Maltese dog, Maltese terrier, Maltese',
|
156 |
+
154: 'Pekinese, Pekingese, Peke',
|
157 |
+
155: 'Shih-Tzu',
|
158 |
+
156: 'Blenheim spaniel',
|
159 |
+
157: 'papillon',
|
160 |
+
158: 'toy terrier',
|
161 |
+
159: 'Rhodesian ridgeback',
|
162 |
+
160: 'Afghan hound, Afghan',
|
163 |
+
161: 'basset, basset hound',
|
164 |
+
162: 'beagle',
|
165 |
+
163: 'bloodhound, sleuthhound',
|
166 |
+
164: 'bluetick',
|
167 |
+
165: 'black-and-tan coonhound',
|
168 |
+
166: 'Walker hound, Walker foxhound',
|
169 |
+
167: 'English foxhound',
|
170 |
+
168: 'redbone',
|
171 |
+
169: 'borzoi, Russian wolfhound',
|
172 |
+
170: 'Irish wolfhound',
|
173 |
+
171: 'Italian greyhound',
|
174 |
+
172: 'whippet',
|
175 |
+
173: 'Ibizan hound, Ibizan Podenco',
|
176 |
+
174: 'Norwegian elkhound, elkhound',
|
177 |
+
175: 'otterhound, otter hound',
|
178 |
+
176: 'Saluki, gazelle hound',
|
179 |
+
177: 'Scottish deerhound, deerhound',
|
180 |
+
178: 'Weimaraner',
|
181 |
+
179: 'Staffordshire bullterrier, Staffordshire bull terrier',
|
182 |
+
180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
|
183 |
+
181: 'Bedlington terrier',
|
184 |
+
182: 'Border terrier',
|
185 |
+
183: 'Kerry blue terrier',
|
186 |
+
184: 'Irish terrier',
|
187 |
+
185: 'Norfolk terrier',
|
188 |
+
186: 'Norwich terrier',
|
189 |
+
187: 'Yorkshire terrier',
|
190 |
+
188: 'wire-haired fox terrier',
|
191 |
+
189: 'Lakeland terrier',
|
192 |
+
190: 'Sealyham terrier, Sealyham',
|
193 |
+
191: 'Airedale, Airedale terrier',
|
194 |
+
192: 'cairn, cairn terrier',
|
195 |
+
193: 'Australian terrier',
|
196 |
+
194: 'Dandie Dinmont, Dandie Dinmont terrier',
|
197 |
+
195: 'Boston bull, Boston terrier',
|
198 |
+
196: 'miniature schnauzer',
|
199 |
+
197: 'giant schnauzer',
|
200 |
+
198: 'standard schnauzer',
|
201 |
+
199: 'Scotch terrier, Scottish terrier, Scottie',
|
202 |
+
200: 'Tibetan terrier, chrysanthemum dog',
|
203 |
+
201: 'silky terrier, Sydney silky',
|
204 |
+
202: 'soft-coated wheaten terrier',
|
205 |
+
203: 'West Highland white terrier',
|
206 |
+
204: 'Lhasa, Lhasa apso',
|
207 |
+
205: 'flat-coated retriever',
|
208 |
+
206: 'curly-coated retriever',
|
209 |
+
207: 'golden retriever',
|
210 |
+
208: 'Labrador retriever',
|
211 |
+
209: 'Chesapeake Bay retriever',
|
212 |
+
210: 'German short-haired pointer',
|
213 |
+
211: 'vizsla, Hungarian pointer',
|
214 |
+
212: 'English setter',
|
215 |
+
213: 'Irish setter, red setter',
|
216 |
+
214: 'Gordon setter',
|
217 |
+
215: 'Brittany spaniel',
|
218 |
+
216: 'clumber, clumber spaniel',
|
219 |
+
217: 'English springer, English springer spaniel',
|
220 |
+
218: 'Welsh springer spaniel',
|
221 |
+
219: 'cocker spaniel, English cocker spaniel, cocker',
|
222 |
+
220: 'Sussex spaniel',
|
223 |
+
221: 'Irish water spaniel',
|
224 |
+
222: 'kuvasz',
|
225 |
+
223: 'schipperke',
|
226 |
+
224: 'groenendael',
|
227 |
+
225: 'malinois',
|
228 |
+
226: 'briard',
|
229 |
+
227: 'kelpie',
|
230 |
+
228: 'komondor',
|
231 |
+
229: 'Old English sheepdog, bobtail',
|
232 |
+
230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
|
233 |
+
231: 'collie',
|
234 |
+
232: 'Border collie',
|
235 |
+
233: 'Bouvier des Flandres, Bouviers des Flandres',
|
236 |
+
234: 'Rottweiler',
|
237 |
+
235: 'German shepherd, German shepherd dog, German police dog, alsatian',
|
238 |
+
236: 'Doberman, Doberman pinscher',
|
239 |
+
237: 'miniature pinscher',
|
240 |
+
238: 'Greater Swiss Mountain dog',
|
241 |
+
239: 'Bernese mountain dog',
|
242 |
+
240: 'Appenzeller',
|
243 |
+
241: 'EntleBucher',
|
244 |
+
242: 'boxer',
|
245 |
+
243: 'bull mastiff',
|
246 |
+
244: 'Tibetan mastiff',
|
247 |
+
245: 'French bulldog',
|
248 |
+
246: 'Great Dane',
|
249 |
+
247: 'Saint Bernard, St Bernard',
|
250 |
+
248: 'Eskimo dog, husky',
|
251 |
+
249: 'malamute, malemute, Alaskan malamute',
|
252 |
+
250: 'Siberian husky',
|
253 |
+
251: 'dalmatian, coach dog, carriage dog',
|
254 |
+
252: 'affenpinscher, monkey pinscher, monkey dog',
|
255 |
+
253: 'basenji',
|
256 |
+
254: 'pug, pug-dog',
|
257 |
+
255: 'Leonberg',
|
258 |
+
256: 'Newfoundland, Newfoundland dog',
|
259 |
+
257: 'Great Pyrenees',
|
260 |
+
258: 'Samoyed, Samoyede',
|
261 |
+
259: 'Pomeranian',
|
262 |
+
260: 'chow, chow chow',
|
263 |
+
261: 'keeshond',
|
264 |
+
262: 'Brabancon griffon',
|
265 |
+
263: 'Pembroke, Pembroke Welsh corgi',
|
266 |
+
264: 'Cardigan, Cardigan Welsh corgi',
|
267 |
+
265: 'toy poodle',
|
268 |
+
266: 'miniature poodle',
|
269 |
+
267: 'standard poodle',
|
270 |
+
268: 'Mexican hairless',
|
271 |
+
269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
|
272 |
+
270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
|
273 |
+
271: 'red wolf, maned wolf, Canis rufus, Canis niger',
|
274 |
+
272: 'coyote, prairie wolf, brush wolf, Canis latrans',
|
275 |
+
273: 'dingo, warrigal, warragal, Canis dingo',
|
276 |
+
274: 'dhole, Cuon alpinus',
|
277 |
+
275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
|
278 |
+
276: 'hyena, hyaena',
|
279 |
+
277: 'red fox, Vulpes vulpes',
|
280 |
+
278: 'kit fox, Vulpes macrotis',
|
281 |
+
279: 'Arctic fox, white fox, Alopex lagopus',
|
282 |
+
280: 'grey fox, gray fox, Urocyon cinereoargenteus',
|
283 |
+
281: 'tabby, tabby cat',
|
284 |
+
282: 'tiger cat',
|
285 |
+
283: 'Persian cat',
|
286 |
+
284: 'Siamese cat, Siamese',
|
287 |
+
285: 'Egyptian cat',
|
288 |
+
286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
|
289 |
+
287: 'lynx, catamount',
|
290 |
+
288: 'leopard, Panthera pardus',
|
291 |
+
289: 'snow leopard, ounce, Panthera uncia',
|
292 |
+
290: 'jaguar, panther, Panthera onca, Felis onca',
|
293 |
+
291: 'lion, king of beasts, Panthera leo',
|
294 |
+
292: 'tiger, Panthera tigris',
|
295 |
+
293: 'cheetah, chetah, Acinonyx jubatus',
|
296 |
+
294: 'brown bear, bruin, Ursus arctos',
|
297 |
+
295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
|
298 |
+
296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
|
299 |
+
297: 'sloth bear, Melursus ursinus, Ursus ursinus',
|
300 |
+
298: 'mongoose',
|
301 |
+
299: 'meerkat, mierkat',
|
302 |
+
300: 'tiger beetle',
|
303 |
+
301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
|
304 |
+
302: 'ground beetle, carabid beetle',
|
305 |
+
303: 'long-horned beetle, longicorn, longicorn beetle',
|
306 |
+
304: 'leaf beetle, chrysomelid',
|
307 |
+
305: 'dung beetle',
|
308 |
+
306: 'rhinoceros beetle',
|
309 |
+
307: 'weevil',
|
310 |
+
308: 'fly',
|
311 |
+
309: 'bee',
|
312 |
+
310: 'ant, emmet, pismire',
|
313 |
+
311: 'grasshopper, hopper',
|
314 |
+
312: 'cricket',
|
315 |
+
313: 'walking stick, walkingstick, stick insect',
|
316 |
+
314: 'cockroach, roach',
|
317 |
+
315: 'mantis, mantid',
|
318 |
+
316: 'cicada, cicala',
|
319 |
+
317: 'leafhopper',
|
320 |
+
318: 'lacewing, lacewing fly',
|
321 |
+
319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
|
322 |
+
320: 'damselfly',
|
323 |
+
321: 'admiral',
|
324 |
+
322: 'ringlet, ringlet butterfly',
|
325 |
+
323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
|
326 |
+
324: 'cabbage butterfly',
|
327 |
+
325: 'sulphur butterfly, sulfur butterfly',
|
328 |
+
326: 'lycaenid, lycaenid butterfly',
|
329 |
+
327: 'starfish, sea star',
|
330 |
+
328: 'sea urchin',
|
331 |
+
329: 'sea cucumber, holothurian',
|
332 |
+
330: 'wood rabbit, cottontail, cottontail rabbit',
|
333 |
+
331: 'hare',
|
334 |
+
332: 'Angora, Angora rabbit',
|
335 |
+
333: 'hamster',
|
336 |
+
334: 'porcupine, hedgehog',
|
337 |
+
335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
|
338 |
+
336: 'marmot',
|
339 |
+
337: 'beaver',
|
340 |
+
338: 'guinea pig, Cavia cobaya',
|
341 |
+
339: 'sorrel',
|
342 |
+
340: 'zebra',
|
343 |
+
341: 'hog, pig, grunter, squealer, Sus scrofa',
|
344 |
+
342: 'wild boar, boar, Sus scrofa',
|
345 |
+
343: 'warthog',
|
346 |
+
344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
|
347 |
+
345: 'ox',
|
348 |
+
346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
|
349 |
+
347: 'bison',
|
350 |
+
348: 'ram, tup',
|
351 |
+
349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
|
352 |
+
350: 'ibex, Capra ibex',
|
353 |
+
351: 'hartebeest',
|
354 |
+
352: 'impala, Aepyceros melampus',
|
355 |
+
353: 'gazelle',
|
356 |
+
354: 'Arabian camel, dromedary, Camelus dromedarius',
|
357 |
+
355: 'llama',
|
358 |
+
356: 'weasel',
|
359 |
+
357: 'mink',
|
360 |
+
358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
|
361 |
+
359: 'black-footed ferret, ferret, Mustela nigripes',
|
362 |
+
360: 'otter',
|
363 |
+
361: 'skunk, polecat, wood pussy',
|
364 |
+
362: 'badger',
|
365 |
+
363: 'armadillo',
|
366 |
+
364: 'three-toed sloth, ai, Bradypus tridactylus',
|
367 |
+
365: 'orangutan, orang, orangutang, Pongo pygmaeus',
|
368 |
+
366: 'gorilla, Gorilla gorilla',
|
369 |
+
367: 'chimpanzee, chimp, Pan troglodytes',
|
370 |
+
368: 'gibbon, Hylobates lar',
|
371 |
+
369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
|
372 |
+
370: 'guenon, guenon monkey',
|
373 |
+
371: 'patas, hussar monkey, Erythrocebus patas',
|
374 |
+
372: 'baboon',
|
375 |
+
373: 'macaque',
|
376 |
+
374: 'langur',
|
377 |
+
375: 'colobus, colobus monkey',
|
378 |
+
376: 'proboscis monkey, Nasalis larvatus',
|
379 |
+
377: 'marmoset',
|
380 |
+
378: 'capuchin, ringtail, Cebus capucinus',
|
381 |
+
379: 'howler monkey, howler',
|
382 |
+
380: 'titi, titi monkey',
|
383 |
+
381: 'spider monkey, Ateles geoffroyi',
|
384 |
+
382: 'squirrel monkey, Saimiri sciureus',
|
385 |
+
383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
|
386 |
+
384: 'indri, indris, Indri indri, Indri brevicaudatus',
|
387 |
+
385: 'Indian elephant, Elephas maximus',
|
388 |
+
386: 'African elephant, Loxodonta africana',
|
389 |
+
387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
|
390 |
+
388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
|
391 |
+
389: 'barracouta, snoek',
|
392 |
+
390: 'eel',
|
393 |
+
391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
|
394 |
+
392: 'rock beauty, Holocanthus tricolor',
|
395 |
+
393: 'anemone fish',
|
396 |
+
394: 'sturgeon',
|
397 |
+
395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
|
398 |
+
396: 'lionfish',
|
399 |
+
397: 'puffer, pufferfish, blowfish, globefish',
|
400 |
+
398: 'abacus',
|
401 |
+
399: 'abaya',
|
402 |
+
400: "academic gown, academic robe, judge's robe",
|
403 |
+
401: 'accordion, piano accordion, squeeze box',
|
404 |
+
402: 'acoustic guitar',
|
405 |
+
403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
|
406 |
+
404: 'airliner',
|
407 |
+
405: 'airship, dirigible',
|
408 |
+
406: 'altar',
|
409 |
+
407: 'ambulance',
|
410 |
+
408: 'amphibian, amphibious vehicle',
|
411 |
+
409: 'analog clock',
|
412 |
+
410: 'apiary, bee house',
|
413 |
+
411: 'apron',
|
414 |
+
412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
|
415 |
+
413: 'assault rifle, assault gun',
|
416 |
+
414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
|
417 |
+
415: 'bakery, bakeshop, bakehouse',
|
418 |
+
416: 'balance beam, beam',
|
419 |
+
417: 'balloon',
|
420 |
+
418: 'ballpoint, ballpoint pen, ballpen, Biro',
|
421 |
+
419: 'Band Aid',
|
422 |
+
420: 'banjo',
|
423 |
+
421: 'bannister, banister, balustrade, balusters, handrail',
|
424 |
+
422: 'barbell',
|
425 |
+
423: 'barber chair',
|
426 |
+
424: 'barbershop',
|
427 |
+
425: 'barn',
|
428 |
+
426: 'barometer',
|
429 |
+
427: 'barrel, cask',
|
430 |
+
428: 'barrow, garden cart, lawn cart, wheelbarrow',
|
431 |
+
429: 'baseball',
|
432 |
+
430: 'basketball',
|
433 |
+
431: 'bassinet',
|
434 |
+
432: 'bassoon',
|
435 |
+
433: 'bathing cap, swimming cap',
|
436 |
+
434: 'bath towel',
|
437 |
+
435: 'bathtub, bathing tub, bath, tub',
|
438 |
+
436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
|
439 |
+
437: 'beacon, lighthouse, beacon light, pharos',
|
440 |
+
438: 'beaker',
|
441 |
+
439: 'bearskin, busby, shako',
|
442 |
+
440: 'beer bottle',
|
443 |
+
441: 'beer glass',
|
444 |
+
442: 'bell cote, bell cot',
|
445 |
+
443: 'bib',
|
446 |
+
444: 'bicycle-built-for-two, tandem bicycle, tandem',
|
447 |
+
445: 'bikini, two-piece',
|
448 |
+
446: 'binder, ring-binder',
|
449 |
+
447: 'binoculars, field glasses, opera glasses',
|
450 |
+
448: 'birdhouse',
|
451 |
+
449: 'boathouse',
|
452 |
+
450: 'bobsled, bobsleigh, bob',
|
453 |
+
451: 'bolo tie, bolo, bola tie, bola',
|
454 |
+
452: 'bonnet, poke bonnet',
|
455 |
+
453: 'bookcase',
|
456 |
+
454: 'bookshop, bookstore, bookstall',
|
457 |
+
455: 'bottlecap',
|
458 |
+
456: 'bow',
|
459 |
+
457: 'bow tie, bow-tie, bowtie',
|
460 |
+
458: 'brass, memorial tablet, plaque',
|
461 |
+
459: 'brassiere, bra, bandeau',
|
462 |
+
460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
|
463 |
+
461: 'breastplate, aegis, egis',
|
464 |
+
462: 'broom',
|
465 |
+
463: 'bucket, pail',
|
466 |
+
464: 'buckle',
|
467 |
+
465: 'bulletproof vest',
|
468 |
+
466: 'bullet train, bullet',
|
469 |
+
467: 'butcher shop, meat market',
|
470 |
+
468: 'cab, hack, taxi, taxicab',
|
471 |
+
469: 'caldron, cauldron',
|
472 |
+
470: 'candle, taper, wax light',
|
473 |
+
471: 'cannon',
|
474 |
+
472: 'canoe',
|
475 |
+
473: 'can opener, tin opener',
|
476 |
+
474: 'cardigan',
|
477 |
+
475: 'car mirror',
|
478 |
+
476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
|
479 |
+
477: "carpenter's kit, tool kit",
|
480 |
+
478: 'carton',
|
481 |
+
479: 'car wheel',
|
482 |
+
480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
|
483 |
+
481: 'cassette',
|
484 |
+
482: 'cassette player',
|
485 |
+
483: 'castle',
|
486 |
+
484: 'catamaran',
|
487 |
+
485: 'CD player',
|
488 |
+
486: 'cello, violoncello',
|
489 |
+
487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
|
490 |
+
488: 'chain',
|
491 |
+
489: 'chainlink fence',
|
492 |
+
490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
|
493 |
+
491: 'chain saw, chainsaw',
|
494 |
+
492: 'chest',
|
495 |
+
493: 'chiffonier, commode',
|
496 |
+
494: 'chime, bell, gong',
|
497 |
+
495: 'china cabinet, china closet',
|
498 |
+
496: 'Christmas stocking',
|
499 |
+
497: 'church, church building',
|
500 |
+
498: 'cinema, movie theater, movie theatre, movie house, picture palace',
|
501 |
+
499: 'cleaver, meat cleaver, chopper',
|
502 |
+
500: 'cliff dwelling',
|
503 |
+
501: 'cloak',
|
504 |
+
502: 'clog, geta, patten, sabot',
|
505 |
+
503: 'cocktail shaker',
|
506 |
+
504: 'coffee mug',
|
507 |
+
505: 'coffeepot',
|
508 |
+
506: 'coil, spiral, volute, whorl, helix',
|
509 |
+
507: 'combination lock',
|
510 |
+
508: 'computer keyboard, keypad',
|
511 |
+
509: 'confectionery, confectionary, candy store',
|
512 |
+
510: 'container ship, containership, container vessel',
|
513 |
+
511: 'convertible',
|
514 |
+
512: 'corkscrew, bottle screw',
|
515 |
+
513: 'cornet, horn, trumpet, trump',
|
516 |
+
514: 'cowboy boot',
|
517 |
+
515: 'cowboy hat, ten-gallon hat',
|
518 |
+
516: 'cradle',
|
519 |
+
517: 'crane',
|
520 |
+
518: 'crash helmet',
|
521 |
+
519: 'crate',
|
522 |
+
520: 'crib, cot',
|
523 |
+
521: 'Crock Pot',
|
524 |
+
522: 'croquet ball',
|
525 |
+
523: 'crutch',
|
526 |
+
524: 'cuirass',
|
527 |
+
525: 'dam, dike, dyke',
|
528 |
+
526: 'desk',
|
529 |
+
527: 'desktop computer',
|
530 |
+
528: 'dial telephone, dial phone',
|
531 |
+
529: 'diaper, nappy, napkin',
|
532 |
+
530: 'digital clock',
|
533 |
+
531: 'digital watch',
|
534 |
+
532: 'dining table, board',
|
535 |
+
533: 'dishrag, dishcloth',
|
536 |
+
534: 'dishwasher, dish washer, dishwashing machine',
|
537 |
+
535: 'disk brake, disc brake',
|
538 |
+
536: 'dock, dockage, docking facility',
|
539 |
+
537: 'dogsled, dog sled, dog sleigh',
|
540 |
+
538: 'dome',
|
541 |
+
539: 'doormat, welcome mat',
|
542 |
+
540: 'drilling platform, offshore rig',
|
543 |
+
541: 'drum, membranophone, tympan',
|
544 |
+
542: 'drumstick',
|
545 |
+
543: 'dumbbell',
|
546 |
+
544: 'Dutch oven',
|
547 |
+
545: 'electric fan, blower',
|
548 |
+
546: 'electric guitar',
|
549 |
+
547: 'electric locomotive',
|
550 |
+
548: 'entertainment center',
|
551 |
+
549: 'envelope',
|
552 |
+
550: 'espresso maker',
|
553 |
+
551: 'face powder',
|
554 |
+
552: 'feather boa, boa',
|
555 |
+
553: 'file, file cabinet, filing cabinet',
|
556 |
+
554: 'fireboat',
|
557 |
+
555: 'fire engine, fire truck',
|
558 |
+
556: 'fire screen, fireguard',
|
559 |
+
557: 'flagpole, flagstaff',
|
560 |
+
558: 'flute, transverse flute',
|
561 |
+
559: 'folding chair',
|
562 |
+
560: 'football helmet',
|
563 |
+
561: 'forklift',
|
564 |
+
562: 'fountain',
|
565 |
+
563: 'fountain pen',
|
566 |
+
564: 'four-poster',
|
567 |
+
565: 'freight car',
|
568 |
+
566: 'French horn, horn',
|
569 |
+
567: 'frying pan, frypan, skillet',
|
570 |
+
568: 'fur coat',
|
571 |
+
569: 'garbage truck, dustcart',
|
572 |
+
570: 'gasmask, respirator, gas helmet',
|
573 |
+
571: 'gas pump, gasoline pump, petrol pump, island dispenser',
|
574 |
+
572: 'goblet',
|
575 |
+
573: 'go-kart',
|
576 |
+
574: 'golf ball',
|
577 |
+
575: 'golfcart, golf cart',
|
578 |
+
576: 'gondola',
|
579 |
+
577: 'gong, tam-tam',
|
580 |
+
578: 'gown',
|
581 |
+
579: 'grand piano, grand',
|
582 |
+
580: 'greenhouse, nursery, glasshouse',
|
583 |
+
581: 'grille, radiator grille',
|
584 |
+
582: 'grocery store, grocery, food market, market',
|
585 |
+
583: 'guillotine',
|
586 |
+
584: 'hair slide',
|
587 |
+
585: 'hair spray',
|
588 |
+
586: 'half track',
|
589 |
+
587: 'hammer',
|
590 |
+
588: 'hamper',
|
591 |
+
589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
|
592 |
+
590: 'hand-held computer, hand-held microcomputer',
|
593 |
+
591: 'handkerchief, hankie, hanky, hankey',
|
594 |
+
592: 'hard disc, hard disk, fixed disk',
|
595 |
+
593: 'harmonica, mouth organ, harp, mouth harp',
|
596 |
+
594: 'harp',
|
597 |
+
595: 'harvester, reaper',
|
598 |
+
596: 'hatchet',
|
599 |
+
597: 'holster',
|
600 |
+
598: 'home theater, home theatre',
|
601 |
+
599: 'honeycomb',
|
602 |
+
600: 'hook, claw',
|
603 |
+
601: 'hoopskirt, crinoline',
|
604 |
+
602: 'horizontal bar, high bar',
|
605 |
+
603: 'horse cart, horse-cart',
|
606 |
+
604: 'hourglass',
|
607 |
+
605: 'iPod',
|
608 |
+
606: 'iron, smoothing iron',
|
609 |
+
607: "jack-o'-lantern",
|
610 |
+
608: 'jean, blue jean, denim',
|
611 |
+
609: 'jeep, landrover',
|
612 |
+
610: 'jersey, T-shirt, tee shirt',
|
613 |
+
611: 'jigsaw puzzle',
|
614 |
+
612: 'jinrikisha, ricksha, rickshaw',
|
615 |
+
613: 'joystick',
|
616 |
+
614: 'kimono',
|
617 |
+
615: 'knee pad',
|
618 |
+
616: 'knot',
|
619 |
+
617: 'lab coat, laboratory coat',
|
620 |
+
618: 'ladle',
|
621 |
+
619: 'lampshade, lamp shade',
|
622 |
+
620: 'laptop, laptop computer',
|
623 |
+
621: 'lawn mower, mower',
|
624 |
+
622: 'lens cap, lens cover',
|
625 |
+
623: 'letter opener, paper knife, paperknife',
|
626 |
+
624: 'library',
|
627 |
+
625: 'lifeboat',
|
628 |
+
626: 'lighter, light, igniter, ignitor',
|
629 |
+
627: 'limousine, limo',
|
630 |
+
628: 'liner, ocean liner',
|
631 |
+
629: 'lipstick, lip rouge',
|
632 |
+
630: 'Loafer',
|
633 |
+
631: 'lotion',
|
634 |
+
632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
|
635 |
+
633: "loupe, jeweler's loupe",
|
636 |
+
634: 'lumbermill, sawmill',
|
637 |
+
635: 'magnetic compass',
|
638 |
+
636: 'mailbag, postbag',
|
639 |
+
637: 'mailbox, letter box',
|
640 |
+
638: 'maillot',
|
641 |
+
639: 'maillot, tank suit',
|
642 |
+
640: 'manhole cover',
|
643 |
+
641: 'maraca',
|
644 |
+
642: 'marimba, xylophone',
|
645 |
+
643: 'mask',
|
646 |
+
644: 'matchstick',
|
647 |
+
645: 'maypole',
|
648 |
+
646: 'maze, labyrinth',
|
649 |
+
647: 'measuring cup',
|
650 |
+
648: 'medicine chest, medicine cabinet',
|
651 |
+
649: 'megalith, megalithic structure',
|
652 |
+
650: 'microphone, mike',
|
653 |
+
651: 'microwave, microwave oven',
|
654 |
+
652: 'military uniform',
|
655 |
+
653: 'milk can',
|
656 |
+
654: 'minibus',
|
657 |
+
655: 'miniskirt, mini',
|
658 |
+
656: 'minivan',
|
659 |
+
657: 'missile',
|
660 |
+
658: 'mitten',
|
661 |
+
659: 'mixing bowl',
|
662 |
+
660: 'mobile home, manufactured home',
|
663 |
+
661: 'Model T',
|
664 |
+
662: 'modem',
|
665 |
+
663: 'monastery',
|
666 |
+
664: 'monitor',
|
667 |
+
665: 'moped',
|
668 |
+
666: 'mortar',
|
669 |
+
667: 'mortarboard',
|
670 |
+
668: 'mosque',
|
671 |
+
669: 'mosquito net',
|
672 |
+
670: 'motor scooter, scooter',
|
673 |
+
671: 'mountain bike, all-terrain bike, off-roader',
|
674 |
+
672: 'mountain tent',
|
675 |
+
673: 'mouse, computer mouse',
|
676 |
+
674: 'mousetrap',
|
677 |
+
675: 'moving van',
|
678 |
+
676: 'muzzle',
|
679 |
+
677: 'nail',
|
680 |
+
678: 'neck brace',
|
681 |
+
679: 'necklace',
|
682 |
+
680: 'nipple',
|
683 |
+
681: 'notebook, notebook computer',
|
684 |
+
682: 'obelisk',
|
685 |
+
683: 'oboe, hautboy, hautbois',
|
686 |
+
684: 'ocarina, sweet potato',
|
687 |
+
685: 'odometer, hodometer, mileometer, milometer',
|
688 |
+
686: 'oil filter',
|
689 |
+
687: 'organ, pipe organ',
|
690 |
+
688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
|
691 |
+
689: 'overskirt',
|
692 |
+
690: 'oxcart',
|
693 |
+
691: 'oxygen mask',
|
694 |
+
692: 'packet',
|
695 |
+
693: 'paddle, boat paddle',
|
696 |
+
694: 'paddlewheel, paddle wheel',
|
697 |
+
695: 'padlock',
|
698 |
+
696: 'paintbrush',
|
699 |
+
697: "pajama, pyjama, pj's, jammies",
|
700 |
+
698: 'palace',
|
701 |
+
699: 'panpipe, pandean pipe, syrinx',
|
702 |
+
700: 'paper towel',
|
703 |
+
701: 'parachute, chute',
|
704 |
+
702: 'parallel bars, bars',
|
705 |
+
703: 'park bench',
|
706 |
+
704: 'parking meter',
|
707 |
+
705: 'passenger car, coach, carriage',
|
708 |
+
706: 'patio, terrace',
|
709 |
+
707: 'pay-phone, pay-station',
|
710 |
+
708: 'pedestal, plinth, footstall',
|
711 |
+
709: 'pencil box, pencil case',
|
712 |
+
710: 'pencil sharpener',
|
713 |
+
711: 'perfume, essence',
|
714 |
+
712: 'Petri dish',
|
715 |
+
713: 'photocopier',
|
716 |
+
714: 'pick, plectrum, plectron',
|
717 |
+
715: 'pickelhaube',
|
718 |
+
716: 'picket fence, paling',
|
719 |
+
717: 'pickup, pickup truck',
|
720 |
+
718: 'pier',
|
721 |
+
719: 'piggy bank, penny bank',
|
722 |
+
720: 'pill bottle',
|
723 |
+
721: 'pillow',
|
724 |
+
722: 'ping-pong ball',
|
725 |
+
723: 'pinwheel',
|
726 |
+
724: 'pirate, pirate ship',
|
727 |
+
725: 'pitcher, ewer',
|
728 |
+
726: "plane, carpenter's plane, woodworking plane",
|
729 |
+
727: 'planetarium',
|
730 |
+
728: 'plastic bag',
|
731 |
+
729: 'plate rack',
|
732 |
+
730: 'plow, plough',
|
733 |
+
731: "plunger, plumber's helper",
|
734 |
+
732: 'Polaroid camera, Polaroid Land camera',
|
735 |
+
733: 'pole',
|
736 |
+
734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
|
737 |
+
735: 'poncho',
|
738 |
+
736: 'pool table, billiard table, snooker table',
|
739 |
+
737: 'pop bottle, soda bottle',
|
740 |
+
738: 'pot, flowerpot',
|
741 |
+
739: "potter's wheel",
|
742 |
+
740: 'power drill',
|
743 |
+
741: 'prayer rug, prayer mat',
|
744 |
+
742: 'printer',
|
745 |
+
743: 'prison, prison house',
|
746 |
+
744: 'projectile, missile',
|
747 |
+
745: 'projector',
|
748 |
+
746: 'puck, hockey puck',
|
749 |
+
747: 'punching bag, punch bag, punching ball, punchball',
|
750 |
+
748: 'purse',
|
751 |
+
749: 'quill, quill pen',
|
752 |
+
750: 'quilt, comforter, comfort, puff',
|
753 |
+
751: 'racer, race car, racing car',
|
754 |
+
752: 'racket, racquet',
|
755 |
+
753: 'radiator',
|
756 |
+
754: 'radio, wireless',
|
757 |
+
755: 'radio telescope, radio reflector',
|
758 |
+
756: 'rain barrel',
|
759 |
+
757: 'recreational vehicle, RV, R.V.',
|
760 |
+
758: 'reel',
|
761 |
+
759: 'reflex camera',
|
762 |
+
760: 'refrigerator, icebox',
|
763 |
+
761: 'remote control, remote',
|
764 |
+
762: 'restaurant, eating house, eating place, eatery',
|
765 |
+
763: 'revolver, six-gun, six-shooter',
|
766 |
+
764: 'rifle',
|
767 |
+
765: 'rocking chair, rocker',
|
768 |
+
766: 'rotisserie',
|
769 |
+
767: 'rubber eraser, rubber, pencil eraser',
|
770 |
+
768: 'rugby ball',
|
771 |
+
769: 'rule, ruler',
|
772 |
+
770: 'running shoe',
|
773 |
+
771: 'safe',
|
774 |
+
772: 'safety pin',
|
775 |
+
773: 'saltshaker, salt shaker',
|
776 |
+
774: 'sandal',
|
777 |
+
775: 'sarong',
|
778 |
+
776: 'sax, saxophone',
|
779 |
+
777: 'scabbard',
|
780 |
+
778: 'scale, weighing machine',
|
781 |
+
779: 'school bus',
|
782 |
+
780: 'schooner',
|
783 |
+
781: 'scoreboard',
|
784 |
+
782: 'screen, CRT screen',
|
785 |
+
783: 'screw',
|
786 |
+
784: 'screwdriver',
|
787 |
+
785: 'seat belt, seatbelt',
|
788 |
+
786: 'sewing machine',
|
789 |
+
787: 'shield, buckler',
|
790 |
+
788: 'shoe shop, shoe-shop, shoe store',
|
791 |
+
789: 'shoji',
|
792 |
+
790: 'shopping basket',
|
793 |
+
791: 'shopping cart',
|
794 |
+
792: 'shovel',
|
795 |
+
793: 'shower cap',
|
796 |
+
794: 'shower curtain',
|
797 |
+
795: 'ski',
|
798 |
+
796: 'ski mask',
|
799 |
+
797: 'sleeping bag',
|
800 |
+
798: 'slide rule, slipstick',
|
801 |
+
799: 'sliding door',
|
802 |
+
800: 'slot, one-armed bandit',
|
803 |
+
801: 'snorkel',
|
804 |
+
802: 'snowmobile',
|
805 |
+
803: 'snowplow, snowplough',
|
806 |
+
804: 'soap dispenser',
|
807 |
+
805: 'soccer ball',
|
808 |
+
806: 'sock',
|
809 |
+
807: 'solar dish, solar collector, solar furnace',
|
810 |
+
808: 'sombrero',
|
811 |
+
809: 'soup bowl',
|
812 |
+
810: 'space bar',
|
813 |
+
811: 'space heater',
|
814 |
+
812: 'space shuttle',
|
815 |
+
813: 'spatula',
|
816 |
+
814: 'speedboat',
|
817 |
+
815: "spider web, spider's web",
|
818 |
+
816: 'spindle',
|
819 |
+
817: 'sports car, sport car',
|
820 |
+
818: 'spotlight, spot',
|
821 |
+
819: 'stage',
|
822 |
+
820: 'steam locomotive',
|
823 |
+
821: 'steel arch bridge',
|
824 |
+
822: 'steel drum',
|
825 |
+
823: 'stethoscope',
|
826 |
+
824: 'stole',
|
827 |
+
825: 'stone wall',
|
828 |
+
826: 'stopwatch, stop watch',
|
829 |
+
827: 'stove',
|
830 |
+
828: 'strainer',
|
831 |
+
829: 'streetcar, tram, tramcar, trolley, trolley car',
|
832 |
+
830: 'stretcher',
|
833 |
+
831: 'studio couch, day bed',
|
834 |
+
832: 'stupa, tope',
|
835 |
+
833: 'submarine, pigboat, sub, U-boat',
|
836 |
+
834: 'suit, suit of clothes',
|
837 |
+
835: 'sundial',
|
838 |
+
836: 'sunglass',
|
839 |
+
837: 'sunglasses, dark glasses, shades',
|
840 |
+
838: 'sunscreen, sunblock, sun blocker',
|
841 |
+
839: 'suspension bridge',
|
842 |
+
840: 'swab, swob, mop',
|
843 |
+
841: 'sweatshirt',
|
844 |
+
842: 'swimming trunks, bathing trunks',
|
845 |
+
843: 'swing',
|
846 |
+
844: 'switch, electric switch, electrical switch',
|
847 |
+
845: 'syringe',
|
848 |
+
846: 'table lamp',
|
849 |
+
847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
|
850 |
+
848: 'tape player',
|
851 |
+
849: 'teapot',
|
852 |
+
850: 'teddy, teddy bear',
|
853 |
+
851: 'television, television system',
|
854 |
+
852: 'tennis ball',
|
855 |
+
853: 'thatch, thatched roof',
|
856 |
+
854: 'theater curtain, theatre curtain',
|
857 |
+
855: 'thimble',
|
858 |
+
856: 'thresher, thrasher, threshing machine',
|
859 |
+
857: 'throne',
|
860 |
+
858: 'tile roof',
|
861 |
+
859: 'toaster',
|
862 |
+
860: 'tobacco shop, tobacconist shop, tobacconist',
|
863 |
+
861: 'toilet seat',
|
864 |
+
862: 'torch',
|
865 |
+
863: 'totem pole',
|
866 |
+
864: 'tow truck, tow car, wrecker',
|
867 |
+
865: 'toyshop',
|
868 |
+
866: 'tractor',
|
869 |
+
867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
|
870 |
+
868: 'tray',
|
871 |
+
869: 'trench coat',
|
872 |
+
870: 'tricycle, trike, velocipede',
|
873 |
+
871: 'trimaran',
|
874 |
+
872: 'tripod',
|
875 |
+
873: 'triumphal arch',
|
876 |
+
874: 'trolleybus, trolley coach, trackless trolley',
|
877 |
+
875: 'trombone',
|
878 |
+
876: 'tub, vat',
|
879 |
+
877: 'turnstile',
|
880 |
+
878: 'typewriter keyboard',
|
881 |
+
879: 'umbrella',
|
882 |
+
880: 'unicycle, monocycle',
|
883 |
+
881: 'upright, upright piano',
|
884 |
+
882: 'vacuum, vacuum cleaner',
|
885 |
+
883: 'vase',
|
886 |
+
884: 'vault',
|
887 |
+
885: 'velvet',
|
888 |
+
886: 'vending machine',
|
889 |
+
887: 'vestment',
|
890 |
+
888: 'viaduct',
|
891 |
+
889: 'violin, fiddle',
|
892 |
+
890: 'volleyball',
|
893 |
+
891: 'waffle iron',
|
894 |
+
892: 'wall clock',
|
895 |
+
893: 'wallet, billfold, notecase, pocketbook',
|
896 |
+
894: 'wardrobe, closet, press',
|
897 |
+
895: 'warplane, military plane',
|
898 |
+
896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
|
899 |
+
897: 'washer, automatic washer, washing machine',
|
900 |
+
898: 'water bottle',
|
901 |
+
899: 'water jug',
|
902 |
+
900: 'water tower',
|
903 |
+
901: 'whiskey jug',
|
904 |
+
902: 'whistle',
|
905 |
+
903: 'wig',
|
906 |
+
904: 'window screen',
|
907 |
+
905: 'window shade',
|
908 |
+
906: 'Windsor tie',
|
909 |
+
907: 'wine bottle',
|
910 |
+
908: 'wing',
|
911 |
+
909: 'wok',
|
912 |
+
910: 'wooden spoon',
|
913 |
+
911: 'wool, woolen, woollen',
|
914 |
+
912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
|
915 |
+
913: 'wreck',
|
916 |
+
914: 'yawl',
|
917 |
+
915: 'yurt',
|
918 |
+
916: 'web site, website, internet site, site',
|
919 |
+
917: 'comic book',
|
920 |
+
918: 'crossword puzzle, crossword',
|
921 |
+
919: 'street sign',
|
922 |
+
920: 'traffic light, traffic signal, stoplight',
|
923 |
+
921: 'book jacket, dust cover, dust jacket, dust wrapper',
|
924 |
+
922: 'menu',
|
925 |
+
923: 'plate',
|
926 |
+
924: 'guacamole',
|
927 |
+
925: 'consomme',
|
928 |
+
926: 'hot pot, hotpot',
|
929 |
+
927: 'trifle',
|
930 |
+
928: 'ice cream, icecream',
|
931 |
+
929: 'ice lolly, lolly, lollipop, popsicle',
|
932 |
+
930: 'French loaf',
|
933 |
+
931: 'bagel, beigel',
|
934 |
+
932: 'pretzel',
|
935 |
+
933: 'cheeseburger',
|
936 |
+
934: 'hotdog, hot dog, red hot',
|
937 |
+
935: 'mashed potato',
|
938 |
+
936: 'head cabbage',
|
939 |
+
937: 'broccoli',
|
940 |
+
938: 'cauliflower',
|
941 |
+
939: 'zucchini, courgette',
|
942 |
+
940: 'spaghetti squash',
|
943 |
+
941: 'acorn squash',
|
944 |
+
942: 'butternut squash',
|
945 |
+
943: 'cucumber, cuke',
|
946 |
+
944: 'artichoke, globe artichoke',
|
947 |
+
945: 'bell pepper',
|
948 |
+
946: 'cardoon',
|
949 |
+
947: 'mushroom',
|
950 |
+
948: 'Granny Smith',
|
951 |
+
949: 'strawberry',
|
952 |
+
950: 'orange',
|
953 |
+
951: 'lemon',
|
954 |
+
952: 'fig',
|
955 |
+
953: 'pineapple, ananas',
|
956 |
+
954: 'banana',
|
957 |
+
955: 'jackfruit, jak, jack',
|
958 |
+
956: 'custard apple',
|
959 |
+
957: 'pomegranate',
|
960 |
+
958: 'hay',
|
961 |
+
959: 'carbonara',
|
962 |
+
960: 'chocolate sauce, chocolate syrup',
|
963 |
+
961: 'dough',
|
964 |
+
962: 'meat loaf, meatloaf',
|
965 |
+
963: 'pizza, pizza pie',
|
966 |
+
964: 'potpie',
|
967 |
+
965: 'burrito',
|
968 |
+
966: 'red wine',
|
969 |
+
967: 'espresso',
|
970 |
+
968: 'cup',
|
971 |
+
969: 'eggnog',
|
972 |
+
970: 'alp',
|
973 |
+
971: 'bubble',
|
974 |
+
972: 'cliff, drop, drop-off',
|
975 |
+
973: 'coral reef',
|
976 |
+
974: 'geyser',
|
977 |
+
975: 'lakeside, lakeshore',
|
978 |
+
976: 'promontory, headland, head, foreland',
|
979 |
+
977: 'sandbar, sand bar',
|
980 |
+
978: 'seashore, coast, seacoast, sea-coast',
|
981 |
+
979: 'valley, vale',
|
982 |
+
980: 'volcano',
|
983 |
+
981: 'ballplayer, baseball player',
|
984 |
+
982: 'groom, bridegroom',
|
985 |
+
983: 'scuba diver',
|
986 |
+
984: 'rapeseed',
|
987 |
+
985: 'daisy',
|
988 |
+
986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
|
989 |
+
987: 'corn',
|
990 |
+
988: 'acorn',
|
991 |
+
989: 'hip, rose hip, rosehip',
|
992 |
+
990: 'buckeye, horse chestnut, conker',
|
993 |
+
991: 'coral fungus',
|
994 |
+
992: 'agaric',
|
995 |
+
993: 'gyromitra',
|
996 |
+
994: 'stinkhorn, carrion fungus',
|
997 |
+
995: 'earthstar',
|
998 |
+
996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
|
999 |
+
997: 'bolete',
|
1000 |
+
998: 'ear, spike, capitulum',
|
1001 |
+
999: 'toilet tissue, toilet paper, bathroom tissue'}
|
modeling/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (2024) Bytedance Ltd. and/or its affiliates
|
3 |
+
|
4 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
you may not use this file except in compliance with the License.
|
6 |
+
You may obtain a copy of the License at
|
7 |
+
|
8 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
Unless required by applicable law or agreed to in writing, software
|
11 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
See the License for the specific language governing permissions and
|
14 |
+
limitations under the License.
|
15 |
+
"""
|
modeling/maskgit.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This file contains implementation for MaskGIT model.
|
2 |
+
|
3 |
+
Copyright (2024) Bytedance Ltd. and/or its affiliates
|
4 |
+
|
5 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
you may not use this file except in compliance with the License.
|
7 |
+
You may obtain a copy of the License at
|
8 |
+
|
9 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
|
11 |
+
Unless required by applicable law or agreed to in writing, software
|
12 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
See the License for the specific language governing permissions and
|
15 |
+
limitations under the License.
|
16 |
+
|
17 |
+
Reference:
|
18 |
+
https://github.com/huggingface/open-muse
|
19 |
+
https://github.com/baaivision/MUSE-Pytorch
|
20 |
+
https://github.com/sail-sg/MDT/blob/main/masked_diffusion/models.py
|
21 |
+
"""
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from torch import nn
|
25 |
+
import numpy as np
|
26 |
+
import math
|
27 |
+
import torch.utils.checkpoint
|
28 |
+
from transformers import BertConfig, BertModel
|
29 |
+
from einops import rearrange
|
30 |
+
|
31 |
+
import json
|
32 |
+
from huggingface_hub import PyTorchModelHubMixin
|
33 |
+
from omegaconf import OmegaConf
|
34 |
+
from pathlib import Path
|
35 |
+
|
36 |
+
from modeling.modules.base_model import BaseModel
|
37 |
+
from modeling.modules.blocks import UViTBlock
|
38 |
+
|
39 |
+
|
40 |
+
class ImageBert(BaseModel, PyTorchModelHubMixin, tags=["arxiv:2406.07550", "image-generation"], repo_url="https://github.com/bytedance/1d-tokenizer", license="apache-2.0"):
|
41 |
+
def __init__(self, config):
|
42 |
+
|
43 |
+
if isinstance(config, dict):
|
44 |
+
config = OmegaConf.create(config)
|
45 |
+
|
46 |
+
super().__init__()
|
47 |
+
self.config = config
|
48 |
+
self.target_codebook_size = config.model.vq_model.codebook_size
|
49 |
+
self.condition_num_classes = config.model.generator.condition_num_classes
|
50 |
+
self.image_seq_len = config.model.generator.image_seq_len
|
51 |
+
self.mask_token_id = self.target_codebook_size
|
52 |
+
self.hidden_size = config.model.generator.hidden_size
|
53 |
+
self.num_hidden_layers = config.model.generator.num_hidden_layers
|
54 |
+
self.num_attention_heads = config.model.generator.num_attention_heads
|
55 |
+
self.intermediate_size = config.model.generator.intermediate_size
|
56 |
+
|
57 |
+
self.model = BertModel(BertConfig(
|
58 |
+
vocab_size=self.target_codebook_size + self.condition_num_classes + 2,
|
59 |
+
hidden_size=self.hidden_size,
|
60 |
+
num_hidden_layers=self.num_hidden_layers,
|
61 |
+
num_attention_heads=self.num_attention_heads,
|
62 |
+
intermediate_size=self.intermediate_size,
|
63 |
+
hidden_act='gelu',
|
64 |
+
hidden_dropout_prob=config.model.generator.dropout,
|
65 |
+
attention_probs_dropout_prob=config.model.generator.attn_drop,
|
66 |
+
max_position_embeddings=config.model.generator.image_seq_len + 1,
|
67 |
+
initializer_range=0.02,
|
68 |
+
layer_norm_eps=1e-12,
|
69 |
+
pad_token_id=None,
|
70 |
+
position_embedding_type="absolute",
|
71 |
+
use_cache=True
|
72 |
+
), add_pooling_layer=False)
|
73 |
+
self.model.lm_head = nn.Linear(self.hidden_size, self.target_codebook_size, bias=True)
|
74 |
+
|
75 |
+
self.model.post_init()
|
76 |
+
|
77 |
+
def _save_pretrained(self, save_directory: Path) -> None:
|
78 |
+
"""Save weights and config to a local directory."""
|
79 |
+
# Assume 'self.config' is your DictConfig object
|
80 |
+
# Convert to a regular dictionary
|
81 |
+
dict_config = OmegaConf.to_container(self.config)
|
82 |
+
# Save as JSON
|
83 |
+
file_path = Path(save_directory) / "config.json"
|
84 |
+
with open(file_path, 'w') as json_file:
|
85 |
+
json.dump(dict_config, json_file, indent=4)
|
86 |
+
super()._save_pretrained(save_directory)
|
87 |
+
|
88 |
+
def forward(self, input_ids=None, condition=None, cond_drop_prob=0.1):
|
89 |
+
# Token space:
|
90 |
+
# [0, codebook_size - 1] : those are the learned quantized image tokens
|
91 |
+
# codebook_size : the mask token used to mask image tokens
|
92 |
+
# [codebook_size + 1, codebook_size + nclass] : the imagenet class tokens
|
93 |
+
# codebook_size + 1 + nclass : the class drop label
|
94 |
+
drop_label_mask = torch.rand_like(condition, dtype=torch.float) < cond_drop_prob
|
95 |
+
# Shift the classes
|
96 |
+
condition = condition + self.target_codebook_size + 1 # [0, 999] -> [codebook_size + 1, codebook_size + 999]
|
97 |
+
condition[drop_label_mask] = self.condition_num_classes + self.target_codebook_size + 1
|
98 |
+
# prepend condition token
|
99 |
+
if input_ids is not None:
|
100 |
+
input_ids = torch.cat([condition.view(condition.shape[0], -1),
|
101 |
+
input_ids.view(input_ids.shape[0], -1),], dim=1)
|
102 |
+
else:
|
103 |
+
# at least there should be masked token
|
104 |
+
raise NotImplementedError
|
105 |
+
model_output = self.model(input_ids=input_ids)
|
106 |
+
model_output = model_output[0]
|
107 |
+
return self.model.lm_head(model_output[:, 1:]) # remove cond
|
108 |
+
|
109 |
+
# ref: https://github.com/baaivision/MUSE-Pytorch/blob/master/libs/muse.py#L40
|
110 |
+
@torch.no_grad()
|
111 |
+
def generate(self,
|
112 |
+
condition,
|
113 |
+
guidance_scale=3.0,
|
114 |
+
guidance_decay="constant",
|
115 |
+
guidance_scale_pow=3.0,
|
116 |
+
randomize_temperature=4.5,
|
117 |
+
softmax_temperature_annealing=False,
|
118 |
+
num_sample_steps=8):
|
119 |
+
if guidance_decay not in ["constant", "linear", "power-cosine"]:
|
120 |
+
# contstant: constant guidance scale
|
121 |
+
# linear: linear increasing the guidance scale as in MUSE
|
122 |
+
# power-cosine: the guidance schedule from MDT
|
123 |
+
raise ValueError(f"Unsupported guidance decay {guidance_decay}")
|
124 |
+
device = condition.device
|
125 |
+
ids = torch.full((condition.shape[0], self.image_seq_len),
|
126 |
+
self.mask_token_id, device=device)
|
127 |
+
|
128 |
+
cfg_scale = guidance_scale if guidance_decay == "constant" else 0.
|
129 |
+
|
130 |
+
for step in range(num_sample_steps):
|
131 |
+
ratio = 1. * (step + 1) / num_sample_steps
|
132 |
+
annealed_temp = randomize_temperature * (1.0 - ratio)
|
133 |
+
is_mask = (ids == self.mask_token_id)
|
134 |
+
|
135 |
+
if guidance_decay == "power-cosine":
|
136 |
+
# ref: https://github.com/sail-sg/MDT/blob/main/masked_diffusion/models.py#L501
|
137 |
+
guidance_scale_pow = torch.ones((1), device=device) * guidance_scale_pow
|
138 |
+
scale_step = (1 - torch.cos(((step / num_sample_steps) ** guidance_scale_pow) * torch.pi)) * 1/2
|
139 |
+
cfg_scale = (guidance_scale - 1) * scale_step + 1
|
140 |
+
|
141 |
+
if cfg_scale != 0:
|
142 |
+
cond_logits = self.forward(
|
143 |
+
ids, condition, cond_drop_prob=0.0
|
144 |
+
)
|
145 |
+
uncond_logits = self.forward(
|
146 |
+
ids, condition, cond_drop_prob=1.0
|
147 |
+
)
|
148 |
+
if guidance_decay == "power-cosine":
|
149 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
150 |
+
else:
|
151 |
+
logits = cond_logits + (cond_logits - uncond_logits) * cfg_scale
|
152 |
+
else:
|
153 |
+
logits = self.forward(
|
154 |
+
ids, condition, cond_drop_prob=0.0
|
155 |
+
)
|
156 |
+
|
157 |
+
if softmax_temperature_annealing:
|
158 |
+
softmax_temperature = 0.5 + 0.8 * (1 - ratio)
|
159 |
+
logits = logits / softmax_temperature
|
160 |
+
|
161 |
+
# Add gumbel noise
|
162 |
+
def log(t, eps=1e-20):
|
163 |
+
return torch.log(t.clamp(min=eps))
|
164 |
+
def gumbel_noise(t):
|
165 |
+
noise = torch.zeros_like(t).uniform_(0, 1)
|
166 |
+
return -log(-log(noise))
|
167 |
+
def add_gumbel_noise(t, temperature):
|
168 |
+
return t + temperature * gumbel_noise(t)
|
169 |
+
|
170 |
+
sampled_ids = add_gumbel_noise(logits, annealed_temp).argmax(dim=-1)
|
171 |
+
sampled_logits = torch.squeeze(
|
172 |
+
torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
|
173 |
+
sampled_ids = torch.where(is_mask, sampled_ids, ids)
|
174 |
+
sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float()
|
175 |
+
# masking
|
176 |
+
mask_ratio = np.arccos(ratio) / (math.pi * 0.5)
|
177 |
+
|
178 |
+
mask_len = torch.Tensor([np.floor(self.image_seq_len * mask_ratio)]).to(device)
|
179 |
+
mask_len = torch.maximum(torch.Tensor([1]).to(device),
|
180 |
+
torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1,
|
181 |
+
mask_len))[0].squeeze()
|
182 |
+
confidence = add_gumbel_noise(sampled_logits, annealed_temp)
|
183 |
+
sorted_confidence, _ = torch.sort(confidence, axis=-1)
|
184 |
+
cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()]
|
185 |
+
masking = (confidence <= cut_off)
|
186 |
+
if step == num_sample_steps - 1:
|
187 |
+
ids = sampled_ids
|
188 |
+
else:
|
189 |
+
ids = torch.where(masking, self.mask_token_id, sampled_ids)
|
190 |
+
|
191 |
+
if guidance_decay == "linear":
|
192 |
+
cfg_scale = ratio * guidance_scale
|
193 |
+
return ids
|
194 |
+
|
195 |
+
def masking_input_tokens(self, input_tokens):
|
196 |
+
batch_size, seq_len = input_tokens.shape
|
197 |
+
device = input_tokens.device
|
198 |
+
|
199 |
+
timesteps = torch.zeros((batch_size,), device=device).float().uniform_(0, 1.0)
|
200 |
+
mask_ratio = torch.acos(timesteps) / (math.pi * 0.5) # arccos schedule
|
201 |
+
mask_ratio = torch.clamp(mask_ratio, min=1e-6, max=1.)
|
202 |
+
num_token_masked = (seq_len * mask_ratio).round().clamp(min=1)
|
203 |
+
batch_randperm = torch.rand(batch_size, seq_len, device=device).argsort(dim=-1)
|
204 |
+
masks = batch_randperm < rearrange(num_token_masked, 'b -> b 1')
|
205 |
+
masked_tokens = torch.where(masks, self.mask_token_id, input_tokens)
|
206 |
+
return masked_tokens, masks
|
207 |
+
|
208 |
+
|
209 |
+
class UViTBert(ImageBert):
|
210 |
+
def __init__(self, config):
|
211 |
+
super().__init__(config=config)
|
212 |
+
|
213 |
+
del self.model
|
214 |
+
|
215 |
+
self.embeddings = nn.Embedding(
|
216 |
+
self.target_codebook_size + self.condition_num_classes + 2,
|
217 |
+
self.hidden_size)
|
218 |
+
|
219 |
+
self.pos_embed = nn.init.trunc_normal_(
|
220 |
+
nn.Parameter(torch.zeros(1, self.config.model.generator.image_seq_len + 1, self.hidden_size)), 0., 0.02)
|
221 |
+
|
222 |
+
self.in_blocks = nn.ModuleList([
|
223 |
+
UViTBlock(
|
224 |
+
dim=self.hidden_size, num_heads=self.num_attention_heads, mlp_ratio=(self.intermediate_size / self.hidden_size),
|
225 |
+
qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, use_checkpoint=False)
|
226 |
+
for _ in range(self.num_hidden_layers // 2)])
|
227 |
+
|
228 |
+
self.mid_block = UViTBlock(
|
229 |
+
dim=self.hidden_size, num_heads=self.num_attention_heads, mlp_ratio=(self.intermediate_size / self.hidden_size),
|
230 |
+
qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, use_checkpoint=False)
|
231 |
+
|
232 |
+
self.out_blocks = nn.ModuleList([
|
233 |
+
UViTBlock(
|
234 |
+
dim=self.hidden_size, num_heads=self.num_attention_heads, mlp_ratio=(self.intermediate_size / self.hidden_size),
|
235 |
+
qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, skip=True, use_checkpoint=False)
|
236 |
+
for _ in range(self.num_hidden_layers // 2)])
|
237 |
+
|
238 |
+
self.norm = nn.LayerNorm(self.hidden_size)
|
239 |
+
self.lm_head = nn.Linear(self.hidden_size,
|
240 |
+
self.target_codebook_size, bias=True)
|
241 |
+
self.apply(self._init_weights)
|
242 |
+
|
243 |
+
def _init_weights(self, m):
|
244 |
+
if isinstance(m, nn.Linear):
|
245 |
+
nn.init.trunc_normal_(m.weight, std=.02)
|
246 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
247 |
+
nn.init.constant_(m.bias, 0)
|
248 |
+
elif isinstance(m, nn.Embedding):
|
249 |
+
m.weight.data = nn.init.trunc_normal_(m.weight.data, mean=0.0, std=0.02)
|
250 |
+
elif isinstance(m, nn.LayerNorm):
|
251 |
+
nn.init.constant_(m.bias, 0)
|
252 |
+
nn.init.constant_(m.weight, 1.0)
|
253 |
+
|
254 |
+
def forward(self, input_ids=None, condition=None, cond_drop_prob=0.1):
|
255 |
+
# Token space:
|
256 |
+
# [0, codebook_size - 1] : those are the learned quantized image tokens
|
257 |
+
# codebook_size : the mask token used to mask image tokens
|
258 |
+
# [codebook_size + 1, codebook_size + nclass] : the imagenet class tokens
|
259 |
+
# codebook_size + 1 + nclass : the class drop label
|
260 |
+
drop_label_mask = torch.rand_like(condition, dtype=torch.float) < cond_drop_prob
|
261 |
+
# Shift the classes
|
262 |
+
condition = condition + self.target_codebook_size + 1 # [0, 999] -> [codebook_size + 1, codebook_size + 999]
|
263 |
+
condition[drop_label_mask] = self.condition_num_classes + self.target_codebook_size + 1
|
264 |
+
# prepend condition token
|
265 |
+
if input_ids is not None:
|
266 |
+
input_ids = torch.cat([condition.view(condition.shape[0], -1),
|
267 |
+
input_ids.view(input_ids.shape[0], -1),], dim=1)
|
268 |
+
else:
|
269 |
+
# at least there should be masked token
|
270 |
+
raise NotImplementedError
|
271 |
+
# UViT forward
|
272 |
+
embeddings = self.embeddings(input_ids)
|
273 |
+
x = embeddings + self.pos_embed[:, :embeddings.shape[1]]
|
274 |
+
skips = []
|
275 |
+
for blk in self.in_blocks:
|
276 |
+
x = blk(x)
|
277 |
+
skips.append(x)
|
278 |
+
x = self.mid_block(x)
|
279 |
+
for blk in self.out_blocks:
|
280 |
+
x = blk(x, skips.pop())
|
281 |
+
x = self.norm(x)
|
282 |
+
return self.lm_head(x[:, 1:]) # remove cond
|
modeling/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_model import BaseModel
|
2 |
+
from .ema_model import EMAModel
|
3 |
+
from .losses import ReconstructionLoss_Stage1, ReconstructionLoss_Stage2, MLMLoss, ARLoss
|
4 |
+
from .blocks import TiTokEncoder, TiTokDecoder, UViTBlock
|
5 |
+
from .maskgit_vqgan import Decoder as Pixel_Decoder
|
6 |
+
from .maskgit_vqgan import VectorQuantizer as Pixel_Quantizer
|
modeling/modules/base_model.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This file contains some base class implementation for models.
|
2 |
+
|
3 |
+
This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
|
4 |
+
All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
|
5 |
+
|
6 |
+
Reference:
|
7 |
+
https://github.com/huggingface/open-muse/blob/main/muse/modeling_utils.py
|
8 |
+
"""
|
9 |
+
import os
|
10 |
+
from typing import Union, Callable, Dict, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
class BaseModel(torch.nn.Module):
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
def save_pretrained_weight(
|
21 |
+
self,
|
22 |
+
save_directory: Union[str, os.PathLike],
|
23 |
+
save_function: Callable = None,
|
24 |
+
state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
25 |
+
):
|
26 |
+
"""Saves a model and its configuration file to a directory.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
save_directory: A string or os.PathLike, directory to which to save.
|
30 |
+
Will be created if it doesn't exist.
|
31 |
+
save_function: A Callable function, the function to use to save the state dictionary.
|
32 |
+
Useful on distributed training like TPUs when one need to replace `torch.save` by
|
33 |
+
another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`.
|
34 |
+
state_dict: A dictionary from str to torch.Tensor, the state dictionary to save.
|
35 |
+
If `None`, the model's state dictionary will be saved.
|
36 |
+
"""
|
37 |
+
if os.path.isfile(save_directory):
|
38 |
+
print(f"Provided path ({save_directory}) should be a directory, not a file")
|
39 |
+
return
|
40 |
+
|
41 |
+
if save_function is None:
|
42 |
+
save_function = torch.save
|
43 |
+
|
44 |
+
os.makedirs(save_directory, exist_ok=True)
|
45 |
+
|
46 |
+
model_to_save = self
|
47 |
+
|
48 |
+
if state_dict is None:
|
49 |
+
state_dict = model_to_save.state_dict()
|
50 |
+
weights_name = "pytorch_model.bin"
|
51 |
+
|
52 |
+
save_function(state_dict, os.path.join(save_directory, weights_name))
|
53 |
+
|
54 |
+
print(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
55 |
+
|
56 |
+
def load_pretrained_weight(
|
57 |
+
self,
|
58 |
+
pretrained_model_path: Union[str, os.PathLike],
|
59 |
+
strict_loading: bool = True,
|
60 |
+
torch_dtype: Optional[torch.dtype] = None
|
61 |
+
):
|
62 |
+
r"""Instantiates a pretrained pytorch model from a pre-trained model configuration.
|
63 |
+
|
64 |
+
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
65 |
+
the model, you should first set it back in training mode with `model.train()`.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
pretrained_model_path: A string or os.PathLike, a path to a *directory* or *file* containing model weights.
|
69 |
+
|
70 |
+
Raises:
|
71 |
+
ValueError: If pretrained_model_path does not exist.
|
72 |
+
"""
|
73 |
+
# If pretrained_model_path is a file, set model_file to this file.
|
74 |
+
if os.path.isfile(pretrained_model_path):
|
75 |
+
model_file = pretrained_model_path
|
76 |
+
# If pretrained_model_path is a directory, set model_file to the path of the
|
77 |
+
# file "pytorch_model.bin" in this directory.
|
78 |
+
elif os.path.isdir(pretrained_model_path):
|
79 |
+
pretrained_model_path = os.path.join(pretrained_model_path, "pytorch_model.bin")
|
80 |
+
if os.path.isfile(pretrained_model_path):
|
81 |
+
model_file = pretrained_model_path
|
82 |
+
else:
|
83 |
+
raise ValueError(f"{pretrained_model_path} does not exist")
|
84 |
+
else:
|
85 |
+
raise ValueError(f"{pretrained_model_path} does not exist")
|
86 |
+
|
87 |
+
# Load model state from checkpoint.
|
88 |
+
checkpoint = torch.load(model_file, map_location="cpu")
|
89 |
+
# Load state dictionary into self.
|
90 |
+
msg = self.load_state_dict(checkpoint, strict=strict_loading)
|
91 |
+
# Print information about loading weights.
|
92 |
+
print(f"loading weight from {model_file}, msg: {msg}")
|
93 |
+
# If torch_dtype is specified and is a valid torch.dtype, convert self to this dtype.
|
94 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
95 |
+
raise ValueError(
|
96 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
97 |
+
)
|
98 |
+
elif torch_dtype is not None:
|
99 |
+
self.to(torch_dtype)
|
100 |
+
|
101 |
+
# Set model in evaluation mode to deactivate DropOut modules by default.
|
102 |
+
self.eval()
|
103 |
+
|
104 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
105 |
+
"""Gets the number of parameters in the module.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
only_trainable: A boolean, whether to only include trainable parameters.
|
109 |
+
exclude_embeddings: A boolean, whether to exclude parameters associated with embeddings.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
An integer, the number of parameters.
|
113 |
+
"""
|
114 |
+
|
115 |
+
if exclude_embeddings:
|
116 |
+
embedding_param_names = [
|
117 |
+
f"{name}.weight"
|
118 |
+
for name, module_type in self.named_modules()
|
119 |
+
if isinstance(module_type, torch.nn.Embedding)
|
120 |
+
]
|
121 |
+
non_embedding_parameters = [
|
122 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
123 |
+
]
|
124 |
+
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
125 |
+
else:
|
126 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
127 |
+
|
modeling/modules/blocks.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Building blocks for TiTok.
|
2 |
+
|
3 |
+
Copyright (2024) Bytedance Ltd. and/or its affiliates
|
4 |
+
|
5 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
you may not use this file except in compliance with the License.
|
7 |
+
You may obtain a copy of the License at
|
8 |
+
|
9 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
|
11 |
+
Unless required by applicable law or agreed to in writing, software
|
12 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
See the License for the specific language governing permissions and
|
15 |
+
limitations under the License.
|
16 |
+
|
17 |
+
Reference:
|
18 |
+
https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py
|
19 |
+
https://github.com/baofff/U-ViT/blob/main/libs/timm.py
|
20 |
+
"""
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
from collections import OrderedDict
|
25 |
+
import einops
|
26 |
+
from einops.layers.torch import Rearrange
|
27 |
+
|
28 |
+
|
29 |
+
class ResidualAttentionBlock(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
d_model,
|
33 |
+
n_head,
|
34 |
+
mlp_ratio = 4.0,
|
35 |
+
act_layer = nn.GELU,
|
36 |
+
norm_layer = nn.LayerNorm
|
37 |
+
):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.ln_1 = norm_layer(d_model)
|
41 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
42 |
+
self.mlp_ratio = mlp_ratio
|
43 |
+
# optionally we can disable the FFN
|
44 |
+
if mlp_ratio > 0:
|
45 |
+
self.ln_2 = norm_layer(d_model)
|
46 |
+
mlp_width = int(d_model * mlp_ratio)
|
47 |
+
self.mlp = nn.Sequential(OrderedDict([
|
48 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
49 |
+
("gelu", act_layer()),
|
50 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
51 |
+
]))
|
52 |
+
|
53 |
+
def attention(
|
54 |
+
self,
|
55 |
+
x: torch.Tensor
|
56 |
+
):
|
57 |
+
return self.attn(x, x, x, need_weights=False)[0]
|
58 |
+
|
59 |
+
def forward(
|
60 |
+
self,
|
61 |
+
x: torch.Tensor,
|
62 |
+
):
|
63 |
+
attn_output = self.attention(x=self.ln_1(x))
|
64 |
+
x = x + attn_output
|
65 |
+
if self.mlp_ratio > 0:
|
66 |
+
x = x + self.mlp(self.ln_2(x))
|
67 |
+
return x
|
68 |
+
|
69 |
+
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
70 |
+
ATTENTION_MODE = 'flash'
|
71 |
+
else:
|
72 |
+
try:
|
73 |
+
import xformers
|
74 |
+
import xformers.ops
|
75 |
+
ATTENTION_MODE = 'xformers'
|
76 |
+
except:
|
77 |
+
ATTENTION_MODE = 'math'
|
78 |
+
print(f'attention mode is {ATTENTION_MODE}')
|
79 |
+
|
80 |
+
|
81 |
+
class Attention(nn.Module):
|
82 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
83 |
+
super().__init__()
|
84 |
+
self.num_heads = num_heads
|
85 |
+
head_dim = dim // num_heads
|
86 |
+
self.scale = qk_scale or head_dim ** -0.5
|
87 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
88 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
89 |
+
self.proj = nn.Linear(dim, dim)
|
90 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
B, L, C = x.shape
|
94 |
+
|
95 |
+
qkv = self.qkv(x)
|
96 |
+
if ATTENTION_MODE == 'flash':
|
97 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
|
98 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
99 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
100 |
+
x = einops.rearrange(x, 'B H L D -> B L (H D)')
|
101 |
+
elif ATTENTION_MODE == 'xformers':
|
102 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
|
103 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
|
104 |
+
x = xformers.ops.memory_efficient_attention(q, k, v)
|
105 |
+
x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
|
106 |
+
elif ATTENTION_MODE == 'math':
|
107 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
|
108 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
109 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
110 |
+
attn = attn.softmax(dim=-1)
|
111 |
+
attn = self.attn_drop(attn)
|
112 |
+
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
|
113 |
+
else:
|
114 |
+
raise NotImplemented
|
115 |
+
|
116 |
+
x = self.proj(x)
|
117 |
+
x = self.proj_drop(x)
|
118 |
+
return x
|
119 |
+
|
120 |
+
|
121 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
122 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
123 |
+
|
124 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
125 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
126 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
127 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
128 |
+
'survival rate' as the argument.
|
129 |
+
|
130 |
+
"""
|
131 |
+
if drop_prob == 0. or not training:
|
132 |
+
return x
|
133 |
+
keep_prob = 1 - drop_prob
|
134 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
135 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
136 |
+
random_tensor.floor_() # binarize
|
137 |
+
output = x.div(keep_prob) * random_tensor
|
138 |
+
return output
|
139 |
+
|
140 |
+
|
141 |
+
class DropPath(nn.Module):
|
142 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
143 |
+
"""
|
144 |
+
def __init__(self, drop_prob=None):
|
145 |
+
super(DropPath, self).__init__()
|
146 |
+
self.drop_prob = drop_prob
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
return drop_path(x, self.drop_prob, self.training)
|
150 |
+
|
151 |
+
|
152 |
+
class Mlp(nn.Module):
|
153 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
154 |
+
super().__init__()
|
155 |
+
out_features = out_features or in_features
|
156 |
+
hidden_features = hidden_features or in_features
|
157 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
158 |
+
self.act = act_layer()
|
159 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
160 |
+
self.drop = nn.Dropout(drop)
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
x = self.fc1(x)
|
164 |
+
x = self.act(x)
|
165 |
+
x = self.drop(x)
|
166 |
+
x = self.fc2(x)
|
167 |
+
x = self.drop(x)
|
168 |
+
return x
|
169 |
+
|
170 |
+
|
171 |
+
class UViTBlock(nn.Module):
|
172 |
+
|
173 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
174 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
|
175 |
+
super().__init__()
|
176 |
+
self.norm1 = norm_layer(dim)
|
177 |
+
self.attn = Attention(
|
178 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
179 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
180 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
181 |
+
self.norm2 = norm_layer(dim)
|
182 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
183 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
184 |
+
self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
|
185 |
+
self.use_checkpoint = use_checkpoint
|
186 |
+
|
187 |
+
def forward(self, x, skip=None):
|
188 |
+
if self.use_checkpoint:
|
189 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
|
190 |
+
else:
|
191 |
+
return self._forward(x, skip)
|
192 |
+
|
193 |
+
def _forward(self, x, skip=None):
|
194 |
+
if self.skip_linear is not None:
|
195 |
+
x = self.skip_linear(torch.cat([x, skip], dim=-1))
|
196 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
197 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
198 |
+
return x
|
199 |
+
|
200 |
+
|
201 |
+
def _expand_token(token, batch_size: int):
|
202 |
+
return token.unsqueeze(0).expand(batch_size, -1, -1)
|
203 |
+
|
204 |
+
|
205 |
+
class TiTokEncoder(nn.Module):
|
206 |
+
def __init__(self, config):
|
207 |
+
super().__init__()
|
208 |
+
self.config = config
|
209 |
+
self.image_size = config.dataset.preprocessing.crop_size
|
210 |
+
self.patch_size = config.model.vq_model.vit_enc_patch_size
|
211 |
+
self.grid_size = self.image_size // self.patch_size
|
212 |
+
self.model_size = config.model.vq_model.vit_enc_model_size
|
213 |
+
self.num_latent_tokens = config.model.vq_model.num_latent_tokens
|
214 |
+
self.token_size = config.model.vq_model.token_size
|
215 |
+
|
216 |
+
if config.model.vq_model.get("quantize_mode", "vq") == "vae":
|
217 |
+
self.token_size = self.token_size * 2 # needs to split into mean and std
|
218 |
+
|
219 |
+
self.is_legacy = config.model.vq_model.get("is_legacy", True)
|
220 |
+
|
221 |
+
self.width = {
|
222 |
+
"small": 512,
|
223 |
+
"base": 768,
|
224 |
+
"large": 1024,
|
225 |
+
}[self.model_size]
|
226 |
+
self.num_layers = {
|
227 |
+
"small": 8,
|
228 |
+
"base": 12,
|
229 |
+
"large": 24,
|
230 |
+
}[self.model_size]
|
231 |
+
self.num_heads = {
|
232 |
+
"small": 8,
|
233 |
+
"base": 12,
|
234 |
+
"large": 16,
|
235 |
+
}[self.model_size]
|
236 |
+
|
237 |
+
self.patch_embed = nn.Conv2d(
|
238 |
+
in_channels=3, out_channels=self.width,
|
239 |
+
kernel_size=self.patch_size, stride=self.patch_size, bias=True)
|
240 |
+
|
241 |
+
scale = self.width ** -0.5
|
242 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
|
243 |
+
self.positional_embedding = nn.Parameter(
|
244 |
+
scale * torch.randn(self.grid_size ** 2 + 1, self.width))
|
245 |
+
self.latent_token_positional_embedding = nn.Parameter(
|
246 |
+
scale * torch.randn(self.num_latent_tokens, self.width))
|
247 |
+
self.ln_pre = nn.LayerNorm(self.width)
|
248 |
+
self.transformer = nn.ModuleList()
|
249 |
+
for i in range(self.num_layers):
|
250 |
+
self.transformer.append(ResidualAttentionBlock(
|
251 |
+
self.width, self.num_heads, mlp_ratio=4.0
|
252 |
+
))
|
253 |
+
self.ln_post = nn.LayerNorm(self.width)
|
254 |
+
self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True)
|
255 |
+
|
256 |
+
def forward(self, pixel_values, latent_tokens):
|
257 |
+
batch_size = pixel_values.shape[0]
|
258 |
+
x = pixel_values
|
259 |
+
x = self.patch_embed(x)
|
260 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
261 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
262 |
+
# class embeddings and positional embeddings
|
263 |
+
x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
|
264 |
+
x = x + self.positional_embedding.to(x.dtype) # shape = [*, grid ** 2 + 1, width]
|
265 |
+
|
266 |
+
|
267 |
+
latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype)
|
268 |
+
latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)
|
269 |
+
x = torch.cat([x, latent_tokens], dim=1)
|
270 |
+
|
271 |
+
x = self.ln_pre(x)
|
272 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
273 |
+
for i in range(self.num_layers):
|
274 |
+
x = self.transformer[i](x)
|
275 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
276 |
+
|
277 |
+
latent_tokens = x[:, 1+self.grid_size**2:]
|
278 |
+
latent_tokens = self.ln_post(latent_tokens)
|
279 |
+
# fake 2D shape
|
280 |
+
if self.is_legacy:
|
281 |
+
latent_tokens = latent_tokens.reshape(batch_size, self.width, self.num_latent_tokens, 1)
|
282 |
+
else:
|
283 |
+
# Fix legacy problem.
|
284 |
+
latent_tokens = latent_tokens.reshape(batch_size, self.num_latent_tokens, self.width, 1).permute(0, 2, 1, 3)
|
285 |
+
latent_tokens = self.conv_out(latent_tokens)
|
286 |
+
latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, self.num_latent_tokens)
|
287 |
+
return latent_tokens
|
288 |
+
|
289 |
+
|
290 |
+
class TiTokDecoder(nn.Module):
|
291 |
+
def __init__(self, config):
|
292 |
+
super().__init__()
|
293 |
+
self.config = config
|
294 |
+
self.image_size = config.dataset.preprocessing.crop_size
|
295 |
+
self.patch_size = config.model.vq_model.vit_dec_patch_size
|
296 |
+
self.grid_size = self.image_size // self.patch_size
|
297 |
+
self.model_size = config.model.vq_model.vit_dec_model_size
|
298 |
+
self.num_latent_tokens = config.model.vq_model.num_latent_tokens
|
299 |
+
self.token_size = config.model.vq_model.token_size
|
300 |
+
self.is_legacy = config.model.vq_model.get("is_legacy", True)
|
301 |
+
self.width = {
|
302 |
+
"small": 512,
|
303 |
+
"base": 768,
|
304 |
+
"large": 1024,
|
305 |
+
}[self.model_size]
|
306 |
+
self.num_layers = {
|
307 |
+
"small": 8,
|
308 |
+
"base": 12,
|
309 |
+
"large": 24,
|
310 |
+
}[self.model_size]
|
311 |
+
self.num_heads = {
|
312 |
+
"small": 8,
|
313 |
+
"base": 12,
|
314 |
+
"large": 16,
|
315 |
+
}[self.model_size]
|
316 |
+
|
317 |
+
self.decoder_embed = nn.Linear(
|
318 |
+
self.token_size, self.width, bias=True)
|
319 |
+
scale = self.width ** -0.5
|
320 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
|
321 |
+
self.positional_embedding = nn.Parameter(
|
322 |
+
scale * torch.randn(self.grid_size ** 2 + 1, self.width))
|
323 |
+
# add mask token and query pos embed
|
324 |
+
self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
|
325 |
+
self.latent_token_positional_embedding = nn.Parameter(
|
326 |
+
scale * torch.randn(self.num_latent_tokens, self.width))
|
327 |
+
self.ln_pre = nn.LayerNorm(self.width)
|
328 |
+
self.transformer = nn.ModuleList()
|
329 |
+
for i in range(self.num_layers):
|
330 |
+
self.transformer.append(ResidualAttentionBlock(
|
331 |
+
self.width, self.num_heads, mlp_ratio=4.0
|
332 |
+
))
|
333 |
+
self.ln_post = nn.LayerNorm(self.width)
|
334 |
+
|
335 |
+
if self.is_legacy:
|
336 |
+
self.ffn = nn.Sequential(
|
337 |
+
nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True),
|
338 |
+
nn.Tanh(),
|
339 |
+
nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True),
|
340 |
+
)
|
341 |
+
self.conv_out = nn.Identity()
|
342 |
+
else:
|
343 |
+
# Directly predicting RGB pixels
|
344 |
+
self.ffn = nn.Sequential(
|
345 |
+
nn.Conv2d(self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True),
|
346 |
+
Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)',
|
347 |
+
p1 = self.patch_size, p2 = self.patch_size),)
|
348 |
+
self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)
|
349 |
+
|
350 |
+
def forward(self, z_quantized):
|
351 |
+
N, C, H, W = z_quantized.shape
|
352 |
+
assert H == 1 and W == self.num_latent_tokens, f"{H}, {W}, {self.num_latent_tokens}"
|
353 |
+
x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
|
354 |
+
x = self.decoder_embed(x)
|
355 |
+
|
356 |
+
batchsize, seq_len, _ = x.shape
|
357 |
+
|
358 |
+
mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(x.dtype)
|
359 |
+
mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
|
360 |
+
mask_tokens], dim=1)
|
361 |
+
mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype)
|
362 |
+
x = x + self.latent_token_positional_embedding[:seq_len]
|
363 |
+
x = torch.cat([mask_tokens, x], dim=1)
|
364 |
+
|
365 |
+
x = self.ln_pre(x)
|
366 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
367 |
+
for i in range(self.num_layers):
|
368 |
+
x = self.transformer[i](x)
|
369 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
370 |
+
x = x[:, 1:1+self.grid_size**2] # remove cls embed
|
371 |
+
x = self.ln_post(x)
|
372 |
+
# N L D -> N D H W
|
373 |
+
x = x.permute(0, 2, 1).reshape(batchsize, self.width, self.grid_size, self.grid_size)
|
374 |
+
x = self.ffn(x.contiguous())
|
375 |
+
x = self.conv_out(x)
|
376 |
+
return x
|
modeling/modules/discriminator.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This file contains some base implementation for discrminators.
|
2 |
+
|
3 |
+
Copyright (2024) Bytedance Ltd. and/or its affiliates
|
4 |
+
|
5 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
you may not use this file except in compliance with the License.
|
7 |
+
You may obtain a copy of the License at
|
8 |
+
|
9 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
|
11 |
+
Unless required by applicable law or agreed to in writing, software
|
12 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
See the License for the specific language governing permissions and
|
15 |
+
limitations under the License.
|
16 |
+
|
17 |
+
TODO: Add reference to Mark Weber's tech report on the improved discriminator architecture.
|
18 |
+
"""
|
19 |
+
import functools
|
20 |
+
import math
|
21 |
+
from typing import Tuple
|
22 |
+
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torch.nn as nn
|
26 |
+
import torch.nn.functional as F
|
27 |
+
|
28 |
+
from .maskgit_vqgan import Conv2dSame
|
29 |
+
|
30 |
+
|
31 |
+
class BlurBlock(torch.nn.Module):
|
32 |
+
def __init__(self,
|
33 |
+
kernel: Tuple[int] = (1, 3, 3, 1)
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False)
|
38 |
+
kernel = kernel[None, :] * kernel[:, None]
|
39 |
+
kernel /= kernel.sum()
|
40 |
+
kernel = kernel.unsqueeze(0).unsqueeze(0)
|
41 |
+
self.register_buffer("kernel", kernel)
|
42 |
+
|
43 |
+
def calc_same_pad(self, i: int, k: int, s: int) -> int:
|
44 |
+
return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0)
|
45 |
+
|
46 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
47 |
+
ic, ih, iw = x.size()[-3:]
|
48 |
+
pad_h = self.calc_same_pad(i=ih, k=4, s=2)
|
49 |
+
pad_w = self.calc_same_pad(i=iw, k=4, s=2)
|
50 |
+
if pad_h > 0 or pad_w > 0:
|
51 |
+
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
52 |
+
|
53 |
+
weight = self.kernel.expand(ic, -1, -1, -1)
|
54 |
+
|
55 |
+
out = F.conv2d(input=x, weight=weight, stride=2, groups=x.shape[1])
|
56 |
+
return out
|
57 |
+
|
58 |
+
|
59 |
+
class NLayerDiscriminator(torch.nn.Module):
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
num_channels: int = 3,
|
63 |
+
hidden_channels: int = 128,
|
64 |
+
num_stages: int = 3,
|
65 |
+
blur_resample: bool = True,
|
66 |
+
blur_kernel_size: int = 4
|
67 |
+
):
|
68 |
+
""" Initializes the NLayerDiscriminator.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
num_channels -> int: The number of input channels.
|
72 |
+
hidden_channels -> int: The number of hidden channels.
|
73 |
+
num_stages -> int: The number of stages.
|
74 |
+
blur_resample -> bool: Whether to use blur resampling.
|
75 |
+
blur_kernel_size -> int: The blur kernel size.
|
76 |
+
"""
|
77 |
+
super().__init__()
|
78 |
+
assert num_stages > 0, "Discriminator cannot have 0 stages"
|
79 |
+
assert (not blur_resample) or (blur_kernel_size >= 3 and blur_kernel_size <= 5), "Blur kernel size must be in [3,5] when sampling]"
|
80 |
+
|
81 |
+
in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages)))
|
82 |
+
init_kernel_size = 5
|
83 |
+
activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1)
|
84 |
+
|
85 |
+
self.block_in = torch.nn.Sequential(
|
86 |
+
Conv2dSame(
|
87 |
+
num_channels,
|
88 |
+
hidden_channels,
|
89 |
+
kernel_size=init_kernel_size
|
90 |
+
),
|
91 |
+
activation(),
|
92 |
+
)
|
93 |
+
|
94 |
+
BLUR_KERNEL_MAP = {
|
95 |
+
3: (1,2,1),
|
96 |
+
4: (1,3,3,1),
|
97 |
+
5: (1,4,6,4,1),
|
98 |
+
}
|
99 |
+
|
100 |
+
discriminator_blocks = []
|
101 |
+
for i_level in range(num_stages):
|
102 |
+
in_channels = hidden_channels * in_channel_mult[i_level]
|
103 |
+
out_channels = hidden_channels * in_channel_mult[i_level + 1]
|
104 |
+
block = torch.nn.Sequential(
|
105 |
+
Conv2dSame(
|
106 |
+
in_channels,
|
107 |
+
out_channels,
|
108 |
+
kernel_size=3,
|
109 |
+
),
|
110 |
+
torch.nn.AvgPool2d(kernel_size=2, stride=2) if not blur_resample else BlurBlock(BLUR_KERNEL_MAP[blur_kernel_size]),
|
111 |
+
torch.nn.GroupNorm(32, out_channels),
|
112 |
+
activation(),
|
113 |
+
)
|
114 |
+
discriminator_blocks.append(block)
|
115 |
+
|
116 |
+
self.blocks = torch.nn.ModuleList(discriminator_blocks)
|
117 |
+
|
118 |
+
self.pool = torch.nn.AdaptiveMaxPool2d((16, 16))
|
119 |
+
|
120 |
+
self.to_logits = torch.nn.Sequential(
|
121 |
+
Conv2dSame(out_channels, out_channels, 1),
|
122 |
+
activation(),
|
123 |
+
Conv2dSame(out_channels, 1, kernel_size=5)
|
124 |
+
)
|
125 |
+
|
126 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
127 |
+
""" Forward pass.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
x -> torch.Tensor: The input tensor.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
output -> torch.Tensor: The output tensor.
|
134 |
+
"""
|
135 |
+
hidden_states = self.block_in(x)
|
136 |
+
for block in self.blocks:
|
137 |
+
hidden_states = block(hidden_states)
|
138 |
+
|
139 |
+
hidden_states = self.pool(hidden_states)
|
140 |
+
|
141 |
+
return self.to_logits(hidden_states)
|
modeling/modules/ema_model.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This file contains some base class implementation for EMA.
|
2 |
+
|
3 |
+
This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
|
4 |
+
All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
|
5 |
+
|
6 |
+
Reference:
|
7 |
+
https://github.com/huggingface/open-muse/blob/64e1afe033717d795866ab8204484705cd4dc3f7/muse/modeling_ema.py#L8
|
8 |
+
"""
|
9 |
+
|
10 |
+
|
11 |
+
import copy
|
12 |
+
from typing import Any, Iterable, Optional, Union
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
|
17 |
+
class EMAModel:
|
18 |
+
"""Exponential Moving Average of models weights."""
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
parameters: Iterable[torch.nn.Parameter],
|
22 |
+
decay: float = 0.9999,
|
23 |
+
min_decay: float = 0.0,
|
24 |
+
update_after_step: int = 0,
|
25 |
+
update_every: int = 1,
|
26 |
+
current_step: int = 0,
|
27 |
+
use_ema_warmup: bool = False,
|
28 |
+
inv_gamma: Union[float, int] = 1.0,
|
29 |
+
power: Union[float, int] = 2 / 3,
|
30 |
+
model_cls: Optional[Any] = None,
|
31 |
+
**model_config_kwargs
|
32 |
+
):
|
33 |
+
"""
|
34 |
+
Args:
|
35 |
+
parameters (Iterable[torch.nn.Parameter]): The parameters to track.
|
36 |
+
decay (float): The decay factor for the exponential moving average.
|
37 |
+
min_decay (float): The minimum decay factor for the exponential moving average.
|
38 |
+
update_after_step (int): The number of steps to wait before starting to update the EMA weights.
|
39 |
+
update_every (int): The number of steps between each EMA update.
|
40 |
+
current_step (int): The current training step.
|
41 |
+
use_ema_warmup (bool): Whether to use EMA warmup.
|
42 |
+
inv_gamma (float):
|
43 |
+
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
|
44 |
+
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
|
45 |
+
|
46 |
+
notes on EMA Warmup:
|
47 |
+
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
48 |
+
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
49 |
+
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
50 |
+
at 215.4k steps).
|
51 |
+
"""
|
52 |
+
|
53 |
+
parameters = list(parameters)
|
54 |
+
self.shadow_params = [p.clone().detach() for p in parameters]
|
55 |
+
self.temp_stored_params = None
|
56 |
+
|
57 |
+
self.decay = decay
|
58 |
+
self.min_decay = min_decay
|
59 |
+
self.update_after_step = update_after_step
|
60 |
+
self.update_every = update_every
|
61 |
+
self.use_ema_warmup = use_ema_warmup
|
62 |
+
self.inv_gamma = inv_gamma
|
63 |
+
self.power = power
|
64 |
+
self.optimization_step = current_step
|
65 |
+
self.cur_decay_value = None # set in `step()`
|
66 |
+
|
67 |
+
self.model_cls = model_cls
|
68 |
+
self.model_config_kwargs = model_config_kwargs
|
69 |
+
|
70 |
+
@classmethod
|
71 |
+
def from_pretrained(cls, checkpoint, model_cls, **model_config_kwargs) -> "EMAModel":
|
72 |
+
model = model_cls(**model_config_kwargs)
|
73 |
+
model.load_pretrained_weight(checkpoint)
|
74 |
+
|
75 |
+
ema_model = cls(model.parameters(), model_cls=model_cls, **model_config_kwargs)
|
76 |
+
return ema_model
|
77 |
+
|
78 |
+
def save_pretrained(self, path):
|
79 |
+
if self.model_cls is None:
|
80 |
+
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
|
81 |
+
|
82 |
+
if self.model_config_kwargs is None:
|
83 |
+
raise ValueError("`save_pretrained` can only be used if `model_config_kwargs` was defined at __init__.")
|
84 |
+
|
85 |
+
model = self.model_cls(**self.model_config_kwargs)
|
86 |
+
self.copy_to(model.parameters())
|
87 |
+
model.save_pretrained_weight(path)
|
88 |
+
|
89 |
+
def set_step(self, optimization_step: int):
|
90 |
+
self.optimization_step = optimization_step
|
91 |
+
|
92 |
+
def get_decay(self, optimization_step: int) -> float:
|
93 |
+
"""Computes the decay factor for the exponential moving average."""
|
94 |
+
step = max(0, optimization_step - self.update_after_step - 1)
|
95 |
+
|
96 |
+
if step <= 0:
|
97 |
+
return 0.0
|
98 |
+
|
99 |
+
if self.use_ema_warmup:
|
100 |
+
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
101 |
+
else:
|
102 |
+
cur_decay_value = (1 + step) / (10 + step)
|
103 |
+
|
104 |
+
cur_decay_value = min(cur_decay_value, self.decay)
|
105 |
+
# Make sure decay is not smaller than min_decay.
|
106 |
+
cur_decay_value = max(cur_decay_value, self.min_decay)
|
107 |
+
return cur_decay_value
|
108 |
+
|
109 |
+
@torch.no_grad()
|
110 |
+
def step(self, parameters: Iterable[torch.nn.Parameter]):
|
111 |
+
parameters = list(parameters)
|
112 |
+
|
113 |
+
self.optimization_step += 1
|
114 |
+
|
115 |
+
if (self.optimization_step - 1) % self.update_every != 0:
|
116 |
+
return
|
117 |
+
|
118 |
+
# Compute the decay factor for the exponential moving average.
|
119 |
+
decay = self.get_decay(self.optimization_step)
|
120 |
+
self.cur_decay_value = decay
|
121 |
+
one_minus_decay = 1 - decay
|
122 |
+
|
123 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
124 |
+
if param.requires_grad:
|
125 |
+
s_param.sub_(one_minus_decay * (s_param - param))
|
126 |
+
else:
|
127 |
+
s_param.copy_(param)
|
128 |
+
|
129 |
+
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
130 |
+
"""Copies current averaged parameters into given collection of parameters.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
134 |
+
updated with the stored moving averages. If `None`, the parameters with which this
|
135 |
+
`ExponentialMovingAverage` was initialized will be used.
|
136 |
+
"""
|
137 |
+
parameters = list(parameters)
|
138 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
139 |
+
param.data.copy_(s_param.to(param.device).data)
|
140 |
+
|
141 |
+
def to(self, device=None, dtype=None) -> None:
|
142 |
+
r"""Moves internal buffers of the ExponentialMovingAverage to `device`.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
device: like `device` argument to `torch.Tensor.to`
|
146 |
+
"""
|
147 |
+
# .to() on the tensors handles None correctly
|
148 |
+
self.shadow_params = [
|
149 |
+
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
|
150 |
+
for p in self.shadow_params
|
151 |
+
]
|
152 |
+
|
153 |
+
def state_dict(self) -> dict:
|
154 |
+
r"""Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
|
155 |
+
checkpointing to save the ema state dict.
|
156 |
+
"""
|
157 |
+
# Following PyTorch conventions, references to tensors are returned:
|
158 |
+
# "returns a reference to the state and not its copy!" -
|
159 |
+
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
160 |
+
return {
|
161 |
+
"decay": self.decay,
|
162 |
+
"min_decay": self.min_decay,
|
163 |
+
"optimization_step": self.optimization_step,
|
164 |
+
"update_after_step": self.update_after_step,
|
165 |
+
"use_ema_warmup": self.use_ema_warmup,
|
166 |
+
"inv_gamma": self.inv_gamma,
|
167 |
+
"power": self.power,
|
168 |
+
"shadow_params": self.shadow_params,
|
169 |
+
}
|
170 |
+
|
171 |
+
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
172 |
+
r"""
|
173 |
+
Args:
|
174 |
+
Save the current parameters for restoring later.
|
175 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
176 |
+
temporarily stored.
|
177 |
+
"""
|
178 |
+
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
|
179 |
+
|
180 |
+
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
181 |
+
r"""Restores the parameters stored with the `store` method. Useful to validate
|
182 |
+
the model with EMA parameters without affecting the original optimization process.
|
183 |
+
Store the parameters before the `copy_to()` method. After validation (or
|
184 |
+
model saving), use this to restore the former parameters.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
188 |
+
updated with the stored parameters. If `None`, the parameters with which this
|
189 |
+
`ExponentialMovingAverage` was initialized will be used.
|
190 |
+
"""
|
191 |
+
if self.temp_stored_params is None:
|
192 |
+
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
|
193 |
+
for c_param, param in zip(self.temp_stored_params, parameters):
|
194 |
+
param.data.copy_(c_param.data)
|
195 |
+
|
196 |
+
# Better memory-wise.
|
197 |
+
self.temp_stored_params = None
|
198 |
+
|
199 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
200 |
+
r"""Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
|
201 |
+
ema state dict.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
state_dict (dict): EMA state. Should be an object returned
|
205 |
+
from a call to :meth:`state_dict`.
|
206 |
+
"""
|
207 |
+
# Deepcopy, to be consistent with module API
|
208 |
+
state_dict = copy.deepcopy(state_dict)
|
209 |
+
|
210 |
+
self.decay = state_dict.get("decay", self.decay)
|
211 |
+
if self.decay < 0.0 or self.decay > 1.0:
|
212 |
+
raise ValueError("Decay must be between 0 and 1")
|
213 |
+
|
214 |
+
self.min_decay = state_dict.get("min_decay", self.min_decay)
|
215 |
+
if not isinstance(self.min_decay, float):
|
216 |
+
raise ValueError("Invalid min_decay")
|
217 |
+
|
218 |
+
self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
|
219 |
+
if not isinstance(self.optimization_step, int):
|
220 |
+
raise ValueError("Invalid optimization_step")
|
221 |
+
|
222 |
+
self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
|
223 |
+
if not isinstance(self.update_after_step, int):
|
224 |
+
raise ValueError("Invalid update_after_step")
|
225 |
+
|
226 |
+
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
|
227 |
+
if not isinstance(self.use_ema_warmup, bool):
|
228 |
+
raise ValueError("Invalid use_ema_warmup")
|
229 |
+
|
230 |
+
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
|
231 |
+
if not isinstance(self.inv_gamma, (float, int)):
|
232 |
+
raise ValueError("Invalid inv_gamma")
|
233 |
+
|
234 |
+
self.power = state_dict.get("power", self.power)
|
235 |
+
if not isinstance(self.power, (float, int)):
|
236 |
+
raise ValueError("Invalid power")
|
237 |
+
|
238 |
+
shadow_params = state_dict.get("shadow_params", None)
|
239 |
+
if shadow_params is not None:
|
240 |
+
self.shadow_params = shadow_params
|
241 |
+
if not isinstance(self.shadow_params, list):
|
242 |
+
raise ValueError("shadow_params must be a list")
|
243 |
+
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
244 |
+
raise ValueError("shadow_params must all be Tensors")
|
modeling/modules/losses.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This files contains training loss implementation.
|
2 |
+
|
3 |
+
Copyright (2024) Bytedance Ltd. and/or its affiliates
|
4 |
+
|
5 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
you may not use this file except in compliance with the License.
|
7 |
+
You may obtain a copy of the License at
|
8 |
+
|
9 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
|
11 |
+
Unless required by applicable law or agreed to in writing, software
|
12 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
See the License for the specific language governing permissions and
|
15 |
+
limitations under the License.
|
16 |
+
|
17 |
+
Ref:
|
18 |
+
https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py
|
19 |
+
"""
|
20 |
+
from typing import Mapping, Text, Tuple
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
from einops import rearrange
|
26 |
+
from torch.cuda.amp import autocast
|
27 |
+
from .perceptual_loss import PerceptualLoss
|
28 |
+
from .discriminator import NLayerDiscriminator
|
29 |
+
|
30 |
+
|
31 |
+
def hinge_d_loss(logits_real: torch.Tensor, logits_fake: torch.Tensor) -> torch.Tensor:
|
32 |
+
"""Hinge loss for discrminator.
|
33 |
+
|
34 |
+
This function is borrowed from
|
35 |
+
https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py#L20
|
36 |
+
"""
|
37 |
+
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
38 |
+
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
39 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
40 |
+
return d_loss
|
41 |
+
|
42 |
+
|
43 |
+
def compute_lecam_loss(
|
44 |
+
logits_real_mean: torch.Tensor,
|
45 |
+
logits_fake_mean: torch.Tensor,
|
46 |
+
ema_logits_real_mean: torch.Tensor,
|
47 |
+
ema_logits_fake_mean: torch.Tensor
|
48 |
+
) -> torch.Tensor:
|
49 |
+
"""Computes the LeCam loss for the given average real and fake logits.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
logits_real_mean -> torch.Tensor: The average real logits.
|
53 |
+
logits_fake_mean -> torch.Tensor: The average fake logits.
|
54 |
+
ema_logits_real_mean -> torch.Tensor: The EMA of the average real logits.
|
55 |
+
ema_logits_fake_mean -> torch.Tensor: The EMA of the average fake logits.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
lecam_loss -> torch.Tensor: The LeCam loss.
|
59 |
+
"""
|
60 |
+
lecam_loss = torch.mean(torch.pow(F.relu(logits_real_mean - ema_logits_fake_mean), 2))
|
61 |
+
lecam_loss += torch.mean(torch.pow(F.relu(ema_logits_real_mean - logits_fake_mean), 2))
|
62 |
+
return lecam_loss
|
63 |
+
|
64 |
+
|
65 |
+
class ReconstructionLoss_Stage1(torch.nn.Module):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
config
|
69 |
+
):
|
70 |
+
super().__init__()
|
71 |
+
loss_config = config.losses
|
72 |
+
self.quantizer_weight = loss_config.quantizer_weight
|
73 |
+
self.target_codebook_size = 1024
|
74 |
+
|
75 |
+
def forward(self,
|
76 |
+
target_codes: torch.Tensor,
|
77 |
+
reconstructions: torch.Tensor,
|
78 |
+
quantizer_loss: torch.Tensor,
|
79 |
+
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
|
80 |
+
return self._forward_generator(target_codes, reconstructions, quantizer_loss)
|
81 |
+
|
82 |
+
def _forward_generator(self,
|
83 |
+
target_codes: torch.Tensor,
|
84 |
+
reconstructions: torch.Tensor,
|
85 |
+
quantizer_loss: Mapping[Text, torch.Tensor],
|
86 |
+
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
|
87 |
+
reconstructions = reconstructions.contiguous()
|
88 |
+
loss_fct = nn.CrossEntropyLoss(reduction="mean")
|
89 |
+
batch_size = reconstructions.shape[0]
|
90 |
+
reconstruction_loss = loss_fct(reconstructions.view(batch_size, self.target_codebook_size, -1),
|
91 |
+
target_codes.view(batch_size, -1))
|
92 |
+
total_loss = reconstruction_loss + \
|
93 |
+
self.quantizer_weight * quantizer_loss["quantizer_loss"]
|
94 |
+
|
95 |
+
loss_dict = dict(
|
96 |
+
total_loss=total_loss.clone().detach(),
|
97 |
+
reconstruction_loss=reconstruction_loss.detach(),
|
98 |
+
quantizer_loss=(self.quantizer_weight * quantizer_loss["quantizer_loss"]).detach(),
|
99 |
+
commitment_loss=quantizer_loss["commitment_loss"].detach(),
|
100 |
+
codebook_loss=quantizer_loss["codebook_loss"].detach(),
|
101 |
+
)
|
102 |
+
|
103 |
+
return total_loss, loss_dict
|
104 |
+
|
105 |
+
|
106 |
+
class ReconstructionLoss_Stage2(torch.nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
config
|
110 |
+
):
|
111 |
+
"""Initializes the losses module.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
config: A dictionary, the configuration for the model and everything else.
|
115 |
+
"""
|
116 |
+
super().__init__()
|
117 |
+
loss_config = config.losses
|
118 |
+
self.discriminator = NLayerDiscriminator()
|
119 |
+
|
120 |
+
self.reconstruction_loss = loss_config.reconstruction_loss
|
121 |
+
self.reconstruction_weight = loss_config.reconstruction_weight
|
122 |
+
self.quantizer_weight = loss_config.quantizer_weight
|
123 |
+
self.perceptual_loss = PerceptualLoss(
|
124 |
+
loss_config.perceptual_loss).eval()
|
125 |
+
self.perceptual_weight = loss_config.perceptual_weight
|
126 |
+
self.discriminator_iter_start = loss_config.discriminator_start
|
127 |
+
|
128 |
+
self.discriminator_factor = loss_config.discriminator_factor
|
129 |
+
self.discriminator_weight = loss_config.discriminator_weight
|
130 |
+
self.lecam_regularization_weight = loss_config.lecam_regularization_weight
|
131 |
+
self.lecam_ema_decay = loss_config.get("lecam_ema_decay", 0.999)
|
132 |
+
if self.lecam_regularization_weight > 0.0:
|
133 |
+
self.register_buffer("ema_real_logits_mean", torch.zeros((1)))
|
134 |
+
self.register_buffer("ema_fake_logits_mean", torch.zeros((1)))
|
135 |
+
|
136 |
+
self.config = config
|
137 |
+
|
138 |
+
@autocast(enabled=False)
|
139 |
+
def forward(self,
|
140 |
+
inputs: torch.Tensor,
|
141 |
+
reconstructions: torch.Tensor,
|
142 |
+
extra_result_dict: Mapping[Text, torch.Tensor],
|
143 |
+
global_step: int,
|
144 |
+
mode: str = "generator",
|
145 |
+
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
|
146 |
+
# Both inputs and reconstructions are in range [0, 1].
|
147 |
+
inputs = inputs.float()
|
148 |
+
reconstructions = reconstructions.float()
|
149 |
+
|
150 |
+
if mode == "generator":
|
151 |
+
return self._forward_generator(inputs, reconstructions, extra_result_dict, global_step)
|
152 |
+
elif mode == "discriminator":
|
153 |
+
return self._forward_discriminator(inputs, reconstructions, global_step)
|
154 |
+
else:
|
155 |
+
raise ValueError(f"Unsupported mode {mode}")
|
156 |
+
|
157 |
+
def should_discriminator_be_trained(self, global_step : int):
|
158 |
+
return global_step >= self.discriminator_iter_start
|
159 |
+
|
160 |
+
def _forward_generator(self,
|
161 |
+
inputs: torch.Tensor,
|
162 |
+
reconstructions: torch.Tensor,
|
163 |
+
extra_result_dict: Mapping[Text, torch.Tensor],
|
164 |
+
global_step: int
|
165 |
+
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
|
166 |
+
"""Generator training step."""
|
167 |
+
inputs = inputs.contiguous()
|
168 |
+
reconstructions = reconstructions.contiguous()
|
169 |
+
if self.reconstruction_loss == "l1":
|
170 |
+
reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean")
|
171 |
+
elif self.reconstruction_loss == "l2":
|
172 |
+
reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean")
|
173 |
+
else:
|
174 |
+
raise ValueError(f"Unsuppored reconstruction_loss {self.reconstruction_loss}")
|
175 |
+
reconstruction_loss *= self.reconstruction_weight
|
176 |
+
|
177 |
+
# Compute perceptual loss.
|
178 |
+
perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean()
|
179 |
+
|
180 |
+
# Compute discriminator loss.
|
181 |
+
generator_loss = torch.zeros((), device=inputs.device)
|
182 |
+
discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0
|
183 |
+
d_weight = 1.0
|
184 |
+
if discriminator_factor > 0.0 and self.discriminator_weight > 0.0:
|
185 |
+
# Disable discriminator gradients.
|
186 |
+
for param in self.discriminator.parameters():
|
187 |
+
param.requires_grad = False
|
188 |
+
logits_fake = self.discriminator(reconstructions)
|
189 |
+
generator_loss = -torch.mean(logits_fake)
|
190 |
+
|
191 |
+
d_weight *= self.discriminator_weight
|
192 |
+
|
193 |
+
# Compute quantizer loss.
|
194 |
+
quantizer_loss = extra_result_dict["quantizer_loss"]
|
195 |
+
total_loss = (
|
196 |
+
reconstruction_loss
|
197 |
+
+ self.perceptual_weight * perceptual_loss
|
198 |
+
+ self.quantizer_weight * quantizer_loss
|
199 |
+
+ d_weight * discriminator_factor * generator_loss
|
200 |
+
)
|
201 |
+
loss_dict = dict(
|
202 |
+
total_loss=total_loss.clone().detach(),
|
203 |
+
reconstruction_loss=reconstruction_loss.detach(),
|
204 |
+
perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(),
|
205 |
+
quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(),
|
206 |
+
weighted_gan_loss=(d_weight * discriminator_factor * generator_loss).detach(),
|
207 |
+
discriminator_factor=torch.tensor(discriminator_factor),
|
208 |
+
commitment_loss=extra_result_dict["commitment_loss"].detach(),
|
209 |
+
codebook_loss=extra_result_dict["codebook_loss"].detach(),
|
210 |
+
d_weight=d_weight,
|
211 |
+
gan_loss=generator_loss.detach(),
|
212 |
+
)
|
213 |
+
|
214 |
+
return total_loss, loss_dict
|
215 |
+
|
216 |
+
def _forward_discriminator(self,
|
217 |
+
inputs: torch.Tensor,
|
218 |
+
reconstructions: torch.Tensor,
|
219 |
+
global_step: int,
|
220 |
+
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
|
221 |
+
"""Discrminator training step."""
|
222 |
+
discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0
|
223 |
+
loss_dict = {}
|
224 |
+
# Turn the gradients on.
|
225 |
+
for param in self.discriminator.parameters():
|
226 |
+
param.requires_grad = True
|
227 |
+
|
228 |
+
real_images = inputs.detach().requires_grad_(True)
|
229 |
+
logits_real = self.discriminator(real_images)
|
230 |
+
logits_fake = self.discriminator(reconstructions.detach())
|
231 |
+
|
232 |
+
discriminator_loss = discriminator_factor * hinge_d_loss(logits_real=logits_real, logits_fake=logits_fake)
|
233 |
+
|
234 |
+
# optional lecam regularization
|
235 |
+
lecam_loss = torch.zeros((), device=inputs.device)
|
236 |
+
if self.lecam_regularization_weight > 0.0:
|
237 |
+
lecam_loss = compute_lecam_loss(
|
238 |
+
torch.mean(logits_real),
|
239 |
+
torch.mean(logits_fake),
|
240 |
+
self.ema_real_logits_mean,
|
241 |
+
self.ema_fake_logits_mean
|
242 |
+
) * self.lecam_regularization_weight
|
243 |
+
|
244 |
+
self.ema_real_logits_mean = self.ema_real_logits_mean * self.lecam_ema_decay + torch.mean(logits_real).detach() * (1 - self.lecam_ema_decay)
|
245 |
+
self.ema_fake_logits_mean = self.ema_fake_logits_mean * self.lecam_ema_decay + torch.mean(logits_fake).detach() * (1 - self.lecam_ema_decay)
|
246 |
+
|
247 |
+
discriminator_loss += lecam_loss
|
248 |
+
|
249 |
+
loss_dict = dict(
|
250 |
+
discriminator_loss=discriminator_loss.detach(),
|
251 |
+
logits_real=logits_real.detach().mean(),
|
252 |
+
logits_fake=logits_fake.detach().mean(),
|
253 |
+
lecam_loss=lecam_loss.detach(),
|
254 |
+
)
|
255 |
+
return discriminator_loss, loss_dict
|
256 |
+
|
257 |
+
|
258 |
+
class MLMLoss(torch.nn.Module):
|
259 |
+
def __init__(self,
|
260 |
+
config):
|
261 |
+
super().__init__()
|
262 |
+
self.label_smoothing = config.losses.label_smoothing
|
263 |
+
self.loss_weight_unmasked_token = config.losses.loss_weight_unmasked_token
|
264 |
+
self.criterion = torch.nn.CrossEntropyLoss(label_smoothing=self.label_smoothing,
|
265 |
+
reduction="none")
|
266 |
+
|
267 |
+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor,
|
268 |
+
weights=None) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
|
269 |
+
inputs = rearrange(inputs, "b n c -> b c n")
|
270 |
+
loss = self.criterion(inputs, targets)
|
271 |
+
weights = weights.to(loss)
|
272 |
+
loss_weights = (1.0 - weights) * self.loss_weight_unmasked_token + weights # set 0 to self.loss_weight_unasked_token
|
273 |
+
loss = (loss * loss_weights).sum() / (loss_weights.sum() + 1e-8)
|
274 |
+
# we only compute correct tokens on masked tokens
|
275 |
+
correct_tokens = ((torch.argmax(inputs, dim=1) == targets) * weights).sum(dim=1) / (weights.sum(1) + 1e-8)
|
276 |
+
return loss, {"loss": loss, "correct_tokens": correct_tokens.mean()}
|
277 |
+
|
278 |
+
|
279 |
+
class ARLoss(torch.nn.Module):
|
280 |
+
def __init__(self, config):
|
281 |
+
super().__init__()
|
282 |
+
self.target_vocab_size = config.model.vq_model.codebook_size
|
283 |
+
self.criterion = torch.nn.CrossEntropyLoss(reduction="mean")
|
284 |
+
|
285 |
+
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
|
286 |
+
shift_logits = logits[..., :-1, :].permute(0, 2, 1).contiguous() # NLC->NCL
|
287 |
+
shift_labels = labels.contiguous()
|
288 |
+
shift_logits = shift_logits.view(shift_logits.shape[0], self.target_vocab_size, -1)
|
289 |
+
shift_labels = shift_labels.view(shift_labels.shape[0], -1)
|
290 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
291 |
+
loss = self.criterion(shift_logits, shift_labels)
|
292 |
+
correct_tokens = (torch.argmax(shift_logits, dim=1) == shift_labels).sum(dim=1) / shift_labels.size(1)
|
293 |
+
return loss, {"loss": loss, "correct_tokens": correct_tokens.mean()}
|