yucornetto commited on
Commit
51ce47d
·
verified ·
1 Parent(s): 1071615
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. LICENSE +201 -0
  3. README.md +82 -14
  4. README_RAR.md +221 -0
  5. README_TiTok.md +213 -0
  6. assets/ILSVRC2012_val_00008636.png +0 -0
  7. assets/ILSVRC2012_val_00010240.png +0 -0
  8. assets/perf_comp.png +0 -0
  9. assets/random_vis_l32.png +3 -0
  10. assets/rar_overview.png +3 -0
  11. assets/recon_w_model_size_num_token.png +3 -0
  12. assets/speed_vs_perf.png +0 -0
  13. assets/titok_teaser.png +0 -0
  14. assets/vis1.png +3 -0
  15. assets/vis2.png +3 -0
  16. assets/vis3.png +3 -0
  17. configs/infer/titok_b64.yaml +39 -0
  18. configs/infer/titok_bl128_vae_c16.yaml +19 -0
  19. configs/infer/titok_bl128_vq8k.yaml +21 -0
  20. configs/infer/titok_bl64_vae_c16.yaml +19 -0
  21. configs/infer/titok_bl64_vq8k.yaml +21 -0
  22. configs/infer/titok_l32.yaml +40 -0
  23. configs/infer/titok_ll32_vae_c16.yaml +19 -0
  24. configs/infer/titok_s128.yaml +39 -0
  25. configs/infer/titok_sl256_vq8k.yaml +21 -0
  26. configs/training/generator/maskgit.yaml +86 -0
  27. configs/training/generator/rar.yaml +78 -0
  28. configs/training/stage1/titok_b64.yaml +70 -0
  29. configs/training/stage1/titok_l32.yaml +70 -0
  30. configs/training/stage1/titok_s128.yaml +70 -0
  31. configs/training/stage2/titok_b64.yaml +80 -0
  32. configs/training/stage2/titok_l32.yaml +79 -0
  33. configs/training/stage2/titok_s128.yaml +79 -0
  34. data/__init__.py +1 -0
  35. data/convert_imagenet_to_wds.py +68 -0
  36. data/webdataset_reader.py +227 -0
  37. demo.ipynb +0 -0
  38. demo_util.py +108 -0
  39. evaluator/__init__.py +1 -0
  40. evaluator/evaluator.py +245 -0
  41. evaluator/inception.py +231 -0
  42. imagenet_classes.py +1001 -0
  43. modeling/__init__.py +15 -0
  44. modeling/maskgit.py +282 -0
  45. modeling/modules/__init__.py +6 -0
  46. modeling/modules/base_model.py +127 -0
  47. modeling/modules/blocks.py +376 -0
  48. modeling/modules/discriminator.py +141 -0
  49. modeling/modules/ema_model.py +244 -0
  50. modeling/modules/losses.py +293 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/random_vis_l32.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/rar_overview.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/recon_w_model_size_num_token.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/vis1.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/vis2.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/vis3.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,14 +1,82 @@
