sushilnitham commited on
Commit
80de44d
1 Parent(s): f42b635

Upload 46 files

Browse files
Files changed (46) hide show
  1. .gitattributes +8 -27
  2. .gitignore +6 -0
  3. CITATION.cff +44 -0
  4. Docker/Dockerfile +17 -0
  5. Docker/README.md +16 -0
  6. Docker/build_docker.sh +1 -0
  7. LICENSE +201 -0
  8. Makefile +5 -0
  9. README.md +272 -0
  10. app/gradio/__pycache__/backend.cpython-311.pyc +0 -0
  11. app/gradio/app.py +53 -0
  12. app/gradio/backend.py +33 -0
  13. app/streamlit/__pycache__/backend.cpython-311.pyc +0 -0
  14. app/streamlit/app.py +108 -0
  15. app/streamlit/backend.py +33 -0
  16. app/streamlit/img/loading.gif +0 -0
  17. img/logo.png +0 -0
  18. pyproject.toml +2 -0
  19. run_docker_image.sh +4 -0
  20. setup.cfg +47 -0
  21. setup.py +4 -0
  22. src/dalle_mini/__init__.py +3 -0
  23. src/dalle_mini/data.py +461 -0
  24. src/dalle_mini/model/__init__.py +5 -0
  25. src/dalle_mini/model/configuration.py +185 -0
  26. src/dalle_mini/model/modeling.py +1953 -0
  27. src/dalle_mini/model/partitions.py +76 -0
  28. src/dalle_mini/model/processor.py +60 -0
  29. src/dalle_mini/model/text.py +262 -0
  30. src/dalle_mini/model/tokenizer.py +8 -0
  31. src/dalle_mini/model/utils.py +27 -0
  32. tools/dataset/encode_dataset.ipynb +371 -0
  33. tools/inference/inference_pipeline.ipynb +561 -0
  34. tools/inference/run_infer_notebook.sh +2 -0
  35. tools/train/config/mega/config.json +49 -0
  36. tools/train/config/micro/config.json +30 -0
  37. tools/train/config/mini/config.json +29 -0
  38. tools/train/config/mini_glu/config.json +30 -0
  39. tools/train/embeddings_retrain_preparation.ipynb +1218 -0
  40. tools/train/scalable_shampoo/README.md +7 -0
  41. tools/train/scalable_shampoo/distributed_shampoo.py +2452 -0
  42. tools/train/scalable_shampoo/quantization_utils.py +123 -0
  43. tools/train/scalable_shampoo/sm3.py +176 -0
  44. tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py +441 -0
  45. tools/train/sweep.yaml +49 -0
  46. tools/train/train.py +1740 -0
.gitattributes CHANGED
@@ -1,35 +1,16 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
 
 
 
4
  *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.arrow filter=lfs diff=lfs merge=lfs -text
10
+ *.ftz filter=lfs diff=lfs merge=lfs -text
11
  *.joblib filter=lfs diff=lfs merge=lfs -text
 
 
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
14
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__
2
+ .ipynb_checkpoints
3
+ .streamlit
4
+ wandb/
5
+ *.egg-info/
6
+ jax_cache/
CITATION.cff ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YAML 1.2
2
+ ---
3
+ abstract: "DALL·E mini is a JAX/Flax reimplementation of OpenAI's DALL·E that requires much smaller hardware resources. By simplifying the architecture and model memory requirements, as well as leveraging open-source code and pre-trained models, we were able to create a model that is 27 times smaller than the original DALL·E and train it on a single TPU v3-8 for only 3 days. DALL·E mini achieves impressive results, albeit of a lower quality than the original system. It can be used for exploration and further experimentation on commodity hardware."
4
+ authors:
5
+ -
6
+ family-names: Dayma
7
+ given-names: Boris
8
+ -
9
+ family-names: Patil
10
+ given-names: Suraj
11
+ -
12
+ family-names: Cuenca
13
+ given-names: Pedro
14
+ -
15
+ family-names: Saifullah
16
+ given-names: Khalid
17
+ -
18
+ family-names: Abraham
19
+ given-names: Tanishq
20
+ -
21
+ family-names: "Lê Khắc"
22
+ given-names: "Phúc"
23
+ -
24
+ family-names: Melas
25
+ given-names: Luke
26
+ -
27
+ family-names: Ghosh
28
+ given-names: Ritobrata
29
+ cff-version: "1.1.0"
30
+ date-released: 2021-07-29
31
+ identifiers:
32
+ keywords:
33
+ - dalle
34
+ - "text-to-image generation"
35
+ - transformer
36
+ - "zero-shot"
37
+ - JAX
38
+ license: "Apache-2.0"
39
+ doi: 10.5281/zenodo.5146400
40
+ message: "If you use this project, please cite it using these metadata."
41
+ repository-code: "https://github.com/borisdayma/dalle-mini"
42
+ title: "DALL·E Mini"
43
+ version: "v0.1-alpha"
44
+ ...
Docker/Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
2
+
3
+ RUN apt-get update && apt-get install -y \
4
+ git \
5
+ python3 \
6
+ python3-pip \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ RUN pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
10
+ && pip install -q \
11
+ git+https://github.com/borisdayma/dalle-mini.git \
12
+ git+https://github.com/patil-suraj/vqgan-jax.git
13
+
14
+ RUN pip install jupyter
15
+
16
+ WORKDIR /workspace
17
+
Docker/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Running Dalle-mini With Docker
2
+
3
+ This folder contains the Dockerfile needed to build a Docker image that can easily run Dalle-mini.
4
+
5
+ ## Inference
6
+
7
+ Steps to run inference with Dalle-mini are as follows:
8
+
9
+ 1. Build the docker image with ```dalle-mini/Docker/build_docker.sh```
10
+ 2. Run the container with ```dalle-mini/run_docker_image.sh```
11
+ 3. Navigate to ```/workspace/tools/inference/``` and run ```run_infer_notebook.sh```
12
+ 4. Click the Jupyter Notebook link and run through the notebook.
13
+
14
+ ### Inference Video Tutorial
15
+
16
+ Alteratively check out a video tutorial on how to run Dalle-mini on [Linux](https://www.youtube.com/watch?v=eWpzLIa6v9E) and [Windows](https://www.youtube.com/watch?v=OqEuEe-xSKk)
Docker/build_docker.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ docker build . -t dalle-mini:latest
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2021 The DALL·E mini Authors
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Makefile ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .PHONY: style
2
+
3
+ style:
4
+ black .
5
+ isort .
README.md ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DALL·E Mini
2
+
3
+ <a href="https://www.craiyon.com/"><img src="https://www.craiyon.com/thumbnail.png" width="300"></a>
4
+
5
+ ## How to use it?
6
+
7
+ You can use the model on [🖍️ craiyon](https://www.craiyon.com/)
8
+
9
+ ## How does it work?
10
+
11
+ Refer to our reports:
12
+
13
+ * [DALL·E mini - Generate Images from Any Text Prompt](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini-Generate-images-from-any-text-prompt--VmlldzoyMDE4NDAy)
14
+ * [DALL·E mini - Explained](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mini-Explained-with-Demo--Vmlldzo4NjIxODA)
15
+ * [DALL·E mega - Training Journal](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-Mega-Training-Journal--VmlldzoxODMxMDI2)
16
+
17
+ ## Development
18
+
19
+ ### Dependencies Installation
20
+
21
+ For inference only, use `pip install dalle-mini`.
22
+
23
+ For development, clone the repo and use `pip install -e ".[dev]"`.
24
+ Before making a PR, check style with `make style`.
25
+
26
+ You can experiment with the pipeline step by step through our [`inference pipeline notebook`](tools/inference/inference_pipeline.ipynb)
27
+
28
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb)
29
+
30
+ ### Training of DALL·E mini
31
+
32
+ Use [`tools/train/train.py`](tools/train/train.py).
33
+
34
+ You can also adjust the [sweep configuration file](https://docs.wandb.ai/guides/sweeps) if you need to perform a hyperparameter search.
35
+
36
+ ## FAQ
37
+
38
+ ### Where to find the latest models?
39
+
40
+ Trained models are on 🤗 Model Hub:
41
+
42
+ * [VQGAN-f16-16384](https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384) for encoding/decoding images
43
+ * [DALL·E mini](https://huggingface.co/dalle-mini/dalle-mini) or [DALL·E mega](https://huggingface.co/dalle-mini/dalle-mega) for generating images from a text prompt
44
+
45
+ ### Where does the logo come from?
46
+
47
+ The "armchair in the shape of an avocado" was used by OpenAI when releasing DALL·E to illustrate the model's capabilities. Having successful predictions on this prompt represents a big milestone for us.
48
+
49
+ ## Contributing
50
+
51
+ Join the community on the [LAION Discord](https://discord.gg/xBPBXfcFHd).
52
+ Any contribution is welcome, from reporting issues to proposing fixes/improvements or testing the model with cool prompts!
53
+
54
+ You can also use these great projects from the community:
55
+
56
+ * spin off your own app with [DALL-E Playground repository](https://github.com/saharmor/dalle-playground) (thanks [Sahar](https://twitter.com/theaievangelist))
57
+
58
+ * try [DALL·E Flow](https://github.com/jina-ai/dalle-flow) project for generating, diffusion, and upscaling in a Human-in-the-Loop workflow (thanks [Han Xiao](https://github.com/hanxiao))
59
+
60
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jina-ai/dalle-flow/blob/main/client.ipynb)
61
+
62
+ * run on [Replicate](https://replicate.com/borisdayma/dalle-mini), in the browser or via API
63
+
64
+ ## Acknowledgements
65
+
66
+ * 🤗 Hugging Face for organizing [the FLAX/JAX community week](https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects)
67
+ * Google [TPU Research Cloud (TRC) program](https://sites.research.google/trc/) for providing computing resources
68
+ * [Weights & Biases](https://wandb.com/) for providing the infrastructure for experiment tracking and model management
69
+
70
+ ## Authors & Contributors
71
+
72
+ DALL·E mini was initially developed by:
73
+
74
+ * [Boris Dayma](https://github.com/borisdayma)
75
+ * [Suraj Patil](https://github.com/patil-suraj)
76
+ * [Pedro Cuenca](https://github.com/pcuenca)
77
+ * [Khalid Saifullah](https://github.com/khalidsaifullaah)
78
+ * [Tanishq Abraham](https://github.com/tmabraham)
79
+ * [Phúc Lê Khắc](https://github.com/lkhphuc)
80
+ * [Luke Melas](https://github.com/lukemelas)
81
+ * [Ritobrata Ghosh](https://github.com/ghosh-r)
82
+
83
+ Many thanks to the people who helped make it better:
84
+
85
+ * the [DALLE-Pytorch](https://discord.gg/xBPBXfcFHd) and [EleutherAI](https://www.eleuther.ai/) communities for testing and exchanging cool ideas
86
+ * [Rohan Anil](https://github.com/rohan-anil) for adding Distributed Shampoo optimizer and always giving great suggestions
87
+ * [Phil Wang](https://github.com/lucidrains) has provided a lot of cool implementations of transformer variants and gives interesting insights with [x-transformers](https://github.com/lucidrains/x-transformers)
88
+ * [Katherine Crowson](https://github.com/crowsonkb) for [super conditioning](https://twitter.com/RiversHaveWings/status/1478093658716966912)
89
+ * the [Gradio team](https://gradio.app/) made an amazing UI for our app
90
+
91
+ ## Citing DALL·E mini
92
+
93
+ If you find DALL·E mini useful in your research or wish to refer, please use the following BibTeX entry.
94
+
95
+ ```text
96
+ @misc{Dayma_DALL·E_Mini_2021,
97
+ author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
98
+ doi = {10.5281/zenodo.5146400},
99
+ month = {7},
100
+ title = {DALL·E Mini},
101
+ url = {https://github.com/borisdayma/dalle-mini},
102
+ year = {2021}
103
+ }
104
+ ```
105
+
106
+ ## References
107
+
108
+ Original DALL·E from "[Zero-Shot Text-to-Image Generation](https://arxiv.org/abs/2102.12092)" with image quantization from "[Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)".
109
+
110
+ Image encoder from "[Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841v2)".
111
+
112
+ Sequence to sequence model based on "[BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461v1)" with implementation of a few variants:
113
+
114
+ * "[GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202)"
115
+ * "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)"
116
+ * "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)"
117
+ * "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)"
118
+ * "[CogView: Mastering Text-to-Image Generation via Transformers](https://arxiv.org/abs/2105.13290v2)"
119
+ * "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)"
120
+ * "[Sinkformers: Transformers with Doubly Stochastic Attention](https://arxiv.org/abs/2110.11773)"
121
+ * "[Foundation Transformers](https://arxiv.org/abs/2210.06423)
122
+
123
+ Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)".
124
+
125
+ ### Citations
126
+
127
+ ```text
128
+ @misc{
129
+ title={Zero-Shot Text-to-Image Generation},
130
+ author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
131
+ year={2021},
132
+ eprint={2102.12092},
133
+ archivePrefix={arXiv},
134
+ primaryClass={cs.CV}
135
+ }
136
+ ```
137
+
138
+ ```text
139
+ @misc{
140
+ title={Learning Transferable Visual Models From Natural Language Supervision},
141
+ author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
142
+ year={2021},
143
+ eprint={2103.00020},
144
+ archivePrefix={arXiv},
145
+ primaryClass={cs.CV}
146
+ }
147
+ ```
148
+
149
+ ```text
150
+ @misc{
151
+ title={Taming Transformers for High-Resolution Image Synthesis},
152
+ author={Patrick Esser and Robin Rombach and Björn Ommer},
153
+ year={2021},
154
+ eprint={2012.09841},
155
+ archivePrefix={arXiv},
156
+ primaryClass={cs.CV}
157
+ }
158
+ ```
159
+
160
+ ```text
161
+ @misc{
162
+ title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension},
163
+ author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer},
164
+ year={2019},
165
+ eprint={1910.13461},
166
+ archivePrefix={arXiv},
167
+ primaryClass={cs.CL}
168
+ }
169
+ ```
170
+
171
+ ```text
172
+ @misc{
173
+ title={Scalable Second Order Optimization for Deep Learning},
174
+ author={Rohan Anil and Vineet Gupta and Tomer Koren and Kevin Regan and Yoram Singer},
175
+ year={2021},
176
+ eprint={2002.09018},
177
+ archivePrefix={arXiv},
178
+ primaryClass={cs.LG}
179
+ }
180
+ ```
181
+
182
+ ```text
183
+ @misc{
184
+ title={GLU Variants Improve Transformer},
185
+ author={Noam Shazeer},
186
+ year={2020},
187
+ url={https://arxiv.org/abs/2002.05202}
188
+ }
189
+ ```
190
+
191
+ ```text
192
+ @misc{
193
+ title={DeepNet: Scaling transformers to 1,000 layers},
194
+ author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Zhang, Dongdong and Wei, Furu},
195
+ year={2022},
196
+ eprint={2203.00555}
197
+ archivePrefix={arXiv},
198
+ primaryClass={cs.LG}
199
+ }
200
+ ```
201
+
202
+ ```text
203
+ @misc{
204
+ title={NormFormer: Improved Transformer Pretraining with Extra Normalization},
205
+ author={Sam Shleifer and Jason Weston and Myle Ott},
206
+ year={2021},
207
+ eprint={2110.09456},
208
+ archivePrefix={arXiv},
209
+ primaryClass={cs.CL}
210
+ }
211
+ ```
212
+
213
+ ```text
214
+ @inproceedings{
215
+ title={Swin Transformer V2: Scaling Up Capacity and Resolution},
216
+ author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
217
+ booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},
218
+ year={2022}
219
+ }
220
+ ```
221
+
222
+ ```text
223
+ @misc{
224
+ title = {CogView: Mastering Text-to-Image Generation via Transformers},
225
+ author = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
226
+ year = {2021},
227
+ eprint = {2105.13290},
228
+ archivePrefix = {arXiv},
229
+ primaryClass = {cs.CV}
230
+ }
231
+ ```
232
+
233
+ ```text
234
+ @misc{
235
+ title = {Root Mean Square Layer Normalization},
236
+ author = {Biao Zhang and Rico Sennrich},
237
+ year = {2019},
238
+ eprint = {1910.07467},
239
+ archivePrefix = {arXiv},
240
+ primaryClass = {cs.LG}
241
+ }
242
+ ```
243
+
244
+ ```text
245
+ @misc{
246
+ title = {Sinkformers: Transformers with Doubly Stochastic Attention},
247
+ url = {https://arxiv.org/abs/2110.11773},
248
+ author = {Sander, Michael E. and Ablin, Pierre and Blondel, Mathieu and Peyré, Gabriel},
249
+ publisher = {arXiv},
250
+ year = {2021},
251
+ }
252
+ ```
253
+
254
+ ```text
255
+ @misc{
256
+ title = {Smooth activations and reproducibility in deep networks},
257
+ url = {https://arxiv.org/abs/2010.09931},
258
+ author = {Shamir, Gil I. and Lin, Dong and Coviello, Lorenzo},
259
+ publisher = {arXiv},
260
+ year = {2020},
261
+ }
262
+ ```
263
+
264
+ ```text
265
+ @misc{
266
+ title = {Foundation Transformers},
267
+ url = {https://arxiv.org/abs/2210.06423},
268
+ author = {Wang, Hongyu and Ma, Shuming and Huang, Shaohan and Dong, Li and Wang, Wenhui and Peng, Zhiliang and Wu, Yu and Bajaj, Payal and Singhal, Saksham and Benhaim, Alon and Patra, Barun and Liu, Zhun and Chaudhary, Vishrav and Song, Xia and Wei, Furu},
269
+ publisher = {arXiv},
270
+ year = {2022},
271
+ }
272
+ ```
app/gradio/__pycache__/backend.cpython-311.pyc ADDED
Binary file (2.1 kB). View file
 
app/gradio/app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ import os
4
+
5
+ import gradio as gr
6
+ from backend import get_images_from_backend
7
+
8
+ block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")
9
+ backend_url = os.environ["BACKEND_SERVER"] + "/generate"
10
+
11
+
12
+ def infer(prompt):
13
+ response = get_images_from_backend(prompt, backend_url)
14
+ return response["images"]
15
+
16
+
17
+ with block:
18
+ gr.Markdown("<h1><center>DALL·E mini</center></h1>")
19
+ gr.Markdown(
20
+ "DALL·E mini is an AI model that generates images from any prompt you give!"
21
+ )
22
+ with gr.Group():
23
+ with gr.Box():
24
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
25
+ text = gr.Textbox(
26
+ label="Enter your prompt", show_label=False, max_lines=1
27
+ ).style(
28
+ border=(True, False, True, True),
29
+ margin=False,
30
+ rounded=(True, False, False, True),
31
+ container=False,
32
+ )
33
+ btn = gr.Button("Run").style(
34
+ margin=False,
35
+ rounded=(False, True, True, False),
36
+ )
37
+ gallery = gr.Gallery(label="Generated images", show_label=False).style(
38
+ grid=[3], height="auto"
39
+ )
40
+ text.submit(infer, inputs=text, outputs=gallery)
41
+ btn.click(infer, inputs=text, outputs=gallery)
42
+
43
+ gr.Markdown(
44
+ """___
45
+ <p style='text-align: center'>
46
+ Created by <a href="https://twitter.com/borisdayma" target="_blank">Boris Dayma</a> et al. 2021-2022
47
+ <br/>
48
+ <a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini-Generate-images-from-any-text-prompt--VmlldzoyMDE4NDAy" target="_blank">Project Report</a>
49
+ </p>"""
50
+ )
51
+
52
+
53
+ block.launch(enable_queue=False)
app/gradio/backend.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Client requests to Dalle-Mini Backend server
2
+
3
+ import base64
4
+ from io import BytesIO
5
+
6
+ import requests
7
+ from PIL import Image
8
+
9
+
10
+ class ServiceError(Exception):
11
+ def __init__(self, status_code):
12
+ self.status_code = status_code
13
+
14
+
15
+ def get_images_from_backend(prompt, backend_url):
16
+ r = requests.post(backend_url, json={"prompt": prompt})
17
+ if r.status_code == 200:
18
+ json = r.json()
19
+ images = json["images"]
20
+ images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
21
+ version = json.get("version", "unknown")
22
+ return {"images": images, "version": version}
23
+ else:
24
+ raise ServiceError(r.status_code)
25
+
26
+
27
+ def get_model_version(url):
28
+ r = requests.get(url)
29
+ if r.status_code == 200:
30
+ version = r.json()["version"]
31
+ return version
32
+ else:
33
+ raise ServiceError(r.status_code)
app/streamlit/__pycache__/backend.cpython-311.pyc ADDED
Binary file (2.1 kB). View file
 
app/streamlit/app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import streamlit as st
5
+ from backend import ServiceError, get_images_from_backend
6
+
7
+ st.sidebar.markdown(
8
+ """
9
+ <style>
10
+ .aligncenter {
11
+ text-align: center;
12
+ }
13
+ </style>
14
+ <p class="aligncenter">
15
+ <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/img/logo.png"/>
16
+ </p>
17
+ """,
18
+ unsafe_allow_html=True,
19
+ )
20
+ st.sidebar.markdown(
21
+ """
22
+ ___
23
+ <p style='text-align: center'>
24
+ DALL·E mini is an AI model that generates images from any prompt you give!
25
+ </p>
26
+
27
+ <p style='text-align: center'>
28
+ Created by Boris Dayma et al. 2021-2022
29
+ <br/>
30
+ <a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
31
+ </p>
32
+ """,
33
+ unsafe_allow_html=True,
34
+ )
35
+
36
+ st.header("DALL·E mini")
37
+ st.subheader("Generate images from text")
38
+
39
+ prompt = st.text_input("What do you want to see?")
40
+
41
+ DEBUG = False
42
+ if prompt != "":
43
+ container = st.empty()
44
+ container.markdown(
45
+ f"""
46
+ <style> p {{ margin:0 }} div {{ margin:0 }} </style>
47
+ <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
48
+ <div class="stAlert">
49
+ <div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
50
+ <div class="st-b7">
51
+ <div class="css-whx05o e13vu3m50">
52
+ <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
53
+ <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
54
+ Generating predictions for: <b>{prompt}</b>
55
+ </div>
56
+ </div>
57
+ </div>
58
+ </div>
59
+ </div>
60
+ </div>
61
+ <small><i>Predictions may take up to 5mn under high load. Please stand by.</i></small>
62
+ """,
63
+ unsafe_allow_html=True,
64
+ )
65
+
66
+ try:
67
+ backend_url = st.secrets["BACKEND_SERVER"] + "/generate"
68
+ response = get_images_from_backend(prompt, backend_url)
69
+ selected = response["images"]
70
+ version = response["version"]
71
+
72
+ margin = 0.1 # for better position of zoom in arrow
73
+ n_columns = 3
74
+ cols = st.columns([1] + [margin, 1] * (n_columns - 1))
75
+ for i, img in enumerate(selected):
76
+ cols[(i % n_columns) * 2].image(img)
77
+ container.markdown(f"**{prompt}**")
78
+
79
+ # st.sidebar.markdown(
80
+ # f"<small><center>{version}</center></small>", unsafe_allow_html=True
81
+ # )
82
+
83
+ # st.markdown(
84
+ # f"""
85
+ # These results have been obtained using model `{version}` from [an ongoing training run](https://wandb.ai/dalle-mini/dalle-mini/runs/mheh9e55).
86
+ # """
87
+ # )
88
+
89
+ st.button("Again!", key="again_button")
90
+
91
+ except ServiceError as error:
92
+ container.text(f"Service unavailable, status: {error.status_code}")
93
+ except KeyError:
94
+ if DEBUG:
95
+ container.markdown(
96
+ """
97
+ **Error: BACKEND_SERVER unset**
98
+
99
+ Please, create a file called `.streamlit/secrets.toml` inside the app's folder and include a line to configure the server URL:
100
+ ```
101
+ BACKEND_SERVER="<server url>"
102
+ ```
103
+ """
104
+ )
105
+ else:
106
+ container.markdown(
107
+ "Error -5, please try again or [report it](mailto:pcuenca-dalle@guenever.net)."
108
+ )
app/streamlit/backend.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Client requests to Dalle-Mini Backend server
2
+
3
+ import base64
4
+ from io import BytesIO
5
+
6
+ import requests
7
+ from PIL import Image
8
+
9
+
10
+ class ServiceError(Exception):
11
+ def __init__(self, status_code):
12
+ self.status_code = status_code
13
+
14
+
15
+ def get_images_from_backend(prompt, backend_url):
16
+ r = requests.post(backend_url, json={"prompt": prompt})
17
+ if r.status_code == 200:
18
+ json = r.json()
19
+ images = json["images"]
20
+ images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
21
+ version = json.get("version", "unknown")
22
+ return {"images": images, "version": version}
23
+ else:
24
+ raise ServiceError(r.status_code)
25
+
26
+
27
+ def get_model_version(url):
28
+ r = requests.get(url)
29
+ if r.status_code == 200:
30
+ version = r.json()["version"]
31
+ return version
32
+ else:
33
+ raise ServiceError(r.status_code)
app/streamlit/img/loading.gif ADDED
img/logo.png ADDED
pyproject.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [tool.isort]
2
+ profile = "black"
run_docker_image.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script is used to run the docker image. Change or remove GPU flag if you dont have nvidia-docker or the needed GPUs
4
+ docker run --rm --name dallemini -it -p 8888:8888 --gpus all -v "${PWD}":/workspace dalle-mini:latest
setup.cfg ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [metadata]
2
+ name = dalle-mini
3
+ version = attr: dalle_mini.__version__
4
+ author = Boris Dayma et al.
5
+ author_email = boris.dayma@gmail.com
6
+ description = DALL·E mini - Generate images from a text prompt
7
+ long_description = file: README.md
8
+ long_description_content_type = text/markdown
9
+ url = https://github.com/borisdayma/dalle-mini
10
+ project_urls =
11
+ Bug Tracker = https://github.com/borisdayma/dalle-mini/issues
12
+ classifiers =
13
+ Programming Language :: Python :: 3
14
+ License :: OSI Approved :: Apache Software License
15
+ Operating System :: OS Independent
16
+ Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Development Status :: 3 - Alpha
18
+ Intended Audience :: Developers
19
+
20
+ [options]
21
+ package_dir =
22
+ =src
23
+ packages = find:
24
+ python_requires = >=3.6
25
+ install_requires =
26
+ transformers==4.25.1
27
+ einops
28
+ unidecode
29
+ ftfy
30
+ emoji
31
+ pillow
32
+ jax==0.3.25
33
+ flax==0.6.3
34
+ orbax==0.0.23
35
+ wandb
36
+
37
+ [options.extras_require]
38
+ dev =
39
+ tqdm
40
+ optax
41
+ braceexpand
42
+ datasets[streaming]
43
+ black[jupyter]
44
+ isort
45
+
46
+ [options.packages.find]
47
+ where = src
setup.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ if __name__ == "__main__":
4
+ setup()
src/dalle_mini/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __version__ = "0.1.5"
2
+
3
+ from .model import DalleBart, DalleBartProcessor
src/dalle_mini/data.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass, field
3
+ from functools import partial
4
+ from pathlib import Path
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from braceexpand import braceexpand
10
+ from datasets import Dataset, load_dataset
11
+
12
+ from .model.text import TextNormalizer
13
+
14
+
15
+ @dataclass
16
+ class Dataset:
17
+ dataset_repo_or_path: str
18
+ train_file: str = None
19
+ validation_file: str = None
20
+ streaming: bool = True
21
+ use_auth_token: bool = False
22
+ text_column: str = "caption"
23
+ encoding_column: str = "encoding"
24
+ max_train_samples: int = None
25
+ max_eval_samples: int = None
26
+ preprocessing_num_workers: int = None
27
+ overwrite_cache: bool = False
28
+ do_train: bool = False
29
+ do_eval: bool = True
30
+ seed_dataset: int = None
31
+ shard_by_host: bool = False
32
+ blank_caption_prob: float = 0.0
33
+ clip_score_column: str = "clip_score"
34
+ min_clip_score: float = None
35
+ max_clip_score: float = None
36
+ filter_column: str = None
37
+ filter_value: str = None
38
+ multi_eval_ds: bool = False
39
+ train_dataset: Dataset = field(init=False)
40
+ eval_dataset: Dataset = field(init=False)
41
+ other_eval_datasets: list = field(init=False)
42
+ rng_dataset: jnp.ndarray = field(init=False)
43
+ multi_hosts: bool = field(init=False)
44
+
45
+ def __post_init__(self):
46
+ if self.seed_dataset is None:
47
+ # create a random seed
48
+ self.seed_dataset = random.randint(0, 2**32 - 1)
49
+ # set numpy rng
50
+ self.np_rng = np.random.default_rng(self.seed_dataset)
51
+ self.multi_hosts = jax.process_count() > 1
52
+ # feed blank captions only in streaming mode for now
53
+ # otherwise dataset could be cached with same blanked captions
54
+ if self.blank_caption_prob:
55
+ assert (
56
+ self.streaming is True
57
+ ), "blank_caption_prob can only be used in streaming mode"
58
+ # define data_files
59
+ if self.train_file is not None or self.validation_file is not None:
60
+ # accept braceexpand notation
61
+ for k in ["train_file", "validation_file"]:
62
+ f = getattr(self, k)
63
+ if isinstance(f, str):
64
+ setattr(self, k, list(braceexpand(f)))
65
+ # for list of files, split training data shards by host
66
+ if (
67
+ isinstance(self.train_file, list)
68
+ and self.multi_hosts
69
+ and self.shard_by_host
70
+ ):
71
+ self.train_file = self.train_file[
72
+ jax.process_index() :: jax.process_count()
73
+ ]
74
+ data_files = {
75
+ "train": self.train_file,
76
+ "validation": self.validation_file,
77
+ }
78
+ else:
79
+ data_files = None
80
+
81
+ # multiple validation datasets
82
+ if self.multi_eval_ds:
83
+ assert Path(
84
+ self.dataset_repo_or_path
85
+ ).is_dir(), f"{self.dataset_repo_or_path} is not a directory, required for multi_eval_ds"
86
+ data_files = {
87
+ split.name: [str(f) for f in split.glob("*.parquet")]
88
+ for split in Path(self.dataset_repo_or_path).glob("*")
89
+ }
90
+ # rename "valid" to "validation" if present for consistency
91
+ if "valid" in data_files:
92
+ data_files["validation"] = data_files["valid"]
93
+ del data_files["valid"]
94
+ self.dataset_repo_or_path = "parquet"
95
+
96
+ # load dataset
97
+ dataset = load_dataset(
98
+ self.dataset_repo_or_path,
99
+ data_files=data_files,
100
+ streaming=self.streaming,
101
+ use_auth_token=self.use_auth_token,
102
+ )
103
+ if self.do_train:
104
+ if "train" not in dataset:
105
+ raise ValueError("Training requires a training dataset")
106
+ self.train_dataset = dataset["train"]
107
+ if self.max_train_samples is not None:
108
+ self.train_dataset = (
109
+ self.train_dataset.take(self.max_train_samples)
110
+ if self.streaming
111
+ else self.train_dataset.select(range(self.max_train_samples))
112
+ )
113
+ if self.do_eval:
114
+ if "validation" not in dataset:
115
+ raise ValueError("Evaluating requires a validation dataset")
116
+ self.eval_dataset = dataset["validation"]
117
+ if self.max_eval_samples is not None:
118
+ self.eval_dataset = (
119
+ self.eval_dataset.take(self.max_eval_samples)
120
+ if self.streaming
121
+ else self.eval_dataset.select(range(self.max_eval_samples))
122
+ )
123
+ # other eval datasets
124
+ other_eval_splits = dataset.keys() - {"train", "validation"}
125
+ self.other_eval_datasets = {
126
+ split: dataset[split] for split in other_eval_splits
127
+ }
128
+
129
+ def preprocess(self, tokenizer, config):
130
+ # get required config variables
131
+ decoder_start_token_id = config.decoder_start_token_id
132
+ normalize_text = config.normalize_text
133
+ max_length = config.max_text_length
134
+
135
+ if self.streaming:
136
+ # we need to shuffle early in streaming mode
137
+ if hasattr(self, "train_dataset"):
138
+ self.train_dataset = self.train_dataset.shuffle(
139
+ buffer_size=5000, seed=self.seed_dataset
140
+ )
141
+ else:
142
+ self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
143
+
144
+ # filter data
145
+ partial_filter_function = partial(
146
+ filter_function,
147
+ filter_column=self.filter_column,
148
+ filter_value=self.filter_value,
149
+ clip_score_column=self.clip_score_column,
150
+ min_clip_score=self.min_clip_score,
151
+ max_clip_score=self.max_clip_score,
152
+ )
153
+ for ds in ["train_dataset", "eval_dataset"]:
154
+ if hasattr(self, ds):
155
+ setattr(
156
+ self,
157
+ ds,
158
+ (
159
+ getattr(self, ds).filter(partial_filter_function)
160
+ if self.streaming
161
+ else getattr(self, ds).filter(
162
+ partial_filter_function,
163
+ num_proc=self.preprocessing_num_workers,
164
+ load_from_cache_file=not self.overwrite_cache,
165
+ desc="Filtering datasets",
166
+ )
167
+ ),
168
+ )
169
+ if hasattr(self, "other_eval_datasets"):
170
+ self.other_eval_datasets = {
171
+ split: (
172
+ ds.filter(partial_filter_function)
173
+ if self.streaming
174
+ else ds.filter(
175
+ partial_filter_function,
176
+ num_proc=self.preprocessing_num_workers,
177
+ load_from_cache_file=not self.overwrite_cache,
178
+ desc="Filtering datasets",
179
+ )
180
+ )
181
+ for split, ds in self.other_eval_datasets.items()
182
+ }
183
+
184
+ # normalize text
185
+ if normalize_text:
186
+ text_normalizer = TextNormalizer()
187
+ partial_normalize_function = partial(
188
+ normalize_function,
189
+ text_column=self.text_column,
190
+ text_normalizer=text_normalizer,
191
+ )
192
+ for ds in ["train_dataset", "eval_dataset"]:
193
+ if hasattr(self, ds):
194
+ setattr(
195
+ self,
196
+ ds,
197
+ (
198
+ getattr(self, ds).map(partial_normalize_function)
199
+ if self.streaming
200
+ else getattr(self, ds).map(
201
+ partial_normalize_function,
202
+ num_proc=self.preprocessing_num_workers,
203
+ load_from_cache_file=not self.overwrite_cache,
204
+ desc="Normalizing datasets",
205
+ )
206
+ ),
207
+ )
208
+ if hasattr(self, "other_eval_datasets"):
209
+ self.other_eval_datasets = {
210
+ split: (
211
+ ds.map(partial_normalize_function)
212
+ if self.streaming
213
+ else ds.map(
214
+ partial_normalize_function,
215
+ num_proc=self.preprocessing_num_workers,
216
+ load_from_cache_file=not self.overwrite_cache,
217
+ desc="Normalizing datasets",
218
+ )
219
+ )
220
+ for split, ds in self.other_eval_datasets.items()
221
+ }
222
+
223
+ # blank captions
224
+ if self.blank_caption_prob:
225
+ partial_blank_caption_function = partial(
226
+ blank_caption_function,
227
+ text_column=self.text_column,
228
+ blank_caption_prob=self.blank_caption_prob,
229
+ rng=self.np_rng,
230
+ )
231
+ if hasattr(self, "train_dataset"):
232
+ self.train_dataset = (
233
+ self.train_dataset.map(partial_blank_caption_function)
234
+ if self.streaming
235
+ else self.train_dataset.map(
236
+ partial_blank_caption_function,
237
+ num_proc=None
238
+ if self.seed_dataset
239
+ else self.preprocessing_num_workers,
240
+ load_from_cache_file=False,
241
+ desc="Blanking some captions",
242
+ )
243
+ )
244
+
245
+ # preprocess
246
+ partial_preprocess_function = partial(
247
+ preprocess_function,
248
+ tokenizer=tokenizer,
249
+ text_column=self.text_column,
250
+ encoding_column=self.encoding_column,
251
+ max_length=max_length,
252
+ decoder_start_token_id=decoder_start_token_id,
253
+ )
254
+ for ds in ["train_dataset", "eval_dataset"]:
255
+ if hasattr(self, ds):
256
+ setattr(
257
+ self,
258
+ ds,
259
+ (
260
+ getattr(self, ds).map(
261
+ partial_preprocess_function,
262
+ batched=True,
263
+ remove_columns=[
264
+ self.text_column,
265
+ self.encoding_column,
266
+ ],
267
+ )
268
+ if self.streaming
269
+ else getattr(self, ds).map(
270
+ partial_preprocess_function,
271
+ batched=True,
272
+ remove_columns=getattr(ds, "column_names"),
273
+ num_proc=self.preprocessing_num_workers,
274
+ load_from_cache_file=not self.overwrite_cache,
275
+ desc="Preprocessing datasets",
276
+ )
277
+ ),
278
+ )
279
+ if hasattr(self, "other_eval_datasets"):
280
+ self.other_eval_datasets = {
281
+ split: (
282
+ ds.map(
283
+ partial_preprocess_function,
284
+ batched=True,
285
+ remove_columns=[
286
+ self.text_column,
287
+ self.encoding_column,
288
+ ],
289
+ )
290
+ if self.streaming
291
+ else ds.map(
292
+ partial_preprocess_function,
293
+ batched=True,
294
+ remove_columns=getattr(ds, "column_names"),
295
+ num_proc=self.preprocessing_num_workers,
296
+ load_from_cache_file=not self.overwrite_cache,
297
+ desc="Preprocessing datasets",
298
+ )
299
+ )
300
+ for split, ds in self.other_eval_datasets.items()
301
+ }
302
+
303
+ def dataloader(self, split, batch_size, epoch=None):
304
+ def _dataloader_datasets_non_streaming(
305
+ dataset: Dataset,
306
+ rng: jax.random.PRNGKey = None,
307
+ ):
308
+ """
309
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
310
+ Shuffle batches if rng is set.
311
+ """
312
+ steps_per_epoch = len(dataset) // batch_size
313
+
314
+ if rng is not None:
315
+ batch_idx = jax.random.permutation(rng, len(dataset))
316
+ else:
317
+ batch_idx = jnp.arange(len(dataset))
318
+
319
+ batch_idx = batch_idx[
320
+ : steps_per_epoch * batch_size
321
+ ] # Skip incomplete batch.
322
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
323
+
324
+ for idx in batch_idx:
325
+ batch = dataset[idx]
326
+ batch = {k: jnp.array(v) for k, v in batch.items()}
327
+ yield batch
328
+
329
+ def _dataloader_datasets_streaming(
330
+ dataset: Dataset,
331
+ epoch: int,
332
+ ):
333
+ keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
334
+ batch = {k: [] for k in keys}
335
+ first_loop = True # stop after one loop in some cases
336
+ while (self.multi_hosts and split == "train") or first_loop:
337
+ # in multi-host, we run forever (no epoch) as hosts need to stop
338
+ # at the same time and training data may not be split equally
339
+ # For validation data we put the entire batch on each host and then
340
+ # keep only the one specific to each host (could be improved but not necessary)
341
+ if epoch is not None:
342
+ assert split == "train"
343
+ # reshuffle training data at each epoch
344
+ dataset.set_epoch(epoch)
345
+ epoch += 1
346
+ for item in dataset:
347
+ for k in keys:
348
+ batch[k].append(item[k])
349
+ if len(batch[keys[0]]) == batch_size:
350
+ batch = {k: jnp.array(v) for k, v in batch.items()}
351
+ yield batch
352
+ batch = {k: [] for k in keys}
353
+ first_loop = False
354
+
355
+ if split == "train":
356
+ ds = self.train_dataset
357
+ elif split == "eval":
358
+ ds = self.eval_dataset
359
+ else:
360
+ ds = self.other_eval_datasets[split]
361
+
362
+ if self.streaming:
363
+ return _dataloader_datasets_streaming(ds, epoch)
364
+ else:
365
+ if split == "train":
366
+ self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
367
+ return _dataloader_datasets_non_streaming(ds, input_rng)
368
+
369
+ @property
370
+ def length(self):
371
+ len_train_dataset, len_eval_dataset = None, None
372
+ if self.streaming:
373
+ # we don't know the length, let's just assume max_samples if defined
374
+ if self.max_train_samples is not None:
375
+ len_train_dataset = self.max_train_samples
376
+ if self.max_eval_samples is not None:
377
+ len_eval_dataset = self.max_eval_samples
378
+ else:
379
+ len_train_dataset = (
380
+ len(self.train_dataset) if hasattr(self, "train_dataset") else None
381
+ )
382
+ len_eval_dataset = (
383
+ len(self.eval_dataset) if hasattr(self, "eval_dataset") else None
384
+ )
385
+ return len_train_dataset, len_eval_dataset
386
+
387
+
388
+ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
389
+ """
390
+ Shift input ids one token to the right.
391
+ """
392
+ shifted_input_ids = np.zeros(input_ids.shape)
393
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
394
+ shifted_input_ids[:, 0] = decoder_start_token_id
395
+ return shifted_input_ids
396
+
397
+
398
+ def blank_caption_function(example, text_column, blank_caption_prob, rng=None):
399
+ if (
400
+ blank_caption_prob
401
+ and (rng.random() if rng is not None else np.random.random())
402
+ < blank_caption_prob
403
+ ):
404
+ example[text_column] = ""
405
+ return example
406
+
407
+
408
+ def normalize_function(example, text_column, text_normalizer):
409
+ example[text_column] = text_normalizer(example[text_column])
410
+ return example
411
+
412
+
413
+ def filter_function(
414
+ example,
415
+ min_clip_score,
416
+ max_clip_score,
417
+ clip_score_column,
418
+ filter_column,
419
+ filter_value,
420
+ ):
421
+ if min_clip_score is not None and example[clip_score_column] < min_clip_score:
422
+ return False
423
+ if max_clip_score is not None and example[clip_score_column] > max_clip_score:
424
+ return False
425
+ if filter_column is not None and example[filter_column] != filter_value:
426
+ return False
427
+ return True
428
+
429
+
430
+ def preprocess_function(
431
+ examples,
432
+ tokenizer,
433
+ text_column,
434
+ encoding_column,
435
+ max_length,
436
+ decoder_start_token_id,
437
+ ):
438
+ inputs = examples[text_column]
439
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
440
+ model_inputs = tokenizer(
441
+ inputs,
442
+ max_length=max_length,
443
+ padding="max_length",
444
+ truncation=True,
445
+ return_tensors="np",
446
+ )
447
+
448
+ # set up targets
449
+ # Note: labels correspond to our target indices
450
+ # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
451
+ labels = examples[encoding_column]
452
+ labels = np.asarray(labels)
453
+
454
+ # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
455
+ model_inputs["labels"] = labels
456
+
457
+ # In our case, this prepends the bos token and removes the last one
458
+ decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id)
459
+ model_inputs["decoder_input_ids"] = decoder_input_ids
460
+
461
+ return model_inputs
src/dalle_mini/model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .configuration import DalleBartConfig
2
+ from .modeling import DalleBart
3
+ from .partitions import set_partitions
4
+ from .processor import DalleBartProcessor
5
+ from .tokenizer import DalleBartTokenizer
src/dalle_mini/model/configuration.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ DalleBart model configuration """
16
+ import warnings
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ from .utils import PretrainedFromWandbMixin
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
27
+ model_type = "dallebart"
28
+ keys_to_ignore_at_inference = ["past_key_values"]
29
+ attribute_map = {
30
+ "num_attention_heads": "encoder_attention_heads",
31
+ "hidden_size": "d_model",
32
+ }
33
+
34
+ def __init__(
35
+ self,
36
+ normalize_text=False,
37
+ encoder_vocab_size=50264,
38
+ image_vocab_size=16384, # encoded image token space
39
+ image_length=256, # number of encoded tokens
40
+ max_text_length=64, # max number of text tokens
41
+ encoder_layers=12,
42
+ encoder_ffn_dim=4096,
43
+ encoder_attention_heads=16,
44
+ decoder_layers=12,
45
+ decoder_ffn_dim=4096,
46
+ decoder_attention_heads=16,
47
+ activation_function="gelu",
48
+ d_model=1024,
49
+ dropout=0.1,
50
+ attention_dropout=0.0,
51
+ activation_dropout=0.0,
52
+ init_std=0.02,
53
+ scale_embedding=False,
54
+ gradient_checkpointing=True,
55
+ use_scan=None,
56
+ use_cache=True,
57
+ is_encoder_decoder=True,
58
+ forced_eos_token_id=None,
59
+ tie_word_embeddings=False, # different modalities and sizes
60
+ do_sample=True,
61
+ # transformer variants
62
+ use_bias=False, # use bias in attention and dense layers (except for lm_head)
63
+ ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
64
+ ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln), "subln"
65
+ use_head_scale=False, # used in NormFormer
66
+ use_cosine_attention=False, # used in Swin v2
67
+ tau_init=0.05, # used only in cosine attention (Swin v2)
68
+ use_absolute_position_embeddings=True, # default
69
+ use_swin_position_embeddings=False, # used in Swin v1/v2
70
+ use_deepnet_scaling=False, # used in Deepnet
71
+ use_subln_init=False,
72
+ use_glu=True, # "GLU Variants Improve Transformer"
73
+ use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
74
+ sinkhorn_iters=1, # used in SinkFormers
75
+ use_final_ln_encoder=True, # final layer normalization in encoder
76
+ use_final_ln_decoder=True, # final layer normalization in decoder
77
+ # parameters that should not be necessary but could affect results
78
+ force_ln_scale=False, # force scale in layernorm even when followed by dense layers
79
+ **kwargs,
80
+ ):
81
+ # text normalizer
82
+ self.normalize_text = normalize_text
83
+
84
+ # transformer variants
85
+ self.use_bias = use_bias
86
+ assert ln_type in [
87
+ "rmsnorm",
88
+ "layernorm",
89
+ ], "ln_type must be 'rmsnorm' or 'layernorm'"
90
+ self.ln_type = ln_type
91
+ if ln_positions == "deepnet":
92
+ ln_positions = "postln"
93
+ assert ln_positions in [
94
+ "normformer",
95
+ "swinv2",
96
+ "cogview",
97
+ "postln",
98
+ "preln",
99
+ "subln",
100
+ ], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln', 'subln'"
101
+ self.use_head_scale = use_head_scale
102
+ assert use_alibi is False, "use_alibi is not supported yet"
103
+ self.ln_positions = ln_positions
104
+ self.use_cosine_attention = use_cosine_attention
105
+ self.tau_init = tau_init
106
+ self.use_absolute_position_embeddings = use_absolute_position_embeddings
107
+ self.use_swin_position_embeddings = use_swin_position_embeddings
108
+ self.use_deepnet_scaling = use_deepnet_scaling
109
+ self.use_subln_init = use_subln_init
110
+ self.use_glu = use_glu
111
+ self.use_alibi = use_alibi
112
+ self.sinkhorn_iters = sinkhorn_iters
113
+ if ln_positions == "postln":
114
+ assert (
115
+ use_final_ln_encoder
116
+ ), "use_final_ln_encoder must be True when ln_positions is 'postln'"
117
+ assert (
118
+ use_final_ln_decoder
119
+ ), "use_final_ln_decoder must be True when ln_positions is 'postln'"
120
+ self.use_final_ln_encoder = use_final_ln_encoder
121
+ self.use_final_ln_decoder = use_final_ln_decoder
122
+ self.force_ln_scale = force_ln_scale
123
+
124
+ # common parameters
125
+ self.encoder_vocab_size = encoder_vocab_size
126
+ self.image_vocab_size = image_vocab_size
127
+ self.image_length = image_length
128
+ self.max_text_length = max_text_length
129
+ self.d_model = d_model
130
+ self.encoder_ffn_dim = encoder_ffn_dim
131
+ self.encoder_layers = encoder_layers
132
+ self.encoder_attention_heads = encoder_attention_heads
133
+ self.decoder_ffn_dim = decoder_ffn_dim
134
+ self.decoder_layers = decoder_layers
135
+ self.decoder_attention_heads = decoder_attention_heads
136
+ self.dropout = dropout
137
+ self.attention_dropout = attention_dropout
138
+ self.activation_dropout = activation_dropout
139
+ self.activation_function = activation_function
140
+ self.init_std = init_std
141
+ self.use_cache = use_cache
142
+ self.gradient_checkpointing = gradient_checkpointing
143
+ # all layers are the same in most configurations
144
+ self.use_scan = use_scan if use_scan is not None else ln_positions != "swinv2"
145
+ assert not (
146
+ self.use_scan and ln_positions == "swinv2"
147
+ ), "scan cannot be used with 'swinv2'"
148
+ self.scale_embedding = (
149
+ scale_embedding # scale factor will be sqrt(d_model) if True
150
+ )
151
+
152
+ # special token id's are appended to vocab if not provided
153
+ decoder_start_token_id = kwargs.pop("decoder_start_token_id", image_vocab_size)
154
+ bos_token_id = kwargs.pop("bos_token_id", image_vocab_size)
155
+ pad_token_id = kwargs.pop("pad_token_id", image_vocab_size)
156
+ eos_token_id = kwargs.pop("eos_token_id", image_vocab_size)
157
+
158
+ # we generate to image_length + 1 (for bos) by default
159
+ min_length = kwargs.pop("min_length", image_length + 1)
160
+ max_length = kwargs.pop("max_length", image_length + 1)
161
+
162
+ super().__init__(
163
+ # args required in parent class
164
+ is_encoder_decoder=is_encoder_decoder,
165
+ tie_word_embeddings=tie_word_embeddings,
166
+ forced_eos_token_id=forced_eos_token_id,
167
+ decoder_start_token_id=decoder_start_token_id,
168
+ bos_token_id=bos_token_id,
169
+ pad_token_id=pad_token_id,
170
+ eos_token_id=eos_token_id,
171
+ min_length=min_length,
172
+ max_length=max_length,
173
+ do_sample=do_sample,
174
+ **kwargs,
175
+ )
176
+
177
+ # ensure backward compatibility for BART CNN models
178
+ if self.forced_bos_token_id is None and kwargs.get(
179
+ "force_bos_token_to_be_generated", False
180
+ ):
181
+ self.forced_bos_token_id = self.bos_token_id
182
+ warnings.warn(
183
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
184
+ "The config can simply be saved and uploaded again to be fixed."
185
+ )
src/dalle_mini/model/modeling.py ADDED
@@ -0,0 +1,1953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021-2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and & DALL·E Mini team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ DalleBart model. """
16
+
17
+ import math
18
+ from functools import partial
19
+ from typing import Any, Dict, Optional, Tuple
20
+
21
+ import flax
22
+ import flax.linen as nn
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from einops import rearrange
26
+ from flax.core.frozen_dict import unfreeze
27
+ from flax.linen import combine_masks, make_causal_mask
28
+ from flax.linen import partitioning as nn_partitioning
29
+ from flax.linen.linear import PrecisionLike
30
+ from flax.traverse_util import flatten_dict, unflatten_dict
31
+ from jax import custom_jvp, lax
32
+ from jax.random import PRNGKey
33
+ from transformers.modeling_flax_outputs import (
34
+ FlaxBaseModelOutput,
35
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
36
+ FlaxCausalLMOutputWithCrossAttentions,
37
+ FlaxSeq2SeqLMOutput,
38
+ )
39
+ from transformers.modeling_flax_utils import ACT2FN
40
+ from transformers.models.bart.modeling_flax_bart import (
41
+ FlaxBartAttention,
42
+ FlaxBartForConditionalGeneration,
43
+ FlaxBartForConditionalGenerationModule,
44
+ FlaxBartModule,
45
+ )
46
+ from transformers.utils import ModelOutput, logging
47
+
48
+ from .configuration import DalleBartConfig
49
+ from .utils import PretrainedFromWandbMixin
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+ remat = nn_partitioning.remat
54
+
55
+
56
+ def smelu(beta: Any = 1.0):
57
+ """
58
+ Implementation of "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations"
59
+ https://arxiv.org/abs/2202.06499
60
+ """
61
+
62
+ @custom_jvp
63
+ @jax.jit
64
+ def _smelu(x: Any) -> Any:
65
+ x = jnp.where(x <= -beta, 0.0, x)
66
+ return jnp.where(x >= beta, x, jnp.square(x + beta) / (4 * beta))
67
+
68
+ _smelu.defjvps(
69
+ lambda g, ans, x: lax.select(
70
+ x == -beta,
71
+ lax.full_like(g, 0),
72
+ lax.select(x == beta, lax.full_like(g, 1), g),
73
+ )
74
+ )
75
+ return _smelu
76
+
77
+
78
+ ACT2FN.update({"smelu": smelu()})
79
+
80
+
81
+ # deepnet initialization
82
+ def deepnet_init(init_std, gain=1):
83
+ init = jax.nn.initializers.normal(init_std)
84
+
85
+ def _init(*args, **kwargs):
86
+ return gain * init(*args, **kwargs)
87
+
88
+ return _init
89
+
90
+
91
+ # deepnet gain
92
+ deepnet_gain = {
93
+ "encoder": {
94
+ "alpha": lambda config: 0.81
95
+ * (config.encoder_layers**4 * config.decoder_layers) ** 0.0625,
96
+ "beta": lambda config: 0.87
97
+ * (config.encoder_layers**4 * config.decoder_layers) ** -0.0625,
98
+ },
99
+ "decoder": {
100
+ "alpha": lambda config: (3 * config.decoder_layers) ** 0.25,
101
+ "beta": lambda config: (12 * config.decoder_layers) ** -0.25,
102
+ },
103
+ }
104
+
105
+ # subln gain
106
+ subln_gain = {
107
+ "encoder": lambda config: math.sqrt(
108
+ 1.0
109
+ / 3.0
110
+ * math.log(3 * config.decoder_layers)
111
+ * math.log(2 * config.encoder_layers)
112
+ ),
113
+ "decoder": lambda config: math.sqrt(math.log(3 * config.decoder_layers)),
114
+ }
115
+
116
+
117
+ class RMSNorm(nn.Module):
118
+ """
119
+ From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
120
+
121
+ Adapted from flax.linen.LayerNorm
122
+ """
123
+
124
+ epsilon: float = 1e-6
125
+ dtype: Any = jnp.float32
126
+ param_dtype: Any = jnp.float32
127
+ use_scale: bool = True
128
+ scale_init: Any = jax.nn.initializers.ones
129
+
130
+ @nn.compact
131
+ def __call__(self, x):
132
+ reduction_axes = (-1,)
133
+ feature_axes = (-1,)
134
+
135
+ rms_sq = self._compute_rms_sq(x, reduction_axes)
136
+
137
+ return self._normalize(
138
+ self,
139
+ x,
140
+ rms_sq,
141
+ reduction_axes,
142
+ feature_axes,
143
+ self.dtype,
144
+ self.param_dtype,
145
+ self.epsilon,
146
+ self.use_scale,
147
+ self.scale_init,
148
+ )
149
+
150
+ def _compute_rms_sq(self, x, axes):
151
+ x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
152
+ rms_sq = jnp.mean(jax.lax.square(x), axes)
153
+ return rms_sq
154
+
155
+ def _normalize(
156
+ self,
157
+ mdl,
158
+ x,
159
+ rms_sq,
160
+ reduction_axes,
161
+ feature_axes,
162
+ dtype,
163
+ param_dtype,
164
+ epsilon,
165
+ use_scale,
166
+ scale_init,
167
+ ):
168
+ reduction_axes = nn.normalization._canonicalize_axes(x.ndim, reduction_axes)
169
+ feature_axes = nn.normalization._canonicalize_axes(x.ndim, feature_axes)
170
+ stats_shape = list(x.shape)
171
+ for axis in reduction_axes:
172
+ stats_shape[axis] = 1
173
+ rms_sq = rms_sq.reshape(stats_shape)
174
+ feature_shape = [1] * x.ndim
175
+ reduced_feature_shape = []
176
+ for ax in feature_axes:
177
+ feature_shape[ax] = x.shape[ax]
178
+ reduced_feature_shape.append(x.shape[ax])
179
+ mul = lax.rsqrt(rms_sq + epsilon)
180
+ if use_scale:
181
+ scale = mdl.param(
182
+ "scale", scale_init, reduced_feature_shape, param_dtype
183
+ ).reshape(feature_shape)
184
+ mul *= scale
185
+ y = mul * x
186
+ return jnp.asarray(y, dtype)
187
+
188
+
189
+ def norm(type, *args, **kwargs):
190
+ if type == "rmsnorm":
191
+ return RMSNorm(*args, **kwargs)
192
+ elif type == "layernorm":
193
+ return nn.LayerNorm(*args, **kwargs)
194
+ else:
195
+ raise ValueError(f"Unknown norm type {type}")
196
+
197
+
198
+ def dot_product_attention_weights(
199
+ query: Any,
200
+ key: Any,
201
+ bias: Optional[Any] = None,
202
+ mask: Optional[Any] = None,
203
+ embed_pos: Optional[Any] = None,
204
+ broadcast_dropout: bool = True,
205
+ dropout_rng: Optional[PRNGKey] = None,
206
+ dropout_rate: float = 0.0,
207
+ deterministic: bool = False,
208
+ dtype: Any = jnp.float32,
209
+ precision: PrecisionLike = None,
210
+ sinkhorn_iters: int = 1,
211
+ is_encoder: bool = False,
212
+ tau=None,
213
+ ):
214
+ """
215
+ Computes dot-product attention weights given query and key.
216
+ mask is included into the bias.
217
+
218
+ Adapted from flax.linen.attention.dot_product_attention_weights"
219
+ """
220
+ assert query.ndim == key.ndim, "q, k must have same rank."
221
+ assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
222
+ assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
223
+ assert query.shape[-1] == key.shape[-1], "q, k depths must match."
224
+
225
+ # attn weight shape is (batch..., num_heads, q_length, kv_length)
226
+ attn_weights = jnp.einsum("...qhd,...khd->...hqk", query, key, precision=precision)
227
+
228
+ # divide by tau (used in Swin v2)
229
+ if tau is not None:
230
+ attn_weights = attn_weights / tau
231
+ else:
232
+ depth = query.shape[-1]
233
+ attn_weights = attn_weights / jnp.sqrt(depth).astype(dtype)
234
+
235
+ # apply attention bias: masking, dropout, proximity bias, etc.
236
+ if bias is not None:
237
+ attn_weights = attn_weights + bias
238
+
239
+ # add relative position
240
+ if embed_pos is not None:
241
+ attn_weights = attn_weights + embed_pos
242
+
243
+ # normalize the attention weights
244
+ if not is_encoder or sinkhorn_iters == 1:
245
+ # sinkhorn does not work for causal (leaks info of future tokens into past)
246
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
247
+ else:
248
+ # adapted from https://github.com/lucidrains/sinkhorn-transformer
249
+ for i in range(sinkhorn_iters):
250
+ # when causal, some attn_weights have been set to -inf through bias
251
+ if i % 2 == 0:
252
+ attn_weights -= jax.nn.logsumexp(attn_weights, axis=-1, keepdims=True)
253
+ else:
254
+ attn_weights -= jax.nn.logsumexp(attn_weights, axis=-2, keepdims=True)
255
+ if mask is not None:
256
+ attn_weights = jnp.where(mask, attn_weights, -jnp.inf)
257
+ attn_weights = jnp.exp(attn_weights).astype(dtype)
258
+
259
+ # apply attention dropout
260
+ if not deterministic and dropout_rate > 0.0:
261
+ keep_prob = 1.0 - dropout_rate
262
+ if broadcast_dropout:
263
+ # dropout is broadcast across the batch + head dimensions
264
+ dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
265
+ keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
266
+ else:
267
+ keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
268
+ multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(
269
+ keep_prob, dtype=dtype
270
+ )
271
+ attn_weights = attn_weights * multiplier
272
+
273
+ return attn_weights
274
+
275
+
276
+ class FlaxBartAttention(FlaxBartAttention):
277
+ """
278
+ Edits:
279
+ - causal mask is used only in decoder and considers image_length
280
+ - scale attention heads per NormFormer paper
281
+ """
282
+
283
+ is_encoder: bool = False
284
+ is_cross_attention: bool = False
285
+ q_length: int = None
286
+ k_length: int = None
287
+
288
+ def setup(self) -> None:
289
+ self.head_dim = self.embed_dim // self.num_heads
290
+ if self.head_dim * self.num_heads != self.embed_dim:
291
+ raise ValueError(
292
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
293
+ f" and `num_heads`: {self.num_heads})."
294
+ )
295
+
296
+ dense = partial(
297
+ nn.Dense,
298
+ self.embed_dim,
299
+ use_bias=self.bias,
300
+ dtype=self.dtype,
301
+ )
302
+
303
+ if self.config.use_deepnet_scaling:
304
+ gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
305
+ self.config
306
+ )
307
+ elif self.config.use_subln_init and not self.is_cross_attention:
308
+ gain = subln_gain["encoder" if self.is_encoder else "decoder"](self.config)
309
+
310
+ self.q_proj = dense(
311
+ kernel_init=jax.nn.initializers.normal(self.config.init_std)
312
+ )
313
+ self.k_proj = dense(
314
+ kernel_init=jax.nn.initializers.normal(self.config.init_std)
315
+ )
316
+ self.v_proj = dense(
317
+ kernel_init=deepnet_init(self.config.init_std, gain)
318
+ if (
319
+ self.config.use_deepnet_scaling
320
+ or (self.config.use_subln_init and not self.is_cross_attention)
321
+ )
322
+ else jax.nn.initializers.normal(self.config.init_std)
323
+ )
324
+ self.out_proj = dense(
325
+ kernel_init=deepnet_init(self.config.init_std, gain)
326
+ if (
327
+ self.config.use_deepnet_scaling
328
+ or (self.config.use_subln_init and not self.is_cross_attention)
329
+ )
330
+ else jax.nn.initializers.normal(self.config.init_std)
331
+ )
332
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
333
+
334
+ if self.config.use_head_scale:
335
+ self.head_scale = self.param(
336
+ "head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1)
337
+ )
338
+
339
+ if self.config.use_cosine_attention:
340
+ # TODO: try using a learnt scale, somehow it immediately diverges in my experiments
341
+ self.tau = self.config.tau_init
342
+
343
+ if self.config.use_swin_position_embeddings:
344
+ self.rel_bias = nn.Embed(
345
+ self.q_length,
346
+ self.k_length * self.num_heads,
347
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
348
+ )
349
+
350
+ if self.causal:
351
+ # used only in decoder
352
+ self.causal_mask = make_causal_mask(
353
+ jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool"
354
+ )
355
+
356
+ if self.config.ln_positions in ["subln"] and not self.is_cross_attention:
357
+ self.mid_layernorm = norm(
358
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
359
+ )
360
+
361
+ def __call__(
362
+ self,
363
+ hidden_states: jnp.ndarray,
364
+ key_value_states: Optional[jnp.ndarray] = None,
365
+ attention_mask: Optional[jnp.ndarray] = None,
366
+ init_cache: bool = False,
367
+ deterministic: bool = True,
368
+ ) -> Tuple[jnp.ndarray]:
369
+ """Input shape: Batch x Time x Channel"""
370
+
371
+ # if key_value_states are provided this layer is used as a cross-attention layer
372
+ # for the decoder
373
+ is_cross_attention = key_value_states is not None
374
+ batch_size = hidden_states.shape[0]
375
+
376
+ # get query proj
377
+ query_states = self.q_proj(hidden_states)
378
+ # get key, value proj
379
+ if is_cross_attention:
380
+ # cross_attentions
381
+ key_states = self.k_proj(key_value_states)
382
+ value_states = self.v_proj(key_value_states)
383
+ else:
384
+ # self_attention
385
+ key_states = self.k_proj(hidden_states)
386
+ value_states = self.v_proj(hidden_states)
387
+
388
+ query_states = self._split_heads(query_states)
389
+ key_states = self._split_heads(key_states)
390
+ value_states = self._split_heads(value_states)
391
+
392
+ # handle cache prepare causal attention mask
393
+ if self.causal:
394
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
395
+ if self.has_variable("cache", "cached_key"):
396
+ mask_shift = self.variables["cache"]["cache_index"]
397
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
398
+ causal_mask = lax.dynamic_slice(
399
+ self.causal_mask,
400
+ (0, 0, mask_shift, 0),
401
+ (1, 1, query_length, max_decoder_length),
402
+ )
403
+ else:
404
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
405
+ causal_mask = jnp.broadcast_to(
406
+ causal_mask, (batch_size,) + causal_mask.shape[1:]
407
+ )
408
+
409
+ # combine masks if needed
410
+ if attention_mask is not None and self.causal:
411
+ attention_mask = jnp.broadcast_to(
412
+ jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape
413
+ )
414
+ attention_mask = combine_masks(attention_mask, causal_mask)
415
+ elif self.causal:
416
+ attention_mask = causal_mask
417
+ elif attention_mask is not None:
418
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
419
+
420
+ # During fast autoregressive decoding, we feed one position at a time,
421
+ # and cache the keys and values step by step.
422
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
423
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
424
+ key_states, value_states, query_states, attention_mask
425
+ )
426
+
427
+ # Convert the boolean attention mask to an attention bias.
428
+ if attention_mask is not None:
429
+ # attention mask in the form of attention bias
430
+ attention_bias = lax.select(
431
+ attention_mask > 0,
432
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
433
+ jnp.full(attention_mask.shape, -jnp.inf).astype(self.dtype),
434
+ )
435
+ else:
436
+ attention_bias = None
437
+
438
+ dropout_rng = None
439
+ if not deterministic and self.dropout > 0.0:
440
+ dropout_rng = self.make_rng("dropout")
441
+
442
+ if self.config.use_cosine_attention:
443
+ # normalize q and k
444
+ query_states = query_states / (
445
+ jnp.linalg.norm(query_states, axis=-1, keepdims=True) + 1e-8
446
+ )
447
+ key_states = key_states / (
448
+ jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8
449
+ )
450
+
451
+ # relative position embeddings
452
+ if self.config.use_swin_position_embeddings:
453
+ position_ids = jnp.arange(self.q_length)
454
+ embed_pos = self.rel_bias(position_ids)
455
+ embed_pos = rearrange(embed_pos, "q (k h) -> 1 h q k", h=self.num_heads)
456
+ else:
457
+ embed_pos = None
458
+
459
+ tau = self.tau if self.config.use_cosine_attention else None
460
+ attn_weights = dot_product_attention_weights(
461
+ query_states,
462
+ key_states,
463
+ bias=attention_bias,
464
+ mask=attention_mask,
465
+ embed_pos=embed_pos,
466
+ dropout_rng=dropout_rng,
467
+ dropout_rate=self.dropout,
468
+ broadcast_dropout=True,
469
+ deterministic=deterministic,
470
+ dtype=self.dtype,
471
+ precision=None,
472
+ sinkhorn_iters=self.config.sinkhorn_iters,
473
+ is_encoder=self.is_encoder,
474
+ tau=tau,
475
+ )
476
+
477
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
478
+ if self.config.use_head_scale:
479
+ # per Normformer
480
+ attn_output = attn_output * self.head_scale
481
+ attn_output = self._merge_heads(attn_output)
482
+
483
+ if self.config.ln_positions in ["subln"] and not self.is_cross_attention:
484
+ attn_output = self.mid_layernorm(attn_output)
485
+
486
+ attn_output = self.out_proj(attn_output)
487
+
488
+ return attn_output, attn_weights
489
+
490
+
491
+ class GLU(nn.Module):
492
+ """From "GLU Variants Improve Transformer" by https://arxiv.org/abs/2002.05202"""
493
+
494
+ config: DalleBartConfig
495
+ ffn_dim: int
496
+ embed_dim: int
497
+ dtype: jnp.dtype = jnp.float32
498
+ is_encoder: bool = False
499
+
500
+ @nn.compact
501
+ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
502
+ if self.config.use_deepnet_scaling:
503
+ gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
504
+ self.config
505
+ )
506
+ elif self.config.use_subln_init:
507
+ gain = subln_gain["encoder" if self.is_encoder else "decoder"](self.config)
508
+
509
+ if self.config.ln_positions in ["normformer", "cogview", "preln", "subln"]:
510
+ x = norm(
511
+ self.config.ln_type,
512
+ dtype=self.dtype,
513
+ epsilon=1e-05,
514
+ use_scale=self.config.force_ln_scale,
515
+ )(x)
516
+ w = nn.Dense(
517
+ self.ffn_dim,
518
+ dtype=self.dtype,
519
+ use_bias=self.config.use_bias,
520
+ kernel_init=deepnet_init(self.config.init_std, gain)
521
+ if (self.config.use_deepnet_scaling or self.config.use_subln_init)
522
+ else jax.nn.initializers.normal(self.config.init_std),
523
+ )(x)
524
+ w = ACT2FN[self.config.activation_function](w)
525
+ v = nn.Dense(
526
+ self.ffn_dim,
527
+ dtype=self.dtype,
528
+ use_bias=self.config.use_bias,
529
+ kernel_init=deepnet_init(self.config.init_std, gain)
530
+ if (self.config.use_deepnet_scaling or self.config.use_subln_init)
531
+ else jax.nn.initializers.normal(self.config.init_std),
532
+ )(x)
533
+ x = w * v
534
+ if self.config.ln_positions in ["normformer", "subln"]:
535
+ x = norm(
536
+ self.config.ln_type,
537
+ dtype=self.dtype,
538
+ epsilon=1e-05,
539
+ use_scale=self.config.force_ln_scale,
540
+ )(x)
541
+ x = nn.Dropout(rate=self.config.activation_dropout)(
542
+ x, deterministic=deterministic
543
+ )
544
+
545
+ x = nn.Dense(
546
+ self.embed_dim,
547
+ dtype=self.dtype,
548
+ use_bias=self.config.use_bias,
549
+ kernel_init=deepnet_init(self.config.init_std, gain)
550
+ if (self.config.use_deepnet_scaling or self.config.use_subln_init)
551
+ else jax.nn.initializers.normal(self.config.init_std),
552
+ )(x)
553
+ if self.config.ln_positions in ["swinv2", "cogview"]:
554
+ x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
555
+ x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
556
+ return x
557
+
558
+
559
+ class FFN(nn.Module):
560
+ """Simple FFN layer"""
561
+
562
+ config: DalleBartConfig
563
+ ffn_dim: int
564
+ embed_dim: int
565
+ dtype: jnp.dtype = jnp.float32
566
+ is_encoder: bool = False
567
+
568
+ @nn.compact
569
+ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
570
+ if self.config.use_deepnet_scaling:
571
+ gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
572
+ self.config
573
+ )
574
+ elif self.config.use_subln_init:
575
+ gain = subln_gain["encoder" if self.is_encoder else "decoder"](self.config)
576
+ if self.config.ln_positions in ["normformer", "cogview", "preln", "subln"]:
577
+ x = norm(
578
+ self.config.ln_type,
579
+ dtype=self.dtype,
580
+ epsilon=1e-05,
581
+ use_scale=self.config.force_ln_scale,
582
+ )(x)
583
+ x = nn.Dense(
584
+ self.ffn_dim,
585
+ dtype=self.dtype,
586
+ use_bias=self.config.use_bias,
587
+ kernel_init=deepnet_init(self.config.init_std, gain)
588
+ if (self.config.use_deepnet_scaling or self.config.use_subln_init)
589
+ else jax.nn.initializers.normal(self.config.init_std),
590
+ )(x)
591
+ x = ACT2FN[self.config.activation_function](x)
592
+ if self.config.ln_positions in ["normformer", "subln"]:
593
+ x = norm(
594
+ self.config.ln_type,
595
+ dtype=self.dtype,
596
+ epsilon=1e-05,
597
+ use_scale=self.config.force_ln_scale,
598
+ )(x)
599
+ x = nn.Dropout(rate=self.config.activation_dropout)(
600
+ x, deterministic=deterministic
601
+ )
602
+ x = nn.Dense(
603
+ self.embed_dim,
604
+ dtype=self.dtype,
605
+ use_bias=self.config.use_bias,
606
+ kernel_init=deepnet_init(self.config.init_std, gain)
607
+ if (self.config.use_deepnet_scaling or self.config.use_subln_init)
608
+ else jax.nn.initializers.normal(self.config.init_std),
609
+ )(x)
610
+ if self.config.ln_positions in ["swinv2", "cogview"]:
611
+ x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
612
+ x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
613
+ return x
614
+
615
+
616
+ class FlaxBartEncoderLayer(nn.Module):
617
+ """
618
+ Edits:
619
+ - no bias
620
+ - use custom FlaxBartAttention
621
+ """
622
+
623
+ config: DalleBartConfig
624
+ dtype: jnp.dtype = jnp.float32
625
+ add_norm: bool = False
626
+ use_scale: bool = True
627
+
628
+ @nn.compact
629
+ def __call__(
630
+ self,
631
+ hidden_states: jnp.ndarray,
632
+ attention_mask: jnp.ndarray,
633
+ output_attentions: bool = True,
634
+ deterministic: bool = True,
635
+ ) -> Tuple[jnp.ndarray]:
636
+ if self.config.use_scan:
637
+ hidden_states = hidden_states[0]
638
+
639
+ res_gain = (
640
+ deepnet_gain["encoder"]["alpha"](self.config)
641
+ if self.config.use_deepnet_scaling
642
+ else 1
643
+ )
644
+
645
+ embed_dim = self.config.d_model
646
+ residual = hidden_states
647
+ if self.config.ln_positions in ["normformer", "cogview", "preln", "subln"]:
648
+ hidden_states = norm(
649
+ self.config.ln_type,
650
+ dtype=self.dtype,
651
+ epsilon=1e-05,
652
+ use_scale=self.config.force_ln_scale,
653
+ )(hidden_states)
654
+ hidden_states, attn_weights = FlaxBartAttention(
655
+ config=self.config,
656
+ embed_dim=embed_dim,
657
+ num_heads=self.config.encoder_attention_heads,
658
+ dropout=self.config.attention_dropout,
659
+ bias=self.config.use_bias,
660
+ dtype=self.dtype,
661
+ is_encoder=True,
662
+ is_cross_attention=False,
663
+ q_length=self.config.max_text_length,
664
+ k_length=self.config.max_text_length,
665
+ )(hidden_states=hidden_states, attention_mask=attention_mask)
666
+
667
+ if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
668
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
669
+ hidden_states
670
+ )
671
+ hidden_states = nn.Dropout(rate=self.config.dropout)(
672
+ hidden_states, deterministic=deterministic
673
+ )
674
+ hidden_states = residual * res_gain + hidden_states
675
+ if self.config.ln_positions in ["postln"]:
676
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
677
+ hidden_states
678
+ )
679
+
680
+ residual = hidden_states
681
+ ff_block = (
682
+ GLU(
683
+ config=self.config,
684
+ ffn_dim=self.config.encoder_ffn_dim,
685
+ embed_dim=embed_dim,
686
+ dtype=self.dtype,
687
+ is_encoder=True,
688
+ )
689
+ if self.config.use_glu
690
+ else FFN(
691
+ config=self.config,
692
+ ffn_dim=self.config.encoder_ffn_dim,
693
+ embed_dim=embed_dim,
694
+ dtype=self.dtype,
695
+ is_encoder=True,
696
+ )
697
+ )
698
+ hidden_states = ff_block(hidden_states, deterministic=deterministic)
699
+ hidden_states = residual * res_gain + hidden_states
700
+ if self.add_norm:
701
+ use_scale = self.use_scale or self.config.force_ln_scale
702
+ hidden_states = norm(
703
+ self.config.ln_type,
704
+ dtype=self.dtype,
705
+ epsilon=1e-05,
706
+ use_scale=use_scale,
707
+ )(hidden_states)
708
+
709
+ outputs = (hidden_states,)
710
+
711
+ if output_attentions:
712
+ outputs += (attn_weights,)
713
+
714
+ if self.config.use_scan:
715
+ outputs = (outputs, None)
716
+
717
+ return outputs
718
+
719
+
720
+ class FlaxBartDecoderLayer(nn.Module):
721
+ """
722
+ Edits:
723
+ - no bias
724
+ - use custom FlaxBartAttention
725
+ """
726
+
727
+ config: DalleBartConfig
728
+ dtype: jnp.dtype = jnp.float32
729
+ add_norm: bool = False
730
+ use_scale: bool = True
731
+
732
+ @nn.compact
733
+ def __call__(
734
+ self,
735
+ hidden_states: jnp.ndarray,
736
+ attention_mask: jnp.ndarray,
737
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
738
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
739
+ init_cache: bool = False,
740
+ output_attentions: bool = True,
741
+ deterministic: bool = True,
742
+ ) -> Tuple[jnp.ndarray]:
743
+ if self.config.use_scan:
744
+ hidden_states = hidden_states[0]
745
+
746
+ res_gain = (
747
+ deepnet_gain["decoder"]["alpha"](self.config)
748
+ if self.config.use_deepnet_scaling
749
+ else 1
750
+ )
751
+
752
+ embed_dim = self.config.d_model
753
+ residual = hidden_states
754
+
755
+ # Self Attention
756
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
757
+ hidden_states = norm(
758
+ self.config.ln_type,
759
+ dtype=self.dtype,
760
+ epsilon=1e-05,
761
+ use_scale=self.config.force_ln_scale,
762
+ )(hidden_states)
763
+ hidden_states, attn_weights = FlaxBartAttention(
764
+ config=self.config,
765
+ embed_dim=embed_dim,
766
+ num_heads=self.config.decoder_attention_heads,
767
+ dropout=self.config.attention_dropout,
768
+ causal=True,
769
+ bias=self.config.use_bias,
770
+ dtype=self.dtype,
771
+ is_encoder=False,
772
+ is_cross_attention=False,
773
+ q_length=self.config.image_length,
774
+ k_length=self.config.image_length,
775
+ )(
776
+ hidden_states=hidden_states,
777
+ attention_mask=attention_mask,
778
+ init_cache=init_cache,
779
+ )
780
+
781
+ if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
782
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
783
+ hidden_states
784
+ )
785
+ hidden_states = nn.Dropout(rate=self.config.dropout)(
786
+ hidden_states, deterministic=deterministic
787
+ )
788
+ hidden_states = residual * res_gain + hidden_states
789
+ if self.config.ln_positions in ["postln"]:
790
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
791
+ hidden_states
792
+ )
793
+
794
+ # Cross Attention
795
+ cross_attn_weights = None
796
+ if encoder_hidden_states is not None:
797
+ residual = hidden_states
798
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
799
+ hidden_states = norm(
800
+ self.config.ln_type,
801
+ dtype=self.dtype,
802
+ epsilon=1e-05,
803
+ use_scale=self.config.force_ln_scale,
804
+ )(hidden_states)
805
+ hidden_states, cross_attn_weights = FlaxBartAttention(
806
+ config=self.config,
807
+ embed_dim=embed_dim,
808
+ num_heads=self.config.decoder_attention_heads,
809
+ dropout=self.config.attention_dropout,
810
+ bias=self.config.use_bias,
811
+ dtype=self.dtype,
812
+ is_encoder=False,
813
+ is_cross_attention=True,
814
+ q_length=self.config.image_length,
815
+ k_length=self.config.max_text_length,
816
+ )(
817
+ hidden_states=hidden_states,
818
+ key_value_states=encoder_hidden_states,
819
+ attention_mask=encoder_attention_mask,
820
+ )
821
+ if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
822
+ hidden_states = norm(
823
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
824
+ )(hidden_states)
825
+ hidden_states = nn.Dropout(rate=self.config.dropout)(
826
+ hidden_states, deterministic=deterministic
827
+ )
828
+ hidden_states = residual * res_gain + hidden_states
829
+ if self.config.ln_positions in ["postln"]:
830
+ hidden_states = norm(
831
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
832
+ )(hidden_states)
833
+
834
+ # Feed forward
835
+ residual = hidden_states
836
+ ff_block = (
837
+ GLU(
838
+ config=self.config,
839
+ ffn_dim=self.config.decoder_ffn_dim,
840
+ embed_dim=embed_dim,
841
+ dtype=self.dtype,
842
+ is_encoder=False,
843
+ )
844
+ if self.config.use_glu
845
+ else FFN(
846
+ config=self.config,
847
+ ffn_dim=self.config.decoder_ffn_dim,
848
+ embed_dim=embed_dim,
849
+ dtype=self.dtype,
850
+ is_encoder=False,
851
+ )
852
+ )
853
+ hidden_states = ff_block(hidden_states, deterministic=deterministic)
854
+ hidden_states = residual * res_gain + hidden_states
855
+ if self.add_norm:
856
+ use_scale = self.use_scale or self.config.force_ln_scale
857
+ hidden_states = norm(
858
+ self.config.ln_type,
859
+ dtype=self.dtype,
860
+ epsilon=1e-05,
861
+ use_scale=use_scale,
862
+ )(hidden_states)
863
+
864
+ outputs = (hidden_states,)
865
+
866
+ if output_attentions:
867
+ outputs += (attn_weights, cross_attn_weights)
868
+
869
+ if self.config.use_scan:
870
+ outputs = (outputs, None)
871
+
872
+ return outputs
873
+
874
+
875
+ class FlaxBartEncoderLayerCollection(nn.Module):
876
+ config: DalleBartConfig
877
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
878
+ """
879
+ Edits:
880
+ - use custom FlaxBartEncoderLayer
881
+ - allow Gradient Checkpointing (nn.remat)
882
+ """
883
+
884
+ @nn.compact
885
+ def __call__(
886
+ self,
887
+ hidden_states,
888
+ attention_mask,
889
+ deterministic: bool = True,
890
+ output_attentions: bool = False,
891
+ output_hidden_states: bool = False,
892
+ return_dict: bool = True,
893
+ ):
894
+ all_hidden_states = () if output_hidden_states else None
895
+ all_self_attns = () if output_attentions else None
896
+
897
+ n_layers = self.config.encoder_layers
898
+ layer = (
899
+ remat(
900
+ FlaxBartEncoderLayer,
901
+ static_argnums=(2, 3),
902
+ prevent_cse=not self.config.use_scan,
903
+ )
904
+ if self.config.gradient_checkpointing
905
+ else FlaxBartEncoderLayer
906
+ )
907
+
908
+ if self.config.use_scan:
909
+ # all blocks are the same so we use nn.scan
910
+ assert not output_attentions, "cannot scan with output_attentions"
911
+ assert not output_hidden_states, "cannot scan with output_hidden_states"
912
+ hidden_states = (hidden_states,)
913
+ # we use a scale on all norms (even last layer) to allow scanning
914
+ hidden_states, _ = nn.scan(
915
+ layer,
916
+ variable_axes={"params": 0, "cache": 0},
917
+ split_rngs={"params": True, "dropout": True},
918
+ in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
919
+ length=n_layers,
920
+ )(
921
+ self.config,
922
+ dtype=self.dtype,
923
+ add_norm=self.config.ln_positions == "postln",
924
+ name="FlaxBartEncoderLayers",
925
+ )(
926
+ hidden_states,
927
+ attention_mask,
928
+ output_attentions,
929
+ deterministic,
930
+ )
931
+ hidden_states = hidden_states[0]
932
+ else:
933
+ for i in range(n_layers):
934
+ if output_hidden_states:
935
+ all_hidden_states += (hidden_states,)
936
+ # final layernorm on the output of the last layer
937
+ # or every 6 layers for Swin v2
938
+ add_norm = self.config.ln_positions == "postln" or (
939
+ self.config.ln_positions == "swinv2"
940
+ and ((i + 1) % 6 == 0)
941
+ and (i != n_layers - 1)
942
+ )
943
+ # we don't need to scale the norm for the last layer
944
+ use_scale = i != n_layers - 1
945
+ layer_outputs = layer(
946
+ self.config,
947
+ dtype=self.dtype,
948
+ add_norm=add_norm,
949
+ use_scale=use_scale,
950
+ name=f"FlaxBartEncoderLayer_{i}",
951
+ )(
952
+ hidden_states,
953
+ attention_mask,
954
+ output_attentions,
955
+ deterministic,
956
+ )
957
+ hidden_states = layer_outputs[0]
958
+ if output_attentions:
959
+ all_self_attns += (layer_outputs[1],)
960
+
961
+ # add hidden states from the last layer
962
+ if output_hidden_states:
963
+ all_hidden_states += (hidden_states,)
964
+
965
+ outputs = [
966
+ hidden_states,
967
+ all_hidden_states,
968
+ all_self_attns,
969
+ ]
970
+
971
+ if not return_dict:
972
+ return tuple(v for v in outputs if v is not None)
973
+
974
+ return FlaxBaseModelOutput(
975
+ last_hidden_state=hidden_states,
976
+ hidden_states=all_hidden_states,
977
+ attentions=all_self_attns,
978
+ )
979
+
980
+
981
+ class FlaxBartDecoderLayerCollection(nn.Module):
982
+ config: DalleBartConfig
983
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
984
+ """
985
+ Edits:
986
+ - use custom FlaxBartDecoderLayer
987
+ - allow Gradient Checkpointing (nn.remat)
988
+ """
989
+
990
+ @nn.compact
991
+ def __call__(
992
+ self,
993
+ hidden_states,
994
+ attention_mask,
995
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
996
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
997
+ deterministic: bool = True,
998
+ init_cache: bool = False,
999
+ output_attentions: bool = False,
1000
+ output_hidden_states: bool = False,
1001
+ return_dict: bool = True,
1002
+ ):
1003
+ # decoder layers
1004
+ all_hidden_states = () if output_hidden_states else None
1005
+ all_self_attns = () if output_attentions else None
1006
+ all_cross_attentions = (
1007
+ () if (output_attentions and encoder_hidden_states is not None) else None
1008
+ )
1009
+
1010
+ n_layers = self.config.decoder_layers
1011
+ layer = (
1012
+ remat(
1013
+ FlaxBartDecoderLayer,
1014
+ static_argnums=(4, 5, 6),
1015
+ prevent_cse=not self.config.use_scan,
1016
+ )
1017
+ if self.config.gradient_checkpointing
1018
+ else FlaxBartDecoderLayer
1019
+ )
1020
+
1021
+ if self.config.use_scan:
1022
+ # all blocks are the same so we use nn.scan
1023
+ assert not output_attentions, "cannot scan with output_attentions"
1024
+ assert not output_hidden_states, "cannot scan with output_hidden_states"
1025
+ hidden_states = (hidden_states,)
1026
+ # we use a scale on all norms (even last layer) to allow scanning
1027
+ hidden_states, _ = nn.scan(
1028
+ layer,
1029
+ variable_axes={"params": 0, "cache": 0},
1030
+ split_rngs={"params": True, "dropout": True},
1031
+ in_axes=(
1032
+ nn.broadcast,
1033
+ nn.broadcast,
1034
+ nn.broadcast,
1035
+ nn.broadcast,
1036
+ nn.broadcast,
1037
+ nn.broadcast,
1038
+ ),
1039
+ length=n_layers,
1040
+ )(
1041
+ self.config,
1042
+ dtype=self.dtype,
1043
+ add_norm=self.config.ln_positions == "postln",
1044
+ name="FlaxBartDecoderLayers",
1045
+ )(
1046
+ hidden_states,
1047
+ attention_mask,
1048
+ encoder_hidden_states,
1049
+ encoder_attention_mask,
1050
+ init_cache,
1051
+ output_attentions,
1052
+ deterministic,
1053
+ )
1054
+ hidden_states = hidden_states[0]
1055
+
1056
+ else:
1057
+ for i in range(n_layers):
1058
+ if output_hidden_states:
1059
+ all_hidden_states += (hidden_states,)
1060
+ # final layernorm on the output of the last layer
1061
+ # or every 6 layers for Swin v2
1062
+ add_norm = self.config.ln_positions == "postln" or (
1063
+ self.config.ln_positions == "swinv2"
1064
+ and ((i + 1) % 6 == 0)
1065
+ and (i != n_layers - 1)
1066
+ )
1067
+ # we don't need to scale the norm for the last layer
1068
+ use_scale = i != n_layers - 1
1069
+ layer_outputs = layer(
1070
+ self.config,
1071
+ dtype=self.dtype,
1072
+ add_norm=add_norm,
1073
+ use_scale=use_scale,
1074
+ name=f"FlaxBartDecoderLayer_{i}",
1075
+ )(
1076
+ hidden_states,
1077
+ attention_mask,
1078
+ encoder_hidden_states,
1079
+ encoder_attention_mask,
1080
+ init_cache,
1081
+ output_attentions,
1082
+ deterministic,
1083
+ )
1084
+
1085
+ hidden_states = layer_outputs[0]
1086
+ if output_attentions:
1087
+ all_self_attns += (layer_outputs[1],)
1088
+
1089
+ if encoder_hidden_states is not None:
1090
+ all_cross_attentions += (layer_outputs[2],)
1091
+
1092
+ # add hidden states from the last decoder layer
1093
+ if output_hidden_states:
1094
+ all_hidden_states += (hidden_states,)
1095
+
1096
+ outputs = [
1097
+ hidden_states,
1098
+ all_hidden_states,
1099
+ all_self_attns,
1100
+ all_cross_attentions,
1101
+ ]
1102
+
1103
+ if not return_dict:
1104
+ return tuple(v for v in outputs if v is not None)
1105
+
1106
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
1107
+ last_hidden_state=hidden_states,
1108
+ hidden_states=all_hidden_states,
1109
+ attentions=all_self_attns,
1110
+ cross_attentions=all_cross_attentions,
1111
+ )
1112
+
1113
+
1114
+ class FlaxBartEncoder(nn.Module):
1115
+ config: DalleBartConfig
1116
+ embed_tokens: nn.Embed
1117
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1118
+ """
1119
+ Edits:
1120
+ - offset set to 0 (no padding token)
1121
+ - use max_text_length instead of max_position_embeddings
1122
+ - use custom FlaxBartEncoderLayerCollection
1123
+ - embed_tokens cannot be None (issue at compile time)
1124
+ """
1125
+
1126
+ def setup(self):
1127
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
1128
+
1129
+ embed_dim = self.config.d_model
1130
+ self.padding_idx = self.config.pad_token_id
1131
+ self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
1132
+
1133
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
1134
+ # and adjust num_embeddings appropriately. Other models don't have this hack
1135
+ self.offset = 0
1136
+ if self.config.use_absolute_position_embeddings:
1137
+ self.embed_positions = nn.Embed(
1138
+ self.config.max_text_length + self.offset, # image length for BOS
1139
+ embed_dim,
1140
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1141
+ )
1142
+ self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
1143
+ self.layernorm_embedding = norm(
1144
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
1145
+ )
1146
+
1147
+ # postln is already applied in every layer
1148
+ if self.config.use_final_ln_encoder and self.config.ln_positions != "postln":
1149
+ self.final_ln = norm(
1150
+ self.config.ln_type,
1151
+ dtype=self.dtype,
1152
+ epsilon=1e-05,
1153
+ use_scale=self.config.force_ln_scale,
1154
+ )
1155
+ else:
1156
+ self.final_ln = None
1157
+
1158
+ def __call__(
1159
+ self,
1160
+ input_ids,
1161
+ attention_mask,
1162
+ position_ids,
1163
+ output_attentions: bool = False,
1164
+ output_hidden_states: bool = False,
1165
+ return_dict: bool = True,
1166
+ deterministic: bool = True,
1167
+ ):
1168
+ input_shape = input_ids.shape
1169
+ input_ids = input_ids.reshape(-1, input_shape[-1])
1170
+
1171
+ hidden_states = self.embed_tokens(input_ids) * self.embed_scale
1172
+
1173
+ if self.config.use_absolute_position_embeddings:
1174
+ embed_pos = self.embed_positions(position_ids + self.offset)
1175
+ hidden_states = hidden_states + embed_pos
1176
+
1177
+ hidden_states = self.layernorm_embedding(hidden_states)
1178
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
1179
+
1180
+ outputs = self.layers(
1181
+ hidden_states,
1182
+ attention_mask,
1183
+ deterministic=deterministic,
1184
+ output_attentions=output_attentions,
1185
+ output_hidden_states=output_hidden_states,
1186
+ return_dict=return_dict,
1187
+ )
1188
+
1189
+ if self.final_ln is None:
1190
+ final_output = outputs[0]
1191
+ else:
1192
+ final_output = self.final_ln(outputs[0])
1193
+
1194
+ if not return_dict:
1195
+ return (final_output,) + outputs[1:]
1196
+
1197
+ return FlaxBaseModelOutput(
1198
+ last_hidden_state=final_output,
1199
+ hidden_states=outputs.hidden_states,
1200
+ attentions=outputs.attentions,
1201
+ )
1202
+
1203
+
1204
+ class FlaxBartDecoder(nn.Module):
1205
+ config: DalleBartConfig
1206
+ embed_tokens: nn.Embed
1207
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1208
+ """
1209
+ Edits:
1210
+ - offset set to 0 (no padding token)
1211
+ - use image_length instead of max_position_embeddings
1212
+ - use custom FlaxBartDecoderLayerCollection
1213
+ - embed_tokens cannot be None (issue at compile time)
1214
+ """
1215
+
1216
+ def setup(self):
1217
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
1218
+
1219
+ embed_dim = self.config.d_model
1220
+ self.padding_idx = self.config.pad_token_id
1221
+ self.embed_scale = (
1222
+ math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
1223
+ )
1224
+
1225
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
1226
+ # and adjust num_embeddings appropriately. Other models don't have this hack
1227
+ self.offset = 0
1228
+ if self.config.use_absolute_position_embeddings:
1229
+ self.embed_positions = nn.Embed(
1230
+ self.config.image_length + self.offset, # image length for BOS
1231
+ embed_dim,
1232
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1233
+ )
1234
+
1235
+ self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
1236
+ self.layernorm_embedding = norm(
1237
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
1238
+ )
1239
+
1240
+ # postln is already applied in every layer
1241
+ if self.config.use_final_ln_decoder and self.config.ln_positions != "postln":
1242
+ self.final_ln = norm(
1243
+ self.config.ln_type,
1244
+ dtype=self.dtype,
1245
+ epsilon=1e-05,
1246
+ use_scale=self.config.force_ln_scale,
1247
+ )
1248
+
1249
+ def __call__(
1250
+ self,
1251
+ input_ids,
1252
+ attention_mask,
1253
+ position_ids,
1254
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
1255
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1256
+ init_cache: bool = False,
1257
+ output_attentions: bool = False,
1258
+ output_hidden_states: bool = False,
1259
+ return_dict: bool = True,
1260
+ deterministic: bool = True,
1261
+ ):
1262
+ input_shape = input_ids.shape
1263
+ input_ids = input_ids.reshape(-1, input_shape[-1])
1264
+
1265
+ hidden_states = self.embed_tokens(input_ids) * self.embed_scale
1266
+
1267
+ if self.config.use_absolute_position_embeddings:
1268
+ embed_pos = self.embed_positions(position_ids + self.offset)
1269
+ hidden_states = hidden_states + embed_pos
1270
+
1271
+ hidden_states = self.layernorm_embedding(hidden_states)
1272
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
1273
+
1274
+ outputs = self.layers(
1275
+ hidden_states,
1276
+ attention_mask,
1277
+ encoder_hidden_states,
1278
+ encoder_attention_mask,
1279
+ deterministic=deterministic,
1280
+ init_cache=init_cache,
1281
+ output_attentions=output_attentions,
1282
+ output_hidden_states=output_hidden_states,
1283
+ return_dict=return_dict,
1284
+ )
1285
+
1286
+ if self.final_ln is None:
1287
+ final_output = outputs[0]
1288
+ else:
1289
+ final_output = self.final_ln(outputs[0])
1290
+
1291
+ if not return_dict:
1292
+ return (final_output,) + outputs[1:]
1293
+
1294
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
1295
+ last_hidden_state=final_output,
1296
+ hidden_states=outputs.hidden_states,
1297
+ attentions=outputs.attentions,
1298
+ cross_attentions=outputs.cross_attentions,
1299
+ )
1300
+
1301
+
1302
+ class FlaxBartModule(FlaxBartModule):
1303
+ """
1304
+ Edits
1305
+ - use custom FlaxBartEncoder & FlaxBartDecoder
1306
+ - use separate embeddings for Encoder & Decoder
1307
+ """
1308
+
1309
+ def setup(self):
1310
+ encoder_embed_tokens = nn.Embed(
1311
+ self.config.encoder_vocab_size,
1312
+ self.config.d_model,
1313
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1314
+ )
1315
+ decoder_embed_tokens = nn.Embed(
1316
+ self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
1317
+ self.config.d_model,
1318
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1319
+ )
1320
+
1321
+ self.encoder = FlaxBartEncoder(
1322
+ self.config, dtype=self.dtype, embed_tokens=encoder_embed_tokens
1323
+ )
1324
+ self.decoder = FlaxBartDecoder(
1325
+ self.config, dtype=self.dtype, embed_tokens=decoder_embed_tokens
1326
+ )
1327
+
1328
+
1329
+ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
1330
+ """
1331
+ Edits:
1332
+ - no bias
1333
+ - lm_head set to image_vocab_size + 1 (for BOS)
1334
+ - uses custom FlaxBartModule
1335
+ """
1336
+
1337
+ def setup(self):
1338
+ self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
1339
+ self.lm_head = nn.Dense(
1340
+ self.config.image_vocab_size
1341
+ + 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
1342
+ use_bias=False,
1343
+ dtype=self.dtype,
1344
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
1345
+ )
1346
+
1347
+ def __call__(
1348
+ self,
1349
+ input_ids,
1350
+ attention_mask,
1351
+ decoder_input_ids,
1352
+ decoder_attention_mask,
1353
+ position_ids,
1354
+ decoder_position_ids,
1355
+ output_attentions: bool = False,
1356
+ output_hidden_states: bool = False,
1357
+ return_dict: bool = True,
1358
+ deterministic: bool = True,
1359
+ ):
1360
+ outputs = self.model(
1361
+ input_ids=input_ids,
1362
+ attention_mask=attention_mask,
1363
+ decoder_input_ids=decoder_input_ids,
1364
+ decoder_attention_mask=decoder_attention_mask,
1365
+ position_ids=position_ids,
1366
+ decoder_position_ids=decoder_position_ids,
1367
+ output_attentions=output_attentions,
1368
+ output_hidden_states=output_hidden_states,
1369
+ return_dict=return_dict,
1370
+ deterministic=deterministic,
1371
+ )
1372
+
1373
+ hidden_states = outputs[0]
1374
+
1375
+ if self.config.tie_word_embeddings:
1376
+ shared_embedding = self.model.variables["params"]["shared"]["embedding"]
1377
+ lm_logits = self.lm_head.apply(
1378
+ {"params": {"kernel": shared_embedding.T}}, hidden_states
1379
+ )
1380
+ else:
1381
+ lm_logits = self.lm_head(hidden_states)
1382
+
1383
+ if not return_dict:
1384
+ output = (lm_logits,) + outputs[1:]
1385
+ return output
1386
+
1387
+ return FlaxSeq2SeqLMOutput(
1388
+ logits=lm_logits,
1389
+ decoder_hidden_states=outputs.decoder_hidden_states,
1390
+ decoder_attentions=outputs.decoder_attentions,
1391
+ cross_attentions=outputs.cross_attentions,
1392
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1393
+ encoder_hidden_states=outputs.encoder_hidden_states,
1394
+ encoder_attentions=outputs.encoder_attentions,
1395
+ )
1396
+
1397
+
1398
+ @flax.struct.dataclass
1399
+ class SampleState:
1400
+ cur_len: jnp.ndarray
1401
+ sequences: jnp.ndarray
1402
+ running_token: jnp.ndarray
1403
+ is_sent_finished: jnp.ndarray
1404
+ prng_key: jnp.ndarray
1405
+ model_kwargs: Dict[str, jnp.ndarray]
1406
+ model_kwargs_uncond: Dict[str, jnp.ndarray]
1407
+
1408
+
1409
+ @flax.struct.dataclass
1410
+ class FlaxSampleOutput(ModelOutput):
1411
+ """
1412
+ Flax Base class for outputs of decoder-only generation models using sampling.
1413
+
1414
+
1415
+ Args:
1416
+ sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
1417
+ The generated sequences.
1418
+ """
1419
+
1420
+ sequences: jnp.ndarray = None
1421
+
1422
+
1423
+ class DalleBart(PretrainedFromWandbMixin, FlaxBartForConditionalGeneration):
1424
+ """
1425
+ Edits:
1426
+ - renamed from FlaxBartForConditionalGeneration
1427
+ - uses custom FlaxBartForConditionalGenerationModule
1428
+ - no bias in decode method
1429
+ - custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
1430
+ related to position embedding during model.generate()
1431
+ - custom generate method to allow super conditions
1432
+ - num_params property
1433
+ - unscan function
1434
+ """
1435
+
1436
+ module_class = FlaxBartForConditionalGenerationModule
1437
+ config_class = DalleBartConfig
1438
+
1439
+ def num_params(self, params=None):
1440
+ if params is None:
1441
+ params = self.params
1442
+ num_params = jax.tree_util.tree_map(
1443
+ lambda param: param.size, flatten_dict(unfreeze(params))
1444
+ ).values()
1445
+ return sum(list(num_params))
1446
+
1447
+ def unscan(self, params):
1448
+ if self.config.use_scan:
1449
+ self.config.use_scan = False
1450
+ params = flatten_dict(params)
1451
+ scanned_keys = [k for k in params.keys() if "layers" in k]
1452
+ for k in scanned_keys:
1453
+ v = params[k]
1454
+ name_idx = k.index("layers") + 1
1455
+ for i in range(len(v)):
1456
+ new_k = (
1457
+ *k[:name_idx],
1458
+ f"{k[name_idx][:-1]}_{i}",
1459
+ *k[name_idx + 1 :],
1460
+ )
1461
+ params[new_k] = v[i]
1462
+ del params[k]
1463
+ params = unflatten_dict(params)
1464
+ return params
1465
+
1466
+ def decode(
1467
+ self,
1468
+ decoder_input_ids,
1469
+ encoder_outputs,
1470
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1471
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1472
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1473
+ past_key_values: dict = None,
1474
+ output_attentions: Optional[bool] = None,
1475
+ output_hidden_states: Optional[bool] = None,
1476
+ return_dict: Optional[bool] = None,
1477
+ train: bool = False,
1478
+ params: dict = None,
1479
+ dropout_rng: PRNGKey = None,
1480
+ ):
1481
+ output_attentions = (
1482
+ output_attentions
1483
+ if output_attentions is not None
1484
+ else self.config.output_attentions
1485
+ )
1486
+ output_hidden_states = (
1487
+ output_hidden_states
1488
+ if output_hidden_states is not None
1489
+ else self.config.output_hidden_states
1490
+ )
1491
+ return_dict = (
1492
+ return_dict if return_dict is not None else self.config.return_dict
1493
+ )
1494
+
1495
+ encoder_hidden_states = encoder_outputs[0]
1496
+ if encoder_attention_mask is None:
1497
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
1498
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
1499
+
1500
+ batch_size, sequence_length = decoder_input_ids.shape
1501
+ if decoder_attention_mask is None:
1502
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
1503
+
1504
+ if decoder_position_ids is None:
1505
+ if past_key_values is not None:
1506
+ raise ValueError(
1507
+ "Make sure to provide `decoder_position_ids` when passing `past_key_values`."
1508
+ )
1509
+
1510
+ decoder_position_ids = jnp.broadcast_to(
1511
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1512
+ )
1513
+
1514
+ # Handle any PRNG if needed
1515
+ rngs = {}
1516
+ if dropout_rng is not None:
1517
+ rngs["dropout"] = dropout_rng
1518
+
1519
+ inputs = {"params": params or self.params}
1520
+
1521
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
1522
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
1523
+ # it can be changed by FlaxBartAttention module
1524
+ if past_key_values:
1525
+ inputs["cache"] = past_key_values
1526
+ mutable = ["cache"]
1527
+ else:
1528
+ mutable = False
1529
+
1530
+ def _decoder_forward(
1531
+ module,
1532
+ decoder_input_ids,
1533
+ decoder_attention_mask,
1534
+ decoder_position_ids,
1535
+ **kwargs,
1536
+ ):
1537
+ decoder_module = module._get_decoder_module()
1538
+ outputs = decoder_module(
1539
+ decoder_input_ids,
1540
+ decoder_attention_mask,
1541
+ decoder_position_ids,
1542
+ **kwargs,
1543
+ )
1544
+ hidden_states = outputs[0]
1545
+
1546
+ if self.config.tie_word_embeddings:
1547
+ shared_embedding = module.model.variables["params"]["shared"][
1548
+ "embedding"
1549
+ ]
1550
+ lm_logits = module.lm_head.apply(
1551
+ {"params": {"kernel": shared_embedding.T}}, hidden_states
1552
+ )
1553
+ else:
1554
+ lm_logits = module.lm_head(hidden_states)
1555
+
1556
+ return lm_logits, outputs
1557
+
1558
+ outputs = self.module.apply(
1559
+ inputs,
1560
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1561
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1562
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1563
+ encoder_hidden_states=encoder_hidden_states,
1564
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
1565
+ output_attentions=output_attentions,
1566
+ output_hidden_states=output_hidden_states,
1567
+ return_dict=return_dict,
1568
+ deterministic=not train,
1569
+ rngs=rngs,
1570
+ mutable=mutable,
1571
+ method=_decoder_forward,
1572
+ )
1573
+
1574
+ if past_key_values is None:
1575
+ lm_logits, decoder_outputs = outputs
1576
+ else:
1577
+ (lm_logits, decoder_outputs), past = outputs
1578
+
1579
+ if return_dict:
1580
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
1581
+ logits=lm_logits,
1582
+ hidden_states=decoder_outputs.hidden_states,
1583
+ attentions=decoder_outputs.attentions,
1584
+ cross_attentions=decoder_outputs.cross_attentions,
1585
+ )
1586
+ else:
1587
+ outputs = (lm_logits,) + decoder_outputs[1:]
1588
+
1589
+ # add updated cache to model output
1590
+ if past_key_values is not None and return_dict:
1591
+ outputs["past_key_values"] = unfreeze(past["cache"])
1592
+ return outputs
1593
+ elif past_key_values is not None and not return_dict:
1594
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
1595
+
1596
+ return outputs
1597
+
1598
+ def prepare_inputs_for_generation(
1599
+ self,
1600
+ decoder_input_ids,
1601
+ max_length,
1602
+ attention_mask: Optional[jnp.DeviceArray] = None,
1603
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
1604
+ encoder_outputs=None,
1605
+ **kwargs,
1606
+ ):
1607
+ # initializing the cache
1608
+ batch_size, seq_length = decoder_input_ids.shape
1609
+
1610
+ past_key_values = self.init_cache(batch_size, max_length - 1, encoder_outputs)
1611
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1612
+ # But since the decoder uses a causal mask, those positions are masked anyways.
1613
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
1614
+ extended_attention_mask = jnp.ones((batch_size, max_length - 1), dtype="i4")
1615
+ if decoder_attention_mask is not None:
1616
+ position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
1617
+ extended_attention_mask = lax.dynamic_update_slice(
1618
+ extended_attention_mask, decoder_attention_mask, (0, 0)
1619
+ )
1620
+ else:
1621
+ position_ids = jnp.broadcast_to(
1622
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
1623
+ )
1624
+
1625
+ return {
1626
+ "past_key_values": past_key_values,
1627
+ "encoder_outputs": encoder_outputs,
1628
+ "encoder_attention_mask": attention_mask,
1629
+ "decoder_attention_mask": extended_attention_mask,
1630
+ "decoder_position_ids": position_ids,
1631
+ }
1632
+
1633
+ def generate(
1634
+ self,
1635
+ input_ids: jnp.ndarray,
1636
+ attention_mask: Optional[jnp.ndarray] = None,
1637
+ max_length: Optional[int] = None,
1638
+ pad_token_id: Optional[int] = None,
1639
+ bos_token_id: Optional[int] = None,
1640
+ eos_token_id: Optional[int] = None,
1641
+ decoder_start_token_id: Optional[int] = None,
1642
+ do_sample: Optional[bool] = None,
1643
+ prng_key: Optional[jnp.ndarray] = None,
1644
+ top_k: Optional[int] = None,
1645
+ top_p: Optional[float] = None,
1646
+ temperature: Optional[float] = None,
1647
+ num_beams: Optional[int] = None,
1648
+ no_repeat_ngram_size: Optional[int] = None,
1649
+ min_length: Optional[int] = None,
1650
+ forced_bos_token_id: Optional[int] = None,
1651
+ forced_eos_token_id: Optional[int] = None,
1652
+ length_penalty: Optional[float] = None,
1653
+ early_stopping: Optional[bool] = None,
1654
+ trace: bool = True,
1655
+ params: Optional[Dict[str, jnp.ndarray]] = None,
1656
+ condition_scale: Optional[float] = 1.0,
1657
+ input_ids_uncond: Optional[jnp.ndarray] = None,
1658
+ attention_mask_uncond: Optional[jnp.ndarray] = None,
1659
+ **model_kwargs,
1660
+ ):
1661
+ """Edit: Allow super conditioning."""
1662
+
1663
+ # set init values
1664
+ max_length = max_length if max_length is not None else self.config.max_length
1665
+ bos_token_id = (
1666
+ bos_token_id if bos_token_id is not None else self.config.bos_token_id
1667
+ )
1668
+ pad_token_id = (
1669
+ pad_token_id if pad_token_id is not None else self.config.pad_token_id
1670
+ )
1671
+ eos_token_id = (
1672
+ eos_token_id if eos_token_id is not None else self.config.eos_token_id
1673
+ )
1674
+ decoder_start_token_id = (
1675
+ decoder_start_token_id
1676
+ if decoder_start_token_id
1677
+ else self.config.decoder_start_token_id
1678
+ )
1679
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
1680
+
1681
+ if decoder_start_token_id is None and self.config.is_encoder_decoder:
1682
+ raise ValueError(
1683
+ "`decoder_start_token_id` has to be defined for encoder-decoder generation."
1684
+ )
1685
+
1686
+ do_sample = do_sample if do_sample is not None else self.config.do_sample
1687
+ num_beams = num_beams if num_beams is not None else self.config.num_beams
1688
+
1689
+ if self.config.is_encoder_decoder:
1690
+ # add encoder_outputs to model_kwargs
1691
+ if model_kwargs.get("encoder_outputs") is None:
1692
+ model_kwargs_input = dict(model_kwargs)
1693
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
1694
+ input_ids,
1695
+ params,
1696
+ {"attention_mask": attention_mask, **model_kwargs_input},
1697
+ )
1698
+ if condition_scale != 1.0:
1699
+ assert (
1700
+ input_ids_uncond is not None
1701
+ ), "`input_ids_uncond` has to be defined for super conditioning."
1702
+ assert (
1703
+ do_sample is True
1704
+ ), "`do_sample` has to be True for super conditioning."
1705
+ assert (
1706
+ num_beams == 1
1707
+ ), "`num_beams` has to be 1 for super conditioning."
1708
+ model_kwargs_uncond = (
1709
+ self._prepare_encoder_decoder_kwargs_for_generation(
1710
+ input_ids_uncond,
1711
+ params,
1712
+ {
1713
+ "attention_mask": attention_mask_uncond,
1714
+ **model_kwargs_input,
1715
+ },
1716
+ )
1717
+ )
1718
+ else:
1719
+ model_kwargs_uncond = None
1720
+ # prepare decoder_input_ids for generation
1721
+ input_ids = (
1722
+ jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
1723
+ )
1724
+
1725
+ if not do_sample and num_beams == 1:
1726
+ logits_processor = self._get_logits_processor(
1727
+ no_repeat_ngram_size,
1728
+ min_length,
1729
+ max_length,
1730
+ eos_token_id,
1731
+ forced_bos_token_id,
1732
+ forced_eos_token_id,
1733
+ )
1734
+ return self._greedy_search(
1735
+ input_ids,
1736
+ max_length,
1737
+ pad_token_id,
1738
+ eos_token_id,
1739
+ logits_processor=logits_processor,
1740
+ trace=trace,
1741
+ params=params,
1742
+ model_kwargs=model_kwargs,
1743
+ )
1744
+ elif do_sample and num_beams == 1:
1745
+ logits_warper = self._get_logits_warper(
1746
+ top_k=top_k, top_p=top_p, temperature=temperature
1747
+ )
1748
+ logits_processor = self._get_logits_processor(
1749
+ no_repeat_ngram_size,
1750
+ min_length,
1751
+ max_length,
1752
+ eos_token_id,
1753
+ forced_bos_token_id,
1754
+ forced_eos_token_id,
1755
+ )
1756
+ return self._sample(
1757
+ input_ids,
1758
+ max_length,
1759
+ pad_token_id,
1760
+ eos_token_id,
1761
+ prng_key,
1762
+ logits_warper=logits_warper,
1763
+ logits_processor=logits_processor,
1764
+ trace=trace,
1765
+ params=params,
1766
+ model_kwargs=model_kwargs,
1767
+ condition_scale=condition_scale,
1768
+ model_kwargs_uncond=model_kwargs_uncond,
1769
+ )
1770
+ elif not do_sample and num_beams > 1:
1771
+ # broadcast input_ids & encoder_outputs
1772
+ input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
1773
+
1774
+ if "encoder_outputs" in model_kwargs:
1775
+ model_kwargs["encoder_outputs"][
1776
+ "last_hidden_state"
1777
+ ] = self._expand_to_num_beams(
1778
+ model_kwargs["encoder_outputs"]["last_hidden_state"],
1779
+ num_beams=num_beams,
1780
+ )
1781
+
1782
+ if "attention_mask" in model_kwargs:
1783
+ model_kwargs["attention_mask"] = self._expand_to_num_beams(
1784
+ model_kwargs["attention_mask"], num_beams=num_beams
1785
+ )
1786
+
1787
+ logits_processor = self._get_logits_processor(
1788
+ no_repeat_ngram_size,
1789
+ min_length,
1790
+ max_length,
1791
+ eos_token_id,
1792
+ forced_bos_token_id,
1793
+ forced_eos_token_id,
1794
+ )
1795
+
1796
+ return self._beam_search(
1797
+ input_ids,
1798
+ max_length,
1799
+ pad_token_id,
1800
+ eos_token_id,
1801
+ length_penalty=length_penalty,
1802
+ early_stopping=early_stopping,
1803
+ logits_processor=logits_processor,
1804
+ trace=trace,
1805
+ params=params,
1806
+ model_kwargs=model_kwargs,
1807
+ )
1808
+ else:
1809
+ raise NotImplementedError("`Beam sampling is currently not implemented.")
1810
+
1811
+ def _sample(
1812
+ self,
1813
+ input_ids: None,
1814
+ max_length: Optional[int] = None,
1815
+ pad_token_id: Optional[int] = None,
1816
+ eos_token_id: Optional[int] = None,
1817
+ prng_key: Optional[jnp.ndarray] = None,
1818
+ logits_processor=None,
1819
+ logits_warper=None,
1820
+ trace: bool = True,
1821
+ params: Optional[Dict[str, jnp.ndarray]] = None,
1822
+ model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
1823
+ condition_scale: float = 1.0,
1824
+ model_kwargs_uncond: Optional[Dict[str, jnp.ndarray]] = None,
1825
+ ):
1826
+ # init values
1827
+ max_length = max_length if max_length is not None else self.config.max_length
1828
+ pad_token_id = (
1829
+ pad_token_id if pad_token_id is not None else self.config.pad_token_id
1830
+ )
1831
+ eos_token_id = (
1832
+ eos_token_id if eos_token_id is not None else self.config.eos_token_id
1833
+ )
1834
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
1835
+
1836
+ batch_size, cur_len = input_ids.shape
1837
+
1838
+ eos_token_id = jnp.array(eos_token_id)
1839
+ pad_token_id = jnp.array(pad_token_id)
1840
+ cur_len = jnp.array(cur_len)
1841
+
1842
+ # per batch-item holding current token in loop.
1843
+ sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
1844
+ sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
1845
+
1846
+ # per batch-item state bit indicating if sentence has finished.
1847
+ is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
1848
+
1849
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
1850
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
1851
+ model = self.decode if self.config.is_encoder_decoder else self
1852
+
1853
+ # initialize model specific kwargs
1854
+ model_kwargs = self.prepare_inputs_for_generation(
1855
+ input_ids, max_length, **model_kwargs
1856
+ )
1857
+ if condition_scale != 1.0:
1858
+ model_kwargs_uncond = self.prepare_inputs_for_generation(
1859
+ input_ids, max_length, **model_kwargs_uncond
1860
+ )
1861
+
1862
+ # initialize state
1863
+ state = SampleState(
1864
+ cur_len=cur_len,
1865
+ sequences=sequences,
1866
+ running_token=input_ids,
1867
+ is_sent_finished=is_sent_finished,
1868
+ prng_key=prng_key,
1869
+ model_kwargs=model_kwargs,
1870
+ model_kwargs_uncond=model_kwargs_uncond,
1871
+ )
1872
+
1873
+ def sample_search_cond_fn(state):
1874
+ """state termination condition fn."""
1875
+ has_reached_max_length = state.cur_len == max_length
1876
+ all_sequence_finished = jnp.all(state.is_sent_finished)
1877
+ finish_generation = jnp.logical_or(
1878
+ has_reached_max_length, all_sequence_finished
1879
+ )
1880
+ return ~finish_generation
1881
+
1882
+ def sample_search_body_fn(state):
1883
+ """state update fn."""
1884
+ prng_key, prng_key_next = jax.random.split(state.prng_key)
1885
+ model_outputs = model(
1886
+ state.running_token, params=params, **state.model_kwargs
1887
+ )
1888
+
1889
+ logits = model_outputs.logits[:, -1]
1890
+
1891
+ # perform super conditioning
1892
+ # Source: @RiversHaveWings - https://twitter.com/RiversHaveWings/status/1478093658716966912?s=20&t=xdm-wZ61Wf7OLnE_NJHZ1w
1893
+ if condition_scale != 1.0:
1894
+ model_outputs_uncond = model(
1895
+ state.running_token, params=params, **state.model_kwargs_uncond
1896
+ )
1897
+ logits_uncond = model_outputs_uncond.logits[:, -1]
1898
+ logits = logits_uncond + condition_scale * (logits - logits_uncond)
1899
+ else:
1900
+ model_outputs_uncond = None
1901
+
1902
+ # apply min_length, ...
1903
+ logits = logits_processor(state.sequences, logits, state.cur_len)
1904
+ # apply top_k, top_k, temperature
1905
+ logits = logits_warper(logits, logits, state.cur_len)
1906
+
1907
+ next_token = jax.random.categorical(prng_key, logits, axis=-1)
1908
+
1909
+ next_is_sent_finished = state.is_sent_finished | (
1910
+ next_token == eos_token_id
1911
+ )
1912
+ next_token = (
1913
+ next_token * ~next_is_sent_finished
1914
+ + pad_token_id * next_is_sent_finished
1915
+ )
1916
+ next_token = next_token[:, None]
1917
+
1918
+ next_sequences = lax.dynamic_update_slice(
1919
+ state.sequences, next_token, (0, state.cur_len)
1920
+ )
1921
+ next_model_kwargs = self.update_inputs_for_generation(
1922
+ model_outputs, state.model_kwargs
1923
+ )
1924
+ next_model_kwargs_uncond = (
1925
+ self.update_inputs_for_generation(
1926
+ model_outputs_uncond, state.model_kwargs_uncond
1927
+ )
1928
+ if condition_scale != 1.0
1929
+ else None
1930
+ )
1931
+
1932
+ return SampleState(
1933
+ cur_len=state.cur_len + 1,
1934
+ sequences=next_sequences,
1935
+ running_token=next_token,
1936
+ is_sent_finished=next_is_sent_finished,
1937
+ model_kwargs=next_model_kwargs,
1938
+ model_kwargs_uncond=next_model_kwargs_uncond,
1939
+ prng_key=prng_key_next,
1940
+ )
1941
+
1942
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
1943
+ if input_ids.shape[1] > 1:
1944
+ state = sample_search_body_fn(state)
1945
+
1946
+ if not trace:
1947
+ state = self._run_loop_in_debug(
1948
+ sample_search_cond_fn, sample_search_body_fn, state
1949
+ )
1950
+ else:
1951
+ state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
1952
+
1953
+ return FlaxSampleOutput(sequences=state.sequences)
src/dalle_mini/model/partitions.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from flax.core.frozen_dict import freeze
4
+ from flax.traverse_util import flatten_dict, unflatten_dict
5
+ from jax.experimental import PartitionSpec as P
6
+
7
+ # utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
8
+ # Sentinels
9
+ _unmatched = object()
10
+
11
+ # For specifying empty leaf dict `{}`
12
+ empty_dict = object()
13
+
14
+
15
+ def _match(qs, ks):
16
+ """Return True if regexes in qs match any window of strings in tuple ks."""
17
+ # compile regexes and force complete match
18
+ qts = tuple(map(lambda x: re.compile(x + "$"), qs))
19
+ for i in range(len(ks) - len(qs) + 1):
20
+ matches = [x.match(y) for x, y in zip(qts, ks[i:])]
21
+ if matches and all(matches):
22
+ return True
23
+ return False
24
+
25
+
26
+ def _replacement_rules(rules):
27
+ def replace(key, val):
28
+ for rule, replacement in rules:
29
+ if _match(rule, key):
30
+ return replacement
31
+ return val
32
+
33
+ return replace
34
+
35
+
36
+ def _get_partition_rules():
37
+ return [
38
+ # embeddings
39
+ (("embed_positions", "embedding"), P("mp", None)),
40
+ (("embed_tokens", "embedding"), P("mp", None)),
41
+ (("rel_bias", "embedding"), P(None, "mp")),
42
+ # attention
43
+ (("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
44
+ (("out_proj", "kernel"), P("mp", None)),
45
+ # FFN
46
+ (("Dense_0", "kernel"), P(None, "mp")),
47
+ (("GLU.*", "Dense_1", "kernel"), P(None, "mp")),
48
+ (("GLU.*", "Dense_2", "kernel"), P("mp", None)),
49
+ (("FFN.*", "Dense_1", "kernel"), P("mp", None)),
50
+ # layer norms
51
+ (("(bias|scale)",), None),
52
+ (("lm_head", "kernel"), P(None, "mp")),
53
+ # head scale and tau
54
+ (("(head_scale|tau)",), None),
55
+ ]
56
+
57
+
58
+ def set_partitions(in_dict, use_scan):
59
+ rules = _get_partition_rules()
60
+ replace = _replacement_rules(rules)
61
+ initd = {k: _unmatched for k in flatten_dict(in_dict)}
62
+ result = {k: replace(k, v) for k, v in initd.items()}
63
+ for k, v in result.items():
64
+ if v == _unmatched:
65
+ print(f"Unmatched -> {k}")
66
+ l = list(result.keys())
67
+ if use_scan:
68
+ # add None dimension to layers
69
+ result = {
70
+ k: (P(*(None,) + v) if v is not None else None)
71
+ if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"])
72
+ else v
73
+ for k, v in result.items()
74
+ }
75
+ assert _unmatched not in result.values(), "Incomplete partition spec."
76
+ return freeze(unflatten_dict(result))
src/dalle_mini/model/processor.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ DalleBart processor """
2
+
3
+ from typing import List
4
+
5
+ import jax.numpy as jnp
6
+
7
+ from .configuration import DalleBartConfig
8
+ from .text import TextNormalizer
9
+ from .tokenizer import DalleBartTokenizer
10
+ from .utils import PretrainedFromWandbMixin
11
+
12
+
13
+ class DalleBartProcessorBase:
14
+ def __init__(
15
+ self, tokenizer: DalleBartTokenizer, normalize_text: bool, max_text_length: int
16
+ ):
17
+ self.tokenizer = tokenizer
18
+ self.normalize_text = normalize_text
19
+ self.max_text_length = max_text_length
20
+ if normalize_text:
21
+ self.text_processor = TextNormalizer()
22
+ # create unconditional tokens
23
+ uncond = self.tokenizer(
24
+ "",
25
+ return_tensors="jax",
26
+ padding="max_length",
27
+ truncation=True,
28
+ max_length=self.max_text_length,
29
+ ).data
30
+ self.input_ids_uncond = uncond["input_ids"]
31
+ self.attention_mask_uncond = uncond["attention_mask"]
32
+
33
+ def __call__(self, text: List[str] = None):
34
+ # check that text is not a string
35
+ assert not isinstance(text, str), "text must be a list of strings"
36
+
37
+ if self.normalize_text:
38
+ text = [self.text_processor(t) for t in text]
39
+ res = self.tokenizer(
40
+ text,
41
+ return_tensors="jax",
42
+ padding="max_length",
43
+ truncation=True,
44
+ max_length=self.max_text_length,
45
+ ).data
46
+ # tokens used only with super conditioning
47
+ n = len(text)
48
+ res["input_ids_uncond"] = jnp.repeat(self.input_ids_uncond, n, axis=0)
49
+ res["attention_mask_uncond"] = jnp.repeat(self.attention_mask_uncond, n, axis=0)
50
+ return res
51
+
52
+ @classmethod
53
+ def from_pretrained(cls, *args, **kwargs):
54
+ tokenizer = DalleBartTokenizer.from_pretrained(*args, **kwargs)
55
+ config = DalleBartConfig.from_pretrained(*args, **kwargs)
56
+ return cls(tokenizer, config.normalize_text, config.max_text_length)
57
+
58
+
59
+ class DalleBartProcessor(PretrainedFromWandbMixin, DalleBartProcessorBase):
60
+ pass
src/dalle_mini/model/text.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for processing text.
3
+ """
4
+
5
+ import html
6
+ import math
7
+ import random
8
+ import re
9
+ from pathlib import Path
10
+
11
+ import emoji
12
+ import ftfy
13
+ from huggingface_hub import hf_hub_download
14
+ from unidecode import unidecode
15
+
16
+ # based on wiki word occurrence
17
+ person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
18
+ temp_token = "xtokx" # avoid repeating chars
19
+
20
+
21
+ class HashtagProcessor:
22
+ # Adapted from wordninja library
23
+ # We use our wikipedia word count + a good heuristic to make it work
24
+ def __init__(self):
25
+ wiki_word_frequency = hf_hub_download(
26
+ "dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt"
27
+ )
28
+ self._word_cost = (
29
+ l.split()[0]
30
+ for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines()
31
+ )
32
+ self._word_cost = {
33
+ str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)
34
+ }
35
+ self._max_word = max(len(x) for x in self._word_cost.keys())
36
+ self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+")
37
+
38
+ def __call__(self, s):
39
+ """Uses dynamic programming to infer the location of spaces in a string without spaces."""
40
+ l = [self._split(x) for x in self._SPLIT_RE.split(s)]
41
+ return " ".join([item for sublist in l for item in sublist])
42
+
43
+ def _split(self, s):
44
+ # Find the best match for the i first characters, assuming cost has
45
+ # been built for the i-1 first characters.
46
+ # Returns a pair (match_cost, match_length).
47
+ def best_match(i):
48
+ candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i]))
49
+ return min(
50
+ (c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1)
51
+ for k, c in candidates
52
+ )
53
+
54
+ # Build the cost array
55
+ cost = [0]
56
+ for i in range(1, len(s) + 1):
57
+ c, k = best_match(i)
58
+ cost.append(c)
59
+
60
+ # Backtrack to recover the minimal-cost string.
61
+ out = []
62
+ i = len(s)
63
+ while i > 0:
64
+ c, k = best_match(i)
65
+ assert c == cost[i]
66
+ newToken = True
67
+ if not s[i - k : i] == "'": # ignore a lone apostrophe
68
+ if len(out) > 0:
69
+ # re-attach split 's and split digits
70
+ if out[-1] == "'s" or (
71
+ s[i - 1].isdigit() and out[-1][0].isdigit()
72
+ ): # digit followed by digit
73
+ out[-1] = (
74
+ s[i - k : i] + out[-1]
75
+ ) # combine current token with previous token
76
+ newToken = False
77
+
78
+ if newToken:
79
+ out.append(s[i - k : i])
80
+
81
+ i -= k
82
+
83
+ return reversed(out)
84
+
85
+
86
+ def replace_person_token(t):
87
+ "Used for CC12M"
88
+ t = re.sub("<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
89
+ while "<person>" in t:
90
+ t = t.replace(
91
+ "<person>", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1
92
+ )
93
+ return t
94
+
95
+
96
+ def fix_html(t):
97
+ # from OpenAI CLIP
98
+ return html.unescape(html.unescape(t))
99
+
100
+
101
+ def replace_punctuation_with_commas(t):
102
+ return re.sub("[()[\].,|:;?!=+~\-\/{}]", ",", t)
103
+
104
+
105
+ def simplify_quotes(t):
106
+ return re.sub("""['"`]""", ' " ', t)
107
+
108
+
109
+ def merge_quotes(t):
110
+ return re.sub('(\s*"+\s*)+', ' " ', t)
111
+
112
+
113
+ def remove_comma_numbers(t):
114
+ def _f(t):
115
+ return re.sub("(\d),(\d{3})", r"\1\2", t)
116
+
117
+ return _f(_f(t))
118
+
119
+
120
+ def pre_process_dot_numbers(t):
121
+ return re.sub("(\w)\.(\w)", rf"\1{temp_token}dot{temp_token}\2", t)
122
+
123
+
124
+ def post_process_dot_numbers(t):
125
+ return re.sub(f"{temp_token}dot{temp_token}", ".", t)
126
+
127
+
128
+ def pre_process_quotes(t):
129
+ # allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
130
+ return re.sub(
131
+ r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", rf"{temp_token}quote{temp_token}", t
132
+ )
133
+
134
+
135
+ def post_process_quotes(t):
136
+ return re.sub(f"{temp_token}quote{temp_token}", "'", t)
137
+
138
+
139
+ def pre_process_dates(t):
140
+ return re.sub("(\d)/(\d)", rf"\1{temp_token}slash{temp_token}\2", t)
141
+
142
+
143
+ def post_process_dates(t):
144
+ return re.sub(f"{temp_token}slash{temp_token}", "/", t)
145
+
146
+
147
+ def merge_commas(t):
148
+ return re.sub("(\s*,+\s*)+", ", ", t)
149
+
150
+
151
+ def add_space_after_commas(t):
152
+ return re.sub(",", ", ", t)
153
+
154
+
155
+ def handle_special_chars(t):
156
+ "Handle special characters"
157
+ # replace "-" with a space when between words without space
158
+ t = re.sub("(\w)-(\w)", r"\1 \2", t)
159
+ # always add space around some characters
160
+ return re.sub("([%&\/$*])", r" \1 ", t)
161
+
162
+
163
+ def expand_hashtags(t, hashtag_processor):
164
+ "Remove # and try to split words"
165
+ return re.sub("#(\w+)", lambda m: hashtag_processor(m.group(1)), t)
166
+
167
+
168
+ _re_ignore_chars = r"[_#\\]"
169
+
170
+
171
+ def ignore_chars(t):
172
+ "Ignore useless characters"
173
+ return re.sub(_re_ignore_chars, " ", t)
174
+
175
+
176
+ def remove_extra_spaces(t):
177
+ "Remove extra spaces (including \t and \n)"
178
+ return re.sub("\s+", " ", t)
179
+
180
+
181
+ def remove_repeating_chars(t):
182
+ "If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance"
183
+ return re.sub(r"(\D)(\1{3,})", r"\1", t)
184
+
185
+
186
+ def remove_urls(t):
187
+ return re.sub(r"http\S+", "", t)
188
+
189
+
190
+ def remove_html_tags(t):
191
+ return re.sub("<[^<]+?>", " ", t)
192
+
193
+
194
+ def remove_first_last_commas(t):
195
+ t = t.strip()
196
+ t = t[:-1] if t and t[-1] == "," else t
197
+ t = t[1:] if t and t[0] == "," else t
198
+ return t.strip()
199
+
200
+
201
+ def remove_wiki_ref(t):
202
+ t = re.sub(r"\A\s*\[\d+\]", "", t)
203
+ return re.sub(r"\[\d+\]\s*\Z", "", t)
204
+
205
+
206
+ class TextNormalizer:
207
+ "Normalize text"
208
+
209
+ def __init__(self):
210
+ self._hashtag_processor = HashtagProcessor()
211
+
212
+ def __call__(self, t):
213
+ # fix some characters
214
+ t = ftfy.fix_text(t)
215
+ # fix html
216
+ t = fix_html(t)
217
+ # decode emojis (would be removed by unidecode)
218
+ t = emoji.demojize(t)
219
+ # decode and simplify text: see unidecode library
220
+ t = unidecode(t)
221
+ # lower case
222
+ t = t.lower()
223
+ # replace <PERSON> (for CC12M)
224
+ t = replace_person_token(t)
225
+ # remove wiki reference (for WIT)
226
+ t = remove_wiki_ref(t)
227
+ # remove html tags
228
+ t = remove_html_tags(t)
229
+ # remove urls
230
+ t = remove_urls(t)
231
+ # remove commas in numbers
232
+ t = remove_comma_numbers(t)
233
+ # handle dots in numbers and quotes - Part 1
234
+ t = pre_process_dot_numbers(t)
235
+ t = pre_process_quotes(t)
236
+ t = pre_process_dates(t)
237
+ # handle special characters
238
+ t = handle_special_chars(t)
239
+ # handle hashtags
240
+ t = expand_hashtags(t, self._hashtag_processor)
241
+ # ignore useless characters
242
+ t = ignore_chars(t)
243
+ # simplify quotes
244
+ t = simplify_quotes(t)
245
+ # all punctuation becomes commas
246
+ t = replace_punctuation_with_commas(t)
247
+ # handle dots in numbers and quotes - Part 2
248
+ t = post_process_dot_numbers(t)
249
+ t = post_process_quotes(t)
250
+ t = post_process_dates(t)
251
+ # handle repeating characters
252
+ t = remove_repeating_chars(t)
253
+ # merge quotes
254
+ t = merge_quotes(t)
255
+ # merge commas
256
+ t = merge_commas(t)
257
+ # remove multiple spaces
258
+ t = remove_extra_spaces(t)
259
+ # remove first and last comma
260
+ t = remove_first_last_commas(t)
261
+ # always start with a space
262
+ return f" {t}"
src/dalle_mini/model/tokenizer.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """ DalleBart tokenizer """
2
+ from transformers import BartTokenizerFast
3
+
4
+ from .utils import PretrainedFromWandbMixin
5
+
6
+
7
+ class DalleBartTokenizer(PretrainedFromWandbMixin, BartTokenizerFast):
8
+ pass
src/dalle_mini/model/utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from pathlib import Path
4
+
5
+ import wandb
6
+
7
+
8
+ class PretrainedFromWandbMixin:
9
+ @classmethod
10
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
11
+ """
12
+ Initializes from a wandb artifact or delegates loading to the superclass.
13
+ """
14
+ with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
15
+ if ":" in pretrained_model_name_or_path and not os.path.isdir(
16
+ pretrained_model_name_or_path
17
+ ):
18
+ # wandb artifact
19
+ if wandb.run is not None:
20
+ artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
21
+ else:
22
+ artifact = wandb.Api().artifact(pretrained_model_name_or_path)
23
+ pretrained_model_name_or_path = artifact.download(tmp_dir)
24
+
25
+ return super(PretrainedFromWandbMixin, cls).from_pretrained(
26
+ pretrained_model_name_or_path, *model_args, **kwargs
27
+ )
tools/dataset/encode_dataset.ipynb ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d0b72877",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Pre-encoding a dataset for DALLE·mini"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "ba7b31e6",
14
+ "metadata": {},
15
+ "source": [
16
+ "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
17
+ "\n",
18
+ "Adapt it to your own dataset and image encoder.\n",
19
+ "\n",
20
+ "At the end you should have a dataset of pairs:\n",
21
+ "* a caption defined as a string\n",
22
+ "* an encoded image defined as a list of int."
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "id": "3b59489e",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "from tqdm.notebook import tqdm\n",
33
+ "\n",
34
+ "import torchvision.transforms as T\n",
35
+ "\n",
36
+ "import webdataset as wds\n",
37
+ "\n",
38
+ "import jax\n",
39
+ "import braceexpand\n",
40
+ "from pathlib import Path"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "id": "c7c4c1e6",
46
+ "metadata": {},
47
+ "source": [
48
+ "## Configuration Parameters"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 3,
54
+ "id": "1265dbfe",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "shards = \"my_images/shard-{0000..0008}.tar\" # defined using braceexpand format as used by webdataset\n",
59
+ "encoded_output = Path(\"encoded_data\") # where we will save our encoded data\n",
60
+ "\n",
61
+ "VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
62
+ " \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
63
+ " \"85eb5d3b51a1c62a0cc8f4ccdee9882c0d0bd384\",\n",
64
+ ")\n",
65
+ "\n",
66
+ "# good defaults for a TPU v3-8\n",
67
+ "batch_size = 128 # Per device\n",
68
+ "num_workers = 8 # For parallel processing\n",
69
+ "total_bs = batch_size * jax.device_count() # You can use a smaller size while testing\n",
70
+ "save_frequency = 128 # Number of batches to create a new file (180MB for f16 and 720MB for f8 per file)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 5,
76
+ "id": "cd956ec6-7d98-4d4d-a454-f80fe857eadd",
77
+ "metadata": {},
78
+ "outputs": [
79
+ {
80
+ "data": {
81
+ "text/plain": [
82
+ "['XXX/shard-0000.tar',\n",
83
+ " 'XXX/shard-0001.tar',\n",
84
+ " 'XXX/shard-0002.tar',\n",
85
+ " 'XXX/shard-0003.tar',\n",
86
+ " 'XXX/shard-0004.tar',\n",
87
+ " 'XXX/shard-0005.tar',\n",
88
+ " 'XXX/shard-0006.tar',\n",
89
+ " 'XXX/shard-0007.tar',\n",
90
+ " 'XXX/shard-0008.tar']"
91
+ ]
92
+ },
93
+ "execution_count": 5,
94
+ "metadata": {},
95
+ "output_type": "execute_result"
96
+ }
97
+ ],
98
+ "source": [
99
+ "shards = list(\n",
100
+ " braceexpand.braceexpand(shards)\n",
101
+ ") # better display for tqdm with known length"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "markdown",
106
+ "id": "75dba8e2",
107
+ "metadata": {},
108
+ "source": [
109
+ "## Load data"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "markdown",
114
+ "id": "a1e8fb95",
115
+ "metadata": {},
116
+ "source": [
117
+ "We load data using `webdataset`."
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "id": "9ef5de9e",
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "ds = (\n",
128
+ " wds.WebDataset(shards, handler=wds.warn_and_continue)\n",
129
+ " .decode(\"rgb\", handler=wds.warn_and_continue)\n",
130
+ " .to_tuple(\"jpg\", \"txt\") # assumes image is in `jpg` and caption in `txt`\n",
131
+ " .batched(total_bs) # load in batch per worker (faster)\n",
132
+ ")"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "markdown",
137
+ "id": "90981824",
138
+ "metadata": {},
139
+ "source": [
140
+ "Note:\n",
141
+ "* you can also shuffle shards and items using `shardshuffle` and `shuffle` if necessary.\n",
142
+ "* you may need to resize images in your pipeline (with `map_dict` for example), we assume they are already set to 256x256.\n",
143
+ "* you can also filter out some items using `select`."
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "markdown",
148
+ "id": "129c377d",
149
+ "metadata": {},
150
+ "source": [
151
+ "We can now inspect our data."
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": null,
157
+ "id": "8cac98cb",
158
+ "metadata": {
159
+ "scrolled": true
160
+ },
161
+ "outputs": [],
162
+ "source": [
163
+ "%%time\n",
164
+ "images, captions = next(iter(ds))"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "id": "cd268fbf",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "images.shape"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "id": "5acfc4d8",
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "captions[:10]"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": null,
190
+ "id": "c24693c0",
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "T.ToPILImage()(images[0].permute(2, 0, 1))"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "markdown",
199
+ "id": "3059ffb1",
200
+ "metadata": {},
201
+ "source": [
202
+ "Finally we create our dataloader."
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "id": "c227c551",
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "dl = (\n",
213
+ " wds.WebLoader(ds, batch_size=None, num_workers=8).unbatched().batched(total_bs)\n",
214
+ ") # avoid partial batch at the end of each worker"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "markdown",
219
+ "id": "a354472b",
220
+ "metadata": {},
221
+ "source": [
222
+ "## Image encoder\n",
223
+ "\n",
224
+ "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "id": "47a8b818",
231
+ "metadata": {
232
+ "scrolled": true
233
+ },
234
+ "outputs": [],
235
+ "source": [
236
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
237
+ "from flax.jax_utils import replicate\n",
238
+ "\n",
239
+ "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")\n",
240
+ "vqgan_params = replicate(vqgan.params)"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "markdown",
245
+ "id": "62ad01c3",
246
+ "metadata": {},
247
+ "source": [
248
+ "## Encoding"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "markdown",
253
+ "id": "20357f74",
254
+ "metadata": {},
255
+ "source": [
256
+ "Encoding is really simple using `shard` to automatically distribute batches across devices and `pmap`."
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "id": "322a4619",
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "from flax.training.common_utils import shard\n",
267
+ "from functools import partial\n",
268
+ "\n",
269
+ "\n",
270
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
271
+ "def p_encode(batch, params):\n",
272
+ " # Not sure if we should `replicate` params, does not seem to have any effect\n",
273
+ " _, indices = vqgan.encode(batch, params=params)\n",
274
+ " return indices"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": null,
280
+ "id": "ff6c10d4",
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "import pandas as pd\n",
285
+ "\n",
286
+ "\n",
287
+ "def encode_dataset(dataloader, output_dir, save_frequency):\n",
288
+ " output_dir.mkdir(parents=True, exist_ok=True)\n",
289
+ " all_captions = []\n",
290
+ " all_encoding = []\n",
291
+ " n_file = 1\n",
292
+ " for idx, (images, captions) in enumerate(tqdm(dataloader)):\n",
293
+ " images = images.numpy()\n",
294
+ " n = len(images) // 8 * 8\n",
295
+ " if n != len(images):\n",
296
+ " # get the max number of images we can (multiple of 8)\n",
297
+ " print(f\"Different sizes {n} vs {len(images)}\")\n",
298
+ " images = images[:n]\n",
299
+ " captions = captions[:n]\n",
300
+ " if not len(captions):\n",
301
+ " print(f\"No images/captions in batch...\")\n",
302
+ " continue\n",
303
+ " images = shard(images)\n",
304
+ " encoded = p_encode(images, vqgan_params)\n",
305
+ " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
306
+ " all_captions.extend(captions)\n",
307
+ " all_encoding.extend(encoded.tolist())\n",
308
+ "\n",
309
+ " # save files\n",
310
+ " if (idx + 1) % save_frequency == 0:\n",
311
+ " print(f\"Saving file {n_file}\")\n",
312
+ " batch_df = pd.DataFrame.from_dict(\n",
313
+ " {\"caption\": all_captions, \"encoding\": all_encoding}\n",
314
+ " )\n",
315
+ " batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")\n",
316
+ " all_captions = []\n",
317
+ " all_encoding = []\n",
318
+ " n_file += 1\n",
319
+ "\n",
320
+ " if len(all_captions):\n",
321
+ " print(f\"Saving final file {n_file}\")\n",
322
+ " batch_df = pd.DataFrame.from_dict(\n",
323
+ " {\"caption\": all_captions, \"encoding\": all_encoding}\n",
324
+ " )\n",
325
+ " batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "id": "7704863d",
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": [
335
+ "encode_dataset(dl, output_dir=encoded_output, save_frequency=save_frequency)"
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "markdown",
340
+ "id": "8953dd84",
341
+ "metadata": {},
342
+ "source": [
343
+ "----"
344
+ ]
345
+ }
346
+ ],
347
+ "metadata": {
348
+ "interpreter": {
349
+ "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
350
+ },
351
+ "kernelspec": {
352
+ "display_name": "Python 3 (ipykernel)",
353
+ "language": "python",
354
+ "name": "python3"
355
+ },
356
+ "language_info": {
357
+ "codemirror_mode": {
358
+ "name": "ipython",
359
+ "version": 3
360
+ },
361
+ "file_extension": ".py",
362
+ "mimetype": "text/x-python",
363
+ "name": "python",
364
+ "nbconvert_exporter": "python",
365
+ "pygments_lexer": "ipython3",
366
+ "version": "3.9.7"
367
+ }
368
+ },
369
+ "nbformat": 4,
370
+ "nbformat_minor": 5
371
+ }
tools/inference/inference_pipeline.ipynb ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "view-in-github",
7
+ "colab_type": "text"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "118UKH5bWCGa"
17
+ },
18
+ "source": [
19
+ "# DALL·E mini - Inference pipeline\n",
20
+ "\n",
21
+ "*Generate images from a text prompt*\n",
22
+ "\n",
23
+ "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
24
+ "\n",
25
+ "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
26
+ "\n",
27
+ "Just want to play? Use directly [the app](https://www.craiyon.com/).\n",
28
+ "\n",
29
+ "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "metadata": {
35
+ "id": "dS8LbaonYm3a"
36
+ },
37
+ "source": [
38
+ "## 🛠️ Installation and set-up"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {
45
+ "id": "uzjAM2GBYpZX"
46
+ },
47
+ "outputs": [],
48
+ "source": [
49
+ "# Required only for colab environments + GPU\n",
50
+ "!pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
51
+ "\n",
52
+ "# Install required libraries\n",
53
+ "!pip install -q dalle-mini\n",
54
+ "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "markdown",
59
+ "metadata": {
60
+ "id": "ozHzTkyv8cqU"
61
+ },
62
+ "source": [
63
+ "We load required models:\n",
64
+ "* DALL·E mini for text to encoded images\n",
65
+ "* VQGAN for decoding images\n",
66
+ "* CLIP for scoring predictions"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {
73
+ "id": "K6CxW2o42f-w"
74
+ },
75
+ "outputs": [],
76
+ "source": [
77
+ "# Model references\n",
78
+ "\n",
79
+ "# dalle-mega\n",
80
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
81
+ "DALLE_COMMIT_ID = None\n",
82
+ "\n",
83
+ "# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n",
84
+ "# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n",
85
+ "\n",
86
+ "# VQGAN model\n",
87
+ "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
88
+ "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\""
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {
95
+ "id": "Yv-aR3t4Oe5v"
96
+ },
97
+ "outputs": [],
98
+ "source": [
99
+ "import jax\n",
100
+ "import jax.numpy as jnp\n",
101
+ "\n",
102
+ "# check how many devices are available\n",
103
+ "jax.local_device_count()"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "metadata": {
110
+ "id": "92zYmvsQ38vL"
111
+ },
112
+ "outputs": [],
113
+ "source": [
114
+ "# Load models & tokenizer\n",
115
+ "from dalle_mini import DalleBart, DalleBartProcessor\n",
116
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
117
+ "from transformers import CLIPProcessor, FlaxCLIPModel\n",
118
+ "\n",
119
+ "# Load dalle-mini\n",
120
+ "model, params = DalleBart.from_pretrained(\n",
121
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
122
+ ")\n",
123
+ "\n",
124
+ "# Load VQGAN\n",
125
+ "vqgan, vqgan_params = VQModel.from_pretrained(\n",
126
+ " VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n",
127
+ ")"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "markdown",
132
+ "metadata": {
133
+ "id": "o_vH2X1tDtzA"
134
+ },
135
+ "source": [
136
+ "Model parameters are replicated on each device for faster inference."
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "metadata": {
143
+ "id": "wtvLoM48EeVw"
144
+ },
145
+ "outputs": [],
146
+ "source": [
147
+ "from flax.jax_utils import replicate\n",
148
+ "\n",
149
+ "params = replicate(params)\n",
150
+ "vqgan_params = replicate(vqgan_params)"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "markdown",
155
+ "metadata": {
156
+ "id": "0A9AHQIgZ_qw"
157
+ },
158
+ "source": [
159
+ "Model functions are compiled and parallelized to take advantage of multiple devices."
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "metadata": {
166
+ "id": "sOtoOmYsSYPz"
167
+ },
168
+ "outputs": [],
169
+ "source": [
170
+ "from functools import partial\n",
171
+ "\n",
172
+ "\n",
173
+ "# model inference\n",
174
+ "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
175
+ "def p_generate(\n",
176
+ " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
177
+ "):\n",
178
+ " return model.generate(\n",
179
+ " **tokenized_prompt,\n",
180
+ " prng_key=key,\n",
181
+ " params=params,\n",
182
+ " top_k=top_k,\n",
183
+ " top_p=top_p,\n",
184
+ " temperature=temperature,\n",
185
+ " condition_scale=condition_scale,\n",
186
+ " )\n",
187
+ "\n",
188
+ "\n",
189
+ "# decode image\n",
190
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
191
+ "def p_decode(indices, params):\n",
192
+ " return vqgan.decode_code(indices, params=params)"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "markdown",
197
+ "metadata": {
198
+ "id": "HmVN6IBwapBA"
199
+ },
200
+ "source": [
201
+ "Keys are passed to the model on each device to generate unique inference per device."
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "metadata": {
208
+ "id": "4CTXmlUkThhX"
209
+ },
210
+ "outputs": [],
211
+ "source": [
212
+ "import random\n",
213
+ "\n",
214
+ "# create a random key\n",
215
+ "seed = random.randint(0, 2**32 - 1)\n",
216
+ "key = jax.random.PRNGKey(seed)"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "markdown",
221
+ "metadata": {
222
+ "id": "BrnVyCo81pij"
223
+ },
224
+ "source": [
225
+ "## 🖍 Text Prompt"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "markdown",
230
+ "metadata": {
231
+ "id": "rsmj0Aj5OQox"
232
+ },
233
+ "source": [
234
+ "Our model requires processing prompts."
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": null,
240
+ "metadata": {
241
+ "id": "YjjhUychOVxm"
242
+ },
243
+ "outputs": [],
244
+ "source": [
245
+ "from dalle_mini import DalleBartProcessor\n",
246
+ "\n",
247
+ "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "markdown",
252
+ "metadata": {
253
+ "id": "BQ7fymSPyvF_"
254
+ },
255
+ "source": [
256
+ "Let's define some text prompts."
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "metadata": {
263
+ "id": "x_0vI9ge1oKr"
264
+ },
265
+ "outputs": [],
266
+ "source": [
267
+ "prompts = [\n",
268
+ " \"sunset over a lake in the mountains\",\n",
269
+ " \"the Eiffel tower landing on the moon\",\n",
270
+ "]"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "markdown",
275
+ "metadata": {
276
+ "id": "XlZUG3SCLnGE"
277
+ },
278
+ "source": [
279
+ "Note: we could use the same prompt multiple times for faster inference."
280
+ ]
281
+ },
282
+ {
283
+ "cell_type": "code",
284
+ "execution_count": null,
285
+ "metadata": {
286
+ "id": "VKjEZGjtO49k"
287
+ },
288
+ "outputs": [],
289
+ "source": [
290
+ "tokenized_prompts = processor(prompts)"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "markdown",
295
+ "metadata": {
296
+ "id": "-CEJBnuJOe5z"
297
+ },
298
+ "source": [
299
+ "Finally we replicate the prompts onto each device."
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": null,
305
+ "metadata": {
306
+ "id": "lQePgju5Oe5z"
307
+ },
308
+ "outputs": [],
309
+ "source": [
310
+ "tokenized_prompt = replicate(tokenized_prompts)"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "markdown",
315
+ "metadata": {
316
+ "id": "phQ9bhjRkgAZ"
317
+ },
318
+ "source": [
319
+ "## 🎨 Generate images\n",
320
+ "\n",
321
+ "We generate images using dalle-mini model and decode them with the VQGAN."
322
+ ]
323
+ },
324
+ {
325
+ "cell_type": "code",
326
+ "execution_count": null,
327
+ "metadata": {
328
+ "id": "d0wVkXpKqnHA"
329
+ },
330
+ "outputs": [],
331
+ "source": [
332
+ "# number of predictions per prompt\n",
333
+ "n_predictions = 8\n",
334
+ "\n",
335
+ "# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)\n",
336
+ "gen_top_k = None\n",
337
+ "gen_top_p = None\n",
338
+ "temperature = None\n",
339
+ "cond_scale = 10.0"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": null,
345
+ "metadata": {
346
+ "id": "SDjEx9JxR3v8"
347
+ },
348
+ "outputs": [],
349
+ "source": [
350
+ "from flax.training.common_utils import shard_prng_key\n",
351
+ "import numpy as np\n",
352
+ "from PIL import Image\n",
353
+ "from tqdm.notebook import trange\n",
354
+ "\n",
355
+ "print(f\"Prompts: {prompts}\\n\")\n",
356
+ "# generate images\n",
357
+ "images = []\n",
358
+ "for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
359
+ " # get a new key\n",
360
+ " key, subkey = jax.random.split(key)\n",
361
+ " # generate images\n",
362
+ " encoded_images = p_generate(\n",
363
+ " tokenized_prompt,\n",
364
+ " shard_prng_key(subkey),\n",
365
+ " params,\n",
366
+ " gen_top_k,\n",
367
+ " gen_top_p,\n",
368
+ " temperature,\n",
369
+ " cond_scale,\n",
370
+ " )\n",
371
+ " # remove BOS\n",
372
+ " encoded_images = encoded_images.sequences[..., 1:]\n",
373
+ " # decode images\n",
374
+ " decoded_images = p_decode(encoded_images, vqgan_params)\n",
375
+ " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
376
+ " for decoded_img in decoded_images:\n",
377
+ " img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n",
378
+ " images.append(img)\n",
379
+ " display(img)\n",
380
+ " print()"
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "markdown",
385
+ "metadata": {
386
+ "id": "tw02wG9zGmyB"
387
+ },
388
+ "source": [
389
+ "## 🏅 Optional: Rank images by CLIP score\n",
390
+ "\n",
391
+ "We can rank images according to CLIP.\n",
392
+ "\n",
393
+ "**Note: your session may crash if you don't have a subscription to Colab Pro.**"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": null,
399
+ "metadata": {
400
+ "id": "RGjlIW_f6GA0"
401
+ },
402
+ "outputs": [],
403
+ "source": [
404
+ "# CLIP model\n",
405
+ "CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
406
+ "CLIP_COMMIT_ID = None\n",
407
+ "\n",
408
+ "# Load CLIP\n",
409
+ "clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
410
+ " CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
411
+ ")\n",
412
+ "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
413
+ "clip_params = replicate(clip_params)\n",
414
+ "\n",
415
+ "\n",
416
+ "# score images\n",
417
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
418
+ "def p_clip(inputs, params):\n",
419
+ " logits = clip(params=params, **inputs).logits_per_image\n",
420
+ " return logits"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": null,
426
+ "metadata": {
427
+ "id": "FoLXpjCmGpju"
428
+ },
429
+ "outputs": [],
430
+ "source": [
431
+ "from flax.training.common_utils import shard\n",
432
+ "\n",
433
+ "# get clip scores\n",
434
+ "clip_inputs = clip_processor(\n",
435
+ " text=prompts * jax.device_count(),\n",
436
+ " images=images,\n",
437
+ " return_tensors=\"np\",\n",
438
+ " padding=\"max_length\",\n",
439
+ " max_length=77,\n",
440
+ " truncation=True,\n",
441
+ ").data\n",
442
+ "logits = p_clip(shard(clip_inputs), clip_params)\n",
443
+ "\n",
444
+ "# organize scores per prompt\n",
445
+ "p = len(prompts)\n",
446
+ "logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()"
447
+ ]
448
+ },
449
+ {
450
+ "cell_type": "markdown",
451
+ "metadata": {
452
+ "id": "4AAWRm70LgED"
453
+ },
454
+ "source": [
455
+ "Let's now display images ranked by CLIP score."
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "code",
460
+ "execution_count": null,
461
+ "metadata": {
462
+ "id": "zsgxxubLLkIu"
463
+ },
464
+ "outputs": [],
465
+ "source": [
466
+ "for i, prompt in enumerate(prompts):\n",
467
+ " print(f\"Prompt: {prompt}\\n\")\n",
468
+ " for idx in logits[i].argsort()[::-1]:\n",
469
+ " display(images[idx * p + i])\n",
470
+ " print(f\"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\\n\")\n",
471
+ " print()"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "markdown",
476
+ "metadata": {
477
+ "id": "oZT9i3jCjir0"
478
+ },
479
+ "source": [
480
+ "## 🪄 Optional: Save your Generated Images as W&B Tables\n",
481
+ "\n",
482
+ "W&B Tables is an interactive 2D grid with support to rich media logging. Use this to save the generated images on W&B dashboard and share with the world."
483
+ ]
484
+ },
485
+ {
486
+ "cell_type": "code",
487
+ "execution_count": null,
488
+ "metadata": {
489
+ "id": "-pSiv6Vwjkn0"
490
+ },
491
+ "outputs": [],
492
+ "source": [
493
+ "import wandb\n",
494
+ "\n",
495
+ "# Initialize a W&B run.\n",
496
+ "project = \"dalle-mini-tables-colab\"\n",
497
+ "run = wandb.init(project=project)\n",
498
+ "\n",
499
+ "# Initialize an empty W&B Tables.\n",
500
+ "columns = [\"captions\"] + [f\"image_{i+1}\" for i in range(n_predictions)]\n",
501
+ "gen_table = wandb.Table(columns=columns)\n",
502
+ "\n",
503
+ "# Add data to the table.\n",
504
+ "for i, prompt in enumerate(prompts):\n",
505
+ " # If CLIP scores exist, sort the Images\n",
506
+ " if logits is not None:\n",
507
+ " idxs = logits[i].argsort()[::-1]\n",
508
+ " tmp_imgs = images[i :: len(prompts)]\n",
509
+ " tmp_imgs = [tmp_imgs[idx] for idx in idxs]\n",
510
+ " else:\n",
511
+ " tmp_imgs = images[i :: len(prompts)]\n",
512
+ "\n",
513
+ " # Add the data to the table.\n",
514
+ " gen_table.add_data(prompt, *[wandb.Image(img) for img in tmp_imgs])\n",
515
+ "\n",
516
+ "# Log the Table to W&B dashboard.\n",
517
+ "wandb.log({\"Generated Images\": gen_table})\n",
518
+ "\n",
519
+ "# Close the W&B run.\n",
520
+ "run.finish()"
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "markdown",
525
+ "metadata": {
526
+ "id": "Ck2ZnHwVjnRd"
527
+ },
528
+ "source": [
529
+ "Click on the link above to check out your generated images."
530
+ ]
531
+ }
532
+ ],
533
+ "metadata": {
534
+ "accelerator": "GPU",
535
+ "colab": {
536
+ "machine_shape": "hm",
537
+ "name": "DALL·E mini - Inference pipeline.ipynb",
538
+ "provenance": [],
539
+ "gpuType": "A100",
540
+ "include_colab_link": true
541
+ },
542
+ "kernelspec": {
543
+ "display_name": "Python 3",
544
+ "name": "python3"
545
+ },
546
+ "language_info": {
547
+ "codemirror_mode": {
548
+ "name": "ipython",
549
+ "version": 3
550
+ },
551
+ "file_extension": ".py",
552
+ "mimetype": "text/x-python",
553
+ "name": "python",
554
+ "nbconvert_exporter": "python",
555
+ "pygments_lexer": "ipython3",
556
+ "version": "3.9.7"
557
+ }
558
+ },
559
+ "nbformat": 4,
560
+ "nbformat_minor": 0
561
+ }
tools/inference/run_infer_notebook.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ jupyter notebook --ip 0.0.0.0 --no-browser --allow-root
tools/train/config/mega/config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "attention_dropout": 0.0,
5
+ "bos_token_id": 16385,
6
+ "d_model": 2048,
7
+ "decoder_attention_heads": 32,
8
+ "decoder_ffn_dim": 4096,
9
+ "decoder_layerdrop": 0.0,
10
+ "decoder_layers": 24,
11
+ "decoder_start_token_id": 16384,
12
+ "do_sample": true,
13
+ "dropout": 0.0,
14
+ "encoder_attention_heads": 32,
15
+ "encoder_ffn_dim": 4096,
16
+ "encoder_layerdrop": 0.0,
17
+ "encoder_layers": 24,
18
+ "encoder_vocab_size": 50272,
19
+ "eos_token_id": 16385,
20
+ "force_ln_scale": false,
21
+ "gradient_checkpointing": false,
22
+ "image_length": 256,
23
+ "image_vocab_size": 16415,
24
+ "init_std": 0.01,
25
+ "is_encoder_decoder": true,
26
+ "ln_positions": "normformer",
27
+ "ln_type": "layernorm",
28
+ "max_length": 257,
29
+ "max_text_length": 64,
30
+ "min_length": 257,
31
+ "model_type": "dallebart",
32
+ "normalize_text": true,
33
+ "pad_token_id": 16385,
34
+ "scale_embedding": false,
35
+ "sinkhorn_iters": 1,
36
+ "tau_init": 0.05,
37
+ "tie_word_embeddings": false,
38
+ "use_absolute_position_embeddings": true,
39
+ "use_alibi": false,
40
+ "use_bias": false,
41
+ "use_cache": true,
42
+ "use_cosine_attention": false,
43
+ "use_deepnet_scaling": false,
44
+ "use_final_ln_decoder": true,
45
+ "use_final_ln_encoder": true,
46
+ "use_glu": true,
47
+ "use_head_scale": false,
48
+ "use_swin_position_embeddings": false
49
+ }
tools/train/config/micro/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "attention_dropout": 0.0,
5
+ "bos_token_id": 16385,
6
+ "d_model": 256,
7
+ "decoder_attention_heads": 2,
8
+ "decoder_ffn_dim": 256,
9
+ "decoder_layerdrop": 0.0,
10
+ "decoder_layers": 2,
11
+ "decoder_start_token_id": 16384,
12
+ "dropout": 0.0,
13
+ "encoder_attention_heads": 2,
14
+ "encoder_ffn_dim": 256,
15
+ "encoder_layerdrop": 0.0,
16
+ "encoder_layers": 2,
17
+ "encoder_vocab_size": 50264,
18
+ "eos_token_id": 16385,
19
+ "image_length": 256,
20
+ "image_vocab_size": 16391,
21
+ "init_std": 0.02,
22
+ "is_encoder_decoder": true,
23
+ "max_text_length": 64,
24
+ "model_type": "dallebart",
25
+ "normalize_text": true,
26
+ "pad_token_id": 16385,
27
+ "scale_embedding": false,
28
+ "tie_word_embeddings": false,
29
+ "use_cache": true
30
+ }
tools/train/config/mini/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "attention_dropout": 0.0,
5
+ "bos_token_id": 16385,
6
+ "d_model": 1024,
7
+ "decoder_attention_heads": 16,
8
+ "decoder_ffn_dim": 4096,
9
+ "decoder_layers": 12,
10
+ "decoder_start_token_id": 16384,
11
+ "dropout": 0.0,
12
+ "encoder_attention_heads": 16,
13
+ "encoder_ffn_dim": 4096,
14
+ "encoder_layers": 12,
15
+ "encoder_vocab_size": 50264,
16
+ "eos_token_id": 16385,
17
+ "gradient_checkpointing": false,
18
+ "image_length": 256,
19
+ "image_vocab_size": 16391,
20
+ "init_std": 0.02,
21
+ "is_encoder_decoder": true,
22
+ "max_text_length": 64,
23
+ "model_type": "dallebart",
24
+ "normalize_text": true,
25
+ "pad_token_id": 16385,
26
+ "scale_embedding": false,
27
+ "tie_word_embeddings": false,
28
+ "use_cache": true
29
+ }
tools/train/config/mini_glu/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "attention_dropout": 0.0,
5
+ "bos_token_id": 16385,
6
+ "d_model": 1024,
7
+ "decoder_attention_heads": 16,
8
+ "decoder_ffn_dim": 2730,
9
+ "decoder_layers": 12,
10
+ "decoder_start_token_id": 16384,
11
+ "dropout": 0.0,
12
+ "encoder_attention_heads": 16,
13
+ "encoder_ffn_dim": 2730,
14
+ "encoder_layers": 12,
15
+ "encoder_vocab_size": 50300,
16
+ "eos_token_id": 16385,
17
+ "gradient_checkpointing": false,
18
+ "image_length": 256,
19
+ "image_vocab_size": 16400,
20
+ "init_std": 0.02,
21
+ "is_encoder_decoder": true,
22
+ "max_text_length": 64,
23
+ "model_type": "dallebart",
24
+ "normalize_text": true,
25
+ "pad_token_id": 16385,
26
+ "scale_embedding": false,
27
+ "tie_word_embeddings": false,
28
+ "use_scan": false,
29
+ "use_cache": true
30
+ }
tools/train/embeddings_retrain_preparation.ipynb ADDED
@@ -0,0 +1,1218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "118UKH5bWCGa"
7
+ },
8
+ "source": [
9
+ "# DALL·E mini - Embedding Retrain Preparation"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "markdown",
14
+ "metadata": {},
15
+ "source": [
16
+ "We'll start with the dalle-mini model for faster experimentation."
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 1,
22
+ "metadata": {
23
+ "id": "K6CxW2o42f-w"
24
+ },
25
+ "outputs": [],
26
+ "source": [
27
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
28
+ "DALLE_COMMIT_ID = None\n",
29
+ "\n",
30
+ "# # dalle-mega\n",
31
+ "# DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
32
+ "# DALLE_COMMIT_ID = None"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 2,
38
+ "metadata": {
39
+ "id": "Yv-aR3t4Oe5v"
40
+ },
41
+ "outputs": [
42
+ {
43
+ "data": {
44
+ "text/plain": [
45
+ "8"
46
+ ]
47
+ },
48
+ "execution_count": 2,
49
+ "metadata": {},
50
+ "output_type": "execute_result"
51
+ }
52
+ ],
53
+ "source": [
54
+ "import jax\n",
55
+ "import jax.numpy as jnp\n",
56
+ "\n",
57
+ "# check how many devices are available\n",
58
+ "jax.local_device_count()"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "markdown",
63
+ "metadata": {},
64
+ "source": [
65
+ "## Load model"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "metadata": {},
71
+ "source": [
72
+ "We load the model twice to keep a copy of the original parameters."
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": 3,
78
+ "metadata": {
79
+ "id": "92zYmvsQ38vL"
80
+ },
81
+ "outputs": [
82
+ {
83
+ "name": "stderr",
84
+ "output_type": "stream",
85
+ "text": [
86
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact mini-1:v0, 1673.43MB. 7 files... Done. 0:0:1.2\n",
87
+ "tcmalloc: large alloc 1751343104 bytes == 0x56011a2c0000 @ 0x7f143aaa9680 0x7f143aaca824 0x5600d248253b 0x5600d24c30ba 0x5600d2599a58 0x5600d24f548d 0x5600d23cf328 0x5600d25af66d 0x5600d24f5825 0x5600d24532da 0x5600d24eafe3 0x5600d24ec709 0x5600d249a1ea 0x5600d252be7a 0x5600d24eafe3 0x5600d24ec709 0x5600d245273d 0x5600d24eafe3 0x5600d2597a7c 0x5600d24ebdbb 0x5600d25ce33e 0x5600d24f5571 0x5600d2452088 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d24f5f94 0x5600d24532da 0x5600d24ebbe4\n",
88
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact mini-1:v0, 1673.43MB. 7 files... Done. 0:0:1.2\n",
89
+ "tcmalloc: large alloc 1751343104 bytes == 0x56011a2c0000 @ 0x7f143aaa9680 0x7f143aaca824 0x5600d248253b 0x5600d24c30ba 0x5600d2599a58 0x5600d24f548d 0x5600d23cf328 0x5600d25af66d 0x5600d24f5825 0x5600d24532da 0x5600d24eafe3 0x5600d24ec709 0x5600d249a1ea 0x5600d252be7a 0x5600d24eafe3 0x5600d24ec709 0x5600d245273d 0x5600d24eafe3 0x5600d2597a7c 0x5600d24ebdbb 0x5600d25ce33e 0x5600d24f5571 0x5600d2452088 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d24f5f94 0x5600d24532da 0x5600d24ebbe4\n"
90
+ ]
91
+ }
92
+ ],
93
+ "source": [
94
+ "# Load model\n",
95
+ "from dalle_mini import DalleBart, DalleBartProcessor\n",
96
+ "\n",
97
+ "model, params = DalleBart.from_pretrained(\n",
98
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
99
+ ")\n",
100
+ "\n",
101
+ "_, params_original = DalleBart.from_pretrained(\n",
102
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
103
+ ")"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "markdown",
108
+ "metadata": {},
109
+ "source": [
110
+ "## Model surgery: remove layers to be retrained"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "markdown",
115
+ "metadata": {},
116
+ "source": [
117
+ "Let's take a look at the params tree."
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": 4,
123
+ "metadata": {},
124
+ "outputs": [
125
+ {
126
+ "data": {
127
+ "text/plain": [
128
+ "437833712"
129
+ ]
130
+ },
131
+ "execution_count": 4,
132
+ "metadata": {},
133
+ "output_type": "execute_result"
134
+ }
135
+ ],
136
+ "source": [
137
+ "sum(x.size for x in jax.tree_leaves(params))"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": 5,
143
+ "metadata": {
144
+ "scrolled": true
145
+ },
146
+ "outputs": [
147
+ {
148
+ "name": "stdout",
149
+ "output_type": "stream",
150
+ "text": [
151
+ "{\n",
152
+ " \"lm_head\": {\n",
153
+ " \"kernel\": [\n",
154
+ " 1024,\n",
155
+ " 16385\n",
156
+ " ]\n",
157
+ " },\n",
158
+ " \"model\": {\n",
159
+ " \"decoder\": {\n",
160
+ " \"embed_positions\": {\n",
161
+ " \"embedding\": [\n",
162
+ " 256,\n",
163
+ " 1024\n",
164
+ " ]\n",
165
+ " },\n",
166
+ " \"embed_tokens\": {\n",
167
+ " \"embedding\": [\n",
168
+ " 16385,\n",
169
+ " 1024\n",
170
+ " ]\n",
171
+ " },\n",
172
+ " \"final_ln\": {\n",
173
+ " \"bias\": [\n",
174
+ " 1024\n",
175
+ " ]\n",
176
+ " },\n",
177
+ " \"layernorm_embedding\": {\n",
178
+ " \"bias\": [\n",
179
+ " 1024\n",
180
+ " ],\n",
181
+ " \"scale\": [\n",
182
+ " 1024\n",
183
+ " ]\n",
184
+ " },\n",
185
+ " \"layers\": {\n",
186
+ " \"FlaxBartDecoderLayers\": {\n",
187
+ " \"FlaxBartAttention_0\": {\n",
188
+ " \"k_proj\": {\n",
189
+ " \"kernel\": [\n",
190
+ " 12,\n",
191
+ " 1024,\n",
192
+ " 1024\n",
193
+ " ]\n",
194
+ " },\n",
195
+ " \"out_proj\": {\n",
196
+ " \"kernel\": [\n",
197
+ " 12,\n",
198
+ " 1024,\n",
199
+ " 1024\n",
200
+ " ]\n",
201
+ " },\n",
202
+ " \"q_proj\": {\n",
203
+ " \"kernel\": [\n",
204
+ " 12,\n",
205
+ " 1024,\n",
206
+ " 1024\n",
207
+ " ]\n",
208
+ " },\n",
209
+ " \"v_proj\": {\n",
210
+ " \"kernel\": [\n",
211
+ " 12,\n",
212
+ " 1024,\n",
213
+ " 1024\n",
214
+ " ]\n",
215
+ " }\n",
216
+ " },\n",
217
+ " \"FlaxBartAttention_1\": {\n",
218
+ " \"k_proj\": {\n",
219
+ " \"kernel\": [\n",
220
+ " 12,\n",
221
+ " 1024,\n",
222
+ " 1024\n",
223
+ " ]\n",
224
+ " },\n",
225
+ " \"out_proj\": {\n",
226
+ " \"kernel\": [\n",
227
+ " 12,\n",
228
+ " 1024,\n",
229
+ " 1024\n",
230
+ " ]\n",
231
+ " },\n",
232
+ " \"q_proj\": {\n",
233
+ " \"kernel\": [\n",
234
+ " 12,\n",
235
+ " 1024,\n",
236
+ " 1024\n",
237
+ " ]\n",
238
+ " },\n",
239
+ " \"v_proj\": {\n",
240
+ " \"kernel\": [\n",
241
+ " 12,\n",
242
+ " 1024,\n",
243
+ " 1024\n",
244
+ " ]\n",
245
+ " }\n",
246
+ " },\n",
247
+ " \"GLU_0\": {\n",
248
+ " \"Dense_0\": {\n",
249
+ " \"kernel\": [\n",
250
+ " 12,\n",
251
+ " 1024,\n",
252
+ " 2730\n",
253
+ " ]\n",
254
+ " },\n",
255
+ " \"Dense_1\": {\n",
256
+ " \"kernel\": [\n",
257
+ " 12,\n",
258
+ " 1024,\n",
259
+ " 2730\n",
260
+ " ]\n",
261
+ " },\n",
262
+ " \"Dense_2\": {\n",
263
+ " \"kernel\": [\n",
264
+ " 12,\n",
265
+ " 2730,\n",
266
+ " 1024\n",
267
+ " ]\n",
268
+ " },\n",
269
+ " \"LayerNorm_0\": {\n",
270
+ " \"bias\": [\n",
271
+ " 12,\n",
272
+ " 1024\n",
273
+ " ]\n",
274
+ " },\n",
275
+ " \"LayerNorm_1\": {\n",
276
+ " \"bias\": [\n",
277
+ " 12,\n",
278
+ " 2730\n",
279
+ " ]\n",
280
+ " }\n",
281
+ " },\n",
282
+ " \"LayerNorm_0\": {\n",
283
+ " \"bias\": [\n",
284
+ " 12,\n",
285
+ " 1024\n",
286
+ " ]\n",
287
+ " },\n",
288
+ " \"LayerNorm_1\": {\n",
289
+ " \"bias\": [\n",
290
+ " 12,\n",
291
+ " 1024\n",
292
+ " ],\n",
293
+ " \"scale\": [\n",
294
+ " 12,\n",
295
+ " 1024\n",
296
+ " ]\n",
297
+ " },\n",
298
+ " \"LayerNorm_2\": {\n",
299
+ " \"bias\": [\n",
300
+ " 12,\n",
301
+ " 1024\n",
302
+ " ]\n",
303
+ " },\n",
304
+ " \"LayerNorm_3\": {\n",
305
+ " \"bias\": [\n",
306
+ " 12,\n",
307
+ " 1024\n",
308
+ " ],\n",
309
+ " \"scale\": [\n",
310
+ " 12,\n",
311
+ " 1024\n",
312
+ " ]\n",
313
+ " }\n",
314
+ " }\n",
315
+ " }\n",
316
+ " },\n",
317
+ " \"encoder\": {\n",
318
+ " \"embed_positions\": {\n",
319
+ " \"embedding\": [\n",
320
+ " 64,\n",
321
+ " 1024\n",
322
+ " ]\n",
323
+ " },\n",
324
+ " \"embed_tokens\": {\n",
325
+ " \"embedding\": [\n",
326
+ " 50264,\n",
327
+ " 1024\n",
328
+ " ]\n",
329
+ " },\n",
330
+ " \"final_ln\": {\n",
331
+ " \"bias\": [\n",
332
+ " 1024\n",
333
+ " ]\n",
334
+ " },\n",
335
+ " \"layernorm_embedding\": {\n",
336
+ " \"bias\": [\n",
337
+ " 1024\n",
338
+ " ],\n",
339
+ " \"scale\": [\n",
340
+ " 1024\n",
341
+ " ]\n",
342
+ " },\n",
343
+ " \"layers\": {\n",
344
+ " \"FlaxBartEncoderLayers\": {\n",
345
+ " \"FlaxBartAttention_0\": {\n",
346
+ " \"k_proj\": {\n",
347
+ " \"kernel\": [\n",
348
+ " 12,\n",
349
+ " 1024,\n",
350
+ " 1024\n",
351
+ " ]\n",
352
+ " },\n",
353
+ " \"out_proj\": {\n",
354
+ " \"kernel\": [\n",
355
+ " 12,\n",
356
+ " 1024,\n",
357
+ " 1024\n",
358
+ " ]\n",
359
+ " },\n",
360
+ " \"q_proj\": {\n",
361
+ " \"kernel\": [\n",
362
+ " 12,\n",
363
+ " 1024,\n",
364
+ " 1024\n",
365
+ " ]\n",
366
+ " },\n",
367
+ " \"v_proj\": {\n",
368
+ " \"kernel\": [\n",
369
+ " 12,\n",
370
+ " 1024,\n",
371
+ " 1024\n",
372
+ " ]\n",
373
+ " }\n",
374
+ " },\n",
375
+ " \"GLU_0\": {\n",
376
+ " \"Dense_0\": {\n",
377
+ " \"kernel\": [\n",
378
+ " 12,\n",
379
+ " 1024,\n",
380
+ " 2730\n",
381
+ " ]\n",
382
+ " },\n",
383
+ " \"Dense_1\": {\n",
384
+ " \"kernel\": [\n",
385
+ " 12,\n",
386
+ " 1024,\n",
387
+ " 2730\n",
388
+ " ]\n",
389
+ " },\n",
390
+ " \"Dense_2\": {\n",
391
+ " \"kernel\": [\n",
392
+ " 12,\n",
393
+ " 2730,\n",
394
+ " 1024\n",
395
+ " ]\n",
396
+ " },\n",
397
+ " \"LayerNorm_0\": {\n",
398
+ " \"bias\": [\n",
399
+ " 12,\n",
400
+ " 1024\n",
401
+ " ]\n",
402
+ " },\n",
403
+ " \"LayerNorm_1\": {\n",
404
+ " \"bias\": [\n",
405
+ " 12,\n",
406
+ " 2730\n",
407
+ " ]\n",
408
+ " }\n",
409
+ " },\n",
410
+ " \"LayerNorm_0\": {\n",
411
+ " \"bias\": [\n",
412
+ " 12,\n",
413
+ " 1024\n",
414
+ " ]\n",
415
+ " },\n",
416
+ " \"LayerNorm_1\": {\n",
417
+ " \"bias\": [\n",
418
+ " 12,\n",
419
+ " 1024\n",
420
+ " ],\n",
421
+ " \"scale\": [\n",
422
+ " 12,\n",
423
+ " 1024\n",
424
+ " ]\n",
425
+ " }\n",
426
+ " }\n",
427
+ " }\n",
428
+ " }\n",
429
+ " }\n",
430
+ "}\n"
431
+ ]
432
+ }
433
+ ],
434
+ "source": [
435
+ "import json\n",
436
+ "\n",
437
+ "tree = jax.tree_map(lambda x: x.shape, params)\n",
438
+ "print(json.dumps(tree, indent=2))"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "markdown",
443
+ "metadata": {},
444
+ "source": [
445
+ "We will remove or reinitialize:\n",
446
+ "- `lm_head`\n",
447
+ "- `model.decoder.embed_positions`\n",
448
+ "- `model.decoder.embed_tokens`\n",
449
+ "- `model.decoder.final_ln`\n",
450
+ "- `model.decoder.layernorm_embedding`"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "code",
455
+ "execution_count": 6,
456
+ "metadata": {},
457
+ "outputs": [],
458
+ "source": [
459
+ "del params[\"lm_head\"]\n",
460
+ "for layer in [\"embed_positions\", \"embed_tokens\", \"final_ln\", \"layernorm_embedding\"]:\n",
461
+ " del params[\"model\"][\"decoder\"][layer]"
462
+ ]
463
+ },
464
+ {
465
+ "cell_type": "code",
466
+ "execution_count": 7,
467
+ "metadata": {},
468
+ "outputs": [
469
+ {
470
+ "data": {
471
+ "text/plain": [
472
+ "{'model': {'decoder': {'layers': {'FlaxBartDecoderLayers': {'FlaxBartAttention_0': {'k_proj': {'kernel': (12,\n",
473
+ " 1024,\n",
474
+ " 1024)},\n",
475
+ " 'out_proj': {'kernel': (12, 1024, 1024)},\n",
476
+ " 'q_proj': {'kernel': (12, 1024, 1024)},\n",
477
+ " 'v_proj': {'kernel': (12, 1024, 1024)}},\n",
478
+ " 'FlaxBartAttention_1': {'k_proj': {'kernel': (12, 1024, 1024)},\n",
479
+ " 'out_proj': {'kernel': (12, 1024, 1024)},\n",
480
+ " 'q_proj': {'kernel': (12, 1024, 1024)},\n",
481
+ " 'v_proj': {'kernel': (12, 1024, 1024)}},\n",
482
+ " 'GLU_0': {'Dense_0': {'kernel': (12, 1024, 2730)},\n",
483
+ " 'Dense_1': {'kernel': (12, 1024, 2730)},\n",
484
+ " 'Dense_2': {'kernel': (12, 2730, 1024)},\n",
485
+ " 'LayerNorm_0': {'bias': (12, 1024)},\n",
486
+ " 'LayerNorm_1': {'bias': (12, 2730)}},\n",
487
+ " 'LayerNorm_0': {'bias': (12, 1024)},\n",
488
+ " 'LayerNorm_1': {'bias': (12, 1024), 'scale': (12, 1024)},\n",
489
+ " 'LayerNorm_2': {'bias': (12, 1024)},\n",
490
+ " 'LayerNorm_3': {'bias': (12, 1024), 'scale': (12, 1024)}}}},\n",
491
+ " 'encoder': {'embed_positions': {'embedding': (64, 1024)},\n",
492
+ " 'embed_tokens': {'embedding': (50264, 1024)},\n",
493
+ " 'final_ln': {'bias': (1024,)},\n",
494
+ " 'layernorm_embedding': {'bias': (1024,), 'scale': (1024,)},\n",
495
+ " 'layers': {'FlaxBartEncoderLayers': {'FlaxBartAttention_0': {'k_proj': {'kernel': (12,\n",
496
+ " 1024,\n",
497
+ " 1024)},\n",
498
+ " 'out_proj': {'kernel': (12, 1024, 1024)},\n",
499
+ " 'q_proj': {'kernel': (12, 1024, 1024)},\n",
500
+ " 'v_proj': {'kernel': (12, 1024, 1024)}},\n",
501
+ " 'GLU_0': {'Dense_0': {'kernel': (12, 1024, 2730)},\n",
502
+ " 'Dense_1': {'kernel': (12, 1024, 2730)},\n",
503
+ " 'Dense_2': {'kernel': (12, 2730, 1024)},\n",
504
+ " 'LayerNorm_0': {'bias': (12, 1024)},\n",
505
+ " 'LayerNorm_1': {'bias': (12, 2730)}},\n",
506
+ " 'LayerNorm_0': {'bias': (12, 1024)},\n",
507
+ " 'LayerNorm_1': {'bias': (12, 1024), 'scale': (12, 1024)}}}}}}"
508
+ ]
509
+ },
510
+ "execution_count": 7,
511
+ "metadata": {},
512
+ "output_type": "execute_result"
513
+ }
514
+ ],
515
+ "source": [
516
+ "jax.tree_map(lambda x: x.shape, params)"
517
+ ]
518
+ },
519
+ {
520
+ "cell_type": "code",
521
+ "execution_count": 8,
522
+ "metadata": {},
523
+ "outputs": [
524
+ {
525
+ "data": {
526
+ "text/plain": [
527
+ "404012016"
528
+ ]
529
+ },
530
+ "execution_count": 8,
531
+ "metadata": {},
532
+ "output_type": "execute_result"
533
+ }
534
+ ],
535
+ "source": [
536
+ "sum(x.size for x in jax.tree_leaves(params))"
537
+ ]
538
+ },
539
+ {
540
+ "cell_type": "markdown",
541
+ "metadata": {},
542
+ "source": [
543
+ "## Reinitialize layers"
544
+ ]
545
+ },
546
+ {
547
+ "cell_type": "markdown",
548
+ "metadata": {},
549
+ "source": [
550
+ "We save a checkpoint and reload it again. It does not automatically reinitialize the missing keys, but it sets `_missing_keys` appropriately so we can initialize them later. We could do the same by simply setting that property ourselves, but I'll refrain from doing so because it's a private implementation detail."
551
+ ]
552
+ },
553
+ {
554
+ "cell_type": "code",
555
+ "execution_count": 9,
556
+ "metadata": {},
557
+ "outputs": [],
558
+ "source": [
559
+ "trimmed_checkpoint = \"mini-trimmed\""
560
+ ]
561
+ },
562
+ {
563
+ "cell_type": "code",
564
+ "execution_count": 10,
565
+ "metadata": {},
566
+ "outputs": [
567
+ {
568
+ "name": "stderr",
569
+ "output_type": "stream",
570
+ "text": [
571
+ "tcmalloc: large alloc 1610424320 bytes == 0x5632d11c8000 @ 0x7f95ccad0680 0x7f95ccaf0bdd 0x7f95be99e29f 0x7f95be9a7750 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a4fc4 0x7f95be9a571e 0x5630fb630f94 0x5630fb58e2da 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb626be4 0x5630fb58d088 0x5630fb625fe3 0x5630fb627709 0x5630fb58d73d 0x5630fb625fe3 0x5630fb6d2a7c 0x5630fb626dbb 0x5630fb70933e 0x5630fb630571\n",
572
+ "tcmalloc: large alloc 3231449088 bytes == 0x56333119a000 @ 0x7f95ccad0680 0x7f95ccaf0bdd 0x7f95be99e29f 0x7f95be9a7750 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a4fc4 0x7f95be9a571e 0x5630fb630f94 0x5630fb58e2da 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb626be4 0x5630fb58d088 0x5630fb625fe3 0x5630fb627709 0x5630fb58d73d 0x5630fb625fe3 0x5630fb6d2a7c 0x5630fb626dbb 0x5630fb70933e 0x5630fb630571\n"
573
+ ]
574
+ }
575
+ ],
576
+ "source": [
577
+ "model.save_pretrained(trimmed_checkpoint, params=params)"
578
+ ]
579
+ },
580
+ {
581
+ "cell_type": "code",
582
+ "execution_count": 11,
583
+ "metadata": {},
584
+ "outputs": [
585
+ {
586
+ "name": "stderr",
587
+ "output_type": "stream",
588
+ "text": [
589
+ "The checkpoint mini-trimmed is missing required keys: {('model', 'decoder', 'embed_tokens', 'embedding'), ('lm_head', 'kernel'), ('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'final_ln', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layernorm_embedding', 'bias')}. Make sure to call model.init_weights to initialize the missing weights.\n",
590
+ "Some weights of DalleBart were not initialized from the model checkpoint at mini-trimmed and are newly initialized: {('model', 'decoder', 'embed_tokens', 'embedding'), ('lm_head', 'kernel'), ('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'final_ln', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layernorm_embedding', 'bias')}\n",
591
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
592
+ ]
593
+ }
594
+ ],
595
+ "source": [
596
+ "model, params = DalleBart.from_pretrained(\n",
597
+ " trimmed_checkpoint, revision=None, dtype=jnp.float16, _do_init=False\n",
598
+ ")"
599
+ ]
600
+ },
601
+ {
602
+ "cell_type": "code",
603
+ "execution_count": 12,
604
+ "metadata": {},
605
+ "outputs": [
606
+ {
607
+ "data": {
608
+ "text/plain": [
609
+ "{('lm_head', 'kernel'),\n",
610
+ " ('model', 'decoder', 'embed_positions', 'embedding'),\n",
611
+ " ('model', 'decoder', 'embed_tokens', 'embedding'),\n",
612
+ " ('model', 'decoder', 'final_ln', 'bias'),\n",
613
+ " ('model', 'decoder', 'layernorm_embedding', 'bias'),\n",
614
+ " ('model', 'decoder', 'layernorm_embedding', 'scale')}"
615
+ ]
616
+ },
617
+ "execution_count": 12,
618
+ "metadata": {},
619
+ "output_type": "execute_result"
620
+ }
621
+ ],
622
+ "source": [
623
+ "model._missing_keys"
624
+ ]
625
+ },
626
+ {
627
+ "cell_type": "code",
628
+ "execution_count": 13,
629
+ "metadata": {},
630
+ "outputs": [],
631
+ "source": [
632
+ "params_reinit = model.init_weights(model.key, model.input_shape, params=params)"
633
+ ]
634
+ },
635
+ {
636
+ "cell_type": "markdown",
637
+ "metadata": {},
638
+ "source": [
639
+ "### Verification"
640
+ ]
641
+ },
642
+ {
643
+ "cell_type": "markdown",
644
+ "metadata": {},
645
+ "source": [
646
+ "The structure should be the same as the original `params` dict. Re-initialized layers should have different parameters, but existing layers should be the same."
647
+ ]
648
+ },
649
+ {
650
+ "cell_type": "code",
651
+ "execution_count": 14,
652
+ "metadata": {
653
+ "scrolled": true
654
+ },
655
+ "outputs": [
656
+ {
657
+ "data": {
658
+ "text/plain": [
659
+ "FrozenDict({\n",
660
+ " lm_head: {\n",
661
+ " kernel: (1024, 16385),\n",
662
+ " },\n",
663
+ " model: {\n",
664
+ " decoder: {\n",
665
+ " embed_positions: {\n",
666
+ " embedding: (256, 1024),\n",
667
+ " },\n",
668
+ " embed_tokens: {\n",
669
+ " embedding: (16385, 1024),\n",
670
+ " },\n",
671
+ " final_ln: {\n",
672
+ " bias: (1024,),\n",
673
+ " },\n",
674
+ " layernorm_embedding: {\n",
675
+ " bias: (1024,),\n",
676
+ " scale: (1024,),\n",
677
+ " },\n",
678
+ " layers: {\n",
679
+ " FlaxBartDecoderLayers: {\n",
680
+ " FlaxBartAttention_0: {\n",
681
+ " k_proj: {\n",
682
+ " kernel: (12, 1024, 1024),\n",
683
+ " },\n",
684
+ " out_proj: {\n",
685
+ " kernel: (12, 1024, 1024),\n",
686
+ " },\n",
687
+ " q_proj: {\n",
688
+ " kernel: (12, 1024, 1024),\n",
689
+ " },\n",
690
+ " v_proj: {\n",
691
+ " kernel: (12, 1024, 1024),\n",
692
+ " },\n",
693
+ " },\n",
694
+ " FlaxBartAttention_1: {\n",
695
+ " k_proj: {\n",
696
+ " kernel: (12, 1024, 1024),\n",
697
+ " },\n",
698
+ " out_proj: {\n",
699
+ " kernel: (12, 1024, 1024),\n",
700
+ " },\n",
701
+ " q_proj: {\n",
702
+ " kernel: (12, 1024, 1024),\n",
703
+ " },\n",
704
+ " v_proj: {\n",
705
+ " kernel: (12, 1024, 1024),\n",
706
+ " },\n",
707
+ " },\n",
708
+ " GLU_0: {\n",
709
+ " Dense_0: {\n",
710
+ " kernel: (12, 1024, 2730),\n",
711
+ " },\n",
712
+ " Dense_1: {\n",
713
+ " kernel: (12, 1024, 2730),\n",
714
+ " },\n",
715
+ " Dense_2: {\n",
716
+ " kernel: (12, 2730, 1024),\n",
717
+ " },\n",
718
+ " LayerNorm_0: {\n",
719
+ " bias: (12, 1024),\n",
720
+ " },\n",
721
+ " LayerNorm_1: {\n",
722
+ " bias: (12, 2730),\n",
723
+ " },\n",
724
+ " },\n",
725
+ " LayerNorm_0: {\n",
726
+ " bias: (12, 1024),\n",
727
+ " },\n",
728
+ " LayerNorm_1: {\n",
729
+ " bias: (12, 1024),\n",
730
+ " scale: (12, 1024),\n",
731
+ " },\n",
732
+ " LayerNorm_2: {\n",
733
+ " bias: (12, 1024),\n",
734
+ " },\n",
735
+ " LayerNorm_3: {\n",
736
+ " bias: (12, 1024),\n",
737
+ " scale: (12, 1024),\n",
738
+ " },\n",
739
+ " },\n",
740
+ " },\n",
741
+ " },\n",
742
+ " encoder: {\n",
743
+ " embed_positions: {\n",
744
+ " embedding: (64, 1024),\n",
745
+ " },\n",
746
+ " embed_tokens: {\n",
747
+ " embedding: (50264, 1024),\n",
748
+ " },\n",
749
+ " final_ln: {\n",
750
+ " bias: (1024,),\n",
751
+ " },\n",
752
+ " layernorm_embedding: {\n",
753
+ " bias: (1024,),\n",
754
+ " scale: (1024,),\n",
755
+ " },\n",
756
+ " layers: {\n",
757
+ " FlaxBartEncoderLayers: {\n",
758
+ " FlaxBartAttention_0: {\n",
759
+ " k_proj: {\n",
760
+ " kernel: (12, 1024, 1024),\n",
761
+ " },\n",
762
+ " out_proj: {\n",
763
+ " kernel: (12, 1024, 1024),\n",
764
+ " },\n",
765
+ " q_proj: {\n",
766
+ " kernel: (12, 1024, 1024),\n",
767
+ " },\n",
768
+ " v_proj: {\n",
769
+ " kernel: (12, 1024, 1024),\n",
770
+ " },\n",
771
+ " },\n",
772
+ " GLU_0: {\n",
773
+ " Dense_0: {\n",
774
+ " kernel: (12, 1024, 2730),\n",
775
+ " },\n",
776
+ " Dense_1: {\n",
777
+ " kernel: (12, 1024, 2730),\n",
778
+ " },\n",
779
+ " Dense_2: {\n",
780
+ " kernel: (12, 2730, 1024),\n",
781
+ " },\n",
782
+ " LayerNorm_0: {\n",
783
+ " bias: (12, 1024),\n",
784
+ " },\n",
785
+ " LayerNorm_1: {\n",
786
+ " bias: (12, 2730),\n",
787
+ " },\n",
788
+ " },\n",
789
+ " LayerNorm_0: {\n",
790
+ " bias: (12, 1024),\n",
791
+ " },\n",
792
+ " LayerNorm_1: {\n",
793
+ " bias: (12, 1024),\n",
794
+ " scale: (12, 1024),\n",
795
+ " },\n",
796
+ " },\n",
797
+ " },\n",
798
+ " },\n",
799
+ " },\n",
800
+ "})"
801
+ ]
802
+ },
803
+ "execution_count": 14,
804
+ "metadata": {},
805
+ "output_type": "execute_result"
806
+ }
807
+ ],
808
+ "source": [
809
+ "jax.tree_map(lambda x: x.shape, params_reinit)"
810
+ ]
811
+ },
812
+ {
813
+ "cell_type": "code",
814
+ "execution_count": 15,
815
+ "metadata": {},
816
+ "outputs": [
817
+ {
818
+ "data": {
819
+ "text/plain": [
820
+ "FrozenDict({\n",
821
+ " embedding: DeviceArray([[ 0.00582082, -0.04113895, 0.00918633, ..., -0.00530822,\n",
822
+ " 0.01297319, 0.02720674],\n",
823
+ " [ 0.03540739, 0.03676804, -0.02924041, ..., 0.00163185,\n",
824
+ " -0.01938273, -0.02105987],\n",
825
+ " [ 0.00478452, -0.03438002, -0.0024974 , ..., -0.03892584,\n",
826
+ " 0.01721252, 0.02605445],\n",
827
+ " ...,\n",
828
+ " [ 0.02495495, 0.00559381, -0.01588043, ..., 0.01393714,\n",
829
+ " -0.01824111, -0.02007291],\n",
830
+ " [ 0.00983252, -0.00180564, -0.01686333, ..., -0.01001718,\n",
831
+ " 0.01886345, -0.00393983],\n",
832
+ " [-0.03589988, -0.00455565, 0.00076276, ..., -0.02145007,\n",
833
+ " -0.00180798, -0.0133148 ]], dtype=float32),\n",
834
+ "})"
835
+ ]
836
+ },
837
+ "execution_count": 15,
838
+ "metadata": {},
839
+ "output_type": "execute_result"
840
+ }
841
+ ],
842
+ "source": [
843
+ "params_reinit[\"model\"][\"decoder\"][\"embed_positions\"]"
844
+ ]
845
+ },
846
+ {
847
+ "cell_type": "code",
848
+ "execution_count": 16,
849
+ "metadata": {},
850
+ "outputs": [
851
+ {
852
+ "data": {
853
+ "text/plain": [
854
+ "(DeviceArray(-0.09320386, dtype=float32),\n",
855
+ " DeviceArray(0.08769083, dtype=float32))"
856
+ ]
857
+ },
858
+ "execution_count": 16,
859
+ "metadata": {},
860
+ "output_type": "execute_result"
861
+ }
862
+ ],
863
+ "source": [
864
+ "embedding_new = params_reinit[\"model\"][\"decoder\"][\"embed_positions\"][\"embedding\"]\n",
865
+ "embedding_new.min(), embedding_new.max()"
866
+ ]
867
+ },
868
+ {
869
+ "cell_type": "code",
870
+ "execution_count": 17,
871
+ "metadata": {},
872
+ "outputs": [
873
+ {
874
+ "data": {
875
+ "text/plain": [
876
+ "{'embedding': DeviceArray([[ 0.03459017, -0.0065838 , -0.11748601, ..., -0.01451578,\n",
877
+ " -0.03927238, -0.00266367],\n",
878
+ " [-0.03116009, 0.00438436, 0.02691377, ..., -0.02886203,\n",
879
+ " -0.01095741, -0.02649871],\n",
880
+ " [-0.03568491, -0.0086962 , 0.01851564, ..., -0.04736514,\n",
881
+ " 0.05310551, -0.01648099],\n",
882
+ " ...,\n",
883
+ " [-0.02454913, 0.03746822, -0.02269235, ..., 0.03377315,\n",
884
+ " 0.003004 , 0.04975331],\n",
885
+ " [-0.05145862, 0.04472217, 0.11103845, ..., 0.04581303,\n",
886
+ " 0.02850476, 0.00554514],\n",
887
+ " [-0.01037806, 0.00281054, -0.0485299 , ..., -0.03325456,\n",
888
+ " -0.0058979 , 0.01733843]], dtype=float32)}"
889
+ ]
890
+ },
891
+ "execution_count": 17,
892
+ "metadata": {},
893
+ "output_type": "execute_result"
894
+ }
895
+ ],
896
+ "source": [
897
+ "params_original[\"model\"][\"decoder\"][\"embed_positions\"]"
898
+ ]
899
+ },
900
+ {
901
+ "cell_type": "code",
902
+ "execution_count": 18,
903
+ "metadata": {},
904
+ "outputs": [
905
+ {
906
+ "data": {
907
+ "text/plain": [
908
+ "(DeviceArray(-0.25866088, dtype=float32),\n",
909
+ " DeviceArray(0.08769083, dtype=float32))"
910
+ ]
911
+ },
912
+ "execution_count": 18,
913
+ "metadata": {},
914
+ "output_type": "execute_result"
915
+ }
916
+ ],
917
+ "source": [
918
+ "embedding_original = params_original[\"model\"][\"decoder\"][\"embed_positions\"][\"embedding\"]\n",
919
+ "embedding_original.min(), embedding_new.max()"
920
+ ]
921
+ },
922
+ {
923
+ "cell_type": "code",
924
+ "execution_count": 19,
925
+ "metadata": {},
926
+ "outputs": [],
927
+ "source": [
928
+ "assert jnp.allclose(embedding_new, embedding_original).item() == False"
929
+ ]
930
+ },
931
+ {
932
+ "cell_type": "code",
933
+ "execution_count": 20,
934
+ "metadata": {},
935
+ "outputs": [],
936
+ "source": [
937
+ "lm_head_original = params_original[\"lm_head\"][\"kernel\"]\n",
938
+ "lm_head_reinit = params_reinit[\"lm_head\"][\"kernel\"]\n",
939
+ "assert jnp.allclose(lm_head_reinit, lm_head_original).item() == False"
940
+ ]
941
+ },
942
+ {
943
+ "cell_type": "code",
944
+ "execution_count": 21,
945
+ "metadata": {},
946
+ "outputs": [],
947
+ "source": [
948
+ "assert jnp.allclose(\n",
949
+ " params_reinit[\"model\"][\"encoder\"][\"layers\"][\"FlaxBartEncoderLayers\"][\n",
950
+ " \"FlaxBartAttention_0\"\n",
951
+ " ][\"k_proj\"][\"kernel\"],\n",
952
+ " params_original[\"model\"][\"encoder\"][\"layers\"][\"FlaxBartEncoderLayers\"][\n",
953
+ " \"FlaxBartAttention_0\"\n",
954
+ " ][\"k_proj\"][\"kernel\"],\n",
955
+ ").item()"
956
+ ]
957
+ },
958
+ {
959
+ "cell_type": "markdown",
960
+ "metadata": {},
961
+ "source": [
962
+ "## Save checkpoint for retrain"
963
+ ]
964
+ },
965
+ {
966
+ "cell_type": "markdown",
967
+ "metadata": {},
968
+ "source": [
969
+ "Finally, we save the resulting model to retrain those layers."
970
+ ]
971
+ },
972
+ {
973
+ "cell_type": "code",
974
+ "execution_count": 22,
975
+ "metadata": {},
976
+ "outputs": [],
977
+ "source": [
978
+ "checkpoint_dir = \"mini-reinit\""
979
+ ]
980
+ },
981
+ {
982
+ "cell_type": "code",
983
+ "execution_count": 23,
984
+ "metadata": {},
985
+ "outputs": [
986
+ {
987
+ "name": "stderr",
988
+ "output_type": "stream",
989
+ "text": [
990
+ "tcmalloc: large alloc 3367796736 bytes == 0x5633f235a000 @ 0x7f95ccad0680 0x7f95ccaf0bdd 0x7f95be99e29f 0x7f95be9a7750 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a4fc4 0x7f95be9a571e 0x5630fb630f94 0x5630fb58e2da 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb626be4 0x5630fb58d088 0x5630fb625fe3 0x5630fb627709 0x5630fb58d73d 0x5630fb625fe3 0x5630fb6d2a7c 0x5630fb626dbb 0x5630fb70933e 0x5630fb630571\n"
991
+ ]
992
+ }
993
+ ],
994
+ "source": [
995
+ "model.save_pretrained(checkpoint_dir, params=params_reinit)"
996
+ ]
997
+ },
998
+ {
999
+ "cell_type": "markdown",
1000
+ "metadata": {},
1001
+ "source": [
1002
+ "### Upload checkpoint to W&B"
1003
+ ]
1004
+ },
1005
+ {
1006
+ "cell_type": "code",
1007
+ "execution_count": 24,
1008
+ "metadata": {},
1009
+ "outputs": [],
1010
+ "source": [
1011
+ "import wandb\n",
1012
+ "from pathlib import Path"
1013
+ ]
1014
+ },
1015
+ {
1016
+ "cell_type": "code",
1017
+ "execution_count": 25,
1018
+ "metadata": {},
1019
+ "outputs": [
1020
+ {
1021
+ "name": "stderr",
1022
+ "output_type": "stream",
1023
+ "text": [
1024
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mpcuenq\u001b[0m (\u001b[33mdalle-mini\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
1025
+ ]
1026
+ },
1027
+ {
1028
+ "data": {
1029
+ "text/html": [
1030
+ "Tracking run with wandb version 0.12.21"
1031
+ ],
1032
+ "text/plain": [
1033
+ "<IPython.core.display.HTML object>"
1034
+ ]
1035
+ },
1036
+ "metadata": {},
1037
+ "output_type": "display_data"
1038
+ },
1039
+ {
1040
+ "data": {
1041
+ "text/html": [
1042
+ "Run data is saved locally in <code>/home/pedro/code/dalle-mini/dalle-mini/tools/train/wandb/run-20220722_105625-2v9szi3q</code>"
1043
+ ],
1044
+ "text/plain": [
1045
+ "<IPython.core.display.HTML object>"
1046
+ ]
1047
+ },
1048
+ "metadata": {},
1049
+ "output_type": "display_data"
1050
+ },
1051
+ {
1052
+ "data": {
1053
+ "text/html": [
1054
+ "Syncing run <strong><a href=\"https://wandb.ai/dalle-mini/dalle-mini/runs/2v9szi3q\" target=\"_blank\">astral-durian-2957</a></strong> to <a href=\"https://wandb.ai/dalle-mini/dalle-mini\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://wandb.me/run\" target=\"_blank\">docs</a>)<br/>"
1055
+ ],
1056
+ "text/plain": [
1057
+ "<IPython.core.display.HTML object>"
1058
+ ]
1059
+ },
1060
+ "metadata": {},
1061
+ "output_type": "display_data"
1062
+ },
1063
+ {
1064
+ "data": {
1065
+ "text/html": [
1066
+ "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src=\"https://wandb.ai/dalle-mini/dalle-mini/runs/2v9szi3q?jupyter=true\" style=\"border:none;width:100%;height:420px;display:none;\"></iframe>"
1067
+ ],
1068
+ "text/plain": [
1069
+ "<wandb.sdk.wandb_run.Run at 0x7f959a563b80>"
1070
+ ]
1071
+ },
1072
+ "execution_count": 25,
1073
+ "metadata": {},
1074
+ "output_type": "execute_result"
1075
+ }
1076
+ ],
1077
+ "source": [
1078
+ "wandb.init(\n",
1079
+ " entity=\"dalle-mini\",\n",
1080
+ " project=\"dalle-mini\",\n",
1081
+ " job_type=\"Seq2Seq\",\n",
1082
+ ")"
1083
+ ]
1084
+ },
1085
+ {
1086
+ "cell_type": "code",
1087
+ "execution_count": 26,
1088
+ "metadata": {},
1089
+ "outputs": [],
1090
+ "source": [
1091
+ "artifact = wandb.Artifact(\n",
1092
+ " name=f\"model-{wandb.run.id}\",\n",
1093
+ " type=\"DalleBart_model\",\n",
1094
+ " metadata={\"embeddings\": \"reset\"},\n",
1095
+ ")\n",
1096
+ "\n",
1097
+ "for filename in [\"config.json\", \"flax_model.msgpack\"]:\n",
1098
+ " artifact.add_file(f\"{Path(checkpoint_dir) / filename}\")"
1099
+ ]
1100
+ },
1101
+ {
1102
+ "cell_type": "code",
1103
+ "execution_count": 27,
1104
+ "metadata": {},
1105
+ "outputs": [
1106
+ {
1107
+ "data": {
1108
+ "text/plain": [
1109
+ "<wandb.sdk.wandb_artifacts.Artifact at 0x7f95984c3fd0>"
1110
+ ]
1111
+ },
1112
+ "execution_count": 27,
1113
+ "metadata": {},
1114
+ "output_type": "execute_result"
1115
+ }
1116
+ ],
1117
+ "source": [
1118
+ "wandb.run.log_artifact(artifact)"
1119
+ ]
1120
+ },
1121
+ {
1122
+ "cell_type": "code",
1123
+ "execution_count": 28,
1124
+ "metadata": {},
1125
+ "outputs": [
1126
+ {
1127
+ "data": {
1128
+ "text/html": [
1129
+ "Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>"
1130
+ ],
1131
+ "text/plain": [
1132
+ "<IPython.core.display.HTML object>"
1133
+ ]
1134
+ },
1135
+ "metadata": {},
1136
+ "output_type": "display_data"
1137
+ },
1138
+ {
1139
+ "data": {
1140
+ "application/vnd.jupyter.widget-view+json": {
1141
+ "model_id": "",
1142
+ "version_major": 2,
1143
+ "version_minor": 0
1144
+ },
1145
+ "text/plain": [
1146
+ "VBox(children=(Label(value='1670.207 MB of 1670.207 MB uploaded (0.000 MB deduped)\\r'), FloatProgress(value=1.…"
1147
+ ]
1148
+ },
1149
+ "metadata": {},
1150
+ "output_type": "display_data"
1151
+ },
1152
+ {
1153
+ "data": {
1154
+ "text/html": [
1155
+ "Synced <strong style=\"color:#cdcd00\">astral-durian-2957</strong>: <a href=\"https://wandb.ai/dalle-mini/dalle-mini/runs/2v9szi3q\" target=\"_blank\">https://wandb.ai/dalle-mini/dalle-mini/runs/2v9szi3q</a><br/>Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)"
1156
+ ],
1157
+ "text/plain": [
1158
+ "<IPython.core.display.HTML object>"
1159
+ ]
1160
+ },
1161
+ "metadata": {},
1162
+ "output_type": "display_data"
1163
+ },
1164
+ {
1165
+ "data": {
1166
+ "text/html": [
1167
+ "Find logs at: <code>./wandb/run-20220722_105625-2v9szi3q/logs</code>"
1168
+ ],
1169
+ "text/plain": [
1170
+ "<IPython.core.display.HTML object>"
1171
+ ]
1172
+ },
1173
+ "metadata": {},
1174
+ "output_type": "display_data"
1175
+ }
1176
+ ],
1177
+ "source": [
1178
+ "wandb.finish()"
1179
+ ]
1180
+ },
1181
+ {
1182
+ "cell_type": "markdown",
1183
+ "metadata": {},
1184
+ "source": [
1185
+ "----"
1186
+ ]
1187
+ }
1188
+ ],
1189
+ "metadata": {
1190
+ "accelerator": "GPU",
1191
+ "colab": {
1192
+ "collapsed_sections": [],
1193
+ "include_colab_link": true,
1194
+ "machine_shape": "hm",
1195
+ "name": "DALL·E mini - Inference pipeline.ipynb",
1196
+ "provenance": []
1197
+ },
1198
+ "kernelspec": {
1199
+ "display_name": "Python 3 (ipykernel)",
1200
+ "language": "python",
1201
+ "name": "python3"
1202
+ },
1203
+ "language_info": {
1204
+ "codemirror_mode": {
1205
+ "name": "ipython",
1206
+ "version": 3
1207
+ },
1208
+ "file_extension": ".py",
1209
+ "mimetype": "text/x-python",
1210
+ "name": "python",
1211
+ "nbconvert_exporter": "python",
1212
+ "pygments_lexer": "ipython3",
1213
+ "version": "3.9.12"
1214
+ }
1215
+ },
1216
+ "nbformat": 4,
1217
+ "nbformat_minor": 1
1218
+ }
tools/train/scalable_shampoo/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Notes
2
+
3
+ Files copied from [google-research/scalable_shampoo/optax](https://github.com/google-research/google-research/tree/master/scalable_shampoo/optax).
4
+
5
+ Imports have been modified to be relative.
6
+
7
+ This will eventually be replaced with `optax-shampoo` package.
tools/train/scalable_shampoo/distributed_shampoo.py ADDED
@@ -0,0 +1,2452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # An implementation of distributed Shampoo optimizer from:
17
+ #
18
+ # Scalable Second Order Optimization for Deep Learning
19
+ # Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
20
+ # Preprint Paper: https://arxiv.org/abs/2002.09018
21
+ #
22
+ # This implementation moves computation of inverse pth root back to the
23
+ # accelerator (if higher precision is available).
24
+ #
25
+ # Authors: Rohan Anil (rohananil at google dot com)
26
+ # Vineet Gupta (vineet at google dot com)
27
+ # James Lottes (jlottes at google dot com)
28
+ # Anudhyan Boral (anudhyan at google dot com)
29
+ #
30
+ """Distributed Shampoo Implementation."""
31
+
32
+ import enum
33
+ import functools
34
+ import itertools
35
+ from typing import Any, Callable, List, NamedTuple, Optional, Tuple, Union
36
+
37
+ import chex
38
+ import jax
39
+ import jax.numpy as jnp
40
+ import numpy as np
41
+ import optax
42
+ from absl import logging
43
+ from flax import struct
44
+ from jax import lax
45
+ from jax.experimental import pjit
46
+ from jax.experimental.sparse import linalg
47
+
48
+ from .quantization_utils import QuantizedValue
49
+ from .symmetric_matrices import symmetric_matrices
50
+
51
+ # Dtype for inverse-pth root routine
52
+ # Switch to f64 if you have hardware that supports it. Enable the jax flag
53
+ # jax_enable_x64 for this to work, otherwise it will default to float32.
54
+ _MAT_INV_PTH_ROOT_DTYPE = jnp.float64
55
+
56
+
57
+ @struct.dataclass
58
+ class TrainingMetrics:
59
+ inverse_pth_root_errors: chex.Array # Error for inverse-pth roots.
60
+ # TODO(rohananil): Add more important metrics to track during training.
61
+
62
+
63
+ # Per parameter optimizer state used in data-parallel training.
64
+ class ParameterStats(NamedTuple):
65
+ """State associated to each parameter of the model being trained."""
66
+
67
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
68
+ statistics: List[Any] # Statistics (QuantizedValue, chex.Array)
69
+ preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
70
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
71
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
72
+ training_metrics: TrainingMetrics # Metrics (optional for training).
73
+
74
+
75
+ # For training extremely large model; We keep a global state with a concatenated
76
+ # statistics and preconditioner states for all vars. This is so that we can
77
+ # annotate the leading axis to be sharded to save memory at the cost of
78
+ # communication.
79
+ @struct.dataclass
80
+ class GlobalShardedParameterStats:
81
+ statistics: chex.Array # Statistics
82
+ preconditioners: chex.Array # Preconditioners
83
+ exponents: chex.Array # exponents
84
+
85
+
86
+ # These are per-parameter local states; All statistics here mirror the parameter
87
+ # Thus the sharding is copied over from the param specification.
88
+ @struct.dataclass
89
+ class LocalShardedParameterStats:
90
+ """State associated to each parameter of the model being trained."""
91
+
92
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
93
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
94
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
95
+ training_metrics: TrainingMetrics # Metrics (optional for training).
96
+ index_start: np.int32 = struct.field(
97
+ pytree_node=False
98
+ ) # Index into global statistics array
99
+ sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
100
+
101
+
102
+ def init_training_metrics(num_statistics):
103
+ # Since the downstream apis expect a jnp.array - we create a dummy one if
104
+ # num_statistics=0.
105
+ if not num_statistics:
106
+ return TrainingMetrics(jnp.array(0, jnp.float32))
107
+ else:
108
+ return TrainingMetrics(jnp.zeros([num_statistics], jnp.float32))
109
+
110
+
111
+ def init_training_metrics_shapes(num_statistics):
112
+ # Since the downstream apis expect a jnp.array - we create a dummy one if
113
+ # num_statistics=0.
114
+ if not num_statistics:
115
+ return TrainingMetrics([[], jnp.float32])
116
+ else:
117
+ return TrainingMetrics([[num_statistics], jnp.float32])
118
+
119
+
120
+ def init_training_metrics_pspec():
121
+ return TrainingMetrics(pjit.PartitionSpec())
122
+
123
+
124
+ class ShardedShampooStats(NamedTuple):
125
+ """Shampoo state in sharded mode."""
126
+
127
+ global_stats: Any
128
+ local_stats: Any
129
+
130
+
131
+ class ShampooState(NamedTuple):
132
+ count: chex.Array
133
+ stats: Any
134
+
135
+
136
+ class InitFnState(NamedTuple):
137
+ init_fn: Any
138
+ pspec_fn: Any
139
+ shape_and_dtype_fn: Any
140
+
141
+
142
+ class GraftingType(enum.IntEnum):
143
+ SGD = 1
144
+ ADAGRAD = 2
145
+ RMSPROP = 3
146
+ RMSPROP_NORMALIZED = 4
147
+ SQRT_N = 5
148
+ ADAGRAD_NORMALIZED = 6
149
+
150
+
151
+ class PreconditionerType(enum.IntEnum):
152
+ # Default, computes preconditioner for each dim
153
+ ALL = 1
154
+ # One sided Shampoo, in this cases only on input dim.
155
+ # Assumes last dim is always the output dim and everything else input dim.
156
+ INPUT = 2
157
+
158
+
159
+ def power_iteration(
160
+ matrix,
161
+ num_iters=100,
162
+ error_tolerance=1e-6,
163
+ precision=lax.Precision.HIGHEST,
164
+ ):
165
+ r"""Power iteration algorithm.
166
+
167
+ The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
168
+ a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
169
+ of `A`, and a vector v, which is the corresponding eigenvector of `A`.
170
+
171
+ References:
172
+ [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
173
+
174
+ Args:
175
+ matrix: the symmetric PSD matrix.
176
+ num_iters: Number of iterations.
177
+ error_tolerance: Iterative exit condition.
178
+ precision: precision XLA related flag, the available options are: a)
179
+ lax.Precision.DEFAULT (better step time, but not precise) b)
180
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
181
+ (best possible precision, slowest)
182
+
183
+ Returns:
184
+ eigen vector, eigen value
185
+ """
186
+ matrix_size = matrix.shape[-1]
187
+
188
+ def _iter_condition(state):
189
+ i, unused_v, unused_s, unused_s_v, run_step = state
190
+ return jnp.logical_and(i < num_iters, run_step)
191
+
192
+ def _iter_body(state):
193
+ """One step of power iteration."""
194
+ i, new_v, s, s_v, unused_run_step = state
195
+ new_v = new_v / jnp.linalg.norm(new_v)
196
+
197
+ s_v = jnp.einsum("ij,j->i", matrix, new_v, precision=precision)
198
+ s_new = jnp.einsum("i,i->", new_v, s_v, precision=precision)
199
+ return (
200
+ i + 1,
201
+ s_v,
202
+ s_new,
203
+ s_v,
204
+ jnp.greater(jnp.abs(s_new - s), error_tolerance),
205
+ )
206
+
207
+ # Figure out how to use step as seed for random.
208
+ v_0 = (
209
+ np.random.RandomState(1729).uniform(-1.0, 1.0, matrix_size).astype(matrix.dtype)
210
+ )
211
+
212
+ init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
213
+ _, v_out, s_out, _, _ = lax.while_loop(_iter_condition, _iter_body, init_state)
214
+ v_out = v_out / jnp.linalg.norm(v_out)
215
+ return v_out, s_out
216
+
217
+
218
+ def mat_power(
219
+ mat_m,
220
+ p,
221
+ precision=lax.Precision.HIGHEST,
222
+ ):
223
+ """A simple matrix power method. M^p where p can be TracedValue."""
224
+ power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
225
+
226
+ def _iter_condition(state):
227
+ i, _, _ = state
228
+ return i > 0
229
+
230
+ def _iter_body(state):
231
+ i, power, mat = state
232
+
233
+ power = jax.lax.cond(
234
+ i % 2 == 1,
235
+ lambda: jnp.matmul(mat, power, precision=precision),
236
+ lambda: power,
237
+ )
238
+ i //= 2
239
+ mat = jnp.matmul(mat, mat, precision=precision)
240
+ return i, power, mat
241
+
242
+ _, result, _ = lax.while_loop(_iter_condition, _iter_body, (p, power, mat_m))
243
+ return result
244
+
245
+
246
+ def _pth_root_difference(w, alpha, beta, p):
247
+ """Computes (w+alpha)^(-1/p)-(w+beta)^(-1/p)."""
248
+
249
+ a = w + alpha
250
+ b = w + beta
251
+ a_minus_b = alpha - beta
252
+ exp = -1 / p
253
+
254
+ def _stable_subtract(b, a_minus_b):
255
+ # Mathematically identical to the target expression, with (w+beta)^(-1/p)
256
+ # term factored out and w cancellation in the subtraction.
257
+ return (b**exp) * jnp.expm1(exp * jnp.log1p(a_minus_b / b))
258
+
259
+ return jnp.where(
260
+ # Choose the branch with the best log1p approximation.
261
+ jnp.abs(a_minus_b / b) < jnp.abs(a_minus_b / a),
262
+ -_stable_subtract(a, -a_minus_b),
263
+ _stable_subtract(b, a_minus_b),
264
+ )
265
+
266
+
267
+ def matrix_inverse_pth_root(
268
+ matrix,
269
+ p,
270
+ num_iters=100,
271
+ ridge_epsilon=1e-6,
272
+ error_tolerance=1e-6,
273
+ precision=lax.Precision.HIGHEST,
274
+ relative_matrix_epsilon=True,
275
+ lobpcg_topk_precondition=0,
276
+ lobpcg_max_iter=0,
277
+ ):
278
+ """Computes `matrix^(-1/p)`, where `p` is a positive integer.
279
+
280
+ This function uses the Coupled newton iterations algorithm for
281
+ the computation of a matrix's inverse pth root.
282
+
283
+
284
+ References:
285
+ [Functions of Matrices, Theory and Computation,
286
+ Nicholas J Higham, Pg 184, Eq 7.18](
287
+ https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
288
+
289
+ Args:
290
+ matrix: the symmetric PSD matrix whose power it to be computed
291
+ p: exponent, for p a positive integer.
292
+ num_iters: Maximum number of iterations.
293
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
294
+ error_tolerance: Error indicator, useful for early termination.
295
+ precision: precision XLA related flag, the available options are: a)
296
+ lax.Precision.DEFAULT (better step time, but not precise) b)
297
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
298
+ (best possible precision, slowest)
299
+ relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
300
+ value when computing inverse-pth root.
301
+ lobpcg_topk_precondition: If nonzero, specifies the number of top
302
+ eigenvectors to subtract out before performing LOBPCG. Note this makes
303
+ relative_matrix_epsilon essentially free.
304
+ lobpcg_max_iter: Maximum iteration count for LOBPCG, defaults to
305
+ `lobpcg_topk_precondition`.
306
+
307
+ Returns:
308
+ matrix^(-1/p) and the error
309
+ """
310
+
311
+ # If the input is not square, materialize it from the concatenated form.
312
+ if matrix.shape[0] != matrix.shape[1]:
313
+ matrix = symmetric_matrices.materialize_matrix_from_concat(matrix)
314
+
315
+ assert matrix.shape[0] == matrix.shape[1]
316
+
317
+ # We use _MAT_INV_PTH_ROOT_DTYPE for the matrix inverse pth root.
318
+ # Switch to f64 if you have hardware that supports it. Enable the jax flag
319
+ # jax_enable_x64 for this to work.
320
+ matrix_size = matrix.shape[0]
321
+ orig_dtype = matrix.dtype
322
+ matrix = matrix.astype(_MAT_INV_PTH_ROOT_DTYPE)
323
+ alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE)
324
+ identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
325
+ original_matrix = matrix
326
+
327
+ if lobpcg_topk_precondition > 0:
328
+ # TODO(vladf): reuse previous top-k as the initial search directions
329
+ pad_shape = (matrix_size - lobpcg_topk_precondition, lobpcg_topk_precondition)
330
+ search_dirs = jnp.concatenate(
331
+ (jnp.eye(lobpcg_topk_precondition), jnp.zeros(pad_shape)), axis=0
332
+ )
333
+ eigvals, eigvecs, actual_iters = linalg.lobpcg_standard(
334
+ matrix,
335
+ search_dirs,
336
+ lobpcg_topk_precondition if lobpcg_max_iter == 0 else lobpcg_max_iter,
337
+ )
338
+ del actual_iters # TODO(vladf): return diagnostics dictionary
339
+
340
+ # The minimal eigenvalue among top-k becomes the maximal one in the whole
341
+ # matrix after deflation.
342
+ max_ev = jnp.min(eigvals)
343
+ deflation = eigvals - max_ev
344
+ scaled_vecs = eigvecs * jnp.sqrt(deflation)
345
+
346
+ # Deflate out top eigenvectors to reduce matrix condition number.
347
+ matrix -= scaled_vecs.dot(scaled_vecs.T, precision=jax.lax.Precision.HIGHEST)
348
+
349
+ # Only use power iteration if lobpcg wasn't already used to derive the
350
+ # top eigenvalue.
351
+ elif relative_matrix_epsilon:
352
+ _, max_ev = power_iteration(
353
+ matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision
354
+ )
355
+ eigvals, eigvecs = None, None # Unused but required by pytype.
356
+
357
+ # Use absolute matrix epsilon scaling otherwise.
358
+ else:
359
+ max_ev = 1.0
360
+ eigvals, eigvecs = None, None # Unused but required by pytype.
361
+
362
+ ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-6)
363
+
364
+ def _iter_condition(state):
365
+ (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state
366
+ error_above_threshold = jnp.logical_and(error > error_tolerance, run_step)
367
+ return jnp.logical_and(i < num_iters, error_above_threshold)
368
+
369
+ def _iter_body(state):
370
+ (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state
371
+ mat_m_i = (1 - alpha) * identity + alpha * mat_m
372
+ new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision)
373
+ new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision)
374
+ new_error = jnp.max(jnp.abs(new_mat_m - identity))
375
+ # sometimes error increases after an iteration before decreasing and
376
+ # converging. 1.2 factor is used to bound the maximal allowed increase.
377
+ return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, new_error < error * 1.2)
378
+
379
+ if matrix_size == 1:
380
+ resultant_mat_h = (matrix + ridge_epsilon) ** alpha
381
+ error = jnp.array(0, jnp.float32)
382
+ else:
383
+ damped_matrix = matrix + ridge_epsilon * identity
384
+
385
+ z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix))
386
+ new_mat_m_0 = damped_matrix * z
387
+ new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
388
+ new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
389
+ init_state = tuple([0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True])
390
+ _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
391
+ _iter_condition, _iter_body, init_state
392
+ )
393
+ error = jnp.max(jnp.abs(mat_m - identity)).astype(jnp.float32)
394
+ is_converged = jnp.asarray(convergence, old_mat_h.dtype)
395
+ resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
396
+ resultant_mat_h = jnp.asarray(resultant_mat_h, orig_dtype)
397
+
398
+ if lobpcg_topk_precondition > 0:
399
+ # Since we deflated the top eigenvectors prior to p-th root inverse,
400
+ # the resultant matrix has larger eigenvalues associated with those
401
+ # same eigenvectors, which we need to now re-deflate.
402
+ #
403
+ # Note that _pth_root_difference returns positive values for this
404
+ # particular argument ordering as min(eigvals) <= eigvals for the
405
+ # jnp.sqrt below.
406
+ pth_diff = _pth_root_difference(ridge_epsilon, jnp.min(eigvals), eigvals, p)
407
+ scaled_vecs = eigvecs * jnp.sqrt(pth_diff)
408
+ resultant_mat_h = (
409
+ resultant_mat_h.astype(scaled_vecs.dtype)
410
+ - scaled_vecs.dot(scaled_vecs.T, precision=jax.lax.Precision.HIGHEST)
411
+ ).astype(orig_dtype)
412
+ mat_m = jnp.matmul(
413
+ mat_power(resultant_mat_h, p),
414
+ original_matrix,
415
+ precision=jax.lax.Precision.HIGHEST,
416
+ )
417
+ error = jnp.max(jnp.abs(mat_m - identity)).astype(jnp.float32)
418
+
419
+ return resultant_mat_h, error
420
+
421
+
422
+ def merge_small_dims(shape_to_merge, max_dim):
423
+ """Merge small dimensions.
424
+
425
+ If there are some small dimensions, we collapse them:
426
+ e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
427
+ [1, 2, 768, 1, 2048] --> [2, 768, 2048]
428
+
429
+ Args:
430
+ shape_to_merge: Shape to merge small dimensions.
431
+ max_dim: Maximal dimension of output shape used in merging.
432
+
433
+ Returns:
434
+ Merged shape.
435
+ """
436
+ if shape_to_merge and np.all(np.array(shape_to_merge) == 1):
437
+ return [1]
438
+
439
+ resulting_shape = []
440
+ product = 1
441
+ for d in shape_to_merge:
442
+ if product * d <= max_dim:
443
+ product *= d
444
+ else:
445
+ if product > 1:
446
+ resulting_shape.append(product)
447
+ product = d
448
+ if product > 1:
449
+ resulting_shape.append(product)
450
+ return resulting_shape
451
+
452
+
453
+ def pad_square_matrix(mat, max_size):
454
+ """Pad a square matrix up to max_size.
455
+
456
+ Args:
457
+ mat: a matrix to pad.
458
+ max_size: matrix size requested.
459
+
460
+ Returns:
461
+ Given M returns [[M, 0], [0, I]]
462
+ """
463
+ rows, cols = mat.shape
464
+ if rows != cols:
465
+ raise ValueError(
466
+ f"Must have rows == cols, instead got rows={rows}, cols={cols}"
467
+ )
468
+ if cols > max_size:
469
+ raise ValueError(
470
+ f"Must have cols <= max_size. Instead got cols={cols}, max_size={max_size}."
471
+ )
472
+ if rows == max_size:
473
+ return mat
474
+ pad_size = max_size - rows
475
+
476
+ zs1 = jnp.zeros([rows, pad_size], dtype=mat.dtype)
477
+ zs2 = jnp.zeros([pad_size, rows], dtype=mat.dtype)
478
+ eye = jnp.eye(pad_size, dtype=mat.dtype)
479
+ mat = jnp.concatenate([mat, zs1], 1)
480
+ mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
481
+ return mat
482
+
483
+
484
+ def make_sliced_padding(
485
+ symmetric_block_size,
486
+ num_blocks,
487
+ starting_block,
488
+ dtype,
489
+ ):
490
+ """Returns padding for symmetric block matrix.
491
+
492
+ Specifically, the padding is given concatenated rectangular matrices
493
+ representing the lower-triangular rows below the starting block. For example,
494
+ if we want to pad the symmetric matrix
495
+
496
+ M = [[A, B^T]
497
+ [B, C]],
498
+
499
+ the desired output (in terms of the full matrix) with num_blocks = 4 is
500
+
501
+ M_padded = [[A, B^T, 0, 0]
502
+ [B, C, 0, 0]
503
+ [0, 0, I, 0]
504
+ 0, 0, 0, I].
505
+
506
+ We would represent M as the block matrix mat = [A, B, C]. In this form, the
507
+ additional padding to provide has form [0, 0, I, 0, 0, 0, I] (only the lower
508
+ triangular parts in the third and fourth rows).
509
+
510
+ Args:
511
+ symmetric_block_size: The size of each block.
512
+ num_blocks: The total number of blocks.
513
+ starting_block: The block where to start the padding.
514
+ dtype: The type to use for the blocks.
515
+ """
516
+ if starting_block == num_blocks:
517
+ return jnp.zeros(shape=(symmetric_block_size, 0), dtype=dtype)
518
+
519
+ blocks = []
520
+ for i in range(starting_block, num_blocks):
521
+ blocks.append(
522
+ jnp.zeros(
523
+ shape=(symmetric_block_size, symmetric_block_size * i), dtype=dtype
524
+ )
525
+ )
526
+ blocks.append(jnp.eye(symmetric_block_size, dtype=dtype))
527
+ return jnp.concatenate(blocks, axis=-1)
528
+
529
+
530
+ def pad_block_symmetric_matrix(
531
+ mat,
532
+ symmetric_block_size,
533
+ max_num_blocks,
534
+ ):
535
+ """Returns the padded blocked symmetric matrix.
536
+
537
+ The size of the padded matrix will be:
538
+ [symmetric_block_size, symmetric_block_size * max_num_blocks]
539
+
540
+ The input matrix can either:
541
+ - Be square with size less or equal to symmetric_block_size. In this case,
542
+ mat will first be padded to a square matrix of size symmetric_block_size,
543
+ and then be padded again up to the full size of the blocked matrix.
544
+ - Be a rectangle with number of rows equal to block size.
545
+ In this case, number of columns must be a multiple of number of rows, and
546
+ the ratio must correspond to a block representation of a symmetric matrix.
547
+ That is, the ratio must have form x * (x + 1) / 2. Here, x represents the
548
+ number of block rows represented by the matrix.
549
+
550
+ Args:
551
+ mat: The input block matrix.
552
+ symmetric_block_size: The size of blocks.
553
+ max_num_blocks: The largest number of blocks to pad to.
554
+ """
555
+ rows, cols = mat.shape
556
+ if rows > symmetric_block_size:
557
+ raise ValueError(
558
+ "Must have rows <= symmetric_block_size. Instead got "
559
+ f"rows={rows}, symmetric_block_size={symmetric_block_size}."
560
+ )
561
+ if rows > cols:
562
+ raise ValueError(
563
+ f"Must have rows <= cols, instead got rows={rows}, cols={cols}."
564
+ )
565
+ if cols > symmetric_block_size * max_num_blocks:
566
+ raise ValueError(
567
+ "Must have cols <= symmetric_block_size * max_num_blocks "
568
+ f"Instead got cols={cols}, "
569
+ f"symmetric_block_size={symmetric_block_size}, "
570
+ f"max_num_blocks={max_num_blocks}."
571
+ )
572
+ if rows < symmetric_block_size:
573
+ mat = pad_square_matrix(mat, max_size=symmetric_block_size)
574
+ # Update rows and cols after possibly padding in pad_square_matrix.
575
+ rows, cols = mat.shape
576
+ assert rows == symmetric_block_size
577
+ assert cols % rows == 0
578
+ filled_blocks = cols // rows
579
+ padding_blocks = make_sliced_padding(
580
+ symmetric_block_size=symmetric_block_size,
581
+ num_blocks=symmetric_matrices.num_blocks_from_total_blocks(max_num_blocks),
582
+ starting_block=symmetric_matrices.num_blocks_from_total_blocks(filled_blocks),
583
+ dtype=mat.dtype,
584
+ )
585
+ return jnp.concatenate([mat, padding_blocks], axis=-1)
586
+
587
+
588
+ def pad_vector(vec, max_size):
589
+ """Pad a vector to a max_size.
590
+
591
+ Args:
592
+ vec: a vector to pad.
593
+ max_size: matrix size requested.
594
+
595
+ Returns:
596
+ Given V returns [V, 0]
597
+ """
598
+ size = vec.shape[0]
599
+ assert size <= max_size
600
+ if size == max_size:
601
+ return vec
602
+ pad_size = max_size - size
603
+ zs1 = jnp.zeros([pad_size], dtype=vec.dtype)
604
+ return jnp.concatenate([vec, zs1], 0)
605
+
606
+
607
+ def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
608
+ """Avoids wasteful buffer allocation with XLA."""
609
+
610
+ def _iter_body(unused_state):
611
+ results = compute_fn(*args, **kwargs)
612
+ return tuple([False] + list(results))
613
+
614
+ def _iter_condition(state):
615
+ return state[0]
616
+
617
+ results = jax.lax.while_loop(
618
+ _iter_condition, _iter_body, tuple([predicate] + init_state)
619
+ )
620
+ return tuple(results[1:])
621
+
622
+
623
+ class BlockPartitioner:
624
+ """Partitions a tensor into smaller tensors."""
625
+
626
+ def __init__(self, param, block_size):
627
+ self._shape = param.shape
628
+ self._splits = []
629
+ split_sizes = []
630
+ # We split params into smaller blocks. Here we store the metadata to make
631
+ # that split.
632
+ for i, d in enumerate(param.shape):
633
+ if 0 < block_size < d:
634
+ # d-1, otherwise split appends a 0-size array.
635
+ nsplit = (d - 1) // block_size
636
+ indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size
637
+ sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size
638
+ sizes[-1] = d - indices[-1]
639
+ self._splits.append((i, indices))
640
+ split_sizes.append(sizes)
641
+ else:
642
+ split_sizes.append(np.array([d], dtype=np.int32))
643
+ self._split_sizes = split_sizes
644
+
645
+ def split_sizes(self):
646
+ return self._split_sizes
647
+
648
+ def partition(self, tensor):
649
+ """Partition tensor into blocks."""
650
+
651
+ assert tensor.shape == self._shape
652
+ tensors = [tensor]
653
+ for i, indices in self._splits:
654
+ tensors_local = []
655
+ for t in tensors:
656
+ tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
657
+ tensors = tensors_local
658
+ return tensors
659
+
660
+ def merge_partitions(self, partitions):
661
+ """Merge partitions back to original shape."""
662
+
663
+ for i, indices in reversed(self._splits):
664
+ n = len(indices) + 1
665
+ partial_merged_tensors = []
666
+ ind = 0
667
+ while ind < len(partitions):
668
+ partial_merged_tensors.append(
669
+ jnp.concatenate(partitions[ind : ind + n], axis=i)
670
+ )
671
+ ind += n
672
+ partitions = partial_merged_tensors
673
+ assert len(partitions) == 1
674
+ return partitions[0]
675
+
676
+
677
+ def gram_weighted_update(old_stats, g, axis, w1, w2, precision=None):
678
+ """Updated statistics via weighted average with new Gram matrix.
679
+
680
+ Returns w₁ R + w₂ Gᵀ G where R is `old_stats` and G is the matrix whose
681
+ columns are the flattened slices of the tensor `g` along the given `axis`.
682
+ (So, `old_stats` and the returned matrix have dimensions n x n where
683
+ n = `g.shape[axis]`).
684
+
685
+ Args:
686
+ old_stats: Old statistics.
687
+ g: Gradient tensor.
688
+ axis: Axis along which to slice `g`.
689
+ w1: Scalar weight for old statistics.
690
+ w2: Scalar weight for new Gram matrix.
691
+ precision: Optional precision XLA related flag, the available options are:
692
+ a) lax.Precision.DEFAULT (better step time, but not precise)
693
+ b) lax.Precision.HIGH (increased precision, slower)
694
+ c) lax.Precision.HIGHEST (best possible precision, slowest)
695
+
696
+ Returns:
697
+ Weighted average of old and new statistics.
698
+ """
699
+ axes = [i for i in range(g.ndim) if i != axis]
700
+ gram_matrix = jnp.tensordot(g, g, axes=(axes, axes), precision=precision)
701
+ return w1 * old_stats + w2 * gram_matrix
702
+
703
+
704
+ class Preconditioner:
705
+ """Compute statistics/shape from gradients for preconditioning."""
706
+
707
+ def __init__(
708
+ self,
709
+ param,
710
+ block_size,
711
+ merge_small_dims_block_size,
712
+ best_effort_shape_interpretation,
713
+ preconditioner_type=PreconditionerType.ALL,
714
+ ):
715
+ """Initializes the preconditioner.
716
+
717
+ Args:
718
+ param: parameter to precondition.
719
+ block_size: Block size used to split param.
720
+ merge_small_dims_block_size: Block size for merging dims.
721
+ best_effort_shape_interpretation: Whether to collapse/merge dims together.
722
+ preconditioner_type: Type of preconditioner to use.
723
+ """
724
+ self._original_shape = param.shape
725
+ self._transformed_shape = param.shape
726
+ if best_effort_shape_interpretation:
727
+ self._transformed_shape = merge_small_dims(
728
+ self._original_shape, merge_small_dims_block_size
729
+ )
730
+ reshaped_param = jnp.reshape(param, self._transformed_shape)
731
+ self._partitioner = BlockPartitioner(reshaped_param, block_size)
732
+ self._preconditioner_type = preconditioner_type
733
+
734
+ def updated_statistics_from_grad(
735
+ self,
736
+ stats,
737
+ grad,
738
+ w1,
739
+ w2,
740
+ to_float=None,
741
+ from_float=None,
742
+ precision=None,
743
+ ):
744
+ """Update statistics from gradients.
745
+
746
+ Args:
747
+ stats: Old statistics or its Cholesky factor if `cholesky` is True.
748
+ grad: Gradient to compute statistics from.
749
+ w1: Weight for old statistics.
750
+ w2: Weight for new statistics.
751
+ to_float: Optional function for converting stats to floating point.
752
+ from_float: Optional function for converting from floating point.
753
+ precision: Optional precision XLA related flag, the available options are:
754
+ a) lax.Precision.DEFAULT (better step time, but not precise)
755
+ b) lax.Precision.HIGH (increased precision, slower)
756
+ c) lax.Precision.HIGHEST (best possible precision, slowest)
757
+
758
+ Returns:
759
+ A list of updated gradient statistics for each partition.
760
+ """
761
+ to_float = to_float if to_float is not None else (lambda x: x)
762
+ from_float = from_float if from_float is not None else (lambda x: x)
763
+ update = functools.partial(gram_weighted_update, precision=precision)
764
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
765
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
766
+ new_stats = []
767
+ index = 0
768
+ for g in partitioned_grads:
769
+ should_preconditioned_dims = self.should_precondition_dims()
770
+ num_preconditioners = sum(should_preconditioned_dims)
771
+ for axis in range(num_preconditioners):
772
+ new_stat = update(to_float(stats[index]), g, axis, w1, w2)
773
+ new_stats.append(from_float(new_stat))
774
+ index += 1
775
+ return new_stats
776
+
777
+ def should_precondition_dims(self):
778
+ """A vector containing indicator indicating if the dim is preconditioned."""
779
+ split_sizes = self._partitioner.split_sizes()
780
+ rank = len(split_sizes)
781
+ if self._preconditioner_type == PreconditionerType.ALL or rank <= 1:
782
+ return [True] * rank
783
+ else:
784
+ return [True] * (rank - 1) + [False]
785
+
786
+ def shapes_for_preconditioners(self):
787
+ """Returns shape from statistics."""
788
+ split_sizes = self._partitioner.split_sizes()
789
+ rank = len(split_sizes)
790
+ # We ignore preconditioner types if rank == 1
791
+ preconditioner_shapes = []
792
+ for t in itertools.product(*split_sizes):
793
+ if self._preconditioner_type == PreconditionerType.ALL or rank <= 1:
794
+ preconditioner_shapes.extend([[d, d] for d in t])
795
+ else:
796
+ preconditioner_shapes.extend([[d, d] for d in t[:-1]])
797
+ return preconditioner_shapes
798
+
799
+ def exponent_for_preconditioner(self):
800
+ """Returns exponent to use for inverse-pth root M^{-1/p}."""
801
+ should_preconditioned_dims = self.should_precondition_dims()
802
+ num_preconditioners = sum(should_preconditioned_dims)
803
+ return 2 * num_preconditioners
804
+
805
+ def preconditioned_grad(self, grad, preconditioners):
806
+ """Precondition the gradient.
807
+
808
+ Args:
809
+ grad: A gradient tensor to precondition.
810
+ preconditioners: A list of preconditioners to apply.
811
+
812
+ Returns:
813
+ A preconditioned gradient.
814
+ """
815
+
816
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
817
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
818
+ preconditioned_partitioned_grads = []
819
+ for i, g in enumerate(partitioned_grads):
820
+ should_preconditioned_dims = self.should_precondition_dims()
821
+ num_preconditioners = sum(should_preconditioned_dims)
822
+ preconditioners_for_grad = preconditioners[
823
+ i * num_preconditioners : (i + 1) * num_preconditioners
824
+ ]
825
+ precond_g = g
826
+ rank = len(g.shape)
827
+ for j, precondition in enumerate(should_preconditioned_dims):
828
+ if precondition:
829
+ precond_g = jnp.tensordot(
830
+ precond_g, preconditioners_for_grad[j], axes=[[0], [0]]
831
+ )
832
+ else:
833
+ precond_g = jnp.transpose(precond_g, axes=(*range(1, rank), 0))
834
+ preconditioned_partitioned_grads.append(precond_g)
835
+ merged_grad = self._partitioner.merge_partitions(
836
+ preconditioned_partitioned_grads
837
+ )
838
+ return jnp.reshape(merged_grad, self._original_shape)
839
+
840
+
841
+ def _convert_to_parameter_stats(global_stats, local_stat, convert_statistics=True):
842
+ """Creates parameter stats from sharded stats."""
843
+ index_start = int(local_stat.index_start)
844
+ index_end = int(len(local_stat.sizes)) + index_start
845
+ statistics = global_stats.statistics[index_start:index_end, :, :]
846
+ preconditioners = global_stats.preconditioners[index_start:index_end, :, :]
847
+ new_statistics = []
848
+ new_preconditioners = []
849
+ for i, size in enumerate(local_stat.sizes):
850
+ new_statistics.append(statistics[i][:size, :size])
851
+ new_preconditioners.append(preconditioners[i][:size, :size])
852
+ if not convert_statistics:
853
+ new_statistics = None
854
+ return ParameterStats(
855
+ local_stat.diagonal_statistics,
856
+ new_statistics,
857
+ new_preconditioners,
858
+ local_stat.diagonal_momentum,
859
+ local_stat.momentum,
860
+ local_stat.training_metrics,
861
+ )
862
+
863
+
864
+ def _convert_from_parameter_stats(parameter_stats, local_stats):
865
+ """Creates sharded stats from paramter stats."""
866
+ return LocalShardedParameterStats(
867
+ parameter_stats.diagonal_statistics,
868
+ parameter_stats.diagonal_momentum,
869
+ parameter_stats.momentum,
870
+ parameter_stats.training_metrics,
871
+ local_stats.index_start,
872
+ local_stats.sizes,
873
+ )
874
+
875
+
876
+ def _add_error_into_local_stats(local_stats, errors, inverse_failure_threshold):
877
+ """Adds errors back into local statistics."""
878
+ new_local_stats = []
879
+ for local_stat in local_stats:
880
+ if local_stat.sizes:
881
+ index_start = int(local_stat.index_start)
882
+ index_end = int(len(local_stat.sizes)) + index_start
883
+ per_stat_error = errors[index_start:index_end]
884
+ else:
885
+ per_stat_error = jnp.array(0, jnp.float32)
886
+ if local_stat.sizes:
887
+ per_stat_error = jnp.where(
888
+ jnp.logical_and(
889
+ per_stat_error > 0.0, per_stat_error != inverse_failure_threshold
890
+ ),
891
+ per_stat_error,
892
+ local_stat.training_metrics.inverse_pth_root_errors,
893
+ )
894
+ new_local_stats.append(
895
+ LocalShardedParameterStats(
896
+ local_stat.diagonal_statistics,
897
+ local_stat.diagonal_momentum,
898
+ local_stat.momentum,
899
+ TrainingMetrics(per_stat_error),
900
+ local_stat.index_start,
901
+ local_stat.sizes,
902
+ )
903
+ )
904
+ return new_local_stats
905
+
906
+
907
+ def batch(x, num_devices):
908
+ """Batch `x` so that so that leading axis is num_devices."""
909
+ n = len(x)
910
+ b = int(n / num_devices)
911
+ return jnp.stack([jnp.stack(x[idx : idx + b]) for idx in range(0, n, b)])
912
+
913
+
914
+ def unbatch(batched_values):
915
+ """Unbatch values across leading axis and return a list of elements."""
916
+ b1, b2 = batched_values.shape[0], batched_values.shape[1]
917
+ results = []
918
+ for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0):
919
+ v_array = jnp.squeeze(v_array)
920
+ # b2 = batches (number of preconditioner computation) per core.
921
+ if b2 > 1:
922
+ for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
923
+ results.append(jnp.squeeze(v))
924
+ else:
925
+ results.append(v_array)
926
+ return results
927
+
928
+
929
+ def distributed_shampoo(
930
+ learning_rate,
931
+ block_size,
932
+ beta1=0.9,
933
+ beta2=0.999,
934
+ diagonal_epsilon=1e-10,
935
+ matrix_epsilon=1e-6,
936
+ weight_decay=0.0,
937
+ start_preconditioning_step=5,
938
+ preconditioning_compute_steps=1,
939
+ statistics_compute_steps=1,
940
+ best_effort_shape_interpretation=True,
941
+ graft_type=GraftingType.SGD,
942
+ nesterov=True,
943
+ exponent_override=0,
944
+ # Pass pmap 'batch axis name' in pmap mode.
945
+ batch_axis_name=None,
946
+ ### Only set following 3 params in pjit/spmd mode.
947
+ ### WARNING: Experimental
948
+ statistics_partition_spec=None,
949
+ preconditioner_partition_spec=None,
950
+ num_devices_for_pjit=None,
951
+ shard_optimizer_states=False,
952
+ ###
953
+ ### Experimental memory reduction mode
954
+ best_effort_memory_usage_reduction=False,
955
+ ###
956
+ inverse_failure_threshold=0.1,
957
+ moving_average_for_momentum=False,
958
+ skip_preconditioning_dim_size_gt=4096,
959
+ clip_by_scaled_gradient_norm=None,
960
+ precision=lax.Precision.HIGHEST,
961
+ tensordot_precision=None,
962
+ relative_matrix_epsilon=True,
963
+ merge_small_dims_block_size=4096,
964
+ lobpcg_topk_precondition=0,
965
+ lobpcg_max_iter=0,
966
+ precondtioner_type=PreconditionerType.ALL,
967
+ skip_preconditioning_rank_lt=1,
968
+ decoupled_learning_rate=True,
969
+ decoupled_weight_decay=False,
970
+ ):
971
+ """Distributed Shampoo optimizer.
972
+
973
+ Distributed Shampoo is a second-order preconditioned method (concretely, a
974
+ variant of full-matrix Adagrad), that provides significant convergence and
975
+ wall-clock time improvements compared to conventional first-order methods,
976
+ and that has been shown to scale to large state-of-the-art deep learning
977
+ models.
978
+
979
+ References:
980
+ Scalable Second Order Optimization for Deep Learning,
981
+ Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
982
+
983
+ Preprint: https://arxiv.org/abs/2002.09018
984
+
985
+ Args:
986
+ learning_rate: the step size used to update the parameters.
987
+ block_size: Block size for large layers (if > 0). Preconditioning compute
988
+ operation is cubic in the dimension of the tensor. Block size allows us to
989
+ chunk the layers into sub-layers of maximal dimension dictated by this
990
+ value. Use 128 as default (increase if you have compute budget).
991
+ beta1: momentum parameter.
992
+ beta2: second moment averaging parameter.
993
+ diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
994
+ to AdaGrad is enabled).
995
+ matrix_epsilon: epsilon to add to statistics before computing inverse pth
996
+ root. If you are running in f32 precision for inverse pth root
997
+ (recommended today) this can go upto 1e-6. If you have latest hardware
998
+ with native f64 precision, set this upto 1e-12.
999
+ weight_decay: Weight decay for regularization.
1000
+ start_preconditioning_step: When to start Shampoo update before which
1001
+ diagonal update is used. This is because we dont have enough information
1002
+ to do stable inverse.
1003
+ preconditioning_compute_steps: How often to compute preconditioner.
1004
+ Performance tuning params for controlling memory and compute requirements.
1005
+ Ideally set this and statistics_compute_steps params to 1.
1006
+ statistics_compute_steps: How often to compute statistics.
1007
+ best_effort_shape_interpretation: If there are some small dimensions,
1008
+ collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
1009
+ block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
1010
+ graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
1011
+ optimizer. This allows us to plugin the Shampoo optimizer into settings
1012
+ where SGD/AdaGrad is already well tuned.
1013
+ nesterov: Nesterov momentum.
1014
+ exponent_override: Override the exponent used in matrix inverse.
1015
+ batch_axis_name: labeled axis over pmap for data-parallel training the
1016
+ optimizer used for.
1017
+ statistics_partition_spec: PartitionSpec to be used in sharded mode.
1018
+ preconditioner_partition_spec: PartitionSpec to be used in sharded mode.
1019
+ num_devices_for_pjit: Number of devices to parallelize over when using pjit.
1020
+ shard_optimizer_states: Shard optimizer states to save memory in model
1021
+ parallel training.
1022
+ best_effort_memory_usage_reduction: Best effort memory usage reduction. -
1023
+ diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x) -> jnp.int8 -
1024
+ statistics, preconditioners -> jnp.int16 + diagonals
1025
+ inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
1026
+ determine that using this threshold.
1027
+ moving_average_for_momentum: Whether to use moving average for momentum
1028
+ instead of exponential moving average.
1029
+ skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
1030
+ greater than this value.
1031
+ clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful when
1032
+ using RMSProp Grafting).
1033
+ precision: precision XLA related flag, the available options are: a)
1034
+ lax.Precision.DEFAULT (better step time, but not precise) b)
1035
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
1036
+ (best possible precision, slowest)
1037
+ tensordot_precision: Optional precision to use for the tensordot operation
1038
+ when computing statistics (e.g., G Gᵀ). Same options as `precision` above.
1039
+ relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
1040
+ value when computing inverse-pth root.
1041
+ merge_small_dims_block_size: Used as the maximum block size
1042
+ to merge the shapes.
1043
+ lobpcg_topk_precondition: If nonzero, specifies the number of top
1044
+ eigenvectors to subtract out before performing LOBPCG. Note this makes
1045
+ relative_matrix_epsilon essentially free.
1046
+ lobpcg_max_iter: Number of LOBPCG iterations, if zero defaults to
1047
+ `lobpcg_topk_precondition`.
1048
+ precondtioner_type: Preconditioner type to select all, left only or right
1049
+ only preconditioners.
1050
+ skip_preconditioning_rank_lt: Skips preconditioning for parameters with
1051
+ rank less than this value.
1052
+ decoupled_learning_rate: If True, use decoupled learning rate, otherwise
1053
+ couple it with preconditioned gradient computation. (Default True)
1054
+ decoupled_weight_decay: If True, use decoupled weight decay, otherwise
1055
+ couple with weight decay. (Default False)
1056
+ Returns:
1057
+ a GradientTransformation.
1058
+ """
1059
+
1060
+ def _graft_type_has_diagonal_statistics():
1061
+ """Returns True if using diagonal firt order method for grafting."""
1062
+ return graft_type != GraftingType.SGD and graft_type != GraftingType.SQRT_N
1063
+
1064
+ def quantized_dtype_for_momentum_buffers(var):
1065
+ return (
1066
+ jnp.int8
1067
+ if best_effort_memory_usage_reduction and len(var.shape) > 1
1068
+ else jnp.float32
1069
+ )
1070
+
1071
+ # Preconditioner and statistics are both stores as int16 in this mode.
1072
+ # We take out the diagonal to make quantization easier.
1073
+ def quantized_dtype_for_second_moment_statistics_buffers():
1074
+ return (
1075
+ jnp.int16
1076
+ if best_effort_memory_usage_reduction and batch_axis_name
1077
+ else jnp.float32
1078
+ )
1079
+
1080
+ # Preconditioner and statistics are both stores as int16 in this mode.
1081
+ # We take out the diagonal to make quantization easier.
1082
+ def quantized_dtype_for_second_moment_preconditioner_buffers():
1083
+ return (
1084
+ jnp.int16
1085
+ if best_effort_memory_usage_reduction and batch_axis_name
1086
+ else jnp.float32
1087
+ )
1088
+
1089
+ def _to_float(maybe_quantized):
1090
+ if isinstance(maybe_quantized, QuantizedValue):
1091
+ return maybe_quantized.to_float()
1092
+ else:
1093
+ return maybe_quantized
1094
+
1095
+ def _maybe_quantize_statistics(statistics_list):
1096
+ return _maybe_quantize_matrices_with_dtype(
1097
+ statistics_list, quantized_dtype_for_second_moment_statistics_buffers()
1098
+ )
1099
+
1100
+ def _maybe_quantize_preconditioners(statistics_list):
1101
+ return _maybe_quantize_matrices_with_dtype(
1102
+ statistics_list, quantized_dtype_for_second_moment_preconditioner_buffers()
1103
+ )
1104
+
1105
+ def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype):
1106
+ if quantized_dtype != jnp.float32:
1107
+ return [
1108
+ QuantizedValue.from_float_value(
1109
+ s, quantized_dtype, extract_diagonal=True
1110
+ )
1111
+ for s in statistics_list
1112
+ ]
1113
+ else:
1114
+ return statistics_list
1115
+
1116
+ def _maybe_dequantize_preconditioners(preconditioner_list):
1117
+ return _maybe_dequantize_matrices_with_dtype(
1118
+ preconditioner_list,
1119
+ quantized_dtype_for_second_moment_preconditioner_buffers(),
1120
+ )
1121
+
1122
+ def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype):
1123
+ if quantized_dtype != jnp.float32:
1124
+ return [s.to_float() for s in statistics_list]
1125
+ else:
1126
+ return statistics_list
1127
+
1128
+ def _quantize_diagonal_statistics(diagonal_statistics):
1129
+ return QuantizedValue.from_float_value(diagonal_statistics, jnp.float32)
1130
+
1131
+ def _quantize_momentum(momentum_statistics):
1132
+ return QuantizedValue.from_float_value(
1133
+ momentum_statistics,
1134
+ quantized_dtype_for_momentum_buffers(momentum_statistics),
1135
+ )
1136
+
1137
+ def preconditioner_from_params(param):
1138
+ """Returns a Preconditioner object for given param."""
1139
+ return Preconditioner(
1140
+ param,
1141
+ block_size,
1142
+ merge_small_dims_block_size,
1143
+ best_effort_shape_interpretation,
1144
+ precondtioner_type,
1145
+ )
1146
+
1147
+ def sharded_init_fn(params):
1148
+ """Returns optimizer state (for PJIT mode).
1149
+
1150
+ Args:
1151
+ params: the parameters that should be updated.
1152
+ """
1153
+ params_flat, treedef = jax.tree_flatten(params)
1154
+ # Find max size to pad to.
1155
+ max_size = 0
1156
+ for param in params_flat:
1157
+ preconditioner = preconditioner_from_params(param)
1158
+ if not _skip_preconditioning(param):
1159
+ shapes = preconditioner.shapes_for_preconditioners()
1160
+ sizes = [s[0] for s in shapes]
1161
+ max_size = max(max(sizes), max_size)
1162
+
1163
+ padded_statistics = []
1164
+ padded_preconditioners = []
1165
+ local_stats_flat = []
1166
+ exponents = []
1167
+ for param in params_flat:
1168
+ preconditioner = preconditioner_from_params(param)
1169
+ shapes = preconditioner.shapes_for_preconditioners()
1170
+ sizes = []
1171
+
1172
+ statistics = []
1173
+ preconditioners = []
1174
+ index_start = len(padded_statistics)
1175
+ if not _skip_preconditioning(param):
1176
+ sizes = [s[0] for s in shapes]
1177
+ shapes = preconditioner.shapes_for_preconditioners()
1178
+ statistics = [
1179
+ matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32)
1180
+ for s in shapes
1181
+ ]
1182
+ preconditioners = [jnp.eye(max_size, dtype=jnp.float32) for s in shapes]
1183
+ padded_statistics.extend(statistics)
1184
+ padded_preconditioners.extend(preconditioners)
1185
+ exponent = (
1186
+ preconditioner.exponent_for_preconditioner()
1187
+ if exponent_override == 0
1188
+ else exponent_override
1189
+ )
1190
+ exponents.extend([exponent] * len(shapes))
1191
+
1192
+ diagonal_statistics = _quantize_diagonal_statistics(jnp.zeros_like(param))
1193
+ diagonal_momentum = _quantize_momentum(jnp.zeros_like(param))
1194
+ momentum = _quantize_momentum(jnp.zeros_like(param))
1195
+
1196
+ local_stats_flat.append(
1197
+ LocalShardedParameterStats(
1198
+ diagonal_statistics,
1199
+ diagonal_momentum,
1200
+ momentum,
1201
+ init_training_metrics(len(sizes)),
1202
+ index_start,
1203
+ sizes,
1204
+ )
1205
+ )
1206
+
1207
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
1208
+ to_pad = -len(padded_statistics) % num_devices_for_pjit
1209
+ if max_size == 0:
1210
+ to_pad = num_devices_for_pjit
1211
+ max_size = block_size
1212
+ stat_dtype = jnp.float32
1213
+ else:
1214
+ stat_dtype = padded_statistics[0].dtype
1215
+ # Pad the statistics and preconditioner matrices to be a multiple of
1216
+ # num devices.
1217
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
1218
+ # is split on.
1219
+ padded_statistics.extend(
1220
+ [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
1221
+ )
1222
+ padded_preconditioners.extend(
1223
+ [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
1224
+ )
1225
+ exponents.extend([1 for _ in range(to_pad)])
1226
+ global_stats = GlobalShardedParameterStats(
1227
+ jnp.stack(padded_statistics),
1228
+ jnp.stack(padded_preconditioners),
1229
+ jnp.stack(exponents),
1230
+ )
1231
+ return ShampooState(
1232
+ count=jnp.zeros([], jnp.int32),
1233
+ stats=ShardedShampooStats(global_stats, local_stats),
1234
+ )
1235
+
1236
+ def _max_statistics_size_from_params(params):
1237
+ max_size = 0
1238
+ for param in params:
1239
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
1240
+ preconditioner = preconditioner_from_params(param_clone)
1241
+ if not _skip_preconditioning(param):
1242
+ shapes = preconditioner.shapes_for_preconditioners()
1243
+ sizes = [s[0] for s in shapes]
1244
+ max_size = max(max(sizes), max_size)
1245
+ return max_size
1246
+
1247
+ def _remove_leading_sharding_annotation(pspec):
1248
+ """Mapping from N-d to (N-1)-d, used for quantization, factoring etc."""
1249
+ # None and PSpec(None) are valid PSpecs.
1250
+ if pspec and len(pspec) > 1:
1251
+ return pjit.PartitionSpec(*pspec[1:])
1252
+ else:
1253
+ return []
1254
+
1255
+ def sharded_init_partition_spec_fn(
1256
+ params, params_partition_spec, partition_spec_for_statistics
1257
+ ):
1258
+ """Returns a parallel state tree with PartitionSpec associated with state.
1259
+
1260
+
1261
+ Args:
1262
+ params: A pytree with params.
1263
+ params_partition_spec: A pytree with PartitionSpec for params.
1264
+ partition_spec_for_statistics: PartitionSpec for the statistics.
1265
+ """
1266
+ # Parallel lists of spec, and params.
1267
+ param_pspec_flat, _ = jax.tree_flatten(
1268
+ params_partition_spec, is_leaf=lambda x: x is None
1269
+ )
1270
+ params_flat, treedef = jax.tree_flatten(params)
1271
+ assert param_pspec_flat
1272
+ assert params_flat
1273
+ # Step is replicated across cores.
1274
+ # None means cores.
1275
+ local_stats_flat = []
1276
+ num_statistics = 0
1277
+ for param, param_pspec in zip(params_flat, param_pspec_flat):
1278
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
1279
+ preconditioner = preconditioner_from_params(param_clone)
1280
+ shapes = preconditioner.shapes_for_preconditioners()
1281
+ sizes = []
1282
+
1283
+ index_start = num_statistics
1284
+ if not _skip_preconditioning(param):
1285
+ sizes = [s[0] for s in shapes]
1286
+ shapes = preconditioner.shapes_for_preconditioners()
1287
+ num_statistics += len(shapes)
1288
+
1289
+ qdtype = quantized_dtype_for_momentum_buffers(param)
1290
+ m1_pspec = param_pspec
1291
+ m2_pspec = param_pspec
1292
+ m1_scale_pspec = []
1293
+ m2_scale_pspec = []
1294
+ if qdtype != jnp.float32:
1295
+ m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
1296
+ m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec)
1297
+
1298
+ local_stats_flat.append(
1299
+ LocalShardedParameterStats(
1300
+ QuantizedValue(
1301
+ param_pspec, [], [], jnp.float32, False, list(param.shape)
1302
+ ),
1303
+ QuantizedValue(
1304
+ m1_pspec, [], m1_scale_pspec, qdtype, False, list(param.shape)
1305
+ ),
1306
+ QuantizedValue(
1307
+ m2_pspec, [], m2_scale_pspec, qdtype, False, list(param.shape)
1308
+ ),
1309
+ init_training_metrics_pspec(),
1310
+ index_start,
1311
+ sizes,
1312
+ )
1313
+ )
1314
+
1315
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
1316
+ global_stats = GlobalShardedParameterStats(
1317
+ partition_spec_for_statistics,
1318
+ partition_spec_for_statistics,
1319
+ pjit.PartitionSpec(),
1320
+ )
1321
+ count_pspec = pjit.PartitionSpec()
1322
+ return ShampooState(
1323
+ count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats)
1324
+ )
1325
+
1326
+ def sharded_init_shape_and_dtype_fn(params):
1327
+ """Returns a parallel state tree with shape, dtype associated with state.
1328
+
1329
+
1330
+ Args:
1331
+ params: A pytree with params.
1332
+ """
1333
+ # Parallel lists of spec, and params.
1334
+ params_flat, treedef = jax.tree_flatten(params)
1335
+ assert params_flat
1336
+ # Step is replicated across cores.
1337
+ # None means cores.
1338
+ local_stats_flat = []
1339
+ num_statistics = 0
1340
+ for param in params_flat:
1341
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
1342
+ preconditioner = preconditioner_from_params(param_clone)
1343
+ shapes = preconditioner.shapes_for_preconditioners()
1344
+ sizes = []
1345
+
1346
+ index_start = num_statistics
1347
+ if not _skip_preconditioning(param):
1348
+ sizes = [s[0] for s in shapes]
1349
+ shapes = preconditioner.shapes_for_preconditioners()
1350
+ num_statistics += len(shapes)
1351
+
1352
+ qdtype = quantized_dtype_for_momentum_buffers(param)
1353
+ m1_shape_and_dtype = [list(param.shape), param.dtype]
1354
+ m2_shape_and_dtype = [list(param.shape), param.dtype]
1355
+ m1_scale_shape_and_dtype = []
1356
+ m2_scale_shape_and_dtype = []
1357
+ if qdtype != jnp.float32:
1358
+ m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1359
+ m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1360
+
1361
+ diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype]
1362
+ local_stats_flat.append(
1363
+ LocalShardedParameterStats(
1364
+ QuantizedValue(
1365
+ diagonal_statistics_shape_and_dtype,
1366
+ [],
1367
+ [],
1368
+ jnp.float32,
1369
+ False,
1370
+ list(param.shape),
1371
+ ),
1372
+ QuantizedValue(
1373
+ m1_shape_and_dtype,
1374
+ [],
1375
+ m1_scale_shape_and_dtype,
1376
+ qdtype,
1377
+ False,
1378
+ list(param.shape),
1379
+ ),
1380
+ QuantizedValue(
1381
+ m2_shape_and_dtype,
1382
+ [],
1383
+ m2_scale_shape_and_dtype,
1384
+ qdtype,
1385
+ False,
1386
+ list(param.shape),
1387
+ ),
1388
+ init_training_metrics_shapes(len(sizes)),
1389
+ index_start,
1390
+ sizes,
1391
+ )
1392
+ )
1393
+
1394
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
1395
+ max_statistics_size = _max_statistics_size_from_params(params_flat)
1396
+ to_pad = -num_statistics % num_devices_for_pjit
1397
+ num_statistics += to_pad
1398
+ if num_statistics == 0:
1399
+ num_statistics = num_devices_for_pjit
1400
+ max_statistics_size = block_size
1401
+ statistics_shape = [num_statistics, max_statistics_size, max_statistics_size]
1402
+ global_stats = GlobalShardedParameterStats(
1403
+ [statistics_shape, jnp.float32],
1404
+ [statistics_shape, jnp.float32],
1405
+ [[num_statistics], jnp.int32],
1406
+ )
1407
+ return ShampooState(
1408
+ count=[[], jnp.float32],
1409
+ stats=ShardedShampooStats(global_stats, local_stats),
1410
+ )
1411
+
1412
+ def sharded_update_fn(grads, state, params):
1413
+ """Transform the input gradient and update all statistics in sharded mode.
1414
+
1415
+ Args:
1416
+ grads: the gradient tensors for the parameters.
1417
+ state: a named tuple containing the state of the optimizer
1418
+ params: the parameters that should be updated.
1419
+
1420
+ Returns:
1421
+ A tuple containing the new parameters and the new optimizer state.
1422
+ """
1423
+ params_flat, treedef = jax.tree_flatten(params)
1424
+ grads_flat = treedef.flatten_up_to(grads)
1425
+
1426
+ global_stats = state.stats.global_stats
1427
+ local_stats_flat = treedef.flatten_up_to(state.stats.local_stats)
1428
+ stats_flat = [
1429
+ _convert_to_parameter_stats(global_stats, local_stat)
1430
+ for local_stat in local_stats_flat
1431
+ ]
1432
+ new_stats_flat = jax.tree_map(
1433
+ lambda g, s, p: _compute_stats(g, s, p, state.count),
1434
+ grads_flat,
1435
+ stats_flat,
1436
+ params_flat,
1437
+ )
1438
+
1439
+ outputs = jax.tree_map(
1440
+ lambda g, s, p: _transform_grad(g, s, p, state.count),
1441
+ grads_flat,
1442
+ new_stats_flat,
1443
+ params_flat,
1444
+ )
1445
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
1446
+
1447
+ updates = jax.tree_unflatten(treedef, updates_flat)
1448
+ # Create new local_stats
1449
+ new_local_stats_flat = [
1450
+ _convert_from_parameter_stats(new_stat, local_stat)
1451
+ for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
1452
+ ]
1453
+
1454
+ max_size = global_stats.statistics.shape[1]
1455
+ new_padded_statistics = []
1456
+ for stat in new_stats_flat:
1457
+ new_padded_statistics.extend(
1458
+ [pad_square_matrix(stat, max_size) for stat in stat.statistics]
1459
+ )
1460
+
1461
+ # Create global stats
1462
+ # TODO(rohananil): Preconditioner is not updated every step, so cost of
1463
+ # stack/pad can be obviated away.
1464
+ # Pad the statistics and preconditioner matrices to be a multiple of
1465
+ # num devices.
1466
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
1467
+ # is split on.
1468
+ to_pad = -len(new_padded_statistics) % num_devices_for_pjit
1469
+ if not new_padded_statistics:
1470
+ to_pad = num_devices_for_pjit
1471
+ stat_dtype = jnp.float32
1472
+ else:
1473
+ stat_dtype = new_padded_statistics[0].dtype
1474
+
1475
+ new_padded_statistics.extend(
1476
+ [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
1477
+ )
1478
+ new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
1479
+ new_stacked_padded_statistics = pjit.with_sharding_constraint(
1480
+ new_stacked_padded_statistics, statistics_partition_spec
1481
+ )
1482
+
1483
+ def _internal_inverse_pth_root_all():
1484
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
1485
+ new_stacked_padded_statistics,
1486
+ global_stats.exponents,
1487
+ statistics_partition_spec,
1488
+ )
1489
+ return preconditioners, errors
1490
+
1491
+ if preconditioning_compute_steps == 1:
1492
+ new_preconditioners, errors = _internal_inverse_pth_root_all()
1493
+ else:
1494
+ # Passing statistics instead of preconditioners as they are similarly
1495
+ # shaped tensors. Note statistics will be ignored as we are passing in
1496
+ # a large init value for error.
1497
+ preconditioners_init = new_stacked_padded_statistics
1498
+ n = new_stacked_padded_statistics.shape[0]
1499
+ errors_init = jnp.ones([n], jnp.float32) * inverse_failure_threshold
1500
+ init_state = [preconditioners_init, errors_init]
1501
+ perform_step = state.count % preconditioning_compute_steps == 0
1502
+ new_preconditioners, errors = efficient_cond(
1503
+ perform_step, _internal_inverse_pth_root_all, init_state
1504
+ )
1505
+
1506
+ new_local_stats_flat = _add_error_into_local_stats(
1507
+ new_local_stats_flat, errors, inverse_failure_threshold
1508
+ )
1509
+ new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
1510
+ errors = errors.reshape((-1, 1, 1))
1511
+ predicate = jnp.logical_or(
1512
+ jnp.isnan(errors), errors >= inverse_failure_threshold
1513
+ ).astype(new_preconditioners.dtype)
1514
+ # TODO(rohananil): Check for numerical instabilities.
1515
+ new_conditional_preconditioners = (
1516
+ predicate * global_stats.preconditioners
1517
+ + (1.0 - predicate) * new_preconditioners
1518
+ )
1519
+ new_global_stats = GlobalShardedParameterStats(
1520
+ new_stacked_padded_statistics,
1521
+ new_conditional_preconditioners,
1522
+ global_stats.exponents,
1523
+ )
1524
+ new_shampoo_state = ShampooState(
1525
+ count=state.count + 1,
1526
+ stats=ShardedShampooStats(new_global_stats, new_local_stats),
1527
+ )
1528
+ return updates, new_shampoo_state
1529
+
1530
+ def init_fn(params):
1531
+ """Initialise the optimiser's state."""
1532
+
1533
+ def _init(param):
1534
+ preconditioner = preconditioner_from_params(param)
1535
+ statistics = []
1536
+ preconditioners = []
1537
+ if not _skip_preconditioning(param):
1538
+ shapes = preconditioner.shapes_for_preconditioners()
1539
+ statistics = [
1540
+ matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes
1541
+ ]
1542
+ preconditioners = [jnp.eye(s[0], dtype=jnp.float32) for s in shapes]
1543
+
1544
+ diagonal_statistics = []
1545
+ if _graft_type_has_diagonal_statistics():
1546
+ diagonal_statistics = jnp.zeros_like(param)
1547
+
1548
+ diagonal_momentum = _quantize_momentum(jnp.zeros_like(param))
1549
+ momentum = _quantize_momentum(jnp.zeros_like(param))
1550
+
1551
+ return ParameterStats(
1552
+ _quantize_diagonal_statistics(diagonal_statistics),
1553
+ _maybe_quantize_statistics(statistics),
1554
+ _maybe_quantize_preconditioners(preconditioners),
1555
+ diagonal_momentum,
1556
+ momentum,
1557
+ init_training_metrics(len(statistics)),
1558
+ )
1559
+
1560
+ return ShampooState(
1561
+ count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)
1562
+ )
1563
+
1564
+ def _skip_preconditioning(param):
1565
+ return len(param.shape) < skip_preconditioning_rank_lt or any(
1566
+ [s > skip_preconditioning_dim_size_gt for s in param.shape]
1567
+ )
1568
+
1569
+ def _compute_stats(grad, state, param, step):
1570
+ """Compute per-parameter statistics."""
1571
+ preconditioner = preconditioner_from_params(param)
1572
+ new_statistics = [[]] * len(state.statistics)
1573
+ w1 = beta2
1574
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
1575
+ if not _skip_preconditioning(param):
1576
+
1577
+ def compute_updated_statistics():
1578
+ return preconditioner.updated_statistics_from_grad(
1579
+ state.statistics,
1580
+ grad,
1581
+ w1=w1,
1582
+ w2=w2,
1583
+ to_float=_to_float,
1584
+ from_float=lambda x: _maybe_quantize_statistics([x])[0],
1585
+ precision=tensordot_precision,
1586
+ )
1587
+
1588
+ if statistics_compute_steps > 1:
1589
+ perform_step = step % statistics_compute_steps == 0
1590
+ init_state = state.statistics
1591
+ new_statistics = list(
1592
+ efficient_cond(perform_step, compute_updated_statistics, init_state)
1593
+ )
1594
+ else:
1595
+ new_statistics = compute_updated_statistics()
1596
+ return ParameterStats(
1597
+ state.diagonal_statistics,
1598
+ new_statistics,
1599
+ state.preconditioners,
1600
+ state.diagonal_momentum,
1601
+ state.momentum,
1602
+ state.training_metrics,
1603
+ )
1604
+
1605
+ mi_pth_root = functools.partial(
1606
+ matrix_inverse_pth_root,
1607
+ ridge_epsilon=matrix_epsilon,
1608
+ precision=precision,
1609
+ relative_matrix_epsilon=relative_matrix_epsilon,
1610
+ lobpcg_topk_precondition=lobpcg_topk_precondition,
1611
+ lobpcg_max_iter=lobpcg_max_iter,
1612
+ )
1613
+
1614
+ def _matrix_inverse_pth_root_vmap(xs, ps):
1615
+ return jax.vmap(mi_pth_root)(xs, ps)
1616
+
1617
+ def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps):
1618
+ def _quantized_to_float(qx, qd, qb):
1619
+ qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape))
1620
+ return qv.to_float()
1621
+
1622
+ def matrix_inverse_pth_root_wrapper(qx, qd, qb, p):
1623
+ v = _quantized_to_float(qx, qd, qb)
1624
+ preconditioner, error = mi_pth_root(v, p)
1625
+ qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True)
1626
+ return qp.quantized, qp.diagonal, qp.bucket_size, error
1627
+
1628
+ return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
1629
+
1630
+ def _matrix_inverse_pth_root_pjit(xs, ps, statistics_partition_spec=None):
1631
+ # Partition the concatenated statistics matrix across all cores.
1632
+ pspec_for_partition = preconditioner_partition_spec
1633
+ partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition)
1634
+ if preconditioner_partition_spec:
1635
+ partitioned_ps_spec = pjit.PartitionSpec(preconditioner_partition_spec[0])
1636
+ else:
1637
+ partitioned_ps_spec = None
1638
+ partitioned_ps = pjit.with_sharding_constraint(ps, partitioned_ps_spec)
1639
+ # Run matrix inverse pth root on each shard.
1640
+ partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
1641
+ partitioned_xs, partitioned_ps
1642
+ )
1643
+ # Reshard output to have the same PSpec as input. This is required to avoid
1644
+ # vmap seeing the full set of statistics.
1645
+ partitioned_preconditioners = pjit.with_sharding_constraint(
1646
+ partitioned_preconditioners, pspec_for_partition
1647
+ )
1648
+ # Recombine the outputs at each core.
1649
+ preconditioners = pjit.with_sharding_constraint(
1650
+ partitioned_preconditioners, statistics_partition_spec
1651
+ )
1652
+ errors = pjit.with_sharding_constraint(partitioned_errors, pjit.PartitionSpec())
1653
+ return preconditioners, errors
1654
+
1655
+ def _pmap_compute_preconditioners(
1656
+ states,
1657
+ step,
1658
+ statistics,
1659
+ num_statistics_per_state,
1660
+ original_shapes,
1661
+ exponents,
1662
+ max_size,
1663
+ prev_preconditioners,
1664
+ ):
1665
+ """Computes preconditioners for given statistics in states in PMAP mode.
1666
+
1667
+ Args:
1668
+ states: A list of optimizer states.
1669
+ step: Current step number
1670
+ statistics: A list of statistics for all variables (for every dim)
1671
+ num_statistics_per_state: Number of statistis per state to reconstruct
1672
+ output states.
1673
+ original_shapes: A list of shapes of the statistics.
1674
+ exponents: Exponent power to use for inverse-pth roots.
1675
+ max_size: Maximum dim of the statistics to pad.
1676
+ prev_preconditioners: Previously available preconditioner.
1677
+
1678
+ Returns:
1679
+ New optimizer states after computing the preconditioner.
1680
+ """
1681
+ if batch_axis_name:
1682
+ num_devices = lax.psum(1, batch_axis_name)
1683
+ else:
1684
+ num_devices = 1
1685
+ num_statistics = len(statistics)
1686
+ # Pad statistics and exponents to next multiple of num_devices.
1687
+ packed_statistics = [pad_square_matrix(stat, max_size) for stat in statistics]
1688
+ to_pad = -num_statistics % num_devices
1689
+ packed_statistics.extend(
1690
+ [jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad)]
1691
+ )
1692
+ exponents.extend([1 for _ in range(to_pad)])
1693
+
1694
+ if not packed_statistics:
1695
+ return states
1696
+
1697
+ all_statistics = batch(packed_statistics, num_devices)
1698
+ all_exponents = batch(exponents, num_devices)
1699
+
1700
+ def _internal_inverse_pth_root_all():
1701
+ if batch_axis_name:
1702
+ current_replica = lax.axis_index(batch_axis_name)
1703
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
1704
+ all_statistics[current_replica], all_exponents[current_replica]
1705
+ )
1706
+ preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
1707
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1708
+ preconditioners_flat = unbatch(preconditioners)
1709
+ errors_flat = unbatch(errors)
1710
+ else:
1711
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
1712
+ all_statistics[0], all_exponents[0]
1713
+ )
1714
+ preconditioners_flat = unbatch(jnp.stack([preconditioners]))
1715
+ errors_flat = unbatch(jnp.stack([errors]))
1716
+
1717
+ return preconditioners_flat, errors_flat
1718
+
1719
+ if preconditioning_compute_steps == 1:
1720
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1721
+ else:
1722
+ # Passing statistics instead of preconditioners as they are similarly
1723
+ # shaped tensors. Note statistics will be ignored as we are passing in
1724
+ # a large init value for error.
1725
+ preconditioners_init = packed_statistics
1726
+ errors_init = [inverse_failure_threshold] * len(packed_statistics)
1727
+ init_state = [preconditioners_init, errors_init]
1728
+ perform_step = step % preconditioning_compute_steps == 0
1729
+ preconditioners_flat, errors_flat = efficient_cond(
1730
+ perform_step, _internal_inverse_pth_root_all, init_state
1731
+ )
1732
+
1733
+ def _skip(error):
1734
+ condition = jnp.logical_or(
1735
+ jnp.isnan(error), error >= inverse_failure_threshold
1736
+ )
1737
+ return condition.astype(error.dtype)
1738
+
1739
+ def _select_preconditioner(error, new_p, old_p):
1740
+ return lax.cond(
1741
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None
1742
+ )
1743
+
1744
+ new_preconditioners_flat = []
1745
+ new_errors_flat = []
1746
+ for p, shape, prev_p, error in zip(
1747
+ preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
1748
+ ):
1749
+ new_preconditioners_flat.append(
1750
+ _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
1751
+ )
1752
+ new_errors_flat.append(error)
1753
+
1754
+ assert len(states) == len(num_statistics_per_state)
1755
+ assert len(new_preconditioners_flat) == num_statistics
1756
+ assert len(new_errors_flat) == num_statistics
1757
+
1758
+ # Add back empty preconditioners so we that we can set the optimizer state.
1759
+ preconditioners_for_states = []
1760
+ idx = 0
1761
+ errors_for_states = []
1762
+ for num_statistics, state in zip(num_statistics_per_state, states):
1763
+ if num_statistics == 0:
1764
+ preconditioners_for_states.append([])
1765
+ errors_for_states.append(jnp.array(0, jnp.float32))
1766
+ else:
1767
+ preconditioners_for_state = new_preconditioners_flat[
1768
+ idx : idx + num_statistics
1769
+ ]
1770
+ assert len(state.statistics) == len(preconditioners_for_state)
1771
+ preconditioners_for_states.append(preconditioners_for_state)
1772
+
1773
+ errors_for_state = jnp.stack(
1774
+ new_errors_flat[idx : idx + num_statistics]
1775
+ )
1776
+ assert len(state.statistics) == len(errors_for_state)
1777
+ errors_for_states.append(errors_for_state)
1778
+
1779
+ idx += num_statistics
1780
+ new_states = []
1781
+ for state, new_preconditioners, new_errors in zip(
1782
+ states, preconditioners_for_states, errors_for_states
1783
+ ):
1784
+ if state.statistics:
1785
+ new_errors = jnp.where(
1786
+ jnp.logical_and(
1787
+ new_errors > 0.0, new_errors != inverse_failure_threshold
1788
+ ),
1789
+ new_errors,
1790
+ state.training_metrics.inverse_pth_root_errors,
1791
+ )
1792
+ new_training_metrics = TrainingMetrics(new_errors)
1793
+ new_states.append(
1794
+ ParameterStats(
1795
+ state.diagonal_statistics,
1796
+ state.statistics,
1797
+ new_preconditioners,
1798
+ state.diagonal_momentum,
1799
+ state.momentum,
1800
+ new_training_metrics,
1801
+ )
1802
+ )
1803
+
1804
+ return new_states
1805
+
1806
+ def _pmap_quantized_compute_preconditioners(
1807
+ states,
1808
+ step,
1809
+ statistics,
1810
+ num_statistics_per_state,
1811
+ original_shapes,
1812
+ exponents,
1813
+ max_size,
1814
+ prev_preconditioners,
1815
+ ):
1816
+ """Computes preconditioners for given statistics in states in PMAP mode.
1817
+
1818
+ For quantization, each statistic is represented by three values:
1819
+ quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots
1820
+ without ever recreating the original matrix in f32.
1821
+
1822
+ Args:
1823
+ states: A list of optimizer states.
1824
+ step: Current step number
1825
+ statistics: A list of statistics for all variables (for every dim)
1826
+ num_statistics_per_state: Number of statistis per state to reconstruct
1827
+ output states.
1828
+ original_shapes: A list of shapes of the statistics.
1829
+ exponents: Exponent power to use for inverse-pth roots.
1830
+ max_size: Maximum dim of the statistics to pad.
1831
+ prev_preconditioners: Previously available preconditioner.
1832
+
1833
+ Returns:
1834
+ New optimizer states after computing the preconditioner.
1835
+ """
1836
+ num_devices = lax.psum(1, batch_axis_name)
1837
+ num_statistics = len(statistics)
1838
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
1839
+ # Complexity here is around: shapes needing be statically shaped,
1840
+ # our custom quantization type requires a different type of packing.
1841
+
1842
+ # Parallel tensors:
1843
+ # quantized [dxd]
1844
+ # diagonals [d] f32
1845
+ # bucket_sizes [d] f32
1846
+ packed_quantized_statistics = [
1847
+ pad_square_matrix(stat.quantized, max_size) for stat in statistics
1848
+ ]
1849
+ packed_quantized_diagonals = [
1850
+ pad_vector(stat.diagonal, max_size) for stat in statistics
1851
+ ]
1852
+ packed_quantized_bucket_sizes = [
1853
+ pad_vector(stat.bucket_size, max_size) for stat in statistics
1854
+ ]
1855
+
1856
+ to_pad = -num_statistics % num_devices
1857
+ padded_eye = jnp.eye(max_size, dtype=jnp.float32)
1858
+ quantized_eye = QuantizedValue.from_float_value(
1859
+ padded_eye, quantized_dtype, True
1860
+ )
1861
+ packed_quantized_statistics.extend(
1862
+ [quantized_eye.quantized for _ in range(to_pad)]
1863
+ )
1864
+ packed_quantized_diagonals.extend(
1865
+ [quantized_eye.diagonal for _ in range(to_pad)]
1866
+ )
1867
+ packed_quantized_bucket_sizes.extend(
1868
+ [quantized_eye.bucket_size for _ in range(to_pad)]
1869
+ )
1870
+ exponents.extend([1 for _ in range(to_pad)])
1871
+
1872
+ if not packed_quantized_statistics:
1873
+ return states
1874
+
1875
+ all_quantized_statistics = batch(packed_quantized_statistics, num_devices)
1876
+ all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices)
1877
+ all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes, num_devices)
1878
+ all_exponents = batch(exponents, num_devices)
1879
+
1880
+ def _internal_inverse_pth_root_all():
1881
+ current_replica = lax.axis_index(batch_axis_name)
1882
+ (
1883
+ quantized_preconditioners,
1884
+ quantized_diagonals,
1885
+ quantized_bucket_sizes,
1886
+ errors,
1887
+ ) = _quantized_matrix_inverse_pth_root_vmap(
1888
+ all_quantized_statistics[current_replica],
1889
+ all_quantized_diagonals[current_replica],
1890
+ all_quantized_bucket_sizes[current_replica],
1891
+ all_exponents[current_replica],
1892
+ )
1893
+ quantized_preconditioners = jax.lax.all_gather(
1894
+ quantized_preconditioners, batch_axis_name
1895
+ )
1896
+ quantized_diagonals = jax.lax.all_gather(
1897
+ quantized_diagonals, batch_axis_name
1898
+ )
1899
+ quantized_bucket_sizes = jax.lax.all_gather(
1900
+ quantized_bucket_sizes, batch_axis_name
1901
+ )
1902
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1903
+ quantized_preconditioners_flat = unbatch(quantized_preconditioners)
1904
+ quantized_diagonals_flat = unbatch(quantized_diagonals)
1905
+ quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes)
1906
+ errors_flat = unbatch(errors)
1907
+ return (
1908
+ quantized_preconditioners_flat,
1909
+ quantized_diagonals_flat,
1910
+ quantized_bucket_sizes_flat,
1911
+ errors_flat,
1912
+ )
1913
+
1914
+ if preconditioning_compute_steps == 1:
1915
+ (
1916
+ quantized_preconditioners_flat,
1917
+ quantized_diagonals_flat,
1918
+ quantized_bucket_sizes_flat,
1919
+ errors_flat,
1920
+ ) = _internal_inverse_pth_root_all()
1921
+ else:
1922
+ # Passing statistics instead of preconditioners as they are similarly
1923
+ # shaped tensors. Note statistics will be ignored as we are passing in
1924
+ # a large init value for error.
1925
+ quantized_preconditioners_init = packed_quantized_statistics
1926
+ quantized_diagonals_init = packed_quantized_diagonals
1927
+ quantized_bucket_sizes_init = packed_quantized_bucket_sizes
1928
+ errors_init = [inverse_failure_threshold] * len(
1929
+ quantized_preconditioners_init
1930
+ )
1931
+ init_state = [
1932
+ quantized_preconditioners_init,
1933
+ quantized_diagonals_init,
1934
+ quantized_bucket_sizes_init,
1935
+ errors_init,
1936
+ ]
1937
+ perform_step = step % preconditioning_compute_steps == 0
1938
+ (
1939
+ quantized_preconditioners_flat,
1940
+ quantized_diagonals_flat,
1941
+ quantized_bucket_sizes_flat,
1942
+ errors_flat,
1943
+ ) = efficient_cond(perform_step, _internal_inverse_pth_root_all, init_state)
1944
+
1945
+ def _skip(error):
1946
+ condition = jnp.logical_or(
1947
+ jnp.isnan(error), error >= inverse_failure_threshold
1948
+ )
1949
+ return condition.astype(error.dtype)
1950
+
1951
+ def _select_preconditioner(error, new_p, old_p):
1952
+ return lax.cond(
1953
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None
1954
+ )
1955
+
1956
+ new_quantized_preconditioners_flat = []
1957
+ new_quantized_diagonals_flat = []
1958
+ new_quantized_bucket_sizes_flat = []
1959
+ new_errors_flat = []
1960
+ for p, d, b, shape, prev_p, error in zip(
1961
+ quantized_preconditioners_flat,
1962
+ quantized_diagonals_flat,
1963
+ quantized_bucket_sizes_flat,
1964
+ original_shapes,
1965
+ prev_preconditioners,
1966
+ errors_flat,
1967
+ ):
1968
+ new_quantized_preconditioners_flat.append(
1969
+ _select_preconditioner(
1970
+ error, p[: shape[0], : shape[1]], prev_p.quantized
1971
+ )
1972
+ )
1973
+ new_quantized_diagonals_flat.append(
1974
+ _select_preconditioner(error, d[: shape[0]], prev_p.diagonal)
1975
+ )
1976
+ new_quantized_bucket_sizes_flat.append(
1977
+ _select_preconditioner(error, b[: shape[0]], prev_p.bucket_size)
1978
+ )
1979
+ new_errors_flat.append(error)
1980
+
1981
+ assert len(states) == len(num_statistics_per_state)
1982
+ assert len(new_quantized_preconditioners_flat) == num_statistics
1983
+ assert len(new_quantized_diagonals_flat) == num_statistics
1984
+ assert len(new_quantized_bucket_sizes_flat) == num_statistics
1985
+
1986
+ # Add back empty preconditioners so we that we can set the optimizer state.
1987
+ preconditioners_for_states = []
1988
+ errors_for_states = []
1989
+ idx = 0
1990
+ for num_statistics, state in zip(num_statistics_per_state, states):
1991
+ if num_statistics == 0:
1992
+ preconditioners_for_states.append([])
1993
+ errors_for_states.append(jnp.array(0, jnp.float32))
1994
+ else:
1995
+ quantized_preconditioners_for_state = (
1996
+ new_quantized_preconditioners_flat[idx : idx + num_statistics]
1997
+ )
1998
+ quantized_diagonals_for_state = new_quantized_diagonals_flat[
1999
+ idx : idx + num_statistics
2000
+ ]
2001
+ quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
2002
+ idx : idx + num_statistics
2003
+ ]
2004
+ errors_for_state = jnp.stack(
2005
+ new_errors_flat[idx : idx + num_statistics]
2006
+ )
2007
+
2008
+ assert len(state.statistics) == len(quantized_preconditioners_for_state)
2009
+ assert len(state.statistics) == len(quantized_diagonals_for_state)
2010
+ assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
2011
+ assert len(state.statistics) == len(errors_for_state)
2012
+
2013
+ quantized_preconditioners = []
2014
+ for qv, qd, qb in zip(
2015
+ quantized_preconditioners_for_state,
2016
+ quantized_diagonals_for_state,
2017
+ quantized_bucket_sizes_for_state,
2018
+ ):
2019
+ quantized_preconditioners.append(
2020
+ QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape))
2021
+ )
2022
+ preconditioners_for_states.append(quantized_preconditioners)
2023
+ errors_for_states.append(errors_for_state)
2024
+ idx += num_statistics
2025
+ new_states = []
2026
+ for state, new_preconditioners, new_errors in zip(
2027
+ states, preconditioners_for_states, errors_for_states
2028
+ ):
2029
+ if state.statistics:
2030
+ new_errors = jnp.where(
2031
+ jnp.logical_and(
2032
+ new_errors > 0.0, new_errors != inverse_failure_threshold
2033
+ ),
2034
+ new_errors,
2035
+ state.training_metrics.inverse_pth_root_errors,
2036
+ )
2037
+ new_training_metrics = TrainingMetrics(new_errors)
2038
+ new_states.append(
2039
+ ParameterStats(
2040
+ state.diagonal_statistics,
2041
+ state.statistics,
2042
+ new_preconditioners,
2043
+ state.diagonal_momentum,
2044
+ state.momentum,
2045
+ new_training_metrics,
2046
+ )
2047
+ )
2048
+
2049
+ return new_states
2050
+
2051
+ def _pjit_compute_preconditioners(
2052
+ states,
2053
+ step,
2054
+ statistics,
2055
+ num_statistics_per_state,
2056
+ original_shapes,
2057
+ exponents,
2058
+ max_size,
2059
+ prev_preconditioners,
2060
+ ):
2061
+ """Computes preconditioners for given statistics in states in PJIT mode.
2062
+
2063
+ Args:
2064
+ states: A list of optimizer states.
2065
+ step: Current step number
2066
+ statistics: A list of statistics for all variables (for every dim)
2067
+ num_statistics_per_state: Number of statistis per state to reconstruct
2068
+ output states.
2069
+ original_shapes: A list of shapes of the statistics.
2070
+ exponents: Exponent power to use for inverse-pth roots.
2071
+ max_size: Maximum dim of the statistics to pad.
2072
+ prev_preconditioners: Previously available preconditioner.
2073
+
2074
+ Returns:
2075
+ New optimizer states after computing the preconditioner.
2076
+ """
2077
+ num_statistics = len(statistics)
2078
+ to_pad = -num_statistics % num_devices_for_pjit
2079
+ padded_statistics = [pad_square_matrix(stat, max_size) for stat in statistics]
2080
+ padded_statistics.extend(
2081
+ [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
2082
+ )
2083
+ exponents.extend([1 for _ in range(to_pad)])
2084
+ all_statistics = jnp.stack(padded_statistics)
2085
+ all_exponents = jnp.stack(exponents)
2086
+
2087
+ def _internal_inverse_pth_root_all():
2088
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
2089
+ all_statistics, all_exponents
2090
+ )
2091
+ b1 = preconditioners.shape[0]
2092
+
2093
+ def split(batched_values):
2094
+ return [
2095
+ jnp.squeeze(v)
2096
+ for v in jnp.split(batched_values, indices_or_sections=b1, axis=0)
2097
+ ]
2098
+
2099
+ return split(preconditioners), split(errors)
2100
+
2101
+ if preconditioning_compute_steps == 1:
2102
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
2103
+ else:
2104
+ # Passing statistics instead of preconditioners as they are similarly
2105
+ # shaped tensors. Note statistics will be ignored as we are passing in
2106
+ # a large init value for error.
2107
+ preconditioners_init = padded_statistics
2108
+ errors_init = [inverse_failure_threshold] * len(padded_statistics)
2109
+ init_state = [preconditioners_init, errors_init]
2110
+ perform_step = step % preconditioning_compute_steps == 0
2111
+ preconditioners_flat, errors_flat = efficient_cond(
2112
+ perform_step, _internal_inverse_pth_root_all, init_state
2113
+ )
2114
+
2115
+ def _skip(error):
2116
+ condition = jnp.logical_or(
2117
+ jnp.isnan(error), error >= inverse_failure_threshold
2118
+ )
2119
+ return condition.astype(error.dtype)
2120
+
2121
+ def _select_preconditioner(error, new_p, old_p):
2122
+ return lax.cond(
2123
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None
2124
+ )
2125
+
2126
+ new_preconditioners_flat = []
2127
+ new_errors_flat = []
2128
+ for p, shape, prev_p, error in zip(
2129
+ preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
2130
+ ):
2131
+ new_preconditioners_flat.append(
2132
+ _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
2133
+ )
2134
+ new_errors_flat.append(error)
2135
+
2136
+ assert len(states) == len(num_statistics_per_state)
2137
+ assert len(new_preconditioners_flat) == num_statistics
2138
+
2139
+ # Add back empty preconditioners so we that we can set the optimizer state.
2140
+ preconditioners_for_states = []
2141
+ errors_for_states = []
2142
+ idx = 0
2143
+ for num_statistics, state in zip(num_statistics_per_state, states):
2144
+ if num_statistics == 0:
2145
+ preconditioners_for_states.append([])
2146
+ errors_for_states.append(jnp.array(0, jnp.float32))
2147
+ else:
2148
+ preconditioners_for_state = new_preconditioners_flat[
2149
+ idx : idx + num_statistics
2150
+ ]
2151
+ assert len(state.statistics) == len(preconditioners_for_state)
2152
+ preconditioners_for_states.append(preconditioners_for_state)
2153
+
2154
+ errors_for_state = jnp.stack(
2155
+ new_errors_flat[idx : idx + num_statistics]
2156
+ )
2157
+ assert len(state.statistics) == len(errors_for_state)
2158
+ errors_for_states.append(errors_for_state)
2159
+ idx += num_statistics
2160
+
2161
+ new_states = []
2162
+ for state, new_preconditioners, new_errors in zip(
2163
+ states, preconditioners_for_states, errors_for_states
2164
+ ):
2165
+ if state.statistics:
2166
+ new_errors = jnp.where(
2167
+ jnp.logical_and(
2168
+ new_errors > 0.0, new_errors != inverse_failure_threshold
2169
+ ),
2170
+ new_errors,
2171
+ state.training_metrics.inverse_pth_root_errors,
2172
+ )
2173
+ new_training_metrics = TrainingMetrics(new_errors)
2174
+ new_states.append(
2175
+ ParameterStats(
2176
+ state.diagonal_statistics,
2177
+ state.statistics,
2178
+ new_preconditioners,
2179
+ state.diagonal_momentum,
2180
+ state.momentum,
2181
+ new_training_metrics,
2182
+ )
2183
+ )
2184
+
2185
+ return new_states
2186
+
2187
+ def _compute_preconditioners(states, params, step):
2188
+ """Computes preconditioners for given statistics in states.
2189
+
2190
+ Args:
2191
+ states: A list of optimizer states.
2192
+ params: A list of params.
2193
+ step: Current step number
2194
+
2195
+ Returns:
2196
+ New optimizer states after computing the preconditioner.
2197
+ """
2198
+ statistics = []
2199
+ num_statistics_per_state = []
2200
+ original_shapes = []
2201
+ exponents = []
2202
+ max_size = 0
2203
+ prev_preconditioners = []
2204
+
2205
+ for state, param in zip(states, params):
2206
+ num_statistics = len(state.statistics)
2207
+ num_statistics_per_state.append(num_statistics)
2208
+ original_shapes_for_state = []
2209
+ if num_statistics > 0:
2210
+ preconditioner = preconditioner_from_params(param)
2211
+ for statistic in state.statistics:
2212
+ exponents.append(
2213
+ preconditioner.exponent_for_preconditioner()
2214
+ if exponent_override == 0
2215
+ else exponent_override
2216
+ )
2217
+ original_shapes_for_state.append(statistic.shape)
2218
+ max_size = max(max_size, statistic.shape[0])
2219
+
2220
+ statistics.extend(state.statistics)
2221
+ prev_preconditioners.extend(state.preconditioners)
2222
+ original_shapes.extend(original_shapes_for_state)
2223
+
2224
+ if not shard_optimizer_states:
2225
+ # Quantization is only enabled if batch_axis_name is not set.
2226
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
2227
+
2228
+ if quantized_dtype == jnp.float32:
2229
+ return _pmap_compute_preconditioners(
2230
+ states,
2231
+ step,
2232
+ statistics,
2233
+ num_statistics_per_state,
2234
+ original_shapes,
2235
+ exponents,
2236
+ max_size,
2237
+ prev_preconditioners,
2238
+ )
2239
+ else:
2240
+ return _pmap_quantized_compute_preconditioners(
2241
+ states,
2242
+ step,
2243
+ statistics,
2244
+ num_statistics_per_state,
2245
+ original_shapes,
2246
+ exponents,
2247
+ max_size,
2248
+ prev_preconditioners,
2249
+ )
2250
+
2251
+ else:
2252
+ return _pjit_compute_preconditioners(
2253
+ states,
2254
+ step,
2255
+ statistics,
2256
+ num_statistics_per_state,
2257
+ original_shapes,
2258
+ exponents,
2259
+ max_size,
2260
+ prev_preconditioners,
2261
+ )
2262
+
2263
+ def _transform_grad(grad, state, param, step):
2264
+ """Transform per-parameter gradients."""
2265
+ preconditioner = preconditioner_from_params(param)
2266
+ sgd_update = grad
2267
+ new_diagonal_statistics = state.diagonal_statistics.to_float()
2268
+
2269
+ if (
2270
+ graft_type == GraftingType.ADAGRAD
2271
+ or graft_type == GraftingType.ADAGRAD_NORMALIZED
2272
+ ):
2273
+ scaled_grad = grad
2274
+ if graft_type == GraftingType.ADAGRAD_NORMALIZED:
2275
+ scaled_grad = grad / (jnp.linalg.norm(grad) + 1e-16)
2276
+
2277
+ new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square(
2278
+ scaled_grad
2279
+ )
2280
+ adagrad_update = scaled_grad / (
2281
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
2282
+ )
2283
+ grafting_update = adagrad_update
2284
+ elif (
2285
+ graft_type == GraftingType.RMSPROP
2286
+ or graft_type == GraftingType.RMSPROP_NORMALIZED
2287
+ ):
2288
+ scaled_grad = grad
2289
+ if graft_type == GraftingType.RMSPROP_NORMALIZED:
2290
+ scaled_grad = grad / (jnp.linalg.norm(grad) + 1e-16)
2291
+
2292
+ w1 = beta2
2293
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
2294
+
2295
+ new_diagonal_statistics = (
2296
+ w1 * state.diagonal_statistics.to_float() + w2 * jnp.square(scaled_grad)
2297
+ )
2298
+ rmsprop_update = scaled_grad / (
2299
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
2300
+ )
2301
+
2302
+ if clip_by_scaled_gradient_norm:
2303
+ scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / (
2304
+ jnp.sqrt(float(rmsprop_update.size))
2305
+ )
2306
+ clipping_denom = jnp.maximum(
2307
+ 1.0, scaled_grad_norm / clip_by_scaled_gradient_norm
2308
+ )
2309
+ rmsprop_update /= clipping_denom
2310
+
2311
+ grafting_update = rmsprop_update
2312
+ elif graft_type == GraftingType.SGD:
2313
+ grafting_update = sgd_update
2314
+ else:
2315
+ grafting_update = jnp.ones_like(sgd_update) * jnp.sign(sgd_update)
2316
+
2317
+ lr = learning_rate
2318
+ if callable(learning_rate):
2319
+ lr = learning_rate(step)
2320
+
2321
+ preconditioner_multiplier = lr if not decoupled_learning_rate else 1.0
2322
+ grafting_update = grafting_update * preconditioner_multiplier
2323
+
2324
+ precond_grad = grad
2325
+ if not _skip_preconditioning(param):
2326
+ precond_grad = preconditioner.preconditioned_grad(
2327
+ precond_grad, _maybe_dequantize_preconditioners(state.preconditioners)
2328
+ )
2329
+ else:
2330
+ precond_grad = grafting_update
2331
+
2332
+ grafting_update_norm = jnp.linalg.norm(grafting_update)
2333
+ precond_grad_norm = jnp.linalg.norm(precond_grad)
2334
+
2335
+ multiplier = grafting_update_norm / (precond_grad_norm + 1e-16)
2336
+ shampoo_update = precond_grad * multiplier
2337
+
2338
+ shampoo_update_with_wd = shampoo_update
2339
+ grafting_update_with_wd = grafting_update
2340
+
2341
+ if weight_decay != 0 and not decoupled_weight_decay:
2342
+ shampoo_update_with_wd = shampoo_update + weight_decay * param
2343
+ grafting_update_with_wd = grafting_update + weight_decay * param
2344
+
2345
+ w = (1.0 - beta1) if moving_average_for_momentum else 1.0
2346
+
2347
+ shampoo_update_with_wd_momentum = (
2348
+ state.momentum.to_float() * beta1 + w * shampoo_update_with_wd
2349
+ )
2350
+
2351
+ grafting_update_with_wd_momentum = (
2352
+ state.diagonal_momentum.to_float() * beta1 + w * grafting_update_with_wd
2353
+ )
2354
+
2355
+ run_shampoo = (step >= start_preconditioning_step).astype(
2356
+ grafting_update_with_wd_momentum.dtype
2357
+ )
2358
+
2359
+ momentum_update = (
2360
+ run_shampoo * shampoo_update_with_wd_momentum
2361
+ + (1.0 - run_shampoo) * grafting_update_with_wd_momentum
2362
+ )
2363
+
2364
+ wd_update = (
2365
+ run_shampoo * shampoo_update_with_wd
2366
+ + (1.0 - run_shampoo) * grafting_update_with_wd
2367
+ )
2368
+
2369
+ nesterov_momentum_update = momentum_update
2370
+
2371
+ if nesterov:
2372
+ nesterov_momentum_update = w * wd_update + beta1 * momentum_update
2373
+
2374
+ if weight_decay != 0 and decoupled_weight_decay:
2375
+ nesterov_momentum_update = (
2376
+ nesterov_momentum_update + lr * weight_decay * param
2377
+ )
2378
+
2379
+ momentum_multiplier = lr if decoupled_learning_rate else 1.0
2380
+ transformed_update = -1.0 * momentum_multiplier * nesterov_momentum_update
2381
+
2382
+ new_diagonal_momentum = grafting_update_with_wd_momentum
2383
+ new_momentum = shampoo_update_with_wd_momentum
2384
+
2385
+ param_stats = ParameterStats(
2386
+ _quantize_diagonal_statistics(new_diagonal_statistics),
2387
+ state.statistics,
2388
+ state.preconditioners,
2389
+ _quantize_momentum(new_diagonal_momentum),
2390
+ _quantize_momentum(new_momentum),
2391
+ state.training_metrics,
2392
+ )
2393
+
2394
+ return transformed_update, param_stats
2395
+
2396
+ def update_fn(grads, state, params):
2397
+ """Transform the input gradient and update all statistics.
2398
+
2399
+ Args:
2400
+ grads: the gradient tensors for the parameters
2401
+ and any custom gradients for preconditioners.
2402
+ state: a named tuple containing the state of the optimizer
2403
+ params: the parameters that should be updated.
2404
+
2405
+ Returns:
2406
+ A tuple containing the new parameters and the new optimizer state.
2407
+ """
2408
+ params_flat, treedef = jax.tree_flatten(params)
2409
+ stats_flat = treedef.flatten_up_to(state.stats)
2410
+ grads_flat = treedef.flatten_up_to(grads)
2411
+ stats_grads = grads_flat
2412
+
2413
+ new_stats_flat = jax.tree_map(
2414
+ lambda g, s, p: _compute_stats(g, s, p, state.count),
2415
+ stats_grads,
2416
+ stats_flat,
2417
+ params_flat,
2418
+ )
2419
+
2420
+ new_stats_flat = _compute_preconditioners(
2421
+ new_stats_flat, params_flat, state.count
2422
+ )
2423
+ outputs = jax.tree_map(
2424
+ lambda g, s, p: _transform_grad(g, s, p, state.count),
2425
+ grads_flat,
2426
+ new_stats_flat,
2427
+ params_flat,
2428
+ )
2429
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
2430
+
2431
+ updates = jax.tree_unflatten(treedef, updates_flat)
2432
+ new_stats = jax.tree_unflatten(treedef, new_stats_flat)
2433
+
2434
+ new_state = ShampooState(count=state.count + 1, stats=new_stats)
2435
+ return updates, new_state
2436
+
2437
+ if shard_optimizer_states:
2438
+ # Hijacks the init_fn signature so we can return an OptState with
2439
+ # appropriate init_fns.
2440
+ opt_init_fn = sharded_init_fn
2441
+
2442
+ def _init_fns(unused_params):
2443
+ return InitFnState(
2444
+ init_fn=opt_init_fn,
2445
+ pspec_fn=sharded_init_partition_spec_fn,
2446
+ shape_and_dtype_fn=sharded_init_shape_and_dtype_fn,
2447
+ )
2448
+
2449
+ opt_update_fn = sharded_update_fn
2450
+ return optax.GradientTransformation(_init_fns, opt_update_fn)
2451
+ else:
2452
+ return optax.GradientTransformation(init_fn, update_fn)
tools/train/scalable_shampoo/quantization_utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Helper routines for quantization."""
17
+
18
+ from typing import Any
19
+
20
+ import chex
21
+ import jax.numpy as jnp
22
+ from flax import struct
23
+
24
+
25
+ # pylint:disable=no-value-for-parameter
26
+ @struct.dataclass
27
+ class QuantizedValue:
28
+ """State associated with quantized value."""
29
+
30
+ quantized: chex.Array
31
+ diagonal: chex.Array # Diagonal (if extract_diagonal is set)
32
+ bucket_size: chex.Array
33
+ quantized_dtype: jnp.dtype = struct.field(
34
+ pytree_node=False
35
+ ) # Dtype for the quantized value.
36
+ extract_diagonal: bool = struct.field(pytree_node=False) # In case its centered.
37
+ shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
38
+
39
+ @classmethod
40
+ def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
41
+ if isinstance(fvalue, list) and not fvalue:
42
+ return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
43
+ quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
44
+ fvalue, quantized_dtype, extract_diagonal
45
+ )
46
+ return QuantizedValue(
47
+ quantized,
48
+ diagonal_fvalue,
49
+ bucket_size,
50
+ quantized_dtype,
51
+ extract_diagonal,
52
+ list(quantized.shape),
53
+ )
54
+
55
+ # Quantization is from Lingvo JAX optimizers.
56
+ # We extend it for int16 quantization of PSD matrices.
57
+ @classmethod
58
+ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
59
+ """Returns quantized value and the bucket."""
60
+ if quantized_dtype == jnp.float32:
61
+ return fvalue, [], []
62
+ elif quantized_dtype == jnp.bfloat16:
63
+ return fvalue.astype(jnp.bfloat16), [], []
64
+
65
+ float_dtype = fvalue.dtype
66
+ if quantized_dtype == jnp.int8:
67
+ # value -128 is not used.
68
+ num_buckets = jnp.array(127.0, dtype=float_dtype)
69
+ elif quantized_dtype == jnp.int16:
70
+ # value -32768 is not used.
71
+ num_buckets = jnp.array(32767.0, dtype=float_dtype)
72
+ else:
73
+ raise ValueError(f"Quantized dtype {quantized_dtype} not supported.")
74
+ # max value is mapped to num_buckets
75
+
76
+ if extract_diagonal and fvalue.ndim != 2:
77
+ raise ValueError(
78
+ f"Input array {fvalue} must be 2D to work with extract_diagonal."
79
+ )
80
+
81
+ diagonal_fvalue = []
82
+ if extract_diagonal:
83
+ diagonal_fvalue = jnp.diag(fvalue)
84
+ # Remove the diagonal entries.
85
+ fvalue = fvalue - jnp.diag(diagonal_fvalue)
86
+
87
+ # TODO(rohananil): Extend this by making use of information about the blocks
88
+ # SM3 style which will be useful for diagonal statistics
89
+ # We first decide the scale.
90
+ if fvalue.ndim < 1:
91
+ raise ValueError(
92
+ f"Input array {fvalue} must have a strictly positive number of dimensions."
93
+ )
94
+
95
+ max_abs = jnp.max(jnp.abs(fvalue), axis=0)
96
+ bucket_size = max_abs / num_buckets
97
+ bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
98
+ # To avoid divide by 0.0
99
+ bs_nonzero = jnp.where(
100
+ bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)
101
+ )
102
+ ratio = fvalue / bs_nonzero
103
+ # We use rounding to remove bias.
104
+ quantized = jnp.round(ratio)
105
+ return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
106
+
107
+ def to_float(self):
108
+ """Returns the float value."""
109
+ if isinstance(self.quantized, list) and not self.quantized:
110
+ return self.quantized
111
+
112
+ if self.quantized_dtype == jnp.float32:
113
+ return self.quantized
114
+
115
+ if self.quantized_dtype == jnp.bfloat16:
116
+ return self.quantized.astype(jnp.float32)
117
+
118
+ float_dtype = self.bucket_size.dtype
119
+ bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
120
+ val = self.quantized.astype(float_dtype) * bucket_size
121
+ if self.extract_diagonal:
122
+ val += jnp.diag(self.diagonal)
123
+ return val
tools/train/scalable_shampoo/sm3.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # An implementation of SM3 from:
17
+ #
18
+ # Memory-Efficient Adaptive Optimization, https://arxiv.org/pdf/1901.11150.pdf
19
+ # Rohan Anil, Vineet Gupta, Tomer Koren, Yoram Singer
20
+ #
21
+ # Author: Rohan Anil (rohananil at google dot com)
22
+ #
23
+
24
+ """SM3 Implementation."""
25
+
26
+ import functools
27
+ from typing import Any, NamedTuple
28
+
29
+ import chex
30
+ import jax
31
+ import jax.numpy as jnp
32
+ import optax
33
+
34
+ from .quantization_utils import QuantizedValue
35
+
36
+
37
+ class SM3State(NamedTuple):
38
+ count: chex.Array
39
+ stats: Any
40
+
41
+
42
+ # Per parameter optimizer state used in data-parallel training.
43
+ class ParameterStats(NamedTuple):
44
+ """State associated to each parameter of the model being trained."""
45
+
46
+ diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner
47
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
48
+
49
+
50
+ def sm3(
51
+ learning_rate, beta1=0.9, beta2=0.999, diagonal_epsilon=1e-10, normalize_grads=False
52
+ ):
53
+ """SM3 optimizer.
54
+
55
+ Memory-Efficient Adaptive Optimization, Rohan Anil, Vineet Gupta, Tomer Koren,
56
+ Yoram Singer
57
+
58
+ https://arxiv.org/abs/1901.11150
59
+
60
+ Args:
61
+ learning_rate: the step size used to update the parameters.
62
+ beta1: momentum parameter.
63
+ beta2: second moment averaging parameter.
64
+ diagonal_epsilon: epsilon for sm3
65
+ normalize_grads: Whether to normalize grads. Author finds it useful when
66
+ grads are high variance.
67
+
68
+ Returns:
69
+ a GradientTransformation.
70
+ """
71
+
72
+ def _quantize_momentum(momentum_statistics):
73
+ return QuantizedValue.from_float_value(momentum_statistics, jnp.int8)
74
+
75
+ def init_fn(params):
76
+ """Initialise the optimiser's state."""
77
+
78
+ def _init(param):
79
+ accumulators = [jnp.zeros([s]) for s in param.shape]
80
+ momentum = _quantize_momentum(jnp.zeros_like(param))
81
+ return ParameterStats(accumulators, momentum)
82
+
83
+ return SM3State(
84
+ count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)
85
+ )
86
+
87
+ def _get_expanded_shape(shape, i):
88
+ rank = len(shape)
89
+ # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i.
90
+ # For eg: i = 1 returns [1, N, 1].
91
+ return [1] * i + [shape[i]] + [1] * (rank - i - 1)
92
+
93
+ def _moving_averages(grad, accumulators):
94
+ w = (1.0 - beta2) if beta2 != 1.0 else 1.0
95
+ if grad.ndim < 2:
96
+ return beta2 * accumulators[0] + w * grad**2
97
+ else:
98
+ min_accumulator = functools.reduce(jnp.minimum, accumulators)
99
+ return beta2 * min_accumulator + w * grad**2
100
+
101
+ def _moving_averages_momentum(grad, momentum):
102
+ w = (1.0 - beta1) if beta1 != 1.0 else 1.0
103
+ return beta1 * momentum.to_float() + w * grad
104
+
105
+ def _sketch_diagonal_statistics(grad, updated_diagonal_statistics):
106
+ all_diagonal_statistics = []
107
+ for i in range(grad.ndim):
108
+ axes = list(range(i)) + list(range(i + 1, grad.ndim))
109
+ dim_diagonal_statistics = jnp.max(updated_diagonal_statistics, axis=axes)
110
+ all_diagonal_statistics.append(dim_diagonal_statistics)
111
+ if grad.ndim == 1:
112
+ all_diagonal_statistics[0] = updated_diagonal_statistics
113
+ return all_diagonal_statistics
114
+
115
+ def update_fn(updates, state, params=None):
116
+ del params
117
+ stats = state.stats
118
+ if normalize_grads:
119
+ updates = jax.tree_map(lambda g: g / (jnp.linalg.norm(g) + 1e-16), updates)
120
+ # Reshape all vectors into N-d tensors to compute min over them.
121
+ # [n], [m] -> [n, 1], [1, m]
122
+ expanded_diagonal_statistics = jax.tree_map(
123
+ lambda grad, state: [ # pylint:disable=g-long-lambda
124
+ jnp.reshape(
125
+ state.diagonal_statistics[i], _get_expanded_shape(grad.shape, i)
126
+ )
127
+ for i in range(grad.ndim)
128
+ ],
129
+ updates,
130
+ stats,
131
+ )
132
+
133
+ # Compute new diagonal statistics
134
+ new_diagonal_statistics = jax.tree_map(
135
+ _moving_averages, updates, expanded_diagonal_statistics
136
+ )
137
+
138
+ # Compute preconditioners (1/sqrt(s)) where s is the statistics.
139
+ new_preconditioners = jax.tree_map(
140
+ lambda t: 1.0 / jnp.sqrt(t + diagonal_epsilon), new_diagonal_statistics
141
+ )
142
+ preconditioned_grads = jax.tree_map(
143
+ lambda g, p: g * p, updates, new_preconditioners
144
+ )
145
+
146
+ # Compute updated momentum (also handle quantization)
147
+ updated_momentum = jax.tree_map(
148
+ lambda preconditioned_grad, state: _moving_averages_momentum( # pylint:disable=g-long-lambda
149
+ preconditioned_grad, state.diagonal_momentum
150
+ ),
151
+ preconditioned_grads,
152
+ stats,
153
+ )
154
+
155
+ # Update diagonal statistics.
156
+ updated_diagonal_statistics = jax.tree_map(
157
+ _sketch_diagonal_statistics, updates, new_diagonal_statistics
158
+ )
159
+
160
+ # Update momentum.
161
+ new_sm3_stats = jax.tree_map(
162
+ lambda momentum, diagonal_stats: ParameterStats( # pylint:disable=g-long-lambda
163
+ diagonal_stats, _quantize_momentum(momentum)
164
+ ),
165
+ updated_momentum,
166
+ updated_diagonal_statistics,
167
+ )
168
+
169
+ lr = learning_rate
170
+ if callable(learning_rate):
171
+ lr = learning_rate(state.count)
172
+
173
+ new_updates = jax.tree_map(lambda pg: -lr * pg, updated_momentum)
174
+ return new_updates, SM3State(count=state.count + 1, stats=new_sm3_stats)
175
+
176
+ return optax.GradientTransformation(init_fn, update_fn)
tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """JAX Ops for symmetric matrices used by the Shampoo optimizer."""
17
+
18
+ import functools
19
+ from typing import Any, List, Optional, Sequence, Union
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+ from flax import struct
25
+ from jax import lax
26
+
27
+
28
+ @struct.dataclass
29
+ class SlicedSymmetricMatrix:
30
+ """A symmetric matrix represented by lower-triangular block row slices.
31
+
32
+ For example, the symmetric matrix M = [[a, b^T], [b, c]] would be represented
33
+ by the block rows a and [b, c].
34
+
35
+ The matrix may be batched, in which case each entry of block_rows may have
36
+ dimension greater than 2. The last two dimensions represent the rows and cols.
37
+ """
38
+
39
+ block_rows: List[jnp.ndarray]
40
+
41
+
42
+ def product_with_transpose(
43
+ mat1,
44
+ mat2,
45
+ axes,
46
+ precision=lax.Precision.DEFAULT,
47
+ ):
48
+ """Returns mat1 * mat2^T for two matrices (possibly batched).
49
+
50
+ The rows and columns are the last two dimensions for each matrix.
51
+
52
+ Args:
53
+ mat1: First matrix.
54
+ mat2: Second matrix.
55
+ axes: The axes over which to apply the product.
56
+ precision: JAX precision to use for the multiplication.
57
+ """
58
+ return jnp.tensordot(a=mat1, b=mat2, axes=axes, precision=precision)
59
+
60
+
61
+ @functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
62
+ def sliced_transposed_product(
63
+ mat,
64
+ block_size,
65
+ axes=(-1,),
66
+ precision=lax.Precision.DEFAULT,
67
+ ):
68
+ """Returns the blocked slices representing a symmetric contraction.
69
+
70
+ Specifically, the output is a contraction of the input mat with itself, in the
71
+ specified axes.
72
+
73
+ Args:
74
+ mat: The matrix for which we will compute a contraction with itself.
75
+ block_size: The size of row blocks to compute.
76
+ axes: Axes to use for the contraction.
77
+ precision: The precision to use in each computation.
78
+
79
+ Raises:
80
+ ValueError: Raised when the specified block size does not evenly divide
81
+ the number of rows of the input mat.
82
+ """
83
+ rank = len(mat.shape)
84
+
85
+ def _make_axis_positive(ax):
86
+ assert -rank <= ax < rank
87
+ return ax + rank if ax < 0 else ax
88
+
89
+ positive_axes = [_make_axis_positive(ax) for ax in axes]
90
+ assert len(positive_axes) == len(axes)
91
+ remaining_axes = set(range(rank)) - set(positive_axes)
92
+ assert len(remaining_axes) == 1
93
+ remaining_ax = remaining_axes.pop()
94
+
95
+ num_rows = mat.shape[remaining_ax]
96
+ if num_rows % block_size != 0:
97
+ raise ValueError(
98
+ "The row dimension must be divisible by block_size. "
99
+ f"Instead got row dimension={num_rows} and block_size={block_size}."
100
+ )
101
+
102
+ block_rows = []
103
+ for i in range(num_rows // block_size):
104
+ start_indices = [0] * rank
105
+ start_indices[remaining_ax] = i * block_size
106
+
107
+ slice_sizes = list(mat.shape)
108
+ slice_sizes[remaining_ax] = block_size
109
+
110
+ slice_sizes_full = list(mat.shape)
111
+ slice_sizes_full[remaining_ax] = (i + 1) * block_size
112
+
113
+ block_rows.append(
114
+ product_with_transpose(
115
+ lax.dynamic_slice(
116
+ mat, start_indices=start_indices, slice_sizes=slice_sizes
117
+ ),
118
+ lax.dynamic_slice(
119
+ mat, start_indices=[0] * rank, slice_sizes=slice_sizes_full
120
+ ),
121
+ axes=(axes, axes),
122
+ precision=precision,
123
+ )
124
+ )
125
+
126
+ return SlicedSymmetricMatrix(block_rows=block_rows)
127
+
128
+
129
+ @functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
130
+ def sliced_transposed_product_concat(
131
+ mat,
132
+ block_size,
133
+ axes=(-1,),
134
+ precision=lax.Precision.DEFAULT,
135
+ ):
136
+ """Returns the concatenated slices representing mat*mat^T.
137
+
138
+ Args:
139
+ mat: The matrix for which we will compute mat*mat^T. It does not need to be
140
+ square, and may be batched.
141
+ block_size: The size of row blocks to compute.
142
+ axes: Axes to use for the contraction.
143
+ precision: The precision to use in each computation.
144
+
145
+ Raises:
146
+ ValueError: Raised when the specified block size does not evenly divide
147
+ the number of rows of the input mat.
148
+ """
149
+ sliced_symmetric_matrix = sliced_transposed_product(
150
+ mat=mat, block_size=block_size, axes=axes, precision=precision
151
+ )
152
+ return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
153
+
154
+
155
+ @jax.jit
156
+ def materialize_matrix(symmetric_matrix):
157
+ """Returns a materialized symmetric matrix.
158
+
159
+ Args:
160
+ symmetric_matrix: the matrix represented by lower-triangular block slices.
161
+ """
162
+ block_rows = symmetric_matrix.block_rows
163
+ block_size = block_rows[0].shape[-2]
164
+ num_blocks = len(block_rows)
165
+
166
+ # Slice the lower-triangular and diagonal blocks into blocks.
167
+ blocks = [
168
+ [
169
+ block_row[Ellipsis, i * block_size : (i + 1) * block_size]
170
+ for i in range(k + 1)
171
+ ]
172
+ for k, block_row in enumerate(block_rows)
173
+ ]
174
+
175
+ # Generate the (off-diagonal) upper-triangular blocks.
176
+ off_diags = [[] for _ in range(num_blocks - 1)]
177
+ for k, block_row in enumerate(block_rows[1:]):
178
+ for i in range(k + 1):
179
+ off_diags[i].append(
180
+ jnp.swapaxes(
181
+ a=block_row[Ellipsis, i * block_size : (i + 1) * block_size],
182
+ axis1=-1,
183
+ axis2=-2,
184
+ )
185
+ )
186
+
187
+ return jnp.block(
188
+ [row + row_t for row, row_t in zip(blocks[:-1], off_diags)] + [blocks[-1]]
189
+ )
190
+
191
+
192
+ @functools.partial(jax.jit, static_argnames="num_blocks")
193
+ def materialize_matrix_from_concat(
194
+ block_rows_concat,
195
+ num_blocks=None,
196
+ ):
197
+ """Returns a materialized symmetric matrix from concatenated slices.
198
+
199
+ Args:
200
+ block_rows_concat: The matrix represented as the concatenated
201
+ lower-triangular blocks.
202
+ num_blocks: The number of block-rows used to represent the symmetric matrix.
203
+ If not specified, it is inferred from the shape of block_rows_concat.
204
+ """
205
+ if num_blocks is None:
206
+ num_blocks = find_num_blocks(block_rows_concat)
207
+
208
+ block_size = block_rows_concat.shape[-2]
209
+
210
+ block_rows = [
211
+ block_rows_concat[
212
+ Ellipsis,
213
+ (k * (k + 1))
214
+ // 2
215
+ * block_size : (((k + 1) * (k + 2)) // 2 + 1)
216
+ * block_size,
217
+ ]
218
+ for k in range(num_blocks)
219
+ ]
220
+
221
+ return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows))
222
+
223
+
224
+ @functools.partial(jax.jit, static_argnames=("alpha", "beta", "axes"))
225
+ def update_sliced_rows(
226
+ symmetric_matrix,
227
+ mat,
228
+ alpha,
229
+ beta,
230
+ axes=(-1,),
231
+ ):
232
+ """Implements the blocked equivalent of SYRK.
233
+
234
+ Specifically, the symmetric matrix (represented using lower-triangular block
235
+ rows) is updated using the sliced product of mat.
236
+
237
+ Args:
238
+ symmetric_matrix: The symmetric matrix to update.
239
+ mat: The matrix to use for the update = mat * mat^T. The number of rows
240
+ should match that of symmetric_matrix.
241
+ alpha: The weight for the update.
242
+ beta: The weight for the original symmetric matrix.
243
+ axes: Axes to use for the contraction of the update.
244
+
245
+ Returns:
246
+ The updated rows of alpha * mat * mat^T + beta * symmetric_matrix.
247
+ """
248
+ block_size = symmetric_matrix.block_rows[0].shape[-2]
249
+ sym_prod = sliced_transposed_product(mat=mat, block_size=block_size, axes=axes)
250
+ return SlicedSymmetricMatrix(
251
+ block_rows=[
252
+ update * alpha + row * beta
253
+ for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows)
254
+ ]
255
+ )
256
+
257
+
258
+ def num_blocks_from_total_blocks(total_blocks):
259
+ """Returns the number of blocks (i.e.
260
+
261
+ block rows) from the total blocks.
262
+
263
+ This is the inverse of the function x -> x*(x+1)/2.
264
+
265
+ For example, the matrix M = [[A, B^T], [B, C]] may be represented using a
266
+ total of 3 blocks ([A, B, C]). The number of corresponding block rows is 2.
267
+
268
+ Args:
269
+ total_blocks: The total blocks used to represent the matrix.
270
+ """
271
+ num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32)
272
+ if (num_blocks * (num_blocks + 1)) / 2 != total_blocks:
273
+ raise ValueError(
274
+ f"total_blocks={total_blocks} does not correspond to "
275
+ "a symmetric matrix. It must have the form total_blocks = x*(x+1)/2."
276
+ )
277
+ return num_blocks
278
+
279
+
280
+ def find_num_blocks(block_rows_concat):
281
+ """Returns the number of (row) blocks representing the concatenated matrix.
282
+
283
+ For example, an input with dimensions [256, 2560] represents 10 square blocks,
284
+ which matches 4 lower-triangular block rows (1+2+3+4). So this function will
285
+ return 4.
286
+
287
+ Use ordinary numpy functions here so that the returned value is static.
288
+
289
+ Args:
290
+ block_rows_concat: The concatenated block array.
291
+
292
+ Raises:
293
+ ValueError: When the dimensions of the matrix do not correspond to a lower
294
+ triangular block representation.
295
+ """
296
+ # Compute the number of square blocks used to represent the matrix.
297
+ total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2]
298
+ # Determine the number of block rows by inverting y = x*(x+1)/2.
299
+ return num_blocks_from_total_blocks(total_blocks)
300
+
301
+
302
+ @functools.partial(jax.jit, static_argnames="block_size")
303
+ def slice_symmetric_matrix(
304
+ mat,
305
+ block_size,
306
+ ):
307
+ """Returns sliced row blocks.
308
+
309
+ Args:
310
+ mat: A symmetric matrix.
311
+ block_size: The size of the row slices.
312
+ """
313
+ num_rows = mat.shape[-2]
314
+ num_cols = mat.shape[-1]
315
+ if num_rows != num_cols:
316
+ raise ValueError("mat is not square.")
317
+ if num_rows % block_size != 0:
318
+ raise ValueError(
319
+ f"block size does not evenly divide rows. num_rows={num_rows}, block_size={block_size}"
320
+ )
321
+ return SlicedSymmetricMatrix(
322
+ block_rows=[
323
+ mat[
324
+ Ellipsis,
325
+ i * block_size : (i + 1) * block_size,
326
+ 0 : (i + 1) * block_size,
327
+ ]
328
+ for i in range(num_rows // block_size)
329
+ ]
330
+ )
331
+
332
+
333
+ @functools.partial(jax.jit, static_argnames="block_size")
334
+ def slice_symmetric_matrix_concat(
335
+ mat,
336
+ block_size,
337
+ ):
338
+ """Returns the concatenated sliced row blocks.
339
+
340
+ Args:
341
+ mat: A symmetric matrix.
342
+ block_size: The size of the row slices.
343
+ """
344
+ sliced_symmetric_matrix = slice_symmetric_matrix(mat=mat, block_size=block_size)
345
+ return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
346
+
347
+
348
+ def sliced_matrix_diag(mat):
349
+ """Returns the diagonal of the symmetric matrix.
350
+
351
+ Args:
352
+ mat: The symmetric matrix represented in concatenated block form.
353
+ """
354
+ rows, cols = mat.shape
355
+ total_blocks = cols // rows
356
+ num_blocks = num_blocks_from_total_blocks(total_blocks)
357
+ diags = []
358
+ for i in range(num_blocks):
359
+ last_index = rows * ((i + 2) * (i + 1)) // 2
360
+ first_index = last_index - rows
361
+ diags.append(jnp.diag(mat[Ellipsis, first_index:last_index]))
362
+ return jnp.concatenate(diags, axis=-1)
363
+
364
+
365
+ def diag_as_concat(diag, block_size):
366
+ """Returns the representation of a diagonal matrix in symmetric block form.
367
+
368
+ Args:
369
+ diag: The 1D array for the diagonals.
370
+ block_size: The size of blocks to use. Must divide the length of diag.
371
+ """
372
+ assert len(diag.shape) == 1 # diag must be 1D.
373
+ assert len(diag) % block_size == 0
374
+ num_diag_blocks = len(diag) // block_size
375
+ blocks = []
376
+ for i in range(num_diag_blocks):
377
+ blocks.append(jnp.zeros(shape=(block_size, block_size * i), dtype=diag.dtype))
378
+ blocks.append(jnp.diag(diag[i * block_size : (i + 1) * block_size]))
379
+ return jnp.concatenate(blocks, axis=-1)
380
+
381
+
382
+ def row_abs_maxes(mat):
383
+ """Returns the max of the absolute values of the rows of the full matrix.
384
+
385
+ For example the symmetric matrix M = [[1, 6], [6, 2]] is represented using
386
+ mat = [1, 6, 2] with block_size = 1. In this case the function returns the
387
+ aboslute row maxes of the original symmetric matrix, [6, 6].
388
+
389
+ Args:
390
+ mat: The symmetric matrix represented as the concatenated blocks.
391
+ """
392
+ rows, cols = mat.shape
393
+
394
+ # Find col and row max for each block.
395
+ col_maxes = []
396
+ row_maxes = []
397
+ for i in range(cols // rows):
398
+ block = jnp.abs(mat[Ellipsis, i * rows : (i + 1) * rows])
399
+ col_maxes.append(jnp.max(block, axis=1))
400
+ row_maxes.append(jnp.max(block, axis=0))
401
+
402
+ # global row max from block maxes.
403
+ num_blocks = num_blocks_from_total_blocks(cols // rows)
404
+ maxes = []
405
+ for i in range(num_blocks):
406
+ maxes.append(
407
+ jnp.concatenate(
408
+ row_maxes[(i * (i + 1) // 2) : ((i + 2) * (i + 1) // 2)]
409
+ + [
410
+ col_maxes[((j + 1) * (j + 2)) // 2 - (j - i + 1)]
411
+ for j in range(i + 1, num_blocks)
412
+ ],
413
+ axis=-1,
414
+ )
415
+ )
416
+
417
+ return jnp.max(jnp.stack(maxes), axis=0)
418
+
419
+
420
+ def times_vector(mat, vec):
421
+ """Returns the symmetric block-concatenated matrix multiplied by a vector.
422
+
423
+ Specifically, each value in the vector is multiplied by a row of the full
424
+ matrix. That is, the vector is broadcast and multiplied element-wise. Note
425
+ this would be the transpose of full_mat * vec if full_mat represented the full
426
+ symmetric matrix.
427
+
428
+ Args:
429
+ mat: The symmetric matrix represented as the concatenated blocks.
430
+ vec: The vector, having the same dimension as the materialized matrix.
431
+ """
432
+ rows, cols = mat.shape
433
+ num_blocks = num_blocks_from_total_blocks(cols // rows)
434
+ multiplied = []
435
+ for i in range(num_blocks):
436
+ mat_block = mat[
437
+ Ellipsis, rows * ((i + 1) * i) // 2 : rows * ((i + 1) * (i + 2)) // 2
438
+ ]
439
+ vec_block = vec[Ellipsis, rows * i : rows * (i + 1)]
440
+ multiplied.append(jnp.einsum("...ij,...i->ij", mat_block, vec_block))
441
+ return jnp.concatenate(multiplied, axis=-1)
tools/train/sweep.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ program: train.py
2
+ project: dalle-mini
3
+ method: random
4
+ metric:
5
+ name: eval/loss
6
+ goal: minimize
7
+ parameters:
8
+ optim:
9
+ value: distributed_shampoo
10
+ learning_rate:
11
+ distribution: log_uniform
12
+ # from exp(min) to exp(max)
13
+ min: -9.2
14
+ max: -6.9
15
+ tokenizer_name:
16
+ value: boris/dalle-mini-tokenizer
17
+ config_name:
18
+ value: ./config/mini
19
+ dtype:
20
+ value: bfloat16
21
+ dataset_repo_or_path:
22
+ value: ./data
23
+ per_device_train_batch_size:
24
+ value: 64
25
+ per_device_eval_batch_size:
26
+ value: 64
27
+ gradient_accumulation_steps:
28
+ value: 1
29
+ warmup_steps:
30
+ value: 1000
31
+ num_train_epochs:
32
+ value: 1
33
+ max_train_samples:
34
+ value: 1000000
35
+ logging_steps:
36
+ value: 40
37
+ eval_steps:
38
+ value: 200
39
+
40
+ command:
41
+ - python3
42
+ - ${program}
43
+ - "--streaming"
44
+ - "--output_dir"
45
+ - "./output"
46
+ - "--overwrite_output_dir"
47
+ - "--do_train"
48
+ - "--do_eval"
49
+ - ${args}
tools/train/train.py ADDED
@@ -0,0 +1,1740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021-2022 The HuggingFace & DALL·E Mini team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training DALL·E Mini.
18
+ Script adapted from run_summarization_flax.py
19
+ """
20
+
21
+ import io
22
+ import logging
23
+ import os
24
+ import sys
25
+ import tempfile
26
+ import time
27
+ from dataclasses import asdict, dataclass, field
28
+ from functools import partial
29
+ from pathlib import Path
30
+ from typing import Any, Callable, NamedTuple, Optional
31
+
32
+ import datasets
33
+ import flax
34
+ import jax
35
+ import jax.numpy as jnp
36
+ import jaxlib
37
+ import numpy as np
38
+ import optax
39
+ import transformers
40
+ import wandb
41
+ from datasets import Dataset
42
+ from flax import core, struct, traverse_util
43
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
44
+ from flax.serialization import from_bytes, to_bytes
45
+ from flax.training.common_utils import onehot
46
+ from jax.experimental import PartitionSpec, maps
47
+ from jax.experimental.compilation_cache import compilation_cache as cc
48
+ from jax.experimental.pjit import pjit, with_sharding_constraint
49
+ from scalable_shampoo.distributed_shampoo import GraftingType, distributed_shampoo
50
+ from tqdm import tqdm
51
+ from transformers import HfArgumentParser
52
+
53
+ import dalle_mini
54
+ from dalle_mini.data import Dataset
55
+ from dalle_mini.model import (
56
+ DalleBart,
57
+ DalleBartConfig,
58
+ DalleBartTokenizer,
59
+ set_partitions,
60
+ )
61
+
62
+ try:
63
+ from google.cloud import storage
64
+ except:
65
+ storage = None
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+ cc.initialize_cache("jax_cache")
70
+
71
+
72
+ @dataclass
73
+ class ModelArguments:
74
+ """
75
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
76
+ """
77
+
78
+ model_name_or_path: Optional[str] = field(
79
+ default=None,
80
+ metadata={
81
+ "help": "The model checkpoint for weights initialization. "
82
+ "Don't set if you want to train a model from scratch. "
83
+ "W&B artifact references are supported in addition to the sources supported by `PreTrainedModel`."
84
+ },
85
+ )
86
+ config_name: Optional[str] = field(
87
+ default=None,
88
+ metadata={
89
+ "help": "Pretrained config name or path if not the same as model_name_or_path"
90
+ },
91
+ )
92
+ tokenizer_name: Optional[str] = field(
93
+ default=None,
94
+ metadata={
95
+ "help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
96
+ },
97
+ )
98
+ dtype: Optional[str] = field(
99
+ default="float32",
100
+ metadata={
101
+ "help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
102
+ },
103
+ )
104
+ restore_state: Optional[bool] = field(
105
+ default=False,
106
+ metadata={
107
+ "help": "Restore optimizer and training state. Can be True (will retrieve associated wandb artifact), a local directory or a Google bucket path."
108
+ },
109
+ )
110
+ dropout: Optional[float] = field(
111
+ default=None,
112
+ metadata={"help": "Dropout rate. Overwrites config."},
113
+ )
114
+ activation_dropout: Optional[float] = field(
115
+ default=None,
116
+ metadata={"help": "Activation dropout rate. Overwrites config."},
117
+ )
118
+ attention_dropout: Optional[float] = field(
119
+ default=None,
120
+ metadata={"help": "Attention dropout rate. Overwrites config."},
121
+ )
122
+
123
+ def __post_init__(self):
124
+ if self.tokenizer_name is None:
125
+ self.tokenizer_name = self.model_name_or_path
126
+ assert (
127
+ self.tokenizer_name is not None
128
+ ), "Tokenizer name or model name/path needs to be specified"
129
+ if self.restore_state:
130
+ assert self.model_name_or_path is not None and (
131
+ "/model-" in self.model_name_or_path
132
+ ), "Restoring state only available with W&B artifact reference"
133
+
134
+ def get_metadata(self):
135
+ if self.model_name_or_path is not None and ":" in self.model_name_or_path:
136
+ if jax.process_index() == 0:
137
+ artifact = wandb.run.use_artifact(self.model_name_or_path)
138
+ else:
139
+ artifact = wandb.Api().artifact(self.model_name_or_path)
140
+ return artifact.metadata
141
+ else:
142
+ return dict()
143
+
144
+ def get_opt_state(self):
145
+ with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
146
+ if self.restore_state is True:
147
+ # wandb artifact
148
+ state_artifact = self.model_name_or_path.replace(
149
+ "/model-", "/state-", 1
150
+ )
151
+ if jax.process_index() == 0:
152
+ artifact = wandb.run.use_artifact(state_artifact)
153
+ else:
154
+ artifact = wandb.Api().artifact(state_artifact)
155
+ if artifact.metadata.get("bucket_path"):
156
+ # we will read directly file contents
157
+ self.restore_state = artifact.metadata["bucket_path"]
158
+ else:
159
+ artifact_dir = artifact.download(tmp_dir)
160
+ self.restore_state = str(Path(artifact_dir) / "opt_state.msgpack")
161
+
162
+ if self.restore_state.startswith("gs://"):
163
+ bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
164
+ bucket, blob_name = str(bucket_path).split("/", 1)
165
+ assert (
166
+ storage is not None
167
+ ), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
168
+ client = storage.Client()
169
+ bucket = client.bucket(bucket)
170
+ blob = bucket.blob(blob_name)
171
+ return blob.download_as_bytes()
172
+
173
+ with Path(self.restore_state).open("rb") as f:
174
+ return f.read()
175
+
176
+
177
+ @dataclass
178
+ class DataTrainingArguments:
179
+ """
180
+ Arguments pertaining to what data we are going to input our model for training and eval.
181
+ """
182
+
183
+ text_column: Optional[str] = field(
184
+ default="caption",
185
+ metadata={
186
+ "help": "The name of the column in the datasets containing the full texts (for summarization)."
187
+ },
188
+ )
189
+ encoding_column: Optional[str] = field(
190
+ default="encoding",
191
+ metadata={
192
+ "help": "The name of the column in the datasets containing the image encodings."
193
+ },
194
+ )
195
+ dataset_repo_or_path: str = field(
196
+ default=None,
197
+ metadata={"help": "The dataset repository containing encoded files."},
198
+ )
199
+ train_file: Optional[str] = field(
200
+ default=None,
201
+ metadata={
202
+ "help": "The input training data file (glob & braceexpand acceptable)."
203
+ },
204
+ )
205
+ validation_file: Optional[str] = field(
206
+ default=None,
207
+ metadata={
208
+ "help": "An optional input evaluation data file (glob & braceexpand acceptable)."
209
+ },
210
+ )
211
+ # data loading should not be a bottleneck so we use "streaming" mode by default
212
+ streaming: Optional[bool] = field(
213
+ default=True,
214
+ metadata={"help": "Whether to stream the dataset."},
215
+ )
216
+ use_auth_token: Optional[bool] = field(
217
+ default=False,
218
+ metadata={
219
+ "help": "Whether to use the authentication token for private datasets."
220
+ },
221
+ )
222
+ shard_by_host: Optional[bool] = field(
223
+ default=False,
224
+ metadata={
225
+ "help": "Whether to shard data files by host in multi-host environments."
226
+ },
227
+ )
228
+ blank_caption_prob: Optional[float] = field(
229
+ default=0.0,
230
+ metadata={
231
+ "help": "Probability of removing some captions for classifier-free guidance."
232
+ },
233
+ )
234
+ clip_score_column: Optional[str] = field(
235
+ default="clip_score",
236
+ metadata={"help": "Column that containts clip score for filtering."},
237
+ )
238
+ min_clip_score: Optional[float] = field(
239
+ default=None,
240
+ metadata={"help": "Minimum clip score required."},
241
+ )
242
+ max_clip_score: Optional[float] = field(
243
+ default=None,
244
+ metadata={"help": "Maximum clip score required."},
245
+ )
246
+ filter_column: Optional[str] = field(
247
+ default=None,
248
+ metadata={"help": "Column that containts classes to be filtered."},
249
+ )
250
+ filter_value: Optional[str] = field(
251
+ default=None,
252
+ metadata={"help": "Class value to be kept during filtering."},
253
+ )
254
+ multi_eval_ds: Optional[bool] = field(
255
+ default=False,
256
+ metadata={
257
+ "help": "Whether to look for multiple validation datasets (local support only)."
258
+ },
259
+ )
260
+ max_train_samples: Optional[int] = field(
261
+ default=None,
262
+ metadata={
263
+ "help": "For debugging purposes or quicker training, truncate the number of training examples."
264
+ },
265
+ )
266
+ max_eval_samples: Optional[int] = field(
267
+ default=None,
268
+ metadata={
269
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples."
270
+ },
271
+ )
272
+ preprocessing_num_workers: Optional[int] = field(
273
+ default=None,
274
+ metadata={
275
+ "help": "The number of processes to use for the preprocessing. Not used in streaming mode."
276
+ },
277
+ )
278
+ overwrite_cache: bool = field(
279
+ default=False,
280
+ metadata={
281
+ "help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
282
+ },
283
+ )
284
+ # default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
285
+ seed_dataset: int = field(
286
+ default=None,
287
+ metadata={
288
+ "help": "Random seed for the dataset that will be set at the beginning of training."
289
+ },
290
+ )
291
+
292
+ def __post_init__(self):
293
+ if self.dataset_repo_or_path is None:
294
+ raise ValueError("Need a dataset repository or path.")
295
+
296
+
297
+ @dataclass
298
+ class TrainingArguments:
299
+ """
300
+ Arguments pertaining to training parameters.
301
+ """
302
+
303
+ output_dir: str = field(
304
+ metadata={
305
+ "help": "The output directory where the model predictions and checkpoints will be written."
306
+ },
307
+ )
308
+ overwrite_output_dir: bool = field(
309
+ default=False,
310
+ metadata={
311
+ "help": (
312
+ "Overwrite the content of the output directory. "
313
+ "Use this to continue training if output_dir points to a checkpoint directory."
314
+ )
315
+ },
316
+ )
317
+
318
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
319
+ do_eval: bool = field(
320
+ default=False, metadata={"help": "Whether to run eval on the validation set."}
321
+ )
322
+
323
+ per_device_train_batch_size: int = field(
324
+ default=8,
325
+ metadata={"help": "Batch size per data parallel device for training."},
326
+ )
327
+ per_device_eval_batch_size: Optional[int] = field(
328
+ default=None,
329
+ metadata={
330
+ "help": "Batch size per data parallel device for evaluation. Same as training batch size if not set."
331
+ },
332
+ )
333
+
334
+ gradient_accumulation_steps: int = field(
335
+ default=1,
336
+ metadata={
337
+ "help": "Number of updates steps to accumulate before performing an update pass."
338
+ },
339
+ )
340
+ gradient_checkpointing: bool = field(
341
+ default=False, metadata={"help": "Use gradient checkpointing."}
342
+ )
343
+
344
+ learning_rate: float = field(
345
+ default=5e-5, metadata={"help": "The initial learning rate."}
346
+ )
347
+ optim: str = field(
348
+ default="distributed_shampoo",
349
+ metadata={
350
+ "help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
351
+ },
352
+ )
353
+ weight_decay: float = field(
354
+ default=0.0, metadata={"help": "Weight decay applied to parameters."}
355
+ )
356
+ beta1: float = field(
357
+ default=0.9,
358
+ metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
359
+ )
360
+ beta2: float = field(
361
+ default=0.999,
362
+ metadata={"help": "Beta2 for for Adam & Distributed Shampoo."},
363
+ )
364
+ adam_epsilon: float = field(
365
+ default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
366
+ )
367
+ max_grad_norm: float = field(
368
+ default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
369
+ )
370
+ block_size: int = field(
371
+ default=1024,
372
+ metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
373
+ )
374
+ preconditioning_compute_steps: int = field(
375
+ default=10, metadata={"help": "Number of steps to update preconditioner."}
376
+ )
377
+ skip_preconditioning_dim_size_gt: int = field(
378
+ default=4096,
379
+ metadata={"help": "Max size for preconditioning with Distributed Shampoo."},
380
+ )
381
+ graft_type: str = field(
382
+ default="rmsprop_normalized",
383
+ metadata={
384
+ "help": "The type of grafting to use. Can be 'rmsprop_normalized' (default), 'rmsprop', 'adagrad', 'adagrad_normalized', 'sgd' or 'sqrt_n'"
385
+ },
386
+ )
387
+ nesterov: bool = field(
388
+ default=False,
389
+ metadata={"help": "Use Nesterov momentum for Distributed Shampoo."},
390
+ )
391
+ optim_quantized: bool = field(
392
+ default=False,
393
+ metadata={
394
+ "help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
395
+ },
396
+ )
397
+ shard_shampoo_across: str = field(
398
+ default="dp",
399
+ metadata={
400
+ "help": "Whether to shard the optimizer across data devices (dp), model devices (mp) or both (2d)."
401
+ },
402
+ )
403
+
404
+ num_train_epochs: int = field(
405
+ default=3, metadata={"help": "Total number of training epochs to perform."}
406
+ )
407
+
408
+ warmup_steps: int = field(
409
+ default=0, metadata={"help": "Linear warmup over warmup_steps."}
410
+ )
411
+ lr_decay: str = field(
412
+ default=None,
413
+ metadata={
414
+ "help": "Decay to be used in the learning rate scheduler. Can be None (default), linear or exponential."
415
+ },
416
+ )
417
+ lr_transition_steps: int = field(
418
+ default=None,
419
+ metadata={
420
+ "help": "Number of transition steps associated with learning rate decay when using exponential decay."
421
+ },
422
+ )
423
+ lr_decay_rate: float = field(
424
+ default=None,
425
+ metadata={
426
+ "help": "Decay rate associated with learning rate when using exponential decay."
427
+ },
428
+ )
429
+ lr_staircase: bool = field(
430
+ default=False,
431
+ metadata={
432
+ "help": "Whether to use staircase or continuous learning rate when using exponential decay."
433
+ },
434
+ )
435
+ lr_offset: int = field(
436
+ default=0,
437
+ metadata={"help": "Number of steps to offset learning rate and keep it at 0."},
438
+ )
439
+ logging_steps: int = field(
440
+ default=40, metadata={"help": "Log every X updates steps."}
441
+ )
442
+ eval_steps: int = field(
443
+ default=400, metadata={"help": "Run an evaluation every X steps."}
444
+ )
445
+ save_steps: int = field(
446
+ default=4000, metadata={"help": "Save checkpoint every X updates steps."}
447
+ )
448
+ log_model: bool = field(
449
+ default=False,
450
+ metadata={"help": "Log model to wandb at `save_steps` frequency."},
451
+ )
452
+ log_norm_steps: int = field(
453
+ default=True,
454
+ metadata={"help": "Log parameters and gradients norm at this frequency."},
455
+ )
456
+ log_histogram_steps: int = field(
457
+ default=False,
458
+ metadata={
459
+ "help": "Log parameters and gradients histograms at this frequency. Slows down training."
460
+ },
461
+ )
462
+
463
+ seed_model: int = field(
464
+ default=42,
465
+ metadata={
466
+ "help": "Random seed for the model that will be set at the beginning of training."
467
+ },
468
+ )
469
+
470
+ embeddings_only: bool = field(
471
+ default=False, metadata={"help": "Train only embedding layers."}
472
+ )
473
+ init_embeddings: bool = field(
474
+ default=False,
475
+ metadata={"help": "When training embedding layers, initialize them."},
476
+ )
477
+
478
+ wandb_entity: Optional[str] = field(
479
+ default=None,
480
+ metadata={"help": "The wandb entity to use (for teams)."},
481
+ )
482
+ wandb_project: str = field(
483
+ default="dalle-mini",
484
+ metadata={"help": "The name of the wandb project."},
485
+ )
486
+ wandb_job_type: str = field(
487
+ default="Seq2Seq",
488
+ metadata={"help": "The name of the wandb job type."},
489
+ )
490
+
491
+ assert_TPU_available: bool = field(
492
+ default=False,
493
+ metadata={"help": "Verify that TPU is not in use."},
494
+ )
495
+
496
+ use_vmap_trick: bool = field(
497
+ default=True,
498
+ metadata={"help": "Verify that TPU is not in use."},
499
+ )
500
+
501
+ mp_devices: Optional[int] = field(
502
+ default=1,
503
+ metadata={
504
+ "help": "Number of devices required for model parallelism. The other dimension of available devices is used for data parallelism."
505
+ },
506
+ )
507
+
508
+ dp_devices: int = field(init=False)
509
+
510
+ def __post_init__(self):
511
+ if self.assert_TPU_available:
512
+ assert (
513
+ jax.local_device_count() == 8
514
+ ), "TPUs in use, please check running processes"
515
+ if self.output_dir.startswith("gs://"):
516
+ assert (
517
+ storage is not None
518
+ ), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
519
+ assert self.optim in [
520
+ "distributed_shampoo",
521
+ "adam",
522
+ "adafactor",
523
+ ], f"Selected optimizer not supported: {self.optim}"
524
+ if self.optim == "adafactor" and self.weight_decay == 0:
525
+ self.weight_decay = None
526
+ assert self.graft_type in [
527
+ "rmsprop_normalized",
528
+ "rmsprop",
529
+ "adagrad",
530
+ "adagrad_normalized",
531
+ "sgd",
532
+ "sqrt_n",
533
+ ], f"Selected graft type not supported: {self.graft_type}"
534
+ assert self.lr_decay in [
535
+ None,
536
+ "linear",
537
+ "exponential",
538
+ ], f"Selected learning rate decay not supported: {self.lr_decay}"
539
+ if self.per_device_eval_batch_size is None:
540
+ self.per_device_eval_batch_size = self.per_device_train_batch_size
541
+ if self.log_norm_steps is True:
542
+ self.log_norm_steps = self.logging_steps
543
+ if not self.do_train:
544
+ self.num_train_epochs = 1
545
+ if (
546
+ os.path.exists(self.output_dir)
547
+ and os.listdir(self.output_dir)
548
+ and self.do_train
549
+ and not self.overwrite_output_dir
550
+ ):
551
+ raise ValueError(
552
+ f"Output directory ({self.output_dir}) already exists and is not empty."
553
+ "Use --overwrite_output_dir to overcome."
554
+ )
555
+ assert self.shard_shampoo_across in [
556
+ "dp",
557
+ "mp",
558
+ "2d",
559
+ ], f"Shard shampoo across {self.shard_shampoo_across} not supported."
560
+ assert (
561
+ self.mp_devices > 0
562
+ ), f"Number of devices for model parallelism must be > 0"
563
+ assert (
564
+ jax.device_count() % self.mp_devices == 0
565
+ ), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})."
566
+ self.dp_devices = jax.device_count() // self.mp_devices
567
+
568
+
569
+ def split_params(data):
570
+ """Split params between scanned and non-scanned"""
571
+ flat = traverse_util.flatten_dict(unfreeze(data))
572
+ split = {"standard": {}, "scanned_encoder": {}, "scanned_decoder": {}}
573
+ for k, v in flat.items():
574
+ if "FlaxBartEncoderLayers" in k:
575
+ split["scanned_encoder"][k] = v
576
+ elif "FlaxBartDecoderLayers" in k:
577
+ split["scanned_decoder"][k] = v
578
+ else:
579
+ split["standard"][k] = v
580
+ # remove empty keys
581
+ split = {k: v for k, v in split.items() if v}
582
+ for k, v in split.items():
583
+ split[k] = freeze(traverse_util.unflatten_dict(v))
584
+ return split
585
+
586
+
587
+ def unsplit_params(data):
588
+ flat = {}
589
+ for k in ["standard", "scanned_encoder", "scanned_decoder"]:
590
+ if k in data:
591
+ flat.update(traverse_util.flatten_dict(unfreeze(data[k])))
592
+ return freeze(traverse_util.unflatten_dict(flat))
593
+
594
+
595
+ def trainable_params(data, embeddings_only):
596
+ """Keep only trainable parameters"""
597
+
598
+ if not embeddings_only:
599
+ return data
600
+
601
+ data = unfreeze(data)
602
+ trainable = {
603
+ "lm_head": data["lm_head"],
604
+ "model": {
605
+ "decoder": {
606
+ layer: data["model"]["decoder"][layer]
607
+ for layer in [
608
+ "embed_positions",
609
+ "embed_tokens",
610
+ "final_ln",
611
+ "layernorm_embedding",
612
+ ]
613
+ }
614
+ },
615
+ }
616
+ return freeze(trainable)
617
+
618
+
619
+ def init_embeddings(model, params):
620
+ """Reinitialize trainable embeddings"""
621
+ # Must match params in trainable_params() above
622
+ trainable_keypaths = [
623
+ "lm_head.kernel",
624
+ "model.decoder.embed_positions.embedding",
625
+ "model.decoder.embed_tokens.embedding",
626
+ "model.decoder.final_ln.bias",
627
+ "model.decoder.layernorm_embedding.bias",
628
+ "model.decoder.layernorm_embedding.scale",
629
+ ]
630
+
631
+ # Note: using private _missing_keys
632
+ init_keys = {tuple(k.split(".")) for k in trainable_keypaths}
633
+ model._missing_keys = init_keys
634
+ return model.init_weights(model.key, model.input_shape, params=params)
635
+
636
+
637
+ def main():
638
+ # See all possible arguments by passing the --help flag to this script.
639
+ parser = HfArgumentParser(
640
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
641
+ )
642
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
643
+ # If we pass only one argument to the script and it's the path to a json file,
644
+ # let's parse it to get our arguments.
645
+ model_args, data_args, training_args = parser.parse_json_file(
646
+ json_file=os.path.abspath(sys.argv[1])
647
+ )
648
+ else:
649
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
650
+
651
+ # check arguments
652
+ if training_args.mp_devices > jax.local_device_count():
653
+ assert (
654
+ data_args.seed_dataset is not None
655
+ ), "Seed dataset must be provided when model is split over multiple hosts"
656
+
657
+ # Make one log on every process with the configuration for debugging.
658
+ logging.basicConfig(
659
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
660
+ datefmt="%m/%d/%Y %H:%M:%S",
661
+ level=logging.INFO,
662
+ )
663
+ # Setup logging, we only want one process per machine to log things on the screen.
664
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
665
+ if jax.process_index() == 0:
666
+ datasets.utils.logging.set_verbosity_warning()
667
+ transformers.utils.logging.set_verbosity_info()
668
+ else:
669
+ datasets.utils.logging.set_verbosity_error()
670
+ transformers.utils.logging.set_verbosity_error()
671
+
672
+ # Set the verbosity to info of the Transformers logger (on main process only):
673
+ logger.info(f"Training/evaluation parameters {training_args}")
674
+
675
+ # Load dataset
676
+ dataset = Dataset(
677
+ **asdict(data_args),
678
+ do_train=training_args.do_train,
679
+ do_eval=training_args.do_eval,
680
+ )
681
+
682
+ logger.info(f"Local TPUs: {jax.local_device_count()}")
683
+ logger.info(f"Global TPUs: {jax.device_count()}")
684
+
685
+ # Set up wandb run
686
+ if jax.process_index() == 0:
687
+ wandb.init(
688
+ entity=training_args.wandb_entity,
689
+ project=training_args.wandb_project,
690
+ job_type=training_args.wandb_job_type,
691
+ config=parser.parse_args(),
692
+ )
693
+
694
+ # Set up our new model config
695
+ config_args = {
696
+ k: getattr(model_args, k)
697
+ for k in ["dropout", "activation_dropout", "attention_dropout"]
698
+ if getattr(model_args, k) is not None
699
+ }
700
+ config_args["gradient_checkpointing"] = training_args.gradient_checkpointing
701
+ if model_args.config_name:
702
+ config = DalleBartConfig.from_pretrained(model_args.config_name)
703
+ else:
704
+ config = None
705
+
706
+ # Load or create new model
707
+ if model_args.model_name_or_path:
708
+ model, params = DalleBart.from_pretrained(
709
+ model_args.model_name_or_path,
710
+ config=config,
711
+ seed=training_args.seed_model,
712
+ dtype=getattr(jnp, model_args.dtype),
713
+ _do_init=False,
714
+ )
715
+ if training_args.embeddings_only and training_args.init_embeddings:
716
+ params = init_embeddings(model, params)
717
+ else:
718
+ model = DalleBart(
719
+ config,
720
+ seed=training_args.seed_model,
721
+ dtype=getattr(jnp, model_args.dtype),
722
+ _do_init=False,
723
+ )
724
+ params = None
725
+ for k, v in config_args.items():
726
+ setattr(model.config, k, v)
727
+ params_shape = model.params_shape_tree
728
+
729
+ # get model metadata
730
+ model_metadata = model_args.get_metadata()
731
+
732
+ # get PartitionSpec for model params (required to be a dict)
733
+ param_spec = set_partitions(params_shape, model.config.use_scan)
734
+ params_shape = freeze(params_shape)
735
+ if params is not None:
736
+ params = freeze(params)
737
+
738
+ # Load tokenizer
739
+ tokenizer = DalleBartTokenizer.from_pretrained(
740
+ model_args.tokenizer_name, use_fast=True
741
+ )
742
+
743
+ # Preprocessing the datasets.
744
+ # We need to normalize and tokenize inputs and targets.
745
+ dataset.preprocess(tokenizer=tokenizer, config=model.config)
746
+
747
+ # Initialize our training
748
+ dropout_rng = jax.random.PRNGKey(training_args.seed_model)
749
+
750
+ # Store some constant
751
+ num_epochs = training_args.num_train_epochs
752
+ # batch size
753
+ batch_size_per_node_per_grad_step = (
754
+ training_args.per_device_train_batch_size
755
+ * jax.local_device_count()
756
+ // training_args.mp_devices
757
+ )
758
+ batch_size_per_node = (
759
+ batch_size_per_node_per_grad_step * training_args.gradient_accumulation_steps
760
+ )
761
+ batch_size_per_step = batch_size_per_node * jax.process_count()
762
+ eval_batch_size_per_node = (
763
+ training_args.per_device_eval_batch_size
764
+ * jax.local_device_count()
765
+ // training_args.mp_devices
766
+ )
767
+ eval_batch_size_per_step = eval_batch_size_per_node * jax.process_count()
768
+ len_train_dataset, len_eval_dataset = dataset.length
769
+ steps_per_epoch = (
770
+ len_train_dataset // batch_size_per_node
771
+ if len_train_dataset is not None
772
+ else None
773
+ )
774
+ num_train_steps = (
775
+ steps_per_epoch * num_epochs if steps_per_epoch is not None else None
776
+ )
777
+ num_params = model.num_params(params_shape)
778
+
779
+ logger.info("***** Running training *****")
780
+ logger.info(f" Num examples = {len_train_dataset}")
781
+ logger.info(f" Num Epochs = {num_epochs}")
782
+ logger.info(
783
+ f" Batch size per dp device = {training_args.per_device_train_batch_size}"
784
+ )
785
+ logger.info(f" Number of devices = {jax.device_count()}")
786
+ logger.info(
787
+ f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
788
+ )
789
+ logger.info(f" Batch size per update = {batch_size_per_step}")
790
+ logger.info(f" Model parameters = {num_params:,}")
791
+
792
+ # set up wandb run
793
+ if jax.process_index() == 0:
794
+ # set default x-axis as 'train/step'
795
+ wandb.define_metric("*", step_metric="train/step")
796
+
797
+ # add interesting config parameters
798
+ wandb.config.update(
799
+ {
800
+ "len_train_dataset": len_train_dataset,
801
+ "len_eval_dataset": len_eval_dataset,
802
+ "batch_size_per_step": batch_size_per_step,
803
+ "num_params": num_params,
804
+ "model_config": model.config.to_dict(),
805
+ "num_devices": jax.device_count(),
806
+ "versions": {
807
+ "jax": jax.__version__,
808
+ "jaxlib": jaxlib.__version__,
809
+ "flax": flax.__version__,
810
+ "transformers": transformers.__version__,
811
+ "datasets": datasets.__version__,
812
+ "wandb": wandb.__version__,
813
+ "dalle_mini": dalle_mini.__version__,
814
+ },
815
+ }
816
+ )
817
+
818
+ # Create learning rate schedule
819
+ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
820
+ """Create the learning rate function."""
821
+ warmup_fn = optax.linear_schedule(
822
+ init_value=0.0,
823
+ end_value=training_args.learning_rate,
824
+ transition_steps=training_args.warmup_steps + 1, # ensure not 0
825
+ )
826
+ last_boundary = training_args.warmup_steps
827
+ # offset step when resuming
828
+ if training_args.lr_offset:
829
+ warmup_fn = optax.join_schedules(
830
+ schedules=[optax.constant_schedule(0.0), warmup_fn],
831
+ boundaries=[training_args.lr_offset],
832
+ )
833
+ last_boundary += training_args.lr_offset
834
+ if training_args.lr_decay is None:
835
+ return warmup_fn
836
+ elif training_args.lr_decay == "linear":
837
+ assert (
838
+ num_train_steps is not None
839
+ ), "linear decay requires knowing the dataset length"
840
+ decay_fn = optax.linear_schedule(
841
+ init_value=training_args.learning_rate,
842
+ end_value=0,
843
+ transition_steps=num_train_steps - training_args.warmup_steps,
844
+ )
845
+ elif training_args.lr_decay == "exponential":
846
+ decay_fn = optax.exponential_decay(
847
+ init_value=training_args.learning_rate,
848
+ transition_steps=training_args.lr_transition_steps,
849
+ decay_rate=training_args.lr_decay_rate,
850
+ staircase=training_args.lr_staircase,
851
+ )
852
+ schedule_fn = optax.join_schedules(
853
+ schedules=[warmup_fn, decay_fn],
854
+ boundaries=[last_boundary],
855
+ )
856
+ return schedule_fn
857
+
858
+ learning_rate_fn = create_learning_rate_fn()
859
+
860
+ # create optimizer
861
+ trainable_params_shape = trainable_params(
862
+ params_shape, training_args.embeddings_only
863
+ )
864
+ if training_args.optim == "distributed_shampoo":
865
+ # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
866
+ graft_type = {
867
+ "sgd": GraftingType.SGD,
868
+ "adagrad": GraftingType.ADAGRAD,
869
+ "rmsprop": GraftingType.RMSPROP,
870
+ "rmsprop_normalized": GraftingType.RMSPROP_NORMALIZED,
871
+ "sqrt_n": GraftingType.SQRT_N,
872
+ "adagrad_normalized": GraftingType.ADAGRAD_NORMALIZED,
873
+ }[training_args.graft_type]
874
+ statistics_partition_spec = (
875
+ PartitionSpec(None, training_args.shard_shampoo_across, None)
876
+ if training_args.shard_shampoo_across != "2d"
877
+ else PartitionSpec(None, "dp", "mp")
878
+ )
879
+ opt = distributed_shampoo(
880
+ learning_rate_fn,
881
+ block_size=training_args.block_size,
882
+ beta1=training_args.beta1,
883
+ beta2=training_args.beta2,
884
+ diagonal_epsilon=1e-10,
885
+ matrix_epsilon=1e-6,
886
+ weight_decay=training_args.weight_decay,
887
+ start_preconditioning_step=max(
888
+ training_args.preconditioning_compute_steps + 1, 101
889
+ ),
890
+ preconditioning_compute_steps=training_args.preconditioning_compute_steps,
891
+ statistics_compute_steps=1,
892
+ best_effort_shape_interpretation=True,
893
+ graft_type=graft_type,
894
+ nesterov=training_args.nesterov,
895
+ exponent_override=0,
896
+ statistics_partition_spec=statistics_partition_spec,
897
+ preconditioner_partition_spec=PartitionSpec(
898
+ training_args.shard_shampoo_across, None, None
899
+ )
900
+ if training_args.shard_shampoo_across != "2d"
901
+ else PartitionSpec(
902
+ "mp" if training_args.mp_devices > training_args.dp_devices else "dp",
903
+ None,
904
+ None,
905
+ ),
906
+ num_devices_for_pjit=training_args.dp_devices,
907
+ shard_optimizer_states=True,
908
+ inverse_failure_threshold=0.1,
909
+ moving_average_for_momentum=True,
910
+ skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
911
+ clip_by_scaled_gradient_norm=None,
912
+ precision=jax.lax.Precision.HIGHEST,
913
+ best_effort_memory_usage_reduction=training_args.optim_quantized,
914
+ )
915
+ # get the real optimizer and helper functions
916
+ update_fn = opt.update
917
+
918
+ optimizer = {}
919
+ opt_fn = {}
920
+ for k, p in split_params(trainable_params_shape).items():
921
+ if "scanned" in k:
922
+ p = jax.eval_shape(
923
+ lambda x: jax.tree_util.tree_map(lambda y: y[0], x), p
924
+ )
925
+ optimizer[k] = opt.init(p)
926
+ opt_fn[k] = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
927
+ optimizer[k].pspec_fn, optimizer[k].shape_and_dtype_fn
928
+ )
929
+ optimizer[k] = optax.GradientTransformation(optimizer[k].init_fn, update_fn)
930
+
931
+ elif training_args.optim == "adam":
932
+ optimizer = optax.adamw(
933
+ learning_rate=learning_rate_fn,
934
+ b1=training_args.beta1,
935
+ b2=training_args.beta2,
936
+ eps=training_args.adam_epsilon,
937
+ weight_decay=training_args.weight_decay,
938
+ )
939
+ optimizer = {k: optimizer for k in split_params(trainable_params_shape)}
940
+
941
+ elif training_args.optim == "adafactor":
942
+ # We use the default parameters here to initialize adafactor,
943
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
944
+ optimizer = optax.adafactor(
945
+ learning_rate=learning_rate_fn,
946
+ clipping_threshold=training_args.max_grad_norm,
947
+ weight_decay_rate=training_args.weight_decay,
948
+ )
949
+ optimizer = {k: optimizer for k in split_params(trainable_params_shape)}
950
+
951
+ # get PartitionSpec for optimizer state
952
+ def get_opt_state_spec_and_shape():
953
+ # get opt_state shape without actual init
954
+ opt_state_shape = {}
955
+ for k, p in split_params(trainable_params_shape).items():
956
+ if "scanned" not in k:
957
+ opt_state_shape[k] = jax.eval_shape(optimizer[k].init, p)
958
+ else:
959
+ opt_state_shape[k] = jax.eval_shape(jax.vmap(optimizer[k].init), p)
960
+
961
+ if training_args.optim == "adafactor":
962
+ # factorized state must be replicated (rank different than params)
963
+ opt_state_spec = {k: None for k in split_params(trainable_params_shape)}
964
+
965
+ elif training_args.optim in ["adam", "distributed_shampoo"]:
966
+
967
+ def _opt_state_spec_per_leaf(x, spec):
968
+ if isinstance(x, FrozenDict):
969
+ # variables with same structure as params
970
+ return spec
971
+ else:
972
+ # other variables such as count
973
+ return None
974
+
975
+ split_spec = split_params(set_partitions(trainable_params_shape, False))
976
+ opt_state_spec = {}
977
+ for k, p in split_params(trainable_params_shape).items():
978
+ if "scanned" in k:
979
+ p = jax.eval_shape(
980
+ lambda x: jax.tree_util.tree_map(lambda y: y[0], x), p
981
+ )
982
+ if training_args.optim == "adam":
983
+ opt_state_spec[k] = jax.tree_util.tree_map(
984
+ partial(_opt_state_spec_per_leaf, spec=split_spec[k]),
985
+ opt_state_shape[k],
986
+ # return None spec for empty elements
987
+ is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
988
+ )
989
+ elif training_args.optim == "distributed_shampoo":
990
+ opt_state_spec[k] = opt_fn[k].pspec_fn(
991
+ p,
992
+ split_spec[k],
993
+ statistics_partition_spec,
994
+ )
995
+ # add dimension for scanned params
996
+ if "scanned" in k:
997
+ opt_state_spec[k] = jax.tree_util.tree_map(
998
+ lambda x: PartitionSpec(*(None,) + x)
999
+ if x is not None
1000
+ else None,
1001
+ opt_state_spec[k],
1002
+ is_leaf=lambda x: isinstance(x, PartitionSpec),
1003
+ )
1004
+
1005
+ else:
1006
+ raise NotImplementedError
1007
+ return freeze(opt_state_spec), freeze(opt_state_shape)
1008
+
1009
+ opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape()
1010
+
1011
+ # create a mesh
1012
+ mesh_shape = (training_args.dp_devices, training_args.mp_devices)
1013
+ devices = np.asarray(jax.devices()).reshape(*mesh_shape)
1014
+ mesh = maps.Mesh(devices, ("dp", "mp"))
1015
+ logger.info(f" Mesh shape: {mesh_shape}")
1016
+
1017
+ # define TrainState
1018
+ class TrainState(struct.PyTreeNode):
1019
+ step: int
1020
+ params: core.FrozenDict[str, Any]
1021
+ opt_state: optax.OptState
1022
+ apply_fn: Callable = struct.field(pytree_node=False)
1023
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
1024
+ dropout_rng: jnp.ndarray = None
1025
+ epoch: int = 0
1026
+ train_time: float = 0.0 # total time the model trained
1027
+ train_samples: int = 0 # number of samples seen
1028
+
1029
+ def apply_gradients(self, *, grads, **kwargs):
1030
+ grads = split_params(trainable_params(grads, training_args.embeddings_only))
1031
+ params = split_params(
1032
+ trainable_params(self.params, training_args.embeddings_only)
1033
+ )
1034
+ opt_state = {}
1035
+ # we loop over keys: "standard", "scanned_encoder", "scanned_decoder"
1036
+ for k, param in params.items():
1037
+ update_fn = self.tx[k].update
1038
+ if "scanned" in k:
1039
+ update_fn = jax.vmap(update_fn, in_axes=(0, 0, 0), out_axes=(0, 0))
1040
+ updates, new_opt_state = update_fn(grads[k], self.opt_state[k], param)
1041
+ params[k] = optax.apply_updates(param, updates)
1042
+ opt_state[k] = new_opt_state
1043
+ params = unsplit_params(params)
1044
+ # merge with non-trainable params
1045
+ params, new_params = traverse_util.flatten_dict(
1046
+ unfreeze(self.params)
1047
+ ), traverse_util.flatten_dict(unfreeze(params))
1048
+ params.update(new_params)
1049
+ params = freeze(traverse_util.unflatten_dict(params))
1050
+
1051
+ return self.replace(
1052
+ step=self.step + 1,
1053
+ params=params,
1054
+ opt_state=freeze(opt_state),
1055
+ **kwargs,
1056
+ )
1057
+
1058
+ @classmethod
1059
+ def create(cls, *, apply_fn, params, tx, **kwargs):
1060
+ opt_state = {}
1061
+ for k, p in split_params(
1062
+ trainable_params(params, training_args.embeddings_only)
1063
+ ).items():
1064
+ init_fn = tx[k].init
1065
+ if "scanned" in k:
1066
+ init_fn = jax.vmap(init_fn)
1067
+ opt_state[k] = init_fn(p)
1068
+ return cls(
1069
+ step=0,
1070
+ apply_fn=apply_fn,
1071
+ params=params,
1072
+ tx=tx,
1073
+ opt_state=freeze(opt_state),
1074
+ **kwargs,
1075
+ )
1076
+
1077
+ # define state spec
1078
+ state_spec = TrainState(
1079
+ params=param_spec,
1080
+ opt_state=opt_state_spec,
1081
+ dropout_rng=None,
1082
+ step=None,
1083
+ epoch=None,
1084
+ train_time=None,
1085
+ train_samples=None,
1086
+ apply_fn=model.__call__,
1087
+ tx=optimizer,
1088
+ )
1089
+
1090
+ # init params if not available yet
1091
+ def maybe_init_params(params):
1092
+ if params is not None:
1093
+ # model params are correctly loaded
1094
+ return params
1095
+ else:
1096
+ # params have not been initialized yet
1097
+ return model.init_weights(model.key, model.input_shape)
1098
+
1099
+ with mesh:
1100
+ logger.info(" Creating state")
1101
+
1102
+ # restore metadata
1103
+ attr_state = {}
1104
+ keys = ["train_time", "train_samples"]
1105
+ if model_args.restore_state:
1106
+ keys += ["step", "epoch"]
1107
+ attr_state = {k: v for k, v in model_metadata.items() if k in keys}
1108
+
1109
+ if not model_args.restore_state:
1110
+
1111
+ def init_state(params):
1112
+ return TrainState.create(
1113
+ apply_fn=model.__call__,
1114
+ tx=optimizer,
1115
+ params=maybe_init_params(params),
1116
+ dropout_rng=dropout_rng,
1117
+ **attr_state,
1118
+ )
1119
+
1120
+ state = pjit(
1121
+ init_state,
1122
+ in_axis_resources=(param_spec,)
1123
+ if model_args.model_name_or_path
1124
+ else None,
1125
+ out_axis_resources=state_spec,
1126
+ donate_argnums=(0,),
1127
+ )(params)
1128
+
1129
+ else:
1130
+ # load opt_state
1131
+ opt_state = from_bytes(opt_state_shape, model_args.get_opt_state())
1132
+
1133
+ def restore_state(params, opt_state):
1134
+ return TrainState(
1135
+ apply_fn=model.__call__,
1136
+ tx=optimizer,
1137
+ params=params,
1138
+ opt_state=opt_state,
1139
+ dropout_rng=dropout_rng,
1140
+ **attr_state,
1141
+ )
1142
+
1143
+ state = pjit(
1144
+ restore_state,
1145
+ in_axis_resources=(
1146
+ param_spec,
1147
+ opt_state_spec,
1148
+ ),
1149
+ out_axis_resources=state_spec,
1150
+ donate_argnums=(0, 1),
1151
+ )(params, opt_state)
1152
+
1153
+ # remove opt_state from CPU
1154
+ del opt_state
1155
+
1156
+ # free CPU memory
1157
+ del params, opt_state_spec, opt_state_shape
1158
+
1159
+ # define batch specs
1160
+ batch_spec = PartitionSpec("dp")
1161
+ grad_batch_spec = PartitionSpec(None, "dp")
1162
+
1163
+ # define loss
1164
+ def loss_fn(logits, labels):
1165
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
1166
+ loss = loss.mean()
1167
+ return loss
1168
+
1169
+ # "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens)
1170
+ # lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2
1171
+ use_vmap_trick = training_args.use_vmap_trick
1172
+
1173
+ # make grad_param_spec for vmap
1174
+ if use_vmap_trick:
1175
+ grad_param_spec = jax.tree_util.tree_map(
1176
+ lambda x: PartitionSpec(*("dp",) + (x if x is not None else (None,))),
1177
+ param_spec,
1178
+ )
1179
+
1180
+ # Define gradient update step fn
1181
+ def train_step(state, batch, train_time):
1182
+ # get a minibatch (one gradient accumulation slice)
1183
+ def get_minibatch(batch, grad_idx):
1184
+ return jax.tree_util.tree_map(
1185
+ lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),
1186
+ batch,
1187
+ )
1188
+
1189
+ def compute_loss(params, minibatch, dropout_rng):
1190
+ # minibatch has dim (batch_size, ...)
1191
+ minibatch, labels = minibatch.pop("labels")
1192
+ logits = state.apply_fn(
1193
+ **minibatch, params=params, dropout_rng=dropout_rng, train=True
1194
+ )[0]
1195
+ return loss_fn(logits, labels)
1196
+
1197
+ grad_fn = jax.value_and_grad(compute_loss)
1198
+
1199
+ def loss_and_grad(grad_idx, dropout_rng):
1200
+ # minibatch at grad_idx for gradient accumulation (None otherwise)
1201
+ minibatch = (
1202
+ get_minibatch(batch, grad_idx) if grad_idx is not None else batch
1203
+ )
1204
+ # ensure it is sharded properly
1205
+ minibatch = with_sharding_constraint(minibatch, batch_spec)
1206
+ # only 1 single rng per grad step, let us handle larger batch size (not sure why)
1207
+ dropout_rng, _ = jax.random.split(dropout_rng)
1208
+
1209
+ if use_vmap_trick:
1210
+ # "vmap trick", calculate loss and grads independently per dp_device
1211
+ loss, grads = jax.vmap(
1212
+ grad_fn, in_axes=(None, 0, None), out_axes=(0, 0)
1213
+ )(state.params, minibatch, dropout_rng)
1214
+ # ensure they are sharded correctly
1215
+ loss = with_sharding_constraint(loss, batch_spec)
1216
+ grads = with_sharding_constraint(grads, grad_param_spec)
1217
+ # average across all devices
1218
+ # Note: we could average per device only after gradient accumulation, right before params update
1219
+ loss, grads = jax.tree_util.tree_map(
1220
+ lambda x: jnp.mean(x, axis=0), (loss, grads)
1221
+ )
1222
+ else:
1223
+ # "vmap trick" does not work in multi-hosts and requires too much hbm
1224
+ loss, grads = grad_fn(state.params, minibatch, dropout_rng)
1225
+ # ensure grads are sharded
1226
+ grads = with_sharding_constraint(grads, param_spec)
1227
+ # return loss and grads
1228
+ return loss, grads, dropout_rng
1229
+
1230
+ if training_args.gradient_accumulation_steps == 1:
1231
+ loss, grads, dropout_rng = loss_and_grad(None, state.dropout_rng)
1232
+ else:
1233
+ # create initial state for cumul_minibatch_step loop
1234
+ init_minibatch_step = (
1235
+ 0.0,
1236
+ with_sharding_constraint(
1237
+ jax.tree_util.tree_map(jnp.zeros_like, state.params), param_spec
1238
+ ),
1239
+ state.dropout_rng,
1240
+ )
1241
+
1242
+ # accumulate gradients
1243
+ def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
1244
+ cumul_loss, cumul_grads, dropout_rng = cumul_loss_grad_dropout
1245
+ loss, grads, dropout_rng = loss_and_grad(grad_idx, dropout_rng)
1246
+ cumul_loss, cumul_grads = jax.tree_util.tree_map(
1247
+ jnp.add, (cumul_loss, cumul_grads), (loss, grads)
1248
+ )
1249
+ cumul_grads = with_sharding_constraint(cumul_grads, param_spec)
1250
+ return cumul_loss, cumul_grads, dropout_rng
1251
+
1252
+ # loop over gradients
1253
+ loss, grads, dropout_rng = jax.lax.fori_loop(
1254
+ 0,
1255
+ training_args.gradient_accumulation_steps,
1256
+ cumul_minibatch_step,
1257
+ init_minibatch_step,
1258
+ )
1259
+ grads = with_sharding_constraint(grads, param_spec)
1260
+ # sum -> mean
1261
+ loss, grads = jax.tree_util.tree_map(
1262
+ lambda x: x / training_args.gradient_accumulation_steps, (loss, grads)
1263
+ )
1264
+
1265
+ grads = with_sharding_constraint(grads, param_spec)
1266
+
1267
+ # update state
1268
+ state = state.apply_gradients(
1269
+ grads=grads,
1270
+ dropout_rng=dropout_rng,
1271
+ train_time=train_time,
1272
+ train_samples=state.train_samples + batch_size_per_step,
1273
+ )
1274
+
1275
+ metrics = {
1276
+ "loss": loss,
1277
+ "learning_rate": learning_rate_fn(state.step),
1278
+ }
1279
+
1280
+ def maybe_fn(fn, val, zeros, freq):
1281
+ """Call fn only if it is a logging step"""
1282
+ return jax.lax.cond(
1283
+ state.step % freq == 0,
1284
+ fn,
1285
+ lambda _: zeros,
1286
+ val,
1287
+ )
1288
+
1289
+ # log additional metrics
1290
+ params = trainable_params(state.params, training_args.embeddings_only)
1291
+ grads = trainable_params(grads, training_args.embeddings_only)
1292
+ if training_args.log_norm_steps:
1293
+ zeros_norm = jax.tree_util.tree_map(lambda _: jnp.float32(0), params)
1294
+
1295
+ def norm(val):
1296
+ return jax.tree_util.tree_map(lambda x: jnp.linalg.norm(x), val)
1297
+
1298
+ gradients_norm = maybe_fn(
1299
+ norm, grads, zeros_norm, training_args.log_norm_steps
1300
+ )
1301
+ params_norm = maybe_fn(
1302
+ norm, params, zeros_norm, training_args.log_norm_steps
1303
+ )
1304
+
1305
+ metrics.update(
1306
+ {
1307
+ "gradients_norm": gradients_norm,
1308
+ "params_norm": params_norm,
1309
+ }
1310
+ )
1311
+
1312
+ if training_args.log_histogram_steps:
1313
+ zeros_hist = jax.tree_util.tree_map(
1314
+ lambda _: jnp.histogram(jnp.zeros(1), density=True), params
1315
+ )
1316
+
1317
+ def histogram(val):
1318
+ return jax.tree_util.tree_map(
1319
+ lambda x: jnp.histogram(x, density=True), val
1320
+ )
1321
+
1322
+ gradients_hist = maybe_fn(
1323
+ histogram, grads, zeros_hist, training_args.log_histogram_steps
1324
+ )
1325
+ params_hist = maybe_fn(
1326
+ histogram, params, zeros_hist, training_args.log_histogram_steps
1327
+ )
1328
+
1329
+ metrics.update(
1330
+ {
1331
+ "params_hist": params_hist,
1332
+ "gradients_hist": gradients_hist,
1333
+ }
1334
+ )
1335
+
1336
+ return state, metrics
1337
+
1338
+ # Define eval fn
1339
+ eval_model = (
1340
+ model
1341
+ if model_args.dtype == "float32"
1342
+ else DalleBart(
1343
+ model.config,
1344
+ seed=training_args.seed_model,
1345
+ dtype=jnp.float32,
1346
+ _do_init=False,
1347
+ )
1348
+ )
1349
+
1350
+ def eval_step(state, batch):
1351
+ def compute_eval_loss(batch):
1352
+ batch, labels = batch.pop("labels")
1353
+ logits = eval_model(**batch, params=state.params, train=False)[0]
1354
+ return loss_fn(logits, labels)
1355
+
1356
+ if use_vmap_trick:
1357
+ loss = jax.vmap(compute_eval_loss)(batch)
1358
+ # ensure they are sharded correctly
1359
+ loss = with_sharding_constraint(loss, batch_spec)
1360
+ # average across all devices
1361
+ loss = jnp.mean(loss)
1362
+ else:
1363
+ loss = compute_eval_loss(batch)
1364
+
1365
+ return loss
1366
+
1367
+ # Create parallel version of the train and eval step
1368
+ p_train_step = pjit(
1369
+ train_step,
1370
+ in_axis_resources=(
1371
+ state_spec,
1372
+ grad_batch_spec
1373
+ if training_args.gradient_accumulation_steps > 1
1374
+ else batch_spec,
1375
+ None,
1376
+ ),
1377
+ out_axis_resources=(state_spec, None),
1378
+ donate_argnums=(0,),
1379
+ )
1380
+ p_eval_step = pjit(
1381
+ eval_step,
1382
+ in_axis_resources=(state_spec, batch_spec),
1383
+ out_axis_resources=None,
1384
+ )
1385
+
1386
+ # define metrics logger
1387
+ class MetricsLogger:
1388
+ def __init__(self, step):
1389
+ # keep state
1390
+ self.state_dict = {}
1391
+ # estimate speed
1392
+ self.step = step
1393
+ self.time = time.perf_counter()
1394
+ self.offset_time = 0.0
1395
+
1396
+ def update_state_metrics(self, state):
1397
+ """Update internal state metrics (logged at each call to be used as x-axis)"""
1398
+ self.state_dict = {
1399
+ f'train/{k.split("_")[-1]}': state[k]
1400
+ for k in ["step", "epoch", "train_time", "train_samples"]
1401
+ }
1402
+ # timing metrics
1403
+ new_step = int(state["step"])
1404
+ new_time = time.perf_counter()
1405
+ if new_step > self.step:
1406
+ # remove time for eval & save
1407
+ delta_time = new_time - self.time - self.offset_time
1408
+ self.offset_time = 0
1409
+ time_per_step = delta_time / (new_step - self.step)
1410
+ self.step = new_step
1411
+ self.time = new_time
1412
+ self.log_time("train_per_step", time_per_step, offset=False)
1413
+ self.log_time("train_per_log", delta_time, offset=False)
1414
+
1415
+ def log_time(self, key, duration, offset=True):
1416
+ if jax.process_index() == 0:
1417
+ wandb.log({f"time/{key}": duration, **self.state_dict})
1418
+ if offset:
1419
+ self.offset_time += duration
1420
+
1421
+ def log(self, metrics, prefix=None):
1422
+ if jax.process_index() == 0:
1423
+ log_metrics = {}
1424
+ for k, v in metrics.items():
1425
+ if "_norm" in k:
1426
+ if self.step % training_args.log_norm_steps == 0:
1427
+ log_metrics[f"{k}/"] = unfreeze(v)
1428
+ elif "_hist" in k:
1429
+ if self.step % training_args.log_histogram_steps == 0:
1430
+ v = jax.tree_util.tree_map(
1431
+ lambda x: jax.device_get(x), unfreeze(v)
1432
+ )
1433
+ v = jax.tree_util.tree_map(
1434
+ lambda x: wandb.Histogram(np_histogram=x),
1435
+ v,
1436
+ is_leaf=lambda x: isinstance(x, tuple),
1437
+ )
1438
+ log_metrics[f"{k}/"] = v
1439
+ else:
1440
+ if prefix is not None:
1441
+ k = f"{prefix}/{k}"
1442
+ log_metrics[k] = v
1443
+ wandb.log({**log_metrics, **self.state_dict})
1444
+
1445
+ # keep local copy of state
1446
+ local_state = {
1447
+ k: jax.device_get(getattr(state, k)).item()
1448
+ for k in ["step", "epoch", "train_time", "train_samples"]
1449
+ }
1450
+ # init variables
1451
+ start_time = time.perf_counter() - local_state["train_time"]
1452
+ train_metrics = None
1453
+ evaluation_ran = False
1454
+ save_model_ran = False
1455
+ metrics_logger = MetricsLogger(local_state["step"])
1456
+ epochs = tqdm(
1457
+ range(local_state["epoch"], num_epochs),
1458
+ desc=f"Epoch ... (1/{num_epochs})",
1459
+ position=0,
1460
+ disable=jax.process_index() > 0,
1461
+ )
1462
+
1463
+ def run_evaluation():
1464
+ # ======================== Evaluating ==============================
1465
+ if training_args.do_eval:
1466
+ start_eval_time = time.perf_counter()
1467
+ # get validation datasets
1468
+ val_datasets = list(
1469
+ dataset.other_eval_datasets.keys()
1470
+ if hasattr(dataset, "other_eval_datasets")
1471
+ else []
1472
+ )
1473
+ val_datasets += ["eval"]
1474
+ for val_dataset in val_datasets:
1475
+ eval_loader = dataset.dataloader(
1476
+ val_dataset,
1477
+ eval_batch_size_per_step
1478
+ * max(1, training_args.mp_devices // jax.local_device_count()),
1479
+ )
1480
+ eval_steps = (
1481
+ len_eval_dataset // eval_batch_size_per_step
1482
+ if len_eval_dataset is not None
1483
+ else None
1484
+ )
1485
+ eval_loss = []
1486
+ for batch in tqdm(
1487
+ eval_loader,
1488
+ desc="Evaluating...",
1489
+ position=2,
1490
+ leave=False,
1491
+ total=eval_steps,
1492
+ disable=jax.process_index() > 0,
1493
+ ):
1494
+ # need to keep only eval_batch_size_per_node items relevant to the node
1495
+ batch = jax.tree_util.tree_map(
1496
+ lambda x: x.reshape(
1497
+ (jax.process_count(), eval_batch_size_per_node)
1498
+ + x.shape[1:]
1499
+ ),
1500
+ batch,
1501
+ )
1502
+ batch = jax.tree_util.tree_map(
1503
+ lambda x: x[jax.process_index()], batch
1504
+ )
1505
+
1506
+ # add dp dimension when using "vmap trick"
1507
+ if use_vmap_trick:
1508
+ bs_shape = (
1509
+ jax.local_device_count() // training_args.mp_devices,
1510
+ training_args.per_device_eval_batch_size,
1511
+ )
1512
+ batch = jax.tree_util.tree_map(
1513
+ lambda x: x.reshape(bs_shape + x.shape[1:]), batch
1514
+ )
1515
+
1516
+ # freeze batch to pass safely to jax transforms
1517
+ batch = freeze(batch)
1518
+ # accumulate losses async
1519
+ eval_loss.append(p_eval_step(state, batch))
1520
+
1521
+ # get the mean of the loss
1522
+ eval_loss = jnp.stack(eval_loss)
1523
+ eval_loss = jnp.mean(eval_loss)
1524
+ eval_metrics = {"loss": eval_loss}
1525
+
1526
+ # log metrics
1527
+ metrics_logger.log(eval_metrics, prefix=val_dataset)
1528
+
1529
+ # Print metrics and update progress bar
1530
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | {val_dataset} Loss: {eval_metrics['loss']})"
1531
+ epochs.write(desc)
1532
+ epochs.desc = desc
1533
+
1534
+ # log time
1535
+ metrics_logger.log_time("eval", time.perf_counter() - start_eval_time)
1536
+
1537
+ return eval_metrics
1538
+
1539
+ def run_save_model(state, eval_metrics=None):
1540
+ if jax.process_index() == 0:
1541
+ start_save_time = time.perf_counter()
1542
+ output_dir = training_args.output_dir
1543
+ use_bucket = output_dir.startswith("gs://")
1544
+ if use_bucket:
1545
+ bucket_path = Path(output_dir[5:]) / wandb.run.id / f"step_{state.step}"
1546
+ bucket, dir_path = str(bucket_path).split("/", 1)
1547
+ tmp_dir = tempfile.TemporaryDirectory()
1548
+ output_dir = tmp_dir.name
1549
+
1550
+ # save model
1551
+ params = jax.device_get(state.params)
1552
+ model.save_pretrained(
1553
+ output_dir,
1554
+ params=params,
1555
+ )
1556
+
1557
+ # save tokenizer
1558
+ tokenizer.save_pretrained(output_dir)
1559
+
1560
+ # copy to bucket
1561
+ if use_bucket:
1562
+ client = storage.Client()
1563
+ bucket = client.bucket(bucket)
1564
+ for filename in Path(output_dir).glob("*"):
1565
+ blob_name = str(Path(dir_path) / "model" / filename.name)
1566
+ blob = bucket.blob(blob_name)
1567
+ blob.upload_from_filename(str(filename))
1568
+ tmp_dir.cleanup()
1569
+
1570
+ # save state
1571
+ opt_state = jax.device_get(state.opt_state)
1572
+ if use_bucket:
1573
+ blob_name = str(Path(dir_path) / "state" / "opt_state.msgpack")
1574
+ blob = bucket.blob(blob_name)
1575
+ blob.upload_from_file(io.BytesIO(to_bytes(opt_state)))
1576
+ else:
1577
+ with (Path(output_dir) / "opt_state.msgpack").open("wb") as f:
1578
+ f.write(to_bytes(opt_state))
1579
+
1580
+ # save to W&B
1581
+ if training_args.log_model:
1582
+ # save some space
1583
+ c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
1584
+ c.cleanup(wandb.util.from_human_size("20GB"))
1585
+
1586
+ metadata = {
1587
+ k: jax.device_get(getattr(state, k)).item()
1588
+ for k in ["step", "epoch", "train_time", "train_samples"]
1589
+ }
1590
+ metadata["num_params"] = num_params
1591
+ if eval_metrics is not None:
1592
+ metadata["eval"] = eval_metrics
1593
+
1594
+ # create model artifact
1595
+ if use_bucket:
1596
+ metadata["bucket_path"] = f"gs://{bucket_path}/model"
1597
+ artifact = wandb.Artifact(
1598
+ name=f"model-{wandb.run.id}",
1599
+ type="DalleBart_model",
1600
+ metadata=metadata,
1601
+ )
1602
+ if use_bucket:
1603
+ artifact.add_reference(metadata["bucket_path"])
1604
+ else:
1605
+ for filename in [
1606
+ "config.json",
1607
+ "flax_model.msgpack",
1608
+ "merges.txt",
1609
+ "special_tokens_map.json",
1610
+ "tokenizer.json",
1611
+ "tokenizer_config.json",
1612
+ "vocab.json",
1613
+ ]:
1614
+ artifact.add_file(
1615
+ f"{Path(training_args.output_dir) / filename}"
1616
+ )
1617
+ wandb.run.log_artifact(artifact)
1618
+
1619
+ # create state artifact
1620
+ if use_bucket:
1621
+ metadata["bucket_path"] = f"gs://{bucket_path}/state"
1622
+ artifact_state = wandb.Artifact(
1623
+ name=f"state-{wandb.run.id}",
1624
+ type="DalleBart_state",
1625
+ metadata=metadata,
1626
+ )
1627
+ if use_bucket:
1628
+ artifact_state.add_reference(metadata["bucket_path"])
1629
+ else:
1630
+ artifact_state.add_file(
1631
+ f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
1632
+ )
1633
+ wandb.run.log_artifact(artifact_state)
1634
+ metrics_logger.log_time("save_model", time.perf_counter() - start_save_time)
1635
+
1636
+ logger.info(" Ready to start training")
1637
+ with mesh:
1638
+ for epoch in epochs:
1639
+ state = state.replace(epoch=epoch)
1640
+ local_state["epoch"] = epoch
1641
+ # ======================== Training ================================
1642
+ metrics_logger.update_state_metrics(local_state)
1643
+ metrics_logger.log({})
1644
+
1645
+ if training_args.do_train:
1646
+ # load data - may be replicated on multiple nodes
1647
+ node_groups = max(
1648
+ 1, training_args.mp_devices // jax.local_device_count()
1649
+ )
1650
+ loader_bs = batch_size_per_node * node_groups
1651
+ train_loader = dataset.dataloader(
1652
+ "train",
1653
+ loader_bs,
1654
+ epoch,
1655
+ )
1656
+ # train
1657
+ for batch in tqdm(
1658
+ train_loader,
1659
+ desc="Training...",
1660
+ position=1,
1661
+ leave=False,
1662
+ total=steps_per_epoch,
1663
+ disable=jax.process_index() > 0,
1664
+ ):
1665
+ # calculate delta time (we have a lag of one step but it's ok)
1666
+ train_time = time.perf_counter() - start_time
1667
+
1668
+ # reset control variables
1669
+ evaluation_ran = False
1670
+ save_model_ran = False
1671
+
1672
+ # set correct shape to batch
1673
+ # - add grad_step dim if gradient_accumulation_steps > 1
1674
+ bs_shape = (
1675
+ (batch_size_per_node_per_grad_step * node_groups,)
1676
+ if not use_vmap_trick
1677
+ else (
1678
+ jax.local_device_count()
1679
+ * node_groups
1680
+ // training_args.mp_devices, # local dp devices
1681
+ training_args.per_device_train_batch_size,
1682
+ )
1683
+ )
1684
+ if training_args.gradient_accumulation_steps > 1:
1685
+ # reshape data into (gradient_accumulation_steps, batch_per_node, ...)
1686
+ # to avoid any data redistribution when sharding
1687
+ bs_shape = (
1688
+ training_args.gradient_accumulation_steps,
1689
+ ) + bs_shape
1690
+
1691
+ # reshape batch
1692
+ batch = jax.tree_util.tree_map(
1693
+ lambda x: x.reshape(bs_shape + x.shape[1:]),
1694
+ batch,
1695
+ )
1696
+ # freeze batch to pass safely to jax transforms
1697
+ batch = freeze(batch)
1698
+
1699
+ # train step
1700
+ state, train_metrics = p_train_step(state, batch, train_time)
1701
+ local_state["step"] += 1
1702
+ local_state["train_time"] = train_time
1703
+ local_state["train_samples"] += batch_size_per_step
1704
+
1705
+ if (
1706
+ local_state["step"] % training_args.logging_steps == 0
1707
+ and jax.process_index() == 0
1708
+ ):
1709
+ metrics_logger.update_state_metrics(local_state)
1710
+ metrics_logger.log(train_metrics, prefix="train")
1711
+
1712
+ eval_metrics = None
1713
+ if local_state["step"] % training_args.eval_steps == 0:
1714
+ eval_metrics = run_evaluation()
1715
+ evaluation_ran = True
1716
+
1717
+ if local_state["step"] % training_args.save_steps == 0:
1718
+ run_save_model(state, eval_metrics)
1719
+ save_model_ran = True
1720
+
1721
+ # log final train metrics
1722
+ if train_metrics is not None:
1723
+ metrics_logger.update_state_metrics(local_state)
1724
+ metrics_logger.log(train_metrics, prefix="train")
1725
+
1726
+ epochs.write(
1727
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
1728
+ )
1729
+
1730
+ # Final evaluation at the end of each epoch
1731
+ if not evaluation_ran:
1732
+ eval_metrics = run_evaluation()
1733
+
1734
+ # save checkpoint after each epoch
1735
+ if not save_model_ran:
1736
+ run_save_model(state, eval_metrics)
1737
+
1738
+
1739
+ if __name__ == "__main__":
1740
+ main()