1
- ---
2
- title: RAR
3
- emoji: 📈
4
- colorFrom: purple
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.6.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: Randomized Autoregressive Visual Generation
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 1D Visual Tokenization and Generation
2
+
3
+ This repo hosts the code and models for the following projects:
4
+
5
+ - RAR: [Randomized Autoregressive Visual Generation](https://yucornetto.github.io/projects/rar.html)
6
+
7
+ - TiTok: [An Image is Worth 32 Tokens for Reconstruction and Generation](https://yucornetto.github.io/projects/titok.html)
8
+
9
+
10
+ ## Short Intro on [Randomized Autoregressive Visual Generation](https://arxiv.org/abs/2406.07550) ([README](README_RAR.md))
11
+
12
+ RAR is a an autoregressive (AR) image generator with full compatibility to language modeling. It introduces a randomness annealing strategy with permuted objective at no additional cost, which enhances the model's ability to learn bidirectional contexts while leaving the autoregressive framework intact. RAR sets a FID score 1.48, demonstrating state-of-the-art performance on ImageNet-256 benchmark and significantly outperforming prior AR image generators.
13
+
14
+ <p>
15
+ <img src="assets/rar_overview.png" alt="teaser" width=90% height=90%>
16
+ </p>
17
+ <p>
18
+ <img src="assets/perf_comp.png" alt="teaser" width=90% height=90%>
19
+ </p>
20
+
21
+ See more details at [README_RAR](README_RAR.md).
22
+
23
+ ## Short Intro on [An Image is Worth 32 Tokens for Reconstruction and Generation](https://arxiv.org/abs/2406.07550) ([README](README_TiTok.md))
24
+
25
+ We present a compact 1D tokenizer which can represent an image with as few as 32 discrete tokens. As a result, it leads to a substantial speed-up on the sampling process (e.g., **410 × faster** than DiT-XL/2) while obtaining a competitive generation quality.
26
+
27
+ <p>
28
+ <img src="assets/titok_teaser.png" alt="teaser" width=90% height=90%>
29
+ </p>
30
+ <p>
31
+ <img src="assets/speed_vs_perf.png" alt="teaser" width=90% height=90%>
32
+ </p>
33
+
34
+ See more details at [README_TiTok](README_TiTok.md).
35
+
36
+ ## Updates
37
+ - 11/04/2024: We release the [tech report](https://arxiv.org/abs/2411.00776) and code for RAR models.
38
+ - 10/16/2024: We update a set of TiTok tokenizer weights trained with an updated single-stage recipe, leading to easier training and better performance. We release the weight of different model size for both VQ and VAE variants TiTok, which we hope could facilitate the research in this area. More details will be available in a tech report later.
39
+ - 09/25/2024: TiTok is accepted by NeurIPS 2024.
40
+ - 09/11/2024: Release the training codes of generator based on TiTok.
41
+ - 08/28/2024: Release the training codes of TiTok.
42
+ - 08/09/2024: Better support on loading pretrained weights from huggingface models, thanks for the help from [@NielsRogge](https://github.com/NielsRogge)!
43
+ - 07/03/2024: Evaluation scripts for reproducing the results reported in the paper, checkpoints of TiTok-B64 and TiTok-S128 are available.
44
+ - 06/21/2024: Demo code and TiTok-L-32 checkpoints release.
45
+ - 06/11/2024: The [tech report](https://arxiv.org/abs/2406.07550) of TiTok is available.
46
+
47
+
48
+ ## Installation
49
+ ```shell
50
+ pip3 install -r requirements.txt
51
+ ```
52
+
53
+ ## Citing
54
+ If you use our work in your research, please use the following BibTeX entry.
55
+
56
+ ```BibTeX
57
+ @article{yu2024randomized,
58
+ author = {Qihang Yu and Ju He and Xueqing Deng and Xiaohui Shen and Liang-Chieh Chen},
59
+ title = {Randomized Autoregressive Visual Generation},
60
+ journal = {arXiv preprint arXiv:2411.00776},
61
+ year = {2024}
62
+ }
63
+ ```
64
+
65
+ ```BibTeX
66
+ @article{yu2024an,
67
+ author = {Qihang Yu and Mark Weber and Xueqing Deng and Xiaohui Shen and Daniel Cremers and Liang-Chieh Chen},
68
+ title = {An Image is Worth 32 Tokens for Reconstruction and Generation},
69
+ journal = {NeurIPS},
70
+ year = {2024}
71
+ }
72
+ ```
73
+
74
+ ## Acknowledgement
75
+
76
+ [MaskGIT](https://github.com/google-research/maskgit)
77
+
78
+ [Taming-Transformers](https://github.com/CompVis/taming-transformers)
79
+
80
+ [Open-MUSE](https://github.com/huggingface/open-muse)
81
+
82
+ [MUSE-Pytorch](https://github.com/baaivision/MUSE-Pytorch)
README_RAR.md ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Randomized Autoregressive Visual Generation
2
+
3
+
4
+ RAR is a an autoregressive (AR) image generator with full compatibility to language modeling. It introduces a randomness annealing strategy with permuted objective at no additional cost, which enhances the model's ability to learn bidirectional contexts while leaving the autoregressive framework intact. RAR sets a FID score 1.48, demonstrating state-of-the-art performance on ImageNet-256 benchmark and significantly outperforming prior AR image generators.
5
+
6
+
7
+ <p>
8
+ <img src="assets/rar_overview.png" alt="teaser" width=90% height=90%>
9
+ </p>
10
+ <p>
11
+ <img src="assets/perf_comp.png" alt="teaser" width=90% height=90%>
12
+ </p>
13
+
14
+
15
+ ## 🚀 Contributions
16
+
17
+ #### We introduce RAR, an improved training strategy enabling standard autoregressive image generator to achieve state-of-the-art performance.
18
+
19
+ #### The proposed RAR is extremly simple yet effective: During training, we randomly permute the input token sequence with probability r, where r will starts at 1.0 and linearly decays to 0.0 over the course of training. This simple strategy enbales better bidirectional representation learning which is missing in standard raster-order-based AR image generator training.
20
+
21
+ #### RAR keeps the AR framework intact, and thus it is totally compatible to the LLM optimization techniques, such as KV-cache, leading to a significantly faster sampling speed compared to MAR-H or MaskBit while maintaining a better performance.
22
+
23
+ ## Model Zoo
24
+ | Model | Link | FID |
25
+ | ------------- | ------------- | ------------- |
26
+ | RAR-B | [checkpoint](https://huggingface.co/yucornetto/RAR/blob/main/rar_b.bin)| 1.95 (generation) |
27
+ | RAR-L | [checkpoint](https://huggingface.co/yucornetto/RAR/blob/main/rar_l.bin)| 1.70 (generation) |
28
+ | RAR-XL | [checkpoint](https://huggingface.co/yucornetto/RAR/blob/main/rar_xl.bin)| 1.50 (generation) |
29
+ | RAR-XXL | [checkpoint](https://huggingface.co/yucornetto/RAR/blob/main/rar_xxl.bin)| 1.48 (generation) |
30
+
31
+ Please note that these models are trained only on limited academic dataset ImageNet, and they are only for research purposes.
32
+
33
+ ## Installation
34
+ ```shell
35
+ pip3 install -r requirements.txt
36
+ ```
37
+
38
+ ## Get Started
39
+ ```python
40
+ import torch
41
+ from PIL import Image
42
+ import numpy as np
43
+ import demo_util
44
+ from huggingface_hub import hf_hub_download
45
+ from utils.train_utils import create_pretrained_tokenizer
46
+
47
+
48
+ # Choose one from ["rar_b_imagenet", "rar_l_imagenet", "rar_xl_imagenet", "rar_xxl_imagenet"]
49
+ rar_model_name = ["rar_b", "rar_l", "rar_xl", "rar_xxl"][3]
50
+
51
+ # download the maskgit-vq tokenizer
52
+ hf_hub_download(repo_id="fun-research/TiTok", filename=f"maskgit-vqgan-imagenet-f16-256.bin", local_dir="./")
53
+ # download the rar generator weight
54
+ hf_hub_download(repo_id="yucornetto/RAR", filename=f"{rar_model_name}.bin", local_dir="./")
55
+
56
+ # load config
57
+ # config = demo_util.get_config("configs/infer/titok_l32.yaml")
58
+ # titok_tokenizer = demo_util.get_titok_tokenizer(config)
59
+ # titok_generator = demo_util.get_titok_generator(config)
60
+
61
+ device = "cuda"
62
+ # maskgit-vq as tokenizer
63
+ tokenizer = create_pretrained_tokenizer(config)
64
+ generator = demo_util.get_rar_generator(config)
65
+ tokenizer.to(device)
66
+ generator.to(device)
67
+
68
+ # generate an image
69
+ sample_labels = [torch.randint(0, 999, size=(1,)).item()] # random IN-1k class
70
+ generated_image = demo_util.sample_fn(
71
+ generator=generator,
72
+ tokenizer=tokenizer,
73
+ labels=sample_labels,
74
+ randomize_temperature=1.0,
75
+ guidance_scale=4.0,
76
+ guidance_scale_pow=0.0, # constant cfg
77
+ device=device
78
+ )
79
+ Image.fromarray(generated_image[0]).save(f"assets/rar_generated_{sample_labels[0]}.png")
80
+ ```
81
+
82
+ ## Testing on ImageNet-1K Benchmark
83
+
84
+ We provide a [sampling script](./sample_imagenet_rar.py) for reproducing the generation results on ImageNet-1K benchmark.
85
+ ```bash
86
+ # Prepare ADM evaluation script
87
+ git clone https://github.com/openai/guided-diffusion.git
88
+
89
+ wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
90
+ ```
91
+ ```python
92
+ # Reproducing RAR-B
93
+ torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_rar.py config=configs/training/generator/rar.yaml \
94
+ experiment.output_dir="rar_b" \
95
+ experiment.generator_checkpoint="rar_b.bin" \
96
+ model.generator.hidden_size=768 \
97
+ model.generator.num_hidden_layers=24 \
98
+ model.generator.num_attention_heads=16 \
99
+ model.generator.intermediate_size=3072 \
100
+ model.generator.randomize_temperature=1.0 \
101
+ model.generator.guidance_scale=16.0 \
102
+ model.generator.guidance_scale_pow=2.75
103
+ # Run eval script. The result FID should be ~1.95
104
+ python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz rar_b.npz
105
+
106
+ # Reproducing RAR-L
107
+ torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_rar.py config=configs/training/generator/rar.yaml \
108
+ experiment.output_dir="rar_l" \
109
+ experiment.generator_checkpoint="rar_l.bin" \
110
+ model.generator.hidden_size=1024 \
111
+ model.generator.num_hidden_layers=24 \
112
+ model.generator.num_attention_heads=16 \
113
+ model.generator.intermediate_size=4096 \
114
+ model.generator.randomize_temperature=1.02 \
115
+ model.generator.guidance_scale=15.5 \
116
+ model.generator.guidance_scale_pow=2.5
117
+ # Run eval script. The result FID should be ~1.70
118
+ python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz rar_l.npz
119
+
120
+ # Reproducing RAR-XL
121
+ torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_rar.py config=configs/training/generator/rar.yaml \
122
+ experiment.output_dir="rar_xl" \
123
+ experiment.generator_checkpoint="rar_xl.bin" \
124
+ model.generator.hidden_size=1280 \
125
+ model.generator.num_hidden_layers=32 \
126
+ model.generator.num_attention_heads=16 \
127
+ model.generator.intermediate_size=5120 \
128
+ model.generator.randomize_temperature=1.02 \
129
+ model.generator.guidance_scale=6.9 \
130
+ model.generator.guidance_scale_pow=1.5
131
+ # Run eval script. The result FID should be ~1.50
132
+ python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz rar_xl.npz
133
+
134
+ # Reproducing RAR-XXL
135
+ torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_rar.py config=configs/training/generator/rar.yaml \
136
+ experiment.output_dir="rar_xxl" \
137
+ experiment.generator_checkpoint="rar_xxl.bin" \
138
+ model.generator.hidden_size=1408 \
139
+ model.generator.num_hidden_layers=40 \
140
+ model.generator.num_attention_heads=16 \
141
+ model.generator.intermediate_size=6144 \
142
+ model.generator.randomize_temperature=1.02 \
143
+ model.generator.guidance_scale=8.0 \
144
+ model.generator.guidance_scale_pow=1.2
145
+ # Run eval script. The result FID should be ~1.48
146
+ python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz rar_xxl.npz
147
+ ```
148
+ ## Training Preparation
149
+ We pretokenize the whole dataset for speed-up the training process. We have uploaded [it](https://huggingface.co/yucornetto/RAR/blob/main/maskgitvq.jsonl) so you can train RAR directly. The training script will download the prerequisite checkpoints and dataset automatically.
150
+
151
+ ## Training
152
+ We provide example commands to train RAR as follows:
153
+ ```bash
154
+ # Training for RAR-B
155
+ WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP} --main_process_port=${ROOT_PORT} --same_network scripts/train_rar.py config=configs/training/generator/rar.yaml \
156
+ experiment.project="rar" \
157
+ experiment.name="rar_b" \
158
+ experiment.output_dir="rar_b" \
159
+ model.generator.hidden_size=768 \
160
+ model.generator.num_hidden_layers=24 \
161
+ model.generator.num_attention_heads=16 \
162
+ model.generator.intermediate_size=3072
163
+
164
+ # Training for RAR-L
165
+ WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP} --main_process_port=${ROOT_PORT} --same_network scripts/train_rar.py config=configs/training/generator/rar.yaml \
166
+ experiment.project="rar" \
167
+ experiment.name="rar_l" \
168
+ experiment.output_dir="rar_l" \
169
+ model.generator.hidden_size=1024 \
170
+ model.generator.num_hidden_layers=24 \
171
+ model.generator.num_attention_heads=16 \
172
+ model.generator.intermediate_size=4096
173
+
174
+ # Training for RAR-XL
175
+ WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP} --main_process_port=${ROOT_PORT} --same_network scripts/train_rar.py config=configs/training/generator/rar.yaml \
176
+ experiment.project="rar" \
177
+ experiment.name="rar_xl" \
178
+ experiment.output_dir="rar_xl" \
179
+ model.generator.hidden_size=1280 \
180
+ model.generator.num_hidden_layers=32 \
181
+ model.generator.num_attention_heads=16 \
182
+ model.generator.intermediate_size=5120
183
+
184
+ # Training for RAR-XXL
185
+ WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP} --main_process_port=${ROOT_PORT} --same_network scripts/train_rar.py config=configs/training/generator/rar.yaml \
186
+ experiment.project="rar" \
187
+ experiment.name="rar_xxl" \
188
+ experiment.output_dir="rar_xxl" \
189
+ model.generator.hidden_size=1408 \
190
+ model.generator.num_hidden_layers=40 \
191
+ model.generator.num_attention_heads=16 \
192
+ model.generator.intermediate_size=6144
193
+ ```
194
+ You may remove the flag "WANDB_MODE=offline" to support online wandb logging, if you have configured it.
195
+
196
+ Notably, you can enable grad checkpointing by adding the flag "model.generator.use_checkpoint=True" and adjust the machine number & GPU number based on your own need. All RAR checkpoints were trained with a global batchsize = 2048.
197
+
198
+
199
+ ## Visualizations
200
+ <p>
201
+ <img src="assets/vis1.png" alt="teaser" width=90% height=90%>
202
+ </p>
203
+ <p>
204
+ <img src="assets/vis2.png" alt="teaser" width=90% height=90%>
205
+ </p>
206
+ <p>
207
+ <img src="assets/vis3.png" alt="teaser" width=90% height=90%>
208
+ </p>
209
+
210
+
211
+ ## Citing
212
+ If you use our work in your research, please use the following BibTeX entry.
213
+
214
+ ```BibTeX
215
+ @inproceedings{yu2024randomized,
216
+ author = {Qihang Yu and Ju He and Xueqing Deng and Xiaohui Shen and Liang-Chieh Chen},
217
+ title = {Randomized Autoregressive Visual Generation},
218
+ journal = {arXiv preprint arXiv:2411.00776},
219
+ year = {2024}
220
+ }
221
+ ```
README_TiTok.md ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # (NeurIPS 2024) Compact and Mighty - Image Tokenization with Only 32 Tokens for both Reconstruction and Generation!
2
+
3
+ <div align="center">
4
+
5
+ [![demo](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Online_Demo-blue)](https://huggingface.co/spaces/fun-research/TiTok)&nbsp;&nbsp;
6
+ [![Website](https://img.shields.io/badge/Project-Website-87CEEB)](https://yucornetto.github.io/projects/titok.html)&nbsp;&nbsp;
7
+ [![paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2406.07550)&nbsp;&nbsp;
8
+
9
+ </div>
10
+
11
+ We present a compact 1D tokenizer which can represent an image with as few as 32 discrete tokens. As a result, it leads to a substantial speed-up on the sampling process (e.g., **410 × faster** than DiT-XL/2) while obtaining a competitive generation quality.
12
+
13
+
14
+ <p>
15
+ <img src="assets/titok_teaser.png" alt="teaser" width=90% height=90%>
16
+ </p>
17
+ <p>
18
+ <img src="assets/speed_vs_perf.png" alt="teaser" width=90% height=90%>
19
+ </p>
20
+
21
+
22
+ ## 🚀 Contributions
23
+
24
+ #### We introduce a novel 1D image tokenization framework that breaks grid constraints existing in 2D tokenization methods, leading to a much more flexible and compact image latent representation.
25
+
26
+ #### The proposed 1D tokenizer can tokenize a 256 × 256 image into as few as 32 discrete tokens, leading to a significant speed-up (hundreds times faster than diffusion models) in generation process, while maintaining state-of-the-art generation quality.
27
+
28
+ #### We conduct a series of experiments to probe the properties of rarely studied 1D image tokenization, paving the path towards compact latent space for efficient and effective image representation.
29
+
30
+ ## Model Zoo
31
+ | Model | Link | FID |
32
+ | ------------- | ------------- | ------------- |
33
+ | TiTok-L-32 Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_l32_imagenet)| 2.21 (reconstruction) |
34
+ | TiTok-B-64 Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_b64_imagenet) | 1.70 (reconstruction) |
35
+ | TiTok-S-128 Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_s128_imagenet) | 1.71 (reconstruction) |
36
+ | TiTok-L-32 Generator | [checkpoint](https://huggingface.co/yucornetto/generator_titok_l32_imagenet) | 2.77 (generation) |
37
+ | TiTok-B-64 Generator | [checkpoint](https://huggingface.co/yucornetto/generator_titok_b64_imagenet) | 2.48 (generation) |
38
+ | TiTok-S-128 Generator | [checkpoint](https://huggingface.co/yucornetto/generator_titok_s128_imagenet) | 1.97 (generation) |
39
+ | TiTok-BL-64 VQ Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_bl64_vq8k_imagenet)| 2.06 (reconstruction) |
40
+ | TiTok-BL-128 VQ Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_bl128_vq8k_imagenet)| 1.49 (reconstruction) |
41
+ | TiTok-SL-256 VQ Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_sl256_vq8k_imagenet)| 1.03 (reconstruction) |
42
+ | TiTok-LL-32 VAE Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_ll32_vae_c16_imagenet)| 1.61 (reconstruction) |
43
+ | TiTok-BL-64 VAE Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_bl64_vae_c16_imagenet)| 1.25 (reconstruction) |
44
+ | TiTok-BL-128 VAE Tokenizer | [checkpoint](https://huggingface.co/yucornetto/tokenizer_titok_bl128_vae_c16_imagenet)| 0.84 (reconstruction) |
45
+
46
+ Please note that these models are trained only on limited academic dataset ImageNet, and they are only for research purposes.
47
+
48
+ ## Installation
49
+ ```shell
50
+ pip3 install -r requirements.txt
51
+ ```
52
+
53
+ ## Get Started
54
+ ```python
55
+ import torch
56
+ from PIL import Image
57
+ import numpy as np
58
+ import demo_util
59
+ from huggingface_hub import hf_hub_download
60
+ from modeling.maskgit import ImageBert
61
+ from modeling.titok import TiTok
62
+
63
+ # Choose one from ["tokenizer_titok_l32_imagenet", "tokenizer_titok_b64_imagenet",
64
+ # "tokenizer_titok_s128_imagenet", "tokenizer_titok_bl128_vae_c16_imagenet", tokenizer_titok_bl64_vae_c16_imagenet",
65
+ # "tokenizer_titok_ll32_vae_c16_imagenet", "tokenizer_titok_sl256_vq8k_imagenet", "tokenizer_titok_bl128_vq8k_imagenet",
66
+ # "tokenizer_titok_bl64_vq8k_imagenet",]
67
+ titok_tokenizer = TiTok.from_pretrained("yucornetto/tokenizer_titok_l32_imagenet")
68
+ titok_tokenizer.eval()
69
+ titok_tokenizer.requires_grad_(False)
70
+ titok_generator = ImageBert.from_pretrained("yucornetto/generator_titok_l32_imagenet")
71
+ titok_generator.eval()
72
+ titok_generator.requires_grad_(False)
73
+
74
+ # or alternatively, downloads from hf
75
+ # hf_hub_download(repo_id="fun-research/TiTok", filename="tokenizer_titok_l32.bin", local_dir="./")
76
+ # hf_hub_download(repo_id="fun-research/TiTok", filename="generator_titok_l32.bin", local_dir="./")
77
+
78
+ # load config
79
+ # config = demo_util.get_config("configs/infer/titok_l32.yaml")
80
+ # titok_tokenizer = demo_util.get_titok_tokenizer(config)
81
+ # titok_generator = demo_util.get_titok_generator(config)
82
+
83
+ device = "cuda"
84
+ titok_tokenizer = titok_tokenizer.to(device)
85
+ titok_generator = titok_generator.to(device)
86
+
87
+ # reconstruct an image. I.e., image -> 32 tokens -> image
88
+ img_path = "assets/ILSVRC2012_val_00010240.png"
89
+ image = torch.from_numpy(np.array(Image.open(img_path)).astype(np.float32)).permute(2, 0, 1).unsqueeze(0) / 255.0
90
+ # tokenization
91
+ if titok_tokenizer.quantize_mode == "vq":
92
+ encoded_tokens = titok_tokenizer.encode(image.to(device))[1]["min_encoding_indices"]
93
+ elif titok_tokenizer.quantize_mode == "vae":
94
+ posteriors = titok_tokenizer.encode(image.to(device))[1]
95
+ encoded_tokens = posteriors.sample()
96
+ else:
97
+ raise NotImplementedError
98
+ # image assets/ILSVRC2012_val_00010240.png is encoded into tokens tensor([[[ 887, 3979, 349, 720, 2809, 2743, 2101, 603, 2205, 1508, 1891, 4015, 1317, 2956, 3774, 2296, 484, 2612, 3472, 2330, 3140, 3113, 1056, 3779, 654, 2360, 1901, 2908, 2169, 953, 1326, 2598]]], device='cuda:0'), with shape torch.Size([1, 1, 32])
99
+ print(f"image {img_path} is encoded into tokens {encoded_tokens}, with shape {encoded_tokens.shape}")
100
+ # de-tokenization
101
+ reconstructed_image = titok_tokenizer.decode_tokens(encoded_tokens)
102
+ reconstructed_image = torch.clamp(reconstructed_image, 0.0, 1.0)
103
+ reconstructed_image = (reconstructed_image * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()[0]
104
+ reconstructed_image = Image.fromarray(reconstructed_image).save("assets/ILSVRC2012_val_00010240_recon.png")
105
+
106
+ # generate an image
107
+ sample_labels = [torch.randint(0, 999, size=(1,)).item()] # random IN-1k class
108
+ generated_image = demo_util.sample_fn(
109
+ generator=titok_generator,
110
+ tokenizer=titok_tokenizer,
111
+ labels=sample_labels,
112
+ guidance_scale=4.5,
113
+ randomize_temperature=1.0,
114
+ num_sample_steps=8,
115
+ device=device
116
+ )
117
+ Image.fromarray(generated_image[0]).save(f"assets/generated_{sample_labels[0]}.png")
118
+ ```
119
+
120
+ We also provide a [jupyter notebook](demo.ipynb) for a quick tutorial on reconstructing and generating images with TiTok-L-32.
121
+
122
+ We also support TiTok with [HuggingFace 🤗 Demo](https://huggingface.co/spaces/fun-research/TiTok)!
123
+
124
+ ## Testing on ImageNet-1K Benchmark
125
+
126
+ We provide a [sampling script](./sample_imagenet_titok.py) for reproducing the generation results on ImageNet-1K benchmark.
127
+ ```bash
128
+ # Prepare ADM evaluation script
129
+ git clone https://github.com/openai/guided-diffusion.git
130
+
131
+ wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
132
+ ```
133
+ ```python
134
+ # Reproducing TiTok-L-32
135
+ torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_titok.py config=configs/infer/titok_l32.yaml experiment.output_dir="titok_l_32"
136
+ # Run eval script. The result FID should be ~2.77
137
+ python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz titok_l_32.npz
138
+
139
+ # Reproducing TiTok-B-64
140
+ torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_titok.py config=configs/infer/titok_b64.yaml experiment.output_dir="titok_b_64"
141
+ # Run eval script. The result FID should be ~2.48
142
+ python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz titok_b_64.npz
143
+
144
+ # Reproducing TiTok-S-128
145
+ torchrun --nnodes=1 --nproc_per_node=8 --rdzv-endpoint=localhost:9999 sample_imagenet_titok.py config=configs/infer/titok_s128.yaml experiment.output_dir="titok_s_128"
146
+ # Run eval script. The result FID should be ~1.97
147
+ python3 guided-diffusion/evaluations/evaluator.py VIRTUAL_imagenet256_labeled.npz titok_s_128.npz
148
+ ```
149
+ ## Training Preparation
150
+ We use [webdataset](https://github.com/webdataset/webdataset) format for data loading. To begin with, it is needed to convert the dataset into webdataset format. An example script to convert ImageNet to wds format is provided [here](./data/convert_imagenet_to_wds.py).
151
+
152
+ Furthermore, the stage1 training relies on a pre-trained MaskGIT-VQGAN to generate proxy codes as learning targets. You can convert the [official Jax weight](https://github.com/google-research/maskgit) to PyTorch version using [this script](https://github.com/huggingface/open-muse/blob/main/scripts/convert_maskgit_vqgan.py). Alternatively, we provided a converted version at [HuggingFace](https://huggingface.co/fun-research/TiTok/blob/main/maskgit-vqgan-imagenet-f16-256.bin) and [Google Drive](https://drive.google.com/file/d/1DjZqzJrUt2hwpmUPkjGSBTFEJcOkLY-Q/view?usp=sharing). The MaskGIT-VQGAN's weight will be automatically downloaded when you run the training script.
153
+
154
+ ## Training
155
+ We provide example commands to train TiTok as follows:
156
+ ```bash
157
+ # Training for TiTok-B64
158
+ # Stage 1
159
+ WANDB_MODE=offline accelerate launch --num_machines=1 --num_processes=8 --machine_rank=0 --main_process_ip=127.0.0.1 --main_process_port=9999 --same_network scripts/train_titok.py config=configs/training/stage1/titok_b64.yaml \
160
+ experiment.project="titok_b64_stage1" \
161
+ experiment.name="titok_b64_stage1_run1" \
162
+ experiment.output_dir="titok_b64_stage1_run1" \
163
+ training.per_gpu_batch_size=32
164
+
165
+ # Stage 2
166
+ WANDB_MODE=offline accelerate launch --num_machines=1 --num_processes=8 --machine_rank=0 --main_process_ip=127.0.0.1 --main_process_port=9999 --same_network scripts/train_titok.py config=configs/training/stage2/titok_b64.yaml \
167
+ experiment.project="titok_b64_stage2" \
168
+ experiment.name="titok_b64_stage2_run1" \
169
+ experiment.output_dir="titok_b64_stage2_run1" \
170
+ training.per_gpu_batch_size=32 \
171
+ experiment.init_weight=${PATH_TO_STAGE1_WEIGHT}
172
+
173
+ # Train Generator (TiTok-B64 as example)
174
+ WANDB_MODE=offline accelerate launch --num_machines=4 --num_processes=32 --machine_rank=${MACHINE_RANK} --main_process_ip=${ROOT_IP}--main_process_port=${ROOT_PORT} --same_network scripts/train_maskgit.py config=configs/training/generator/maskgit.yaml \
175
+ experiment.project="titok_generation" \
176
+ experiment.name="titok_b64_maskgit" \
177
+ experiment.output_dir="titok_b64_maskgit" \
178
+ experiment.tokenizer_checkpoint=${PATH_TO_STAGE1_or_STAGE2_WEIGHT}
179
+ ```
180
+ You may remove the flag "WANDB_MODE=offline" to support online wandb logging, if you have configured it.
181
+
182
+ The config _titok_b64.yaml_ can be replaced with _titok_s128.yaml_ or _titok_l32.yaml_ for other TiTok variants.
183
+
184
+ ## Visualizations
185
+ <p>
186
+ <img src="assets/recon_w_model_size_num_token.png" alt="teaser" width=90% height=90%>
187
+ </p>
188
+ <p>
189
+ <img src="assets/random_vis_l32.png" alt="teaser" width=90% height=90%>
190
+ </p>
191
+
192
+
193
+ ## Citing
194
+ If you use our work in your research, please use the following BibTeX entry.
195
+
196
+ ```BibTeX
197
+ @inproceedings{yu2024an,
198
+ author = {Qihang Yu and Mark Weber and Xueqing Deng and Xiaohui Shen and Daniel Cremers and Liang-Chieh Chen},
199
+ title = {An Image is Worth 32 Tokens for Reconstruction and Generation},
200
+ journal = {NeurIPS},
201
+ year = {2024}
202
+ }
203
+ ```
204
+
205
+ ## Acknowledgement
206
+
207
+ [MaskGIT](https://github.com/google-research/maskgit)
208
+
209
+ [Taming-Transformers](https://github.com/CompVis/taming-transformers)
210
+
211
+ [Open-MUSE](https://github.com/huggingface/open-muse)
212
+
213
+ [MUSE-Pytorch](https://github.com/baaivision/MUSE-Pytorch)
assets/ILSVRC2012_val_00008636.png ADDED
assets/ILSVRC2012_val_00010240.png ADDED
assets/perf_comp.png ADDED
assets/random_vis_l32.png ADDED

Git LFS Details

  • SHA256: ff40d0274f7d6656791e4fc72afbf0d46b0a3975803d6184a46baac0ab80438e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.27 MB
assets/rar_overview.png ADDED

Git LFS Details

  • SHA256: 0ce147111f79db6f3dbc0190e861b5b6efea4c5965422147054ca31aebe25ec5
  • Pointer size: 132 Bytes
  • Size of remote file: 2.17 MB
assets/recon_w_model_size_num_token.png ADDED

Git LFS Details

  • SHA256: 8e5fe53bb8aa64fe918a33de92ac2d965d46871298eeec6fcd2a4a00f1b75386
  • Pointer size: 132 Bytes
  • Size of remote file: 1.49 MB
assets/speed_vs_perf.png ADDED
assets/titok_teaser.png ADDED
assets/vis1.png ADDED

Git LFS Details

  • SHA256: f407322aa70ff4288db1ec8a3ddf6ab3e5d6149f33ce8e57e8493c8cb3ae7aca
  • Pointer size: 132 Bytes
  • Size of remote file: 3.54 MB
assets/vis2.png ADDED

Git LFS Details

  • SHA256: f739117053cdbff67aebd95f0bf9cfdb3f2fdf900c030776fdc615a3b1672a69
  • Pointer size: 132 Bytes
  • Size of remote file: 3.32 MB
assets/vis3.png ADDED

Git LFS Details

  • SHA256: 25914cd1bde3d34b6df46190ad77e046e6cfd57fdd58599cadd131f2ad92fb84
  • Pointer size: 132 Bytes
  • Size of remote file: 2.32 MB
configs/infer/titok_b64.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ tokenizer_checkpoint: "tokenizer_titok_b64.bin"
3
+ generator_checkpoint: "generator_titok_b64.bin"
4
+ output_dir: "titok_b_64"
5
+ model:
6
+ vq_model:
7
+ codebook_size: 4096
8
+ token_size: 12
9
+ use_l2_norm: True
10
+ commitment_cost: 0.25
11
+ # vit arch
12
+ vit_enc_model_size: "base"
13
+ vit_dec_model_size: "base"
14
+ vit_enc_patch_size: 16
15
+ vit_dec_patch_size: 16
16
+ num_latent_tokens: 64
17
+ finetune_decoder: True
18
+
19
+ generator:
20
+ model_type: "ViT"
21
+ hidden_size: 768
22
+ num_hidden_layers: 24
23
+ num_attention_heads: 16
24
+ intermediate_size: 3072
25
+ dropout: 0.1
26
+ attn_drop: 0.1
27
+ num_steps: 8
28
+ class_label_dropout: 0.1
29
+ image_seq_len: ${model.vq_model.num_latent_tokens}
30
+ condition_num_classes: 1000
31
+
32
+ # sampling hyper-params
33
+ randomize_temperature: 11.0
34
+ guidance_scale: 3.0
35
+ guidance_decay: "linear"
36
+
37
+ dataset:
38
+ preprocessing:
39
+ crop_size: 256
configs/infer/titok_bl128_vae_c16.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ tokenizer_checkpoint: "titok_bl128_vae_c16.bin"
3
+ output_dir: "titok_bl128_vae_c16"
4
+ model:
5
+ vq_model:
6
+ quantize_mode: "vae"
7
+ token_size: 16
8
+ # vit arch
9
+ vit_enc_model_size: "base"
10
+ vit_dec_model_size: "large"
11
+ vit_enc_patch_size: 16
12
+ vit_dec_patch_size: 16
13
+ num_latent_tokens: 128
14
+ finetune_decoder: False
15
+ is_legacy: False
16
+
17
+ dataset:
18
+ preprocessing:
19
+ crop_size: 256
configs/infer/titok_bl128_vq8k.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ tokenizer_checkpoint: "tokenizer_titok_bl128_vq8k.bin"
3
+ output_dir: "titok_bl128_vq8k"
4
+ model:
5
+ vq_model:
6
+ codebook_size: 8192
7
+ token_size: 64
8
+ use_l2_norm: False
9
+ commitment_cost: 0.25
10
+ # vit arch
11
+ vit_enc_model_size: "base"
12
+ vit_dec_model_size: "large"
13
+ vit_enc_patch_size: 16
14
+ vit_dec_patch_size: 16
15
+ num_latent_tokens: 128
16
+ finetune_decoder: False
17
+ is_legacy: False
18
+
19
+ dataset:
20
+ preprocessing:
21
+ crop_size: 256
configs/infer/titok_bl64_vae_c16.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ tokenizer_checkpoint: "titok_bl64_vae_c16.bin"
3
+ output_dir: "titok_bl64_vae_c16"
4
+ model:
5
+ vq_model:
6
+ quantize_mode: "vae"
7
+ token_size: 16
8
+ # vit arch
9
+ vit_enc_model_size: "base"
10
+ vit_dec_model_size: "large"
11
+ vit_enc_patch_size: 16
12
+ vit_dec_patch_size: 16
13
+ num_latent_tokens: 64
14
+ finetune_decoder: False
15
+ is_legacy: False
16
+
17
+ dataset:
18
+ preprocessing:
19
+ crop_size: 256
configs/infer/titok_bl64_vq8k.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ tokenizer_checkpoint: "tokenizer_titok_bl64_vq8k.bin"
3
+ output_dir: "titok_bl64_vq8k"
4
+ model:
5
+ vq_model:
6
+ codebook_size: 8192
7
+ token_size: 64
8
+ use_l2_norm: False
9
+ commitment_cost: 0.25
10
+ # vit arch
11
+ vit_enc_model_size: "base"
12
+ vit_dec_model_size: "large"
13
+ vit_enc_patch_size: 16
14
+ vit_dec_patch_size: 16
15
+ num_latent_tokens: 64
16
+ finetune_decoder: False
17
+ is_legacy: False
18
+
19
+ dataset:
20
+ preprocessing:
21
+ crop_size: 256
configs/infer/titok_l32.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ tokenizer_checkpoint: "tokenizer_titok_l32.bin"
3
+ generator_checkpoint: "generator_titok_l32.bin"
4
+ output_dir: "titok_l_32"
5
+ model:
6
+ vq_model:
7
+ codebook_size: 4096
8
+ token_size: 12
9
+ use_l2_norm: True
10
+ commitment_cost: 0.25
11
+ # vit arch
12
+ vit_enc_model_size: "large"
13
+ vit_dec_model_size: "large"
14
+ vit_enc_patch_size: 16
15
+ vit_dec_patch_size: 16
16
+ num_latent_tokens: 32
17
+ finetune_decoder: True
18
+
19
+ generator:
20
+ model_type: "ViT"
21
+ hidden_size: 768
22
+ num_hidden_layers: 24
23
+ num_attention_heads: 16
24
+ intermediate_size: 3072
25
+ dropout: 0.1
26
+ attn_drop: 0.1
27
+ num_steps: 8
28
+ class_label_dropout: 0.1
29
+ image_seq_len: ${model.vq_model.num_latent_tokens}
30
+ condition_num_classes: 1000
31
+
32
+ # sampling hyper-params
33
+ randomize_temperature: 9.5
34
+ guidance_scale: 4.5
35
+ guidance_decay: "linear"
36
+
37
+
38
+ dataset:
39
+ preprocessing:
40
+ crop_size: 256
configs/infer/titok_ll32_vae_c16.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ tokenizer_checkpoint: "titok_ll32_vae_c16.bin"
3
+ output_dir: "titok_ll32_vae_c16"
4
+ model:
5
+ vq_model:
6
+ quantize_mode: "vae"
7
+ token_size: 16
8
+ # vit arch
9
+ vit_enc_model_size: "large"
10
+ vit_dec_model_size: "large"
11
+ vit_enc_patch_size: 16
12
+ vit_dec_patch_size: 16
13
+ num_latent_tokens: 32
14
+ finetune_decoder: False
15
+ is_legacy: False
16
+
17
+ dataset:
18
+ preprocessing:
19
+ crop_size: 256
configs/infer/titok_s128.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ tokenizer_checkpoint: "tokenizer_titok_s128.bin"
3
+ generator_checkpoint: "generator_titok_s128.bin"
4
+ output_dir: "titok_s_128"
5
+ model:
6
+ vq_model:
7
+ codebook_size: 4096
8
+ token_size: 12
9
+ use_l2_norm: True
10
+ commitment_cost: 0.25
11
+ # vit arch
12
+ vit_enc_model_size: "small"
13
+ vit_dec_model_size: "small"
14
+ vit_enc_patch_size: 16
15
+ vit_dec_patch_size: 16
16
+ num_latent_tokens: 128
17
+ finetune_decoder: True
18
+
19
+ generator:
20
+ model_type: "UViT"
21
+ hidden_size: 1024
22
+ num_hidden_layers: 20
23
+ num_attention_heads: 16
24
+ intermediate_size: 4096
25
+ dropout: 0.1
26
+ attn_drop: 0.1
27
+ num_steps: 64
28
+ class_label_dropout: 0.1
29
+ image_seq_len: ${model.vq_model.num_latent_tokens}
30
+ condition_num_classes: 1000
31
+
32
+ # sampling hyper-params
33
+ randomize_temperature: 2.8
34
+ guidance_scale: 6.9
35
+ guidance_decay: "power-cosine"
36
+
37
+ dataset:
38
+ preprocessing:
39
+ crop_size: 256
configs/infer/titok_sl256_vq8k.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ tokenizer_checkpoint: "tokenizer_titok_sl256_vq8k.bin"
3
+ output_dir: "titok_sl256_vq8k"
4
+ model:
5
+ vq_model:
6
+ codebook_size: 8192
7
+ token_size: 64
8
+ use_l2_norm: False
9
+ commitment_cost: 0.25
10
+ # vit arch
11
+ vit_enc_model_size: "small"
12
+ vit_dec_model_size: "large"
13
+ vit_enc_patch_size: 16
14
+ vit_dec_patch_size: 16
15
+ num_latent_tokens: 256
16
+ finetune_decoder: False
17
+ is_legacy: False
18
+
19
+ dataset:
20
+ preprocessing:
21
+ crop_size: 256
configs/training/generator/maskgit.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ project: "titok_generation"
3
+ name: "titok_b64_maskgit"
4
+ max_train_examples: 1_281_167
5
+ save_every: 50_000
6
+ eval_every: 50_000
7
+ generate_every: 5_000
8
+ log_every: 50
9
+ log_grad_norm_every: 1_000
10
+ resume: True
11
+ tokenizer_checkpoint: "tokenizer_titok_b64.bin"
12
+
13
+ model:
14
+ vq_model:
15
+ codebook_size: 4096
16
+ token_size: 12
17
+ use_l2_norm: True
18
+ commitment_cost: 0.25
19
+ # vit arch
20
+ vit_enc_model_size: "base"
21
+ vit_dec_model_size: "base"
22
+ vit_enc_patch_size: 16
23
+ vit_dec_patch_size: 16
24
+ num_latent_tokens: 64
25
+ finetune_decoder: True
26
+
27
+ generator:
28
+ model_type: "ViT"
29
+ hidden_size: 768
30
+ num_hidden_layers: 24
31
+ num_attention_heads: 16
32
+ intermediate_size: 3072
33
+ dropout: 0.1
34
+ attn_drop: 0.1
35
+ num_steps: 8
36
+ class_label_dropout: 0.1
37
+ image_seq_len: ${model.vq_model.num_latent_tokens}
38
+ condition_num_classes: 1000
39
+
40
+ # sampling hyper-params on the flight
41
+ randomize_temperature: 1.0
42
+ guidance_scale: 4.5
43
+ guidance_decay: "constant"
44
+
45
+ losses:
46
+ label_smoothing: 0.1
47
+ loss_weight_unmasked_token: 0.1
48
+
49
+ dataset:
50
+ params:
51
+ train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
52
+ eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
53
+ num_workers_per_gpu: 12
54
+ preprocessing:
55
+ resize_shorter_edge: 256
56
+ crop_size: 256
57
+ random_crop: False
58
+ random_flip: True
59
+
60
+ optimizer:
61
+ name: adamw
62
+ params:
63
+ learning_rate: 2e-4
64
+ beta1: 0.9
65
+ beta2: 0.96
66
+ weight_decay: 0.03
67
+
68
+
69
+ lr_scheduler:
70
+ scheduler: "cosine"
71
+ params:
72
+ learning_rate: ${optimizer.params.learning_rate}
73
+ warmup_steps: 10_000
74
+ end_lr: 1e-5
75
+
76
+
77
+ training:
78
+ gradient_accumulation_steps: 1
79
+ per_gpu_batch_size: 64 # 32 GPU, total batch size 2048
80
+ mixed_precision: "bf16"
81
+ enable_tf32: True
82
+ enable_wandb: True
83
+ use_ema: True
84
+ seed: 42
85
+ max_train_steps: 500_000
86
+ max_grad_norm: 1.0
configs/training/generator/rar.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ project: "rar_generation"
3
+ name: "rar-b"
4
+ max_train_examples: 1_281_167
5
+ save_every: 50_000
6
+ eval_every: 50_000
7
+ generate_every: 5_000
8
+ log_every: 50
9
+ log_grad_norm_every: 1_000
10
+ resume: True
11
+
12
+ model:
13
+ vq_model:
14
+ codebook_size: 1024
15
+ token_size: 256
16
+ num_latent_tokens: 256
17
+ finetune_decoder: False
18
+ pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"
19
+
20
+ generator:
21
+ hidden_size: 768
22
+ num_hidden_layers: 24
23
+ num_attention_heads: 16
24
+ intermediate_size: 3072
25
+ dropout: 0.1
26
+ attn_drop: 0.1
27
+ class_label_dropout: 0.1
28
+ image_seq_len: 256
29
+ condition_num_classes: 1000
30
+
31
+ # sampling hyper-params for RAR-B
32
+ randomize_temperature: 1.0
33
+ guidance_scale: 16.0
34
+ guidance_scale_pow: 2.75
35
+ use_checkpoint: False # True to save memory
36
+
37
+ randomness_anneal_start: 125000 # 200 epoch
38
+ randomness_anneal_end: 187500 # 300 epoch
39
+
40
+ dataset:
41
+ params:
42
+ # use pretokenized dataset for speed-up
43
+ pretokenization: "maskgitvq.jsonl"
44
+ train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
45
+ eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
46
+ num_workers_per_gpu: 12
47
+ preprocessing:
48
+ resize_shorter_edge: 256
49
+ crop_size: 256
50
+ random_crop: False
51
+ random_flip: True
52
+
53
+ optimizer:
54
+ name: adamw
55
+ params:
56
+ learning_rate: 4e-4
57
+ beta1: 0.9
58
+ beta2: 0.96
59
+ weight_decay: 0.03
60
+
61
+
62
+ lr_scheduler:
63
+ scheduler: "cosine"
64
+ params:
65
+ learning_rate: ${optimizer.params.learning_rate}
66
+ warmup_steps: 62_500 # 100 epochs with bsz 2048
67
+ end_lr: 1e-5
68
+
69
+ training:
70
+ gradient_accumulation_steps: 1
71
+ per_gpu_batch_size: 64 # 32 GPU, total batch size 2048
72
+ mixed_precision: "bf16"
73
+ enable_tf32: True
74
+ enable_wandb: True
75
+ use_ema: False
76
+ seed: 42
77
+ max_train_steps: 250_000
78
+ max_grad_norm: 1.0
configs/training/stage1/titok_b64.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ project: "titok_b_64_stage1"
3
+ name: "titok_b_64_stage1_run1"
4
+ output_dir: "titok_b_64_stage1_run1"
5
+ max_train_examples: 1_281_167
6
+ save_every: 50_000
7
+ eval_every: 50_000
8
+ generate_every: 5_000
9
+ log_every: 50
10
+ log_grad_norm_every: 1_000
11
+ resume: True
12
+ init_weight: ""
13
+
14
+
15
+ model:
16
+ vq_model:
17
+ codebook_size: 4096
18
+ token_size: 12
19
+ use_l2_norm: True
20
+ commitment_cost: 0.25
21
+ # vit arch
22
+ vit_enc_model_size: "base"
23
+ vit_dec_model_size: "base"
24
+ vit_enc_patch_size: 16
25
+ vit_dec_patch_size: 16
26
+ num_latent_tokens: 64
27
+ finetune_decoder: False
28
+ pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"
29
+
30
+ losses:
31
+ quantizer_weight: 1.0
32
+
33
+ dataset:
34
+ params:
35
+ train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
36
+ eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
37
+ num_workers_per_gpu: 12
38
+ preprocessing:
39
+ resize_shorter_edge: 256
40
+ crop_size: 256
41
+ random_crop: True
42
+ random_flip: True
43
+
44
+
45
+ optimizer:
46
+ name: adamw
47
+ params:
48
+ learning_rate: 1e-4
49
+ beta1: 0.9
50
+ beta2: 0.99
51
+ weight_decay: 1e-4
52
+
53
+ lr_scheduler:
54
+ scheduler: "cosine"
55
+ params:
56
+ learning_rate: ${optimizer.params.learning_rate}
57
+ warmup_steps: 10_000
58
+ end_lr: 1e-5
59
+
60
+ training:
61
+ gradient_accumulation_steps: 1
62
+ per_gpu_batch_size: 32
63
+ mixed_precision: "fp16"
64
+ enable_tf32: True
65
+ enable_wandb: True
66
+ use_ema: True
67
+ seed: 42
68
+ max_train_steps: 1_000_000
69
+ num_generated_images: 2
70
+ max_grad_norm: 1.0
configs/training/stage1/titok_l32.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ project: "titok_l_32_stage1"
3
+ name: "titok_l_32_stage1_run1"
4
+ output_dir: "titok_l_32_stage1_run1"
5
+ max_train_examples: 1_281_167
6
+ save_every: 50_000
7
+ eval_every: 50_000
8
+ generate_every: 5_000
9
+ log_every: 50
10
+ log_grad_norm_every: 1_000
11
+ resume: True
12
+ init_weight: ""
13
+
14
+
15
+ model:
16
+ vq_model:
17
+ codebook_size: 4096
18
+ token_size: 12
19
+ use_l2_norm: True
20
+ commitment_cost: 0.25
21
+ # vit arch
22
+ vit_enc_model_size: "large"
23
+ vit_dec_model_size: "large"
24
+ vit_enc_patch_size: 16
25
+ vit_dec_patch_size: 16
26
+ num_latent_tokens: 32
27
+ finetune_decoder: False
28
+ pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"
29
+
30
+ losses:
31
+ quantizer_weight: 1.0
32
+
33
+ dataset:
34
+ params:
35
+ train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
36
+ eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
37
+ num_workers_per_gpu: 12
38
+ preprocessing:
39
+ resize_shorter_edge: 256
40
+ crop_size: 256
41
+ random_crop: True
42
+ random_flip: True
43
+
44
+ optimizer:
45
+ name: adamw
46
+ params:
47
+ learning_rate: 1e-4
48
+ beta1: 0.9
49
+ beta2: 0.99
50
+ weight_decay: 1e-4
51
+ epsilon: 1e-8
52
+
53
+ lr_scheduler:
54
+ scheduler: "cosine"
55
+ params:
56
+ learning_rate: ${optimizer.params.learning_rate}
57
+ warmup_steps: 10_000
58
+ end_lr: 1e-5
59
+
60
+ training:
61
+ gradient_accumulation_steps: 1
62
+ per_gpu_batch_size: 32
63
+ mixed_precision: "fp16"
64
+ enable_tf32: True
65
+ enable_wandb: True
66
+ use_ema: True
67
+ seed: 42
68
+ max_train_steps: 1_000_000
69
+ num_generated_images: 2
70
+ max_grad_norm: 1.0
configs/training/stage1/titok_s128.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ project: "titok_s_128_stage1"
3
+ name: "titok_s_128_stage1_run1"
4
+ output_dir: "titok_s_128_stage1_run1"
5
+ max_train_examples: 1_281_167
6
+ save_every: 50_000
7
+ eval_every: 50_000
8
+ generate_every: 5_000
9
+ log_every: 50
10
+ log_grad_norm_every: 1_000
11
+ resume: True
12
+ init_weight: ""
13
+
14
+
15
+ model:
16
+ vq_model:
17
+ codebook_size: 4096
18
+ token_size: 12
19
+ use_l2_norm: True
20
+ commitment_cost: 0.25
21
+ # vit arch
22
+ vit_enc_model_size: "small"
23
+ vit_dec_model_size: "small"
24
+ vit_enc_patch_size: 16
25
+ vit_dec_patch_size: 16
26
+ num_latent_tokens: 128
27
+ finetune_decoder: False
28
+ pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"
29
+
30
+ losses:
31
+ quantizer_weight: 1.0
32
+
33
+ dataset:
34
+ params:
35
+ train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
36
+ eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
37
+ num_workers_per_gpu: 12
38
+ preprocessing:
39
+ resize_shorter_edge: 256
40
+ crop_size: 256
41
+ random_crop: True
42
+ random_flip: True
43
+
44
+ optimizer:
45
+ name: adamw
46
+ params:
47
+ learning_rate: 1e-4
48
+ beta1: 0.9
49
+ beta2: 0.99
50
+ weight_decay: 1e-4
51
+ epsilon: 1e-8
52
+
53
+ lr_scheduler:
54
+ scheduler: "cosine"
55
+ params:
56
+ learning_rate: ${optimizer.params.learning_rate}
57
+ warmup_steps: 10_000
58
+ end_lr: 1e-5
59
+
60
+ training:
61
+ gradient_accumulation_steps: 1
62
+ per_gpu_batch_size: 32
63
+ mixed_precision: "fp16"
64
+ enable_tf32: True
65
+ enable_wandb: True
66
+ use_ema: True
67
+ seed: 42
68
+ max_train_steps: 1_000_000
69
+ num_generated_images: 2
70
+ max_grad_norm: 1.0
configs/training/stage2/titok_b64.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ project: "titok_b_64_stage2"
3
+ name: "titok_b_64_stage2_run1"
4
+ output_dir: "titok_b_64_stage2_run1"
5
+ max_train_examples: 1_281_167
6
+ save_every: 50_000
7
+ eval_every: 50_000
8
+ generate_every: 5_000
9
+ log_every: 50
10
+ log_grad_norm_every: 1_000
11
+ resume: True
12
+ init_weight: "titok_b_64_stage1.bin"
13
+
14
+
15
+ model:
16
+ vq_model:
17
+ codebook_size: 4096
18
+ token_size: 12
19
+ use_l2_norm: True
20
+ commitment_cost: 0.0
21
+ # vit arch
22
+ vit_enc_model_size: "base"
23
+ vit_dec_model_size: "base"
24
+ vit_enc_patch_size: 16
25
+ vit_dec_patch_size: 16
26
+ num_latent_tokens: 64
27
+ finetune_decoder: True
28
+ pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"
29
+
30
+ losses:
31
+ discriminator_start: 20_000
32
+ quantizer_weight: 0.0
33
+ discriminator_factor: 1.0
34
+ discriminator_weight: 0.01
35
+ perceptual_loss: "convnext_s"
36
+ perceptual_weight: 0.1
37
+ reconstruction_loss: "l2"
38
+ reconstruction_weight: 1.0
39
+ lecam_regularization_weight: 0.001
40
+
41
+
42
+ dataset:
43
+ params:
44
+ train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
45
+ eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
46
+ num_workers_per_gpu: 12
47
+ preprocessing:
48
+ resize_shorter_edge: 256
49
+ crop_size: 256
50
+ random_crop: True
51
+ random_flip: True
52
+
53
+
54
+ optimizer:
55
+ name: adamw
56
+ params:
57
+ learning_rate: 1e-4
58
+ discriminator_learning_rate: 1e-4
59
+ beta1: 0.9
60
+ beta2: 0.999
61
+ weight_decay: 1e-4
62
+
63
+ lr_scheduler:
64
+ scheduler: "cosine"
65
+ params:
66
+ learning_rate: ${optimizer.params.learning_rate}
67
+ warmup_steps: 5_000
68
+ end_lr: 1e-5
69
+
70
+ training:
71
+ gradient_accumulation_steps: 1
72
+ per_gpu_batch_size: 32
73
+ mixed_precision: "fp16"
74
+ enable_tf32: True
75
+ enable_wandb: True
76
+ use_ema: True
77
+ seed: 42
78
+ max_train_steps: 500_000
79
+ num_generated_images: 2
80
+ max_grad_norm: 1.0
configs/training/stage2/titok_l32.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ project: "titok_l_32_stage2"
3
+ name: "titok_l_32_stage2_run1"
4
+ output_dir: "titok_l_32_stage2_run1"
5
+ max_train_examples: 1_281_167
6
+ save_every: 50_000
7
+ eval_every: 50_000
8
+ generate_every: 5_000
9
+ log_every: 50
10
+ log_grad_norm_every: 1_000
11
+ resume: True
12
+ init_weight: "titok_l_32_stage1.bin"
13
+
14
+
15
+ model:
16
+ vq_model:
17
+ codebook_size: 4096
18
+ token_size: 12
19
+ use_l2_norm: True
20
+ commitment_cost: 0.0
21
+ # vit arch
22
+ vit_enc_model_size: "large"
23
+ vit_dec_model_size: "large"
24
+ vit_enc_patch_size: 16
25
+ vit_dec_patch_size: 16
26
+ num_latent_tokens: 32
27
+ finetune_decoder: True
28
+ pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"
29
+
30
+ losses:
31
+ discriminator_start: 20_000
32
+ quantizer_weight: 0.0
33
+ discriminator_factor: 1.0
34
+ discriminator_weight: 0.01
35
+ perceptual_loss: "convnext_s"
36
+ perceptual_weight: 0.1
37
+ reconstruction_loss: "l2"
38
+ reconstruction_weight: 1.0
39
+ lecam_regularization_weight: 0.001
40
+
41
+ dataset:
42
+ params:
43
+ train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
44
+ eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
45
+ num_workers_per_gpu: 12
46
+ preprocessing:
47
+ resize_shorter_edge: 256
48
+ crop_size: 256
49
+ random_crop: True
50
+ random_flip: True
51
+
52
+
53
+ optimizer:
54
+ name: adamw
55
+ params:
56
+ learning_rate: 1e-4
57
+ discriminator_learning_rate: 1e-4
58
+ beta1: 0.9
59
+ beta2: 0.999
60
+ weight_decay: 1e-4
61
+
62
+ lr_scheduler:
63
+ scheduler: "cosine"
64
+ params:
65
+ learning_rate: ${optimizer.params.learning_rate}
66
+ warmup_steps: 5_000
67
+ end_lr: 1e-5
68
+
69
+ training:
70
+ gradient_accumulation_steps: 1
71
+ per_gpu_batch_size: 32
72
+ mixed_precision: "fp16"
73
+ enable_tf32: True
74
+ enable_wandb: True
75
+ use_ema: True
76
+ seed: 42
77
+ max_train_steps: 500_000
78
+ num_generated_images: 2
79
+ max_grad_norm: 1.0
configs/training/stage2/titok_s128.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ project: "titok_s_128_stage2"
3
+ name: "titok_s_128_stage2_run1"
4
+ output_dir: "titok_s_128_stage2_run1"
5
+ max_train_examples: 1_281_167
6
+ save_every: 50_000
7
+ eval_every: 50_000
8
+ generate_every: 5_000
9
+ log_every: 50
10
+ log_grad_norm_every: 1_000
11
+ resume: True
12
+ init_weight: "titok_s_128_stage1.bin"
13
+
14
+
15
+ model:
16
+ vq_model:
17
+ codebook_size: 4096
18
+ token_size: 12
19
+ use_l2_norm: True
20
+ commitment_cost: 0.0
21
+ # vit arch
22
+ vit_enc_model_size: "small"
23
+ vit_dec_model_size: "small"
24
+ vit_enc_patch_size: 16
25
+ vit_dec_patch_size: 16
26
+ num_latent_tokens: 128
27
+ finetune_decoder: True
28
+ pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"
29
+
30
+ losses:
31
+ discriminator_start: 20_000
32
+ quantizer_weight: 0.0
33
+ discriminator_factor: 1.0
34
+ discriminator_weight: 0.01
35
+ perceptual_loss: "convnext_s"
36
+ perceptual_weight: 0.1
37
+ reconstruction_loss: "l2"
38
+ reconstruction_weight: 1.0
39
+ lecam_regularization_weight: 0.001
40
+
41
+ dataset:
42
+ params:
43
+ train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
44
+ eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
45
+ num_workers_per_gpu: 12
46
+ preprocessing:
47
+ resize_shorter_edge: 256
48
+ crop_size: 256
49
+ random_crop: True
50
+ random_flip: True
51
+
52
+
53
+ optimizer:
54
+ name: adamw
55
+ params:
56
+ learning_rate: 1e-4
57
+ discriminator_learning_rate: 1e-4
58
+ beta1: 0.9
59
+ beta2: 0.999
60
+ weight_decay: 1e-4
61
+
62
+ lr_scheduler:
63
+ scheduler: "cosine"
64
+ params:
65
+ learning_rate: ${optimizer.params.learning_rate}
66
+ warmup_steps: 5_000
67
+ end_lr: 1e-5
68
+
69
+ training:
70
+ gradient_accumulation_steps: 1
71
+ per_gpu_batch_size: 32
72
+ mixed_precision: "fp16"
73
+ enable_tf32: True
74
+ enable_wandb: True
75
+ use_ema: True
76
+ seed: 42
77
+ max_train_steps: 500_000
78
+ num_generated_images: 2
79
+ max_grad_norm: 1.0
data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .webdataset_reader import SimpleImageDataset, PretoeknizedDataSetJSONL
data/convert_imagenet_to_wds.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Adapted from https://github.com/webdataset/webdataset-imagenet/blob/main/convert-imagenet.py
17
+
18
+ import argparse
19
+ import os
20
+ import sys
21
+ import time
22
+
23
+ import webdataset as wds
24
+ from datasets import load_dataset
25
+
26
+
27
+ def convert_imagenet_to_wds(output_dir, max_train_samples_per_shard, max_val_samples_per_shard):
28
+ assert not os.path.exists(os.path.join(output_dir, "imagenet-train-000000.tar"))
29
+ assert not os.path.exists(os.path.join(output_dir, "imagenet-val-000000.tar"))
30
+
31
+ opat = os.path.join(output_dir, "imagenet-train-%06d.tar")
32
+ output = wds.ShardWriter(opat, maxcount=max_train_samples_per_shard)
33
+ dataset = load_dataset("imagenet-1k", streaming=True, split="train", use_auth_token=True)
34
+ now = time.time()
35
+ for i, example in enumerate(dataset):
36
+ if i % max_train_samples_per_shard == 0:
37
+ print(i, file=sys.stderr)
38
+ img, label = example["image"], example["label"]
39
+ output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label})
40
+ output.close()
41
+ time_taken = time.time() - now
42
+ print(f"Wrote {i+1} train examples in {time_taken // 3600} hours.")
43
+
44
+ opat = os.path.join(output_dir, "imagenet-val-%06d.tar")
45
+ output = wds.ShardWriter(opat, maxcount=max_val_samples_per_shard)
46
+ dataset = load_dataset("imagenet-1k", streaming=True, split="validation", use_auth_token=True)
47
+ now = time.time()
48
+ for i, example in enumerate(dataset):
49
+ if i % max_val_samples_per_shard == 0:
50
+ print(i, file=sys.stderr)
51
+ img, label = example["image"], example["label"]
52
+ output.write({"__key__": "%08d" % i, "jpg": img.convert("RGB"), "cls": label})
53
+ output.close()
54
+ time_taken = time.time() - now
55
+ print(f"Wrote {i+1} val examples in {time_taken // 60} min.")
56
+
57
+
58
+ if __name__ == "__main__":
59
+ # create parase object
60
+ parser = argparse.ArgumentParser()
61
+ parser.add_argument("--output_dir", type=str, required=True, help="Path to the output directory.")
62
+ parser.add_argument("--max_train_samples_per_shard", type=int, default=4000, help="Path to the output directory.")
63
+ parser.add_argument("--max_val_samples_per_shard", type=int, default=1000, help="Path to the output directory.")
64
+ args = parser.parse_args()
65
+
66
+ # create output directory
67
+ os.makedirs(args.output_dir, exist_ok=True)
68
+ convert_imagenet_to_wds(args.output_dir, args.max_train_samples_per_shard, args.max_val_samples_per_shard)
data/webdataset_reader.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains the definition of data loader using webdataset.
2
+
3
+ This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
4
+ All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
5
+
6
+ Reference:
7
+ https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py
8
+ https://github.com/huggingface/open-muse/blob/main/training/data.py
9
+ """
10
+
11
+ import math
12
+ from typing import List, Union, Text
13
+ import webdataset as wds
14
+ import torch
15
+ from torch.utils.data import default_collate
16
+ from torchvision import transforms
17
+ from torch.utils.data import Dataset
18
+ import linecache
19
+ import json
20
+
21
+
22
+ def filter_keys(key_set):
23
+ def _f(dictionary):
24
+ return {k: v for k, v in dictionary.items() if k in key_set}
25
+
26
+ return _f
27
+
28
+
29
+ class ImageTransform:
30
+ def __init__(self,
31
+ resize_shorter_edge: int = 256,
32
+ crop_size: int = 256,
33
+ random_crop: bool = True,
34
+ random_flip: bool = True,
35
+ normalize_mean: List[float] = [0., 0., 0.],
36
+ normalize_std: List[float] = [1., 1., 1.]):
37
+ """Initializes the WebDatasetReader with specified augmentation parameters.
38
+
39
+ Args:
40
+ resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
41
+ crop_size: An integer, the size to crop the input image to.
42
+ random_crop: A boolean, whether to use random crop augmentation during training.
43
+ random_flip: A boolean, whether to use random flipping augmentation during training.
44
+ normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
45
+ normalize_std: A list of float, the normalization std used to normalize the image tensor.
46
+
47
+ Raises:
48
+ NotImplementedError: If the interpolation mode is not one of ["bicubic", "bilinear"].
49
+ """
50
+ train_transform = []
51
+ interpolation = transforms.InterpolationMode.BICUBIC
52
+
53
+ train_transform.append(
54
+ transforms.Resize(resize_shorter_edge, interpolation=interpolation, antialias=True))
55
+ if random_crop:
56
+ train_transform.append(transforms.RandomCrop(crop_size))
57
+ else:
58
+ train_transform.append(transforms.CenterCrop(crop_size))
59
+ if random_flip:
60
+ train_transform.append(transforms.RandomHorizontalFlip())
61
+ train_transform.append(transforms.ToTensor())
62
+ # normalize_mean = [0, 0, 0] and normalize_std = [1, 1, 1] will normalize images into [0, 1],
63
+ # normalize_mean = [0.5, 0.5, 0.5] and normalize_std = [0.5, 0.5, 0.5] will normalize images into [-1, 1].
64
+ train_transform.append(transforms.Normalize(normalize_mean, normalize_std))
65
+
66
+ self.train_transform = transforms.Compose(train_transform)
67
+ self.eval_transform = transforms.Compose(
68
+ [
69
+ # Note that we always resize to crop_size during eval to ensure the results
70
+ # can be compared against reference numbers on ImageNet etc.
71
+ transforms.Resize(crop_size, interpolation=interpolation, antialias=True),
72
+ transforms.CenterCrop(crop_size),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(normalize_mean, normalize_std)
75
+ ]
76
+ )
77
+ print(f"self.train_transform: {self.train_transform}")
78
+ print(f"self.eval_transform: {self.eval_transform}")
79
+
80
+
81
+ class SimpleImageDataset:
82
+ def __init__(
83
+ self,
84
+ train_shards_path: Union[Text, List[Text]],
85
+ eval_shards_path: Union[Text, List[Text]],
86
+ num_train_examples: int,
87
+ per_gpu_batch_size: int,
88
+ global_batch_size: int,
89
+ num_workers_per_gpu: int = 12,
90
+ resize_shorter_edge: int = 256,
91
+ crop_size: int = 256,
92
+ random_crop = True,
93
+ random_flip = True,
94
+ normalize_mean: List[float] = [0., 0., 0.],
95
+ normalize_std: List[float] = [1., 1., 1.],
96
+ ):
97
+ """Initializes the WebDatasetReader class.
98
+
99
+ Args:
100
+ train_shards_path: A string or list of string, path to the training data shards in webdataset format.
101
+ eval_shards_path: A string or list of string, path to the evaluation data shards in webdataset format.
102
+ num_train_examples: An integer, total number of training examples.
103
+ per_gpu_batch_size: An integer, number of examples per GPU batch.
104
+ global_batch_size: An integer, total number of examples in a batch across all GPUs.
105
+ num_workers_per_gpu: An integer, number of workers per GPU.
106
+ resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
107
+ crop_size: An integer, the size to crop the input image to.
108
+ random_crop: A boolean, whether to use random crop augmentation during training.
109
+ random_flip: A boolean, whether to use random flipping augmentation during training.
110
+ normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
111
+ normalize_std: A list of float, the normalization std used to normalize the image tensor.
112
+ """
113
+ transform = ImageTransform(
114
+ resize_shorter_edge, crop_size, random_crop, random_flip,
115
+ normalize_mean, normalize_std)
116
+
117
+ train_processing_pipeline = [
118
+ wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"])),
119
+ wds.rename(
120
+ image="jpg;png;jpeg;webp",
121
+ class_id="cls",
122
+ handler=wds.warn_and_continue,
123
+ ),
124
+ wds.map(filter_keys(set(["image", "class_id", "filename"]))),
125
+ wds.map_dict(
126
+ image=transform.train_transform,
127
+ class_id=lambda x: int(x),
128
+ handler=wds.warn_and_continue,
129
+ ),
130
+ ]
131
+
132
+ test_processing_pipeline = [
133
+ wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"])),
134
+ wds.rename(
135
+ image="jpg;png;jpeg;webp",
136
+ class_id="cls",
137
+ handler=wds.warn_and_continue,
138
+ ),
139
+ wds.map(filter_keys(set(["image", "class_id", "filename"]))),
140
+ wds.map_dict(
141
+ image=transform.eval_transform,
142
+ class_id=lambda x: int(x),
143
+ handler=wds.warn_and_continue,
144
+ ),
145
+ ]
146
+
147
+ # Create train dataset and loader.
148
+ pipeline = [
149
+ wds.ResampledShards(train_shards_path),
150
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
151
+ wds.shuffle(bufsize=5000,
152
+ initial=1000),
153
+ *train_processing_pipeline,
154
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
155
+ ]
156
+
157
+ num_batches = math.ceil(num_train_examples / global_batch_size)
158
+ num_worker_batches = math.ceil(num_train_examples /
159
+ (global_batch_size * num_workers_per_gpu))
160
+ num_batches = num_worker_batches * num_workers_per_gpu
161
+ num_samples = num_batches * global_batch_size
162
+
163
+ # Each worker is iterating over the complete dataset.
164
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
165
+ self._train_dataloader = wds.WebLoader(
166
+ self._train_dataset,
167
+ batch_size=None,
168
+ shuffle=False,
169
+ num_workers=num_workers_per_gpu,
170
+ pin_memory=True,
171
+ persistent_workers=True,
172
+ )
173
+ # Add meta-data to dataloader instance for convenience.
174
+ self._train_dataloader.num_batches = num_batches
175
+ self._train_dataloader.num_samples = num_samples
176
+
177
+ # Create eval dataset and loader.
178
+ pipeline = [
179
+ wds.SimpleShardList(eval_shards_path),
180
+ wds.split_by_worker,
181
+ wds.tarfile_to_samples(handler=wds.ignore_and_continue),
182
+ *test_processing_pipeline,
183
+ wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate),
184
+ ]
185
+ self._eval_dataset = wds.DataPipeline(*pipeline)
186
+ self._eval_dataloader = wds.WebLoader(
187
+ self._eval_dataset,
188
+ batch_size=None,
189
+ shuffle=False,
190
+ num_workers=num_workers_per_gpu,
191
+ pin_memory=True,
192
+ persistent_workers=True,
193
+ )
194
+
195
+ @property
196
+ def train_dataset(self):
197
+ return self._train_dataset
198
+
199
+ @property
200
+ def train_dataloader(self):
201
+ return self._train_dataloader
202
+
203
+ @property
204
+ def eval_dataset(self):
205
+ return self._eval_dataset
206
+
207
+ @property
208
+ def eval_dataloader(self):
209
+ return self._eval_dataloader
210
+
211
+
212
+ class PretoeknizedDataSetJSONL(Dataset):
213
+ def __init__(self, data_path):
214
+ super().__init__()
215
+ self.jsonl_file = data_path
216
+ self.num_lines = sum(1 for _ in open(self.jsonl_file))
217
+ # Ensure the file is cached
218
+ linecache.checkcache(self.jsonl_file)
219
+ print("Number of data:", self.num_lines)
220
+
221
+ def __len__(self):
222
+ return self.num_lines
223
+
224
+ def __getitem__(self, idx):
225
+ line = linecache.getline(self.jsonl_file, idx + 1).strip()
226
+ data = json.loads(line)
227
+ return torch.tensor(data["class_id"]), torch.tensor(data["tokens"])
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
demo_util.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Demo file for sampling images from TiTok.
2
+
3
+ Copyright (2024) Bytedance Ltd. and/or its affiliates
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+ """
17
+
18
+
19
+ import torch
20
+
21
+ from omegaconf import OmegaConf
22
+ from modeling.titok import TiTok
23
+ from modeling.maskgit import ImageBert, UViTBert
24
+ from modeling.rar import RAR
25
+
26
+
27
+ def get_config_cli():
28
+ cli_conf = OmegaConf.from_cli()
29
+
30
+ yaml_conf = OmegaConf.load(cli_conf.config)
31
+ conf = OmegaConf.merge(yaml_conf, cli_conf)
32
+
33
+ return conf
34
+
35
+ def get_config(config_path):
36
+ conf = OmegaConf.load(config_path)
37
+ return conf
38
+
39
+ def get_titok_tokenizer(config):
40
+ tokenizer = TiTok(config)
41
+ tokenizer.load_state_dict(torch.load(config.experiment.tokenizer_checkpoint, map_location="cpu"))
42
+ tokenizer.eval()
43
+ tokenizer.requires_grad_(False)
44
+ return tokenizer
45
+
46
+ def get_titok_generator(config):
47
+ if config.model.generator.model_type == "ViT":
48
+ model_cls = ImageBert
49
+ elif config.model.generator.model_type == "UViT":
50
+ model_cls = UViTBert
51
+ else:
52
+ raise ValueError(f"Unsupported model type {config.model.generator.model_type}")
53
+ generator = model_cls(config)
54
+ generator.load_state_dict(torch.load(config.experiment.generator_checkpoint, map_location="cpu"))
55
+ generator.eval()
56
+ generator.requires_grad_(False)
57
+ return generator
58
+
59
+ def get_rar_generator(config):
60
+ model_cls = RAR
61
+ generator = model_cls(config)
62
+ generator.load_state_dict(torch.load(config.experiment.generator_checkpoint, map_location="cpu"))
63
+ generator.eval()
64
+ generator.requires_grad_(False)
65
+ generator.set_random_ratio(0)
66
+ return generator
67
+
68
+
69
+ @torch.no_grad()
70
+ def sample_fn(generator,
71
+ tokenizer,
72
+ labels=None,
73
+ guidance_scale=3.0,
74
+ guidance_decay="constant",
75
+ guidance_scale_pow=3.0,
76
+ randomize_temperature=2.0,
77
+ softmax_temperature_annealing=False,
78
+ num_sample_steps=8,
79
+ device="cuda",
80
+ return_tensor=False):
81
+ generator.eval()
82
+ tokenizer.eval()
83
+ if labels is None:
84
+ # goldfish, chicken, tiger, cat, hourglass, ship, dog, race car, airliner, teddy bear, random
85
+ labels = [1, 7, 282, 604, 724, 179, 751, 404, 850, torch.randint(0, 999, size=(1,))]
86
+
87
+ if not isinstance(labels, torch.Tensor):
88
+ labels = torch.LongTensor(labels).to(device)
89
+
90
+ generated_tokens = generator.generate(
91
+ condition=labels,
92
+ guidance_scale=guidance_scale,
93
+ guidance_decay=guidance_decay,
94
+ guidance_scale_pow=guidance_scale_pow,
95
+ randomize_temperature=randomize_temperature,
96
+ softmax_temperature_annealing=softmax_temperature_annealing,
97
+ num_sample_steps=num_sample_steps)
98
+
99
+ generated_image = tokenizer.decode_tokens(
100
+ generated_tokens.view(generated_tokens.shape[0], -1)
101
+ )
102
+ if return_tensor:
103
+ return generated_image
104
+
105
+ generated_image = torch.clamp(generated_image, 0.0, 1.0)
106
+ generated_image = (generated_image * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
107
+
108
+ return generated_image
evaluator/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .evaluator import VQGANEvaluator
evaluator/evaluator.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains a class to evalute the reconstruction results.
2
+
3
+ Copyright (2024) Bytedance Ltd. and/or its affiliates
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+ """
17
+
18
+ import warnings
19
+
20
+ from typing import Sequence, Optional, Mapping, Text
21
+ import numpy as np
22
+ from scipy import linalg
23
+ import torch
24
+ import torch.nn.functional as F
25
+
26
+ from .inception import get_inception_model
27
+
28
+
29
+ def get_covariance(sigma: torch.Tensor, total: torch.Tensor, num_examples: int) -> torch.Tensor:
30
+ """Computes covariance of the input tensor.
31
+
32
+ Args:
33
+ sigma: A torch.Tensor, sum of outer products of input features.
34
+ total: A torch.Tensor, sum of all input features.
35
+ num_examples: An integer, number of examples in the input tensor.
36
+ Returns:
37
+ A torch.Tensor, covariance of the input tensor.
38
+ """
39
+ if num_examples == 0:
40
+ return torch.zeros_like(sigma)
41
+
42
+ sub_matrix = torch.outer(total, total)
43
+ sub_matrix = sub_matrix / num_examples
44
+
45
+ return (sigma - sub_matrix) / (num_examples - 1)
46
+
47
+
48
+ class VQGANEvaluator:
49
+ def __init__(
50
+ self,
51
+ device,
52
+ enable_rfid: bool = True,
53
+ enable_inception_score: bool = True,
54
+ enable_codebook_usage_measure: bool = False,
55
+ enable_codebook_entropy_measure: bool = False,
56
+ num_codebook_entries: int = 1024
57
+ ):
58
+ """Initializes VQGAN Evaluator.
59
+
60
+ Args:
61
+ device: The device to use for evaluation.
62
+ enable_rfid: A boolean, whether enabling rFID score.
63
+ enable_inception_score: A boolean, whether enabling Inception Score.
64
+ enable_codebook_usage_measure: A boolean, whether enabling codebook usage measure.
65
+ enable_codebook_entropy_measure: A boolean, whether enabling codebook entropy measure.
66
+ num_codebook_entries: An integer, the number of codebook entries.
67
+ """
68
+ self._device = device
69
+
70
+ self._enable_rfid = enable_rfid
71
+ self._enable_inception_score = enable_inception_score
72
+ self._enable_codebook_usage_measure = enable_codebook_usage_measure
73
+ self._enable_codebook_entropy_measure = enable_codebook_entropy_measure
74
+ self._num_codebook_entries = num_codebook_entries
75
+
76
+ # Variables related to Inception score and rFID.
77
+ self._inception_model = None
78
+ self._is_num_features = 0
79
+ self._rfid_num_features = 0
80
+ if self._enable_inception_score or self._enable_rfid:
81
+ self._rfid_num_features = 2048
82
+ self._is_num_features = 1008
83
+ self._inception_model = get_inception_model().to(self._device)
84
+ self._inception_model.eval()
85
+ self._is_eps = 1e-16
86
+ self._rfid_eps = 1e-6
87
+
88
+ self.reset_metrics()
89
+
90
+ def reset_metrics(self):
91
+ """Resets all metrics."""
92
+ self._num_examples = 0
93
+ self._num_updates = 0
94
+
95
+ self._is_prob_total = torch.zeros(
96
+ self._is_num_features, dtype=torch.float64, device=self._device
97
+ )
98
+ self._is_total_kl_d = torch.zeros(
99
+ self._is_num_features, dtype=torch.float64, device=self._device
100
+ )
101
+ self._rfid_real_sigma = torch.zeros(
102
+ (self._rfid_num_features, self._rfid_num_features),
103
+ dtype=torch.float64, device=self._device
104
+ )
105
+ self._rfid_real_total = torch.zeros(
106
+ self._rfid_num_features, dtype=torch.float64, device=self._device
107
+ )
108
+ self._rfid_fake_sigma = torch.zeros(
109
+ (self._rfid_num_features, self._rfid_num_features),
110
+ dtype=torch.float64, device=self._device
111
+ )
112
+ self._rfid_fake_total = torch.zeros(
113
+ self._rfid_num_features, dtype=torch.float64, device=self._device
114
+ )
115
+
116
+ self._set_of_codebook_indices = set()
117
+ self._codebook_frequencies = torch.zeros((self._num_codebook_entries), dtype=torch.float64, device=self._device)
118
+
119
+ def update(
120
+ self,
121
+ real_images: torch.Tensor,
122
+ fake_images: torch.Tensor,
123
+ codebook_indices: Optional[torch.Tensor] = None
124
+ ):
125
+ """Updates the metrics with the given images.
126
+
127
+ Args:
128
+ real_images: A torch.Tensor, the real images.
129
+ fake_images: A torch.Tensor, the fake images.
130
+ codebook_indices: A torch.Tensor, the indices of the codebooks for each image.
131
+
132
+ Raises:
133
+ ValueError: If the fake images is not in RGB (3 channel).
134
+ ValueError: If the fake and real images have different shape.
135
+ """
136
+
137
+ batch_size = real_images.shape[0]
138
+ dim = tuple(range(1, real_images.ndim))
139
+ self._num_examples += batch_size
140
+ self._num_updates += 1
141
+
142
+ if self._enable_inception_score or self._enable_rfid:
143
+ # Quantize to uint8 as a real image.
144
+ fake_inception_images = (fake_images * 255).to(torch.uint8)
145
+ features_fake = self._inception_model(fake_inception_images)
146
+ inception_logits_fake = features_fake["logits_unbiased"]
147
+ inception_probabilities_fake = F.softmax(inception_logits_fake, dim=-1)
148
+
149
+ if self._enable_inception_score:
150
+ probabiliies_sum = torch.sum(inception_probabilities_fake, 0, dtype=torch.float64)
151
+
152
+ log_prob = torch.log(inception_probabilities_fake + self._is_eps)
153
+ if log_prob.dtype != inception_probabilities_fake.dtype:
154
+ log_prob = log_prob.to(inception_probabilities_fake)
155
+ kl_sum = torch.sum(inception_probabilities_fake * log_prob, 0, dtype=torch.float64)
156
+
157
+ self._is_prob_total += probabiliies_sum
158
+ self._is_total_kl_d += kl_sum
159
+
160
+ if self._enable_rfid:
161
+ real_inception_images = (real_images * 255).to(torch.uint8)
162
+ features_real = self._inception_model(real_inception_images)
163
+ if (features_real['2048'].shape[0] != features_fake['2048'].shape[0] or
164
+ features_real['2048'].shape[1] != features_fake['2048'].shape[1]):
165
+ raise ValueError(f"Number of features should be equal for real and fake.")
166
+
167
+ for f_real, f_fake in zip(features_real['2048'], features_fake['2048']):
168
+ self._rfid_real_total += f_real
169
+ self._rfid_fake_total += f_fake
170
+
171
+ self._rfid_real_sigma += torch.outer(f_real, f_real)
172
+ self._rfid_fake_sigma += torch.outer(f_fake, f_fake)
173
+
174
+ if self._enable_codebook_usage_measure:
175
+ self._set_of_codebook_indices |= set(torch.unique(codebook_indices, sorted=False).tolist())
176
+
177
+ if self._enable_codebook_entropy_measure:
178
+ entries, counts = torch.unique(codebook_indices, sorted=False, return_counts=True)
179
+ self._codebook_frequencies.index_add_(0, entries.int(), counts.double())
180
+
181
+
182
+ def result(self) -> Mapping[Text, torch.Tensor]:
183
+ """Returns the evaluation result."""
184
+ eval_score = {}
185
+
186
+ if self._num_examples < 1:
187
+ raise ValueError("No examples to evaluate.")
188
+
189
+ if self._enable_inception_score:
190
+ mean_probs = self._is_prob_total / self._num_examples
191
+ log_mean_probs = torch.log(mean_probs + self._is_eps)
192
+ if log_mean_probs.dtype != self._is_prob_total.dtype:
193
+ log_mean_probs = log_mean_probs.to(self._is_prob_total)
194
+ excess_entropy = self._is_prob_total * log_mean_probs
195
+ avg_kl_d = torch.sum(self._is_total_kl_d - excess_entropy) / self._num_examples
196
+
197
+ inception_score = torch.exp(avg_kl_d).item()
198
+ eval_score["InceptionScore"] = inception_score
199
+
200
+ if self._enable_rfid:
201
+ mu_real = self._rfid_real_total / self._num_examples
202
+ mu_fake = self._rfid_fake_total / self._num_examples
203
+ sigma_real = get_covariance(self._rfid_real_sigma, self._rfid_real_total, self._num_examples)
204
+ sigma_fake = get_covariance(self._rfid_fake_sigma, self._rfid_fake_total, self._num_examples)
205
+
206
+ mu_real, mu_fake = mu_real.cpu(), mu_fake.cpu()
207
+ sigma_real, sigma_fake = sigma_real.cpu(), sigma_fake.cpu()
208
+
209
+ diff = mu_real - mu_fake
210
+
211
+ # Product might be almost singular.
212
+ covmean, _ = linalg.sqrtm(sigma_real.mm(sigma_fake).numpy(), disp=False)
213
+ # Numerical error might give slight imaginary component.
214
+ if np.iscomplexobj(covmean):
215
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
216
+ m = np.max(np.abs(covmean.imag))
217
+ raise ValueError("Imaginary component {}".format(m))
218
+ covmean = covmean.real
219
+
220
+ tr_covmean = np.trace(covmean)
221
+
222
+ if not np.isfinite(covmean).all():
223
+ tr_covmean = np.sum(np.sqrt((
224
+ (np.diag(sigma_real) * self._rfid_eps) * (np.diag(sigma_fake) * self._rfid_eps))
225
+ / (self._rfid_eps * self._rfid_eps)
226
+ ))
227
+
228
+ rfid = float(diff.dot(diff).item() + torch.trace(sigma_real) + torch.trace(sigma_fake)
229
+ - 2 * tr_covmean
230
+ )
231
+ if torch.isnan(torch.tensor(rfid)) or torch.isinf(torch.tensor(rfid)):
232
+ warnings.warn("The product of covariance of train and test features is out of bounds.")
233
+
234
+ eval_score["rFID"] = rfid
235
+
236
+ if self._enable_codebook_usage_measure:
237
+ usage = float(len(self._set_of_codebook_indices)) / self._num_codebook_entries
238
+ eval_score["CodebookUsage"] = usage
239
+
240
+ if self._enable_codebook_entropy_measure:
241
+ probs = self._codebook_frequencies / self._codebook_frequencies.sum()
242
+ entropy = (-torch.log2(probs + 1e-8) * probs).sum()
243
+ eval_score["CodebookEntropy"] = entropy
244
+
245
+ return eval_score
evaluator/inception.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file is for Inception model borrowed from torch metrics / fidelity.
2
+
3
+ This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
4
+ All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
5
+
6
+ Reference:
7
+ https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py
8
+ """
9
+ # Copyright The Lightning team.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ import torch
23
+ import torch.nn.functional as F
24
+
25
+ from torch_fidelity.feature_extractor_base import FeatureExtractorBase
26
+ from torch_fidelity.helpers import vassert
27
+ from torch_fidelity.feature_extractor_inceptionv3 import BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE_1, InceptionE_2
28
+ from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x
29
+
30
+ try:
31
+ from torchvision.models.utils import load_state_dict_from_url
32
+ except ImportError:
33
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
34
+
35
+
36
+ # Note: Compared shasum and models should be the same.
37
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
38
+
39
+ class FeatureExtractorInceptionV3(FeatureExtractorBase):
40
+ INPUT_IMAGE_SIZE = 299
41
+
42
+ def __init__(
43
+ self,
44
+ name,
45
+ features_list,
46
+ **kwargs,
47
+ ):
48
+ """
49
+ InceptionV3 feature extractor for 2D RGB 24bit images.
50
+
51
+ Args:
52
+
53
+ name (str): Unique name of the feature extractor, must be the same as used in
54
+ :func:`register_feature_extractor`.
55
+
56
+ features_list (list): A list of the requested feature names, which will be produced for each input. This
57
+ feature extractor provides the following features:
58
+
59
+ - '64'
60
+ - '192'
61
+ - '768'
62
+ - '2048'
63
+ - 'logits_unbiased'
64
+ - 'logits'
65
+
66
+ """
67
+ super(FeatureExtractorInceptionV3, self).__init__(name, features_list)
68
+ self.feature_extractor_internal_dtype = torch.float64
69
+
70
+ self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
71
+ self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
72
+ self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
73
+ self.MaxPool_1 = torch.nn.MaxPool2d(kernel_size=3, stride=2)
74
+
75
+ self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
76
+ self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
77
+ self.MaxPool_2 = torch.nn.MaxPool2d(kernel_size=3, stride=2)
78
+
79
+ self.Mixed_5b = InceptionA(192, pool_features=32)
80
+ self.Mixed_5c = InceptionA(256, pool_features=64)
81
+ self.Mixed_5d = InceptionA(288, pool_features=64)
82
+ self.Mixed_6a = InceptionB(288)
83
+ self.Mixed_6b = InceptionC(768, channels_7x7=128)
84
+ self.Mixed_6c = InceptionC(768, channels_7x7=160)
85
+ self.Mixed_6d = InceptionC(768, channels_7x7=160)
86
+ self.Mixed_6e = InceptionC(768, channels_7x7=192)
87
+
88
+ self.Mixed_7a = InceptionD(768)
89
+ self.Mixed_7b = InceptionE_1(1280)
90
+ self.Mixed_7c = InceptionE_2(2048)
91
+ self.AvgPool = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1))
92
+
93
+ self.fc = torch.nn.Linear(2048, 1008)
94
+
95
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
96
+ #state_dict = torch.load(FID_WEIGHTS_URL, map_location='cpu')
97
+ self.load_state_dict(state_dict)
98
+
99
+ self.to(self.feature_extractor_internal_dtype)
100
+ self.requires_grad_(False)
101
+ self.eval()
102
+
103
+ def forward(self, x):
104
+ vassert(torch.is_tensor(x) and x.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8')
105
+ vassert(x.dim() == 4 and x.shape[1] == 3, f'Input is not Bx3xHxW: {x.shape}')
106
+ features = {}
107
+ remaining_features = self.features_list.copy()
108
+
109
+ x = x.to(self.feature_extractor_internal_dtype)
110
+ # N x 3 x ? x ?
111
+
112
+ x = interpolate_bilinear_2d_like_tensorflow1x(
113
+ x,
114
+ size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE),
115
+ align_corners=False,
116
+ )
117
+ # N x 3 x 299 x 299
118
+
119
+ # x = (x - 128) * torch.tensor(0.0078125, dtype=torch.float32, device=x.device) # really happening in graph
120
+ x = (x - 128) / 128 # but this gives bit-exact output _of this step_ too
121
+ # N x 3 x 299 x 299
122
+
123
+ x = self.Conv2d_1a_3x3(x)
124
+ # N x 32 x 149 x 149
125
+ x = self.Conv2d_2a_3x3(x)
126
+ # N x 32 x 147 x 147
127
+ x = self.Conv2d_2b_3x3(x)
128
+ # N x 64 x 147 x 147
129
+ x = self.MaxPool_1(x)
130
+ # N x 64 x 73 x 73
131
+
132
+ if '64' in remaining_features:
133
+ features['64'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
134
+ remaining_features.remove('64')
135
+ if len(remaining_features) == 0:
136
+ return features
137
+
138
+ x = self.Conv2d_3b_1x1(x)
139
+ # N x 80 x 73 x 73
140
+ x = self.Conv2d_4a_3x3(x)
141
+ # N x 192 x 71 x 71
142
+ x = self.MaxPool_2(x)
143
+ # N x 192 x 35 x 35
144
+
145
+ if '192' in remaining_features:
146
+ features['192'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
147
+ remaining_features.remove('192')
148
+ if len(remaining_features) == 0:
149
+ return features
150
+
151
+ x = self.Mixed_5b(x)
152
+ # N x 256 x 35 x 35
153
+ x = self.Mixed_5c(x)
154
+ # N x 288 x 35 x 35
155
+ x = self.Mixed_5d(x)
156
+ # N x 288 x 35 x 35
157
+ x = self.Mixed_6a(x)
158
+ # N x 768 x 17 x 17
159
+ x = self.Mixed_6b(x)
160
+ # N x 768 x 17 x 17
161
+ x = self.Mixed_6c(x)
162
+ # N x 768 x 17 x 17
163
+ x = self.Mixed_6d(x)
164
+ # N x 768 x 17 x 17
165
+ x = self.Mixed_6e(x)
166
+ # N x 768 x 17 x 17
167
+
168
+ if '768' in remaining_features:
169
+ features['768'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1).to(torch.float32)
170
+ remaining_features.remove('768')
171
+ if len(remaining_features) == 0:
172
+ return features
173
+
174
+ x = self.Mixed_7a(x)
175
+ # N x 1280 x 8 x 8
176
+ x = self.Mixed_7b(x)
177
+ # N x 2048 x 8 x 8
178
+ x = self.Mixed_7c(x)
179
+ # N x 2048 x 8 x 8
180
+ x = self.AvgPool(x)
181
+ # N x 2048 x 1 x 1
182
+
183
+ x = torch.flatten(x, 1)
184
+ # N x 2048
185
+
186
+ if '2048' in remaining_features:
187
+ features['2048'] = x
188
+ remaining_features.remove('2048')
189
+ if len(remaining_features) == 0:
190
+ return features
191
+
192
+ if 'logits_unbiased' in remaining_features:
193
+ x = x.mm(self.fc.weight.T)
194
+ # N x 1008 (num_classes)
195
+ features['logits_unbiased'] = x
196
+ remaining_features.remove('logits_unbiased')
197
+ if len(remaining_features) == 0:
198
+ return features
199
+
200
+ x = x + self.fc.bias.unsqueeze(0)
201
+ else:
202
+ x = self.fc(x)
203
+ # N x 1008 (num_classes)
204
+
205
+ features['logits'] = x
206
+ return features
207
+
208
+ @staticmethod
209
+ def get_provided_features_list():
210
+ return '64', '192', '768', '2048', 'logits_unbiased', 'logits'
211
+
212
+ @staticmethod
213
+ def get_default_feature_layer_for_metric(metric):
214
+ return {
215
+ 'isc': 'logits_unbiased',
216
+ 'fid': '2048',
217
+ 'kid': '2048',
218
+ 'prc': '2048',
219
+ }[metric]
220
+
221
+ @staticmethod
222
+ def can_be_compiled():
223
+ return True
224
+
225
+ @staticmethod
226
+ def get_dummy_input_for_compile():
227
+ return (torch.rand([1, 3, 4, 4]) * 255).to(torch.uint8)
228
+
229
+ def get_inception_model():
230
+ model = FeatureExtractorInceptionV3("inception_model", ["2048", "logits_unbiased"])
231
+ return model
imagenet_classes.py ADDED
@@ -0,0 +1,1001 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ imagenet_idx2classname = {
2
+ 0: 'tench, Tinca tinca',
3
+ 1: 'goldfish, Carassius auratus',
4
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
5
+ 3: 'tiger shark, Galeocerdo cuvieri',
6
+ 4: 'hammerhead, hammerhead shark',
7
+ 5: 'electric ray, crampfish, numbfish, torpedo',
8
+ 6: 'stingray',
9
+ 7: 'cock',
10
+ 8: 'hen',
11
+ 9: 'ostrich, Struthio camelus',
12
+ 10: 'brambling, Fringilla montifringilla',
13
+ 11: 'goldfinch, Carduelis carduelis',
14
+ 12: 'house finch, linnet, Carpodacus mexicanus',
15
+ 13: 'junco, snowbird',
16
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
17
+ 15: 'robin, American robin, Turdus migratorius',
18
+ 16: 'bulbul',
19
+ 17: 'jay',
20
+ 18: 'magpie',
21
+ 19: 'chickadee',
22
+ 20: 'water ouzel, dipper',
23
+ 21: 'kite',
24
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
25
+ 23: 'vulture',
26
+ 24: 'great grey owl, great gray owl, Strix nebulosa',
27
+ 25: 'European fire salamander, Salamandra salamandra',
28
+ 26: 'common newt, Triturus vulgaris',
29
+ 27: 'eft',
30
+ 28: 'spotted salamander, Ambystoma maculatum',
31
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum',
32
+ 30: 'bullfrog, Rana catesbeiana',
33
+ 31: 'tree frog, tree-frog',
34
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
35
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta',
36
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
37
+ 35: 'mud turtle',
38
+ 36: 'terrapin',
39
+ 37: 'box turtle, box tortoise',
40
+ 38: 'banded gecko',
41
+ 39: 'common iguana, iguana, Iguana iguana',
42
+ 40: 'American chameleon, anole, Anolis carolinensis',
43
+ 41: 'whiptail, whiptail lizard',
44
+ 42: 'agama',
45
+ 43: 'frilled lizard, Chlamydosaurus kingi',
46
+ 44: 'alligator lizard',
47
+ 45: 'Gila monster, Heloderma suspectum',
48
+ 46: 'green lizard, Lacerta viridis',
49
+ 47: 'African chameleon, Chamaeleo chamaeleon',
50
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
51
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
52
+ 50: 'American alligator, Alligator mississipiensis',
53
+ 51: 'triceratops',
54
+ 52: 'thunder snake, worm snake, Carphophis amoenus',
55
+ 53: 'ringneck snake, ring-necked snake, ring snake',
56
+ 54: 'hognose snake, puff adder, sand viper',
57
+ 55: 'green snake, grass snake',
58
+ 56: 'king snake, kingsnake',
59
+ 57: 'garter snake, grass snake',
60
+ 58: 'water snake',
61
+ 59: 'vine snake',
62
+ 60: 'night snake, Hypsiglena torquata',
63
+ 61: 'boa constrictor, Constrictor constrictor',
64
+ 62: 'rock python, rock snake, Python sebae',
65
+ 63: 'Indian cobra, Naja naja',
66
+ 64: 'green mamba',
67
+ 65: 'sea snake',
68
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
69
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
70
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
71
+ 69: 'trilobite',
72
+ 70: 'harvestman, daddy longlegs, Phalangium opilio',
73
+ 71: 'scorpion',
74
+ 72: 'black and gold garden spider, Argiope aurantia',
75
+ 73: 'barn spider, Araneus cavaticus',
76
+ 74: 'garden spider, Aranea diademata',
77
+ 75: 'black widow, Latrodectus mactans',
78
+ 76: 'tarantula',
79
+ 77: 'wolf spider, hunting spider',
80
+ 78: 'tick',
81
+ 79: 'centipede',
82
+ 80: 'black grouse',
83
+ 81: 'ptarmigan',
84
+ 82: 'ruffed grouse, partridge, Bonasa umbellus',
85
+ 83: 'prairie chicken, prairie grouse, prairie fowl',
86
+ 84: 'peacock',
87
+ 85: 'quail',
88
+ 86: 'partridge',
89
+ 87: 'African grey, African gray, Psittacus erithacus',
90
+ 88: 'macaw',
91
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
92
+ 90: 'lorikeet',
93
+ 91: 'coucal',
94
+ 92: 'bee eater',
95
+ 93: 'hornbill',
96
+ 94: 'hummingbird',
97
+ 95: 'jacamar',
98
+ 96: 'toucan',
99
+ 97: 'drake',
100
+ 98: 'red-breasted merganser, Mergus serrator',
101
+ 99: 'goose',
102
+ 100: 'black swan, Cygnus atratus',
103
+ 101: 'tusker',
104
+ 102: 'echidna, spiny anteater, anteater',
105
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
106
+ 104: 'wallaby, brush kangaroo',
107
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
108
+ 106: 'wombat',
109
+ 107: 'jellyfish',
110
+ 108: 'sea anemone, anemone',
111
+ 109: 'brain coral',
112
+ 110: 'flatworm, platyhelminth',
113
+ 111: 'nematode, nematode worm, roundworm',
114
+ 112: 'conch',
115
+ 113: 'snail',
116
+ 114: 'slug',
117
+ 115: 'sea slug, nudibranch',
118
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
119
+ 117: 'chambered nautilus, pearly nautilus, nautilus',
120
+ 118: 'Dungeness crab, Cancer magister',
121
+ 119: 'rock crab, Cancer irroratus',
122
+ 120: 'fiddler crab',
123
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
124
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
125
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
126
+ 124: 'crayfish, crawfish, crawdad, crawdaddy',
127
+ 125: 'hermit crab',
128
+ 126: 'isopod',
129
+ 127: 'white stork, Ciconia ciconia',
130
+ 128: 'black stork, Ciconia nigra',
131
+ 129: 'spoonbill',
132
+ 130: 'flamingo',
133
+ 131: 'little blue heron, Egretta caerulea',
134
+ 132: 'American egret, great white heron, Egretta albus',
135
+ 133: 'bittern',
136
+ 134: 'crane',
137
+ 135: 'limpkin, Aramus pictus',
138
+ 136: 'European gallinule, Porphyrio porphyrio',
139
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
140
+ 138: 'bustard',
141
+ 139: 'ruddy turnstone, Arenaria interpres',
142
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina',
143
+ 141: 'redshank, Tringa totanus',
144
+ 142: 'dowitcher',
145
+ 143: 'oystercatcher, oyster catcher',
146
+ 144: 'pelican',
147
+ 145: 'king penguin, Aptenodytes patagonica',
148
+ 146: 'albatross, mollymawk',
149
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
150
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
151
+ 149: 'dugong, Dugong dugon',
152
+ 150: 'sea lion',
153
+ 151: 'Chihuahua',
154
+ 152: 'Japanese spaniel',
155
+ 153: 'Maltese dog, Maltese terrier, Maltese',
156
+ 154: 'Pekinese, Pekingese, Peke',
157
+ 155: 'Shih-Tzu',
158
+ 156: 'Blenheim spaniel',
159
+ 157: 'papillon',
160
+ 158: 'toy terrier',
161
+ 159: 'Rhodesian ridgeback',
162
+ 160: 'Afghan hound, Afghan',
163
+ 161: 'basset, basset hound',
164
+ 162: 'beagle',
165
+ 163: 'bloodhound, sleuthhound',
166
+ 164: 'bluetick',
167
+ 165: 'black-and-tan coonhound',
168
+ 166: 'Walker hound, Walker foxhound',
169
+ 167: 'English foxhound',
170
+ 168: 'redbone',
171
+ 169: 'borzoi, Russian wolfhound',
172
+ 170: 'Irish wolfhound',
173
+ 171: 'Italian greyhound',
174
+ 172: 'whippet',
175
+ 173: 'Ibizan hound, Ibizan Podenco',
176
+ 174: 'Norwegian elkhound, elkhound',
177
+ 175: 'otterhound, otter hound',
178
+ 176: 'Saluki, gazelle hound',
179
+ 177: 'Scottish deerhound, deerhound',
180
+ 178: 'Weimaraner',
181
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
182
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
183
+ 181: 'Bedlington terrier',
184
+ 182: 'Border terrier',
185
+ 183: 'Kerry blue terrier',
186
+ 184: 'Irish terrier',
187
+ 185: 'Norfolk terrier',
188
+ 186: 'Norwich terrier',
189
+ 187: 'Yorkshire terrier',
190
+ 188: 'wire-haired fox terrier',
191
+ 189: 'Lakeland terrier',
192
+ 190: 'Sealyham terrier, Sealyham',
193
+ 191: 'Airedale, Airedale terrier',
194
+ 192: 'cairn, cairn terrier',
195
+ 193: 'Australian terrier',
196
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier',
197
+ 195: 'Boston bull, Boston terrier',
198
+ 196: 'miniature schnauzer',
199
+ 197: 'giant schnauzer',
200
+ 198: 'standard schnauzer',
201
+ 199: 'Scotch terrier, Scottish terrier, Scottie',
202
+ 200: 'Tibetan terrier, chrysanthemum dog',
203
+ 201: 'silky terrier, Sydney silky',
204
+ 202: 'soft-coated wheaten terrier',
205
+ 203: 'West Highland white terrier',
206
+ 204: 'Lhasa, Lhasa apso',
207
+ 205: 'flat-coated retriever',
208
+ 206: 'curly-coated retriever',
209
+ 207: 'golden retriever',
210
+ 208: 'Labrador retriever',
211
+ 209: 'Chesapeake Bay retriever',
212
+ 210: 'German short-haired pointer',
213
+ 211: 'vizsla, Hungarian pointer',
214
+ 212: 'English setter',
215
+ 213: 'Irish setter, red setter',
216
+ 214: 'Gordon setter',
217
+ 215: 'Brittany spaniel',
218
+ 216: 'clumber, clumber spaniel',
219
+ 217: 'English springer, English springer spaniel',
220
+ 218: 'Welsh springer spaniel',
221
+ 219: 'cocker spaniel, English cocker spaniel, cocker',
222
+ 220: 'Sussex spaniel',
223
+ 221: 'Irish water spaniel',
224
+ 222: 'kuvasz',
225
+ 223: 'schipperke',
226
+ 224: 'groenendael',
227
+ 225: 'malinois',
228
+ 226: 'briard',
229
+ 227: 'kelpie',
230
+ 228: 'komondor',
231
+ 229: 'Old English sheepdog, bobtail',
232
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
233
+ 231: 'collie',
234
+ 232: 'Border collie',
235
+ 233: 'Bouvier des Flandres, Bouviers des Flandres',
236
+ 234: 'Rottweiler',
237
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
238
+ 236: 'Doberman, Doberman pinscher',
239
+ 237: 'miniature pinscher',
240
+ 238: 'Greater Swiss Mountain dog',
241
+ 239: 'Bernese mountain dog',
242
+ 240: 'Appenzeller',
243
+ 241: 'EntleBucher',
244
+ 242: 'boxer',
245
+ 243: 'bull mastiff',
246
+ 244: 'Tibetan mastiff',
247
+ 245: 'French bulldog',
248
+ 246: 'Great Dane',
249
+ 247: 'Saint Bernard, St Bernard',
250
+ 248: 'Eskimo dog, husky',
251
+ 249: 'malamute, malemute, Alaskan malamute',
252
+ 250: 'Siberian husky',
253
+ 251: 'dalmatian, coach dog, carriage dog',
254
+ 252: 'affenpinscher, monkey pinscher, monkey dog',
255
+ 253: 'basenji',
256
+ 254: 'pug, pug-dog',
257
+ 255: 'Leonberg',
258
+ 256: 'Newfoundland, Newfoundland dog',
259
+ 257: 'Great Pyrenees',
260
+ 258: 'Samoyed, Samoyede',
261
+ 259: 'Pomeranian',
262
+ 260: 'chow, chow chow',
263
+ 261: 'keeshond',
264
+ 262: 'Brabancon griffon',
265
+ 263: 'Pembroke, Pembroke Welsh corgi',
266
+ 264: 'Cardigan, Cardigan Welsh corgi',
267
+ 265: 'toy poodle',
268
+ 266: 'miniature poodle',
269
+ 267: 'standard poodle',
270
+ 268: 'Mexican hairless',
271
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
272
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
273
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
274
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
275
+ 273: 'dingo, warrigal, warragal, Canis dingo',
276
+ 274: 'dhole, Cuon alpinus',
277
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
278
+ 276: 'hyena, hyaena',
279
+ 277: 'red fox, Vulpes vulpes',
280
+ 278: 'kit fox, Vulpes macrotis',
281
+ 279: 'Arctic fox, white fox, Alopex lagopus',
282
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
283
+ 281: 'tabby, tabby cat',
284
+ 282: 'tiger cat',
285
+ 283: 'Persian cat',
286
+ 284: 'Siamese cat, Siamese',
287
+ 285: 'Egyptian cat',
288
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
289
+ 287: 'lynx, catamount',
290
+ 288: 'leopard, Panthera pardus',
291
+ 289: 'snow leopard, ounce, Panthera uncia',
292
+ 290: 'jaguar, panther, Panthera onca, Felis onca',
293
+ 291: 'lion, king of beasts, Panthera leo',
294
+ 292: 'tiger, Panthera tigris',
295
+ 293: 'cheetah, chetah, Acinonyx jubatus',
296
+ 294: 'brown bear, bruin, Ursus arctos',
297
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
298
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
299
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
300
+ 298: 'mongoose',
301
+ 299: 'meerkat, mierkat',
302
+ 300: 'tiger beetle',
303
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
304
+ 302: 'ground beetle, carabid beetle',
305
+ 303: 'long-horned beetle, longicorn, longicorn beetle',
306
+ 304: 'leaf beetle, chrysomelid',
307
+ 305: 'dung beetle',
308
+ 306: 'rhinoceros beetle',
309
+ 307: 'weevil',
310
+ 308: 'fly',
311
+ 309: 'bee',
312
+ 310: 'ant, emmet, pismire',
313
+ 311: 'grasshopper, hopper',
314
+ 312: 'cricket',
315
+ 313: 'walking stick, walkingstick, stick insect',
316
+ 314: 'cockroach, roach',
317
+ 315: 'mantis, mantid',
318
+ 316: 'cicada, cicala',
319
+ 317: 'leafhopper',
320
+ 318: 'lacewing, lacewing fly',
321
+ 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
322
+ 320: 'damselfly',
323
+ 321: 'admiral',
324
+ 322: 'ringlet, ringlet butterfly',
325
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
326
+ 324: 'cabbage butterfly',
327
+ 325: 'sulphur butterfly, sulfur butterfly',
328
+ 326: 'lycaenid, lycaenid butterfly',
329
+ 327: 'starfish, sea star',
330
+ 328: 'sea urchin',
331
+ 329: 'sea cucumber, holothurian',
332
+ 330: 'wood rabbit, cottontail, cottontail rabbit',
333
+ 331: 'hare',
334
+ 332: 'Angora, Angora rabbit',
335
+ 333: 'hamster',
336
+ 334: 'porcupine, hedgehog',
337
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
338
+ 336: 'marmot',
339
+ 337: 'beaver',
340
+ 338: 'guinea pig, Cavia cobaya',
341
+ 339: 'sorrel',
342
+ 340: 'zebra',
343
+ 341: 'hog, pig, grunter, squealer, Sus scrofa',
344
+ 342: 'wild boar, boar, Sus scrofa',
345
+ 343: 'warthog',
346
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
347
+ 345: 'ox',
348
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
349
+ 347: 'bison',
350
+ 348: 'ram, tup',
351
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
352
+ 350: 'ibex, Capra ibex',
353
+ 351: 'hartebeest',
354
+ 352: 'impala, Aepyceros melampus',
355
+ 353: 'gazelle',
356
+ 354: 'Arabian camel, dromedary, Camelus dromedarius',
357
+ 355: 'llama',
358
+ 356: 'weasel',
359
+ 357: 'mink',
360
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
361
+ 359: 'black-footed ferret, ferret, Mustela nigripes',
362
+ 360: 'otter',
363
+ 361: 'skunk, polecat, wood pussy',
364
+ 362: 'badger',
365
+ 363: 'armadillo',
366
+ 364: 'three-toed sloth, ai, Bradypus tridactylus',
367
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
368
+ 366: 'gorilla, Gorilla gorilla',
369
+ 367: 'chimpanzee, chimp, Pan troglodytes',
370
+ 368: 'gibbon, Hylobates lar',
371
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
372
+ 370: 'guenon, guenon monkey',
373
+ 371: 'patas, hussar monkey, Erythrocebus patas',
374
+ 372: 'baboon',
375
+ 373: 'macaque',
376
+ 374: 'langur',
377
+ 375: 'colobus, colobus monkey',
378
+ 376: 'proboscis monkey, Nasalis larvatus',
379
+ 377: 'marmoset',
380
+ 378: 'capuchin, ringtail, Cebus capucinus',
381
+ 379: 'howler monkey, howler',
382
+ 380: 'titi, titi monkey',
383
+ 381: 'spider monkey, Ateles geoffroyi',
384
+ 382: 'squirrel monkey, Saimiri sciureus',
385
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
386
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus',
387
+ 385: 'Indian elephant, Elephas maximus',
388
+ 386: 'African elephant, Loxodonta africana',
389
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
390
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
391
+ 389: 'barracouta, snoek',
392
+ 390: 'eel',
393
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
394
+ 392: 'rock beauty, Holocanthus tricolor',
395
+ 393: 'anemone fish',
396
+ 394: 'sturgeon',
397
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
398
+ 396: 'lionfish',
399
+ 397: 'puffer, pufferfish, blowfish, globefish',
400
+ 398: 'abacus',
401
+ 399: 'abaya',
402
+ 400: "academic gown, academic robe, judge's robe",
403
+ 401: 'accordion, piano accordion, squeeze box',
404
+ 402: 'acoustic guitar',
405
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
406
+ 404: 'airliner',
407
+ 405: 'airship, dirigible',
408
+ 406: 'altar',
409
+ 407: 'ambulance',
410
+ 408: 'amphibian, amphibious vehicle',
411
+ 409: 'analog clock',
412
+ 410: 'apiary, bee house',
413
+ 411: 'apron',
414
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
415
+ 413: 'assault rifle, assault gun',
416
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
417
+ 415: 'bakery, bakeshop, bakehouse',
418
+ 416: 'balance beam, beam',
419
+ 417: 'balloon',
420
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro',
421
+ 419: 'Band Aid',
422
+ 420: 'banjo',
423
+ 421: 'bannister, banister, balustrade, balusters, handrail',
424
+ 422: 'barbell',
425
+ 423: 'barber chair',
426
+ 424: 'barbershop',
427
+ 425: 'barn',
428
+ 426: 'barometer',
429
+ 427: 'barrel, cask',
430
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow',
431
+ 429: 'baseball',
432
+ 430: 'basketball',
433
+ 431: 'bassinet',
434
+ 432: 'bassoon',
435
+ 433: 'bathing cap, swimming cap',
436
+ 434: 'bath towel',
437
+ 435: 'bathtub, bathing tub, bath, tub',
438
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
439
+ 437: 'beacon, lighthouse, beacon light, pharos',
440
+ 438: 'beaker',
441
+ 439: 'bearskin, busby, shako',
442
+ 440: 'beer bottle',
443
+ 441: 'beer glass',
444
+ 442: 'bell cote, bell cot',
445
+ 443: 'bib',
446
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem',
447
+ 445: 'bikini, two-piece',
448
+ 446: 'binder, ring-binder',
449
+ 447: 'binoculars, field glasses, opera glasses',
450
+ 448: 'birdhouse',
451
+ 449: 'boathouse',
452
+ 450: 'bobsled, bobsleigh, bob',
453
+ 451: 'bolo tie, bolo, bola tie, bola',
454
+ 452: 'bonnet, poke bonnet',
455
+ 453: 'bookcase',
456
+ 454: 'bookshop, bookstore, bookstall',
457
+ 455: 'bottlecap',
458
+ 456: 'bow',
459
+ 457: 'bow tie, bow-tie, bowtie',
460
+ 458: 'brass, memorial tablet, plaque',
461
+ 459: 'brassiere, bra, bandeau',
462
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
463
+ 461: 'breastplate, aegis, egis',
464
+ 462: 'broom',
465
+ 463: 'bucket, pail',
466
+ 464: 'buckle',
467
+ 465: 'bulletproof vest',
468
+ 466: 'bullet train, bullet',
469
+ 467: 'butcher shop, meat market',
470
+ 468: 'cab, hack, taxi, taxicab',
471
+ 469: 'caldron, cauldron',
472
+ 470: 'candle, taper, wax light',
473
+ 471: 'cannon',
474
+ 472: 'canoe',
475
+ 473: 'can opener, tin opener',
476
+ 474: 'cardigan',
477
+ 475: 'car mirror',
478
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
479
+ 477: "carpenter's kit, tool kit",
480
+ 478: 'carton',
481
+ 479: 'car wheel',
482
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
483
+ 481: 'cassette',
484
+ 482: 'cassette player',
485
+ 483: 'castle',
486
+ 484: 'catamaran',
487
+ 485: 'CD player',
488
+ 486: 'cello, violoncello',
489
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
490
+ 488: 'chain',
491
+ 489: 'chainlink fence',
492
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
493
+ 491: 'chain saw, chainsaw',
494
+ 492: 'chest',
495
+ 493: 'chiffonier, commode',
496
+ 494: 'chime, bell, gong',
497
+ 495: 'china cabinet, china closet',
498
+ 496: 'Christmas stocking',
499
+ 497: 'church, church building',
500
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
501
+ 499: 'cleaver, meat cleaver, chopper',
502
+ 500: 'cliff dwelling',
503
+ 501: 'cloak',
504
+ 502: 'clog, geta, patten, sabot',
505
+ 503: 'cocktail shaker',
506
+ 504: 'coffee mug',
507
+ 505: 'coffeepot',
508
+ 506: 'coil, spiral, volute, whorl, helix',
509
+ 507: 'combination lock',
510
+ 508: 'computer keyboard, keypad',
511
+ 509: 'confectionery, confectionary, candy store',
512
+ 510: 'container ship, containership, container vessel',
513
+ 511: 'convertible',
514
+ 512: 'corkscrew, bottle screw',
515
+ 513: 'cornet, horn, trumpet, trump',
516
+ 514: 'cowboy boot',
517
+ 515: 'cowboy hat, ten-gallon hat',
518
+ 516: 'cradle',
519
+ 517: 'crane',
520
+ 518: 'crash helmet',
521
+ 519: 'crate',
522
+ 520: 'crib, cot',
523
+ 521: 'Crock Pot',
524
+ 522: 'croquet ball',
525
+ 523: 'crutch',
526
+ 524: 'cuirass',
527
+ 525: 'dam, dike, dyke',
528
+ 526: 'desk',
529
+ 527: 'desktop computer',
530
+ 528: 'dial telephone, dial phone',
531
+ 529: 'diaper, nappy, napkin',
532
+ 530: 'digital clock',
533
+ 531: 'digital watch',
534
+ 532: 'dining table, board',
535
+ 533: 'dishrag, dishcloth',
536
+ 534: 'dishwasher, dish washer, dishwashing machine',
537
+ 535: 'disk brake, disc brake',
538
+ 536: 'dock, dockage, docking facility',
539
+ 537: 'dogsled, dog sled, dog sleigh',
540
+ 538: 'dome',
541
+ 539: 'doormat, welcome mat',
542
+ 540: 'drilling platform, offshore rig',
543
+ 541: 'drum, membranophone, tympan',
544
+ 542: 'drumstick',
545
+ 543: 'dumbbell',
546
+ 544: 'Dutch oven',
547
+ 545: 'electric fan, blower',
548
+ 546: 'electric guitar',
549
+ 547: 'electric locomotive',
550
+ 548: 'entertainment center',
551
+ 549: 'envelope',
552
+ 550: 'espresso maker',
553
+ 551: 'face powder',
554
+ 552: 'feather boa, boa',
555
+ 553: 'file, file cabinet, filing cabinet',
556
+ 554: 'fireboat',
557
+ 555: 'fire engine, fire truck',
558
+ 556: 'fire screen, fireguard',
559
+ 557: 'flagpole, flagstaff',
560
+ 558: 'flute, transverse flute',
561
+ 559: 'folding chair',
562
+ 560: 'football helmet',
563
+ 561: 'forklift',
564
+ 562: 'fountain',
565
+ 563: 'fountain pen',
566
+ 564: 'four-poster',
567
+ 565: 'freight car',
568
+ 566: 'French horn, horn',
569
+ 567: 'frying pan, frypan, skillet',
570
+ 568: 'fur coat',
571
+ 569: 'garbage truck, dustcart',
572
+ 570: 'gasmask, respirator, gas helmet',
573
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
574
+ 572: 'goblet',
575
+ 573: 'go-kart',
576
+ 574: 'golf ball',
577
+ 575: 'golfcart, golf cart',
578
+ 576: 'gondola',
579
+ 577: 'gong, tam-tam',
580
+ 578: 'gown',
581
+ 579: 'grand piano, grand',
582
+ 580: 'greenhouse, nursery, glasshouse',
583
+ 581: 'grille, radiator grille',
584
+ 582: 'grocery store, grocery, food market, market',
585
+ 583: 'guillotine',
586
+ 584: 'hair slide',
587
+ 585: 'hair spray',
588
+ 586: 'half track',
589
+ 587: 'hammer',
590
+ 588: 'hamper',
591
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
592
+ 590: 'hand-held computer, hand-held microcomputer',
593
+ 591: 'handkerchief, hankie, hanky, hankey',
594
+ 592: 'hard disc, hard disk, fixed disk',
595
+ 593: 'harmonica, mouth organ, harp, mouth harp',
596
+ 594: 'harp',
597
+ 595: 'harvester, reaper',
598
+ 596: 'hatchet',
599
+ 597: 'holster',
600
+ 598: 'home theater, home theatre',
601
+ 599: 'honeycomb',
602
+ 600: 'hook, claw',
603
+ 601: 'hoopskirt, crinoline',
604
+ 602: 'horizontal bar, high bar',
605
+ 603: 'horse cart, horse-cart',
606
+ 604: 'hourglass',
607
+ 605: 'iPod',
608
+ 606: 'iron, smoothing iron',
609
+ 607: "jack-o'-lantern",
610
+ 608: 'jean, blue jean, denim',
611
+ 609: 'jeep, landrover',
612
+ 610: 'jersey, T-shirt, tee shirt',
613
+ 611: 'jigsaw puzzle',
614
+ 612: 'jinrikisha, ricksha, rickshaw',
615
+ 613: 'joystick',
616
+ 614: 'kimono',
617
+ 615: 'knee pad',
618
+ 616: 'knot',
619
+ 617: 'lab coat, laboratory coat',
620
+ 618: 'ladle',
621
+ 619: 'lampshade, lamp shade',
622
+ 620: 'laptop, laptop computer',
623
+ 621: 'lawn mower, mower',
624
+ 622: 'lens cap, lens cover',
625
+ 623: 'letter opener, paper knife, paperknife',
626
+ 624: 'library',
627
+ 625: 'lifeboat',
628
+ 626: 'lighter, light, igniter, ignitor',
629
+ 627: 'limousine, limo',
630
+ 628: 'liner, ocean liner',
631
+ 629: 'lipstick, lip rouge',
632
+ 630: 'Loafer',
633
+ 631: 'lotion',
634
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
635
+ 633: "loupe, jeweler's loupe",
636
+ 634: 'lumbermill, sawmill',
637
+ 635: 'magnetic compass',
638
+ 636: 'mailbag, postbag',
639
+ 637: 'mailbox, letter box',
640
+ 638: 'maillot',
641
+ 639: 'maillot, tank suit',
642
+ 640: 'manhole cover',
643
+ 641: 'maraca',
644
+ 642: 'marimba, xylophone',
645
+ 643: 'mask',
646
+ 644: 'matchstick',
647
+ 645: 'maypole',
648
+ 646: 'maze, labyrinth',
649
+ 647: 'measuring cup',
650
+ 648: 'medicine chest, medicine cabinet',
651
+ 649: 'megalith, megalithic structure',
652
+ 650: 'microphone, mike',
653
+ 651: 'microwave, microwave oven',
654
+ 652: 'military uniform',
655
+ 653: 'milk can',
656
+ 654: 'minibus',
657
+ 655: 'miniskirt, mini',
658
+ 656: 'minivan',
659
+ 657: 'missile',
660
+ 658: 'mitten',
661
+ 659: 'mixing bowl',
662
+ 660: 'mobile home, manufactured home',
663
+ 661: 'Model T',
664
+ 662: 'modem',
665
+ 663: 'monastery',
666
+ 664: 'monitor',
667
+ 665: 'moped',
668
+ 666: 'mortar',
669
+ 667: 'mortarboard',
670
+ 668: 'mosque',
671
+ 669: 'mosquito net',
672
+ 670: 'motor scooter, scooter',
673
+ 671: 'mountain bike, all-terrain bike, off-roader',
674
+ 672: 'mountain tent',
675
+ 673: 'mouse, computer mouse',
676
+ 674: 'mousetrap',
677
+ 675: 'moving van',
678
+ 676: 'muzzle',
679
+ 677: 'nail',
680
+ 678: 'neck brace',
681
+ 679: 'necklace',
682
+ 680: 'nipple',
683
+ 681: 'notebook, notebook computer',
684
+ 682: 'obelisk',
685
+ 683: 'oboe, hautboy, hautbois',
686
+ 684: 'ocarina, sweet potato',
687
+ 685: 'odometer, hodometer, mileometer, milometer',
688
+ 686: 'oil filter',
689
+ 687: 'organ, pipe organ',
690
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
691
+ 689: 'overskirt',
692
+ 690: 'oxcart',
693
+ 691: 'oxygen mask',
694
+ 692: 'packet',
695
+ 693: 'paddle, boat paddle',
696
+ 694: 'paddlewheel, paddle wheel',
697
+ 695: 'padlock',
698
+ 696: 'paintbrush',
699
+ 697: "pajama, pyjama, pj's, jammies",
700
+ 698: 'palace',
701
+ 699: 'panpipe, pandean pipe, syrinx',
702
+ 700: 'paper towel',
703
+ 701: 'parachute, chute',
704
+ 702: 'parallel bars, bars',
705
+ 703: 'park bench',
706
+ 704: 'parking meter',
707
+ 705: 'passenger car, coach, carriage',
708
+ 706: 'patio, terrace',
709
+ 707: 'pay-phone, pay-station',
710
+ 708: 'pedestal, plinth, footstall',
711
+ 709: 'pencil box, pencil case',
712
+ 710: 'pencil sharpener',
713
+ 711: 'perfume, essence',
714
+ 712: 'Petri dish',
715
+ 713: 'photocopier',
716
+ 714: 'pick, plectrum, plectron',
717
+ 715: 'pickelhaube',
718
+ 716: 'picket fence, paling',
719
+ 717: 'pickup, pickup truck',
720
+ 718: 'pier',
721
+ 719: 'piggy bank, penny bank',
722
+ 720: 'pill bottle',
723
+ 721: 'pillow',
724
+ 722: 'ping-pong ball',
725
+ 723: 'pinwheel',
726
+ 724: 'pirate, pirate ship',
727
+ 725: 'pitcher, ewer',
728
+ 726: "plane, carpenter's plane, woodworking plane",
729
+ 727: 'planetarium',
730
+ 728: 'plastic bag',
731
+ 729: 'plate rack',
732
+ 730: 'plow, plough',
733
+ 731: "plunger, plumber's helper",
734
+ 732: 'Polaroid camera, Polaroid Land camera',
735
+ 733: 'pole',
736
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
737
+ 735: 'poncho',
738
+ 736: 'pool table, billiard table, snooker table',
739
+ 737: 'pop bottle, soda bottle',
740
+ 738: 'pot, flowerpot',
741
+ 739: "potter's wheel",
742
+ 740: 'power drill',
743
+ 741: 'prayer rug, prayer mat',
744
+ 742: 'printer',
745
+ 743: 'prison, prison house',
746
+ 744: 'projectile, missile',
747
+ 745: 'projector',
748
+ 746: 'puck, hockey puck',
749
+ 747: 'punching bag, punch bag, punching ball, punchball',
750
+ 748: 'purse',
751
+ 749: 'quill, quill pen',
752
+ 750: 'quilt, comforter, comfort, puff',
753
+ 751: 'racer, race car, racing car',
754
+ 752: 'racket, racquet',
755
+ 753: 'radiator',
756
+ 754: 'radio, wireless',
757
+ 755: 'radio telescope, radio reflector',
758
+ 756: 'rain barrel',
759
+ 757: 'recreational vehicle, RV, R.V.',
760
+ 758: 'reel',
761
+ 759: 'reflex camera',
762
+ 760: 'refrigerator, icebox',
763
+ 761: 'remote control, remote',
764
+ 762: 'restaurant, eating house, eating place, eatery',
765
+ 763: 'revolver, six-gun, six-shooter',
766
+ 764: 'rifle',
767
+ 765: 'rocking chair, rocker',
768
+ 766: 'rotisserie',
769
+ 767: 'rubber eraser, rubber, pencil eraser',
770
+ 768: 'rugby ball',
771
+ 769: 'rule, ruler',
772
+ 770: 'running shoe',
773
+ 771: 'safe',
774
+ 772: 'safety pin',
775
+ 773: 'saltshaker, salt shaker',
776
+ 774: 'sandal',
777
+ 775: 'sarong',
778
+ 776: 'sax, saxophone',
779
+ 777: 'scabbard',
780
+ 778: 'scale, weighing machine',
781
+ 779: 'school bus',
782
+ 780: 'schooner',
783
+ 781: 'scoreboard',
784
+ 782: 'screen, CRT screen',
785
+ 783: 'screw',
786
+ 784: 'screwdriver',
787
+ 785: 'seat belt, seatbelt',
788
+ 786: 'sewing machine',
789
+ 787: 'shield, buckler',
790
+ 788: 'shoe shop, shoe-shop, shoe store',
791
+ 789: 'shoji',
792
+ 790: 'shopping basket',
793
+ 791: 'shopping cart',
794
+ 792: 'shovel',
795
+ 793: 'shower cap',
796
+ 794: 'shower curtain',
797
+ 795: 'ski',
798
+ 796: 'ski mask',
799
+ 797: 'sleeping bag',
800
+ 798: 'slide rule, slipstick',
801
+ 799: 'sliding door',
802
+ 800: 'slot, one-armed bandit',
803
+ 801: 'snorkel',
804
+ 802: 'snowmobile',
805
+ 803: 'snowplow, snowplough',
806
+ 804: 'soap dispenser',
807
+ 805: 'soccer ball',
808
+ 806: 'sock',
809
+ 807: 'solar dish, solar collector, solar furnace',
810
+ 808: 'sombrero',
811
+ 809: 'soup bowl',
812
+ 810: 'space bar',
813
+ 811: 'space heater',
814
+ 812: 'space shuttle',
815
+ 813: 'spatula',
816
+ 814: 'speedboat',
817
+ 815: "spider web, spider's web",
818
+ 816: 'spindle',
819
+ 817: 'sports car, sport car',
820
+ 818: 'spotlight, spot',
821
+ 819: 'stage',
822
+ 820: 'steam locomotive',
823
+ 821: 'steel arch bridge',
824
+ 822: 'steel drum',
825
+ 823: 'stethoscope',
826
+ 824: 'stole',
827
+ 825: 'stone wall',
828
+ 826: 'stopwatch, stop watch',
829
+ 827: 'stove',
830
+ 828: 'strainer',
831
+ 829: 'streetcar, tram, tramcar, trolley, trolley car',
832
+ 830: 'stretcher',
833
+ 831: 'studio couch, day bed',
834
+ 832: 'stupa, tope',
835
+ 833: 'submarine, pigboat, sub, U-boat',
836
+ 834: 'suit, suit of clothes',
837
+ 835: 'sundial',
838
+ 836: 'sunglass',
839
+ 837: 'sunglasses, dark glasses, shades',
840
+ 838: 'sunscreen, sunblock, sun blocker',
841
+ 839: 'suspension bridge',
842
+ 840: 'swab, swob, mop',
843
+ 841: 'sweatshirt',
844
+ 842: 'swimming trunks, bathing trunks',
845
+ 843: 'swing',
846
+ 844: 'switch, electric switch, electrical switch',
847
+ 845: 'syringe',
848
+ 846: 'table lamp',
849
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
850
+ 848: 'tape player',
851
+ 849: 'teapot',
852
+ 850: 'teddy, teddy bear',
853
+ 851: 'television, television system',
854
+ 852: 'tennis ball',
855
+ 853: 'thatch, thatched roof',
856
+ 854: 'theater curtain, theatre curtain',
857
+ 855: 'thimble',
858
+ 856: 'thresher, thrasher, threshing machine',
859
+ 857: 'throne',
860
+ 858: 'tile roof',
861
+ 859: 'toaster',
862
+ 860: 'tobacco shop, tobacconist shop, tobacconist',
863
+ 861: 'toilet seat',
864
+ 862: 'torch',
865
+ 863: 'totem pole',
866
+ 864: 'tow truck, tow car, wrecker',
867
+ 865: 'toyshop',
868
+ 866: 'tractor',
869
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
870
+ 868: 'tray',
871
+ 869: 'trench coat',
872
+ 870: 'tricycle, trike, velocipede',
873
+ 871: 'trimaran',
874
+ 872: 'tripod',
875
+ 873: 'triumphal arch',
876
+ 874: 'trolleybus, trolley coach, trackless trolley',
877
+ 875: 'trombone',
878
+ 876: 'tub, vat',
879
+ 877: 'turnstile',
880
+ 878: 'typewriter keyboard',
881
+ 879: 'umbrella',
882
+ 880: 'unicycle, monocycle',
883
+ 881: 'upright, upright piano',
884
+ 882: 'vacuum, vacuum cleaner',
885
+ 883: 'vase',
886
+ 884: 'vault',
887
+ 885: 'velvet',
888
+ 886: 'vending machine',
889
+ 887: 'vestment',
890
+ 888: 'viaduct',
891
+ 889: 'violin, fiddle',
892
+ 890: 'volleyball',
893
+ 891: 'waffle iron',
894
+ 892: 'wall clock',
895
+ 893: 'wallet, billfold, notecase, pocketbook',
896
+ 894: 'wardrobe, closet, press',
897
+ 895: 'warplane, military plane',
898
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
899
+ 897: 'washer, automatic washer, washing machine',
900
+ 898: 'water bottle',
901
+ 899: 'water jug',
902
+ 900: 'water tower',
903
+ 901: 'whiskey jug',
904
+ 902: 'whistle',
905
+ 903: 'wig',
906
+ 904: 'window screen',
907
+ 905: 'window shade',
908
+ 906: 'Windsor tie',
909
+ 907: 'wine bottle',
910
+ 908: 'wing',
911
+ 909: 'wok',
912
+ 910: 'wooden spoon',
913
+ 911: 'wool, woolen, woollen',
914
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
915
+ 913: 'wreck',
916
+ 914: 'yawl',
917
+ 915: 'yurt',
918
+ 916: 'web site, website, internet site, site',
919
+ 917: 'comic book',
920
+ 918: 'crossword puzzle, crossword',
921
+ 919: 'street sign',
922
+ 920: 'traffic light, traffic signal, stoplight',
923
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper',
924
+ 922: 'menu',
925
+ 923: 'plate',
926
+ 924: 'guacamole',
927
+ 925: 'consomme',
928
+ 926: 'hot pot, hotpot',
929
+ 927: 'trifle',
930
+ 928: 'ice cream, icecream',
931
+ 929: 'ice lolly, lolly, lollipop, popsicle',
932
+ 930: 'French loaf',
933
+ 931: 'bagel, beigel',
934
+ 932: 'pretzel',
935
+ 933: 'cheeseburger',
936
+ 934: 'hotdog, hot dog, red hot',
937
+ 935: 'mashed potato',
938
+ 936: 'head cabbage',
939
+ 937: 'broccoli',
940
+ 938: 'cauliflower',
941
+ 939: 'zucchini, courgette',
942
+ 940: 'spaghetti squash',
943
+ 941: 'acorn squash',
944
+ 942: 'butternut squash',
945
+ 943: 'cucumber, cuke',
946
+ 944: 'artichoke, globe artichoke',
947
+ 945: 'bell pepper',
948
+ 946: 'cardoon',
949
+ 947: 'mushroom',
950
+ 948: 'Granny Smith',
951
+ 949: 'strawberry',
952
+ 950: 'orange',
953
+ 951: 'lemon',
954
+ 952: 'fig',
955
+ 953: 'pineapple, ananas',
956
+ 954: 'banana',
957
+ 955: 'jackfruit, jak, jack',
958
+ 956: 'custard apple',
959
+ 957: 'pomegranate',
960
+ 958: 'hay',
961
+ 959: 'carbonara',
962
+ 960: 'chocolate sauce, chocolate syrup',
963
+ 961: 'dough',
964
+ 962: 'meat loaf, meatloaf',
965
+ 963: 'pizza, pizza pie',
966
+ 964: 'potpie',
967
+ 965: 'burrito',
968
+ 966: 'red wine',
969
+ 967: 'espresso',
970
+ 968: 'cup',
971
+ 969: 'eggnog',
972
+ 970: 'alp',
973
+ 971: 'bubble',
974
+ 972: 'cliff, drop, drop-off',
975
+ 973: 'coral reef',
976
+ 974: 'geyser',
977
+ 975: 'lakeside, lakeshore',
978
+ 976: 'promontory, headland, head, foreland',
979
+ 977: 'sandbar, sand bar',
980
+ 978: 'seashore, coast, seacoast, sea-coast',
981
+ 979: 'valley, vale',
982
+ 980: 'volcano',
983
+ 981: 'ballplayer, baseball player',
984
+ 982: 'groom, bridegroom',
985
+ 983: 'scuba diver',
986
+ 984: 'rapeseed',
987
+ 985: 'daisy',
988
+ 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
989
+ 987: 'corn',
990
+ 988: 'acorn',
991
+ 989: 'hip, rose hip, rosehip',
992
+ 990: 'buckeye, horse chestnut, conker',
993
+ 991: 'coral fungus',
994
+ 992: 'agaric',
995
+ 993: 'gyromitra',
996
+ 994: 'stinkhorn, carrion fungus',
997
+ 995: 'earthstar',
998
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
999
+ 997: 'bolete',
1000
+ 998: 'ear, spike, capitulum',
1001
+ 999: 'toilet tissue, toilet paper, bathroom tissue'}
modeling/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (2024) Bytedance Ltd. and/or its affiliates
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
modeling/maskgit.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains implementation for MaskGIT model.
2
+
3
+ Copyright (2024) Bytedance Ltd. and/or its affiliates
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+
17
+ Reference:
18
+ https://github.com/huggingface/open-muse
19
+ https://github.com/baaivision/MUSE-Pytorch
20
+ https://github.com/sail-sg/MDT/blob/main/masked_diffusion/models.py
21
+ """
22
+
23
+ import torch
24
+ from torch import nn
25
+ import numpy as np
26
+ import math
27
+ import torch.utils.checkpoint
28
+ from transformers import BertConfig, BertModel
29
+ from einops import rearrange
30
+
31
+ import json
32
+ from huggingface_hub import PyTorchModelHubMixin
33
+ from omegaconf import OmegaConf
34
+ from pathlib import Path
35
+
36
+ from modeling.modules.base_model import BaseModel
37
+ from modeling.modules.blocks import UViTBlock
38
+
39
+
40
+ class ImageBert(BaseModel, PyTorchModelHubMixin, tags=["arxiv:2406.07550", "image-generation"], repo_url="https://github.com/bytedance/1d-tokenizer", license="apache-2.0"):
41
+ def __init__(self, config):
42
+
43
+ if isinstance(config, dict):
44
+ config = OmegaConf.create(config)
45
+
46
+ super().__init__()
47
+ self.config = config
48
+ self.target_codebook_size = config.model.vq_model.codebook_size
49
+ self.condition_num_classes = config.model.generator.condition_num_classes
50
+ self.image_seq_len = config.model.generator.image_seq_len
51
+ self.mask_token_id = self.target_codebook_size
52
+ self.hidden_size = config.model.generator.hidden_size
53
+ self.num_hidden_layers = config.model.generator.num_hidden_layers
54
+ self.num_attention_heads = config.model.generator.num_attention_heads
55
+ self.intermediate_size = config.model.generator.intermediate_size
56
+
57
+ self.model = BertModel(BertConfig(
58
+ vocab_size=self.target_codebook_size + self.condition_num_classes + 2,
59
+ hidden_size=self.hidden_size,
60
+ num_hidden_layers=self.num_hidden_layers,
61
+ num_attention_heads=self.num_attention_heads,
62
+ intermediate_size=self.intermediate_size,
63
+ hidden_act='gelu',
64
+ hidden_dropout_prob=config.model.generator.dropout,
65
+ attention_probs_dropout_prob=config.model.generator.attn_drop,
66
+ max_position_embeddings=config.model.generator.image_seq_len + 1,
67
+ initializer_range=0.02,
68
+ layer_norm_eps=1e-12,
69
+ pad_token_id=None,
70
+ position_embedding_type="absolute",
71
+ use_cache=True
72
+ ), add_pooling_layer=False)
73
+ self.model.lm_head = nn.Linear(self.hidden_size, self.target_codebook_size, bias=True)
74
+
75
+ self.model.post_init()
76
+
77
+ def _save_pretrained(self, save_directory: Path) -> None:
78
+ """Save weights and config to a local directory."""
79
+ # Assume 'self.config' is your DictConfig object
80
+ # Convert to a regular dictionary
81
+ dict_config = OmegaConf.to_container(self.config)
82
+ # Save as JSON
83
+ file_path = Path(save_directory) / "config.json"
84
+ with open(file_path, 'w') as json_file:
85
+ json.dump(dict_config, json_file, indent=4)
86
+ super()._save_pretrained(save_directory)
87
+
88
+ def forward(self, input_ids=None, condition=None, cond_drop_prob=0.1):
89
+ # Token space:
90
+ # [0, codebook_size - 1] : those are the learned quantized image tokens
91
+ # codebook_size : the mask token used to mask image tokens
92
+ # [codebook_size + 1, codebook_size + nclass] : the imagenet class tokens
93
+ # codebook_size + 1 + nclass : the class drop label
94
+ drop_label_mask = torch.rand_like(condition, dtype=torch.float) < cond_drop_prob
95
+ # Shift the classes
96
+ condition = condition + self.target_codebook_size + 1 # [0, 999] -> [codebook_size + 1, codebook_size + 999]
97
+ condition[drop_label_mask] = self.condition_num_classes + self.target_codebook_size + 1
98
+ # prepend condition token
99
+ if input_ids is not None:
100
+ input_ids = torch.cat([condition.view(condition.shape[0], -1),
101
+ input_ids.view(input_ids.shape[0], -1),], dim=1)
102
+ else:
103
+ # at least there should be masked token
104
+ raise NotImplementedError
105
+ model_output = self.model(input_ids=input_ids)
106
+ model_output = model_output[0]
107
+ return self.model.lm_head(model_output[:, 1:]) # remove cond
108
+
109
+ # ref: https://github.com/baaivision/MUSE-Pytorch/blob/master/libs/muse.py#L40
110
+ @torch.no_grad()
111
+ def generate(self,
112
+ condition,
113
+ guidance_scale=3.0,
114
+ guidance_decay="constant",
115
+ guidance_scale_pow=3.0,
116
+ randomize_temperature=4.5,
117
+ softmax_temperature_annealing=False,
118
+ num_sample_steps=8):
119
+ if guidance_decay not in ["constant", "linear", "power-cosine"]:
120
+ # contstant: constant guidance scale
121
+ # linear: linear increasing the guidance scale as in MUSE
122
+ # power-cosine: the guidance schedule from MDT
123
+ raise ValueError(f"Unsupported guidance decay {guidance_decay}")
124
+ device = condition.device
125
+ ids = torch.full((condition.shape[0], self.image_seq_len),
126
+ self.mask_token_id, device=device)
127
+
128
+ cfg_scale = guidance_scale if guidance_decay == "constant" else 0.
129
+
130
+ for step in range(num_sample_steps):
131
+ ratio = 1. * (step + 1) / num_sample_steps
132
+ annealed_temp = randomize_temperature * (1.0 - ratio)
133
+ is_mask = (ids == self.mask_token_id)
134
+
135
+ if guidance_decay == "power-cosine":
136
+ # ref: https://github.com/sail-sg/MDT/blob/main/masked_diffusion/models.py#L501
137
+ guidance_scale_pow = torch.ones((1), device=device) * guidance_scale_pow
138
+ scale_step = (1 - torch.cos(((step / num_sample_steps) ** guidance_scale_pow) * torch.pi)) * 1/2
139
+ cfg_scale = (guidance_scale - 1) * scale_step + 1
140
+
141
+ if cfg_scale != 0:
142
+ cond_logits = self.forward(
143
+ ids, condition, cond_drop_prob=0.0
144
+ )
145
+ uncond_logits = self.forward(
146
+ ids, condition, cond_drop_prob=1.0
147
+ )
148
+ if guidance_decay == "power-cosine":
149
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
150
+ else:
151
+ logits = cond_logits + (cond_logits - uncond_logits) * cfg_scale
152
+ else:
153
+ logits = self.forward(
154
+ ids, condition, cond_drop_prob=0.0
155
+ )
156
+
157
+ if softmax_temperature_annealing:
158
+ softmax_temperature = 0.5 + 0.8 * (1 - ratio)
159
+ logits = logits / softmax_temperature
160
+
161
+ # Add gumbel noise
162
+ def log(t, eps=1e-20):
163
+ return torch.log(t.clamp(min=eps))
164
+ def gumbel_noise(t):
165
+ noise = torch.zeros_like(t).uniform_(0, 1)
166
+ return -log(-log(noise))
167
+ def add_gumbel_noise(t, temperature):
168
+ return t + temperature * gumbel_noise(t)
169
+
170
+ sampled_ids = add_gumbel_noise(logits, annealed_temp).argmax(dim=-1)
171
+ sampled_logits = torch.squeeze(
172
+ torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
173
+ sampled_ids = torch.where(is_mask, sampled_ids, ids)
174
+ sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float()
175
+ # masking
176
+ mask_ratio = np.arccos(ratio) / (math.pi * 0.5)
177
+
178
+ mask_len = torch.Tensor([np.floor(self.image_seq_len * mask_ratio)]).to(device)
179
+ mask_len = torch.maximum(torch.Tensor([1]).to(device),
180
+ torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1,
181
+ mask_len))[0].squeeze()
182
+ confidence = add_gumbel_noise(sampled_logits, annealed_temp)
183
+ sorted_confidence, _ = torch.sort(confidence, axis=-1)
184
+ cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()]
185
+ masking = (confidence <= cut_off)
186
+ if step == num_sample_steps - 1:
187
+ ids = sampled_ids
188
+ else:
189
+ ids = torch.where(masking, self.mask_token_id, sampled_ids)
190
+
191
+ if guidance_decay == "linear":
192
+ cfg_scale = ratio * guidance_scale
193
+ return ids
194
+
195
+ def masking_input_tokens(self, input_tokens):
196
+ batch_size, seq_len = input_tokens.shape
197
+ device = input_tokens.device
198
+
199
+ timesteps = torch.zeros((batch_size,), device=device).float().uniform_(0, 1.0)
200
+ mask_ratio = torch.acos(timesteps) / (math.pi * 0.5) # arccos schedule
201
+ mask_ratio = torch.clamp(mask_ratio, min=1e-6, max=1.)
202
+ num_token_masked = (seq_len * mask_ratio).round().clamp(min=1)
203
+ batch_randperm = torch.rand(batch_size, seq_len, device=device).argsort(dim=-1)
204
+ masks = batch_randperm < rearrange(num_token_masked, 'b -> b 1')
205
+ masked_tokens = torch.where(masks, self.mask_token_id, input_tokens)
206
+ return masked_tokens, masks
207
+
208
+
209
+ class UViTBert(ImageBert):
210
+ def __init__(self, config):
211
+ super().__init__(config=config)
212
+
213
+ del self.model
214
+
215
+ self.embeddings = nn.Embedding(
216
+ self.target_codebook_size + self.condition_num_classes + 2,
217
+ self.hidden_size)
218
+
219
+ self.pos_embed = nn.init.trunc_normal_(
220
+ nn.Parameter(torch.zeros(1, self.config.model.generator.image_seq_len + 1, self.hidden_size)), 0., 0.02)
221
+
222
+ self.in_blocks = nn.ModuleList([
223
+ UViTBlock(
224
+ dim=self.hidden_size, num_heads=self.num_attention_heads, mlp_ratio=(self.intermediate_size / self.hidden_size),
225
+ qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, use_checkpoint=False)
226
+ for _ in range(self.num_hidden_layers // 2)])
227
+
228
+ self.mid_block = UViTBlock(
229
+ dim=self.hidden_size, num_heads=self.num_attention_heads, mlp_ratio=(self.intermediate_size / self.hidden_size),
230
+ qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, use_checkpoint=False)
231
+
232
+ self.out_blocks = nn.ModuleList([
233
+ UViTBlock(
234
+ dim=self.hidden_size, num_heads=self.num_attention_heads, mlp_ratio=(self.intermediate_size / self.hidden_size),
235
+ qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, skip=True, use_checkpoint=False)
236
+ for _ in range(self.num_hidden_layers // 2)])
237
+
238
+ self.norm = nn.LayerNorm(self.hidden_size)
239
+ self.lm_head = nn.Linear(self.hidden_size,
240
+ self.target_codebook_size, bias=True)
241
+ self.apply(self._init_weights)
242
+
243
+ def _init_weights(self, m):
244
+ if isinstance(m, nn.Linear):
245
+ nn.init.trunc_normal_(m.weight, std=.02)
246
+ if isinstance(m, nn.Linear) and m.bias is not None:
247
+ nn.init.constant_(m.bias, 0)
248
+ elif isinstance(m, nn.Embedding):
249
+ m.weight.data = nn.init.trunc_normal_(m.weight.data, mean=0.0, std=0.02)
250
+ elif isinstance(m, nn.LayerNorm):
251
+ nn.init.constant_(m.bias, 0)
252
+ nn.init.constant_(m.weight, 1.0)
253
+
254
+ def forward(self, input_ids=None, condition=None, cond_drop_prob=0.1):
255
+ # Token space:
256
+ # [0, codebook_size - 1] : those are the learned quantized image tokens
257
+ # codebook_size : the mask token used to mask image tokens
258
+ # [codebook_size + 1, codebook_size + nclass] : the imagenet class tokens
259
+ # codebook_size + 1 + nclass : the class drop label
260
+ drop_label_mask = torch.rand_like(condition, dtype=torch.float) < cond_drop_prob
261
+ # Shift the classes
262
+ condition = condition + self.target_codebook_size + 1 # [0, 999] -> [codebook_size + 1, codebook_size + 999]
263
+ condition[drop_label_mask] = self.condition_num_classes + self.target_codebook_size + 1
264
+ # prepend condition token
265
+ if input_ids is not None:
266
+ input_ids = torch.cat([condition.view(condition.shape[0], -1),
267
+ input_ids.view(input_ids.shape[0], -1),], dim=1)
268
+ else:
269
+ # at least there should be masked token
270
+ raise NotImplementedError
271
+ # UViT forward
272
+ embeddings = self.embeddings(input_ids)
273
+ x = embeddings + self.pos_embed[:, :embeddings.shape[1]]
274
+ skips = []
275
+ for blk in self.in_blocks:
276
+ x = blk(x)
277
+ skips.append(x)
278
+ x = self.mid_block(x)
279
+ for blk in self.out_blocks:
280
+ x = blk(x, skips.pop())
281
+ x = self.norm(x)
282
+ return self.lm_head(x[:, 1:]) # remove cond
modeling/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .base_model import BaseModel
2
+ from .ema_model import EMAModel
3
+ from .losses import ReconstructionLoss_Stage1, ReconstructionLoss_Stage2, MLMLoss, ARLoss
4
+ from .blocks import TiTokEncoder, TiTokDecoder, UViTBlock
5
+ from .maskgit_vqgan import Decoder as Pixel_Decoder
6
+ from .maskgit_vqgan import VectorQuantizer as Pixel_Quantizer
modeling/modules/base_model.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains some base class implementation for models.
2
+
3
+ This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
4
+ All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
5
+
6
+ Reference:
7
+ https://github.com/huggingface/open-muse/blob/main/muse/modeling_utils.py
8
+ """
9
+ import os
10
+ from typing import Union, Callable, Dict, Optional
11
+
12
+ import torch
13
+
14
+
15
+ class BaseModel(torch.nn.Module):
16
+
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ def save_pretrained_weight(
21
+ self,
22
+ save_directory: Union[str, os.PathLike],
23
+ save_function: Callable = None,
24
+ state_dict: Optional[Dict[str, torch.Tensor]] = None,
25
+ ):
26
+ """Saves a model and its configuration file to a directory.
27
+
28
+ Args:
29
+ save_directory: A string or os.PathLike, directory to which to save.
30
+ Will be created if it doesn't exist.
31
+ save_function: A Callable function, the function to use to save the state dictionary.
32
+ Useful on distributed training like TPUs when one need to replace `torch.save` by
33
+ another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`.
34
+ state_dict: A dictionary from str to torch.Tensor, the state dictionary to save.
35
+ If `None`, the model's state dictionary will be saved.
36
+ """
37
+ if os.path.isfile(save_directory):
38
+ print(f"Provided path ({save_directory}) should be a directory, not a file")
39
+ return
40
+
41
+ if save_function is None:
42
+ save_function = torch.save
43
+
44
+ os.makedirs(save_directory, exist_ok=True)
45
+
46
+ model_to_save = self
47
+
48
+ if state_dict is None:
49
+ state_dict = model_to_save.state_dict()
50
+ weights_name = "pytorch_model.bin"
51
+
52
+ save_function(state_dict, os.path.join(save_directory, weights_name))
53
+
54
+ print(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
55
+
56
+ def load_pretrained_weight(
57
+ self,
58
+ pretrained_model_path: Union[str, os.PathLike],
59
+ strict_loading: bool = True,
60
+ torch_dtype: Optional[torch.dtype] = None
61
+ ):
62
+ r"""Instantiates a pretrained pytorch model from a pre-trained model configuration.
63
+
64
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
65
+ the model, you should first set it back in training mode with `model.train()`.
66
+
67
+ Args:
68
+ pretrained_model_path: A string or os.PathLike, a path to a *directory* or *file* containing model weights.
69
+
70
+ Raises:
71
+ ValueError: If pretrained_model_path does not exist.
72
+ """
73
+ # If pretrained_model_path is a file, set model_file to this file.
74
+ if os.path.isfile(pretrained_model_path):
75
+ model_file = pretrained_model_path
76
+ # If pretrained_model_path is a directory, set model_file to the path of the
77
+ # file "pytorch_model.bin" in this directory.
78
+ elif os.path.isdir(pretrained_model_path):
79
+ pretrained_model_path = os.path.join(pretrained_model_path, "pytorch_model.bin")
80
+ if os.path.isfile(pretrained_model_path):
81
+ model_file = pretrained_model_path
82
+ else:
83
+ raise ValueError(f"{pretrained_model_path} does not exist")
84
+ else:
85
+ raise ValueError(f"{pretrained_model_path} does not exist")
86
+
87
+ # Load model state from checkpoint.
88
+ checkpoint = torch.load(model_file, map_location="cpu")
89
+ # Load state dictionary into self.
90
+ msg = self.load_state_dict(checkpoint, strict=strict_loading)
91
+ # Print information about loading weights.
92
+ print(f"loading weight from {model_file}, msg: {msg}")
93
+ # If torch_dtype is specified and is a valid torch.dtype, convert self to this dtype.
94
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
95
+ raise ValueError(
96
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
97
+ )
98
+ elif torch_dtype is not None:
99
+ self.to(torch_dtype)
100
+
101
+ # Set model in evaluation mode to deactivate DropOut modules by default.
102
+ self.eval()
103
+
104
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
105
+ """Gets the number of parameters in the module.
106
+
107
+ Args:
108
+ only_trainable: A boolean, whether to only include trainable parameters.
109
+ exclude_embeddings: A boolean, whether to exclude parameters associated with embeddings.
110
+
111
+ Returns:
112
+ An integer, the number of parameters.
113
+ """
114
+
115
+ if exclude_embeddings:
116
+ embedding_param_names = [
117
+ f"{name}.weight"
118
+ for name, module_type in self.named_modules()
119
+ if isinstance(module_type, torch.nn.Embedding)
120
+ ]
121
+ non_embedding_parameters = [
122
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
123
+ ]
124
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
125
+ else:
126
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
127
+
modeling/modules/blocks.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Building blocks for TiTok.
2
+
3
+ Copyright (2024) Bytedance Ltd. and/or its affiliates
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+
17
+ Reference:
18
+ https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py
19
+ https://github.com/baofff/U-ViT/blob/main/libs/timm.py
20
+ """
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ from collections import OrderedDict
25
+ import einops
26
+ from einops.layers.torch import Rearrange
27
+
28
+
29
+ class ResidualAttentionBlock(nn.Module):
30
+ def __init__(
31
+ self,
32
+ d_model,
33
+ n_head,
34
+ mlp_ratio = 4.0,
35
+ act_layer = nn.GELU,
36
+ norm_layer = nn.LayerNorm
37
+ ):
38
+ super().__init__()
39
+
40
+ self.ln_1 = norm_layer(d_model)
41
+ self.attn = nn.MultiheadAttention(d_model, n_head)
42
+ self.mlp_ratio = mlp_ratio
43
+ # optionally we can disable the FFN
44
+ if mlp_ratio > 0:
45
+ self.ln_2 = norm_layer(d_model)
46
+ mlp_width = int(d_model * mlp_ratio)
47
+ self.mlp = nn.Sequential(OrderedDict([
48
+ ("c_fc", nn.Linear(d_model, mlp_width)),
49
+ ("gelu", act_layer()),
50
+ ("c_proj", nn.Linear(mlp_width, d_model))
51
+ ]))
52
+
53
+ def attention(
54
+ self,
55
+ x: torch.Tensor
56
+ ):
57
+ return self.attn(x, x, x, need_weights=False)[0]
58
+
59
+ def forward(
60
+ self,
61
+ x: torch.Tensor,
62
+ ):
63
+ attn_output = self.attention(x=self.ln_1(x))
64
+ x = x + attn_output
65
+ if self.mlp_ratio > 0:
66
+ x = x + self.mlp(self.ln_2(x))
67
+ return x
68
+
69
+ if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
70
+ ATTENTION_MODE = 'flash'
71
+ else:
72
+ try:
73
+ import xformers
74
+ import xformers.ops
75
+ ATTENTION_MODE = 'xformers'
76
+ except:
77
+ ATTENTION_MODE = 'math'
78
+ print(f'attention mode is {ATTENTION_MODE}')
79
+
80
+
81
+ class Attention(nn.Module):
82
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
83
+ super().__init__()
84
+ self.num_heads = num_heads
85
+ head_dim = dim // num_heads
86
+ self.scale = qk_scale or head_dim ** -0.5
87
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
88
+ self.attn_drop = nn.Dropout(attn_drop)
89
+ self.proj = nn.Linear(dim, dim)
90
+ self.proj_drop = nn.Dropout(proj_drop)
91
+
92
+ def forward(self, x):
93
+ B, L, C = x.shape
94
+
95
+ qkv = self.qkv(x)
96
+ if ATTENTION_MODE == 'flash':
97
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
98
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
99
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
100
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
101
+ elif ATTENTION_MODE == 'xformers':
102
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
103
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
104
+ x = xformers.ops.memory_efficient_attention(q, k, v)
105
+ x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
106
+ elif ATTENTION_MODE == 'math':
107
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
108
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
109
+ attn = (q @ k.transpose(-2, -1)) * self.scale
110
+ attn = attn.softmax(dim=-1)
111
+ attn = self.attn_drop(attn)
112
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
113
+ else:
114
+ raise NotImplemented
115
+
116
+ x = self.proj(x)
117
+ x = self.proj_drop(x)
118
+ return x
119
+
120
+
121
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
122
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
123
+
124
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
125
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
126
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
127
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
128
+ 'survival rate' as the argument.
129
+
130
+ """
131
+ if drop_prob == 0. or not training:
132
+ return x
133
+ keep_prob = 1 - drop_prob
134
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
135
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
136
+ random_tensor.floor_() # binarize
137
+ output = x.div(keep_prob) * random_tensor
138
+ return output
139
+
140
+
141
+ class DropPath(nn.Module):
142
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
143
+ """
144
+ def __init__(self, drop_prob=None):
145
+ super(DropPath, self).__init__()
146
+ self.drop_prob = drop_prob
147
+
148
+ def forward(self, x):
149
+ return drop_path(x, self.drop_prob, self.training)
150
+
151
+
152
+ class Mlp(nn.Module):
153
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
154
+ super().__init__()
155
+ out_features = out_features or in_features
156
+ hidden_features = hidden_features or in_features
157
+ self.fc1 = nn.Linear(in_features, hidden_features)
158
+ self.act = act_layer()
159
+ self.fc2 = nn.Linear(hidden_features, out_features)
160
+ self.drop = nn.Dropout(drop)
161
+
162
+ def forward(self, x):
163
+ x = self.fc1(x)
164
+ x = self.act(x)
165
+ x = self.drop(x)
166
+ x = self.fc2(x)
167
+ x = self.drop(x)
168
+ return x
169
+
170
+
171
+ class UViTBlock(nn.Module):
172
+
173
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
174
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
175
+ super().__init__()
176
+ self.norm1 = norm_layer(dim)
177
+ self.attn = Attention(
178
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
179
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
180
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
181
+ self.norm2 = norm_layer(dim)
182
+ mlp_hidden_dim = int(dim * mlp_ratio)
183
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
184
+ self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
185
+ self.use_checkpoint = use_checkpoint
186
+
187
+ def forward(self, x, skip=None):
188
+ if self.use_checkpoint:
189
+ return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
190
+ else:
191
+ return self._forward(x, skip)
192
+
193
+ def _forward(self, x, skip=None):
194
+ if self.skip_linear is not None:
195
+ x = self.skip_linear(torch.cat([x, skip], dim=-1))
196
+ x = x + self.drop_path(self.attn(self.norm1(x)))
197
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
198
+ return x
199
+
200
+
201
+ def _expand_token(token, batch_size: int):
202
+ return token.unsqueeze(0).expand(batch_size, -1, -1)
203
+
204
+
205
+ class TiTokEncoder(nn.Module):
206
+ def __init__(self, config):
207
+ super().__init__()
208
+ self.config = config
209
+ self.image_size = config.dataset.preprocessing.crop_size
210
+ self.patch_size = config.model.vq_model.vit_enc_patch_size
211
+ self.grid_size = self.image_size // self.patch_size
212
+ self.model_size = config.model.vq_model.vit_enc_model_size
213
+ self.num_latent_tokens = config.model.vq_model.num_latent_tokens
214
+ self.token_size = config.model.vq_model.token_size
215
+
216
+ if config.model.vq_model.get("quantize_mode", "vq") == "vae":
217
+ self.token_size = self.token_size * 2 # needs to split into mean and std
218
+
219
+ self.is_legacy = config.model.vq_model.get("is_legacy", True)
220
+
221
+ self.width = {
222
+ "small": 512,
223
+ "base": 768,
224
+ "large": 1024,
225
+ }[self.model_size]
226
+ self.num_layers = {
227
+ "small": 8,
228
+ "base": 12,
229
+ "large": 24,
230
+ }[self.model_size]
231
+ self.num_heads = {
232
+ "small": 8,
233
+ "base": 12,
234
+ "large": 16,
235
+ }[self.model_size]
236
+
237
+ self.patch_embed = nn.Conv2d(
238
+ in_channels=3, out_channels=self.width,
239
+ kernel_size=self.patch_size, stride=self.patch_size, bias=True)
240
+
241
+ scale = self.width ** -0.5
242
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
243
+ self.positional_embedding = nn.Parameter(
244
+ scale * torch.randn(self.grid_size ** 2 + 1, self.width))
245
+ self.latent_token_positional_embedding = nn.Parameter(
246
+ scale * torch.randn(self.num_latent_tokens, self.width))
247
+ self.ln_pre = nn.LayerNorm(self.width)
248
+ self.transformer = nn.ModuleList()
249
+ for i in range(self.num_layers):
250
+ self.transformer.append(ResidualAttentionBlock(
251
+ self.width, self.num_heads, mlp_ratio=4.0
252
+ ))
253
+ self.ln_post = nn.LayerNorm(self.width)
254
+ self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True)
255
+
256
+ def forward(self, pixel_values, latent_tokens):
257
+ batch_size = pixel_values.shape[0]
258
+ x = pixel_values
259
+ x = self.patch_embed(x)
260
+ x = x.reshape(x.shape[0], x.shape[1], -1)
261
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
262
+ # class embeddings and positional embeddings
263
+ x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
264
+ x = x + self.positional_embedding.to(x.dtype) # shape = [*, grid ** 2 + 1, width]
265
+
266
+
267
+ latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype)
268
+ latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)
269
+ x = torch.cat([x, latent_tokens], dim=1)
270
+
271
+ x = self.ln_pre(x)
272
+ x = x.permute(1, 0, 2) # NLD -> LND
273
+ for i in range(self.num_layers):
274
+ x = self.transformer[i](x)
275
+ x = x.permute(1, 0, 2) # LND -> NLD
276
+
277
+ latent_tokens = x[:, 1+self.grid_size**2:]
278
+ latent_tokens = self.ln_post(latent_tokens)
279
+ # fake 2D shape
280
+ if self.is_legacy:
281
+ latent_tokens = latent_tokens.reshape(batch_size, self.width, self.num_latent_tokens, 1)
282
+ else:
283
+ # Fix legacy problem.
284
+ latent_tokens = latent_tokens.reshape(batch_size, self.num_latent_tokens, self.width, 1).permute(0, 2, 1, 3)
285
+ latent_tokens = self.conv_out(latent_tokens)
286
+ latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, self.num_latent_tokens)
287
+ return latent_tokens
288
+
289
+
290
+ class TiTokDecoder(nn.Module):
291
+ def __init__(self, config):
292
+ super().__init__()
293
+ self.config = config
294
+ self.image_size = config.dataset.preprocessing.crop_size
295
+ self.patch_size = config.model.vq_model.vit_dec_patch_size
296
+ self.grid_size = self.image_size // self.patch_size
297
+ self.model_size = config.model.vq_model.vit_dec_model_size
298
+ self.num_latent_tokens = config.model.vq_model.num_latent_tokens
299
+ self.token_size = config.model.vq_model.token_size
300
+ self.is_legacy = config.model.vq_model.get("is_legacy", True)
301
+ self.width = {
302
+ "small": 512,
303
+ "base": 768,
304
+ "large": 1024,
305
+ }[self.model_size]
306
+ self.num_layers = {
307
+ "small": 8,
308
+ "base": 12,
309
+ "large": 24,
310
+ }[self.model_size]
311
+ self.num_heads = {
312
+ "small": 8,
313
+ "base": 12,
314
+ "large": 16,
315
+ }[self.model_size]
316
+
317
+ self.decoder_embed = nn.Linear(
318
+ self.token_size, self.width, bias=True)
319
+ scale = self.width ** -0.5
320
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
321
+ self.positional_embedding = nn.Parameter(
322
+ scale * torch.randn(self.grid_size ** 2 + 1, self.width))
323
+ # add mask token and query pos embed
324
+ self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
325
+ self.latent_token_positional_embedding = nn.Parameter(
326
+ scale * torch.randn(self.num_latent_tokens, self.width))
327
+ self.ln_pre = nn.LayerNorm(self.width)
328
+ self.transformer = nn.ModuleList()
329
+ for i in range(self.num_layers):
330
+ self.transformer.append(ResidualAttentionBlock(
331
+ self.width, self.num_heads, mlp_ratio=4.0
332
+ ))
333
+ self.ln_post = nn.LayerNorm(self.width)
334
+
335
+ if self.is_legacy:
336
+ self.ffn = nn.Sequential(
337
+ nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True),
338
+ nn.Tanh(),
339
+ nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True),
340
+ )
341
+ self.conv_out = nn.Identity()
342
+ else:
343
+ # Directly predicting RGB pixels
344
+ self.ffn = nn.Sequential(
345
+ nn.Conv2d(self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True),
346
+ Rearrange('b (p1 p2 c) h w -> b c (h p1) (w p2)',
347
+ p1 = self.patch_size, p2 = self.patch_size),)
348
+ self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)
349
+
350
+ def forward(self, z_quantized):
351
+ N, C, H, W = z_quantized.shape
352
+ assert H == 1 and W == self.num_latent_tokens, f"{H}, {W}, {self.num_latent_tokens}"
353
+ x = z_quantized.reshape(N, C*H, W).permute(0, 2, 1) # NLD
354
+ x = self.decoder_embed(x)
355
+
356
+ batchsize, seq_len, _ = x.shape
357
+
358
+ mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(x.dtype)
359
+ mask_tokens = torch.cat([_expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),
360
+ mask_tokens], dim=1)
361
+ mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype)
362
+ x = x + self.latent_token_positional_embedding[:seq_len]
363
+ x = torch.cat([mask_tokens, x], dim=1)
364
+
365
+ x = self.ln_pre(x)
366
+ x = x.permute(1, 0, 2) # NLD -> LND
367
+ for i in range(self.num_layers):
368
+ x = self.transformer[i](x)
369
+ x = x.permute(1, 0, 2) # LND -> NLD
370
+ x = x[:, 1:1+self.grid_size**2] # remove cls embed
371
+ x = self.ln_post(x)
372
+ # N L D -> N D H W
373
+ x = x.permute(0, 2, 1).reshape(batchsize, self.width, self.grid_size, self.grid_size)
374
+ x = self.ffn(x.contiguous())
375
+ x = self.conv_out(x)
376
+ return x
modeling/modules/discriminator.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains some base implementation for discrminators.
2
+
3
+ Copyright (2024) Bytedance Ltd. and/or its affiliates
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+
17
+ TODO: Add reference to Mark Weber's tech report on the improved discriminator architecture.
18
+ """
19
+ import functools
20
+ import math
21
+ from typing import Tuple
22
+
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+
28
+ from .maskgit_vqgan import Conv2dSame
29
+
30
+
31
+ class BlurBlock(torch.nn.Module):
32
+ def __init__(self,
33
+ kernel: Tuple[int] = (1, 3, 3, 1)
34
+ ):
35
+ super().__init__()
36
+
37
+ kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False)
38
+ kernel = kernel[None, :] * kernel[:, None]
39
+ kernel /= kernel.sum()
40
+ kernel = kernel.unsqueeze(0).unsqueeze(0)
41
+ self.register_buffer("kernel", kernel)
42
+
43
+ def calc_same_pad(self, i: int, k: int, s: int) -> int:
44
+ return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0)
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ ic, ih, iw = x.size()[-3:]
48
+ pad_h = self.calc_same_pad(i=ih, k=4, s=2)
49
+ pad_w = self.calc_same_pad(i=iw, k=4, s=2)
50
+ if pad_h > 0 or pad_w > 0:
51
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
52
+
53
+ weight = self.kernel.expand(ic, -1, -1, -1)
54
+
55
+ out = F.conv2d(input=x, weight=weight, stride=2, groups=x.shape[1])
56
+ return out
57
+
58
+
59
+ class NLayerDiscriminator(torch.nn.Module):
60
+ def __init__(
61
+ self,
62
+ num_channels: int = 3,
63
+ hidden_channels: int = 128,
64
+ num_stages: int = 3,
65
+ blur_resample: bool = True,
66
+ blur_kernel_size: int = 4
67
+ ):
68
+ """ Initializes the NLayerDiscriminator.
69
+
70
+ Args:
71
+ num_channels -> int: The number of input channels.
72
+ hidden_channels -> int: The number of hidden channels.
73
+ num_stages -> int: The number of stages.
74
+ blur_resample -> bool: Whether to use blur resampling.
75
+ blur_kernel_size -> int: The blur kernel size.
76
+ """
77
+ super().__init__()
78
+ assert num_stages > 0, "Discriminator cannot have 0 stages"
79
+ assert (not blur_resample) or (blur_kernel_size >= 3 and blur_kernel_size <= 5), "Blur kernel size must be in [3,5] when sampling]"
80
+
81
+ in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages)))
82
+ init_kernel_size = 5
83
+ activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1)
84
+
85
+ self.block_in = torch.nn.Sequential(
86
+ Conv2dSame(
87
+ num_channels,
88
+ hidden_channels,
89
+ kernel_size=init_kernel_size
90
+ ),
91
+ activation(),
92
+ )
93
+
94
+ BLUR_KERNEL_MAP = {
95
+ 3: (1,2,1),
96
+ 4: (1,3,3,1),
97
+ 5: (1,4,6,4,1),
98
+ }
99
+
100
+ discriminator_blocks = []
101
+ for i_level in range(num_stages):
102
+ in_channels = hidden_channels * in_channel_mult[i_level]
103
+ out_channels = hidden_channels * in_channel_mult[i_level + 1]
104
+ block = torch.nn.Sequential(
105
+ Conv2dSame(
106
+ in_channels,
107
+ out_channels,
108
+ kernel_size=3,
109
+ ),
110
+ torch.nn.AvgPool2d(kernel_size=2, stride=2) if not blur_resample else BlurBlock(BLUR_KERNEL_MAP[blur_kernel_size]),
111
+ torch.nn.GroupNorm(32, out_channels),
112
+ activation(),
113
+ )
114
+ discriminator_blocks.append(block)
115
+
116
+ self.blocks = torch.nn.ModuleList(discriminator_blocks)
117
+
118
+ self.pool = torch.nn.AdaptiveMaxPool2d((16, 16))
119
+
120
+ self.to_logits = torch.nn.Sequential(
121
+ Conv2dSame(out_channels, out_channels, 1),
122
+ activation(),
123
+ Conv2dSame(out_channels, 1, kernel_size=5)
124
+ )
125
+
126
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
127
+ """ Forward pass.
128
+
129
+ Args:
130
+ x -> torch.Tensor: The input tensor.
131
+
132
+ Returns:
133
+ output -> torch.Tensor: The output tensor.
134
+ """
135
+ hidden_states = self.block_in(x)
136
+ for block in self.blocks:
137
+ hidden_states = block(hidden_states)
138
+
139
+ hidden_states = self.pool(hidden_states)
140
+
141
+ return self.to_logits(hidden_states)
modeling/modules/ema_model.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains some base class implementation for EMA.
2
+
3
+ This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
4
+ All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
5
+
6
+ Reference:
7
+ https://github.com/huggingface/open-muse/blob/64e1afe033717d795866ab8204484705cd4dc3f7/muse/modeling_ema.py#L8
8
+ """
9
+
10
+
11
+ import copy
12
+ from typing import Any, Iterable, Optional, Union
13
+
14
+ import torch
15
+
16
+
17
+ class EMAModel:
18
+ """Exponential Moving Average of models weights."""
19
+ def __init__(
20
+ self,
21
+ parameters: Iterable[torch.nn.Parameter],
22
+ decay: float = 0.9999,
23
+ min_decay: float = 0.0,
24
+ update_after_step: int = 0,
25
+ update_every: int = 1,
26
+ current_step: int = 0,
27
+ use_ema_warmup: bool = False,
28
+ inv_gamma: Union[float, int] = 1.0,
29
+ power: Union[float, int] = 2 / 3,
30
+ model_cls: Optional[Any] = None,
31
+ **model_config_kwargs
32
+ ):
33
+ """
34
+ Args:
35
+ parameters (Iterable[torch.nn.Parameter]): The parameters to track.
36
+ decay (float): The decay factor for the exponential moving average.
37
+ min_decay (float): The minimum decay factor for the exponential moving average.
38
+ update_after_step (int): The number of steps to wait before starting to update the EMA weights.
39
+ update_every (int): The number of steps between each EMA update.
40
+ current_step (int): The current training step.
41
+ use_ema_warmup (bool): Whether to use EMA warmup.
42
+ inv_gamma (float):
43
+ Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
44
+ power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
45
+
46
+ notes on EMA Warmup:
47
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
48
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
49
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
50
+ at 215.4k steps).
51
+ """
52
+
53
+ parameters = list(parameters)
54
+ self.shadow_params = [p.clone().detach() for p in parameters]
55
+ self.temp_stored_params = None
56
+
57
+ self.decay = decay
58
+ self.min_decay = min_decay
59
+ self.update_after_step = update_after_step
60
+ self.update_every = update_every
61
+ self.use_ema_warmup = use_ema_warmup
62
+ self.inv_gamma = inv_gamma
63
+ self.power = power
64
+ self.optimization_step = current_step
65
+ self.cur_decay_value = None # set in `step()`
66
+
67
+ self.model_cls = model_cls
68
+ self.model_config_kwargs = model_config_kwargs
69
+
70
+ @classmethod
71
+ def from_pretrained(cls, checkpoint, model_cls, **model_config_kwargs) -> "EMAModel":
72
+ model = model_cls(**model_config_kwargs)
73
+ model.load_pretrained_weight(checkpoint)
74
+
75
+ ema_model = cls(model.parameters(), model_cls=model_cls, **model_config_kwargs)
76
+ return ema_model
77
+
78
+ def save_pretrained(self, path):
79
+ if self.model_cls is None:
80
+ raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
81
+
82
+ if self.model_config_kwargs is None:
83
+ raise ValueError("`save_pretrained` can only be used if `model_config_kwargs` was defined at __init__.")
84
+
85
+ model = self.model_cls(**self.model_config_kwargs)
86
+ self.copy_to(model.parameters())
87
+ model.save_pretrained_weight(path)
88
+
89
+ def set_step(self, optimization_step: int):
90
+ self.optimization_step = optimization_step
91
+
92
+ def get_decay(self, optimization_step: int) -> float:
93
+ """Computes the decay factor for the exponential moving average."""
94
+ step = max(0, optimization_step - self.update_after_step - 1)
95
+
96
+ if step <= 0:
97
+ return 0.0
98
+
99
+ if self.use_ema_warmup:
100
+ cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
101
+ else:
102
+ cur_decay_value = (1 + step) / (10 + step)
103
+
104
+ cur_decay_value = min(cur_decay_value, self.decay)
105
+ # Make sure decay is not smaller than min_decay.
106
+ cur_decay_value = max(cur_decay_value, self.min_decay)
107
+ return cur_decay_value
108
+
109
+ @torch.no_grad()
110
+ def step(self, parameters: Iterable[torch.nn.Parameter]):
111
+ parameters = list(parameters)
112
+
113
+ self.optimization_step += 1
114
+
115
+ if (self.optimization_step - 1) % self.update_every != 0:
116
+ return
117
+
118
+ # Compute the decay factor for the exponential moving average.
119
+ decay = self.get_decay(self.optimization_step)
120
+ self.cur_decay_value = decay
121
+ one_minus_decay = 1 - decay
122
+
123
+ for s_param, param in zip(self.shadow_params, parameters):
124
+ if param.requires_grad:
125
+ s_param.sub_(one_minus_decay * (s_param - param))
126
+ else:
127
+ s_param.copy_(param)
128
+
129
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
130
+ """Copies current averaged parameters into given collection of parameters.
131
+
132
+ Args:
133
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
134
+ updated with the stored moving averages. If `None`, the parameters with which this
135
+ `ExponentialMovingAverage` was initialized will be used.
136
+ """
137
+ parameters = list(parameters)
138
+ for s_param, param in zip(self.shadow_params, parameters):
139
+ param.data.copy_(s_param.to(param.device).data)
140
+
141
+ def to(self, device=None, dtype=None) -> None:
142
+ r"""Moves internal buffers of the ExponentialMovingAverage to `device`.
143
+
144
+ Args:
145
+ device: like `device` argument to `torch.Tensor.to`
146
+ """
147
+ # .to() on the tensors handles None correctly
148
+ self.shadow_params = [
149
+ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
150
+ for p in self.shadow_params
151
+ ]
152
+
153
+ def state_dict(self) -> dict:
154
+ r"""Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
155
+ checkpointing to save the ema state dict.
156
+ """
157
+ # Following PyTorch conventions, references to tensors are returned:
158
+ # "returns a reference to the state and not its copy!" -
159
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
160
+ return {
161
+ "decay": self.decay,
162
+ "min_decay": self.min_decay,
163
+ "optimization_step": self.optimization_step,
164
+ "update_after_step": self.update_after_step,
165
+ "use_ema_warmup": self.use_ema_warmup,
166
+ "inv_gamma": self.inv_gamma,
167
+ "power": self.power,
168
+ "shadow_params": self.shadow_params,
169
+ }
170
+
171
+ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
172
+ r"""
173
+ Args:
174
+ Save the current parameters for restoring later.
175
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
176
+ temporarily stored.
177
+ """
178
+ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
179
+
180
+ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
181
+ r"""Restores the parameters stored with the `store` method. Useful to validate
182
+ the model with EMA parameters without affecting the original optimization process.
183
+ Store the parameters before the `copy_to()` method. After validation (or
184
+ model saving), use this to restore the former parameters.
185
+
186
+ Args:
187
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
188
+ updated with the stored parameters. If `None`, the parameters with which this
189
+ `ExponentialMovingAverage` was initialized will be used.
190
+ """
191
+ if self.temp_stored_params is None:
192
+ raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
193
+ for c_param, param in zip(self.temp_stored_params, parameters):
194
+ param.data.copy_(c_param.data)
195
+
196
+ # Better memory-wise.
197
+ self.temp_stored_params = None
198
+
199
+ def load_state_dict(self, state_dict: dict) -> None:
200
+ r"""Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
201
+ ema state dict.
202
+
203
+ Args:
204
+ state_dict (dict): EMA state. Should be an object returned
205
+ from a call to :meth:`state_dict`.
206
+ """
207
+ # Deepcopy, to be consistent with module API
208
+ state_dict = copy.deepcopy(state_dict)
209
+
210
+ self.decay = state_dict.get("decay", self.decay)
211
+ if self.decay < 0.0 or self.decay > 1.0:
212
+ raise ValueError("Decay must be between 0 and 1")
213
+
214
+ self.min_decay = state_dict.get("min_decay", self.min_decay)
215
+ if not isinstance(self.min_decay, float):
216
+ raise ValueError("Invalid min_decay")
217
+
218
+ self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
219
+ if not isinstance(self.optimization_step, int):
220
+ raise ValueError("Invalid optimization_step")
221
+
222
+ self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
223
+ if not isinstance(self.update_after_step, int):
224
+ raise ValueError("Invalid update_after_step")
225
+
226
+ self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
227
+ if not isinstance(self.use_ema_warmup, bool):
228
+ raise ValueError("Invalid use_ema_warmup")
229
+
230
+ self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
231
+ if not isinstance(self.inv_gamma, (float, int)):
232
+ raise ValueError("Invalid inv_gamma")
233
+
234
+ self.power = state_dict.get("power", self.power)
235
+ if not isinstance(self.power, (float, int)):
236
+ raise ValueError("Invalid power")
237
+
238
+ shadow_params = state_dict.get("shadow_params", None)
239
+ if shadow_params is not None:
240
+ self.shadow_params = shadow_params
241
+ if not isinstance(self.shadow_params, list):
242
+ raise ValueError("shadow_params must be a list")
243
+ if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
244
+ raise ValueError("shadow_params must all be Tensors")
modeling/modules/losses.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This files contains training loss implementation.
2
+
3
+ Copyright (2024) Bytedance Ltd. and/or its affiliates
4
+
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+
17
+ Ref:
18
+ https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py
19
+ """
20
+ from typing import Mapping, Text, Tuple
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from einops import rearrange
26
+ from torch.cuda.amp import autocast
27
+ from .perceptual_loss import PerceptualLoss
28
+ from .discriminator import NLayerDiscriminator
29
+
30
+
31
+ def hinge_d_loss(logits_real: torch.Tensor, logits_fake: torch.Tensor) -> torch.Tensor:
32
+ """Hinge loss for discrminator.
33
+
34
+ This function is borrowed from
35
+ https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py#L20
36
+ """
37
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
38
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
39
+ d_loss = 0.5 * (loss_real + loss_fake)
40
+ return d_loss
41
+
42
+
43
+ def compute_lecam_loss(
44
+ logits_real_mean: torch.Tensor,
45
+ logits_fake_mean: torch.Tensor,
46
+ ema_logits_real_mean: torch.Tensor,
47
+ ema_logits_fake_mean: torch.Tensor
48
+ ) -> torch.Tensor:
49
+ """Computes the LeCam loss for the given average real and fake logits.
50
+
51
+ Args:
52
+ logits_real_mean -> torch.Tensor: The average real logits.
53
+ logits_fake_mean -> torch.Tensor: The average fake logits.
54
+ ema_logits_real_mean -> torch.Tensor: The EMA of the average real logits.
55
+ ema_logits_fake_mean -> torch.Tensor: The EMA of the average fake logits.
56
+
57
+ Returns:
58
+ lecam_loss -> torch.Tensor: The LeCam loss.
59
+ """
60
+ lecam_loss = torch.mean(torch.pow(F.relu(logits_real_mean - ema_logits_fake_mean), 2))
61
+ lecam_loss += torch.mean(torch.pow(F.relu(ema_logits_real_mean - logits_fake_mean), 2))
62
+ return lecam_loss
63
+
64
+
65
+ class ReconstructionLoss_Stage1(torch.nn.Module):
66
+ def __init__(
67
+ self,
68
+ config
69
+ ):
70
+ super().__init__()
71
+ loss_config = config.losses
72
+ self.quantizer_weight = loss_config.quantizer_weight
73
+ self.target_codebook_size = 1024
74
+
75
+ def forward(self,
76
+ target_codes: torch.Tensor,
77
+ reconstructions: torch.Tensor,
78
+ quantizer_loss: torch.Tensor,
79
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
80
+ return self._forward_generator(target_codes, reconstructions, quantizer_loss)
81
+
82
+ def _forward_generator(self,
83
+ target_codes: torch.Tensor,
84
+ reconstructions: torch.Tensor,
85
+ quantizer_loss: Mapping[Text, torch.Tensor],
86
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
87
+ reconstructions = reconstructions.contiguous()
88
+ loss_fct = nn.CrossEntropyLoss(reduction="mean")
89
+ batch_size = reconstructions.shape[0]
90
+ reconstruction_loss = loss_fct(reconstructions.view(batch_size, self.target_codebook_size, -1),
91
+ target_codes.view(batch_size, -1))
92
+ total_loss = reconstruction_loss + \
93
+ self.quantizer_weight * quantizer_loss["quantizer_loss"]
94
+
95
+ loss_dict = dict(
96
+ total_loss=total_loss.clone().detach(),
97
+ reconstruction_loss=reconstruction_loss.detach(),
98
+ quantizer_loss=(self.quantizer_weight * quantizer_loss["quantizer_loss"]).detach(),
99
+ commitment_loss=quantizer_loss["commitment_loss"].detach(),
100
+ codebook_loss=quantizer_loss["codebook_loss"].detach(),
101
+ )
102
+
103
+ return total_loss, loss_dict
104
+
105
+
106
+ class ReconstructionLoss_Stage2(torch.nn.Module):
107
+ def __init__(
108
+ self,
109
+ config
110
+ ):
111
+ """Initializes the losses module.
112
+
113
+ Args:
114
+ config: A dictionary, the configuration for the model and everything else.
115
+ """
116
+ super().__init__()
117
+ loss_config = config.losses
118
+ self.discriminator = NLayerDiscriminator()
119
+
120
+ self.reconstruction_loss = loss_config.reconstruction_loss
121
+ self.reconstruction_weight = loss_config.reconstruction_weight
122
+ self.quantizer_weight = loss_config.quantizer_weight
123
+ self.perceptual_loss = PerceptualLoss(
124
+ loss_config.perceptual_loss).eval()
125
+ self.perceptual_weight = loss_config.perceptual_weight
126
+ self.discriminator_iter_start = loss_config.discriminator_start
127
+
128
+ self.discriminator_factor = loss_config.discriminator_factor
129
+ self.discriminator_weight = loss_config.discriminator_weight
130
+ self.lecam_regularization_weight = loss_config.lecam_regularization_weight
131
+ self.lecam_ema_decay = loss_config.get("lecam_ema_decay", 0.999)
132
+ if self.lecam_regularization_weight > 0.0:
133
+ self.register_buffer("ema_real_logits_mean", torch.zeros((1)))
134
+ self.register_buffer("ema_fake_logits_mean", torch.zeros((1)))
135
+
136
+ self.config = config
137
+
138
+ @autocast(enabled=False)
139
+ def forward(self,
140
+ inputs: torch.Tensor,
141
+ reconstructions: torch.Tensor,
142
+ extra_result_dict: Mapping[Text, torch.Tensor],
143
+ global_step: int,
144
+ mode: str = "generator",
145
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
146
+ # Both inputs and reconstructions are in range [0, 1].
147
+ inputs = inputs.float()
148
+ reconstructions = reconstructions.float()
149
+
150
+ if mode == "generator":
151
+ return self._forward_generator(inputs, reconstructions, extra_result_dict, global_step)
152
+ elif mode == "discriminator":
153
+ return self._forward_discriminator(inputs, reconstructions, global_step)
154
+ else:
155
+ raise ValueError(f"Unsupported mode {mode}")
156
+
157
+ def should_discriminator_be_trained(self, global_step : int):
158
+ return global_step >= self.discriminator_iter_start
159
+
160
+ def _forward_generator(self,
161
+ inputs: torch.Tensor,
162
+ reconstructions: torch.Tensor,
163
+ extra_result_dict: Mapping[Text, torch.Tensor],
164
+ global_step: int
165
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
166
+ """Generator training step."""
167
+ inputs = inputs.contiguous()
168
+ reconstructions = reconstructions.contiguous()
169
+ if self.reconstruction_loss == "l1":
170
+ reconstruction_loss = F.l1_loss(inputs, reconstructions, reduction="mean")
171
+ elif self.reconstruction_loss == "l2":
172
+ reconstruction_loss = F.mse_loss(inputs, reconstructions, reduction="mean")
173
+ else:
174
+ raise ValueError(f"Unsuppored reconstruction_loss {self.reconstruction_loss}")
175
+ reconstruction_loss *= self.reconstruction_weight
176
+
177
+ # Compute perceptual loss.
178
+ perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean()
179
+
180
+ # Compute discriminator loss.
181
+ generator_loss = torch.zeros((), device=inputs.device)
182
+ discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0
183
+ d_weight = 1.0
184
+ if discriminator_factor > 0.0 and self.discriminator_weight > 0.0:
185
+ # Disable discriminator gradients.
186
+ for param in self.discriminator.parameters():
187
+ param.requires_grad = False
188
+ logits_fake = self.discriminator(reconstructions)
189
+ generator_loss = -torch.mean(logits_fake)
190
+
191
+ d_weight *= self.discriminator_weight
192
+
193
+ # Compute quantizer loss.
194
+ quantizer_loss = extra_result_dict["quantizer_loss"]
195
+ total_loss = (
196
+ reconstruction_loss
197
+ + self.perceptual_weight * perceptual_loss
198
+ + self.quantizer_weight * quantizer_loss
199
+ + d_weight * discriminator_factor * generator_loss
200
+ )
201
+ loss_dict = dict(
202
+ total_loss=total_loss.clone().detach(),
203
+ reconstruction_loss=reconstruction_loss.detach(),
204
+ perceptual_loss=(self.perceptual_weight * perceptual_loss).detach(),
205
+ quantizer_loss=(self.quantizer_weight * quantizer_loss).detach(),
206
+ weighted_gan_loss=(d_weight * discriminator_factor * generator_loss).detach(),
207
+ discriminator_factor=torch.tensor(discriminator_factor),
208
+ commitment_loss=extra_result_dict["commitment_loss"].detach(),
209
+ codebook_loss=extra_result_dict["codebook_loss"].detach(),
210
+ d_weight=d_weight,
211
+ gan_loss=generator_loss.detach(),
212
+ )
213
+
214
+ return total_loss, loss_dict
215
+
216
+ def _forward_discriminator(self,
217
+ inputs: torch.Tensor,
218
+ reconstructions: torch.Tensor,
219
+ global_step: int,
220
+ ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
221
+ """Discrminator training step."""
222
+ discriminator_factor = self.discriminator_factor if self.should_discriminator_be_trained(global_step) else 0
223
+ loss_dict = {}
224
+ # Turn the gradients on.
225
+ for param in self.discriminator.parameters():
226
+ param.requires_grad = True
227
+
228
+ real_images = inputs.detach().requires_grad_(True)
229
+ logits_real = self.discriminator(real_images)
230
+ logits_fake = self.discriminator(reconstructions.detach())
231
+
232
+ discriminator_loss = discriminator_factor * hinge_d_loss(logits_real=logits_real, logits_fake=logits_fake)
233
+
234
+ # optional lecam regularization
235
+ lecam_loss = torch.zeros((), device=inputs.device)
236
+ if self.lecam_regularization_weight > 0.0:
237
+ lecam_loss = compute_lecam_loss(
238
+ torch.mean(logits_real),
239
+ torch.mean(logits_fake),
240
+ self.ema_real_logits_mean,
241
+ self.ema_fake_logits_mean
242
+ ) * self.lecam_regularization_weight
243
+
244
+ self.ema_real_logits_mean = self.ema_real_logits_mean * self.lecam_ema_decay + torch.mean(logits_real).detach() * (1 - self.lecam_ema_decay)
245
+ self.ema_fake_logits_mean = self.ema_fake_logits_mean * self.lecam_ema_decay + torch.mean(logits_fake).detach() * (1 - self.lecam_ema_decay)
246
+
247
+ discriminator_loss += lecam_loss
248
+
249
+ loss_dict = dict(
250
+ discriminator_loss=discriminator_loss.detach(),
251
+ logits_real=logits_real.detach().mean(),
252
+ logits_fake=logits_fake.detach().mean(),
253
+ lecam_loss=lecam_loss.detach(),
254
+ )
255
+ return discriminator_loss, loss_dict
256
+
257
+
258
+ class MLMLoss(torch.nn.Module):
259
+ def __init__(self,
260
+ config):
261
+ super().__init__()
262
+ self.label_smoothing = config.losses.label_smoothing
263
+ self.loss_weight_unmasked_token = config.losses.loss_weight_unmasked_token
264
+ self.criterion = torch.nn.CrossEntropyLoss(label_smoothing=self.label_smoothing,
265
+ reduction="none")
266
+
267
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor,
268
+ weights=None) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
269
+ inputs = rearrange(inputs, "b n c -> b c n")
270
+ loss = self.criterion(inputs, targets)
271
+ weights = weights.to(loss)
272
+ loss_weights = (1.0 - weights) * self.loss_weight_unmasked_token + weights # set 0 to self.loss_weight_unasked_token
273
+ loss = (loss * loss_weights).sum() / (loss_weights.sum() + 1e-8)
274
+ # we only compute correct tokens on masked tokens
275
+ correct_tokens = ((torch.argmax(inputs, dim=1) == targets) * weights).sum(dim=1) / (weights.sum(1) + 1e-8)
276
+ return loss, {"loss": loss, "correct_tokens": correct_tokens.mean()}
277
+
278
+
279
+ class ARLoss(torch.nn.Module):
280
+ def __init__(self, config):
281
+ super().__init__()
282
+ self.target_vocab_size = config.model.vq_model.codebook_size
283
+ self.criterion = torch.nn.CrossEntropyLoss(reduction="mean")
284
+
285
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
286
+ shift_logits = logits[..., :-1, :].permute(0, 2, 1).contiguous() # NLC->NCL
287
+ shift_labels = labels.contiguous()
288
+ shift_logits = shift_logits.view(shift_logits.shape[0], self.target_vocab_size, -1)
289
+ shift_labels = shift_labels.view(shift_labels.shape[0], -1)
290
+ shift_labels = shift_labels.to(shift_logits.device)
291
+ loss = self.criterion(shift_logits, shift_labels)
292
+ correct_tokens = (torch.argmax(shift_logits, dim=1) == shift_labels).sum(dim=1) / shift_labels.size(1)
293
+ return loss, {"loss": loss, "correct_tokens": correct_tokens.mean()}