Spaces:
Runtime error
Runtime error
FritsLyneborg
commited on
Commit
·
6742988
1
Parent(s):
d2606d5
AI upload
Browse files- .gitignore +6 -0
- CITATION.cff +44 -0
- LICENSE +201 -0
- Makefile +5 -0
- app/gradio/app_gradio.py +179 -0
- app/gradio/requirements.txt +4 -0
- app/streamlit/app.py +117 -0
- app/streamlit/img/loading.gif +0 -0
- img/logo.png +0 -0
- pyproject.toml +2 -0
- setup.cfg +46 -0
- setup.py +4 -0
- src/dalle_mini/__init__.py +3 -0
- src/dalle_mini/data.py +378 -0
- src/dalle_mini/model/__init__.py +5 -0
- src/dalle_mini/model/configuration.py +176 -0
- src/dalle_mini/model/modeling.py +2093 -0
- src/dalle_mini/model/partitions.py +67 -0
- src/dalle_mini/model/processor.py +58 -0
- src/dalle_mini/model/text.py +262 -0
- src/dalle_mini/model/tokenizer.py +8 -0
- src/dalle_mini/model/utils.py +27 -0
- tools/dataset/encode_dataset.ipynb +371 -0
- tools/inference/inference_pipeline.ipynb +479 -0
- tools/train/config/medium/config.json +31 -0
- tools/train/config/mega/config.json +30 -0
- tools/train/config/micro/config.json +30 -0
- tools/train/config/mini/config.json +29 -0
- tools/train/config/mini_glu/config.json +29 -0
- tools/train/scalable_shampoo/README.md +7 -0
- tools/train/scalable_shampoo/distributed_shampoo.py +2267 -0
- tools/train/scalable_shampoo/quantization_utils.py +124 -0
- tools/train/scalable_shampoo/sm3.py +176 -0
- tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py +442 -0
- tools/train/sweep.yaml +49 -0
- tools/train/train.py +1436 -0
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.ipynb_checkpoints
|
3 |
+
.streamlit
|
4 |
+
wandb/
|
5 |
+
*.egg-info/
|
6 |
+
jax_cache/
|
CITATION.cff
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# YAML 1.2
|
2 |
+
---
|
3 |
+
abstract: "DALL·E mini is a JAX/Flax reimplementation of OpenAI's DALL·E that requires much smaller hardware resources. By simplifying the architecture and model memory requirements, as well as leveraging open-source code and pre-trained models, we were able to create a model that is 27 times smaller than the original DALL·E and train it on a single TPU v3-8 for only 3 days. DALL·E mini achieves impressive results, albeit of a lower quality than the original system. It can be used for exploration and further experimentation on commodity hardware."
|
4 |
+
authors:
|
5 |
+
-
|
6 |
+
family-names: Dayma
|
7 |
+
given-names: Boris
|
8 |
+
-
|
9 |
+
family-names: Patil
|
10 |
+
given-names: Suraj
|
11 |
+
-
|
12 |
+
family-names: Cuenca
|
13 |
+
given-names: Pedro
|
14 |
+
-
|
15 |
+
family-names: Saifullah
|
16 |
+
given-names: Khalid
|
17 |
+
-
|
18 |
+
family-names: Abraham
|
19 |
+
given-names: Tanishq
|
20 |
+
-
|
21 |
+
family-names: "Lê Khắc"
|
22 |
+
given-names: "Phúc"
|
23 |
+
-
|
24 |
+
family-names: Melas
|
25 |
+
given-names: Luke
|
26 |
+
-
|
27 |
+
family-names: Ghosh
|
28 |
+
given-names: Ritobrata
|
29 |
+
cff-version: "1.1.0"
|
30 |
+
date-released: 2021-07-29
|
31 |
+
identifiers:
|
32 |
+
keywords:
|
33 |
+
- dalle
|
34 |
+
- "text-to-image generation"
|
35 |
+
- transformer
|
36 |
+
- "zero-shot"
|
37 |
+
- JAX
|
38 |
+
license: "Apache-2.0"
|
39 |
+
doi: 10.5281/zenodo.5146400
|
40 |
+
message: "If you use this project, please cite it using these metadata."
|
41 |
+
repository-code: "https://github.com/borisdayma/dalle-mini"
|
42 |
+
title: "DALL·E Mini"
|
43 |
+
version: "v0.1-alpha"
|
44 |
+
...
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2021 The DALL·E mini Authors
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
Makefile
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.PHONY: style
|
2 |
+
|
3 |
+
style:
|
4 |
+
black .
|
5 |
+
isort .
|
app/gradio/app_gradio.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# Uncomment to run on cpu
|
5 |
+
# import os
|
6 |
+
# os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
7 |
+
|
8 |
+
import random
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
import jax
|
12 |
+
import numpy as np
|
13 |
+
from flax.jax_utils import replicate
|
14 |
+
from flax.training.common_utils import shard
|
15 |
+
from PIL import Image, ImageDraw, ImageFont
|
16 |
+
|
17 |
+
# ## CLIP Scoring
|
18 |
+
from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel
|
19 |
+
from vqgan_jax.modeling_flax_vqgan import VQModel
|
20 |
+
|
21 |
+
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
22 |
+
|
23 |
+
DALLE_REPO = "flax-community/dalle-mini"
|
24 |
+
DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
|
25 |
+
|
26 |
+
VQGAN_REPO = "flax-community/vqgan_f16_16384"
|
27 |
+
VQGAN_COMMIT_ID = "90cc46addd2dd8f5be21586a9a23e1b95aa506a9"
|
28 |
+
|
29 |
+
tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
|
30 |
+
model = CustomFlaxBartForConditionalGeneration.from_pretrained(
|
31 |
+
DALLE_REPO, revision=DALLE_COMMIT_ID
|
32 |
+
)
|
33 |
+
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
|
34 |
+
|
35 |
+
|
36 |
+
def captioned_strip(images, caption=None, rows=1):
|
37 |
+
increased_h = 0 if caption is None else 48
|
38 |
+
w, h = images[0].size[0], images[0].size[1]
|
39 |
+
img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h))
|
40 |
+
for i, img_ in enumerate(images):
|
41 |
+
img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
|
42 |
+
|
43 |
+
if caption is not None:
|
44 |
+
draw = ImageDraw.Draw(img)
|
45 |
+
font = ImageFont.truetype(
|
46 |
+
"/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
|
47 |
+
)
|
48 |
+
draw.text((20, 3), caption, (255, 255, 255), font=font)
|
49 |
+
return img
|
50 |
+
|
51 |
+
|
52 |
+
def custom_to_pil(x):
|
53 |
+
x = np.clip(x, 0.0, 1.0)
|
54 |
+
x = (255 * x).astype(np.uint8)
|
55 |
+
x = Image.fromarray(x)
|
56 |
+
if not x.mode == "RGB":
|
57 |
+
x = x.convert("RGB")
|
58 |
+
return x
|
59 |
+
|
60 |
+
|
61 |
+
def generate(input, rng, params):
|
62 |
+
return model.generate(
|
63 |
+
**input,
|
64 |
+
max_length=257,
|
65 |
+
num_beams=1,
|
66 |
+
do_sample=True,
|
67 |
+
prng_key=rng,
|
68 |
+
eos_token_id=50000,
|
69 |
+
pad_token_id=50000,
|
70 |
+
params=params,
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
def get_images(indices, params):
|
75 |
+
return vqgan.decode_code(indices, params=params)
|
76 |
+
|
77 |
+
|
78 |
+
p_generate = jax.pmap(generate, "batch")
|
79 |
+
p_get_images = jax.pmap(get_images, "batch")
|
80 |
+
|
81 |
+
bart_params = replicate(model.params)
|
82 |
+
vqgan_params = replicate(vqgan.params)
|
83 |
+
|
84 |
+
clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
85 |
+
print("Initialize FlaxCLIPModel")
|
86 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
87 |
+
print("Initialize CLIPProcessor")
|
88 |
+
|
89 |
+
|
90 |
+
def hallucinate(prompt, num_images=64):
|
91 |
+
prompt = [prompt] * jax.device_count()
|
92 |
+
inputs = tokenizer(
|
93 |
+
prompt,
|
94 |
+
return_tensors="jax",
|
95 |
+
padding="max_length",
|
96 |
+
truncation=True,
|
97 |
+
max_length=128,
|
98 |
+
).data
|
99 |
+
inputs = shard(inputs)
|
100 |
+
|
101 |
+
all_images = []
|
102 |
+
for i in range(num_images // jax.device_count()):
|
103 |
+
key = random.randint(0, 1e7)
|
104 |
+
rng = jax.random.PRNGKey(key)
|
105 |
+
rngs = jax.random.split(rng, jax.local_device_count())
|
106 |
+
indices = p_generate(inputs, rngs, bart_params).sequences
|
107 |
+
indices = indices[:, :, 1:]
|
108 |
+
|
109 |
+
images = p_get_images(indices, vqgan_params)
|
110 |
+
images = np.squeeze(np.asarray(images), 1)
|
111 |
+
for image in images:
|
112 |
+
all_images.append(custom_to_pil(image))
|
113 |
+
return all_images
|
114 |
+
|
115 |
+
|
116 |
+
def clip_top_k(prompt, images, k=8):
|
117 |
+
inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
|
118 |
+
outputs = clip(**inputs)
|
119 |
+
logits = outputs.logits_per_text
|
120 |
+
scores = np.array(logits[0]).argsort()[-k:][::-1]
|
121 |
+
return [images[score] for score in scores]
|
122 |
+
|
123 |
+
|
124 |
+
def compose_predictions(images, caption=None):
|
125 |
+
increased_h = 0 if caption is None else 48
|
126 |
+
w, h = images[0].size[0], images[0].size[1]
|
127 |
+
img = Image.new("RGB", (len(images) * w, h + increased_h))
|
128 |
+
for i, img_ in enumerate(images):
|
129 |
+
img.paste(img_, (i * w, increased_h))
|
130 |
+
|
131 |
+
if caption is not None:
|
132 |
+
draw = ImageDraw.Draw(img)
|
133 |
+
font = ImageFont.truetype(
|
134 |
+
"/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
|
135 |
+
)
|
136 |
+
draw.text((20, 3), caption, (255, 255, 255), font=font)
|
137 |
+
return img
|
138 |
+
|
139 |
+
|
140 |
+
def top_k_predictions(prompt, num_candidates=32, k=8):
|
141 |
+
images = hallucinate(prompt, num_images=num_candidates)
|
142 |
+
images = clip_top_k(prompt, images, k=k)
|
143 |
+
return images
|
144 |
+
|
145 |
+
|
146 |
+
def run_inference(prompt, num_images=32, num_preds=8):
|
147 |
+
images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
|
148 |
+
predictions = captioned_strip(images)
|
149 |
+
output_title = f"""
|
150 |
+
<b>{prompt}</b>
|
151 |
+
"""
|
152 |
+
return (output_title, predictions)
|
153 |
+
|
154 |
+
|
155 |
+
outputs = [
|
156 |
+
gr.outputs.HTML(label=""), # To be used as title
|
157 |
+
gr.outputs.Image(label=""),
|
158 |
+
]
|
159 |
+
|
160 |
+
description = """
|
161 |
+
DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
|
162 |
+
"""
|
163 |
+
gr.Interface(
|
164 |
+
run_inference,
|
165 |
+
inputs=[gr.inputs.Textbox(label="What do you want to see?")],
|
166 |
+
outputs=outputs,
|
167 |
+
title="DALL·E mini",
|
168 |
+
description=description,
|
169 |
+
article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
|
170 |
+
layout="vertical",
|
171 |
+
theme="huggingface",
|
172 |
+
examples=[
|
173 |
+
["an armchair in the shape of an avocado"],
|
174 |
+
["snowy mountains by the sea"],
|
175 |
+
],
|
176 |
+
allow_flagging=False,
|
177 |
+
live=False,
|
178 |
+
# server_port=8999
|
179 |
+
).launch(share=True)
|
app/gradio/requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Requirements for huggingface spaces
|
2 |
+
gradio>=2.2.3
|
3 |
+
flax
|
4 |
+
transformers
|
app/streamlit/app.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
import base64
|
5 |
+
from io import BytesIO
|
6 |
+
|
7 |
+
import requests
|
8 |
+
import streamlit as st
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
|
12 |
+
class ServiceError(Exception):
|
13 |
+
def __init__(self, status_code):
|
14 |
+
self.status_code = status_code
|
15 |
+
|
16 |
+
|
17 |
+
def get_images_from_backend(prompt, backend_url):
|
18 |
+
r = requests.post(backend_url, json={"prompt": prompt})
|
19 |
+
if r.status_code == 200:
|
20 |
+
images = r.json()["images"]
|
21 |
+
images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
|
22 |
+
return images
|
23 |
+
else:
|
24 |
+
raise ServiceError(r.status_code)
|
25 |
+
|
26 |
+
|
27 |
+
st.sidebar.markdown(
|
28 |
+
"""
|
29 |
+
<style>
|
30 |
+
.aligncenter {
|
31 |
+
text-align: center;
|
32 |
+
}
|
33 |
+
</style>
|
34 |
+
<p class="aligncenter">
|
35 |
+
<img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/img/logo.png"/>
|
36 |
+
</p>
|
37 |
+
""",
|
38 |
+
unsafe_allow_html=True,
|
39 |
+
)
|
40 |
+
st.sidebar.markdown(
|
41 |
+
"""
|
42 |
+
___
|
43 |
+
<p style='text-align: center'>
|
44 |
+
DALL·E mini is an AI model that generates images from any prompt you give!
|
45 |
+
</p>
|
46 |
+
|
47 |
+
<p style='text-align: center'>
|
48 |
+
Created by Boris Dayma et al. 2021
|
49 |
+
<br/>
|
50 |
+
<a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
|
51 |
+
</p>
|
52 |
+
""",
|
53 |
+
unsafe_allow_html=True,
|
54 |
+
)
|
55 |
+
|
56 |
+
st.header("DALL·E mini")
|
57 |
+
st.subheader("Generate images from text")
|
58 |
+
|
59 |
+
prompt = st.text_input("What do you want to see?")
|
60 |
+
|
61 |
+
DEBUG = False
|
62 |
+
if prompt != "":
|
63 |
+
container = st.empty()
|
64 |
+
container.markdown(
|
65 |
+
f"""
|
66 |
+
<style> p {{ margin:0 }} div {{ margin:0 }} </style>
|
67 |
+
<div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
|
68 |
+
<div class="stAlert">
|
69 |
+
<div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
|
70 |
+
<div class="st-b7">
|
71 |
+
<div class="css-whx05o e13vu3m50">
|
72 |
+
<div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
|
73 |
+
<img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
|
74 |
+
Generating predictions for: <b>{prompt}</b>
|
75 |
+
</div>
|
76 |
+
</div>
|
77 |
+
</div>
|
78 |
+
</div>
|
79 |
+
</div>
|
80 |
+
</div>
|
81 |
+
<small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
|
82 |
+
""",
|
83 |
+
unsafe_allow_html=True,
|
84 |
+
)
|
85 |
+
|
86 |
+
try:
|
87 |
+
backend_url = st.secrets["BACKEND_SERVER"]
|
88 |
+
print(f"Getting selections: {prompt}")
|
89 |
+
selected = get_images_from_backend(prompt, backend_url)
|
90 |
+
|
91 |
+
margin = 0.1 # for better position of zoom in arrow
|
92 |
+
n_columns = 3
|
93 |
+
cols = st.columns([1] + [margin, 1] * (n_columns - 1))
|
94 |
+
for i, img in enumerate(selected):
|
95 |
+
cols[(i % n_columns) * 2].image(img)
|
96 |
+
container.markdown(f"**{prompt}**")
|
97 |
+
|
98 |
+
st.button("Again!", key="again_button")
|
99 |
+
|
100 |
+
except ServiceError as error:
|
101 |
+
container.text(f"Service unavailable, status: {error.status_code}")
|
102 |
+
except KeyError:
|
103 |
+
if DEBUG:
|
104 |
+
container.markdown(
|
105 |
+
"""
|
106 |
+
**Error: BACKEND_SERVER unset**
|
107 |
+
|
108 |
+
Please, create a file called `.streamlit/secrets.toml` inside the app's folder and include a line to configure the server URL:
|
109 |
+
```
|
110 |
+
BACKEND_SERVER="<server url>"
|
111 |
+
```
|
112 |
+
"""
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
container.markdown(
|
116 |
+
"Error -5, please try again or [report it](mailto:pcuenca-dalle@guenever.net)."
|
117 |
+
)
|
app/streamlit/img/loading.gif
ADDED
img/logo.png
ADDED
pyproject.toml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[tool.isort]
|
2 |
+
profile = "black"
|
setup.cfg
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[metadata]
|
2 |
+
name = dalle-mini
|
3 |
+
version = attr: dalle_mini.__version__
|
4 |
+
author = Boris Dayma et al.
|
5 |
+
author_email = boris.dayma@gmail.com
|
6 |
+
description = DALL·E mini - Generate images from a text prompt
|
7 |
+
long_description = file: README.md
|
8 |
+
long_description_content_type = text/markdown
|
9 |
+
url = https://github.com/borisdayma/dalle-mini
|
10 |
+
project_urls =
|
11 |
+
Bug Tracker = https://github.com/borisdayma/dalle-mini/issues
|
12 |
+
classifiers =
|
13 |
+
Programming Language :: Python :: 3
|
14 |
+
License :: OSI Approved :: Apache Software License
|
15 |
+
Operating System :: OS Independent
|
16 |
+
Topic :: Scientific/Engineering :: Artificial Intelligence
|
17 |
+
Development Status :: 3 - Alpha
|
18 |
+
Intended Audience :: Developers
|
19 |
+
|
20 |
+
[options]
|
21 |
+
package_dir =
|
22 |
+
=src
|
23 |
+
packages = find:
|
24 |
+
python_requires = >=3.6
|
25 |
+
install_requires =
|
26 |
+
transformers
|
27 |
+
einops
|
28 |
+
unidecode
|
29 |
+
ftfy
|
30 |
+
emoji
|
31 |
+
pillow
|
32 |
+
jax
|
33 |
+
flax
|
34 |
+
wandb
|
35 |
+
|
36 |
+
[options.extras_require]
|
37 |
+
dev =
|
38 |
+
tqdm
|
39 |
+
optax
|
40 |
+
braceexpand
|
41 |
+
datasets[streaming]
|
42 |
+
black[jupyter]
|
43 |
+
isort
|
44 |
+
|
45 |
+
[options.packages.find]
|
46 |
+
where = src
|
setup.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup
|
2 |
+
|
3 |
+
if __name__ == "__main__":
|
4 |
+
setup()
|
src/dalle_mini/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "0.0.4"
|
2 |
+
|
3 |
+
from .model import DalleBart, DalleBartProcessor
|
src/dalle_mini/data.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
import jax
|
6 |
+
import jax.numpy as jnp
|
7 |
+
import numpy as np
|
8 |
+
from braceexpand import braceexpand
|
9 |
+
from datasets import Dataset, load_dataset
|
10 |
+
|
11 |
+
from .model.text import TextNormalizer
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class Dataset:
|
16 |
+
dataset_repo_or_path: str
|
17 |
+
train_file: str = None
|
18 |
+
validation_file: str = None
|
19 |
+
streaming: bool = True
|
20 |
+
use_auth_token: bool = False
|
21 |
+
text_column: str = "caption"
|
22 |
+
encoding_column: str = "encoding"
|
23 |
+
max_train_samples: int = None
|
24 |
+
max_eval_samples: int = None
|
25 |
+
preprocessing_num_workers: int = None
|
26 |
+
overwrite_cache: bool = False
|
27 |
+
do_train: bool = False
|
28 |
+
do_eval: bool = True
|
29 |
+
seed_dataset: int = None
|
30 |
+
shard_by_host: bool = False
|
31 |
+
blank_caption_prob: float = 0.0
|
32 |
+
clip_score_column: str = "clip_score"
|
33 |
+
min_clip_score: float = None
|
34 |
+
max_clip_score: float = None
|
35 |
+
filter_column: str = None
|
36 |
+
filter_value: str = None
|
37 |
+
train_dataset: Dataset = field(init=False)
|
38 |
+
eval_dataset: Dataset = field(init=False)
|
39 |
+
rng_dataset: jnp.ndarray = field(init=False)
|
40 |
+
multi_hosts: bool = field(init=False)
|
41 |
+
|
42 |
+
def __post_init__(self):
|
43 |
+
if self.seed_dataset is None:
|
44 |
+
# create a random seed
|
45 |
+
self.seed_dataset = random.randint(0, 2**32 - 1)
|
46 |
+
self.multi_hosts = jax.process_count() > 1
|
47 |
+
# feed blank captions only in streaming mode for now
|
48 |
+
# otherwise dataset could be cached with same blanked captions
|
49 |
+
if self.blank_caption_prob:
|
50 |
+
assert (
|
51 |
+
self.streaming is True
|
52 |
+
), "blank_caption_prob can only be used in streaming mode"
|
53 |
+
# define data_files
|
54 |
+
if self.train_file is not None or self.validation_file is not None:
|
55 |
+
# accept braceexpand notation
|
56 |
+
for k in ["train_file", "validation_file"]:
|
57 |
+
f = getattr(self, k)
|
58 |
+
if isinstance(f, str):
|
59 |
+
setattr(self, k, list(braceexpand(f)))
|
60 |
+
# for list of files, split training data shards by host
|
61 |
+
if (
|
62 |
+
isinstance(self.train_file, list)
|
63 |
+
and self.multi_hosts
|
64 |
+
and self.shard_by_host
|
65 |
+
):
|
66 |
+
self.train_file = self.train_file[
|
67 |
+
jax.process_index() :: jax.process_count()
|
68 |
+
]
|
69 |
+
data_files = {
|
70 |
+
"train": self.train_file,
|
71 |
+
"validation": self.validation_file,
|
72 |
+
}
|
73 |
+
else:
|
74 |
+
data_files = None
|
75 |
+
|
76 |
+
# load dataset
|
77 |
+
dataset = load_dataset(
|
78 |
+
self.dataset_repo_or_path,
|
79 |
+
data_files=data_files,
|
80 |
+
streaming=self.streaming,
|
81 |
+
use_auth_token=self.use_auth_token,
|
82 |
+
)
|
83 |
+
if self.do_train:
|
84 |
+
if "train" not in dataset:
|
85 |
+
raise ValueError("Training requires a training dataset")
|
86 |
+
self.train_dataset = dataset["train"]
|
87 |
+
if self.max_train_samples is not None:
|
88 |
+
self.train_dataset = (
|
89 |
+
self.train_dataset.take(self.max_train_samples)
|
90 |
+
if self.streaming
|
91 |
+
else self.train_dataset.select(range(self.max_train_samples))
|
92 |
+
)
|
93 |
+
if self.do_eval:
|
94 |
+
if "validation" not in dataset:
|
95 |
+
raise ValueError("Evaluating requires a validation dataset")
|
96 |
+
self.eval_dataset = dataset["validation"]
|
97 |
+
if self.max_eval_samples is not None:
|
98 |
+
self.eval_dataset = (
|
99 |
+
self.eval_dataset.take(self.max_eval_samples)
|
100 |
+
if self.streaming
|
101 |
+
else self.eval_dataset.select(range(self.max_eval_samples))
|
102 |
+
)
|
103 |
+
|
104 |
+
def preprocess(self, tokenizer, config):
|
105 |
+
# get required config variables
|
106 |
+
decoder_start_token_id = config.decoder_start_token_id
|
107 |
+
normalize_text = config.normalize_text
|
108 |
+
max_length = config.max_text_length
|
109 |
+
|
110 |
+
if self.streaming:
|
111 |
+
# we need to shuffle early in streaming mode
|
112 |
+
if hasattr(self, "train_dataset"):
|
113 |
+
self.train_dataset = self.train_dataset.shuffle(
|
114 |
+
buffer_size=5000, seed=self.seed_dataset
|
115 |
+
)
|
116 |
+
else:
|
117 |
+
self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
|
118 |
+
|
119 |
+
# filter data
|
120 |
+
partial_filter_function = partial(
|
121 |
+
filter_function,
|
122 |
+
filter_column=self.filter_column,
|
123 |
+
filter_value=self.filter_value,
|
124 |
+
clip_score_column=self.clip_score_column,
|
125 |
+
min_clip_score=self.min_clip_score,
|
126 |
+
max_clip_score=self.max_clip_score,
|
127 |
+
)
|
128 |
+
for ds in ["train_dataset", "eval_dataset"]:
|
129 |
+
if hasattr(self, ds):
|
130 |
+
setattr(
|
131 |
+
self,
|
132 |
+
ds,
|
133 |
+
(
|
134 |
+
getattr(self, ds).filter(partial_filter_function)
|
135 |
+
if self.streaming
|
136 |
+
else getattr(self, ds).filter(
|
137 |
+
partial_filter_function,
|
138 |
+
num_proc=self.preprocessing_num_workers,
|
139 |
+
load_from_cache_file=not self.overwrite_cache,
|
140 |
+
desc="Filtering datasets",
|
141 |
+
)
|
142 |
+
),
|
143 |
+
)
|
144 |
+
|
145 |
+
# normalize text
|
146 |
+
if normalize_text:
|
147 |
+
text_normalizer = TextNormalizer()
|
148 |
+
partial_normalize_function = partial(
|
149 |
+
normalize_function,
|
150 |
+
text_column=self.text_column,
|
151 |
+
text_normalizer=text_normalizer,
|
152 |
+
)
|
153 |
+
for ds in ["train_dataset", "eval_dataset"]:
|
154 |
+
if hasattr(self, ds):
|
155 |
+
setattr(
|
156 |
+
self,
|
157 |
+
ds,
|
158 |
+
(
|
159 |
+
getattr(self, ds).map(partial_normalize_function)
|
160 |
+
if self.streaming
|
161 |
+
else getattr(self, ds).map(
|
162 |
+
partial_normalize_function,
|
163 |
+
num_proc=self.preprocessing_num_workers,
|
164 |
+
load_from_cache_file=not self.overwrite_cache,
|
165 |
+
desc="Normalizing datasets",
|
166 |
+
)
|
167 |
+
),
|
168 |
+
)
|
169 |
+
|
170 |
+
# blank captions
|
171 |
+
if self.blank_caption_prob:
|
172 |
+
partial_blank_caption_function = partial(
|
173 |
+
blank_caption_function,
|
174 |
+
text_column=self.text_column,
|
175 |
+
blank_caption_prob=self.blank_caption_prob,
|
176 |
+
)
|
177 |
+
if hasattr(self, "train_dataset"):
|
178 |
+
self.train_dataset = (
|
179 |
+
self.train_dataset.map(partial_blank_caption_function)
|
180 |
+
if self.streaming
|
181 |
+
else self.train_dataset.map(
|
182 |
+
partial_blank_caption_function,
|
183 |
+
num_proc=self.preprocessing_num_workers,
|
184 |
+
load_from_cache_file=False,
|
185 |
+
desc="Blanking some captions",
|
186 |
+
)
|
187 |
+
)
|
188 |
+
|
189 |
+
# preprocess
|
190 |
+
partial_preprocess_function = partial(
|
191 |
+
preprocess_function,
|
192 |
+
tokenizer=tokenizer,
|
193 |
+
text_column=self.text_column,
|
194 |
+
encoding_column=self.encoding_column,
|
195 |
+
max_length=max_length,
|
196 |
+
decoder_start_token_id=decoder_start_token_id,
|
197 |
+
)
|
198 |
+
for ds in ["train_dataset", "eval_dataset"]:
|
199 |
+
if hasattr(self, ds):
|
200 |
+
setattr(
|
201 |
+
self,
|
202 |
+
ds,
|
203 |
+
(
|
204 |
+
getattr(self, ds).map(
|
205 |
+
partial_preprocess_function,
|
206 |
+
batched=True,
|
207 |
+
remove_columns=[
|
208 |
+
self.text_column,
|
209 |
+
self.encoding_column,
|
210 |
+
],
|
211 |
+
)
|
212 |
+
if self.streaming
|
213 |
+
else getattr(self, ds).map(
|
214 |
+
partial_preprocess_function,
|
215 |
+
batched=True,
|
216 |
+
remove_columns=getattr(ds, "column_names"),
|
217 |
+
num_proc=self.preprocessing_num_workers,
|
218 |
+
load_from_cache_file=not self.overwrite_cache,
|
219 |
+
desc="Preprocessing datasets",
|
220 |
+
)
|
221 |
+
),
|
222 |
+
)
|
223 |
+
|
224 |
+
def dataloader(self, split, batch_size, epoch=None):
|
225 |
+
def _dataloader_datasets_non_streaming(
|
226 |
+
dataset: Dataset,
|
227 |
+
rng: jax.random.PRNGKey = None,
|
228 |
+
):
|
229 |
+
"""
|
230 |
+
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
231 |
+
Shuffle batches if rng is set.
|
232 |
+
"""
|
233 |
+
steps_per_epoch = len(dataset) // batch_size
|
234 |
+
|
235 |
+
if rng is not None:
|
236 |
+
batch_idx = jax.random.permutation(rng, len(dataset))
|
237 |
+
else:
|
238 |
+
batch_idx = jnp.arange(len(dataset))
|
239 |
+
|
240 |
+
batch_idx = batch_idx[
|
241 |
+
: steps_per_epoch * batch_size
|
242 |
+
] # Skip incomplete batch.
|
243 |
+
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
244 |
+
|
245 |
+
for idx in batch_idx:
|
246 |
+
batch = dataset[idx]
|
247 |
+
batch = {k: jnp.array(v) for k, v in batch.items()}
|
248 |
+
yield batch
|
249 |
+
|
250 |
+
def _dataloader_datasets_streaming(
|
251 |
+
dataset: Dataset,
|
252 |
+
epoch: int,
|
253 |
+
):
|
254 |
+
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
255 |
+
batch = {k: [] for k in keys}
|
256 |
+
first_loop = True # stop after one loop in some cases
|
257 |
+
while (self.multi_hosts and split == "train") or first_loop:
|
258 |
+
# in multi-host, we run forever (no epoch) as hosts need to stop
|
259 |
+
# at the same time and training data may not be split equally
|
260 |
+
# For validation data we put the entire batch on each host and then
|
261 |
+
# keep only the one specific to each host (could be improved but not necessary)
|
262 |
+
if epoch is not None:
|
263 |
+
assert split == "train"
|
264 |
+
# reshuffle training data at each epoch
|
265 |
+
dataset.set_epoch(epoch)
|
266 |
+
epoch += 1
|
267 |
+
for item in dataset:
|
268 |
+
for k in keys:
|
269 |
+
batch[k].append(item[k])
|
270 |
+
if len(batch[keys[0]]) == batch_size:
|
271 |
+
batch = {k: jnp.array(v) for k, v in batch.items()}
|
272 |
+
yield batch
|
273 |
+
batch = {k: [] for k in keys}
|
274 |
+
first_loop = False
|
275 |
+
|
276 |
+
if split == "train":
|
277 |
+
ds = self.train_dataset
|
278 |
+
elif split == "eval":
|
279 |
+
ds = self.eval_dataset
|
280 |
+
else:
|
281 |
+
raise ValueError(f'split must be "train" or "eval", got {split}')
|
282 |
+
|
283 |
+
if self.streaming:
|
284 |
+
return _dataloader_datasets_streaming(ds, epoch)
|
285 |
+
else:
|
286 |
+
if split == "train":
|
287 |
+
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
288 |
+
return _dataloader_datasets_non_streaming(ds, input_rng)
|
289 |
+
|
290 |
+
@property
|
291 |
+
def length(self):
|
292 |
+
len_train_dataset, len_eval_dataset = None, None
|
293 |
+
if self.streaming:
|
294 |
+
# we don't know the length, let's just assume max_samples if defined
|
295 |
+
if self.max_train_samples is not None:
|
296 |
+
len_train_dataset = self.max_train_samples
|
297 |
+
if self.max_eval_samples is not None:
|
298 |
+
len_eval_dataset = self.max_eval_samples
|
299 |
+
else:
|
300 |
+
len_train_dataset = (
|
301 |
+
len(self.train_dataset) if hasattr(self, "train_dataset") else None
|
302 |
+
)
|
303 |
+
len_eval_dataset = (
|
304 |
+
len(self.eval_dataset) if hasattr(self, "eval_dataset") else None
|
305 |
+
)
|
306 |
+
return len_train_dataset, len_eval_dataset
|
307 |
+
|
308 |
+
|
309 |
+
def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
|
310 |
+
"""
|
311 |
+
Shift input ids one token to the right.
|
312 |
+
"""
|
313 |
+
shifted_input_ids = np.zeros(input_ids.shape)
|
314 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
315 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
316 |
+
return shifted_input_ids
|
317 |
+
|
318 |
+
|
319 |
+
def blank_caption_function(example, text_column, blank_caption_prob):
|
320 |
+
if blank_caption_prob and np.random.rand() < blank_caption_prob:
|
321 |
+
example[text_column] = ""
|
322 |
+
return example
|
323 |
+
|
324 |
+
|
325 |
+
def normalize_function(example, text_column, text_normalizer):
|
326 |
+
example[text_column] = text_normalizer(example[text_column])
|
327 |
+
return example
|
328 |
+
|
329 |
+
|
330 |
+
def filter_function(
|
331 |
+
example,
|
332 |
+
min_clip_score,
|
333 |
+
max_clip_score,
|
334 |
+
clip_score_column,
|
335 |
+
filter_column,
|
336 |
+
filter_value,
|
337 |
+
):
|
338 |
+
if min_clip_score is not None and example[clip_score_column] < min_clip_score:
|
339 |
+
return False
|
340 |
+
if max_clip_score is not None and example[clip_score_column] > max_clip_score:
|
341 |
+
return False
|
342 |
+
if filter_column is not None and example[filter_column] != filter_value:
|
343 |
+
return False
|
344 |
+
return True
|
345 |
+
|
346 |
+
|
347 |
+
def preprocess_function(
|
348 |
+
examples,
|
349 |
+
tokenizer,
|
350 |
+
text_column,
|
351 |
+
encoding_column,
|
352 |
+
max_length,
|
353 |
+
decoder_start_token_id,
|
354 |
+
):
|
355 |
+
inputs = examples[text_column]
|
356 |
+
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
357 |
+
model_inputs = tokenizer(
|
358 |
+
inputs,
|
359 |
+
max_length=max_length,
|
360 |
+
padding="max_length",
|
361 |
+
truncation=True,
|
362 |
+
return_tensors="np",
|
363 |
+
)
|
364 |
+
|
365 |
+
# set up targets
|
366 |
+
# Note: labels correspond to our target indices
|
367 |
+
# decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
|
368 |
+
labels = examples[encoding_column]
|
369 |
+
labels = np.asarray(labels)
|
370 |
+
|
371 |
+
# We need the labels, in addition to the decoder_input_ids, for the compute_loss function
|
372 |
+
model_inputs["labels"] = labels
|
373 |
+
|
374 |
+
# In our case, this prepends the bos token and removes the last one
|
375 |
+
decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id)
|
376 |
+
model_inputs["decoder_input_ids"] = decoder_input_ids
|
377 |
+
|
378 |
+
return model_inputs
|
src/dalle_mini/model/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .configuration import DalleBartConfig
|
2 |
+
from .modeling import DalleBart
|
3 |
+
from .partitions import set_partitions
|
4 |
+
from .processor import DalleBartProcessor
|
5 |
+
from .tokenizer import DalleBartTokenizer
|
src/dalle_mini/model/configuration.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" DalleBart model configuration """
|
16 |
+
import warnings
|
17 |
+
|
18 |
+
from transformers.configuration_utils import PretrainedConfig
|
19 |
+
from transformers.utils import logging
|
20 |
+
|
21 |
+
from .utils import PretrainedFromWandbMixin
|
22 |
+
|
23 |
+
logger = logging.get_logger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
27 |
+
model_type = "dallebart"
|
28 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
29 |
+
attribute_map = {
|
30 |
+
"num_attention_heads": "encoder_attention_heads",
|
31 |
+
"hidden_size": "d_model",
|
32 |
+
}
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
normalize_text=False,
|
37 |
+
encoder_vocab_size=50264,
|
38 |
+
image_vocab_size=16384, # encoded image token space
|
39 |
+
image_length=256, # number of encoded tokens
|
40 |
+
max_text_length=64, # max number of text tokens
|
41 |
+
encoder_layers=12,
|
42 |
+
encoder_ffn_dim=4096,
|
43 |
+
encoder_attention_heads=16,
|
44 |
+
decoder_layers=12,
|
45 |
+
decoder_ffn_dim=4096,
|
46 |
+
decoder_attention_heads=16,
|
47 |
+
activation_function="gelu",
|
48 |
+
d_model=1024,
|
49 |
+
dropout=0.1,
|
50 |
+
attention_dropout=0.0,
|
51 |
+
activation_dropout=0.0,
|
52 |
+
init_std=0.02,
|
53 |
+
scale_embedding=False,
|
54 |
+
gradient_checkpointing=False,
|
55 |
+
use_cache=True,
|
56 |
+
is_encoder_decoder=True,
|
57 |
+
forced_eos_token_id=None,
|
58 |
+
tie_word_embeddings=False, # different modalities and sizes
|
59 |
+
do_sample=True,
|
60 |
+
# transformer variants
|
61 |
+
use_bias=False, # use bias in attention and dense layers (except for lm_head)
|
62 |
+
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
63 |
+
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
|
64 |
+
use_head_scale=False, # used in NormFormer
|
65 |
+
use_cosine_attention=False, # used in Swin v2
|
66 |
+
tau_init=0.05, # used only in cosine attention (Swin v2)
|
67 |
+
use_absolute_position_embeddings=True, # default
|
68 |
+
use_swin_position_embeddings=False, # used in Swin v1/v2
|
69 |
+
use_deepnet_scaling=False, # used in Deepnet
|
70 |
+
use_glu=False, # "GLU Variants Improve Transformer"
|
71 |
+
use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
|
72 |
+
sinkhorn_iters=1, # used in SinkFormers
|
73 |
+
use_final_ln_encoder=True, # final layer normalization in encoder
|
74 |
+
use_final_ln_decoder=True, # final layer normalization in decoder
|
75 |
+
# parameters that should not be necessary but could affect results
|
76 |
+
force_ln_scale=False, # force scale in layernorm even when followed by dense layers
|
77 |
+
**kwargs,
|
78 |
+
):
|
79 |
+
# text normalizer
|
80 |
+
self.normalize_text = normalize_text
|
81 |
+
|
82 |
+
# transformer variants
|
83 |
+
self.use_bias = use_bias
|
84 |
+
assert ln_type in [
|
85 |
+
"rmsnorm",
|
86 |
+
"layernorm",
|
87 |
+
], "ln_type must be 'rmsnorm' or 'layernorm'"
|
88 |
+
self.ln_type = ln_type
|
89 |
+
if ln_positions == "deepnet":
|
90 |
+
ln_positions = "postln"
|
91 |
+
assert ln_positions in [
|
92 |
+
"normformer",
|
93 |
+
"swinv2",
|
94 |
+
"cogview",
|
95 |
+
"postln",
|
96 |
+
"preln",
|
97 |
+
], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln'"
|
98 |
+
self.use_head_scale = use_head_scale
|
99 |
+
assert use_alibi is False, "use_alibi is not supported yet"
|
100 |
+
self.ln_positions = ln_positions
|
101 |
+
self.use_cosine_attention = use_cosine_attention
|
102 |
+
self.tau_init = tau_init
|
103 |
+
self.use_absolute_position_embeddings = use_absolute_position_embeddings
|
104 |
+
self.use_swin_position_embeddings = use_swin_position_embeddings
|
105 |
+
self.use_deepnet_scaling = use_deepnet_scaling
|
106 |
+
self.use_glu = use_glu
|
107 |
+
self.use_alibi = use_alibi
|
108 |
+
self.sinkhorn_iters = sinkhorn_iters
|
109 |
+
if ln_positions == "postln":
|
110 |
+
assert (
|
111 |
+
use_final_ln_encoder
|
112 |
+
), "use_final_ln_encoder must be True when ln_positions is 'postln'"
|
113 |
+
assert (
|
114 |
+
use_final_ln_decoder
|
115 |
+
), "use_final_ln_decoder must be True when ln_positions is 'postln'"
|
116 |
+
self.use_final_ln_encoder = use_final_ln_encoder
|
117 |
+
self.use_final_ln_decoder = use_final_ln_decoder
|
118 |
+
self.force_ln_scale = force_ln_scale
|
119 |
+
|
120 |
+
# common parameters
|
121 |
+
self.encoder_vocab_size = encoder_vocab_size
|
122 |
+
self.image_vocab_size = image_vocab_size
|
123 |
+
self.image_length = image_length
|
124 |
+
self.max_text_length = max_text_length
|
125 |
+
self.d_model = d_model
|
126 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
127 |
+
self.encoder_layers = encoder_layers
|
128 |
+
self.encoder_attention_heads = encoder_attention_heads
|
129 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
130 |
+
self.decoder_layers = decoder_layers
|
131 |
+
self.decoder_attention_heads = decoder_attention_heads
|
132 |
+
self.dropout = dropout
|
133 |
+
self.attention_dropout = attention_dropout
|
134 |
+
self.activation_dropout = activation_dropout
|
135 |
+
self.activation_function = activation_function
|
136 |
+
self.init_std = init_std
|
137 |
+
self.use_cache = use_cache
|
138 |
+
self.gradient_checkpointing = gradient_checkpointing
|
139 |
+
self.scale_embedding = (
|
140 |
+
scale_embedding # scale factor will be sqrt(d_model) if True
|
141 |
+
)
|
142 |
+
|
143 |
+
# special token id's are appended to vocab if not provided
|
144 |
+
decoder_start_token_id = kwargs.pop("decoder_start_token_id", image_vocab_size)
|
145 |
+
bos_token_id = kwargs.pop("bos_token_id", image_vocab_size)
|
146 |
+
pad_token_id = kwargs.pop("pad_token_id", image_vocab_size)
|
147 |
+
eos_token_id = kwargs.pop("eos_token_id", image_vocab_size)
|
148 |
+
|
149 |
+
# we generate to image_length + 1 (for bos) by default
|
150 |
+
min_length = kwargs.pop("min_length", image_length + 1)
|
151 |
+
max_length = kwargs.pop("max_length", image_length + 1)
|
152 |
+
|
153 |
+
super().__init__(
|
154 |
+
# args required in parent class
|
155 |
+
is_encoder_decoder=is_encoder_decoder,
|
156 |
+
tie_word_embeddings=tie_word_embeddings,
|
157 |
+
forced_eos_token_id=forced_eos_token_id,
|
158 |
+
decoder_start_token_id=decoder_start_token_id,
|
159 |
+
bos_token_id=bos_token_id,
|
160 |
+
pad_token_id=pad_token_id,
|
161 |
+
eos_token_id=eos_token_id,
|
162 |
+
min_length=min_length,
|
163 |
+
max_length=max_length,
|
164 |
+
do_sample=do_sample,
|
165 |
+
**kwargs,
|
166 |
+
)
|
167 |
+
|
168 |
+
# ensure backward compatibility for BART CNN models
|
169 |
+
if self.forced_bos_token_id is None and kwargs.get(
|
170 |
+
"force_bos_token_to_be_generated", False
|
171 |
+
):
|
172 |
+
self.forced_bos_token_id = self.bos_token_id
|
173 |
+
warnings.warn(
|
174 |
+
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
|
175 |
+
"The config can simply be saved and uploaded again to be fixed."
|
176 |
+
)
|
src/dalle_mini/model/modeling.py
ADDED
@@ -0,0 +1,2093 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021-2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and & DALL·E Mini team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" DalleBart model. """
|
16 |
+
|
17 |
+
import math
|
18 |
+
import os
|
19 |
+
from functools import partial
|
20 |
+
from pickle import UnpicklingError
|
21 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import flax
|
24 |
+
import flax.linen as nn
|
25 |
+
import jax
|
26 |
+
import jax.numpy as jnp
|
27 |
+
import msgpack.exceptions
|
28 |
+
from einops import rearrange
|
29 |
+
from flax.core.frozen_dict import unfreeze
|
30 |
+
from flax.linen import combine_masks, make_causal_mask
|
31 |
+
from flax.linen import partitioning as nn_partitioning
|
32 |
+
from flax.linen.linear import PrecisionLike
|
33 |
+
from flax.serialization import from_bytes
|
34 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
35 |
+
from jax import custom_jvp, lax
|
36 |
+
from jax.random import PRNGKey
|
37 |
+
from transformers.configuration_utils import PretrainedConfig
|
38 |
+
from transformers.file_utils import (
|
39 |
+
FLAX_WEIGHTS_NAME,
|
40 |
+
WEIGHTS_NAME,
|
41 |
+
cached_path,
|
42 |
+
hf_bucket_url,
|
43 |
+
is_offline_mode,
|
44 |
+
is_remote_url,
|
45 |
+
)
|
46 |
+
from transformers.generation_flax_utils import FlaxSampleOutput
|
47 |
+
from transformers.modeling_flax_outputs import (
|
48 |
+
FlaxBaseModelOutput,
|
49 |
+
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
50 |
+
FlaxCausalLMOutputWithCrossAttentions,
|
51 |
+
FlaxSeq2SeqLMOutput,
|
52 |
+
)
|
53 |
+
from transformers.modeling_flax_utils import ACT2FN
|
54 |
+
from transformers.models.bart.modeling_flax_bart import (
|
55 |
+
FlaxBartAttention,
|
56 |
+
FlaxBartForConditionalGeneration,
|
57 |
+
FlaxBartForConditionalGenerationModule,
|
58 |
+
FlaxBartModule,
|
59 |
+
FlaxBartPreTrainedModel,
|
60 |
+
)
|
61 |
+
from transformers.utils import logging
|
62 |
+
|
63 |
+
from .configuration import DalleBartConfig
|
64 |
+
from .utils import PretrainedFromWandbMixin
|
65 |
+
|
66 |
+
logger = logging.get_logger(__name__)
|
67 |
+
|
68 |
+
remat = nn_partitioning.remat
|
69 |
+
|
70 |
+
|
71 |
+
def smelu(beta: Any = 1.0):
|
72 |
+
"""
|
73 |
+
Implementation of "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations"
|
74 |
+
https://arxiv.org/abs/2202.06499
|
75 |
+
"""
|
76 |
+
|
77 |
+
@custom_jvp
|
78 |
+
@jax.jit
|
79 |
+
def _smelu(x: Any) -> Any:
|
80 |
+
x = jnp.where(x <= -beta, 0.0, x)
|
81 |
+
return jnp.where(x >= beta, x, jnp.square(x + beta) / (4 * beta))
|
82 |
+
|
83 |
+
_smelu.defjvps(
|
84 |
+
lambda g, ans, x: lax.select(
|
85 |
+
x == -beta,
|
86 |
+
lax.full_like(g, 0),
|
87 |
+
lax.select(x == beta, lax.full_like(g, 1), g),
|
88 |
+
)
|
89 |
+
)
|
90 |
+
return _smelu
|
91 |
+
|
92 |
+
|
93 |
+
ACT2FN.update({"smelu": smelu})
|
94 |
+
|
95 |
+
# deepnet initialization
|
96 |
+
def deepnet_init(gain=1):
|
97 |
+
init = jax.nn.initializers.glorot_normal()
|
98 |
+
|
99 |
+
def _init(*args, **kwargs):
|
100 |
+
return gain * init(*args, **kwargs)
|
101 |
+
|
102 |
+
return _init
|
103 |
+
|
104 |
+
|
105 |
+
# deepnet gain
|
106 |
+
deepnet_gain = {
|
107 |
+
"encoder": {
|
108 |
+
"alpha": lambda config: 0.81
|
109 |
+
* (config.encoder_layers**4 * config.decoder_layers) ** 0.0625,
|
110 |
+
"beta": lambda config: 0.87
|
111 |
+
* (config.encoder_layers**4 * config.decoder_layers) ** -0.0625,
|
112 |
+
},
|
113 |
+
"decoder": {
|
114 |
+
"alpha": lambda config: (3 * config.decoder_layers) ** 0.25,
|
115 |
+
"beta": lambda config: (12 * config.decoder_layers) ** -0.25,
|
116 |
+
},
|
117 |
+
}
|
118 |
+
|
119 |
+
|
120 |
+
class RMSNorm(nn.Module):
|
121 |
+
"""
|
122 |
+
From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
|
123 |
+
|
124 |
+
Adapted from flax.linen.LayerNorm
|
125 |
+
"""
|
126 |
+
|
127 |
+
epsilon: float = 1e-6
|
128 |
+
dtype: Any = jnp.float32
|
129 |
+
param_dtype: Any = jnp.float32
|
130 |
+
use_scale: bool = True
|
131 |
+
scale_init: Any = jax.nn.initializers.ones
|
132 |
+
|
133 |
+
@nn.compact
|
134 |
+
def __call__(self, x):
|
135 |
+
reduction_axes = (-1,)
|
136 |
+
feature_axes = (-1,)
|
137 |
+
|
138 |
+
rms_sq = self._compute_rms_sq(x, reduction_axes)
|
139 |
+
|
140 |
+
return self._normalize(
|
141 |
+
self,
|
142 |
+
x,
|
143 |
+
rms_sq,
|
144 |
+
reduction_axes,
|
145 |
+
feature_axes,
|
146 |
+
self.dtype,
|
147 |
+
self.param_dtype,
|
148 |
+
self.epsilon,
|
149 |
+
self.use_scale,
|
150 |
+
self.scale_init,
|
151 |
+
)
|
152 |
+
|
153 |
+
def _compute_rms_sq(self, x, axes):
|
154 |
+
x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
|
155 |
+
rms_sq = jnp.mean(jax.lax.square(x), axes)
|
156 |
+
return rms_sq
|
157 |
+
|
158 |
+
def _normalize(
|
159 |
+
self,
|
160 |
+
mdl,
|
161 |
+
x,
|
162 |
+
rms_sq,
|
163 |
+
reduction_axes,
|
164 |
+
feature_axes,
|
165 |
+
dtype,
|
166 |
+
param_dtype,
|
167 |
+
epsilon,
|
168 |
+
use_scale,
|
169 |
+
scale_init,
|
170 |
+
):
|
171 |
+
reduction_axes = nn.normalization._canonicalize_axes(x.ndim, reduction_axes)
|
172 |
+
feature_axes = nn.normalization._canonicalize_axes(x.ndim, feature_axes)
|
173 |
+
stats_shape = list(x.shape)
|
174 |
+
for axis in reduction_axes:
|
175 |
+
stats_shape[axis] = 1
|
176 |
+
rms_sq = rms_sq.reshape(stats_shape)
|
177 |
+
feature_shape = [1] * x.ndim
|
178 |
+
reduced_feature_shape = []
|
179 |
+
for ax in feature_axes:
|
180 |
+
feature_shape[ax] = x.shape[ax]
|
181 |
+
reduced_feature_shape.append(x.shape[ax])
|
182 |
+
mul = lax.rsqrt(rms_sq + epsilon)
|
183 |
+
if use_scale:
|
184 |
+
scale = mdl.param(
|
185 |
+
"scale", scale_init, reduced_feature_shape, param_dtype
|
186 |
+
).reshape(feature_shape)
|
187 |
+
mul *= scale
|
188 |
+
y = mul * x
|
189 |
+
return jnp.asarray(y, dtype)
|
190 |
+
|
191 |
+
|
192 |
+
def norm(type, *args, **kwargs):
|
193 |
+
if type == "rmsnorm":
|
194 |
+
return RMSNorm(*args, **kwargs)
|
195 |
+
elif type == "layernorm":
|
196 |
+
return nn.LayerNorm(*args, **kwargs)
|
197 |
+
else:
|
198 |
+
raise ValueError(f"Unknown norm type {type}")
|
199 |
+
|
200 |
+
|
201 |
+
def dot_product_attention_weights(
|
202 |
+
query: Any,
|
203 |
+
key: Any,
|
204 |
+
bias: Optional[Any] = None,
|
205 |
+
mask: Optional[Any] = None,
|
206 |
+
embed_pos: Optional[Any] = None,
|
207 |
+
broadcast_dropout: bool = True,
|
208 |
+
dropout_rng: Optional[PRNGKey] = None,
|
209 |
+
dropout_rate: float = 0.0,
|
210 |
+
deterministic: bool = False,
|
211 |
+
dtype: Any = jnp.float32,
|
212 |
+
precision: PrecisionLike = None,
|
213 |
+
sinkhorn_iters: int = 1,
|
214 |
+
is_encoder: bool = False,
|
215 |
+
):
|
216 |
+
"""
|
217 |
+
Computes dot-product attention weights given query and key.
|
218 |
+
mask is included into the bias.
|
219 |
+
|
220 |
+
Adapted from flax.linen.attention.dot_product_attention_weights"
|
221 |
+
"""
|
222 |
+
assert query.ndim == key.ndim, "q, k must have same rank."
|
223 |
+
assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
|
224 |
+
assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
|
225 |
+
assert query.shape[-1] == key.shape[-1], "q, k depths must match."
|
226 |
+
|
227 |
+
# calculate attention matrix
|
228 |
+
depth = query.shape[-1]
|
229 |
+
query = query / jnp.sqrt(depth).astype(dtype)
|
230 |
+
# attn weight shape is (batch..., num_heads, q_length, kv_length)
|
231 |
+
attn_weights = jnp.einsum("...qhd,...khd->...hqk", query, key, precision=precision)
|
232 |
+
|
233 |
+
# apply attention bias: masking, dropout, proximity bias, etc.
|
234 |
+
if bias is not None:
|
235 |
+
attn_weights = attn_weights + bias
|
236 |
+
|
237 |
+
# add relative position
|
238 |
+
if embed_pos is not None:
|
239 |
+
attn_weights = attn_weights + embed_pos
|
240 |
+
|
241 |
+
# normalize the attention weights
|
242 |
+
if not is_encoder or sinkhorn_iters == 1:
|
243 |
+
# sinkhorn does not work for causal (leaks info of future tokens into past)
|
244 |
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
245 |
+
else:
|
246 |
+
# adapted from https://github.com/lucidrains/sinkhorn-transformer
|
247 |
+
for i in range(sinkhorn_iters):
|
248 |
+
# when causal, some attn_weights have been set to -inf through bias
|
249 |
+
if i % 2 == 0:
|
250 |
+
attn_weights -= jax.nn.logsumexp(attn_weights, axis=-1, keepdims=True)
|
251 |
+
else:
|
252 |
+
attn_weights -= jax.nn.logsumexp(attn_weights, axis=-2, keepdims=True)
|
253 |
+
if mask is not None:
|
254 |
+
attn_weights = jnp.where(mask, attn_weights, -jnp.inf)
|
255 |
+
attn_weights = jnp.exp(attn_weights).astype(dtype)
|
256 |
+
|
257 |
+
# apply attention dropout
|
258 |
+
if not deterministic and dropout_rate > 0.0:
|
259 |
+
keep_prob = 1.0 - dropout_rate
|
260 |
+
if broadcast_dropout:
|
261 |
+
# dropout is broadcast across the batch + head dimensions
|
262 |
+
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
|
263 |
+
keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
|
264 |
+
else:
|
265 |
+
keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
|
266 |
+
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(
|
267 |
+
keep_prob, dtype=dtype
|
268 |
+
)
|
269 |
+
attn_weights = attn_weights * multiplier
|
270 |
+
|
271 |
+
return attn_weights
|
272 |
+
|
273 |
+
|
274 |
+
class FlaxBartAttention(FlaxBartAttention):
|
275 |
+
"""
|
276 |
+
Edits:
|
277 |
+
- causal mask is used only in decoder and considers image_length
|
278 |
+
- scale attention heads per NormFormer paper
|
279 |
+
"""
|
280 |
+
|
281 |
+
is_encoder: bool = False
|
282 |
+
q_length: int = None
|
283 |
+
k_length: int = None
|
284 |
+
|
285 |
+
def setup(self) -> None:
|
286 |
+
self.head_dim = self.embed_dim // self.num_heads
|
287 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
288 |
+
raise ValueError(
|
289 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
290 |
+
f" and `num_heads`: {self.num_heads})."
|
291 |
+
)
|
292 |
+
|
293 |
+
dense = partial(
|
294 |
+
nn.Dense,
|
295 |
+
self.embed_dim,
|
296 |
+
use_bias=self.bias,
|
297 |
+
dtype=self.dtype,
|
298 |
+
)
|
299 |
+
|
300 |
+
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
|
301 |
+
self.config
|
302 |
+
)
|
303 |
+
|
304 |
+
self.q_proj = dense(
|
305 |
+
kernel_init=deepnet_init()
|
306 |
+
if self.config.use_deepnet_scaling
|
307 |
+
else jax.nn.initializers.normal(self.config.init_std)
|
308 |
+
)
|
309 |
+
self.k_proj = dense(
|
310 |
+
kernel_init=deepnet_init()
|
311 |
+
if self.config.use_deepnet_scaling
|
312 |
+
else jax.nn.initializers.normal(self.config.init_std)
|
313 |
+
)
|
314 |
+
self.v_proj = dense(
|
315 |
+
kernel_init=deepnet_init(gain)
|
316 |
+
if self.config.use_deepnet_scaling
|
317 |
+
else jax.nn.initializers.normal(self.config.init_std)
|
318 |
+
)
|
319 |
+
self.out_proj = dense(
|
320 |
+
kernel_init=deepnet_init(gain)
|
321 |
+
if self.config.use_deepnet_scaling
|
322 |
+
else jax.nn.initializers.normal(self.config.init_std)
|
323 |
+
)
|
324 |
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
325 |
+
|
326 |
+
if self.config.use_head_scale:
|
327 |
+
self.head_scale = self.param(
|
328 |
+
"head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1)
|
329 |
+
)
|
330 |
+
|
331 |
+
if self.config.use_cosine_attention:
|
332 |
+
self.tau = self.param(
|
333 |
+
"tau",
|
334 |
+
jax.nn.initializers.constant(self.config.tau_init),
|
335 |
+
(1, self.num_heads, 1, 1),
|
336 |
+
)
|
337 |
+
|
338 |
+
if self.config.use_swin_position_embeddings:
|
339 |
+
self.rel_bias = nn.Embed(
|
340 |
+
self.q_length,
|
341 |
+
self.k_length * self.num_heads,
|
342 |
+
embedding_init=deepnet_init()
|
343 |
+
if self.config.use_deepnet_scaling
|
344 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
345 |
+
)
|
346 |
+
|
347 |
+
if self.causal:
|
348 |
+
# used only in decoder
|
349 |
+
self.causal_mask = make_causal_mask(
|
350 |
+
jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool"
|
351 |
+
)
|
352 |
+
|
353 |
+
def __call__(
|
354 |
+
self,
|
355 |
+
hidden_states: jnp.ndarray,
|
356 |
+
key_value_states: Optional[jnp.ndarray] = None,
|
357 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
358 |
+
init_cache: bool = False,
|
359 |
+
deterministic: bool = True,
|
360 |
+
) -> Tuple[jnp.ndarray]:
|
361 |
+
"""Input shape: Batch x Time x Channel"""
|
362 |
+
|
363 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
364 |
+
# for the decoder
|
365 |
+
is_cross_attention = key_value_states is not None
|
366 |
+
batch_size = hidden_states.shape[0]
|
367 |
+
|
368 |
+
# get query proj
|
369 |
+
query_states = self.q_proj(hidden_states)
|
370 |
+
# get key, value proj
|
371 |
+
if is_cross_attention:
|
372 |
+
# cross_attentions
|
373 |
+
key_states = self.k_proj(key_value_states)
|
374 |
+
value_states = self.v_proj(key_value_states)
|
375 |
+
else:
|
376 |
+
# self_attention
|
377 |
+
key_states = self.k_proj(hidden_states)
|
378 |
+
value_states = self.v_proj(hidden_states)
|
379 |
+
|
380 |
+
query_states = self._split_heads(query_states)
|
381 |
+
key_states = self._split_heads(key_states)
|
382 |
+
value_states = self._split_heads(value_states)
|
383 |
+
|
384 |
+
# handle cache prepare causal attention mask
|
385 |
+
if self.causal:
|
386 |
+
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
387 |
+
if self.has_variable("cache", "cached_key"):
|
388 |
+
mask_shift = self.variables["cache"]["cache_index"]
|
389 |
+
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
390 |
+
causal_mask = lax.dynamic_slice(
|
391 |
+
self.causal_mask,
|
392 |
+
(0, 0, mask_shift, 0),
|
393 |
+
(1, 1, query_length, max_decoder_length),
|
394 |
+
)
|
395 |
+
else:
|
396 |
+
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
397 |
+
causal_mask = jnp.broadcast_to(
|
398 |
+
causal_mask, (batch_size,) + causal_mask.shape[1:]
|
399 |
+
)
|
400 |
+
|
401 |
+
# combine masks if needed
|
402 |
+
if attention_mask is not None and self.causal:
|
403 |
+
attention_mask = jnp.broadcast_to(
|
404 |
+
jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape
|
405 |
+
)
|
406 |
+
attention_mask = combine_masks(attention_mask, causal_mask)
|
407 |
+
elif self.causal:
|
408 |
+
attention_mask = causal_mask
|
409 |
+
elif attention_mask is not None:
|
410 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
411 |
+
|
412 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
413 |
+
# and cache the keys and values step by step.
|
414 |
+
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
415 |
+
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
416 |
+
key_states, value_states, query_states, attention_mask
|
417 |
+
)
|
418 |
+
|
419 |
+
# Convert the boolean attention mask to an attention bias.
|
420 |
+
if attention_mask is not None:
|
421 |
+
# attention mask in the form of attention bias
|
422 |
+
attention_bias = lax.select(
|
423 |
+
attention_mask > 0,
|
424 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
425 |
+
jnp.full(attention_mask.shape, -jnp.inf).astype(self.dtype),
|
426 |
+
)
|
427 |
+
else:
|
428 |
+
attention_bias = None
|
429 |
+
|
430 |
+
dropout_rng = None
|
431 |
+
if not deterministic and self.dropout > 0.0:
|
432 |
+
dropout_rng = self.make_rng("dropout")
|
433 |
+
|
434 |
+
if self.config.use_cosine_attention:
|
435 |
+
# normalize q and k
|
436 |
+
query_states = query_states / (
|
437 |
+
jnp.linalg.norm(query_states, axis=-1, keepdims=True) + 1e-8
|
438 |
+
)
|
439 |
+
key_states = key_states / (
|
440 |
+
jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8
|
441 |
+
)
|
442 |
+
|
443 |
+
# relative position embeddings
|
444 |
+
if self.config.use_swin_position_embeddings:
|
445 |
+
position_ids = jnp.arange(self.q_length)
|
446 |
+
embed_pos = self.rel_bias(position_ids)
|
447 |
+
embed_pos = rearrange(embed_pos, "q (k h) -> 1 h q k", h=self.num_heads)
|
448 |
+
else:
|
449 |
+
embed_pos = None
|
450 |
+
|
451 |
+
attn_weights = dot_product_attention_weights(
|
452 |
+
query_states,
|
453 |
+
key_states,
|
454 |
+
bias=attention_bias,
|
455 |
+
mask=attention_mask,
|
456 |
+
embed_pos=embed_pos,
|
457 |
+
dropout_rng=dropout_rng,
|
458 |
+
dropout_rate=self.dropout,
|
459 |
+
broadcast_dropout=True,
|
460 |
+
deterministic=deterministic,
|
461 |
+
dtype=self.dtype,
|
462 |
+
precision=None,
|
463 |
+
sinkhorn_iters=self.config.sinkhorn_iters,
|
464 |
+
is_encoder=self.is_encoder,
|
465 |
+
)
|
466 |
+
if self.config.use_cosine_attention:
|
467 |
+
# divide by tau
|
468 |
+
attn_weights = attn_weights / jnp.maximum(self.tau, 0.01)
|
469 |
+
|
470 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
471 |
+
if self.config.use_head_scale:
|
472 |
+
# per Normformer
|
473 |
+
attn_output = attn_output * self.head_scale
|
474 |
+
attn_output = self._merge_heads(attn_output)
|
475 |
+
attn_output = self.out_proj(attn_output)
|
476 |
+
|
477 |
+
return attn_output, attn_weights
|
478 |
+
|
479 |
+
|
480 |
+
class GLU(nn.Module):
|
481 |
+
"""From "GLU Variants Improve Transformer" by https://arxiv.org/abs/2002.05202"""
|
482 |
+
|
483 |
+
config: DalleBartConfig
|
484 |
+
ffn_dim: int
|
485 |
+
embed_dim: int
|
486 |
+
dtype: jnp.dtype = jnp.float32
|
487 |
+
is_encoder: bool = False
|
488 |
+
|
489 |
+
@nn.compact
|
490 |
+
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
491 |
+
|
492 |
+
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
|
493 |
+
self.config
|
494 |
+
)
|
495 |
+
|
496 |
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
497 |
+
x = norm(
|
498 |
+
self.config.ln_type,
|
499 |
+
dtype=self.dtype,
|
500 |
+
epsilon=1e-05,
|
501 |
+
use_scale=self.config.force_ln_scale,
|
502 |
+
)(x)
|
503 |
+
w = nn.Dense(
|
504 |
+
self.ffn_dim,
|
505 |
+
dtype=self.dtype,
|
506 |
+
use_bias=self.config.use_bias,
|
507 |
+
kernel_init=deepnet_init(gain)
|
508 |
+
if self.config.use_deepnet_scaling
|
509 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
510 |
+
)(x)
|
511 |
+
w = ACT2FN[self.config.activation_function](w)
|
512 |
+
v = nn.Dense(
|
513 |
+
self.ffn_dim,
|
514 |
+
dtype=self.dtype,
|
515 |
+
use_bias=self.config.use_bias,
|
516 |
+
kernel_init=deepnet_init(gain)
|
517 |
+
if self.config.use_deepnet_scaling
|
518 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
519 |
+
)(x)
|
520 |
+
x = w * v
|
521 |
+
if self.config.ln_positions in ["normformer"]:
|
522 |
+
x = norm(
|
523 |
+
self.config.ln_type,
|
524 |
+
dtype=self.dtype,
|
525 |
+
epsilon=1e-05,
|
526 |
+
use_scale=self.config.force_ln_scale,
|
527 |
+
)(x)
|
528 |
+
x = nn.Dropout(rate=self.config.activation_dropout)(
|
529 |
+
x, deterministic=deterministic
|
530 |
+
)
|
531 |
+
|
532 |
+
x = nn.Dense(
|
533 |
+
self.embed_dim,
|
534 |
+
dtype=self.dtype,
|
535 |
+
use_bias=self.config.use_bias,
|
536 |
+
kernel_init=deepnet_init(gain)
|
537 |
+
if self.config.use_deepnet_scaling
|
538 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
539 |
+
)(x)
|
540 |
+
if self.config.ln_positions in ["swinv2", "cogview"]:
|
541 |
+
x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
|
542 |
+
x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
|
543 |
+
return x
|
544 |
+
|
545 |
+
|
546 |
+
class FFN(nn.Module):
|
547 |
+
"""Simple FFN layer"""
|
548 |
+
|
549 |
+
config: DalleBartConfig
|
550 |
+
ffn_dim: int
|
551 |
+
embed_dim: int
|
552 |
+
dtype: jnp.dtype = jnp.float32
|
553 |
+
is_encoder: bool = False
|
554 |
+
|
555 |
+
@nn.compact
|
556 |
+
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
557 |
+
|
558 |
+
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
|
559 |
+
self.config
|
560 |
+
)
|
561 |
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
562 |
+
x = norm(
|
563 |
+
self.config.ln_type,
|
564 |
+
dtype=self.dtype,
|
565 |
+
epsilon=1e-05,
|
566 |
+
use_scale=self.config.force_ln_scale,
|
567 |
+
)(x)
|
568 |
+
x = nn.Dense(
|
569 |
+
self.ffn_dim,
|
570 |
+
dtype=self.dtype,
|
571 |
+
use_bias=self.config.use_bias,
|
572 |
+
kernel_init=deepnet_init(gain)
|
573 |
+
if self.config.use_deepnet_scaling
|
574 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
575 |
+
)(x)
|
576 |
+
x = ACT2FN[self.config.activation_function](x)
|
577 |
+
if self.config.ln_positions in ["normformer"]:
|
578 |
+
x = norm(
|
579 |
+
self.config.ln_type,
|
580 |
+
dtype=self.dtype,
|
581 |
+
epsilon=1e-05,
|
582 |
+
use_scale=self.config.force_ln_scale,
|
583 |
+
)(x)
|
584 |
+
x = nn.Dropout(rate=self.config.activation_dropout)(
|
585 |
+
x, deterministic=deterministic
|
586 |
+
)
|
587 |
+
x = nn.Dense(
|
588 |
+
self.embed_dim,
|
589 |
+
dtype=self.dtype,
|
590 |
+
use_bias=self.config.use_bias,
|
591 |
+
kernel_init=deepnet_init(gain)
|
592 |
+
if self.config.use_deepnet_scaling
|
593 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
594 |
+
)(x)
|
595 |
+
if self.config.ln_positions in ["swinv2", "cogview"]:
|
596 |
+
x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
|
597 |
+
x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
|
598 |
+
return x
|
599 |
+
|
600 |
+
|
601 |
+
class FlaxBartEncoderLayer(nn.Module):
|
602 |
+
"""
|
603 |
+
Edits:
|
604 |
+
- no bias
|
605 |
+
- use custom FlaxBartAttention
|
606 |
+
"""
|
607 |
+
|
608 |
+
config: DalleBartConfig
|
609 |
+
dtype: jnp.dtype = jnp.float32
|
610 |
+
add_norm: bool = False
|
611 |
+
use_scale: bool = True
|
612 |
+
|
613 |
+
@nn.compact
|
614 |
+
def __call__(
|
615 |
+
self,
|
616 |
+
hidden_states: jnp.ndarray,
|
617 |
+
attention_mask: jnp.ndarray,
|
618 |
+
output_attentions: bool = True,
|
619 |
+
deterministic: bool = True,
|
620 |
+
) -> Tuple[jnp.ndarray]:
|
621 |
+
|
622 |
+
res_gain = (
|
623 |
+
deepnet_gain["encoder"]["alpha"](self.config)
|
624 |
+
if self.config.use_deepnet_scaling
|
625 |
+
else 1
|
626 |
+
)
|
627 |
+
|
628 |
+
embed_dim = self.config.d_model
|
629 |
+
residual = hidden_states
|
630 |
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
631 |
+
hidden_states = norm(
|
632 |
+
self.config.ln_type,
|
633 |
+
dtype=self.dtype,
|
634 |
+
epsilon=1e-05,
|
635 |
+
use_scale=self.config.force_ln_scale,
|
636 |
+
)(hidden_states)
|
637 |
+
hidden_states, attn_weights = FlaxBartAttention(
|
638 |
+
config=self.config,
|
639 |
+
embed_dim=embed_dim,
|
640 |
+
num_heads=self.config.encoder_attention_heads,
|
641 |
+
dropout=self.config.attention_dropout,
|
642 |
+
bias=self.config.use_bias,
|
643 |
+
dtype=self.dtype,
|
644 |
+
is_encoder=True,
|
645 |
+
q_length=self.config.max_text_length,
|
646 |
+
k_length=self.config.max_text_length,
|
647 |
+
)(hidden_states=hidden_states, attention_mask=attention_mask)
|
648 |
+
|
649 |
+
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
|
650 |
+
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
651 |
+
hidden_states
|
652 |
+
)
|
653 |
+
hidden_states = nn.Dropout(rate=self.config.dropout)(
|
654 |
+
hidden_states, deterministic=deterministic
|
655 |
+
)
|
656 |
+
hidden_states = residual * res_gain + hidden_states
|
657 |
+
if self.config.ln_positions in ["postln"]:
|
658 |
+
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
659 |
+
hidden_states
|
660 |
+
)
|
661 |
+
|
662 |
+
residual = hidden_states
|
663 |
+
ff_block = (
|
664 |
+
GLU(
|
665 |
+
config=self.config,
|
666 |
+
ffn_dim=self.config.encoder_ffn_dim,
|
667 |
+
embed_dim=embed_dim,
|
668 |
+
dtype=self.dtype,
|
669 |
+
is_encoder=True,
|
670 |
+
)
|
671 |
+
if self.config.use_glu
|
672 |
+
else FFN(
|
673 |
+
config=self.config,
|
674 |
+
ffn_dim=self.config.encoder_ffn_dim,
|
675 |
+
embed_dim=embed_dim,
|
676 |
+
dtype=self.dtype,
|
677 |
+
is_encoder=True,
|
678 |
+
)
|
679 |
+
)
|
680 |
+
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
681 |
+
hidden_states = residual * res_gain + hidden_states
|
682 |
+
if self.add_norm or self.config.ln_positions in ["postln"]:
|
683 |
+
use_scale = (
|
684 |
+
self.use_scale
|
685 |
+
or self.config.ln_positions == "postln"
|
686 |
+
or self.config.force_ln_scale
|
687 |
+
)
|
688 |
+
hidden_states = norm(
|
689 |
+
self.config.ln_type,
|
690 |
+
dtype=self.dtype,
|
691 |
+
epsilon=1e-05,
|
692 |
+
use_scale=use_scale,
|
693 |
+
)(hidden_states)
|
694 |
+
|
695 |
+
outputs = (hidden_states,)
|
696 |
+
|
697 |
+
if output_attentions:
|
698 |
+
outputs += (attn_weights,)
|
699 |
+
|
700 |
+
return outputs
|
701 |
+
|
702 |
+
|
703 |
+
class FlaxBartDecoderLayer(nn.Module):
|
704 |
+
"""
|
705 |
+
Edits:
|
706 |
+
- no bias
|
707 |
+
- use custom FlaxBartAttention
|
708 |
+
"""
|
709 |
+
|
710 |
+
config: DalleBartConfig
|
711 |
+
dtype: jnp.dtype = jnp.float32
|
712 |
+
add_norm: bool = False
|
713 |
+
use_scale: bool = False
|
714 |
+
|
715 |
+
@nn.compact
|
716 |
+
def __call__(
|
717 |
+
self,
|
718 |
+
hidden_states: jnp.ndarray,
|
719 |
+
attention_mask: jnp.ndarray,
|
720 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
721 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
722 |
+
init_cache: bool = False,
|
723 |
+
output_attentions: bool = True,
|
724 |
+
deterministic: bool = True,
|
725 |
+
) -> Tuple[jnp.ndarray]:
|
726 |
+
|
727 |
+
res_gain = (
|
728 |
+
deepnet_gain["decoder"]["alpha"](self.config)
|
729 |
+
if self.config.use_deepnet_scaling
|
730 |
+
else 1
|
731 |
+
)
|
732 |
+
|
733 |
+
embed_dim = self.config.d_model
|
734 |
+
residual = hidden_states
|
735 |
+
|
736 |
+
# Self Attention
|
737 |
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
738 |
+
hidden_states = norm(
|
739 |
+
self.config.ln_type,
|
740 |
+
dtype=self.dtype,
|
741 |
+
epsilon=1e-05,
|
742 |
+
use_scale=self.config.force_ln_scale,
|
743 |
+
)(hidden_states)
|
744 |
+
hidden_states, attn_weights = FlaxBartAttention(
|
745 |
+
config=self.config,
|
746 |
+
embed_dim=embed_dim,
|
747 |
+
num_heads=self.config.decoder_attention_heads,
|
748 |
+
dropout=self.config.attention_dropout,
|
749 |
+
causal=True,
|
750 |
+
bias=self.config.use_bias,
|
751 |
+
dtype=self.dtype,
|
752 |
+
is_encoder=False,
|
753 |
+
q_length=self.config.image_length,
|
754 |
+
k_length=self.config.image_length,
|
755 |
+
)(
|
756 |
+
hidden_states=hidden_states,
|
757 |
+
attention_mask=attention_mask,
|
758 |
+
init_cache=init_cache,
|
759 |
+
)
|
760 |
+
|
761 |
+
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
|
762 |
+
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
763 |
+
hidden_states
|
764 |
+
)
|
765 |
+
hidden_states = nn.Dropout(rate=self.config.dropout)(
|
766 |
+
hidden_states, deterministic=deterministic
|
767 |
+
)
|
768 |
+
hidden_states = residual * res_gain + hidden_states
|
769 |
+
if self.config.ln_positions in ["postln"]:
|
770 |
+
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
771 |
+
hidden_states
|
772 |
+
)
|
773 |
+
|
774 |
+
# Cross Attention
|
775 |
+
cross_attn_weights = None
|
776 |
+
if encoder_hidden_states is not None:
|
777 |
+
residual = hidden_states
|
778 |
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
779 |
+
hidden_states = norm(
|
780 |
+
self.config.ln_type,
|
781 |
+
dtype=self.dtype,
|
782 |
+
epsilon=1e-05,
|
783 |
+
use_scale=self.config.force_ln_scale,
|
784 |
+
)(hidden_states)
|
785 |
+
hidden_states, cross_attn_weights = FlaxBartAttention(
|
786 |
+
config=self.config,
|
787 |
+
embed_dim=embed_dim,
|
788 |
+
num_heads=self.config.decoder_attention_heads,
|
789 |
+
dropout=self.config.attention_dropout,
|
790 |
+
bias=self.config.use_bias,
|
791 |
+
dtype=self.dtype,
|
792 |
+
is_encoder=False,
|
793 |
+
q_length=self.config.image_length,
|
794 |
+
k_length=self.config.max_text_length,
|
795 |
+
)(
|
796 |
+
hidden_states=hidden_states,
|
797 |
+
key_value_states=encoder_hidden_states,
|
798 |
+
attention_mask=encoder_attention_mask,
|
799 |
+
)
|
800 |
+
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
|
801 |
+
hidden_states = norm(
|
802 |
+
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
803 |
+
)(hidden_states)
|
804 |
+
hidden_states = nn.Dropout(rate=self.config.dropout)(
|
805 |
+
hidden_states, deterministic=deterministic
|
806 |
+
)
|
807 |
+
hidden_states = residual * res_gain + hidden_states
|
808 |
+
if self.config.ln_positions in ["postln"]:
|
809 |
+
hidden_states = norm(
|
810 |
+
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
811 |
+
)(hidden_states)
|
812 |
+
|
813 |
+
# Feed forward
|
814 |
+
residual = hidden_states
|
815 |
+
ff_block = (
|
816 |
+
GLU(
|
817 |
+
config=self.config,
|
818 |
+
ffn_dim=self.config.decoder_ffn_dim,
|
819 |
+
embed_dim=embed_dim,
|
820 |
+
dtype=self.dtype,
|
821 |
+
is_encoder=False,
|
822 |
+
)
|
823 |
+
if self.config.use_glu
|
824 |
+
else FFN(
|
825 |
+
config=self.config,
|
826 |
+
ffn_dim=self.config.decoder_ffn_dim,
|
827 |
+
embed_dim=embed_dim,
|
828 |
+
dtype=self.dtype,
|
829 |
+
is_encoder=False,
|
830 |
+
)
|
831 |
+
)
|
832 |
+
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
833 |
+
hidden_states = residual * res_gain + hidden_states
|
834 |
+
if self.add_norm or self.config.ln_positions in ["postln"]:
|
835 |
+
use_scale = (
|
836 |
+
self.use_scale
|
837 |
+
or self.config.ln_positions == "postln"
|
838 |
+
or self.config.force_ln_scale
|
839 |
+
)
|
840 |
+
hidden_states = norm(
|
841 |
+
self.config.ln_type,
|
842 |
+
dtype=self.dtype,
|
843 |
+
epsilon=1e-05,
|
844 |
+
use_scale=use_scale,
|
845 |
+
)(hidden_states)
|
846 |
+
|
847 |
+
outputs = (hidden_states,)
|
848 |
+
|
849 |
+
if output_attentions:
|
850 |
+
outputs += (attn_weights, cross_attn_weights)
|
851 |
+
|
852 |
+
return outputs
|
853 |
+
|
854 |
+
|
855 |
+
class FlaxBartEncoderLayerCollection(nn.Module):
|
856 |
+
config: DalleBartConfig
|
857 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
858 |
+
"""
|
859 |
+
Edits:
|
860 |
+
- use custom FlaxBartEncoderLayer
|
861 |
+
- allow Gradient Checkpointing (nn.remat)
|
862 |
+
"""
|
863 |
+
|
864 |
+
@nn.compact
|
865 |
+
def __call__(
|
866 |
+
self,
|
867 |
+
hidden_states,
|
868 |
+
attention_mask,
|
869 |
+
deterministic: bool = True,
|
870 |
+
output_attentions: bool = False,
|
871 |
+
output_hidden_states: bool = False,
|
872 |
+
return_dict: bool = True,
|
873 |
+
):
|
874 |
+
all_hidden_states = () if output_hidden_states else None
|
875 |
+
all_self_attns = () if output_attentions else None
|
876 |
+
|
877 |
+
n_layers = self.config.encoder_layers
|
878 |
+
layer = (
|
879 |
+
remat(FlaxBartEncoderLayer, static_argnums=(2, 3))
|
880 |
+
if self.config.gradient_checkpointing
|
881 |
+
else FlaxBartEncoderLayer
|
882 |
+
)
|
883 |
+
for i in range(n_layers):
|
884 |
+
if output_hidden_states:
|
885 |
+
all_hidden_states += (hidden_states,)
|
886 |
+
# final layernorm on the output of the last layer
|
887 |
+
# or every 6 layers for Swin v2
|
888 |
+
add_norm = (
|
889 |
+
self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0)
|
890 |
+
) or (self.config.use_final_ln_encoder and (i == n_layers - 1))
|
891 |
+
# we don't need to scale the norm for the last layer
|
892 |
+
use_scale = i != n_layers - 1
|
893 |
+
layer_outputs = layer(
|
894 |
+
self.config, dtype=self.dtype, add_norm=add_norm, use_scale=use_scale
|
895 |
+
)(
|
896 |
+
hidden_states,
|
897 |
+
attention_mask,
|
898 |
+
output_attentions,
|
899 |
+
deterministic,
|
900 |
+
)
|
901 |
+
hidden_states = layer_outputs[0]
|
902 |
+
if output_attentions:
|
903 |
+
all_self_attns += (layer_outputs[1],)
|
904 |
+
|
905 |
+
# add hidden states from the last layer
|
906 |
+
if output_hidden_states:
|
907 |
+
all_hidden_states += (hidden_states,)
|
908 |
+
|
909 |
+
outputs = [
|
910 |
+
hidden_states,
|
911 |
+
all_hidden_states,
|
912 |
+
all_self_attns,
|
913 |
+
]
|
914 |
+
|
915 |
+
if not return_dict:
|
916 |
+
return tuple(v for v in outputs if v is not None)
|
917 |
+
|
918 |
+
return FlaxBaseModelOutput(
|
919 |
+
last_hidden_state=hidden_states,
|
920 |
+
hidden_states=all_hidden_states,
|
921 |
+
attentions=all_self_attns,
|
922 |
+
)
|
923 |
+
|
924 |
+
|
925 |
+
class FlaxBartDecoderLayerCollection(nn.Module):
|
926 |
+
config: DalleBartConfig
|
927 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
928 |
+
"""
|
929 |
+
Edits:
|
930 |
+
- use custom FlaxBartDecoderLayer
|
931 |
+
- allow Gradient Checkpointing (nn.remat)
|
932 |
+
"""
|
933 |
+
|
934 |
+
@nn.compact
|
935 |
+
def __call__(
|
936 |
+
self,
|
937 |
+
hidden_states,
|
938 |
+
attention_mask,
|
939 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
940 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
941 |
+
deterministic: bool = True,
|
942 |
+
init_cache: bool = False,
|
943 |
+
output_attentions: bool = False,
|
944 |
+
output_hidden_states: bool = False,
|
945 |
+
return_dict: bool = True,
|
946 |
+
):
|
947 |
+
# decoder layers
|
948 |
+
all_hidden_states = () if output_hidden_states else None
|
949 |
+
all_self_attns = () if output_attentions else None
|
950 |
+
all_cross_attentions = (
|
951 |
+
() if (output_attentions and encoder_hidden_states is not None) else None
|
952 |
+
)
|
953 |
+
|
954 |
+
n_layers = self.config.decoder_layers
|
955 |
+
layer = (
|
956 |
+
remat(FlaxBartDecoderLayer, static_argnums=(4, 5, 6))
|
957 |
+
if self.config.gradient_checkpointing
|
958 |
+
else FlaxBartDecoderLayer
|
959 |
+
)
|
960 |
+
for i in range(n_layers):
|
961 |
+
if output_hidden_states:
|
962 |
+
all_hidden_states += (hidden_states,)
|
963 |
+
# final layernorm on the output of the last layer
|
964 |
+
# or every 6 layers for Swin v2
|
965 |
+
add_norm = (
|
966 |
+
self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0)
|
967 |
+
) or (self.config.use_final_ln_decoder and (i == n_layers - 1))
|
968 |
+
# we don't need to scale the norm for the last layer
|
969 |
+
use_scale = i != n_layers - 1
|
970 |
+
layer_outputs = layer(
|
971 |
+
self.config, dtype=self.dtype, add_norm=add_norm, use_scale=use_scale
|
972 |
+
)(
|
973 |
+
hidden_states,
|
974 |
+
attention_mask,
|
975 |
+
encoder_hidden_states,
|
976 |
+
encoder_attention_mask,
|
977 |
+
init_cache,
|
978 |
+
output_attentions,
|
979 |
+
deterministic,
|
980 |
+
)
|
981 |
+
|
982 |
+
hidden_states = layer_outputs[0]
|
983 |
+
if output_attentions:
|
984 |
+
all_self_attns += (layer_outputs[1],)
|
985 |
+
|
986 |
+
if encoder_hidden_states is not None:
|
987 |
+
all_cross_attentions += (layer_outputs[2],)
|
988 |
+
|
989 |
+
# add hidden states from the last decoder layer
|
990 |
+
if output_hidden_states:
|
991 |
+
all_hidden_states += (hidden_states,)
|
992 |
+
|
993 |
+
outputs = [
|
994 |
+
hidden_states,
|
995 |
+
all_hidden_states,
|
996 |
+
all_self_attns,
|
997 |
+
all_cross_attentions,
|
998 |
+
]
|
999 |
+
|
1000 |
+
if not return_dict:
|
1001 |
+
return tuple(v for v in outputs if v is not None)
|
1002 |
+
|
1003 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
1004 |
+
last_hidden_state=hidden_states,
|
1005 |
+
hidden_states=all_hidden_states,
|
1006 |
+
attentions=all_self_attns,
|
1007 |
+
cross_attentions=all_cross_attentions,
|
1008 |
+
)
|
1009 |
+
|
1010 |
+
|
1011 |
+
class FlaxBartEncoder(nn.Module):
|
1012 |
+
config: DalleBartConfig
|
1013 |
+
embed_tokens: nn.Embed
|
1014 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
1015 |
+
"""
|
1016 |
+
Edits:
|
1017 |
+
- offset set to 0 (no padding token)
|
1018 |
+
- use max_text_length instead of max_position_embeddings
|
1019 |
+
- use custom FlaxBartEncoderLayerCollection
|
1020 |
+
- embed_tokens cannot be None (issue at compile time)
|
1021 |
+
"""
|
1022 |
+
|
1023 |
+
def setup(self):
|
1024 |
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
1025 |
+
|
1026 |
+
embed_dim = self.config.d_model
|
1027 |
+
self.padding_idx = self.config.pad_token_id
|
1028 |
+
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
|
1029 |
+
|
1030 |
+
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
1031 |
+
# and adjust num_embeddings appropriately. Other models don't have this hack
|
1032 |
+
self.offset = 0
|
1033 |
+
if self.config.use_absolute_position_embeddings:
|
1034 |
+
self.embed_positions = nn.Embed(
|
1035 |
+
self.config.max_text_length + self.offset, # image length for BOS
|
1036 |
+
embed_dim,
|
1037 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
1038 |
+
)
|
1039 |
+
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
|
1040 |
+
self.layernorm_embedding = norm(
|
1041 |
+
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
1042 |
+
)
|
1043 |
+
|
1044 |
+
def __call__(
|
1045 |
+
self,
|
1046 |
+
input_ids,
|
1047 |
+
attention_mask,
|
1048 |
+
position_ids,
|
1049 |
+
output_attentions: bool = False,
|
1050 |
+
output_hidden_states: bool = False,
|
1051 |
+
return_dict: bool = True,
|
1052 |
+
deterministic: bool = True,
|
1053 |
+
):
|
1054 |
+
input_shape = input_ids.shape
|
1055 |
+
input_ids = input_ids.reshape(-1, input_shape[-1])
|
1056 |
+
|
1057 |
+
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
|
1058 |
+
|
1059 |
+
if self.config.use_absolute_position_embeddings:
|
1060 |
+
embed_pos = self.embed_positions(position_ids + self.offset)
|
1061 |
+
hidden_states = hidden_states + embed_pos
|
1062 |
+
|
1063 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
1064 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
1065 |
+
|
1066 |
+
outputs = self.layers(
|
1067 |
+
hidden_states,
|
1068 |
+
attention_mask,
|
1069 |
+
deterministic=deterministic,
|
1070 |
+
output_attentions=output_attentions,
|
1071 |
+
output_hidden_states=output_hidden_states,
|
1072 |
+
return_dict=return_dict,
|
1073 |
+
)
|
1074 |
+
|
1075 |
+
if not return_dict:
|
1076 |
+
return outputs
|
1077 |
+
|
1078 |
+
return FlaxBaseModelOutput(
|
1079 |
+
last_hidden_state=outputs.last_hidden_state,
|
1080 |
+
hidden_states=outputs.hidden_states,
|
1081 |
+
attentions=outputs.attentions,
|
1082 |
+
)
|
1083 |
+
|
1084 |
+
|
1085 |
+
class FlaxBartDecoder(nn.Module):
|
1086 |
+
config: DalleBartConfig
|
1087 |
+
embed_tokens: nn.Embed
|
1088 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
1089 |
+
"""
|
1090 |
+
Edits:
|
1091 |
+
- offset set to 0 (no padding token)
|
1092 |
+
- use image_length instead of max_position_embeddings
|
1093 |
+
- use custom FlaxBartDecoderLayerCollection
|
1094 |
+
- embed_tokens cannot be None (issue at compile time)
|
1095 |
+
"""
|
1096 |
+
|
1097 |
+
def setup(self):
|
1098 |
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
1099 |
+
|
1100 |
+
embed_dim = self.config.d_model
|
1101 |
+
self.padding_idx = self.config.pad_token_id
|
1102 |
+
self.embed_scale = (
|
1103 |
+
math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
|
1104 |
+
)
|
1105 |
+
|
1106 |
+
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
1107 |
+
# and adjust num_embeddings appropriately. Other models don't have this hack
|
1108 |
+
self.offset = 0
|
1109 |
+
if self.config.use_absolute_position_embeddings:
|
1110 |
+
self.embed_positions = nn.Embed(
|
1111 |
+
self.config.image_length + self.offset, # image length for BOS
|
1112 |
+
embed_dim,
|
1113 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
1114 |
+
)
|
1115 |
+
|
1116 |
+
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
1117 |
+
self.layernorm_embedding = norm(
|
1118 |
+
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
1119 |
+
)
|
1120 |
+
|
1121 |
+
def __call__(
|
1122 |
+
self,
|
1123 |
+
input_ids,
|
1124 |
+
attention_mask,
|
1125 |
+
position_ids,
|
1126 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
1127 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
1128 |
+
init_cache: bool = False,
|
1129 |
+
output_attentions: bool = False,
|
1130 |
+
output_hidden_states: bool = False,
|
1131 |
+
return_dict: bool = True,
|
1132 |
+
deterministic: bool = True,
|
1133 |
+
):
|
1134 |
+
input_shape = input_ids.shape
|
1135 |
+
input_ids = input_ids.reshape(-1, input_shape[-1])
|
1136 |
+
|
1137 |
+
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
|
1138 |
+
|
1139 |
+
if self.config.use_absolute_position_embeddings:
|
1140 |
+
embed_pos = self.embed_positions(position_ids + self.offset)
|
1141 |
+
hidden_states = hidden_states + embed_pos
|
1142 |
+
|
1143 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
1144 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
1145 |
+
|
1146 |
+
outputs = self.layers(
|
1147 |
+
hidden_states,
|
1148 |
+
attention_mask,
|
1149 |
+
encoder_hidden_states,
|
1150 |
+
encoder_attention_mask,
|
1151 |
+
deterministic=deterministic,
|
1152 |
+
init_cache=init_cache,
|
1153 |
+
output_attentions=output_attentions,
|
1154 |
+
output_hidden_states=output_hidden_states,
|
1155 |
+
return_dict=return_dict,
|
1156 |
+
)
|
1157 |
+
|
1158 |
+
if not return_dict:
|
1159 |
+
return outputs
|
1160 |
+
|
1161 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
1162 |
+
last_hidden_state=outputs.last_hidden_state,
|
1163 |
+
hidden_states=outputs.hidden_states,
|
1164 |
+
attentions=outputs.attentions,
|
1165 |
+
cross_attentions=outputs.cross_attentions,
|
1166 |
+
)
|
1167 |
+
|
1168 |
+
|
1169 |
+
class FlaxBartModule(FlaxBartModule):
|
1170 |
+
"""
|
1171 |
+
Edits
|
1172 |
+
- use custom FlaxBartEncoder & FlaxBartDecoder
|
1173 |
+
- use separate embeddings for Encoder & Decoder
|
1174 |
+
"""
|
1175 |
+
|
1176 |
+
def setup(self):
|
1177 |
+
encoder_embed_tokens = nn.Embed(
|
1178 |
+
self.config.encoder_vocab_size,
|
1179 |
+
self.config.d_model,
|
1180 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
1181 |
+
)
|
1182 |
+
decoder_embed_tokens = nn.Embed(
|
1183 |
+
self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
|
1184 |
+
self.config.d_model,
|
1185 |
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
1186 |
+
)
|
1187 |
+
|
1188 |
+
self.encoder = FlaxBartEncoder(
|
1189 |
+
self.config, dtype=self.dtype, embed_tokens=encoder_embed_tokens
|
1190 |
+
)
|
1191 |
+
self.decoder = FlaxBartDecoder(
|
1192 |
+
self.config, dtype=self.dtype, embed_tokens=decoder_embed_tokens
|
1193 |
+
)
|
1194 |
+
|
1195 |
+
|
1196 |
+
class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
|
1197 |
+
"""
|
1198 |
+
Edits:
|
1199 |
+
- added num_params property
|
1200 |
+
- config_class replaced to DalleBartConfig
|
1201 |
+
- __init__ accepts abstract_init which does uses parameter shape to initialize the model
|
1202 |
+
- init weights on CPU with `load_on_cpu`
|
1203 |
+
- restore weights on CPU with custom `from_pretrained`
|
1204 |
+
"""
|
1205 |
+
|
1206 |
+
config_class = DalleBartConfig
|
1207 |
+
|
1208 |
+
def __init__(
|
1209 |
+
self,
|
1210 |
+
config: DalleBartConfig,
|
1211 |
+
input_shape: Tuple[int] = (1, 1),
|
1212 |
+
seed: int = 0,
|
1213 |
+
dtype: jnp.dtype = jnp.float32,
|
1214 |
+
abstract_init: bool = False,
|
1215 |
+
load_on_cpu: bool = False,
|
1216 |
+
init_weights: bool = True,
|
1217 |
+
**kwargs,
|
1218 |
+
):
|
1219 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
1220 |
+
|
1221 |
+
# adapted from HuggingFace FlaxPreTrainedModel
|
1222 |
+
if config is None:
|
1223 |
+
raise ValueError("config cannot be None")
|
1224 |
+
|
1225 |
+
if module is None:
|
1226 |
+
raise ValueError("module cannot be None")
|
1227 |
+
|
1228 |
+
# Those are private to be exposed as typed property on derived classes.
|
1229 |
+
self._config = config
|
1230 |
+
self._module = module
|
1231 |
+
|
1232 |
+
# Those are public as their type is generic to every derived classes.
|
1233 |
+
self.key = PRNGKey(seed)
|
1234 |
+
self.dtype = dtype
|
1235 |
+
|
1236 |
+
if init_weights:
|
1237 |
+
# get shape of params only
|
1238 |
+
random_params = self.init_weights(
|
1239 |
+
self.key,
|
1240 |
+
input_shape,
|
1241 |
+
abstract_init=abstract_init,
|
1242 |
+
load_on_cpu=load_on_cpu,
|
1243 |
+
)
|
1244 |
+
|
1245 |
+
# save required_params as set
|
1246 |
+
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
|
1247 |
+
self.params = random_params
|
1248 |
+
|
1249 |
+
def init_weights(
|
1250 |
+
self, rng=None, input_shape=(1, 1), abstract_init=False, load_on_cpu=False
|
1251 |
+
):
|
1252 |
+
if rng is None:
|
1253 |
+
rng = self.key
|
1254 |
+
init_fn = super().init_weights
|
1255 |
+
if load_on_cpu:
|
1256 |
+
init_fn = jax.jit(init_fn, static_argnums=(1,), backend="cpu")
|
1257 |
+
if abstract_init:
|
1258 |
+
# only set shape and dtype, load parameters separately
|
1259 |
+
init_fn = partial(init_fn, input_shape=input_shape)
|
1260 |
+
params = jax.eval_shape(init_fn, rng)
|
1261 |
+
else:
|
1262 |
+
params = init_fn(rng, input_shape)
|
1263 |
+
return params
|
1264 |
+
|
1265 |
+
@property
|
1266 |
+
def num_params(self):
|
1267 |
+
num_params = jax.tree_map(
|
1268 |
+
lambda param: param.size, flatten_dict(unfreeze(self.params))
|
1269 |
+
).values()
|
1270 |
+
return sum(list(num_params))
|
1271 |
+
|
1272 |
+
@classmethod
|
1273 |
+
def from_pretrained(
|
1274 |
+
cls,
|
1275 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
1276 |
+
dtype: jnp.dtype = jnp.float32,
|
1277 |
+
*model_args,
|
1278 |
+
**kwargs,
|
1279 |
+
):
|
1280 |
+
config = kwargs.pop("config", None)
|
1281 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
1282 |
+
from_pt = kwargs.pop("from_pt", False)
|
1283 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
1284 |
+
force_download = kwargs.pop("force_download", False)
|
1285 |
+
resume_download = kwargs.pop("resume_download", False)
|
1286 |
+
proxies = kwargs.pop("proxies", None)
|
1287 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
1288 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
1289 |
+
revision = kwargs.pop("revision", None)
|
1290 |
+
from_pipeline = kwargs.pop("_from_pipeline", None)
|
1291 |
+
from_auto_class = kwargs.pop("_from_auto", False)
|
1292 |
+
|
1293 |
+
user_agent = {
|
1294 |
+
"file_type": "model",
|
1295 |
+
"framework": "flax",
|
1296 |
+
"from_auto_class": from_auto_class,
|
1297 |
+
}
|
1298 |
+
if from_pipeline is not None:
|
1299 |
+
user_agent["using_pipeline"] = from_pipeline
|
1300 |
+
|
1301 |
+
if is_offline_mode() and not local_files_only:
|
1302 |
+
logger.info("Offline mode: forcing local_files_only=True")
|
1303 |
+
local_files_only = True
|
1304 |
+
|
1305 |
+
# Load config if we don't provide a configuration
|
1306 |
+
if not isinstance(config, PretrainedConfig):
|
1307 |
+
config_path = (
|
1308 |
+
config if config is not None else pretrained_model_name_or_path
|
1309 |
+
)
|
1310 |
+
config, model_kwargs = cls.config_class.from_pretrained(
|
1311 |
+
config_path,
|
1312 |
+
cache_dir=cache_dir,
|
1313 |
+
return_unused_kwargs=True,
|
1314 |
+
force_download=force_download,
|
1315 |
+
resume_download=resume_download,
|
1316 |
+
proxies=proxies,
|
1317 |
+
local_files_only=local_files_only,
|
1318 |
+
use_auth_token=use_auth_token,
|
1319 |
+
revision=revision,
|
1320 |
+
_from_auto=from_auto_class,
|
1321 |
+
_from_pipeline=from_pipeline,
|
1322 |
+
**kwargs,
|
1323 |
+
)
|
1324 |
+
else:
|
1325 |
+
model_kwargs = kwargs
|
1326 |
+
|
1327 |
+
# Add the dtype to model_kwargs
|
1328 |
+
model_kwargs["dtype"] = dtype
|
1329 |
+
|
1330 |
+
# Load model
|
1331 |
+
if pretrained_model_name_or_path is not None:
|
1332 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
1333 |
+
if from_pt and os.path.isfile(
|
1334 |
+
os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
1335 |
+
):
|
1336 |
+
# Load from a PyTorch checkpoint
|
1337 |
+
archive_file = os.path.join(
|
1338 |
+
pretrained_model_name_or_path, WEIGHTS_NAME
|
1339 |
+
)
|
1340 |
+
elif os.path.isfile(
|
1341 |
+
os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
1342 |
+
):
|
1343 |
+
# Load from a Flax checkpoint
|
1344 |
+
archive_file = os.path.join(
|
1345 |
+
pretrained_model_name_or_path, FLAX_WEIGHTS_NAME
|
1346 |
+
)
|
1347 |
+
else:
|
1348 |
+
raise EnvironmentError(
|
1349 |
+
f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory "
|
1350 |
+
f"{pretrained_model_name_or_path} or `from_pt` set to False"
|
1351 |
+
)
|
1352 |
+
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
|
1353 |
+
pretrained_model_name_or_path
|
1354 |
+
):
|
1355 |
+
archive_file = pretrained_model_name_or_path
|
1356 |
+
else:
|
1357 |
+
archive_file = hf_bucket_url(
|
1358 |
+
pretrained_model_name_or_path,
|
1359 |
+
filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
|
1360 |
+
revision=revision,
|
1361 |
+
)
|
1362 |
+
|
1363 |
+
# redirect to the cache, if necessary
|
1364 |
+
try:
|
1365 |
+
resolved_archive_file = cached_path(
|
1366 |
+
archive_file,
|
1367 |
+
cache_dir=cache_dir,
|
1368 |
+
force_download=force_download,
|
1369 |
+
proxies=proxies,
|
1370 |
+
resume_download=resume_download,
|
1371 |
+
local_files_only=local_files_only,
|
1372 |
+
use_auth_token=use_auth_token,
|
1373 |
+
user_agent=user_agent,
|
1374 |
+
)
|
1375 |
+
except EnvironmentError as err:
|
1376 |
+
logger.error(err)
|
1377 |
+
msg = (
|
1378 |
+
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
1379 |
+
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
|
1380 |
+
f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
|
1381 |
+
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
|
1382 |
+
)
|
1383 |
+
raise EnvironmentError(msg)
|
1384 |
+
|
1385 |
+
if resolved_archive_file == archive_file:
|
1386 |
+
logger.info(f"loading weights file {archive_file}")
|
1387 |
+
else:
|
1388 |
+
logger.info(
|
1389 |
+
f"loading weights file {archive_file} from cache at {resolved_archive_file}"
|
1390 |
+
)
|
1391 |
+
else:
|
1392 |
+
resolved_archive_file = None
|
1393 |
+
|
1394 |
+
# init random models
|
1395 |
+
model = cls(config, *model_args, **model_kwargs)
|
1396 |
+
|
1397 |
+
with open(resolved_archive_file, "rb") as state_f:
|
1398 |
+
try:
|
1399 |
+
state = from_bytes(cls, state_f.read())
|
1400 |
+
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
1401 |
+
try:
|
1402 |
+
with open(resolved_archive_file) as f:
|
1403 |
+
if f.read().startswith("version"):
|
1404 |
+
raise OSError(
|
1405 |
+
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
1406 |
+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
1407 |
+
"you cloned."
|
1408 |
+
)
|
1409 |
+
else:
|
1410 |
+
raise ValueError from e
|
1411 |
+
except (UnicodeDecodeError, ValueError):
|
1412 |
+
raise EnvironmentError(
|
1413 |
+
f"Unable to convert {archive_file} to Flax deserializable object. "
|
1414 |
+
)
|
1415 |
+
|
1416 |
+
# if model is base model only use model_prefix key
|
1417 |
+
if (
|
1418 |
+
cls.base_model_prefix not in dict(model.params)
|
1419 |
+
and cls.base_model_prefix in state
|
1420 |
+
):
|
1421 |
+
state = state[cls.base_model_prefix]
|
1422 |
+
|
1423 |
+
# if model is head model and we are loading weights from base model
|
1424 |
+
# we initialize new params dict with base_model_prefix
|
1425 |
+
if (
|
1426 |
+
cls.base_model_prefix in dict(model.params)
|
1427 |
+
and cls.base_model_prefix not in state
|
1428 |
+
):
|
1429 |
+
state = {cls.base_model_prefix: state}
|
1430 |
+
|
1431 |
+
# flatten dicts
|
1432 |
+
state = flatten_dict(state)
|
1433 |
+
|
1434 |
+
random_state = flatten_dict(unfreeze(model.params))
|
1435 |
+
|
1436 |
+
missing_keys = model.required_params - set(state.keys())
|
1437 |
+
unexpected_keys = set(state.keys()) - model.required_params
|
1438 |
+
|
1439 |
+
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
1440 |
+
# matching the weights in the model.
|
1441 |
+
mismatched_keys = []
|
1442 |
+
for key in state.keys():
|
1443 |
+
if key in random_state and state[key].shape != random_state[key].shape:
|
1444 |
+
if ignore_mismatched_sizes:
|
1445 |
+
mismatched_keys.append(
|
1446 |
+
(key, state[key].shape, random_state[key].shape)
|
1447 |
+
)
|
1448 |
+
state[key] = random_state[key]
|
1449 |
+
else:
|
1450 |
+
raise ValueError(
|
1451 |
+
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
|
1452 |
+
f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
|
1453 |
+
"Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
|
1454 |
+
"model."
|
1455 |
+
)
|
1456 |
+
|
1457 |
+
# add missing keys as random parameters
|
1458 |
+
for missing_key in missing_keys:
|
1459 |
+
state[missing_key] = random_state[missing_key]
|
1460 |
+
|
1461 |
+
# remove unexpected keys to not be saved again
|
1462 |
+
for unexpected_key in unexpected_keys:
|
1463 |
+
del state[unexpected_key]
|
1464 |
+
|
1465 |
+
if len(unexpected_keys) > 0:
|
1466 |
+
logger.warning(
|
1467 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
1468 |
+
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
1469 |
+
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
1470 |
+
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
|
1471 |
+
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
1472 |
+
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
1473 |
+
)
|
1474 |
+
else:
|
1475 |
+
logger.info(
|
1476 |
+
f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
|
1477 |
+
)
|
1478 |
+
|
1479 |
+
if len(missing_keys) > 0:
|
1480 |
+
logger.warning(
|
1481 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
1482 |
+
f"and are newly initialized: {missing_keys}\n"
|
1483 |
+
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
1484 |
+
)
|
1485 |
+
elif len(mismatched_keys) == 0:
|
1486 |
+
logger.info(
|
1487 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
1488 |
+
f"If your task is similar to the task the model of the checkpoint was trained on, "
|
1489 |
+
f"you can already use {model.__class__.__name__} for predictions without further training."
|
1490 |
+
)
|
1491 |
+
if len(mismatched_keys) > 0:
|
1492 |
+
mismatched_warning = "\n".join(
|
1493 |
+
[
|
1494 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
1495 |
+
for key, shape1, shape2 in mismatched_keys
|
1496 |
+
]
|
1497 |
+
)
|
1498 |
+
logger.warning(
|
1499 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
1500 |
+
f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
|
1501 |
+
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
1502 |
+
)
|
1503 |
+
|
1504 |
+
# set correct parameters
|
1505 |
+
model.params = unflatten_dict(state)
|
1506 |
+
|
1507 |
+
return model
|
1508 |
+
|
1509 |
+
|
1510 |
+
class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
1511 |
+
"""
|
1512 |
+
Edits:
|
1513 |
+
- no bias
|
1514 |
+
- lm_head set to image_vocab_size + 1 (for BOS)
|
1515 |
+
- uses custom FlaxBartModule
|
1516 |
+
"""
|
1517 |
+
|
1518 |
+
def setup(self):
|
1519 |
+
self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
|
1520 |
+
self.lm_head = nn.Dense(
|
1521 |
+
self.config.image_vocab_size
|
1522 |
+
+ 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
|
1523 |
+
use_bias=False,
|
1524 |
+
dtype=self.dtype,
|
1525 |
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
1526 |
+
)
|
1527 |
+
|
1528 |
+
def __call__(
|
1529 |
+
self,
|
1530 |
+
input_ids,
|
1531 |
+
attention_mask,
|
1532 |
+
decoder_input_ids,
|
1533 |
+
decoder_attention_mask,
|
1534 |
+
position_ids,
|
1535 |
+
decoder_position_ids,
|
1536 |
+
output_attentions: bool = False,
|
1537 |
+
output_hidden_states: bool = False,
|
1538 |
+
return_dict: bool = True,
|
1539 |
+
deterministic: bool = True,
|
1540 |
+
):
|
1541 |
+
outputs = self.model(
|
1542 |
+
input_ids=input_ids,
|
1543 |
+
attention_mask=attention_mask,
|
1544 |
+
decoder_input_ids=decoder_input_ids,
|
1545 |
+
decoder_attention_mask=decoder_attention_mask,
|
1546 |
+
position_ids=position_ids,
|
1547 |
+
decoder_position_ids=decoder_position_ids,
|
1548 |
+
output_attentions=output_attentions,
|
1549 |
+
output_hidden_states=output_hidden_states,
|
1550 |
+
return_dict=return_dict,
|
1551 |
+
deterministic=deterministic,
|
1552 |
+
)
|
1553 |
+
|
1554 |
+
hidden_states = outputs[0]
|
1555 |
+
|
1556 |
+
if self.config.tie_word_embeddings:
|
1557 |
+
shared_embedding = self.model.variables["params"]["shared"]["embedding"]
|
1558 |
+
lm_logits = self.lm_head.apply(
|
1559 |
+
{"params": {"kernel": shared_embedding.T}}, hidden_states
|
1560 |
+
)
|
1561 |
+
else:
|
1562 |
+
lm_logits = self.lm_head(hidden_states)
|
1563 |
+
|
1564 |
+
if not return_dict:
|
1565 |
+
output = (lm_logits,) + outputs[1:]
|
1566 |
+
return output
|
1567 |
+
|
1568 |
+
return FlaxSeq2SeqLMOutput(
|
1569 |
+
logits=lm_logits,
|
1570 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
1571 |
+
decoder_attentions=outputs.decoder_attentions,
|
1572 |
+
cross_attentions=outputs.cross_attentions,
|
1573 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
1574 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
1575 |
+
encoder_attentions=outputs.encoder_attentions,
|
1576 |
+
)
|
1577 |
+
|
1578 |
+
|
1579 |
+
@flax.struct.dataclass
|
1580 |
+
class SampleState:
|
1581 |
+
cur_len: jnp.ndarray
|
1582 |
+
sequences: jnp.ndarray
|
1583 |
+
running_token: jnp.ndarray
|
1584 |
+
is_sent_finished: jnp.ndarray
|
1585 |
+
prng_key: jnp.ndarray
|
1586 |
+
model_kwargs: Dict[str, jnp.ndarray]
|
1587 |
+
model_kwargs_uncond: Dict[str, jnp.ndarray]
|
1588 |
+
|
1589 |
+
|
1590 |
+
class DalleBart(
|
1591 |
+
PretrainedFromWandbMixin, FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration
|
1592 |
+
):
|
1593 |
+
"""
|
1594 |
+
Edits:
|
1595 |
+
- renamed from FlaxBartForConditionalGeneration
|
1596 |
+
- uses custom FlaxBartPreTrainedModel
|
1597 |
+
- uses custom FlaxBartForConditionalGenerationModule
|
1598 |
+
- no bias in decode method
|
1599 |
+
- custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
|
1600 |
+
related to position embedding during model.generate()
|
1601 |
+
- custom generate method to allow super conditions
|
1602 |
+
"""
|
1603 |
+
|
1604 |
+
module_class = FlaxBartForConditionalGenerationModule
|
1605 |
+
|
1606 |
+
def decode(
|
1607 |
+
self,
|
1608 |
+
decoder_input_ids,
|
1609 |
+
encoder_outputs,
|
1610 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
1611 |
+
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
1612 |
+
decoder_position_ids: Optional[jnp.ndarray] = None,
|
1613 |
+
past_key_values: dict = None,
|
1614 |
+
output_attentions: Optional[bool] = None,
|
1615 |
+
output_hidden_states: Optional[bool] = None,
|
1616 |
+
return_dict: Optional[bool] = None,
|
1617 |
+
train: bool = False,
|
1618 |
+
params: dict = None,
|
1619 |
+
dropout_rng: PRNGKey = None,
|
1620 |
+
):
|
1621 |
+
output_attentions = (
|
1622 |
+
output_attentions
|
1623 |
+
if output_attentions is not None
|
1624 |
+
else self.config.output_attentions
|
1625 |
+
)
|
1626 |
+
output_hidden_states = (
|
1627 |
+
output_hidden_states
|
1628 |
+
if output_hidden_states is not None
|
1629 |
+
else self.config.output_hidden_states
|
1630 |
+
)
|
1631 |
+
return_dict = (
|
1632 |
+
return_dict if return_dict is not None else self.config.return_dict
|
1633 |
+
)
|
1634 |
+
|
1635 |
+
encoder_hidden_states = encoder_outputs[0]
|
1636 |
+
if encoder_attention_mask is None:
|
1637 |
+
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
1638 |
+
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
1639 |
+
|
1640 |
+
batch_size, sequence_length = decoder_input_ids.shape
|
1641 |
+
if decoder_attention_mask is None:
|
1642 |
+
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
1643 |
+
|
1644 |
+
if decoder_position_ids is None:
|
1645 |
+
if past_key_values is not None:
|
1646 |
+
raise ValueError(
|
1647 |
+
"Make sure to provide `decoder_position_ids` when passing `past_key_values`."
|
1648 |
+
)
|
1649 |
+
|
1650 |
+
decoder_position_ids = jnp.broadcast_to(
|
1651 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
1652 |
+
)
|
1653 |
+
|
1654 |
+
# Handle any PRNG if needed
|
1655 |
+
rngs = {}
|
1656 |
+
if dropout_rng is not None:
|
1657 |
+
rngs["dropout"] = dropout_rng
|
1658 |
+
|
1659 |
+
inputs = {"params": params or self.params}
|
1660 |
+
|
1661 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
1662 |
+
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
1663 |
+
# it can be changed by FlaxBartAttention module
|
1664 |
+
if past_key_values:
|
1665 |
+
inputs["cache"] = past_key_values
|
1666 |
+
mutable = ["cache"]
|
1667 |
+
else:
|
1668 |
+
mutable = False
|
1669 |
+
|
1670 |
+
def _decoder_forward(
|
1671 |
+
module,
|
1672 |
+
decoder_input_ids,
|
1673 |
+
decoder_attention_mask,
|
1674 |
+
decoder_position_ids,
|
1675 |
+
**kwargs,
|
1676 |
+
):
|
1677 |
+
decoder_module = module._get_decoder_module()
|
1678 |
+
outputs = decoder_module(
|
1679 |
+
decoder_input_ids,
|
1680 |
+
decoder_attention_mask,
|
1681 |
+
decoder_position_ids,
|
1682 |
+
**kwargs,
|
1683 |
+
)
|
1684 |
+
hidden_states = outputs[0]
|
1685 |
+
|
1686 |
+
if self.config.tie_word_embeddings:
|
1687 |
+
shared_embedding = module.model.variables["params"]["shared"][
|
1688 |
+
"embedding"
|
1689 |
+
]
|
1690 |
+
lm_logits = module.lm_head.apply(
|
1691 |
+
{"params": {"kernel": shared_embedding.T}}, hidden_states
|
1692 |
+
)
|
1693 |
+
else:
|
1694 |
+
lm_logits = module.lm_head(hidden_states)
|
1695 |
+
|
1696 |
+
return lm_logits, outputs
|
1697 |
+
|
1698 |
+
outputs = self.module.apply(
|
1699 |
+
inputs,
|
1700 |
+
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
1701 |
+
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
1702 |
+
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
1703 |
+
encoder_hidden_states=encoder_hidden_states,
|
1704 |
+
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
|
1705 |
+
output_attentions=output_attentions,
|
1706 |
+
output_hidden_states=output_hidden_states,
|
1707 |
+
return_dict=return_dict,
|
1708 |
+
deterministic=not train,
|
1709 |
+
rngs=rngs,
|
1710 |
+
mutable=mutable,
|
1711 |
+
method=_decoder_forward,
|
1712 |
+
)
|
1713 |
+
|
1714 |
+
if past_key_values is None:
|
1715 |
+
lm_logits, decoder_outputs = outputs
|
1716 |
+
else:
|
1717 |
+
(lm_logits, decoder_outputs), past = outputs
|
1718 |
+
|
1719 |
+
if return_dict:
|
1720 |
+
outputs = FlaxCausalLMOutputWithCrossAttentions(
|
1721 |
+
logits=lm_logits,
|
1722 |
+
hidden_states=decoder_outputs.hidden_states,
|
1723 |
+
attentions=decoder_outputs.attentions,
|
1724 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
1725 |
+
)
|
1726 |
+
else:
|
1727 |
+
outputs = (lm_logits,) + decoder_outputs[1:]
|
1728 |
+
|
1729 |
+
# add updated cache to model output
|
1730 |
+
if past_key_values is not None and return_dict:
|
1731 |
+
outputs["past_key_values"] = unfreeze(past["cache"])
|
1732 |
+
return outputs
|
1733 |
+
elif past_key_values is not None and not return_dict:
|
1734 |
+
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
1735 |
+
|
1736 |
+
return outputs
|
1737 |
+
|
1738 |
+
def prepare_inputs_for_generation(
|
1739 |
+
self,
|
1740 |
+
decoder_input_ids,
|
1741 |
+
max_length,
|
1742 |
+
attention_mask: Optional[jnp.DeviceArray] = None,
|
1743 |
+
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
1744 |
+
encoder_outputs=None,
|
1745 |
+
**kwargs,
|
1746 |
+
):
|
1747 |
+
# initializing the cache
|
1748 |
+
batch_size, seq_length = decoder_input_ids.shape
|
1749 |
+
|
1750 |
+
past_key_values = self.init_cache(batch_size, max_length - 1, encoder_outputs)
|
1751 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
1752 |
+
# But since the decoder uses a causal mask, those positions are masked anyways.
|
1753 |
+
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
1754 |
+
extended_attention_mask = jnp.ones((batch_size, max_length - 1), dtype="i4")
|
1755 |
+
if decoder_attention_mask is not None:
|
1756 |
+
position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
|
1757 |
+
extended_attention_mask = lax.dynamic_update_slice(
|
1758 |
+
extended_attention_mask, decoder_attention_mask, (0, 0)
|
1759 |
+
)
|
1760 |
+
else:
|
1761 |
+
position_ids = jnp.broadcast_to(
|
1762 |
+
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
|
1763 |
+
)
|
1764 |
+
|
1765 |
+
return {
|
1766 |
+
"past_key_values": past_key_values,
|
1767 |
+
"encoder_outputs": encoder_outputs,
|
1768 |
+
"encoder_attention_mask": attention_mask,
|
1769 |
+
"decoder_attention_mask": extended_attention_mask,
|
1770 |
+
"decoder_position_ids": position_ids,
|
1771 |
+
}
|
1772 |
+
|
1773 |
+
def generate(
|
1774 |
+
self,
|
1775 |
+
input_ids: jnp.ndarray,
|
1776 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
1777 |
+
max_length: Optional[int] = None,
|
1778 |
+
pad_token_id: Optional[int] = None,
|
1779 |
+
bos_token_id: Optional[int] = None,
|
1780 |
+
eos_token_id: Optional[int] = None,
|
1781 |
+
decoder_start_token_id: Optional[int] = None,
|
1782 |
+
do_sample: Optional[bool] = None,
|
1783 |
+
prng_key: Optional[jnp.ndarray] = None,
|
1784 |
+
top_k: Optional[int] = None,
|
1785 |
+
top_p: Optional[float] = None,
|
1786 |
+
temperature: Optional[float] = None,
|
1787 |
+
num_beams: Optional[int] = None,
|
1788 |
+
no_repeat_ngram_size: Optional[int] = None,
|
1789 |
+
min_length: Optional[int] = None,
|
1790 |
+
forced_bos_token_id: Optional[int] = None,
|
1791 |
+
forced_eos_token_id: Optional[int] = None,
|
1792 |
+
length_penalty: Optional[float] = None,
|
1793 |
+
early_stopping: Optional[bool] = None,
|
1794 |
+
trace: bool = True,
|
1795 |
+
params: Optional[Dict[str, jnp.ndarray]] = None,
|
1796 |
+
condition_scale: Optional[float] = 1.0,
|
1797 |
+
input_ids_uncond: Optional[jnp.ndarray] = None,
|
1798 |
+
attention_mask_uncond: Optional[jnp.ndarray] = None,
|
1799 |
+
**model_kwargs,
|
1800 |
+
):
|
1801 |
+
"""Edit: Allow super conditioning."""
|
1802 |
+
|
1803 |
+
# set init values
|
1804 |
+
max_length = max_length if max_length is not None else self.config.max_length
|
1805 |
+
bos_token_id = (
|
1806 |
+
bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
1807 |
+
)
|
1808 |
+
pad_token_id = (
|
1809 |
+
pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
1810 |
+
)
|
1811 |
+
eos_token_id = (
|
1812 |
+
eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
1813 |
+
)
|
1814 |
+
decoder_start_token_id = (
|
1815 |
+
decoder_start_token_id
|
1816 |
+
if decoder_start_token_id
|
1817 |
+
else self.config.decoder_start_token_id
|
1818 |
+
)
|
1819 |
+
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
1820 |
+
|
1821 |
+
if decoder_start_token_id is None and self.config.is_encoder_decoder:
|
1822 |
+
raise ValueError(
|
1823 |
+
"`decoder_start_token_id` has to be defined for encoder-decoder generation."
|
1824 |
+
)
|
1825 |
+
|
1826 |
+
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
1827 |
+
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
1828 |
+
|
1829 |
+
if self.config.is_encoder_decoder:
|
1830 |
+
# add encoder_outputs to model_kwargs
|
1831 |
+
if model_kwargs.get("encoder_outputs") is None:
|
1832 |
+
model_kwargs_input = dict(model_kwargs)
|
1833 |
+
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
1834 |
+
input_ids,
|
1835 |
+
params,
|
1836 |
+
{"attention_mask": attention_mask, **model_kwargs_input},
|
1837 |
+
)
|
1838 |
+
if condition_scale != 1.0:
|
1839 |
+
assert (
|
1840 |
+
input_ids_uncond is not None
|
1841 |
+
), "`input_ids_uncond` has to be defined for super conditioning."
|
1842 |
+
assert (
|
1843 |
+
do_sample is True
|
1844 |
+
), "`do_sample` has to be True for super conditioning."
|
1845 |
+
assert (
|
1846 |
+
num_beams == 1
|
1847 |
+
), "`num_beams` has to be 1 for super conditioning."
|
1848 |
+
model_kwargs_uncond = (
|
1849 |
+
self._prepare_encoder_decoder_kwargs_for_generation(
|
1850 |
+
input_ids_uncond,
|
1851 |
+
params,
|
1852 |
+
{
|
1853 |
+
"attention_mask": attention_mask_uncond,
|
1854 |
+
**model_kwargs_input,
|
1855 |
+
},
|
1856 |
+
)
|
1857 |
+
)
|
1858 |
+
else:
|
1859 |
+
model_kwargs_uncond = None
|
1860 |
+
# prepare decoder_input_ids for generation
|
1861 |
+
input_ids = (
|
1862 |
+
jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
1863 |
+
)
|
1864 |
+
|
1865 |
+
if not do_sample and num_beams == 1:
|
1866 |
+
logits_processor = self._get_logits_processor(
|
1867 |
+
no_repeat_ngram_size,
|
1868 |
+
min_length,
|
1869 |
+
max_length,
|
1870 |
+
eos_token_id,
|
1871 |
+
forced_bos_token_id,
|
1872 |
+
forced_eos_token_id,
|
1873 |
+
)
|
1874 |
+
return self._greedy_search(
|
1875 |
+
input_ids,
|
1876 |
+
max_length,
|
1877 |
+
pad_token_id,
|
1878 |
+
eos_token_id,
|
1879 |
+
logits_processor=logits_processor,
|
1880 |
+
trace=trace,
|
1881 |
+
params=params,
|
1882 |
+
model_kwargs=model_kwargs,
|
1883 |
+
)
|
1884 |
+
elif do_sample and num_beams == 1:
|
1885 |
+
logits_warper = self._get_logits_warper(
|
1886 |
+
top_k=top_k, top_p=top_p, temperature=temperature
|
1887 |
+
)
|
1888 |
+
logits_processor = self._get_logits_processor(
|
1889 |
+
no_repeat_ngram_size,
|
1890 |
+
min_length,
|
1891 |
+
max_length,
|
1892 |
+
eos_token_id,
|
1893 |
+
forced_bos_token_id,
|
1894 |
+
forced_eos_token_id,
|
1895 |
+
)
|
1896 |
+
return self._sample(
|
1897 |
+
input_ids,
|
1898 |
+
max_length,
|
1899 |
+
pad_token_id,
|
1900 |
+
eos_token_id,
|
1901 |
+
prng_key,
|
1902 |
+
logits_warper=logits_warper,
|
1903 |
+
logits_processor=logits_processor,
|
1904 |
+
trace=trace,
|
1905 |
+
params=params,
|
1906 |
+
model_kwargs=model_kwargs,
|
1907 |
+
condition_scale=condition_scale,
|
1908 |
+
model_kwargs_uncond=model_kwargs_uncond,
|
1909 |
+
)
|
1910 |
+
elif not do_sample and num_beams > 1:
|
1911 |
+
# broadcast input_ids & encoder_outputs
|
1912 |
+
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
|
1913 |
+
|
1914 |
+
if "encoder_outputs" in model_kwargs:
|
1915 |
+
model_kwargs["encoder_outputs"][
|
1916 |
+
"last_hidden_state"
|
1917 |
+
] = self._expand_to_num_beams(
|
1918 |
+
model_kwargs["encoder_outputs"]["last_hidden_state"],
|
1919 |
+
num_beams=num_beams,
|
1920 |
+
)
|
1921 |
+
|
1922 |
+
if "attention_mask" in model_kwargs:
|
1923 |
+
model_kwargs["attention_mask"] = self._expand_to_num_beams(
|
1924 |
+
model_kwargs["attention_mask"], num_beams=num_beams
|
1925 |
+
)
|
1926 |
+
|
1927 |
+
logits_processor = self._get_logits_processor(
|
1928 |
+
no_repeat_ngram_size,
|
1929 |
+
min_length,
|
1930 |
+
max_length,
|
1931 |
+
eos_token_id,
|
1932 |
+
forced_bos_token_id,
|
1933 |
+
forced_eos_token_id,
|
1934 |
+
)
|
1935 |
+
|
1936 |
+
return self._beam_search(
|
1937 |
+
input_ids,
|
1938 |
+
max_length,
|
1939 |
+
pad_token_id,
|
1940 |
+
eos_token_id,
|
1941 |
+
length_penalty=length_penalty,
|
1942 |
+
early_stopping=early_stopping,
|
1943 |
+
logits_processor=logits_processor,
|
1944 |
+
trace=trace,
|
1945 |
+
params=params,
|
1946 |
+
model_kwargs=model_kwargs,
|
1947 |
+
)
|
1948 |
+
else:
|
1949 |
+
raise NotImplementedError("`Beam sampling is currently not implemented.")
|
1950 |
+
|
1951 |
+
def _sample(
|
1952 |
+
self,
|
1953 |
+
input_ids: None,
|
1954 |
+
max_length: Optional[int] = None,
|
1955 |
+
pad_token_id: Optional[int] = None,
|
1956 |
+
eos_token_id: Optional[int] = None,
|
1957 |
+
prng_key: Optional[jnp.ndarray] = None,
|
1958 |
+
logits_processor=None,
|
1959 |
+
logits_warper=None,
|
1960 |
+
trace: bool = True,
|
1961 |
+
params: Optional[Dict[str, jnp.ndarray]] = None,
|
1962 |
+
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
1963 |
+
condition_scale: float = 1.0,
|
1964 |
+
model_kwargs_uncond: Optional[Dict[str, jnp.ndarray]] = None,
|
1965 |
+
):
|
1966 |
+
# init values
|
1967 |
+
max_length = max_length if max_length is not None else self.config.max_length
|
1968 |
+
pad_token_id = (
|
1969 |
+
pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
1970 |
+
)
|
1971 |
+
eos_token_id = (
|
1972 |
+
eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
1973 |
+
)
|
1974 |
+
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
1975 |
+
|
1976 |
+
batch_size, cur_len = input_ids.shape
|
1977 |
+
|
1978 |
+
eos_token_id = jnp.array(eos_token_id)
|
1979 |
+
pad_token_id = jnp.array(pad_token_id)
|
1980 |
+
cur_len = jnp.array(cur_len)
|
1981 |
+
|
1982 |
+
# per batch-item holding current token in loop.
|
1983 |
+
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
|
1984 |
+
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
|
1985 |
+
|
1986 |
+
# per batch-item state bit indicating if sentence has finished.
|
1987 |
+
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
1988 |
+
|
1989 |
+
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
1990 |
+
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
1991 |
+
model = self.decode if self.config.is_encoder_decoder else self
|
1992 |
+
|
1993 |
+
# initialize model specific kwargs
|
1994 |
+
model_kwargs = self.prepare_inputs_for_generation(
|
1995 |
+
input_ids, max_length, **model_kwargs
|
1996 |
+
)
|
1997 |
+
if condition_scale != 1.0:
|
1998 |
+
model_kwargs_uncond = self.prepare_inputs_for_generation(
|
1999 |
+
input_ids, max_length, **model_kwargs_uncond
|
2000 |
+
)
|
2001 |
+
|
2002 |
+
# initialize state
|
2003 |
+
state = SampleState(
|
2004 |
+
cur_len=cur_len,
|
2005 |
+
sequences=sequences,
|
2006 |
+
running_token=input_ids,
|
2007 |
+
is_sent_finished=is_sent_finished,
|
2008 |
+
prng_key=prng_key,
|
2009 |
+
model_kwargs=model_kwargs,
|
2010 |
+
model_kwargs_uncond=model_kwargs_uncond,
|
2011 |
+
)
|
2012 |
+
|
2013 |
+
def sample_search_cond_fn(state):
|
2014 |
+
"""state termination condition fn."""
|
2015 |
+
has_reached_max_length = state.cur_len == max_length
|
2016 |
+
all_sequence_finished = jnp.all(state.is_sent_finished)
|
2017 |
+
finish_generation = jnp.logical_or(
|
2018 |
+
has_reached_max_length, all_sequence_finished
|
2019 |
+
)
|
2020 |
+
return ~finish_generation
|
2021 |
+
|
2022 |
+
def sample_search_body_fn(state):
|
2023 |
+
"""state update fn."""
|
2024 |
+
prng_key, prng_key_next = jax.random.split(state.prng_key)
|
2025 |
+
model_outputs = model(
|
2026 |
+
state.running_token, params=params, **state.model_kwargs
|
2027 |
+
)
|
2028 |
+
|
2029 |
+
logits = model_outputs.logits[:, -1]
|
2030 |
+
|
2031 |
+
# perform super conditioning
|
2032 |
+
# Source: @RiversHaveWings - https://twitter.com/RiversHaveWings/status/1478093658716966912?s=20&t=xdm-wZ61Wf7OLnE_NJHZ1w
|
2033 |
+
if condition_scale != 1.0:
|
2034 |
+
model_outputs_uncond = model(
|
2035 |
+
state.running_token, params=params, **state.model_kwargs_uncond
|
2036 |
+
)
|
2037 |
+
logits_uncond = model_outputs_uncond.logits[:, -1]
|
2038 |
+
logits = logits_uncond + condition_scale * (logits - logits_uncond)
|
2039 |
+
else:
|
2040 |
+
model_outputs_uncond = None
|
2041 |
+
|
2042 |
+
# apply min_length, ...
|
2043 |
+
logits = logits_processor(state.sequences, logits, state.cur_len)
|
2044 |
+
# apply top_k, top_k, temperature
|
2045 |
+
logits = logits_warper(logits, logits, state.cur_len)
|
2046 |
+
|
2047 |
+
next_token = jax.random.categorical(prng_key, logits, axis=-1)
|
2048 |
+
|
2049 |
+
next_is_sent_finished = state.is_sent_finished | (
|
2050 |
+
next_token == eos_token_id
|
2051 |
+
)
|
2052 |
+
next_token = (
|
2053 |
+
next_token * ~next_is_sent_finished
|
2054 |
+
+ pad_token_id * next_is_sent_finished
|
2055 |
+
)
|
2056 |
+
next_token = next_token[:, None]
|
2057 |
+
|
2058 |
+
next_sequences = lax.dynamic_update_slice(
|
2059 |
+
state.sequences, next_token, (0, state.cur_len)
|
2060 |
+
)
|
2061 |
+
next_model_kwargs = self.update_inputs_for_generation(
|
2062 |
+
model_outputs, state.model_kwargs
|
2063 |
+
)
|
2064 |
+
next_model_kwargs_uncond = (
|
2065 |
+
self.update_inputs_for_generation(
|
2066 |
+
model_outputs_uncond, state.model_kwargs_uncond
|
2067 |
+
)
|
2068 |
+
if condition_scale != 1.0
|
2069 |
+
else None
|
2070 |
+
)
|
2071 |
+
|
2072 |
+
return SampleState(
|
2073 |
+
cur_len=state.cur_len + 1,
|
2074 |
+
sequences=next_sequences,
|
2075 |
+
running_token=next_token,
|
2076 |
+
is_sent_finished=next_is_sent_finished,
|
2077 |
+
model_kwargs=next_model_kwargs,
|
2078 |
+
model_kwargs_uncond=next_model_kwargs_uncond,
|
2079 |
+
prng_key=prng_key_next,
|
2080 |
+
)
|
2081 |
+
|
2082 |
+
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
2083 |
+
if input_ids.shape[1] > 1:
|
2084 |
+
state = sample_search_body_fn(state)
|
2085 |
+
|
2086 |
+
if not trace:
|
2087 |
+
state = self._run_loop_in_debug(
|
2088 |
+
sample_search_cond_fn, sample_search_body_fn, state
|
2089 |
+
)
|
2090 |
+
else:
|
2091 |
+
state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
|
2092 |
+
|
2093 |
+
return FlaxSampleOutput(sequences=state.sequences)
|
src/dalle_mini/model/partitions.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
from flax.core.frozen_dict import freeze
|
4 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
5 |
+
from jax.experimental import PartitionSpec as P
|
6 |
+
|
7 |
+
# utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
|
8 |
+
# Sentinels
|
9 |
+
_unmatched = object()
|
10 |
+
|
11 |
+
# For specifying empty leaf dict `{}`
|
12 |
+
empty_dict = object()
|
13 |
+
|
14 |
+
|
15 |
+
def _match(qs, ks):
|
16 |
+
"""Return True if regexes in qs match any window of strings in tuple ks."""
|
17 |
+
# compile regexes and force complete match
|
18 |
+
qts = tuple(map(lambda x: re.compile(x + "$"), qs))
|
19 |
+
for i in range(len(ks) - len(qs) + 1):
|
20 |
+
matches = [x.match(y) for x, y in zip(qts, ks[i:])]
|
21 |
+
if matches and all(matches):
|
22 |
+
return True
|
23 |
+
return False
|
24 |
+
|
25 |
+
|
26 |
+
def _replacement_rules(rules):
|
27 |
+
def replace(key, val):
|
28 |
+
for rule, replacement in rules:
|
29 |
+
if _match(rule, key):
|
30 |
+
return replacement
|
31 |
+
return val
|
32 |
+
|
33 |
+
return replace
|
34 |
+
|
35 |
+
|
36 |
+
def _get_partition_rules():
|
37 |
+
return [
|
38 |
+
# embeddings
|
39 |
+
(("embed_positions", "embedding"), P("mp", None)),
|
40 |
+
(("embed_tokens", "embedding"), P("mp", None)),
|
41 |
+
(("rel_bias", "embedding"), P(None, "mp")),
|
42 |
+
# attention
|
43 |
+
(("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
|
44 |
+
(("out_proj", "kernel"), P("mp", None)),
|
45 |
+
# FFN
|
46 |
+
(("Dense_0", "kernel"), P(None, "mp")),
|
47 |
+
(("GLU.*", "Dense_1", "kernel"), P(None, "mp")),
|
48 |
+
(("GLU.*", "Dense_2", "kernel"), P("mp", None)),
|
49 |
+
(("FFN.*", "Dense_1", "kernel"), P("mp", None)),
|
50 |
+
# layer norms
|
51 |
+
(("(bias|scale)",), None),
|
52 |
+
(("lm_head", "kernel"), P(None, "mp")),
|
53 |
+
# head scale and tau
|
54 |
+
(("(head_scale|tau)",), None),
|
55 |
+
]
|
56 |
+
|
57 |
+
|
58 |
+
def set_partitions(in_dict):
|
59 |
+
rules = _get_partition_rules()
|
60 |
+
replace = _replacement_rules(rules)
|
61 |
+
initd = {k: _unmatched for k in flatten_dict(in_dict)}
|
62 |
+
result = {k: replace(k, v) for k, v in initd.items()}
|
63 |
+
for k, v in result.items():
|
64 |
+
if v == _unmatched:
|
65 |
+
print(f"Unmatched -> {k}")
|
66 |
+
assert _unmatched not in result.values(), "Incomplete partition spec."
|
67 |
+
return freeze(unflatten_dict(result))
|
src/dalle_mini/model/processor.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" DalleBart processor """
|
2 |
+
|
3 |
+
import jax.numpy as jnp
|
4 |
+
|
5 |
+
from .configuration import DalleBartConfig
|
6 |
+
from .text import TextNormalizer
|
7 |
+
from .tokenizer import DalleBartTokenizer
|
8 |
+
from .utils import PretrainedFromWandbMixin
|
9 |
+
|
10 |
+
|
11 |
+
class DalleBartProcessorBase:
|
12 |
+
def __init__(
|
13 |
+
self, tokenizer: DalleBartTokenizer, normalize_text: bool, max_text_length: int
|
14 |
+
):
|
15 |
+
self.tokenizer = tokenizer
|
16 |
+
self.normalize_text = normalize_text
|
17 |
+
self.max_text_length = max_text_length
|
18 |
+
if normalize_text:
|
19 |
+
self.text_processor = TextNormalizer()
|
20 |
+
# create unconditional tokens
|
21 |
+
uncond = self.tokenizer(
|
22 |
+
"",
|
23 |
+
return_tensors="jax",
|
24 |
+
padding="max_length",
|
25 |
+
truncation=True,
|
26 |
+
max_length=self.max_text_length,
|
27 |
+
).data
|
28 |
+
self.input_ids_uncond = uncond["input_ids"]
|
29 |
+
self.attention_mask_uncond = uncond["attention_mask"]
|
30 |
+
|
31 |
+
def __call__(self, text: str = None):
|
32 |
+
# check that text is not a string
|
33 |
+
assert not isinstance(text, str), "text must be a list of strings"
|
34 |
+
|
35 |
+
if self.normalize_text:
|
36 |
+
text = [self.text_processor(t) for t in text]
|
37 |
+
res = self.tokenizer(
|
38 |
+
text,
|
39 |
+
return_tensors="jax",
|
40 |
+
padding="max_length",
|
41 |
+
truncation=True,
|
42 |
+
max_length=self.max_text_length,
|
43 |
+
).data
|
44 |
+
# tokens used only with super conditioning
|
45 |
+
n = len(text)
|
46 |
+
res["input_ids_uncond"] = jnp.repeat(self.input_ids_uncond, n, axis=0)
|
47 |
+
res["attention_mask_uncond"] = jnp.repeat(self.attention_mask_uncond, n, axis=0)
|
48 |
+
return res
|
49 |
+
|
50 |
+
@classmethod
|
51 |
+
def from_pretrained(cls, *args, **kwargs):
|
52 |
+
tokenizer = DalleBartTokenizer.from_pretrained(*args, **kwargs)
|
53 |
+
config = DalleBartConfig.from_pretrained(*args, **kwargs)
|
54 |
+
return cls(tokenizer, config.normalize_text, config.max_text_length)
|
55 |
+
|
56 |
+
|
57 |
+
class DalleBartProcessor(PretrainedFromWandbMixin, DalleBartProcessorBase):
|
58 |
+
pass
|
src/dalle_mini/model/text.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for processing text.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import html
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
import re
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import emoji
|
12 |
+
import ftfy
|
13 |
+
from huggingface_hub import hf_hub_download
|
14 |
+
from unidecode import unidecode
|
15 |
+
|
16 |
+
# based on wiki word occurence
|
17 |
+
person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
|
18 |
+
temp_token = "xtokx" # avoid repeating chars
|
19 |
+
|
20 |
+
|
21 |
+
class HashtagProcessor:
|
22 |
+
# Adapted from wordninja library
|
23 |
+
# We use our wikipedia word count + a good heuristic to make it work
|
24 |
+
def __init__(self):
|
25 |
+
wiki_word_frequency = hf_hub_download(
|
26 |
+
"dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt"
|
27 |
+
)
|
28 |
+
self._word_cost = (
|
29 |
+
l.split()[0]
|
30 |
+
for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines()
|
31 |
+
)
|
32 |
+
self._word_cost = {
|
33 |
+
str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)
|
34 |
+
}
|
35 |
+
self._max_word = max(len(x) for x in self._word_cost.keys())
|
36 |
+
self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+")
|
37 |
+
|
38 |
+
def __call__(self, s):
|
39 |
+
"""Uses dynamic programming to infer the location of spaces in a string without spaces."""
|
40 |
+
l = [self._split(x) for x in self._SPLIT_RE.split(s)]
|
41 |
+
return " ".join([item for sublist in l for item in sublist])
|
42 |
+
|
43 |
+
def _split(self, s):
|
44 |
+
# Find the best match for the i first characters, assuming cost has
|
45 |
+
# been built for the i-1 first characters.
|
46 |
+
# Returns a pair (match_cost, match_length).
|
47 |
+
def best_match(i):
|
48 |
+
candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i]))
|
49 |
+
return min(
|
50 |
+
(c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1)
|
51 |
+
for k, c in candidates
|
52 |
+
)
|
53 |
+
|
54 |
+
# Build the cost array
|
55 |
+
cost = [0]
|
56 |
+
for i in range(1, len(s) + 1):
|
57 |
+
c, k = best_match(i)
|
58 |
+
cost.append(c)
|
59 |
+
|
60 |
+
# Backtrack to recover the minimal-cost string.
|
61 |
+
out = []
|
62 |
+
i = len(s)
|
63 |
+
while i > 0:
|
64 |
+
c, k = best_match(i)
|
65 |
+
assert c == cost[i]
|
66 |
+
newToken = True
|
67 |
+
if not s[i - k : i] == "'": # ignore a lone apostrophe
|
68 |
+
if len(out) > 0:
|
69 |
+
# re-attach split 's and split digits
|
70 |
+
if out[-1] == "'s" or (
|
71 |
+
s[i - 1].isdigit() and out[-1][0].isdigit()
|
72 |
+
): # digit followed by digit
|
73 |
+
out[-1] = (
|
74 |
+
s[i - k : i] + out[-1]
|
75 |
+
) # combine current token with previous token
|
76 |
+
newToken = False
|
77 |
+
|
78 |
+
if newToken:
|
79 |
+
out.append(s[i - k : i])
|
80 |
+
|
81 |
+
i -= k
|
82 |
+
|
83 |
+
return reversed(out)
|
84 |
+
|
85 |
+
|
86 |
+
def replace_person_token(t):
|
87 |
+
"Used for CC12M"
|
88 |
+
t = re.sub("<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
|
89 |
+
while "<person>" in t:
|
90 |
+
t = t.replace(
|
91 |
+
"<person>", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1
|
92 |
+
)
|
93 |
+
return t
|
94 |
+
|
95 |
+
|
96 |
+
def fix_html(t):
|
97 |
+
# from OpenAI CLIP
|
98 |
+
return html.unescape(html.unescape(t))
|
99 |
+
|
100 |
+
|
101 |
+
def replace_punctuation_with_commas(t):
|
102 |
+
return re.sub("[()[\].,|:;?!=+~\-\/{}]", ",", t)
|
103 |
+
|
104 |
+
|
105 |
+
def simplify_quotes(t):
|
106 |
+
return re.sub("""['"`]""", ' " ', t)
|
107 |
+
|
108 |
+
|
109 |
+
def merge_quotes(t):
|
110 |
+
return re.sub('(\s*"+\s*)+', ' " ', t)
|
111 |
+
|
112 |
+
|
113 |
+
def remove_comma_numbers(t):
|
114 |
+
def _f(t):
|
115 |
+
return re.sub("(\d),(\d{3})", r"\1\2", t)
|
116 |
+
|
117 |
+
return _f(_f(t))
|
118 |
+
|
119 |
+
|
120 |
+
def pre_process_dot_numbers(t):
|
121 |
+
return re.sub("(\w)\.(\w)", rf"\1{temp_token}dot{temp_token}\2", t)
|
122 |
+
|
123 |
+
|
124 |
+
def post_process_dot_numbers(t):
|
125 |
+
return re.sub(f"{temp_token}dot{temp_token}", ".", t)
|
126 |
+
|
127 |
+
|
128 |
+
def pre_process_quotes(t):
|
129 |
+
# allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
|
130 |
+
return re.sub(
|
131 |
+
r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", rf"{temp_token}quote{temp_token}", t
|
132 |
+
)
|
133 |
+
|
134 |
+
|
135 |
+
def post_process_quotes(t):
|
136 |
+
return re.sub(f"{temp_token}quote{temp_token}", "'", t)
|
137 |
+
|
138 |
+
|
139 |
+
def pre_process_dates(t):
|
140 |
+
return re.sub("(\d)/(\d)", rf"\1{temp_token}slash{temp_token}\2", t)
|
141 |
+
|
142 |
+
|
143 |
+
def post_process_dates(t):
|
144 |
+
return re.sub(f"{temp_token}slash{temp_token}", "/", t)
|
145 |
+
|
146 |
+
|
147 |
+
def merge_commas(t):
|
148 |
+
return re.sub("(\s*,+\s*)+", ", ", t)
|
149 |
+
|
150 |
+
|
151 |
+
def add_space_after_commas(t):
|
152 |
+
return re.sub(",", ", ", t)
|
153 |
+
|
154 |
+
|
155 |
+
def handle_special_chars(t):
|
156 |
+
"Handle special characters"
|
157 |
+
# replace "-" with a space when between words without space
|
158 |
+
t = re.sub("(\w)-(\w)", r"\1 \2", t)
|
159 |
+
# always add space around some characters
|
160 |
+
return re.sub("([%&\/$*])", r" \1 ", t)
|
161 |
+
|
162 |
+
|
163 |
+
def expand_hashtags(t, hashtag_processor):
|
164 |
+
"Remove # and try to split words"
|
165 |
+
return re.sub("#(\w+)", lambda m: hashtag_processor(m.group(1)), t)
|
166 |
+
|
167 |
+
|
168 |
+
_re_ignore_chars = r"[_#\\]"
|
169 |
+
|
170 |
+
|
171 |
+
def ignore_chars(t):
|
172 |
+
"Ignore useless characters"
|
173 |
+
return re.sub(_re_ignore_chars, " ", t)
|
174 |
+
|
175 |
+
|
176 |
+
def remove_extra_spaces(t):
|
177 |
+
"Remove extra spaces (including \t and \n)"
|
178 |
+
return re.sub("\s+", " ", t)
|
179 |
+
|
180 |
+
|
181 |
+
def remove_repeating_chars(t):
|
182 |
+
"If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance"
|
183 |
+
return re.sub(r"(\D)(\1{3,})", r"\1", t)
|
184 |
+
|
185 |
+
|
186 |
+
def remove_urls(t):
|
187 |
+
return re.sub(r"http\S+", "", t)
|
188 |
+
|
189 |
+
|
190 |
+
def remove_html_tags(t):
|
191 |
+
return re.sub("<[^<]+?>", "", t)
|
192 |
+
|
193 |
+
|
194 |
+
def remove_first_last_commas(t):
|
195 |
+
t = t.strip()
|
196 |
+
t = t[:-1] if t and t[-1] == "," else t
|
197 |
+
t = t[1:] if t and t[0] == "," else t
|
198 |
+
return t.strip()
|
199 |
+
|
200 |
+
|
201 |
+
def remove_wiki_ref(t):
|
202 |
+
t = re.sub(r"\A\s*\[\d+\]", "", t)
|
203 |
+
return re.sub(r"\[\d+\]\s*\Z", "", t)
|
204 |
+
|
205 |
+
|
206 |
+
class TextNormalizer:
|
207 |
+
"Normalize text"
|
208 |
+
|
209 |
+
def __init__(self):
|
210 |
+
self._hashtag_processor = HashtagProcessor()
|
211 |
+
|
212 |
+
def __call__(self, t):
|
213 |
+
# fix some characters
|
214 |
+
t = ftfy.fix_text(t)
|
215 |
+
# fix html
|
216 |
+
t = fix_html(t)
|
217 |
+
# decode emojis (would be removed by unidecode)
|
218 |
+
t = emoji.demojize(t)
|
219 |
+
# decode and simplify text: see unidecode library
|
220 |
+
t = unidecode(t)
|
221 |
+
# lower case
|
222 |
+
t = t.lower()
|
223 |
+
# replace <PERSON> (for CC12M)
|
224 |
+
t = replace_person_token(t)
|
225 |
+
# remove wiki reference (for WIT)
|
226 |
+
t = remove_wiki_ref(t)
|
227 |
+
# remove html tags
|
228 |
+
t = remove_html_tags(t)
|
229 |
+
# remove urls
|
230 |
+
t = remove_urls(t)
|
231 |
+
# remove commas in numbers
|
232 |
+
t = remove_comma_numbers(t)
|
233 |
+
# handle dots in numbers and quotes - Part 1
|
234 |
+
t = pre_process_dot_numbers(t)
|
235 |
+
t = pre_process_quotes(t)
|
236 |
+
t = pre_process_dates(t)
|
237 |
+
# handle special characters
|
238 |
+
t = handle_special_chars(t)
|
239 |
+
# handle hashtags
|
240 |
+
t = expand_hashtags(t, self._hashtag_processor)
|
241 |
+
# ignore useless characters
|
242 |
+
t = ignore_chars(t)
|
243 |
+
# simplify quotes
|
244 |
+
t = simplify_quotes(t)
|
245 |
+
# all punctuation becomes commas
|
246 |
+
t = replace_punctuation_with_commas(t)
|
247 |
+
# handle dots in numbers and quotes - Part 2
|
248 |
+
t = post_process_dot_numbers(t)
|
249 |
+
t = post_process_quotes(t)
|
250 |
+
t = post_process_dates(t)
|
251 |
+
# handle repeating characters
|
252 |
+
t = remove_repeating_chars(t)
|
253 |
+
# merge quotes
|
254 |
+
t = merge_quotes(t)
|
255 |
+
# merge commas
|
256 |
+
t = merge_commas(t)
|
257 |
+
# remove multiple spaces
|
258 |
+
t = remove_extra_spaces(t)
|
259 |
+
# remove first and last comma
|
260 |
+
t = remove_first_last_commas(t)
|
261 |
+
# always start with a space
|
262 |
+
return f" {t}"
|
src/dalle_mini/model/tokenizer.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" DalleBart tokenizer """
|
2 |
+
from transformers import BartTokenizerFast
|
3 |
+
|
4 |
+
from .utils import PretrainedFromWandbMixin
|
5 |
+
|
6 |
+
|
7 |
+
class DalleBartTokenizer(PretrainedFromWandbMixin, BartTokenizerFast):
|
8 |
+
pass
|
src/dalle_mini/model/utils.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import wandb
|
6 |
+
|
7 |
+
|
8 |
+
class PretrainedFromWandbMixin:
|
9 |
+
@classmethod
|
10 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
11 |
+
"""
|
12 |
+
Initializes from a wandb artifact or delegates loading to the superclass.
|
13 |
+
"""
|
14 |
+
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
|
15 |
+
if ":" in pretrained_model_name_or_path and not os.path.isdir(
|
16 |
+
pretrained_model_name_or_path
|
17 |
+
):
|
18 |
+
# wandb artifact
|
19 |
+
if wandb.run is not None:
|
20 |
+
artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
|
21 |
+
else:
|
22 |
+
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
23 |
+
pretrained_model_name_or_path = artifact.download(tmp_dir)
|
24 |
+
|
25 |
+
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|
26 |
+
pretrained_model_name_or_path, *model_args, **kwargs
|
27 |
+
)
|
tools/dataset/encode_dataset.ipynb
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "d0b72877",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Pre-encoding a dataset for DALLE·mini"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "markdown",
|
13 |
+
"id": "ba7b31e6",
|
14 |
+
"metadata": {},
|
15 |
+
"source": [
|
16 |
+
"This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
|
17 |
+
"\n",
|
18 |
+
"Adapt it to your own dataset and image encoder.\n",
|
19 |
+
"\n",
|
20 |
+
"At the end you should have a dataset of pairs:\n",
|
21 |
+
"* a caption defined as a string\n",
|
22 |
+
"* an encoded image defined as a list of int."
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": null,
|
28 |
+
"id": "3b59489e",
|
29 |
+
"metadata": {},
|
30 |
+
"outputs": [],
|
31 |
+
"source": [
|
32 |
+
"from tqdm.notebook import tqdm\n",
|
33 |
+
"\n",
|
34 |
+
"import torchvision.transforms as T\n",
|
35 |
+
"\n",
|
36 |
+
"import webdataset as wds\n",
|
37 |
+
"\n",
|
38 |
+
"import jax\n",
|
39 |
+
"import braceexpand\n",
|
40 |
+
"from pathlib import Path"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "markdown",
|
45 |
+
"id": "c7c4c1e6",
|
46 |
+
"metadata": {},
|
47 |
+
"source": [
|
48 |
+
"## Configuration Parameters"
|
49 |
+
]
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "code",
|
53 |
+
"execution_count": 3,
|
54 |
+
"id": "1265dbfe",
|
55 |
+
"metadata": {},
|
56 |
+
"outputs": [],
|
57 |
+
"source": [
|
58 |
+
"shards = \"my_images/shard-{0000..0008}.tar\" # defined using braceexpand format as used by webdataset\n",
|
59 |
+
"encoded_output = Path(\"encoded_data\") # where we will save our encoded data\n",
|
60 |
+
"\n",
|
61 |
+
"VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
|
62 |
+
" \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
|
63 |
+
" \"85eb5d3b51a1c62a0cc8f4ccdee9882c0d0bd384\",\n",
|
64 |
+
")\n",
|
65 |
+
"\n",
|
66 |
+
"# good defaults for a TPU v3-8\n",
|
67 |
+
"batch_size = 128 # Per device\n",
|
68 |
+
"num_workers = 8 # For parallel processing\n",
|
69 |
+
"total_bs = batch_size * jax.device_count() # You can use a smaller size while testing\n",
|
70 |
+
"save_frequency = 128 # Number of batches to create a new file (180MB for f16 and 720MB for f8 per file)"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "code",
|
75 |
+
"execution_count": 5,
|
76 |
+
"id": "cd956ec6-7d98-4d4d-a454-f80fe857eadd",
|
77 |
+
"metadata": {},
|
78 |
+
"outputs": [
|
79 |
+
{
|
80 |
+
"data": {
|
81 |
+
"text/plain": [
|
82 |
+
"['XXX/shard-0000.tar',\n",
|
83 |
+
" 'XXX/shard-0001.tar',\n",
|
84 |
+
" 'XXX/shard-0002.tar',\n",
|
85 |
+
" 'XXX/shard-0003.tar',\n",
|
86 |
+
" 'XXX/shard-0004.tar',\n",
|
87 |
+
" 'XXX/shard-0005.tar',\n",
|
88 |
+
" 'XXX/shard-0006.tar',\n",
|
89 |
+
" 'XXX/shard-0007.tar',\n",
|
90 |
+
" 'XXX/shard-0008.tar']"
|
91 |
+
]
|
92 |
+
},
|
93 |
+
"execution_count": 5,
|
94 |
+
"metadata": {},
|
95 |
+
"output_type": "execute_result"
|
96 |
+
}
|
97 |
+
],
|
98 |
+
"source": [
|
99 |
+
"shards = list(\n",
|
100 |
+
" braceexpand.braceexpand(shards)\n",
|
101 |
+
") # better display for tqdm with known length"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "markdown",
|
106 |
+
"id": "75dba8e2",
|
107 |
+
"metadata": {},
|
108 |
+
"source": [
|
109 |
+
"## Load data"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"cell_type": "markdown",
|
114 |
+
"id": "a1e8fb95",
|
115 |
+
"metadata": {},
|
116 |
+
"source": [
|
117 |
+
"We load data using `webdataset`."
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": null,
|
123 |
+
"id": "9ef5de9e",
|
124 |
+
"metadata": {},
|
125 |
+
"outputs": [],
|
126 |
+
"source": [
|
127 |
+
"ds = (\n",
|
128 |
+
" wds.WebDataset(shards, handler=wds.warn_and_continue)\n",
|
129 |
+
" .decode(\"rgb\", handler=wds.warn_and_continue)\n",
|
130 |
+
" .to_tuple(\"jpg\", \"txt\") # assumes image is in `jpg` and caption in `txt`\n",
|
131 |
+
" .batched(total_bs) # load in batch per worker (faster)\n",
|
132 |
+
")"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"cell_type": "markdown",
|
137 |
+
"id": "90981824",
|
138 |
+
"metadata": {},
|
139 |
+
"source": [
|
140 |
+
"Note:\n",
|
141 |
+
"* you can also shuffle shards and items using `shardshuffle` and `shuffle` if necessary.\n",
|
142 |
+
"* you may need to resize images in your pipeline (with `map_dict` for example), we assume they are already set to 256x256.\n",
|
143 |
+
"* you can also filter out some items using `select`."
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "markdown",
|
148 |
+
"id": "129c377d",
|
149 |
+
"metadata": {},
|
150 |
+
"source": [
|
151 |
+
"We can now inspect our data."
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"cell_type": "code",
|
156 |
+
"execution_count": null,
|
157 |
+
"id": "8cac98cb",
|
158 |
+
"metadata": {
|
159 |
+
"scrolled": true
|
160 |
+
},
|
161 |
+
"outputs": [],
|
162 |
+
"source": [
|
163 |
+
"%%time\n",
|
164 |
+
"images, captions = next(iter(ds))"
|
165 |
+
]
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"cell_type": "code",
|
169 |
+
"execution_count": null,
|
170 |
+
"id": "cd268fbf",
|
171 |
+
"metadata": {},
|
172 |
+
"outputs": [],
|
173 |
+
"source": [
|
174 |
+
"images.shape"
|
175 |
+
]
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"cell_type": "code",
|
179 |
+
"execution_count": null,
|
180 |
+
"id": "5acfc4d8",
|
181 |
+
"metadata": {},
|
182 |
+
"outputs": [],
|
183 |
+
"source": [
|
184 |
+
"captions[:10]"
|
185 |
+
]
|
186 |
+
},
|
187 |
+
{
|
188 |
+
"cell_type": "code",
|
189 |
+
"execution_count": null,
|
190 |
+
"id": "c24693c0",
|
191 |
+
"metadata": {},
|
192 |
+
"outputs": [],
|
193 |
+
"source": [
|
194 |
+
"T.ToPILImage()(images[0].permute(2, 0, 1))"
|
195 |
+
]
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"cell_type": "markdown",
|
199 |
+
"id": "3059ffb1",
|
200 |
+
"metadata": {},
|
201 |
+
"source": [
|
202 |
+
"Finally we create our dataloader."
|
203 |
+
]
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"cell_type": "code",
|
207 |
+
"execution_count": null,
|
208 |
+
"id": "c227c551",
|
209 |
+
"metadata": {},
|
210 |
+
"outputs": [],
|
211 |
+
"source": [
|
212 |
+
"dl = (\n",
|
213 |
+
" wds.WebLoader(ds, batch_size=None, num_workers=8).unbatched().batched(total_bs)\n",
|
214 |
+
") # avoid partial batch at the end of each worker"
|
215 |
+
]
|
216 |
+
},
|
217 |
+
{
|
218 |
+
"cell_type": "markdown",
|
219 |
+
"id": "a354472b",
|
220 |
+
"metadata": {},
|
221 |
+
"source": [
|
222 |
+
"## Image encoder\n",
|
223 |
+
"\n",
|
224 |
+
"We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
|
225 |
+
]
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "code",
|
229 |
+
"execution_count": null,
|
230 |
+
"id": "47a8b818",
|
231 |
+
"metadata": {
|
232 |
+
"scrolled": true
|
233 |
+
},
|
234 |
+
"outputs": [],
|
235 |
+
"source": [
|
236 |
+
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
237 |
+
"from flax.jax_utils import replicate\n",
|
238 |
+
"\n",
|
239 |
+
"vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")\n",
|
240 |
+
"vqgan_params = replicate(vqgan.params)"
|
241 |
+
]
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"cell_type": "markdown",
|
245 |
+
"id": "62ad01c3",
|
246 |
+
"metadata": {},
|
247 |
+
"source": [
|
248 |
+
"## Encoding"
|
249 |
+
]
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"cell_type": "markdown",
|
253 |
+
"id": "20357f74",
|
254 |
+
"metadata": {},
|
255 |
+
"source": [
|
256 |
+
"Encoding is really simple using `shard` to automatically distribute batches across devices and `pmap`."
|
257 |
+
]
|
258 |
+
},
|
259 |
+
{
|
260 |
+
"cell_type": "code",
|
261 |
+
"execution_count": null,
|
262 |
+
"id": "322a4619",
|
263 |
+
"metadata": {},
|
264 |
+
"outputs": [],
|
265 |
+
"source": [
|
266 |
+
"from flax.training.common_utils import shard\n",
|
267 |
+
"from functools import partial\n",
|
268 |
+
"\n",
|
269 |
+
"\n",
|
270 |
+
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
271 |
+
"def p_encode(batch, params):\n",
|
272 |
+
" # Not sure if we should `replicate` params, does not seem to have any effect\n",
|
273 |
+
" _, indices = vqgan.encode(batch, params=params)\n",
|
274 |
+
" return indices"
|
275 |
+
]
|
276 |
+
},
|
277 |
+
{
|
278 |
+
"cell_type": "code",
|
279 |
+
"execution_count": null,
|
280 |
+
"id": "ff6c10d4",
|
281 |
+
"metadata": {},
|
282 |
+
"outputs": [],
|
283 |
+
"source": [
|
284 |
+
"import pandas as pd\n",
|
285 |
+
"\n",
|
286 |
+
"\n",
|
287 |
+
"def encode_dataset(dataloader, output_dir, save_frequency):\n",
|
288 |
+
" output_dir.mkdir(parents=True, exist_ok=True)\n",
|
289 |
+
" all_captions = []\n",
|
290 |
+
" all_encoding = []\n",
|
291 |
+
" n_file = 1\n",
|
292 |
+
" for idx, (images, captions) in enumerate(tqdm(dataloader)):\n",
|
293 |
+
" images = images.numpy()\n",
|
294 |
+
" n = len(images) // 8 * 8\n",
|
295 |
+
" if n != len(images):\n",
|
296 |
+
" # get the max number of images we can (multiple of 8)\n",
|
297 |
+
" print(f\"Different sizes {n} vs {len(images)}\")\n",
|
298 |
+
" images = images[:n]\n",
|
299 |
+
" captions = captions[:n]\n",
|
300 |
+
" if not len(captions):\n",
|
301 |
+
" print(f\"No images/captions in batch...\")\n",
|
302 |
+
" continue\n",
|
303 |
+
" images = shard(images)\n",
|
304 |
+
" encoded = p_encode(images, vqgan_params)\n",
|
305 |
+
" encoded = encoded.reshape(-1, encoded.shape[-1])\n",
|
306 |
+
" all_captions.extend(captions)\n",
|
307 |
+
" all_encoding.extend(encoded.tolist())\n",
|
308 |
+
"\n",
|
309 |
+
" # save files\n",
|
310 |
+
" if (idx + 1) % save_frequency == 0:\n",
|
311 |
+
" print(f\"Saving file {n_file}\")\n",
|
312 |
+
" batch_df = pd.DataFrame.from_dict(\n",
|
313 |
+
" {\"caption\": all_captions, \"encoding\": all_encoding}\n",
|
314 |
+
" )\n",
|
315 |
+
" batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")\n",
|
316 |
+
" all_captions = []\n",
|
317 |
+
" all_encoding = []\n",
|
318 |
+
" n_file += 1\n",
|
319 |
+
"\n",
|
320 |
+
" if len(all_captions):\n",
|
321 |
+
" print(f\"Saving final file {n_file}\")\n",
|
322 |
+
" batch_df = pd.DataFrame.from_dict(\n",
|
323 |
+
" {\"caption\": all_captions, \"encoding\": all_encoding}\n",
|
324 |
+
" )\n",
|
325 |
+
" batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")"
|
326 |
+
]
|
327 |
+
},
|
328 |
+
{
|
329 |
+
"cell_type": "code",
|
330 |
+
"execution_count": null,
|
331 |
+
"id": "7704863d",
|
332 |
+
"metadata": {},
|
333 |
+
"outputs": [],
|
334 |
+
"source": [
|
335 |
+
"encode_dataset(dl, output_dir=encoded_output, save_frequency=save_frequency)"
|
336 |
+
]
|
337 |
+
},
|
338 |
+
{
|
339 |
+
"cell_type": "markdown",
|
340 |
+
"id": "8953dd84",
|
341 |
+
"metadata": {},
|
342 |
+
"source": [
|
343 |
+
"----"
|
344 |
+
]
|
345 |
+
}
|
346 |
+
],
|
347 |
+
"metadata": {
|
348 |
+
"interpreter": {
|
349 |
+
"hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
|
350 |
+
},
|
351 |
+
"kernelspec": {
|
352 |
+
"display_name": "Python 3 (ipykernel)",
|
353 |
+
"language": "python",
|
354 |
+
"name": "python3"
|
355 |
+
},
|
356 |
+
"language_info": {
|
357 |
+
"codemirror_mode": {
|
358 |
+
"name": "ipython",
|
359 |
+
"version": 3
|
360 |
+
},
|
361 |
+
"file_extension": ".py",
|
362 |
+
"mimetype": "text/x-python",
|
363 |
+
"name": "python",
|
364 |
+
"nbconvert_exporter": "python",
|
365 |
+
"pygments_lexer": "ipython3",
|
366 |
+
"version": "3.9.7"
|
367 |
+
}
|
368 |
+
},
|
369 |
+
"nbformat": 4,
|
370 |
+
"nbformat_minor": 5
|
371 |
+
}
|
tools/inference/inference_pipeline.ipynb
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"colab_type": "text",
|
7 |
+
"id": "view-in-github"
|
8 |
+
},
|
9 |
+
"source": [
|
10 |
+
"<a href=\"https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"metadata": {
|
16 |
+
"id": "118UKH5bWCGa"
|
17 |
+
},
|
18 |
+
"source": [
|
19 |
+
"# DALL·E mini - Inference pipeline\n",
|
20 |
+
"\n",
|
21 |
+
"*Generate images from a text prompt*\n",
|
22 |
+
"\n",
|
23 |
+
"<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
|
24 |
+
"\n",
|
25 |
+
"This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
|
26 |
+
"\n",
|
27 |
+
"Just want to play? Use [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).\n",
|
28 |
+
"\n",
|
29 |
+
"For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "markdown",
|
34 |
+
"metadata": {
|
35 |
+
"id": "dS8LbaonYm3a"
|
36 |
+
},
|
37 |
+
"source": [
|
38 |
+
"## 🛠️ Installation and set-up"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": null,
|
44 |
+
"metadata": {
|
45 |
+
"id": "uzjAM2GBYpZX"
|
46 |
+
},
|
47 |
+
"outputs": [],
|
48 |
+
"source": [
|
49 |
+
"# Install required libraries\n",
|
50 |
+
"!pip install -q git+https://github.com/huggingface/transformers.git\n",
|
51 |
+
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
|
52 |
+
"!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "markdown",
|
57 |
+
"metadata": {
|
58 |
+
"id": "ozHzTkyv8cqU"
|
59 |
+
},
|
60 |
+
"source": [
|
61 |
+
"We load required models:\n",
|
62 |
+
"* dalle·mini for text to encoded images\n",
|
63 |
+
"* VQGAN for decoding images\n",
|
64 |
+
"* CLIP for scoring predictions"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "code",
|
69 |
+
"execution_count": null,
|
70 |
+
"metadata": {
|
71 |
+
"id": "K6CxW2o42f-w"
|
72 |
+
},
|
73 |
+
"outputs": [],
|
74 |
+
"source": [
|
75 |
+
"# Model references\n",
|
76 |
+
"\n",
|
77 |
+
"# dalle-mini\n",
|
78 |
+
"DALLE_MODEL = \"dalle-mini/dalle-mini/model-3f0lem84:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
|
79 |
+
"DALLE_COMMIT_ID = None\n",
|
80 |
+
"\n",
|
81 |
+
"# VQGAN model\n",
|
82 |
+
"VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
|
83 |
+
"VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
|
84 |
+
"\n",
|
85 |
+
"# CLIP model\n",
|
86 |
+
"CLIP_REPO = \"openai/clip-vit-large-patch14\"\n",
|
87 |
+
"CLIP_COMMIT_ID = None"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": null,
|
93 |
+
"metadata": {
|
94 |
+
"id": "Yv-aR3t4Oe5v"
|
95 |
+
},
|
96 |
+
"outputs": [],
|
97 |
+
"source": [
|
98 |
+
"import jax\n",
|
99 |
+
"import jax.numpy as jnp\n",
|
100 |
+
"\n",
|
101 |
+
"# check how many devices are available\n",
|
102 |
+
"jax.local_device_count()"
|
103 |
+
]
|
104 |
+
},
|
105 |
+
{
|
106 |
+
"cell_type": "code",
|
107 |
+
"execution_count": null,
|
108 |
+
"metadata": {
|
109 |
+
"id": "HWnQrQuXOe5w"
|
110 |
+
},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"# type used for computation - use bfloat16 on TPU's\n",
|
114 |
+
"dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n",
|
115 |
+
"\n",
|
116 |
+
"# TODO: fix issue with bfloat16\n",
|
117 |
+
"dtype = jnp.float32"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": null,
|
123 |
+
"metadata": {
|
124 |
+
"id": "92zYmvsQ38vL"
|
125 |
+
},
|
126 |
+
"outputs": [],
|
127 |
+
"source": [
|
128 |
+
"# Load models & tokenizer\n",
|
129 |
+
"from dalle_mini import DalleBart, DalleBartProcessor\n",
|
130 |
+
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
131 |
+
"from transformers import CLIPProcessor, FlaxCLIPModel\n",
|
132 |
+
"\n",
|
133 |
+
"# Load dalle-mini\n",
|
134 |
+
"model = DalleBart.from_pretrained(\n",
|
135 |
+
" DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
|
136 |
+
")\n",
|
137 |
+
"\n",
|
138 |
+
"# Load VQGAN\n",
|
139 |
+
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
|
140 |
+
"\n",
|
141 |
+
"# Load CLIP\n",
|
142 |
+
"clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
|
143 |
+
"clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)"
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "markdown",
|
148 |
+
"metadata": {
|
149 |
+
"id": "o_vH2X1tDtzA"
|
150 |
+
},
|
151 |
+
"source": [
|
152 |
+
"Model parameters are replicated on each device for faster inference."
|
153 |
+
]
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "code",
|
157 |
+
"execution_count": null,
|
158 |
+
"metadata": {
|
159 |
+
"id": "wtvLoM48EeVw"
|
160 |
+
},
|
161 |
+
"outputs": [],
|
162 |
+
"source": [
|
163 |
+
"from flax.jax_utils import replicate\n",
|
164 |
+
"\n",
|
165 |
+
"# convert model parameters for inference if requested\n",
|
166 |
+
"if dtype == jnp.bfloat16:\n",
|
167 |
+
" model.params = model.to_bf16(model.params)\n",
|
168 |
+
"\n",
|
169 |
+
"model._params = replicate(model.params)\n",
|
170 |
+
"vqgan._params = replicate(vqgan.params)\n",
|
171 |
+
"clip._params = replicate(clip.params)"
|
172 |
+
]
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"cell_type": "markdown",
|
176 |
+
"metadata": {
|
177 |
+
"id": "0A9AHQIgZ_qw"
|
178 |
+
},
|
179 |
+
"source": [
|
180 |
+
"Model functions are compiled and parallelized to take advantage of multiple devices."
|
181 |
+
]
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"cell_type": "code",
|
185 |
+
"execution_count": null,
|
186 |
+
"metadata": {
|
187 |
+
"id": "sOtoOmYsSYPz"
|
188 |
+
},
|
189 |
+
"outputs": [],
|
190 |
+
"source": [
|
191 |
+
"from functools import partial\n",
|
192 |
+
"\n",
|
193 |
+
"# model inference\n",
|
194 |
+
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
|
195 |
+
"def p_generate(\n",
|
196 |
+
" tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
|
197 |
+
"):\n",
|
198 |
+
" return model.generate(\n",
|
199 |
+
" **tokenized_prompt,\n",
|
200 |
+
" prng_key=key,\n",
|
201 |
+
" params=params,\n",
|
202 |
+
" top_k=top_k,\n",
|
203 |
+
" top_p=top_p,\n",
|
204 |
+
" temperature=temperature,\n",
|
205 |
+
" condition_scale=condition_scale,\n",
|
206 |
+
" )\n",
|
207 |
+
"\n",
|
208 |
+
"\n",
|
209 |
+
"# decode images\n",
|
210 |
+
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
211 |
+
"def p_decode(indices, params):\n",
|
212 |
+
" return vqgan.decode_code(indices, params=params)\n",
|
213 |
+
"\n",
|
214 |
+
"\n",
|
215 |
+
"# score images\n",
|
216 |
+
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
217 |
+
"def p_clip(inputs, params):\n",
|
218 |
+
" logits = clip(params=params, **inputs).logits_per_image\n",
|
219 |
+
" return logits"
|
220 |
+
]
|
221 |
+
},
|
222 |
+
{
|
223 |
+
"cell_type": "markdown",
|
224 |
+
"metadata": {
|
225 |
+
"id": "HmVN6IBwapBA"
|
226 |
+
},
|
227 |
+
"source": [
|
228 |
+
"Keys are passed to the model on each device to generate unique inference per device."
|
229 |
+
]
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"cell_type": "code",
|
233 |
+
"execution_count": null,
|
234 |
+
"metadata": {
|
235 |
+
"id": "4CTXmlUkThhX"
|
236 |
+
},
|
237 |
+
"outputs": [],
|
238 |
+
"source": [
|
239 |
+
"import random\n",
|
240 |
+
"\n",
|
241 |
+
"# create a random key\n",
|
242 |
+
"seed = random.randint(0, 2**32 - 1)\n",
|
243 |
+
"key = jax.random.PRNGKey(seed)"
|
244 |
+
]
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"cell_type": "markdown",
|
248 |
+
"metadata": {
|
249 |
+
"id": "BrnVyCo81pij"
|
250 |
+
},
|
251 |
+
"source": [
|
252 |
+
"## 🖍 Text Prompt"
|
253 |
+
]
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"cell_type": "markdown",
|
257 |
+
"metadata": {
|
258 |
+
"id": "rsmj0Aj5OQox"
|
259 |
+
},
|
260 |
+
"source": [
|
261 |
+
"Our model requires processing prompts."
|
262 |
+
]
|
263 |
+
},
|
264 |
+
{
|
265 |
+
"cell_type": "code",
|
266 |
+
"execution_count": null,
|
267 |
+
"metadata": {
|
268 |
+
"id": "YjjhUychOVxm"
|
269 |
+
},
|
270 |
+
"outputs": [],
|
271 |
+
"source": [
|
272 |
+
"from dalle_mini import DalleBartProcessor\n",
|
273 |
+
"\n",
|
274 |
+
"processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
|
275 |
+
]
|
276 |
+
},
|
277 |
+
{
|
278 |
+
"cell_type": "markdown",
|
279 |
+
"metadata": {
|
280 |
+
"id": "BQ7fymSPyvF_"
|
281 |
+
},
|
282 |
+
"source": [
|
283 |
+
"Let's define a text prompt."
|
284 |
+
]
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"cell_type": "code",
|
288 |
+
"execution_count": null,
|
289 |
+
"metadata": {
|
290 |
+
"id": "x_0vI9ge1oKr"
|
291 |
+
},
|
292 |
+
"outputs": [],
|
293 |
+
"source": [
|
294 |
+
"prompt = \"sunset over the lake in the mountains\""
|
295 |
+
]
|
296 |
+
},
|
297 |
+
{
|
298 |
+
"cell_type": "code",
|
299 |
+
"execution_count": null,
|
300 |
+
"metadata": {
|
301 |
+
"id": "VKjEZGjtO49k"
|
302 |
+
},
|
303 |
+
"outputs": [],
|
304 |
+
"source": [
|
305 |
+
"tokenized_prompt = processor([prompt])"
|
306 |
+
]
|
307 |
+
},
|
308 |
+
{
|
309 |
+
"cell_type": "markdown",
|
310 |
+
"metadata": {
|
311 |
+
"id": "-CEJBnuJOe5z"
|
312 |
+
},
|
313 |
+
"source": [
|
314 |
+
"Finally we replicate it onto each device."
|
315 |
+
]
|
316 |
+
},
|
317 |
+
{
|
318 |
+
"cell_type": "code",
|
319 |
+
"execution_count": null,
|
320 |
+
"metadata": {
|
321 |
+
"id": "lQePgju5Oe5z"
|
322 |
+
},
|
323 |
+
"outputs": [],
|
324 |
+
"source": [
|
325 |
+
"tokenized_prompt = replicate(tokenized_prompt)"
|
326 |
+
]
|
327 |
+
},
|
328 |
+
{
|
329 |
+
"cell_type": "markdown",
|
330 |
+
"metadata": {
|
331 |
+
"id": "phQ9bhjRkgAZ"
|
332 |
+
},
|
333 |
+
"source": [
|
334 |
+
"## 🎨 Generate images\n",
|
335 |
+
"\n",
|
336 |
+
"We generate images using dalle-mini model and decode them with the VQGAN."
|
337 |
+
]
|
338 |
+
},
|
339 |
+
{
|
340 |
+
"cell_type": "code",
|
341 |
+
"execution_count": null,
|
342 |
+
"metadata": {
|
343 |
+
"id": "d0wVkXpKqnHA"
|
344 |
+
},
|
345 |
+
"outputs": [],
|
346 |
+
"source": [
|
347 |
+
"# number of predictions\n",
|
348 |
+
"n_predictions = 32\n",
|
349 |
+
"\n",
|
350 |
+
"# We can customize top_k/top_p used for generating samples\n",
|
351 |
+
"gen_top_k = None\n",
|
352 |
+
"gen_top_p = None\n",
|
353 |
+
"temperature = 0.85\n",
|
354 |
+
"cond_scale = 3.0"
|
355 |
+
]
|
356 |
+
},
|
357 |
+
{
|
358 |
+
"cell_type": "code",
|
359 |
+
"execution_count": null,
|
360 |
+
"metadata": {
|
361 |
+
"id": "SDjEx9JxR3v8"
|
362 |
+
},
|
363 |
+
"outputs": [],
|
364 |
+
"source": [
|
365 |
+
"from flax.training.common_utils import shard_prng_key\n",
|
366 |
+
"import numpy as np\n",
|
367 |
+
"from PIL import Image\n",
|
368 |
+
"from tqdm.notebook import trange\n",
|
369 |
+
"\n",
|
370 |
+
"# generate images\n",
|
371 |
+
"images = []\n",
|
372 |
+
"for i in trange(n_predictions // jax.device_count()):\n",
|
373 |
+
" # get a new key\n",
|
374 |
+
" key, subkey = jax.random.split(key)\n",
|
375 |
+
" # generate images\n",
|
376 |
+
" encoded_images = p_generate(\n",
|
377 |
+
" tokenized_prompt,\n",
|
378 |
+
" shard_prng_key(subkey),\n",
|
379 |
+
" model.params,\n",
|
380 |
+
" gen_top_k,\n",
|
381 |
+
" gen_top_p,\n",
|
382 |
+
" temperature,\n",
|
383 |
+
" cond_scale,\n",
|
384 |
+
" )\n",
|
385 |
+
" # remove BOS\n",
|
386 |
+
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
387 |
+
" # decode images\n",
|
388 |
+
" decoded_images = p_decode(encoded_images, vqgan.params)\n",
|
389 |
+
" decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
|
390 |
+
" for img in decoded_images:\n",
|
391 |
+
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
392 |
+
]
|
393 |
+
},
|
394 |
+
{
|
395 |
+
"cell_type": "markdown",
|
396 |
+
"metadata": {
|
397 |
+
"id": "tw02wG9zGmyB"
|
398 |
+
},
|
399 |
+
"source": [
|
400 |
+
"Let's calculate their score with CLIP."
|
401 |
+
]
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"cell_type": "code",
|
405 |
+
"execution_count": null,
|
406 |
+
"metadata": {
|
407 |
+
"id": "FoLXpjCmGpju"
|
408 |
+
},
|
409 |
+
"outputs": [],
|
410 |
+
"source": [
|
411 |
+
"from flax.training.common_utils import shard\n",
|
412 |
+
"\n",
|
413 |
+
"# get clip scores\n",
|
414 |
+
"clip_inputs = clip_processor(\n",
|
415 |
+
" text=[prompt] * jax.device_count(),\n",
|
416 |
+
" images=images,\n",
|
417 |
+
" return_tensors=\"np\",\n",
|
418 |
+
" padding=\"max_length\",\n",
|
419 |
+
" max_length=77,\n",
|
420 |
+
" truncation=True,\n",
|
421 |
+
").data\n",
|
422 |
+
"logits = p_clip(shard(clip_inputs), clip.params)\n",
|
423 |
+
"logits = logits.squeeze().flatten()"
|
424 |
+
]
|
425 |
+
},
|
426 |
+
{
|
427 |
+
"cell_type": "markdown",
|
428 |
+
"metadata": {
|
429 |
+
"id": "4AAWRm70LgED"
|
430 |
+
},
|
431 |
+
"source": [
|
432 |
+
"Let's display images ranked by CLIP score."
|
433 |
+
]
|
434 |
+
},
|
435 |
+
{
|
436 |
+
"cell_type": "code",
|
437 |
+
"execution_count": null,
|
438 |
+
"metadata": {
|
439 |
+
"id": "zsgxxubLLkIu"
|
440 |
+
},
|
441 |
+
"outputs": [],
|
442 |
+
"source": [
|
443 |
+
"print(f\"Prompt: {prompt}\\n\")\n",
|
444 |
+
"for idx in logits.argsort()[::-1]:\n",
|
445 |
+
" display(images[idx])\n",
|
446 |
+
" print(f\"Score: {logits[idx]:.2f}\\n\")"
|
447 |
+
]
|
448 |
+
}
|
449 |
+
],
|
450 |
+
"metadata": {
|
451 |
+
"accelerator": "GPU",
|
452 |
+
"colab": {
|
453 |
+
"collapsed_sections": [],
|
454 |
+
"include_colab_link": true,
|
455 |
+
"machine_shape": "hm",
|
456 |
+
"name": "DALL·E mini - Inference pipeline.ipynb",
|
457 |
+
"provenance": []
|
458 |
+
},
|
459 |
+
"kernelspec": {
|
460 |
+
"display_name": "Python 3 (ipykernel)",
|
461 |
+
"language": "python",
|
462 |
+
"name": "python3"
|
463 |
+
},
|
464 |
+
"language_info": {
|
465 |
+
"codemirror_mode": {
|
466 |
+
"name": "ipython",
|
467 |
+
"version": 3
|
468 |
+
},
|
469 |
+
"file_extension": ".py",
|
470 |
+
"mimetype": "text/x-python",
|
471 |
+
"name": "python",
|
472 |
+
"nbconvert_exporter": "python",
|
473 |
+
"pygments_lexer": "ipython3",
|
474 |
+
"version": "3.9.7"
|
475 |
+
}
|
476 |
+
},
|
477 |
+
"nbformat": 4,
|
478 |
+
"nbformat_minor": 0
|
479 |
+
}
|
tools/train/config/medium/config.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"activation_dropout": 0.0,
|
3 |
+
"activation_function": "gelu",
|
4 |
+
"attention_dropout": 0.0,
|
5 |
+
"bos_token_id": 16385,
|
6 |
+
"d_model": 1408,
|
7 |
+
"decoder_attention_heads": 16,
|
8 |
+
"decoder_ffn_dim": 4096,
|
9 |
+
"decoder_layerdrop": 0.0,
|
10 |
+
"decoder_layers": 14,
|
11 |
+
"decoder_start_token_id": 16384,
|
12 |
+
"dropout": 0.0,
|
13 |
+
"encoder_attention_heads": 16,
|
14 |
+
"encoder_ffn_dim": 4096,
|
15 |
+
"encoder_layerdrop": 0.0,
|
16 |
+
"encoder_layers": 14,
|
17 |
+
"encoder_vocab_size": 50264,
|
18 |
+
"eos_token_id": 16385,
|
19 |
+
"gradient_checkpointing": false,
|
20 |
+
"image_length": 256,
|
21 |
+
"image_vocab_size": 16384,
|
22 |
+
"init_std": 0.01,
|
23 |
+
"is_encoder_decoder": true,
|
24 |
+
"max_text_length": 64,
|
25 |
+
"model_type": "dallebart",
|
26 |
+
"normalize_text": true,
|
27 |
+
"pad_token_id": 16385,
|
28 |
+
"scale_embedding": false,
|
29 |
+
"tie_word_embeddings": false,
|
30 |
+
"use_cache": true
|
31 |
+
}
|
tools/train/config/mega/config.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"activation_dropout": 0.0,
|
3 |
+
"activation_function": "gelu",
|
4 |
+
"attention_dropout": 0.0,
|
5 |
+
"bos_token_id": 16385,
|
6 |
+
"d_model": 2048,
|
7 |
+
"decoder_attention_heads": 32,
|
8 |
+
"decoder_ffn_dim": 8192,
|
9 |
+
"decoder_layerdrop": 0.0,
|
10 |
+
"decoder_layers": 24,
|
11 |
+
"decoder_start_token_id": 16384,
|
12 |
+
"dropout": 0.0,
|
13 |
+
"encoder_attention_heads": 32,
|
14 |
+
"encoder_ffn_dim": 8192,
|
15 |
+
"encoder_layerdrop": 0.0,
|
16 |
+
"encoder_layers": 24,
|
17 |
+
"encoder_vocab_size": 50264,
|
18 |
+
"eos_token_id": 16385,
|
19 |
+
"image_length": 256,
|
20 |
+
"image_vocab_size": 16391,
|
21 |
+
"init_std": 0.01,
|
22 |
+
"is_encoder_decoder": true,
|
23 |
+
"max_text_length": 64,
|
24 |
+
"model_type": "dallebart",
|
25 |
+
"normalize_text": true,
|
26 |
+
"pad_token_id": 16385,
|
27 |
+
"scale_embedding": false,
|
28 |
+
"tie_word_embeddings": false,
|
29 |
+
"use_cache": true
|
30 |
+
}
|
tools/train/config/micro/config.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"activation_dropout": 0.0,
|
3 |
+
"activation_function": "gelu",
|
4 |
+
"attention_dropout": 0.0,
|
5 |
+
"bos_token_id": 16385,
|
6 |
+
"d_model": 256,
|
7 |
+
"decoder_attention_heads": 2,
|
8 |
+
"decoder_ffn_dim": 256,
|
9 |
+
"decoder_layerdrop": 0.0,
|
10 |
+
"decoder_layers": 2,
|
11 |
+
"decoder_start_token_id": 16384,
|
12 |
+
"dropout": 0.0,
|
13 |
+
"encoder_attention_heads": 2,
|
14 |
+
"encoder_ffn_dim": 256,
|
15 |
+
"encoder_layerdrop": 0.0,
|
16 |
+
"encoder_layers": 2,
|
17 |
+
"encoder_vocab_size": 50264,
|
18 |
+
"eos_token_id": 16385,
|
19 |
+
"image_length": 256,
|
20 |
+
"image_vocab_size": 16391,
|
21 |
+
"init_std": 0.02,
|
22 |
+
"is_encoder_decoder": true,
|
23 |
+
"max_text_length": 64,
|
24 |
+
"model_type": "dallebart",
|
25 |
+
"normalize_text": true,
|
26 |
+
"pad_token_id": 16385,
|
27 |
+
"scale_embedding": false,
|
28 |
+
"tie_word_embeddings": false,
|
29 |
+
"use_cache": true
|
30 |
+
}
|
tools/train/config/mini/config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"activation_dropout": 0.0,
|
3 |
+
"activation_function": "gelu",
|
4 |
+
"attention_dropout": 0.0,
|
5 |
+
"bos_token_id": 16385,
|
6 |
+
"d_model": 1024,
|
7 |
+
"decoder_attention_heads": 16,
|
8 |
+
"decoder_ffn_dim": 4096,
|
9 |
+
"decoder_layers": 12,
|
10 |
+
"decoder_start_token_id": 16384,
|
11 |
+
"dropout": 0.0,
|
12 |
+
"encoder_attention_heads": 16,
|
13 |
+
"encoder_ffn_dim": 4096,
|
14 |
+
"encoder_layers": 12,
|
15 |
+
"encoder_vocab_size": 50264,
|
16 |
+
"eos_token_id": 16385,
|
17 |
+
"gradient_checkpointing": false,
|
18 |
+
"image_length": 256,
|
19 |
+
"image_vocab_size": 16384,
|
20 |
+
"init_std": 0.02,
|
21 |
+
"is_encoder_decoder": true,
|
22 |
+
"max_text_length": 64,
|
23 |
+
"model_type": "dallebart",
|
24 |
+
"normalize_text": true,
|
25 |
+
"pad_token_id": 16385,
|
26 |
+
"scale_embedding": false,
|
27 |
+
"tie_word_embeddings": false,
|
28 |
+
"use_cache": true
|
29 |
+
}
|
tools/train/config/mini_glu/config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"activation_dropout": 0.0,
|
3 |
+
"activation_function": "gelu",
|
4 |
+
"attention_dropout": 0.0,
|
5 |
+
"bos_token_id": 16385,
|
6 |
+
"d_model": 1024,
|
7 |
+
"decoder_attention_heads": 16,
|
8 |
+
"decoder_ffn_dim": 2730,
|
9 |
+
"decoder_layers": 12,
|
10 |
+
"decoder_start_token_id": 16384,
|
11 |
+
"dropout": 0.0,
|
12 |
+
"encoder_attention_heads": 16,
|
13 |
+
"encoder_ffn_dim": 2730,
|
14 |
+
"encoder_layers": 12,
|
15 |
+
"encoder_vocab_size": 50264,
|
16 |
+
"eos_token_id": 16385,
|
17 |
+
"gradient_checkpointing": false,
|
18 |
+
"image_length": 256,
|
19 |
+
"image_vocab_size": 16384,
|
20 |
+
"init_std": 0.02,
|
21 |
+
"is_encoder_decoder": true,
|
22 |
+
"max_text_length": 64,
|
23 |
+
"model_type": "dallebart",
|
24 |
+
"normalize_text": true,
|
25 |
+
"pad_token_id": 16385,
|
26 |
+
"scale_embedding": false,
|
27 |
+
"tie_word_embeddings": false,
|
28 |
+
"use_cache": true
|
29 |
+
}
|
tools/train/scalable_shampoo/README.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Notes
|
2 |
+
|
3 |
+
Files copied from [google-research/scalable_shampoo/optax](https://github.com/google-research/google-research/tree/master/scalable_shampoo/optax).
|
4 |
+
|
5 |
+
Imports have been modified to be relative.
|
6 |
+
|
7 |
+
This will eventually be replaced with `optax-shampoo` package.
|
tools/train/scalable_shampoo/distributed_shampoo.py
ADDED
@@ -0,0 +1,2267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# An implementation of distributed Shampoo optimizer from:
|
17 |
+
#
|
18 |
+
# Scalable Second Order Optimization for Deep Learning
|
19 |
+
# Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
|
20 |
+
# Preprint Paper: https://arxiv.org/abs/2002.09018
|
21 |
+
#
|
22 |
+
# This implementation moves computation of inverse pth root back to the
|
23 |
+
# accelerator (if higher precision is available).
|
24 |
+
#
|
25 |
+
# Authors: Rohan Anil (rohananil at google dot com)
|
26 |
+
# & Vineet Gupta (vineet at google dot com)
|
27 |
+
#
|
28 |
+
"""Distributed Shampoo Implementation."""
|
29 |
+
|
30 |
+
import enum
|
31 |
+
import functools
|
32 |
+
import itertools
|
33 |
+
from typing import Any, List, NamedTuple, Tuple
|
34 |
+
|
35 |
+
import chex
|
36 |
+
import jax
|
37 |
+
import jax.experimental.pjit as pjit
|
38 |
+
import jax.numpy as jnp
|
39 |
+
import numpy as np
|
40 |
+
import optax
|
41 |
+
from flax import struct
|
42 |
+
from jax import lax
|
43 |
+
|
44 |
+
from .quantization_utils import QuantizedValue
|
45 |
+
from .symmetric_matrices import symmetric_matrices
|
46 |
+
|
47 |
+
# Dtype for inverse-pth root routine
|
48 |
+
# Switch to f64 if you have hardware that supports it. Enable the jax flag
|
49 |
+
# jax_enable_x64 for this to work, otherwise it will default to float32.
|
50 |
+
_MAT_INV_PTH_ROOT_DTYPE = jnp.float64
|
51 |
+
|
52 |
+
|
53 |
+
@struct.dataclass
|
54 |
+
class TrainingMetrics:
|
55 |
+
inverse_pth_root_errors: chex.Array # Error for inverse-pth roots.
|
56 |
+
# TODO(rohananil): Add more important metrics to track during training.
|
57 |
+
|
58 |
+
|
59 |
+
# Per parameter optimizer state used in data-parallel training.
|
60 |
+
class ParameterStats(NamedTuple):
|
61 |
+
"""State associated to each parameter of the model being trained."""
|
62 |
+
|
63 |
+
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
|
64 |
+
statistics: List[Any] # Statistics (QuantizedValue, chex.Array)
|
65 |
+
preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
|
66 |
+
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
|
67 |
+
momentum: QuantizedValue # Momentum for the shampoo preconditioner
|
68 |
+
training_metrics: TrainingMetrics # Metrics (optional for training).
|
69 |
+
|
70 |
+
|
71 |
+
# For training extremely large model; We keep a global state with a concatenated
|
72 |
+
# statistics and preconditioner states for all vars. This is so that we can
|
73 |
+
# annotate the leading axis to be sharded to save memory at the cost of
|
74 |
+
# communication.
|
75 |
+
@struct.dataclass
|
76 |
+
class GlobalShardedParameterStats:
|
77 |
+
statistics: chex.Array # Statistics
|
78 |
+
preconditioners: chex.Array # Preconditioners
|
79 |
+
exponents: chex.Array # exponents
|
80 |
+
|
81 |
+
|
82 |
+
# These are per-parameter local states; All statistics here mirror the parameter
|
83 |
+
# Thus the sharding is copied over from the param specification.
|
84 |
+
@struct.dataclass
|
85 |
+
class LocalShardedParameterStats:
|
86 |
+
"""State associated to each parameter of the model being trained."""
|
87 |
+
|
88 |
+
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
|
89 |
+
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
|
90 |
+
momentum: QuantizedValue # Momentum for the shampoo preconditioner
|
91 |
+
training_metrics: TrainingMetrics # Metrics (optional for training).
|
92 |
+
index_start: np.int32 = struct.field(
|
93 |
+
pytree_node=False
|
94 |
+
) # Index into global statistics array
|
95 |
+
sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
|
96 |
+
|
97 |
+
|
98 |
+
def init_training_metrics(num_statistics):
|
99 |
+
# Since the downstream apis expect a jnp.array - we create a dummy one if
|
100 |
+
# num_statistics=0.
|
101 |
+
n = 1 if not num_statistics else num_statistics
|
102 |
+
return TrainingMetrics(jnp.zeros([n], jnp.float32))
|
103 |
+
|
104 |
+
|
105 |
+
def init_training_metrics_shapes(num_statistics):
|
106 |
+
# Since the downstream apis expect a jnp.array - we create a dummy one if
|
107 |
+
# num_statistics=0.
|
108 |
+
n = 1 if not num_statistics else num_statistics
|
109 |
+
return TrainingMetrics([[n], jnp.float32])
|
110 |
+
|
111 |
+
|
112 |
+
def init_training_metrics_pspec():
|
113 |
+
return TrainingMetrics(pjit.PartitionSpec())
|
114 |
+
|
115 |
+
|
116 |
+
class ShardedShampooStats(NamedTuple):
|
117 |
+
"""Shampoo state in sharded mode."""
|
118 |
+
|
119 |
+
global_stats: Any
|
120 |
+
local_stats: Any
|
121 |
+
|
122 |
+
|
123 |
+
class ShampooState(NamedTuple):
|
124 |
+
count: chex.Array
|
125 |
+
stats: Any
|
126 |
+
|
127 |
+
|
128 |
+
class InitFnState(NamedTuple):
|
129 |
+
init_fn: Any
|
130 |
+
pspec_fn: Any
|
131 |
+
shape_and_dtype_fn: Any
|
132 |
+
|
133 |
+
|
134 |
+
class GraftingType(enum.IntEnum):
|
135 |
+
SGD = 1
|
136 |
+
ADAGRAD = 2
|
137 |
+
RMSPROP = 3
|
138 |
+
RMSPROP_NORMALIZED = 4
|
139 |
+
SQRT_N = 5
|
140 |
+
ADAGRAD_NORMALIZED = 6
|
141 |
+
|
142 |
+
|
143 |
+
def power_iteration(
|
144 |
+
matrix,
|
145 |
+
num_iters=100,
|
146 |
+
error_tolerance=1e-6,
|
147 |
+
precision=lax.Precision.HIGHEST,
|
148 |
+
):
|
149 |
+
r"""Power iteration algorithm.
|
150 |
+
|
151 |
+
The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
|
152 |
+
a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
|
153 |
+
of `A`, and a vector v, which is the corresponding eigenvector of `A`.
|
154 |
+
|
155 |
+
References:
|
156 |
+
[Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
|
157 |
+
|
158 |
+
Args:
|
159 |
+
matrix: the symmetric PSD matrix.
|
160 |
+
num_iters: Number of iterations.
|
161 |
+
error_tolerance: Iterative exit condition.
|
162 |
+
precision: precision XLA related flag, the available options are: a)
|
163 |
+
lax.Precision.DEFAULT (better step time, but not precise) b)
|
164 |
+
lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
|
165 |
+
(best possible precision, slowest)
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
eigen vector, eigen value
|
169 |
+
"""
|
170 |
+
matrix_size = matrix.shape[-1]
|
171 |
+
|
172 |
+
def _iter_condition(state):
|
173 |
+
i, unused_v, unused_s, unused_s_v, run_step = state
|
174 |
+
return jnp.logical_and(i < num_iters, run_step)
|
175 |
+
|
176 |
+
def _iter_body(state):
|
177 |
+
"""One step of power iteration."""
|
178 |
+
i, new_v, s, s_v, unused_run_step = state
|
179 |
+
new_v = new_v / jnp.linalg.norm(new_v)
|
180 |
+
|
181 |
+
s_v = jnp.einsum("ij,j->i", matrix, new_v, precision=precision)
|
182 |
+
s_new = jnp.einsum("i,i->", new_v, s_v, precision=precision)
|
183 |
+
return (
|
184 |
+
i + 1,
|
185 |
+
s_v,
|
186 |
+
s_new,
|
187 |
+
s_v,
|
188 |
+
jnp.greater(jnp.abs(s_new - s), error_tolerance),
|
189 |
+
)
|
190 |
+
|
191 |
+
# Figure out how to use step as seed for random.
|
192 |
+
v_0 = (
|
193 |
+
np.random.RandomState(1729).uniform(-1.0, 1.0, matrix_size).astype(matrix.dtype)
|
194 |
+
)
|
195 |
+
|
196 |
+
init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
|
197 |
+
_, v_out, s_out, _, _ = lax.while_loop(_iter_condition, _iter_body, init_state)
|
198 |
+
v_out = v_out / jnp.linalg.norm(v_out)
|
199 |
+
return v_out, s_out
|
200 |
+
|
201 |
+
|
202 |
+
def mat_power(
|
203 |
+
mat_m,
|
204 |
+
p,
|
205 |
+
precision=lax.Precision.HIGHEST,
|
206 |
+
):
|
207 |
+
"""A simple matrix power method. M^p where p can be TracedValue."""
|
208 |
+
power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
|
209 |
+
|
210 |
+
def _iter_condition(state):
|
211 |
+
i, _, _ = state
|
212 |
+
return i > 0
|
213 |
+
|
214 |
+
def _iter_body(state):
|
215 |
+
i, power, mat = state
|
216 |
+
|
217 |
+
power = jax.lax.cond(
|
218 |
+
i % 2 == 1,
|
219 |
+
lambda: jnp.matmul(mat, power, precision=precision),
|
220 |
+
lambda: power,
|
221 |
+
)
|
222 |
+
i //= 2
|
223 |
+
mat = jnp.matmul(mat, mat, precision=precision)
|
224 |
+
return i, power, mat
|
225 |
+
|
226 |
+
_, result, _ = lax.while_loop(_iter_condition, _iter_body, (p, power, mat_m))
|
227 |
+
return result
|
228 |
+
|
229 |
+
|
230 |
+
def matrix_inverse_pth_root(
|
231 |
+
matrix,
|
232 |
+
p,
|
233 |
+
num_iters=100,
|
234 |
+
ridge_epsilon=1e-6,
|
235 |
+
error_tolerance=1e-6,
|
236 |
+
precision=lax.Precision.HIGHEST,
|
237 |
+
):
|
238 |
+
"""Computes `matrix^(-1/p)`, where `p` is a positive integer.
|
239 |
+
|
240 |
+
This function uses the Coupled newton iterations algorithm for
|
241 |
+
the computation of a matrix's inverse pth root.
|
242 |
+
|
243 |
+
|
244 |
+
References:
|
245 |
+
[Functions of Matrices, Theory and Computation,
|
246 |
+
Nicholas J Higham, Pg 184, Eq 7.18](
|
247 |
+
https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
|
248 |
+
|
249 |
+
Args:
|
250 |
+
matrix: the symmetric PSD matrix whose power it to be computed
|
251 |
+
p: exponent, for p a positive integer.
|
252 |
+
num_iters: Maximum number of iterations.
|
253 |
+
ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
|
254 |
+
error_tolerance: Error indicator, useful for early termination.
|
255 |
+
precision: precision XLA related flag, the available options are: a)
|
256 |
+
lax.Precision.DEFAULT (better step time, but not precise) b)
|
257 |
+
lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
|
258 |
+
(best possible precision, slowest)
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
matrix^(-1/p)
|
262 |
+
"""
|
263 |
+
|
264 |
+
# If the input is not square, materialize it from the concatenated form.
|
265 |
+
if matrix.shape[0] != matrix.shape[1]:
|
266 |
+
matrix = symmetric_matrices.materialize_matrix_from_concat(matrix)
|
267 |
+
|
268 |
+
assert matrix.shape[0] == matrix.shape[1]
|
269 |
+
|
270 |
+
# We use _MAT_INV_PTH_ROOT_DTYPE for the matrix inverse pth root.
|
271 |
+
# Switch to f64 if you have hardware that supports it. Enable the jax flag
|
272 |
+
# jax_enable_x64 for this to work.
|
273 |
+
matrix_size = matrix.shape[0]
|
274 |
+
orig_dtype = matrix.dtype
|
275 |
+
matrix = matrix.astype(_MAT_INV_PTH_ROOT_DTYPE)
|
276 |
+
alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE)
|
277 |
+
identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
|
278 |
+
_, max_ev = power_iteration(
|
279 |
+
matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision
|
280 |
+
)
|
281 |
+
ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-6)
|
282 |
+
|
283 |
+
def _iter_condition(state):
|
284 |
+
(i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state
|
285 |
+
error_above_threshold = jnp.logical_and(error > error_tolerance, run_step)
|
286 |
+
return jnp.logical_and(i < num_iters, error_above_threshold)
|
287 |
+
|
288 |
+
def _iter_body(state):
|
289 |
+
(i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state
|
290 |
+
mat_m_i = (1 - alpha) * identity + alpha * mat_m
|
291 |
+
new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision)
|
292 |
+
new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision)
|
293 |
+
new_error = jnp.max(jnp.abs(new_mat_m - identity))
|
294 |
+
# sometimes error increases after an iteration before decreasing and
|
295 |
+
# converging. 1.2 factor is used to bound the maximal allowed increase.
|
296 |
+
return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, new_error < error * 1.2)
|
297 |
+
|
298 |
+
if matrix_size == 1:
|
299 |
+
resultant_mat_h = (matrix + ridge_epsilon) ** alpha
|
300 |
+
error = 0
|
301 |
+
else:
|
302 |
+
damped_matrix = matrix + ridge_epsilon * identity
|
303 |
+
|
304 |
+
z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix))
|
305 |
+
new_mat_m_0 = damped_matrix * z
|
306 |
+
new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
|
307 |
+
new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
|
308 |
+
init_state = tuple([0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True])
|
309 |
+
_, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
|
310 |
+
_iter_condition, _iter_body, init_state
|
311 |
+
)
|
312 |
+
error = jnp.max(jnp.abs(mat_m - identity)).astype(jnp.float32)
|
313 |
+
is_converged = jnp.asarray(convergence, old_mat_h.dtype)
|
314 |
+
resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
|
315 |
+
resultant_mat_h = jnp.asarray(resultant_mat_h, orig_dtype)
|
316 |
+
return resultant_mat_h, error
|
317 |
+
|
318 |
+
|
319 |
+
def merge_small_dims(shape_to_merge, max_dim):
|
320 |
+
"""Merge small dimensions.
|
321 |
+
|
322 |
+
If there are some small dimensions, we collapse them:
|
323 |
+
e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
|
324 |
+
[1, 2, 768, 1, 2048] --> [2, 768, 2048]
|
325 |
+
|
326 |
+
Args:
|
327 |
+
shape_to_merge: Shape to merge small dimensions.
|
328 |
+
max_dim: Maximal dimension of output shape used in merging.
|
329 |
+
|
330 |
+
Returns:
|
331 |
+
Merged shape.
|
332 |
+
"""
|
333 |
+
if shape_to_merge and np.all(np.array(shape_to_merge) == 1):
|
334 |
+
return [1]
|
335 |
+
|
336 |
+
resulting_shape = []
|
337 |
+
product = 1
|
338 |
+
for d in shape_to_merge:
|
339 |
+
if product * d <= max_dim:
|
340 |
+
product *= d
|
341 |
+
else:
|
342 |
+
if product > 1:
|
343 |
+
resulting_shape.append(product)
|
344 |
+
product = d
|
345 |
+
if product > 1:
|
346 |
+
resulting_shape.append(product)
|
347 |
+
return resulting_shape
|
348 |
+
|
349 |
+
|
350 |
+
def pad_square_matrix(mat, max_size):
|
351 |
+
"""Pad a square matrix up to max_size.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
mat: a matrix to pad.
|
355 |
+
max_size: matrix size requested.
|
356 |
+
|
357 |
+
Returns:
|
358 |
+
Given M returns [[M, 0], [0, I]]
|
359 |
+
"""
|
360 |
+
rows, cols = mat.shape
|
361 |
+
if rows != cols:
|
362 |
+
raise ValueError(
|
363 |
+
"Must have rows == cols, instead got " f"rows={rows}, cols={cols}"
|
364 |
+
)
|
365 |
+
if cols > max_size:
|
366 |
+
raise ValueError(
|
367 |
+
"Must have cols <= max_size. Instead got "
|
368 |
+
f"cols={cols}, max_size={max_size}."
|
369 |
+
)
|
370 |
+
if rows == max_size:
|
371 |
+
return mat
|
372 |
+
pad_size = max_size - rows
|
373 |
+
|
374 |
+
zs1 = jnp.zeros([rows, pad_size], dtype=mat.dtype)
|
375 |
+
zs2 = jnp.zeros([pad_size, rows], dtype=mat.dtype)
|
376 |
+
eye = jnp.eye(pad_size, dtype=mat.dtype)
|
377 |
+
mat = jnp.concatenate([mat, zs1], 1)
|
378 |
+
mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
|
379 |
+
return mat
|
380 |
+
|
381 |
+
|
382 |
+
def make_sliced_padding(
|
383 |
+
symmetric_block_size,
|
384 |
+
num_blocks,
|
385 |
+
starting_block,
|
386 |
+
dtype,
|
387 |
+
):
|
388 |
+
"""Returns padding for symmetric block matrix.
|
389 |
+
|
390 |
+
Specifically, the padding is given concatenated rectangular matrices
|
391 |
+
representing the lower-triangular rows below the starting block. For example,
|
392 |
+
if we want to pad the symmetric matrix
|
393 |
+
|
394 |
+
M = [[A, B^T]
|
395 |
+
[B, C]],
|
396 |
+
|
397 |
+
the desired output (in terms of the full matrix) with num_blocks = 4 is
|
398 |
+
|
399 |
+
M_padded = [[A, B^T, 0, 0]
|
400 |
+
[B, C, 0, 0]
|
401 |
+
[0, 0, I, 0]
|
402 |
+
0, 0, 0, I].
|
403 |
+
|
404 |
+
We would represent M as the block matrix mat = [A, B, C]. In this form, the
|
405 |
+
additional padding to provide has form [0, 0, I, 0, 0, 0, I] (only the lower
|
406 |
+
triangular parts in the third and fourth rows).
|
407 |
+
|
408 |
+
Args:
|
409 |
+
symmetric_block_size: The size of each block.
|
410 |
+
num_blocks: The total number of blocks.
|
411 |
+
starting_block: The block where to start the padding.
|
412 |
+
dtype: The type to use for the blocks.
|
413 |
+
"""
|
414 |
+
if starting_block == num_blocks:
|
415 |
+
return jnp.zeros(shape=(symmetric_block_size, 0), dtype=dtype)
|
416 |
+
|
417 |
+
blocks = []
|
418 |
+
for i in range(starting_block, num_blocks):
|
419 |
+
blocks.append(
|
420 |
+
jnp.zeros(
|
421 |
+
shape=(symmetric_block_size, symmetric_block_size * i), dtype=dtype
|
422 |
+
)
|
423 |
+
)
|
424 |
+
blocks.append(jnp.eye(symmetric_block_size, dtype=dtype))
|
425 |
+
return jnp.concatenate(blocks, axis=-1)
|
426 |
+
|
427 |
+
|
428 |
+
def pad_block_symmetric_matrix(
|
429 |
+
mat,
|
430 |
+
symmetric_block_size,
|
431 |
+
max_num_blocks,
|
432 |
+
):
|
433 |
+
"""Returns the padded blocked symmetric matrix.
|
434 |
+
|
435 |
+
The size of the padded matrix will be:
|
436 |
+
[symmetric_block_size, symmetric_block_size * max_num_blocks]
|
437 |
+
|
438 |
+
The input matrix can either:
|
439 |
+
- Be square with size less or equal to symmetric_block_size. In this case,
|
440 |
+
mat will first be padded to a square matrix of size symmetric_block_size,
|
441 |
+
and then be padded again up to the full size of the blocked matrix.
|
442 |
+
- Be a rectangle with number of rows equal to block size.
|
443 |
+
In this case, number of columns must be a multiple of number of rows, and
|
444 |
+
the ratio must correspond to a block representation of a symmetric matrix.
|
445 |
+
That is, the ratio must have form x * (x + 1) / 2. Here, x represents the
|
446 |
+
number of block rows represented by the matrix.
|
447 |
+
|
448 |
+
Args:
|
449 |
+
mat: The input block matrix.
|
450 |
+
symmetric_block_size: The size of blocks.
|
451 |
+
max_num_blocks: The largest number of blocks to pad to.
|
452 |
+
"""
|
453 |
+
rows, cols = mat.shape
|
454 |
+
if rows > symmetric_block_size:
|
455 |
+
raise ValueError(
|
456 |
+
"Must have rows <= symmetric_block_size. Instead got "
|
457 |
+
f"rows={rows}, symmetric_block_size={symmetric_block_size}."
|
458 |
+
)
|
459 |
+
if rows > cols:
|
460 |
+
raise ValueError(
|
461 |
+
"Must have rows <= cols, instead got " f"rows={rows}, cols={cols}."
|
462 |
+
)
|
463 |
+
if cols > symmetric_block_size * max_num_blocks:
|
464 |
+
raise ValueError(
|
465 |
+
"Must have cols <= symmetric_block_size * max_num_blocks "
|
466 |
+
f"Instead got cols={cols}, "
|
467 |
+
f"symmetric_block_size={symmetric_block_size}, "
|
468 |
+
f"max_num_blocks={max_num_blocks}."
|
469 |
+
)
|
470 |
+
if rows < symmetric_block_size:
|
471 |
+
mat = pad_square_matrix(mat, max_size=symmetric_block_size)
|
472 |
+
# Update rows and cols after possibly padding in pad_square_matrix.
|
473 |
+
rows, cols = mat.shape
|
474 |
+
assert rows == symmetric_block_size
|
475 |
+
assert cols % rows == 0
|
476 |
+
filled_blocks = cols // rows
|
477 |
+
padding_blocks = make_sliced_padding(
|
478 |
+
symmetric_block_size=symmetric_block_size,
|
479 |
+
num_blocks=symmetric_matrices.num_blocks_from_total_blocks(max_num_blocks),
|
480 |
+
starting_block=symmetric_matrices.num_blocks_from_total_blocks(filled_blocks),
|
481 |
+
dtype=mat.dtype,
|
482 |
+
)
|
483 |
+
return jnp.concatenate([mat, padding_blocks], axis=-1)
|
484 |
+
|
485 |
+
|
486 |
+
def pad_vector(vec, max_size):
|
487 |
+
"""Pad a vector to a max_size.
|
488 |
+
|
489 |
+
Args:
|
490 |
+
vec: a vector to pad.
|
491 |
+
max_size: matrix size requested.
|
492 |
+
|
493 |
+
Returns:
|
494 |
+
Given V returns [V, 0]
|
495 |
+
"""
|
496 |
+
size = vec.shape[0]
|
497 |
+
assert size <= max_size
|
498 |
+
if size == max_size:
|
499 |
+
return vec
|
500 |
+
pad_size = max_size - size
|
501 |
+
zs1 = jnp.zeros([pad_size], dtype=vec.dtype)
|
502 |
+
return jnp.concatenate([vec, zs1], 0)
|
503 |
+
|
504 |
+
|
505 |
+
def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
|
506 |
+
"""Avoids wasteful buffer allocation with XLA."""
|
507 |
+
|
508 |
+
def _iter_body(unused_state):
|
509 |
+
results = compute_fn(*args, **kwargs)
|
510 |
+
return tuple([False] + list(results))
|
511 |
+
|
512 |
+
def _iter_condition(state):
|
513 |
+
return state[0]
|
514 |
+
|
515 |
+
results = jax.lax.while_loop(
|
516 |
+
_iter_condition, _iter_body, tuple([predicate] + init_state)
|
517 |
+
)
|
518 |
+
return tuple(results[1:])
|
519 |
+
|
520 |
+
|
521 |
+
class BlockPartitioner:
|
522 |
+
"""Partitions a tensor into smaller tensors."""
|
523 |
+
|
524 |
+
def __init__(self, param, block_size):
|
525 |
+
self._shape = param.shape
|
526 |
+
self._splits = []
|
527 |
+
split_sizes = []
|
528 |
+
# We split params into smaller blocks. Here we store the metadata to make
|
529 |
+
# that split.
|
530 |
+
for i, d in enumerate(param.shape):
|
531 |
+
if 0 < block_size < d:
|
532 |
+
# d-1, otherwise split appends a 0-size array.
|
533 |
+
nsplit = (d - 1) // block_size
|
534 |
+
indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size
|
535 |
+
sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size
|
536 |
+
sizes[-1] = d - indices[-1]
|
537 |
+
self._splits.append((i, indices))
|
538 |
+
split_sizes.append(sizes)
|
539 |
+
else:
|
540 |
+
split_sizes.append(np.array([d], dtype=np.int32))
|
541 |
+
self._num_splits = len(split_sizes)
|
542 |
+
self._preconditioner_shapes = []
|
543 |
+
for t in itertools.product(*split_sizes):
|
544 |
+
self._preconditioner_shapes.extend([[d, d] for d in t])
|
545 |
+
|
546 |
+
def shapes_for_preconditioners(self):
|
547 |
+
return self._preconditioner_shapes
|
548 |
+
|
549 |
+
def num_splits(self):
|
550 |
+
return self._num_splits
|
551 |
+
|
552 |
+
def partition(self, tensor):
|
553 |
+
"""Partition tensor into blocks."""
|
554 |
+
|
555 |
+
assert tensor.shape == self._shape
|
556 |
+
tensors = [tensor]
|
557 |
+
for (i, indices) in self._splits:
|
558 |
+
tensors_local = []
|
559 |
+
for t in tensors:
|
560 |
+
tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
|
561 |
+
tensors = tensors_local
|
562 |
+
return tensors
|
563 |
+
|
564 |
+
def merge_partitions(self, partitions):
|
565 |
+
"""Merge partitions back to original shape."""
|
566 |
+
|
567 |
+
for (i, indices) in reversed(self._splits):
|
568 |
+
n = len(indices) + 1
|
569 |
+
partial_merged_tensors = []
|
570 |
+
ind = 0
|
571 |
+
while ind < len(partitions):
|
572 |
+
partial_merged_tensors.append(
|
573 |
+
jnp.concatenate(partitions[ind : ind + n], axis=i)
|
574 |
+
)
|
575 |
+
ind += n
|
576 |
+
partitions = partial_merged_tensors
|
577 |
+
assert len(partitions) == 1
|
578 |
+
return partitions[0]
|
579 |
+
|
580 |
+
|
581 |
+
class Preconditioner:
|
582 |
+
"""Compute statistics/shape from gradients for preconditioning."""
|
583 |
+
|
584 |
+
def __init__(self, param, block_size, best_effort_shape_interpretation):
|
585 |
+
self._original_shape = param.shape
|
586 |
+
self._transformed_shape = param.shape
|
587 |
+
if best_effort_shape_interpretation:
|
588 |
+
self._transformed_shape = merge_small_dims(self._original_shape, block_size)
|
589 |
+
reshaped_param = jnp.reshape(param, self._transformed_shape)
|
590 |
+
self._partitioner = BlockPartitioner(reshaped_param, block_size)
|
591 |
+
|
592 |
+
def statistics_from_grad(self, grad):
|
593 |
+
"""Compute statistics from gradients.
|
594 |
+
|
595 |
+
Args:
|
596 |
+
grad: Gradient to compute statistics from.
|
597 |
+
|
598 |
+
Returns:
|
599 |
+
A list of gradient statistics for each partition.
|
600 |
+
"""
|
601 |
+
reshaped_grad = jnp.reshape(grad, self._transformed_shape)
|
602 |
+
partitioned_grads = self._partitioner.partition(reshaped_grad)
|
603 |
+
stats = []
|
604 |
+
for g in partitioned_grads:
|
605 |
+
g_stats = []
|
606 |
+
rank = len(g.shape)
|
607 |
+
for i in range(rank):
|
608 |
+
axes = list(range(i)) + list(range(i + 1, rank))
|
609 |
+
stat = jnp.tensordot(g, g, axes=(axes, axes))
|
610 |
+
g_stats.append(stat)
|
611 |
+
stats.extend(g_stats)
|
612 |
+
return stats
|
613 |
+
|
614 |
+
def shapes_for_preconditioners(self):
|
615 |
+
"""Returns shape from statistics."""
|
616 |
+
return self._partitioner.shapes_for_preconditioners()
|
617 |
+
|
618 |
+
def exponent_for_preconditioner(self):
|
619 |
+
"""Returns exponent to use for inverse-pth root M^{-1/p}."""
|
620 |
+
return 2 * len(self._transformed_shape)
|
621 |
+
|
622 |
+
def preconditioned_grad(self, grad, preconditioners):
|
623 |
+
"""Precondition the gradient.
|
624 |
+
|
625 |
+
Args:
|
626 |
+
grad: A gradient tensor to precondition.
|
627 |
+
preconditioners: A list of preconditioners to apply.
|
628 |
+
|
629 |
+
Returns:
|
630 |
+
A preconditioned gradient.
|
631 |
+
"""
|
632 |
+
|
633 |
+
reshaped_grad = jnp.reshape(grad, self._transformed_shape)
|
634 |
+
partitioned_grads = self._partitioner.partition(reshaped_grad)
|
635 |
+
preconditioned_partitioned_grads = []
|
636 |
+
num_splits = self._partitioner.num_splits()
|
637 |
+
for i, g in enumerate(partitioned_grads):
|
638 |
+
preconditioners_for_grad = preconditioners[
|
639 |
+
i * num_splits : (i + 1) * num_splits
|
640 |
+
]
|
641 |
+
rank = len(g.shape)
|
642 |
+
precond_g = g
|
643 |
+
for j in range(rank):
|
644 |
+
precond_g = jnp.tensordot(
|
645 |
+
precond_g, preconditioners_for_grad[j], axes=[[0], [0]]
|
646 |
+
)
|
647 |
+
preconditioned_partitioned_grads.append(precond_g)
|
648 |
+
merged_grad = self._partitioner.merge_partitions(
|
649 |
+
preconditioned_partitioned_grads
|
650 |
+
)
|
651 |
+
return jnp.reshape(merged_grad, self._original_shape)
|
652 |
+
|
653 |
+
|
654 |
+
def _convert_to_parameter_stats(global_stats, local_stat):
|
655 |
+
"""Creates parameter stats from sharded stats."""
|
656 |
+
index_start = int(local_stat.index_start)
|
657 |
+
index_end = int(len(local_stat.sizes)) + index_start
|
658 |
+
statistics = global_stats.statistics[index_start:index_end, :, :]
|
659 |
+
preconditioners = global_stats.preconditioners[index_start:index_end, :, :]
|
660 |
+
new_statistics = []
|
661 |
+
new_preconditioners = []
|
662 |
+
for i, size in enumerate(local_stat.sizes):
|
663 |
+
new_statistics.append(statistics[i][:size, :size])
|
664 |
+
new_preconditioners.append(preconditioners[i][:size, :size])
|
665 |
+
return ParameterStats(
|
666 |
+
local_stat.diagonal_statistics,
|
667 |
+
new_statistics,
|
668 |
+
new_preconditioners,
|
669 |
+
local_stat.diagonal_momentum,
|
670 |
+
local_stat.momentum,
|
671 |
+
local_stat.training_metrics,
|
672 |
+
)
|
673 |
+
|
674 |
+
|
675 |
+
def _convert_from_parameter_stats(parameter_stats, local_stats):
|
676 |
+
"""Creates sharded stats from paramter stats."""
|
677 |
+
return LocalShardedParameterStats(
|
678 |
+
parameter_stats.diagonal_statistics,
|
679 |
+
parameter_stats.diagonal_momentum,
|
680 |
+
parameter_stats.momentum,
|
681 |
+
parameter_stats.training_metrics,
|
682 |
+
local_stats.index_start,
|
683 |
+
local_stats.sizes,
|
684 |
+
)
|
685 |
+
|
686 |
+
|
687 |
+
def _add_error_into_local_stats(local_stats, errors, inverse_failure_threshold):
|
688 |
+
"""Adds errors back into local statistics."""
|
689 |
+
new_local_stats = []
|
690 |
+
for local_stat in local_stats:
|
691 |
+
index_start = int(local_stat.index_start)
|
692 |
+
index_end = int(len(local_stat.sizes)) + index_start
|
693 |
+
per_stat_error = errors[index_start:index_end]
|
694 |
+
if local_stat.sizes:
|
695 |
+
per_stat_error = jnp.where(
|
696 |
+
jnp.logical_and(
|
697 |
+
per_stat_error > 0.0, per_stat_error != inverse_failure_threshold
|
698 |
+
),
|
699 |
+
per_stat_error,
|
700 |
+
local_stat.training_metrics.inverse_pth_root_errors,
|
701 |
+
)
|
702 |
+
new_local_stats.append(
|
703 |
+
LocalShardedParameterStats(
|
704 |
+
local_stat.diagonal_statistics,
|
705 |
+
local_stat.diagonal_momentum,
|
706 |
+
local_stat.momentum,
|
707 |
+
TrainingMetrics(per_stat_error),
|
708 |
+
local_stat.index_start,
|
709 |
+
local_stat.sizes,
|
710 |
+
)
|
711 |
+
)
|
712 |
+
return new_local_stats
|
713 |
+
|
714 |
+
|
715 |
+
def batch(x, num_devices):
|
716 |
+
"""Batch `x` so that so that leading axis is num_devices."""
|
717 |
+
n = len(x)
|
718 |
+
b = int(n / num_devices)
|
719 |
+
return jnp.stack([jnp.stack(x[idx : idx + b]) for idx in range(0, n, b)])
|
720 |
+
|
721 |
+
|
722 |
+
def unbatch(batched_values):
|
723 |
+
"""Unbatch values across leading axis and return a list of elements."""
|
724 |
+
b1, b2 = batched_values.shape[0], batched_values.shape[1]
|
725 |
+
results = []
|
726 |
+
for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0):
|
727 |
+
v_array = jnp.squeeze(v_array)
|
728 |
+
# b2 = batches (number of preconditioner computation) per core.
|
729 |
+
if b2 > 1:
|
730 |
+
for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
|
731 |
+
results.append(jnp.squeeze(v))
|
732 |
+
else:
|
733 |
+
results.append(v_array)
|
734 |
+
return results
|
735 |
+
|
736 |
+
|
737 |
+
def distributed_shampoo(
|
738 |
+
learning_rate,
|
739 |
+
block_size,
|
740 |
+
beta1=0.9,
|
741 |
+
beta2=0.999,
|
742 |
+
diagonal_epsilon=1e-10,
|
743 |
+
matrix_epsilon=1e-6,
|
744 |
+
weight_decay=0.0,
|
745 |
+
start_preconditioning_step=5,
|
746 |
+
preconditioning_compute_steps=1,
|
747 |
+
statistics_compute_steps=1,
|
748 |
+
best_effort_shape_interpretation=True,
|
749 |
+
graft_type=GraftingType.SGD,
|
750 |
+
nesterov=True,
|
751 |
+
exponent_override=0,
|
752 |
+
# Pass pmap 'batch axis name' in pmap mode.
|
753 |
+
batch_axis_name=None,
|
754 |
+
### Only set following 3 params in pjit/spmd mode.
|
755 |
+
### WARNING: Experimental
|
756 |
+
statistics_partition_spec=None,
|
757 |
+
preconditioner_partition_spec=None,
|
758 |
+
num_devices_for_pjit=None,
|
759 |
+
shard_optimizer_states=False,
|
760 |
+
###
|
761 |
+
### Experimental memory reduction mode
|
762 |
+
best_effort_memory_usage_reduction=False,
|
763 |
+
###
|
764 |
+
inverse_failure_threshold=0.1,
|
765 |
+
moving_average_for_momentum=False,
|
766 |
+
skip_preconditioning_dim_size_gt=4096,
|
767 |
+
clip_by_scaled_gradient_norm=None,
|
768 |
+
precision=lax.Precision.HIGHEST,
|
769 |
+
):
|
770 |
+
"""Distributed Shampoo optimizer.
|
771 |
+
|
772 |
+
Distributed Shampoo is a second-order preconditioned method (concretely, a
|
773 |
+
variant of full-matrix Adagrad), that provides significant convergence and
|
774 |
+
wall-clock time improvements compared to conventional first-order methods,
|
775 |
+
and that has been shown to scale to large state-of-the-art deep learning
|
776 |
+
models.
|
777 |
+
|
778 |
+
References:
|
779 |
+
Scalable Second Order Optimization for Deep Learning,
|
780 |
+
Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
|
781 |
+
|
782 |
+
Preprint: https://arxiv.org/abs/2002.09018
|
783 |
+
|
784 |
+
Args:
|
785 |
+
learning_rate: the step size used to update the parameters.
|
786 |
+
block_size: Block size for large layers (if > 0). Preconditioning compute
|
787 |
+
operation is cubic in the dimension of the tensor. Block size allows us to
|
788 |
+
chunk the layers into sub-layers of maximal dimension dictated by this
|
789 |
+
value. Use 128 as default (increase if you have compute budget).
|
790 |
+
beta1: momentum parameter.
|
791 |
+
beta2: second moment averaging parameter.
|
792 |
+
diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
|
793 |
+
to AdaGrad is enabled).
|
794 |
+
matrix_epsilon: epsilon to add to statistics before computing inverse pth
|
795 |
+
root. If you are running in f32 precision for inverse pth root
|
796 |
+
(recommended today) this can go upto 1e-6. If you have latest hardware
|
797 |
+
with native f64 precision, set this upto 1e-12.
|
798 |
+
weight_decay: Weight decay for regularization.
|
799 |
+
start_preconditioning_step: When to start Shampoo update before which
|
800 |
+
diagonal update is used. This is because we dont have enough information
|
801 |
+
to do stable inverse.
|
802 |
+
preconditioning_compute_steps: How often to compute preconditioner.
|
803 |
+
Performance tuning params for controlling memory and compute requirements.
|
804 |
+
Ideally set this and statistics_compute_steps params to 1.
|
805 |
+
statistics_compute_steps: How often to compute statistics.
|
806 |
+
best_effort_shape_interpretation: If there are some small dimensions,
|
807 |
+
collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
|
808 |
+
block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
|
809 |
+
graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
|
810 |
+
optimizer. This allows us to plugin the Shampoo optimizer into settings
|
811 |
+
where SGD/AdaGrad is already well tuned.
|
812 |
+
nesterov: Nesterov momentum.
|
813 |
+
exponent_override: Override the exponent used in matrix inverse.
|
814 |
+
batch_axis_name: labeled axis over pmap for data-parallel training the
|
815 |
+
optimizer used for.
|
816 |
+
statistics_partition_spec: PartitionSpec to be used in sharded mode.
|
817 |
+
preconditioner_partition_spec: PartitionSpec to be used in sharded mode.
|
818 |
+
num_devices_for_pjit: Number of devices to parallelize over when using pjit.
|
819 |
+
shard_optimizer_states: Shard optimizer states to save memory in model
|
820 |
+
parallel training.
|
821 |
+
best_effort_memory_usage_reduction: Best effort memory usage reduction. -
|
822 |
+
diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x) -> jnp.int8 -
|
823 |
+
statistics, preconditioners -> jnp.int16 + diagonals
|
824 |
+
inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
|
825 |
+
determine that using this threshold.
|
826 |
+
moving_average_for_momentum: Whether to use moving average for momentum
|
827 |
+
instead of exponential moving average.
|
828 |
+
skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
|
829 |
+
greater than this value.
|
830 |
+
clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful when
|
831 |
+
using RMSProp Grafting).
|
832 |
+
precision: precision XLA related flag, the available options are: a)
|
833 |
+
lax.Precision.DEFAULT (better step time, but not precise) b)
|
834 |
+
lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
|
835 |
+
(best possible precision, slowest)
|
836 |
+
|
837 |
+
Returns:
|
838 |
+
a GradientTransformation.
|
839 |
+
"""
|
840 |
+
|
841 |
+
def _graft_type_has_diagonal_statistics():
|
842 |
+
"""Returns True if using diagonal firt order method for grafting."""
|
843 |
+
return graft_type != GraftingType.SGD and graft_type != GraftingType.SQRT_N
|
844 |
+
|
845 |
+
def _graft_type_has_diagonal_momentum_states():
|
846 |
+
"""Returns False if using SQRT_N for grafting."""
|
847 |
+
return graft_type != GraftingType.SQRT_N
|
848 |
+
|
849 |
+
def quantized_dtype_for_momentum_buffers():
|
850 |
+
return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32
|
851 |
+
|
852 |
+
# TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes.
|
853 |
+
def quantized_dtype_for_diagonal_statistics_buffers():
|
854 |
+
return jnp.float32
|
855 |
+
|
856 |
+
# Preconditioner and statistics are both stores as int16 in this mode.
|
857 |
+
# We take out the diagonal to make quantization easier.
|
858 |
+
def quantized_dtype_for_second_moment_statistics_buffers():
|
859 |
+
return (
|
860 |
+
jnp.int16
|
861 |
+
if best_effort_memory_usage_reduction and batch_axis_name
|
862 |
+
else jnp.float32
|
863 |
+
)
|
864 |
+
|
865 |
+
# Preconditioner and statistics are both stores as int16 in this mode.
|
866 |
+
# We take out the diagonal to make quantization easier.
|
867 |
+
def quantized_dtype_for_second_moment_preconditioner_buffers():
|
868 |
+
return (
|
869 |
+
jnp.int16
|
870 |
+
if best_effort_memory_usage_reduction and batch_axis_name
|
871 |
+
else jnp.float32
|
872 |
+
)
|
873 |
+
|
874 |
+
def _to_float(maybe_quantized):
|
875 |
+
if isinstance(maybe_quantized, QuantizedValue):
|
876 |
+
return maybe_quantized.to_float()
|
877 |
+
else:
|
878 |
+
return maybe_quantized
|
879 |
+
|
880 |
+
def _maybe_quantize_statistics(statistics_list):
|
881 |
+
return _maybe_quantize_matrices_with_dtype(
|
882 |
+
statistics_list, quantized_dtype_for_second_moment_statistics_buffers()
|
883 |
+
)
|
884 |
+
|
885 |
+
def _maybe_quantize_preconditioners(statistics_list):
|
886 |
+
return _maybe_quantize_matrices_with_dtype(
|
887 |
+
statistics_list, quantized_dtype_for_second_moment_preconditioner_buffers()
|
888 |
+
)
|
889 |
+
|
890 |
+
def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype):
|
891 |
+
if quantized_dtype != jnp.float32:
|
892 |
+
return [
|
893 |
+
QuantizedValue.from_float_value(
|
894 |
+
s, quantized_dtype, extract_diagonal=True
|
895 |
+
)
|
896 |
+
for s in statistics_list
|
897 |
+
]
|
898 |
+
else:
|
899 |
+
return statistics_list
|
900 |
+
|
901 |
+
def _maybe_dequantize_preconditioners(preconditioner_list):
|
902 |
+
return _maybe_dequantize_matrices_with_dtype(
|
903 |
+
preconditioner_list,
|
904 |
+
quantized_dtype_for_second_moment_preconditioner_buffers(),
|
905 |
+
)
|
906 |
+
|
907 |
+
def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype):
|
908 |
+
if quantized_dtype != jnp.float32:
|
909 |
+
return [s.to_float() for s in statistics_list]
|
910 |
+
else:
|
911 |
+
return statistics_list
|
912 |
+
|
913 |
+
def _quantize_diagonal_statistics(diagonal_statistics):
|
914 |
+
return QuantizedValue.from_float_value(
|
915 |
+
diagonal_statistics, quantized_dtype_for_diagonal_statistics_buffers()
|
916 |
+
)
|
917 |
+
|
918 |
+
def _quantize_momentum(momentum_statistics):
|
919 |
+
return QuantizedValue.from_float_value(
|
920 |
+
momentum_statistics, quantized_dtype_for_momentum_buffers()
|
921 |
+
)
|
922 |
+
|
923 |
+
def sharded_init_fn(params):
|
924 |
+
"""Returns optimizer state (for PJIT mode).
|
925 |
+
|
926 |
+
Args:
|
927 |
+
params: the parameters that should be updated.
|
928 |
+
"""
|
929 |
+
params_flat, treedef = jax.tree_flatten(params)
|
930 |
+
# Find max size to pad to.
|
931 |
+
max_size = 0
|
932 |
+
for param in params_flat:
|
933 |
+
preconditioner = Preconditioner(
|
934 |
+
param, block_size, best_effort_shape_interpretation
|
935 |
+
)
|
936 |
+
if not _skip_preconditioning(param):
|
937 |
+
shapes = preconditioner.shapes_for_preconditioners()
|
938 |
+
sizes = [s[0] for s in shapes]
|
939 |
+
max_size = max(max(sizes), max_size)
|
940 |
+
|
941 |
+
padded_statistics = []
|
942 |
+
padded_preconditioners = []
|
943 |
+
local_stats_flat = []
|
944 |
+
exponents = []
|
945 |
+
for param in params_flat:
|
946 |
+
preconditioner = Preconditioner(
|
947 |
+
param, block_size, best_effort_shape_interpretation
|
948 |
+
)
|
949 |
+
shapes = preconditioner.shapes_for_preconditioners()
|
950 |
+
sizes = []
|
951 |
+
|
952 |
+
statistics = []
|
953 |
+
preconditioners = []
|
954 |
+
index_start = len(padded_statistics)
|
955 |
+
if not _skip_preconditioning(param):
|
956 |
+
sizes = [s[0] for s in shapes]
|
957 |
+
shapes = preconditioner.shapes_for_preconditioners()
|
958 |
+
statistics = [
|
959 |
+
matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32)
|
960 |
+
for s in shapes
|
961 |
+
]
|
962 |
+
preconditioners = [jnp.eye(max_size, dtype=jnp.float32) for s in shapes]
|
963 |
+
padded_statistics.extend(statistics)
|
964 |
+
padded_preconditioners.extend(preconditioners)
|
965 |
+
exponent = (
|
966 |
+
preconditioner.exponent_for_preconditioner()
|
967 |
+
if exponent_override == 0
|
968 |
+
else exponent_override
|
969 |
+
)
|
970 |
+
exponents.extend([exponent] * len(shapes))
|
971 |
+
|
972 |
+
diagonal_statistics = []
|
973 |
+
if _graft_type_has_diagonal_statistics():
|
974 |
+
diagonal_statistics = jnp.zeros_like(param)
|
975 |
+
|
976 |
+
diagonal_momentum = _quantize_momentum([])
|
977 |
+
momentum = _quantize_momentum(jnp.zeros_like(param))
|
978 |
+
if _graft_type_has_diagonal_momentum_states():
|
979 |
+
diagonal_momentum = _quantize_momentum((jnp.zeros_like(param)))
|
980 |
+
|
981 |
+
local_stats_flat.append(
|
982 |
+
LocalShardedParameterStats(
|
983 |
+
_quantize_diagonal_statistics(diagonal_statistics),
|
984 |
+
diagonal_momentum,
|
985 |
+
momentum,
|
986 |
+
init_training_metrics(len(sizes)),
|
987 |
+
index_start,
|
988 |
+
sizes,
|
989 |
+
)
|
990 |
+
)
|
991 |
+
|
992 |
+
local_stats = jax.tree_unflatten(treedef, local_stats_flat)
|
993 |
+
to_pad = -len(padded_statistics) % num_devices_for_pjit
|
994 |
+
if max_size == 0:
|
995 |
+
to_pad = num_devices_for_pjit
|
996 |
+
max_size = block_size
|
997 |
+
stat_dtype = jnp.float32
|
998 |
+
else:
|
999 |
+
stat_dtype = padded_statistics[0].dtype
|
1000 |
+
# Pad the statistics and preconditioner matrices to be a multiple of
|
1001 |
+
# num devices.
|
1002 |
+
# TODO(rohananil): Relax to only the size of the mesh axis where the dim
|
1003 |
+
# is split on.
|
1004 |
+
padded_statistics.extend(
|
1005 |
+
[jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
|
1006 |
+
)
|
1007 |
+
padded_preconditioners.extend(
|
1008 |
+
[jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
|
1009 |
+
)
|
1010 |
+
exponents.extend([1 for _ in range(to_pad)])
|
1011 |
+
global_stats = GlobalShardedParameterStats(
|
1012 |
+
jnp.stack(padded_statistics),
|
1013 |
+
jnp.stack(padded_preconditioners),
|
1014 |
+
jnp.stack(exponents),
|
1015 |
+
)
|
1016 |
+
return ShampooState(
|
1017 |
+
count=jnp.zeros([], jnp.int32),
|
1018 |
+
stats=ShardedShampooStats(global_stats, local_stats),
|
1019 |
+
)
|
1020 |
+
|
1021 |
+
def _max_statistics_size_from_params(params):
|
1022 |
+
max_size = 0
|
1023 |
+
for param in params:
|
1024 |
+
param_clone = jnp.zeros(param.shape, dtype=param.dtype)
|
1025 |
+
preconditioner = Preconditioner(
|
1026 |
+
param_clone, block_size, best_effort_shape_interpretation
|
1027 |
+
)
|
1028 |
+
if not _skip_preconditioning(param):
|
1029 |
+
shapes = preconditioner.shapes_for_preconditioners()
|
1030 |
+
sizes = [s[0] for s in shapes]
|
1031 |
+
max_size = max(max(sizes), max_size)
|
1032 |
+
return max_size
|
1033 |
+
|
1034 |
+
def _remove_leading_sharding_annotation(pspec):
|
1035 |
+
"""Mapping from N-d to (N-1)-d, used for quantization, factoring etc."""
|
1036 |
+
# None and PSpec(None) are valid PSpecs.
|
1037 |
+
if pspec and len(pspec) > 1:
|
1038 |
+
return pjit.PartitionSpec(*pspec[1:])
|
1039 |
+
else:
|
1040 |
+
return []
|
1041 |
+
|
1042 |
+
def sharded_init_partition_spec_fn(
|
1043 |
+
params, params_partition_spec, partition_spec_for_statistics
|
1044 |
+
):
|
1045 |
+
"""Returns a parallel state tree with PartitionSpec associated with state.
|
1046 |
+
|
1047 |
+
|
1048 |
+
Args:
|
1049 |
+
params: A pytree with params.
|
1050 |
+
params_partition_spec: A pytree with PartitionSpec for params.
|
1051 |
+
partition_spec_for_statistics: PartitionSpec for the statistics.
|
1052 |
+
"""
|
1053 |
+
# Parallel lists of spec, and params.
|
1054 |
+
param_pspec_flat, _ = jax.tree_flatten(
|
1055 |
+
params_partition_spec, is_leaf=lambda x: x is None
|
1056 |
+
)
|
1057 |
+
params_flat, treedef = jax.tree_flatten(params)
|
1058 |
+
assert param_pspec_flat
|
1059 |
+
assert params_flat
|
1060 |
+
# Step is replicated across cores.
|
1061 |
+
# None means cores.
|
1062 |
+
local_stats_flat = []
|
1063 |
+
num_statistics = 0
|
1064 |
+
for param, param_pspec in zip(params_flat, param_pspec_flat):
|
1065 |
+
param_clone = jnp.zeros(param.shape, dtype=param.dtype)
|
1066 |
+
preconditioner = Preconditioner(
|
1067 |
+
param_clone, block_size, best_effort_shape_interpretation
|
1068 |
+
)
|
1069 |
+
shapes = preconditioner.shapes_for_preconditioners()
|
1070 |
+
sizes = []
|
1071 |
+
|
1072 |
+
index_start = num_statistics
|
1073 |
+
if not _skip_preconditioning(param):
|
1074 |
+
sizes = [s[0] for s in shapes]
|
1075 |
+
shapes = preconditioner.shapes_for_preconditioners()
|
1076 |
+
num_statistics += len(shapes)
|
1077 |
+
|
1078 |
+
diagonal_statistics_pspec = []
|
1079 |
+
diagonal_statistics_scale_pspec = []
|
1080 |
+
if _graft_type_has_diagonal_statistics():
|
1081 |
+
# Identically shaped param.
|
1082 |
+
diagonal_statistics_pspec = param_pspec
|
1083 |
+
if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
|
1084 |
+
diagonal_statistics_scale_pspec = (
|
1085 |
+
_remove_leading_sharding_annotation(param_pspec)
|
1086 |
+
)
|
1087 |
+
|
1088 |
+
m1_pspec = []
|
1089 |
+
m1_scale_pspec = []
|
1090 |
+
if _graft_type_has_diagonal_momentum_states():
|
1091 |
+
m1_pspec = param_pspec
|
1092 |
+
if quantized_dtype_for_momentum_buffers() != jnp.float32:
|
1093 |
+
m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
|
1094 |
+
|
1095 |
+
m2_pspec = param_pspec
|
1096 |
+
m2_scale_pspec = []
|
1097 |
+
if quantized_dtype_for_momentum_buffers() != jnp.float32:
|
1098 |
+
m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec)
|
1099 |
+
|
1100 |
+
local_stats_flat.append(
|
1101 |
+
LocalShardedParameterStats(
|
1102 |
+
QuantizedValue(
|
1103 |
+
diagonal_statistics_pspec,
|
1104 |
+
[],
|
1105 |
+
diagonal_statistics_scale_pspec,
|
1106 |
+
quantized_dtype_for_diagonal_statistics_buffers(),
|
1107 |
+
False,
|
1108 |
+
list(param.shape),
|
1109 |
+
),
|
1110 |
+
QuantizedValue(
|
1111 |
+
m1_pspec,
|
1112 |
+
[],
|
1113 |
+
m1_scale_pspec,
|
1114 |
+
quantized_dtype_for_momentum_buffers(),
|
1115 |
+
False,
|
1116 |
+
list(param.shape),
|
1117 |
+
),
|
1118 |
+
QuantizedValue(
|
1119 |
+
m2_pspec,
|
1120 |
+
[],
|
1121 |
+
m2_scale_pspec,
|
1122 |
+
quantized_dtype_for_momentum_buffers(),
|
1123 |
+
False,
|
1124 |
+
list(param.shape),
|
1125 |
+
),
|
1126 |
+
init_training_metrics_pspec(),
|
1127 |
+
index_start,
|
1128 |
+
sizes,
|
1129 |
+
)
|
1130 |
+
)
|
1131 |
+
|
1132 |
+
local_stats = jax.tree_unflatten(treedef, local_stats_flat)
|
1133 |
+
global_stats = GlobalShardedParameterStats(
|
1134 |
+
partition_spec_for_statistics,
|
1135 |
+
partition_spec_for_statistics,
|
1136 |
+
pjit.PartitionSpec(),
|
1137 |
+
)
|
1138 |
+
count_pspec = pjit.PartitionSpec()
|
1139 |
+
return ShampooState(
|
1140 |
+
count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats)
|
1141 |
+
)
|
1142 |
+
|
1143 |
+
def sharded_init_shape_and_dtype_fn(params):
|
1144 |
+
"""Returns a parallel state tree with shape, dtype associated with state.
|
1145 |
+
|
1146 |
+
|
1147 |
+
Args:
|
1148 |
+
params: A pytree with params.
|
1149 |
+
"""
|
1150 |
+
# Parallel lists of spec, and params.
|
1151 |
+
params_flat, treedef = jax.tree_flatten(params)
|
1152 |
+
assert params_flat
|
1153 |
+
# Step is replicated across cores.
|
1154 |
+
# None means cores.
|
1155 |
+
local_stats_flat = []
|
1156 |
+
num_statistics = 0
|
1157 |
+
for param in params_flat:
|
1158 |
+
param_clone = jnp.zeros(param.shape, dtype=param.dtype)
|
1159 |
+
preconditioner = Preconditioner(
|
1160 |
+
param_clone, block_size, best_effort_shape_interpretation
|
1161 |
+
)
|
1162 |
+
shapes = preconditioner.shapes_for_preconditioners()
|
1163 |
+
sizes = []
|
1164 |
+
|
1165 |
+
index_start = num_statistics
|
1166 |
+
if not _skip_preconditioning(param):
|
1167 |
+
sizes = [s[0] for s in shapes]
|
1168 |
+
shapes = preconditioner.shapes_for_preconditioners()
|
1169 |
+
num_statistics += len(shapes)
|
1170 |
+
|
1171 |
+
diagonal_statistics_shape_and_dtype = []
|
1172 |
+
diagonal_statistics_scale_shape_and_dtype = []
|
1173 |
+
if _graft_type_has_diagonal_statistics():
|
1174 |
+
diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype]
|
1175 |
+
qdtype = quantized_dtype_for_diagonal_statistics_buffers()
|
1176 |
+
if qdtype != jnp.float32:
|
1177 |
+
diagonal_statistics_shape_and_dtype = [list(param.shape), qdtype]
|
1178 |
+
diagonal_statistics_scale_shape_and_dtype = [
|
1179 |
+
list(param.shape)[1:],
|
1180 |
+
param.dtype,
|
1181 |
+
]
|
1182 |
+
|
1183 |
+
qdtype = quantized_dtype_for_momentum_buffers()
|
1184 |
+
m1_shape_and_dtype = []
|
1185 |
+
m1_scale_shape_and_dtype = []
|
1186 |
+
if _graft_type_has_diagonal_momentum_states():
|
1187 |
+
m1_shape_and_dtype = [list(param.shape), qdtype]
|
1188 |
+
if quantized_dtype_for_momentum_buffers() != jnp.float32:
|
1189 |
+
m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
|
1190 |
+
|
1191 |
+
m2_shape_and_dtype = [list(param.shape), param.dtype]
|
1192 |
+
m2_scale_shape_and_dtype = []
|
1193 |
+
if qdtype != jnp.float32:
|
1194 |
+
m2_shape_and_dtype = [list(param.shape), qdtype]
|
1195 |
+
m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
|
1196 |
+
|
1197 |
+
local_stats_flat.append(
|
1198 |
+
LocalShardedParameterStats(
|
1199 |
+
QuantizedValue(
|
1200 |
+
diagonal_statistics_shape_and_dtype,
|
1201 |
+
[],
|
1202 |
+
diagonal_statistics_scale_shape_and_dtype,
|
1203 |
+
quantized_dtype_for_diagonal_statistics_buffers(),
|
1204 |
+
False,
|
1205 |
+
list(param.shape),
|
1206 |
+
),
|
1207 |
+
QuantizedValue(
|
1208 |
+
m1_shape_and_dtype,
|
1209 |
+
[],
|
1210 |
+
m1_scale_shape_and_dtype,
|
1211 |
+
quantized_dtype_for_momentum_buffers(),
|
1212 |
+
False,
|
1213 |
+
list(param.shape),
|
1214 |
+
),
|
1215 |
+
QuantizedValue(
|
1216 |
+
m2_shape_and_dtype,
|
1217 |
+
[],
|
1218 |
+
m2_scale_shape_and_dtype,
|
1219 |
+
quantized_dtype_for_momentum_buffers(),
|
1220 |
+
False,
|
1221 |
+
list(param.shape),
|
1222 |
+
),
|
1223 |
+
init_training_metrics_shapes(len(sizes)),
|
1224 |
+
index_start,
|
1225 |
+
sizes,
|
1226 |
+
)
|
1227 |
+
)
|
1228 |
+
|
1229 |
+
local_stats = jax.tree_unflatten(treedef, local_stats_flat)
|
1230 |
+
max_statistics_size = _max_statistics_size_from_params(params_flat)
|
1231 |
+
to_pad = -num_statistics % num_devices_for_pjit
|
1232 |
+
num_statistics += to_pad
|
1233 |
+
if num_statistics == 0:
|
1234 |
+
num_statistics = num_devices_for_pjit
|
1235 |
+
max_statistics_size = block_size
|
1236 |
+
statistics_shape = [num_statistics, max_statistics_size, max_statistics_size]
|
1237 |
+
global_stats = GlobalShardedParameterStats(
|
1238 |
+
[statistics_shape, jnp.float32],
|
1239 |
+
[statistics_shape, jnp.float32],
|
1240 |
+
[[num_statistics], jnp.int32],
|
1241 |
+
)
|
1242 |
+
return ShampooState(
|
1243 |
+
count=[[], jnp.float32],
|
1244 |
+
stats=ShardedShampooStats(global_stats, local_stats),
|
1245 |
+
)
|
1246 |
+
|
1247 |
+
def sharded_update_fn(grads, state, params):
|
1248 |
+
"""Transform the input gradient and update all statistics in sharded mode.
|
1249 |
+
|
1250 |
+
Args:
|
1251 |
+
grads: the gradient tensors for the parameters.
|
1252 |
+
state: a named tuple containing the state of the optimizer
|
1253 |
+
params: the parameters that should be updated.
|
1254 |
+
|
1255 |
+
Returns:
|
1256 |
+
A tuple containing the new parameters and the new optimizer state.
|
1257 |
+
"""
|
1258 |
+
params_flat, treedef = jax.tree_flatten(params)
|
1259 |
+
grads_flat = treedef.flatten_up_to(grads)
|
1260 |
+
|
1261 |
+
global_stats = state.stats.global_stats
|
1262 |
+
local_stats_flat = treedef.flatten_up_to(state.stats.local_stats)
|
1263 |
+
stats_flat = [
|
1264 |
+
_convert_to_parameter_stats(global_stats, local_stat)
|
1265 |
+
for local_stat in local_stats_flat
|
1266 |
+
]
|
1267 |
+
new_stats_flat = jax.tree_multimap(
|
1268 |
+
lambda g, s, p: _compute_stats(g, s, p, state.count),
|
1269 |
+
grads_flat,
|
1270 |
+
stats_flat,
|
1271 |
+
params_flat,
|
1272 |
+
)
|
1273 |
+
|
1274 |
+
outputs = jax.tree_multimap(
|
1275 |
+
lambda g, s, p: _transform_grad(g, s, p, state.count),
|
1276 |
+
grads_flat,
|
1277 |
+
new_stats_flat,
|
1278 |
+
params_flat,
|
1279 |
+
)
|
1280 |
+
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
|
1281 |
+
|
1282 |
+
updates = jax.tree_unflatten(treedef, updates_flat)
|
1283 |
+
# Create new local_stats
|
1284 |
+
new_local_stats_flat = [
|
1285 |
+
_convert_from_parameter_stats(new_stat, local_stat)
|
1286 |
+
for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
|
1287 |
+
]
|
1288 |
+
|
1289 |
+
max_size = global_stats.statistics.shape[1]
|
1290 |
+
new_padded_statistics = []
|
1291 |
+
for stat in new_stats_flat:
|
1292 |
+
new_padded_statistics.extend(
|
1293 |
+
[pad_square_matrix(stat, max_size) for stat in stat.statistics]
|
1294 |
+
)
|
1295 |
+
|
1296 |
+
# Create global stats
|
1297 |
+
# TODO(rohananil): Preconditioner is not updated every step, so cost of
|
1298 |
+
# stack/pad can be obviated away.
|
1299 |
+
# Pad the statistics and preconditioner matrices to be a multiple of
|
1300 |
+
# num devices.
|
1301 |
+
# TODO(rohananil): Relax to only the size of the mesh axis where the dim
|
1302 |
+
# is split on.
|
1303 |
+
to_pad = -len(new_padded_statistics) % num_devices_for_pjit
|
1304 |
+
new_padded_statistics.extend(
|
1305 |
+
[
|
1306 |
+
jnp.eye(max_size, dtype=new_padded_statistics[0].dtype)
|
1307 |
+
for _ in range(to_pad)
|
1308 |
+
]
|
1309 |
+
)
|
1310 |
+
new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
|
1311 |
+
new_stacked_padded_statistics = pjit.with_sharding_constraint(
|
1312 |
+
new_stacked_padded_statistics, statistics_partition_spec
|
1313 |
+
)
|
1314 |
+
|
1315 |
+
def _internal_inverse_pth_root_all():
|
1316 |
+
preconditioners, errors = _matrix_inverse_pth_root_pjit(
|
1317 |
+
new_stacked_padded_statistics,
|
1318 |
+
global_stats.exponents,
|
1319 |
+
statistics_partition_spec,
|
1320 |
+
)
|
1321 |
+
return preconditioners, errors
|
1322 |
+
|
1323 |
+
if preconditioning_compute_steps == 1:
|
1324 |
+
new_preconditioners, errors = _internal_inverse_pth_root_all()
|
1325 |
+
else:
|
1326 |
+
# Passing statistics instead of preconditioners as they are similarly
|
1327 |
+
# shaped tensors. Note statistics will be ignored as we are passing in
|
1328 |
+
# a large init value for error.
|
1329 |
+
preconditioners_init = new_stacked_padded_statistics
|
1330 |
+
n = new_stacked_padded_statistics.shape[0]
|
1331 |
+
errors_init = jnp.ones([n], jnp.float32) * inverse_failure_threshold
|
1332 |
+
init_state = [preconditioners_init, errors_init]
|
1333 |
+
perform_step = state.count % preconditioning_compute_steps == 0
|
1334 |
+
new_preconditioners, errors = efficient_cond(
|
1335 |
+
perform_step, _internal_inverse_pth_root_all, init_state
|
1336 |
+
)
|
1337 |
+
|
1338 |
+
new_local_stats_flat = _add_error_into_local_stats(
|
1339 |
+
new_local_stats_flat, errors, inverse_failure_threshold
|
1340 |
+
)
|
1341 |
+
new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
|
1342 |
+
errors = errors.reshape((-1, 1, 1))
|
1343 |
+
predicate = jnp.logical_or(
|
1344 |
+
jnp.isnan(errors), errors >= inverse_failure_threshold
|
1345 |
+
).astype(new_preconditioners.dtype)
|
1346 |
+
# TODO(rohananil): Check for numerical instabilities.
|
1347 |
+
new_conditional_preconditioners = (
|
1348 |
+
predicate * global_stats.preconditioners
|
1349 |
+
+ (1.0 - predicate) * new_preconditioners
|
1350 |
+
)
|
1351 |
+
new_global_stats = GlobalShardedParameterStats(
|
1352 |
+
new_stacked_padded_statistics,
|
1353 |
+
new_conditional_preconditioners,
|
1354 |
+
global_stats.exponents,
|
1355 |
+
)
|
1356 |
+
new_shampoo_state = ShampooState(
|
1357 |
+
count=state.count + 1,
|
1358 |
+
stats=ShardedShampooStats(new_global_stats, new_local_stats),
|
1359 |
+
)
|
1360 |
+
return updates, new_shampoo_state
|
1361 |
+
|
1362 |
+
def init_fn(params):
|
1363 |
+
"""Initialise the optimiser's state."""
|
1364 |
+
|
1365 |
+
def _init(param):
|
1366 |
+
preconditioner = Preconditioner(
|
1367 |
+
param, block_size, best_effort_shape_interpretation
|
1368 |
+
)
|
1369 |
+
statistics = []
|
1370 |
+
preconditioners = []
|
1371 |
+
if not _skip_preconditioning(param):
|
1372 |
+
shapes = preconditioner.shapes_for_preconditioners()
|
1373 |
+
statistics = [
|
1374 |
+
matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes
|
1375 |
+
]
|
1376 |
+
preconditioners = [jnp.eye(s[0], dtype=jnp.float32) for s in shapes]
|
1377 |
+
|
1378 |
+
diagonal_statistics = []
|
1379 |
+
if _graft_type_has_diagonal_statistics():
|
1380 |
+
diagonal_statistics = jnp.zeros_like(param)
|
1381 |
+
|
1382 |
+
diagonal_momentum = _quantize_momentum([])
|
1383 |
+
momentum = _quantize_momentum(jnp.zeros_like(param))
|
1384 |
+
if _graft_type_has_diagonal_momentum_states():
|
1385 |
+
diagonal_momentum = _quantize_momentum(jnp.zeros_like(param))
|
1386 |
+
|
1387 |
+
return ParameterStats(
|
1388 |
+
_quantize_diagonal_statistics(diagonal_statistics),
|
1389 |
+
_maybe_quantize_statistics(statistics),
|
1390 |
+
_maybe_quantize_preconditioners(preconditioners),
|
1391 |
+
diagonal_momentum,
|
1392 |
+
momentum,
|
1393 |
+
init_training_metrics(len(statistics)),
|
1394 |
+
)
|
1395 |
+
|
1396 |
+
return ShampooState(
|
1397 |
+
count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)
|
1398 |
+
)
|
1399 |
+
|
1400 |
+
def _skip_preconditioning(param):
|
1401 |
+
return len(param.shape) < 1 or any(
|
1402 |
+
[s > skip_preconditioning_dim_size_gt for s in param.shape]
|
1403 |
+
)
|
1404 |
+
|
1405 |
+
def _compute_stats(grad, state, param, step):
|
1406 |
+
"""Compute per-parameter statistics."""
|
1407 |
+
preconditioner = Preconditioner(
|
1408 |
+
param, block_size, best_effort_shape_interpretation
|
1409 |
+
)
|
1410 |
+
new_statistics = [[]] * len(state.statistics)
|
1411 |
+
w1 = beta2
|
1412 |
+
w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
|
1413 |
+
if not _skip_preconditioning(param):
|
1414 |
+
|
1415 |
+
def compute_updated_statistics():
|
1416 |
+
new_stats = preconditioner.statistics_from_grad(grad)
|
1417 |
+
new_stats_accumulators = []
|
1418 |
+
for stat, stat_accumulator in zip(new_stats, state.statistics):
|
1419 |
+
new_stats_accumulators.append(
|
1420 |
+
w1 * _to_float(stat_accumulator) + w2 * stat
|
1421 |
+
)
|
1422 |
+
return _maybe_quantize_statistics(new_stats_accumulators)
|
1423 |
+
|
1424 |
+
if statistics_compute_steps > 1:
|
1425 |
+
perform_step = step % statistics_compute_steps == 0
|
1426 |
+
init_state = state.statistics
|
1427 |
+
new_statistics = list(
|
1428 |
+
efficient_cond(perform_step, compute_updated_statistics, init_state)
|
1429 |
+
)
|
1430 |
+
else:
|
1431 |
+
new_statistics = compute_updated_statistics()
|
1432 |
+
return ParameterStats(
|
1433 |
+
state.diagonal_statistics,
|
1434 |
+
new_statistics,
|
1435 |
+
state.preconditioners,
|
1436 |
+
state.diagonal_momentum,
|
1437 |
+
state.momentum,
|
1438 |
+
state.training_metrics,
|
1439 |
+
)
|
1440 |
+
|
1441 |
+
def _matrix_inverse_pth_root_vmap(xs, ps):
|
1442 |
+
mi_pth_root = functools.partial(
|
1443 |
+
matrix_inverse_pth_root, ridge_epsilon=matrix_epsilon, precision=precision
|
1444 |
+
)
|
1445 |
+
return jax.vmap(mi_pth_root)(xs, ps)
|
1446 |
+
|
1447 |
+
def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps):
|
1448 |
+
def _quantized_to_float(qx, qd, qb):
|
1449 |
+
qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape))
|
1450 |
+
return qv.to_float()
|
1451 |
+
|
1452 |
+
def matrix_inverse_pth_root_wrapper(qx, qd, qb, p):
|
1453 |
+
v = _quantized_to_float(qx, qd, qb)
|
1454 |
+
preconditioner, error = matrix_inverse_pth_root(
|
1455 |
+
v, p, ridge_epsilon=matrix_epsilon, precision=precision
|
1456 |
+
)
|
1457 |
+
qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True)
|
1458 |
+
return qp.quantized, qp.diagonal, qp.bucket_size, error
|
1459 |
+
|
1460 |
+
return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
|
1461 |
+
|
1462 |
+
def _matrix_inverse_pth_root_pjit(xs, ps, statistics_partition_spec=None):
|
1463 |
+
# Partition the concatenated statistics matrix across all cores.
|
1464 |
+
pspec_for_partition = preconditioner_partition_spec
|
1465 |
+
partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition)
|
1466 |
+
partitioned_ps = pjit.with_sharding_constraint(
|
1467 |
+
ps, pjit.PartitionSpec(preconditioner_partition_spec[0])
|
1468 |
+
)
|
1469 |
+
# Run matrix inverse pth root on each shard.
|
1470 |
+
partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
|
1471 |
+
partitioned_xs, partitioned_ps
|
1472 |
+
)
|
1473 |
+
# Reshard output to have the same PSpec as input. This is required to avoid
|
1474 |
+
# vmap seeing the full set of statistics.
|
1475 |
+
partitioned_preconditioners = pjit.with_sharding_constraint(
|
1476 |
+
partitioned_preconditioners, pspec_for_partition
|
1477 |
+
)
|
1478 |
+
# Recombine the outputs at each core.
|
1479 |
+
preconditioners = pjit.with_sharding_constraint(
|
1480 |
+
partitioned_preconditioners, statistics_partition_spec
|
1481 |
+
)
|
1482 |
+
errors = pjit.with_sharding_constraint(partitioned_errors, pjit.PartitionSpec())
|
1483 |
+
return preconditioners, errors
|
1484 |
+
|
1485 |
+
def _pmap_compute_preconditioners(
|
1486 |
+
states,
|
1487 |
+
step,
|
1488 |
+
statistics,
|
1489 |
+
num_statistics_per_state,
|
1490 |
+
original_shapes,
|
1491 |
+
exponents,
|
1492 |
+
max_size,
|
1493 |
+
prev_preconditioners,
|
1494 |
+
):
|
1495 |
+
"""Computes preconditioners for given statistics in states in PMAP mode.
|
1496 |
+
|
1497 |
+
Args:
|
1498 |
+
states: A list of optimizer states.
|
1499 |
+
step: Current step number
|
1500 |
+
statistics: A list of statistics for all variables (for every dim)
|
1501 |
+
num_statistics_per_state: Number of statistis per state to reconstruct
|
1502 |
+
output states.
|
1503 |
+
original_shapes: A list of shapes of the statistics.
|
1504 |
+
exponents: Exponent power to use for inverse-pth roots.
|
1505 |
+
max_size: Maximum dim of the statistics to pad.
|
1506 |
+
prev_preconditioners: Previously available preconditioner.
|
1507 |
+
|
1508 |
+
Returns:
|
1509 |
+
New optimizer states after computing the preconditioner.
|
1510 |
+
"""
|
1511 |
+
num_devices = lax.psum(1, batch_axis_name)
|
1512 |
+
num_statistics = len(statistics)
|
1513 |
+
# Pad statistics and exponents to next multiple of num_devices.
|
1514 |
+
packed_statistics = [pad_square_matrix(stat, max_size) for stat in statistics]
|
1515 |
+
to_pad = -num_statistics % num_devices
|
1516 |
+
packed_statistics.extend(
|
1517 |
+
[jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad)]
|
1518 |
+
)
|
1519 |
+
exponents.extend([1 for _ in range(to_pad)])
|
1520 |
+
|
1521 |
+
if not packed_statistics:
|
1522 |
+
return states
|
1523 |
+
|
1524 |
+
all_statistics = batch(packed_statistics, num_devices)
|
1525 |
+
all_exponents = batch(exponents, num_devices)
|
1526 |
+
|
1527 |
+
def _internal_inverse_pth_root_all():
|
1528 |
+
current_replica = lax.axis_index(batch_axis_name)
|
1529 |
+
preconditioners, errors = _matrix_inverse_pth_root_vmap(
|
1530 |
+
all_statistics[current_replica], all_exponents[current_replica]
|
1531 |
+
)
|
1532 |
+
preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
|
1533 |
+
errors = jax.lax.all_gather(errors, batch_axis_name)
|
1534 |
+
preconditioners_flat = unbatch(preconditioners)
|
1535 |
+
errors_flat = unbatch(errors)
|
1536 |
+
return preconditioners_flat, errors_flat
|
1537 |
+
|
1538 |
+
if preconditioning_compute_steps == 1:
|
1539 |
+
preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
|
1540 |
+
else:
|
1541 |
+
# Passing statistics instead of preconditioners as they are similarly
|
1542 |
+
# shaped tensors. Note statistics will be ignored as we are passing in
|
1543 |
+
# a large init value for error.
|
1544 |
+
preconditioners_init = packed_statistics
|
1545 |
+
errors_init = [inverse_failure_threshold] * len(packed_statistics)
|
1546 |
+
init_state = [preconditioners_init, errors_init]
|
1547 |
+
perform_step = step % preconditioning_compute_steps == 0
|
1548 |
+
preconditioners_flat, errors_flat = efficient_cond(
|
1549 |
+
perform_step, _internal_inverse_pth_root_all, init_state
|
1550 |
+
)
|
1551 |
+
|
1552 |
+
def _skip(error):
|
1553 |
+
condition = jnp.logical_or(
|
1554 |
+
jnp.isnan(error), error >= inverse_failure_threshold
|
1555 |
+
)
|
1556 |
+
return condition.astype(error.dtype)
|
1557 |
+
|
1558 |
+
def _select_preconditioner(error, new_p, old_p):
|
1559 |
+
return lax.cond(
|
1560 |
+
_skip(error), lambda _: old_p, lambda _: new_p, operand=None
|
1561 |
+
)
|
1562 |
+
|
1563 |
+
new_preconditioners_flat = []
|
1564 |
+
new_errors_flat = []
|
1565 |
+
for p, shape, prev_p, error in zip(
|
1566 |
+
preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
|
1567 |
+
):
|
1568 |
+
new_preconditioners_flat.append(
|
1569 |
+
_select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
|
1570 |
+
)
|
1571 |
+
new_errors_flat.append(error)
|
1572 |
+
|
1573 |
+
assert len(states) == len(num_statistics_per_state)
|
1574 |
+
assert len(new_preconditioners_flat) == num_statistics
|
1575 |
+
assert len(new_errors_flat) == num_statistics
|
1576 |
+
|
1577 |
+
# Add back empty preconditioners so we that we can set the optimizer state.
|
1578 |
+
preconditioners_for_states = []
|
1579 |
+
idx = 0
|
1580 |
+
errors_for_states = []
|
1581 |
+
for num_statistics, state in zip(num_statistics_per_state, states):
|
1582 |
+
if num_statistics == 0:
|
1583 |
+
preconditioners_for_states.append([])
|
1584 |
+
errors_for_states.append([])
|
1585 |
+
else:
|
1586 |
+
preconditioners_for_state = new_preconditioners_flat[
|
1587 |
+
idx : idx + num_statistics
|
1588 |
+
]
|
1589 |
+
assert len(state.statistics) == len(preconditioners_for_state)
|
1590 |
+
preconditioners_for_states.append(preconditioners_for_state)
|
1591 |
+
|
1592 |
+
errors_for_state = jnp.stack(
|
1593 |
+
new_errors_flat[idx : idx + num_statistics]
|
1594 |
+
)
|
1595 |
+
assert len(state.statistics) == len(errors_for_state)
|
1596 |
+
errors_for_states.append(errors_for_state)
|
1597 |
+
|
1598 |
+
idx += num_statistics
|
1599 |
+
new_states = []
|
1600 |
+
for state, new_preconditioners, new_errors in zip(
|
1601 |
+
states, preconditioners_for_states, errors_for_states
|
1602 |
+
):
|
1603 |
+
if state.statistics:
|
1604 |
+
new_errors = jnp.where(
|
1605 |
+
jnp.logical_and(
|
1606 |
+
new_errors > 0.0, new_errors != inverse_failure_threshold
|
1607 |
+
),
|
1608 |
+
new_errors,
|
1609 |
+
state.training_metrics.inverse_pth_root_errors,
|
1610 |
+
)
|
1611 |
+
new_training_metrics = TrainingMetrics(new_errors)
|
1612 |
+
new_states.append(
|
1613 |
+
ParameterStats(
|
1614 |
+
state.diagonal_statistics,
|
1615 |
+
state.statistics,
|
1616 |
+
new_preconditioners,
|
1617 |
+
state.diagonal_momentum,
|
1618 |
+
state.momentum,
|
1619 |
+
new_training_metrics,
|
1620 |
+
)
|
1621 |
+
)
|
1622 |
+
|
1623 |
+
return new_states
|
1624 |
+
|
1625 |
+
def _pmap_quantized_compute_preconditioners(
|
1626 |
+
states,
|
1627 |
+
step,
|
1628 |
+
statistics,
|
1629 |
+
num_statistics_per_state,
|
1630 |
+
original_shapes,
|
1631 |
+
exponents,
|
1632 |
+
max_size,
|
1633 |
+
prev_preconditioners,
|
1634 |
+
):
|
1635 |
+
"""Computes preconditioners for given statistics in states in PMAP mode.
|
1636 |
+
|
1637 |
+
For quantization, each statistic is represented by three values:
|
1638 |
+
quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots
|
1639 |
+
without ever recreating the original matrix in f32.
|
1640 |
+
|
1641 |
+
Args:
|
1642 |
+
states: A list of optimizer states.
|
1643 |
+
step: Current step number
|
1644 |
+
statistics: A list of statistics for all variables (for every dim)
|
1645 |
+
num_statistics_per_state: Number of statistis per state to reconstruct
|
1646 |
+
output states.
|
1647 |
+
original_shapes: A list of shapes of the statistics.
|
1648 |
+
exponents: Exponent power to use for inverse-pth roots.
|
1649 |
+
max_size: Maximum dim of the statistics to pad.
|
1650 |
+
prev_preconditioners: Previously available preconditioner.
|
1651 |
+
|
1652 |
+
Returns:
|
1653 |
+
New optimizer states after computing the preconditioner.
|
1654 |
+
"""
|
1655 |
+
num_devices = lax.psum(1, batch_axis_name)
|
1656 |
+
num_statistics = len(statistics)
|
1657 |
+
quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
|
1658 |
+
# Complexity here is around: shapes needing be statically shaped,
|
1659 |
+
# our custom quantization type requires a different type of packing.
|
1660 |
+
|
1661 |
+
# Parallel tensors:
|
1662 |
+
# quantized [dxd]
|
1663 |
+
# diagonals [d] f32
|
1664 |
+
# bucket_sizes [d] f32
|
1665 |
+
packed_quantized_statistics = [
|
1666 |
+
pad_square_matrix(stat.quantized, max_size) for stat in statistics
|
1667 |
+
]
|
1668 |
+
packed_quantized_diagonals = [
|
1669 |
+
pad_vector(stat.diagonal, max_size) for stat in statistics
|
1670 |
+
]
|
1671 |
+
packed_quantized_bucket_sizes = [
|
1672 |
+
pad_vector(stat.bucket_size, max_size) for stat in statistics
|
1673 |
+
]
|
1674 |
+
|
1675 |
+
to_pad = -num_statistics % num_devices
|
1676 |
+
padded_eye = jnp.eye(max_size, dtype=jnp.float32)
|
1677 |
+
quantized_eye = QuantizedValue.from_float_value(
|
1678 |
+
padded_eye, quantized_dtype, True
|
1679 |
+
)
|
1680 |
+
packed_quantized_statistics.extend(
|
1681 |
+
[quantized_eye.quantized for _ in range(to_pad)]
|
1682 |
+
)
|
1683 |
+
packed_quantized_diagonals.extend(
|
1684 |
+
[quantized_eye.diagonal for _ in range(to_pad)]
|
1685 |
+
)
|
1686 |
+
packed_quantized_bucket_sizes.extend(
|
1687 |
+
[quantized_eye.bucket_size for _ in range(to_pad)]
|
1688 |
+
)
|
1689 |
+
exponents.extend([1 for _ in range(to_pad)])
|
1690 |
+
|
1691 |
+
if not packed_quantized_statistics:
|
1692 |
+
return states
|
1693 |
+
|
1694 |
+
all_quantized_statistics = batch(packed_quantized_statistics, num_devices)
|
1695 |
+
all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices)
|
1696 |
+
all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes, num_devices)
|
1697 |
+
all_exponents = batch(exponents, num_devices)
|
1698 |
+
|
1699 |
+
def _internal_inverse_pth_root_all():
|
1700 |
+
current_replica = lax.axis_index(batch_axis_name)
|
1701 |
+
(
|
1702 |
+
quantized_preconditioners,
|
1703 |
+
quantized_diagonals,
|
1704 |
+
quantized_bucket_sizes,
|
1705 |
+
errors,
|
1706 |
+
) = _quantized_matrix_inverse_pth_root_vmap(
|
1707 |
+
all_quantized_statistics[current_replica],
|
1708 |
+
all_quantized_diagonals[current_replica],
|
1709 |
+
all_quantized_bucket_sizes[current_replica],
|
1710 |
+
all_exponents[current_replica],
|
1711 |
+
)
|
1712 |
+
quantized_preconditioners = jax.lax.all_gather(
|
1713 |
+
quantized_preconditioners, batch_axis_name
|
1714 |
+
)
|
1715 |
+
quantized_diagonals = jax.lax.all_gather(
|
1716 |
+
quantized_diagonals, batch_axis_name
|
1717 |
+
)
|
1718 |
+
quantized_bucket_sizes = jax.lax.all_gather(
|
1719 |
+
quantized_bucket_sizes, batch_axis_name
|
1720 |
+
)
|
1721 |
+
errors = jax.lax.all_gather(errors, batch_axis_name)
|
1722 |
+
quantized_preconditioners_flat = unbatch(quantized_preconditioners)
|
1723 |
+
quantized_diagonals_flat = unbatch(quantized_diagonals)
|
1724 |
+
quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes)
|
1725 |
+
errors_flat = unbatch(errors)
|
1726 |
+
return (
|
1727 |
+
quantized_preconditioners_flat,
|
1728 |
+
quantized_diagonals_flat,
|
1729 |
+
quantized_bucket_sizes_flat,
|
1730 |
+
errors_flat,
|
1731 |
+
)
|
1732 |
+
|
1733 |
+
if preconditioning_compute_steps == 1:
|
1734 |
+
(
|
1735 |
+
quantized_preconditioners_flat,
|
1736 |
+
quantized_diagonals_flat,
|
1737 |
+
quantized_bucket_sizes_flat,
|
1738 |
+
errors_flat,
|
1739 |
+
) = _internal_inverse_pth_root_all()
|
1740 |
+
else:
|
1741 |
+
# Passing statistics instead of preconditioners as they are similarly
|
1742 |
+
# shaped tensors. Note statistics will be ignored as we are passing in
|
1743 |
+
# a large init value for error.
|
1744 |
+
quantized_preconditioners_init = packed_quantized_statistics
|
1745 |
+
quantized_diagonals_init = packed_quantized_diagonals
|
1746 |
+
quantized_bucket_sizes_init = packed_quantized_bucket_sizes
|
1747 |
+
errors_init = [inverse_failure_threshold] * len(
|
1748 |
+
quantized_preconditioners_init
|
1749 |
+
)
|
1750 |
+
init_state = [
|
1751 |
+
quantized_preconditioners_init,
|
1752 |
+
quantized_diagonals_init,
|
1753 |
+
quantized_bucket_sizes_init,
|
1754 |
+
errors_init,
|
1755 |
+
]
|
1756 |
+
perform_step = step % preconditioning_compute_steps == 0
|
1757 |
+
(
|
1758 |
+
quantized_preconditioners_flat,
|
1759 |
+
quantized_diagonals_flat,
|
1760 |
+
quantized_bucket_sizes_flat,
|
1761 |
+
errors_flat,
|
1762 |
+
) = efficient_cond(perform_step, _internal_inverse_pth_root_all, init_state)
|
1763 |
+
|
1764 |
+
def _skip(error):
|
1765 |
+
condition = jnp.logical_or(
|
1766 |
+
jnp.isnan(error), error >= inverse_failure_threshold
|
1767 |
+
)
|
1768 |
+
return condition.astype(error.dtype)
|
1769 |
+
|
1770 |
+
def _select_preconditioner(error, new_p, old_p):
|
1771 |
+
return lax.cond(
|
1772 |
+
_skip(error), lambda _: old_p, lambda _: new_p, operand=None
|
1773 |
+
)
|
1774 |
+
|
1775 |
+
new_quantized_preconditioners_flat = []
|
1776 |
+
new_quantized_diagonals_flat = []
|
1777 |
+
new_quantized_bucket_sizes_flat = []
|
1778 |
+
new_errors_flat = []
|
1779 |
+
for p, d, b, shape, prev_p, error in zip(
|
1780 |
+
quantized_preconditioners_flat,
|
1781 |
+
quantized_diagonals_flat,
|
1782 |
+
quantized_bucket_sizes_flat,
|
1783 |
+
original_shapes,
|
1784 |
+
prev_preconditioners,
|
1785 |
+
errors_flat,
|
1786 |
+
):
|
1787 |
+
new_quantized_preconditioners_flat.append(
|
1788 |
+
_select_preconditioner(
|
1789 |
+
error, p[: shape[0], : shape[1]], prev_p.quantized
|
1790 |
+
)
|
1791 |
+
)
|
1792 |
+
new_quantized_diagonals_flat.append(
|
1793 |
+
_select_preconditioner(error, d[: shape[0]], prev_p.diagonal)
|
1794 |
+
)
|
1795 |
+
new_quantized_bucket_sizes_flat.append(
|
1796 |
+
_select_preconditioner(error, b[: shape[0]], prev_p.bucket_size)
|
1797 |
+
)
|
1798 |
+
new_errors_flat.append(error)
|
1799 |
+
|
1800 |
+
assert len(states) == len(num_statistics_per_state)
|
1801 |
+
assert len(new_quantized_preconditioners_flat) == num_statistics
|
1802 |
+
assert len(new_quantized_diagonals_flat) == num_statistics
|
1803 |
+
assert len(new_quantized_bucket_sizes_flat) == num_statistics
|
1804 |
+
|
1805 |
+
# Add back empty preconditioners so we that we can set the optimizer state.
|
1806 |
+
preconditioners_for_states = []
|
1807 |
+
errors_for_states = []
|
1808 |
+
idx = 0
|
1809 |
+
for num_statistics, state in zip(num_statistics_per_state, states):
|
1810 |
+
if num_statistics == 0:
|
1811 |
+
preconditioners_for_states.append([])
|
1812 |
+
errors_for_states.append([])
|
1813 |
+
else:
|
1814 |
+
quantized_preconditioners_for_state = (
|
1815 |
+
new_quantized_preconditioners_flat[idx : idx + num_statistics]
|
1816 |
+
)
|
1817 |
+
quantized_diagonals_for_state = new_quantized_diagonals_flat[
|
1818 |
+
idx : idx + num_statistics
|
1819 |
+
]
|
1820 |
+
quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
|
1821 |
+
idx : idx + num_statistics
|
1822 |
+
]
|
1823 |
+
errors_for_state = jnp.stack(
|
1824 |
+
new_errors_flat[idx : idx + num_statistics]
|
1825 |
+
)
|
1826 |
+
|
1827 |
+
assert len(state.statistics) == len(quantized_preconditioners_for_state)
|
1828 |
+
assert len(state.statistics) == len(quantized_diagonals_for_state)
|
1829 |
+
assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
|
1830 |
+
assert len(state.statistics) == len(errors_for_state)
|
1831 |
+
|
1832 |
+
quantized_preconditioners = []
|
1833 |
+
for qv, qd, qb in zip(
|
1834 |
+
quantized_preconditioners_for_state,
|
1835 |
+
quantized_diagonals_for_state,
|
1836 |
+
quantized_bucket_sizes_for_state,
|
1837 |
+
):
|
1838 |
+
quantized_preconditioners.append(
|
1839 |
+
QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape))
|
1840 |
+
)
|
1841 |
+
preconditioners_for_states.append(quantized_preconditioners)
|
1842 |
+
errors_for_states.append(errors_for_state)
|
1843 |
+
idx += num_statistics
|
1844 |
+
new_states = []
|
1845 |
+
for state, new_preconditioners, new_errors in zip(
|
1846 |
+
states, preconditioners_for_states, errors_for_states
|
1847 |
+
):
|
1848 |
+
if state.statistics:
|
1849 |
+
new_errors = jnp.where(
|
1850 |
+
jnp.logical_and(
|
1851 |
+
new_errors > 0.0, new_errors != inverse_failure_threshold
|
1852 |
+
),
|
1853 |
+
new_errors,
|
1854 |
+
state.training_metrics.inverse_pth_root_errors,
|
1855 |
+
)
|
1856 |
+
new_training_metrics = TrainingMetrics(new_errors)
|
1857 |
+
new_states.append(
|
1858 |
+
ParameterStats(
|
1859 |
+
state.diagonal_statistics,
|
1860 |
+
state.statistics,
|
1861 |
+
new_preconditioners,
|
1862 |
+
state.diagonal_momentum,
|
1863 |
+
state.momentum,
|
1864 |
+
new_training_metrics,
|
1865 |
+
)
|
1866 |
+
)
|
1867 |
+
|
1868 |
+
return new_states
|
1869 |
+
|
1870 |
+
def _pjit_compute_preconditioners(
|
1871 |
+
states,
|
1872 |
+
step,
|
1873 |
+
statistics,
|
1874 |
+
num_statistics_per_state,
|
1875 |
+
original_shapes,
|
1876 |
+
exponents,
|
1877 |
+
max_size,
|
1878 |
+
prev_preconditioners,
|
1879 |
+
):
|
1880 |
+
"""Computes preconditioners for given statistics in states in PJIT mode.
|
1881 |
+
|
1882 |
+
Args:
|
1883 |
+
states: A list of optimizer states.
|
1884 |
+
step: Current step number
|
1885 |
+
statistics: A list of statistics for all variables (for every dim)
|
1886 |
+
num_statistics_per_state: Number of statistis per state to reconstruct
|
1887 |
+
output states.
|
1888 |
+
original_shapes: A list of shapes of the statistics.
|
1889 |
+
exponents: Exponent power to use for inverse-pth roots.
|
1890 |
+
max_size: Maximum dim of the statistics to pad.
|
1891 |
+
prev_preconditioners: Previously available preconditioner.
|
1892 |
+
|
1893 |
+
Returns:
|
1894 |
+
New optimizer states after computing the preconditioner.
|
1895 |
+
"""
|
1896 |
+
num_statistics = len(statistics)
|
1897 |
+
to_pad = -num_statistics % num_devices_for_pjit
|
1898 |
+
padded_statistics = [pad_square_matrix(stat, max_size) for stat in statistics]
|
1899 |
+
padded_statistics.extend(
|
1900 |
+
[jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
|
1901 |
+
)
|
1902 |
+
exponents.extend([1 for _ in range(to_pad)])
|
1903 |
+
all_statistics = jnp.stack(padded_statistics)
|
1904 |
+
all_exponents = jnp.stack(exponents)
|
1905 |
+
|
1906 |
+
def _internal_inverse_pth_root_all():
|
1907 |
+
preconditioners, errors = _matrix_inverse_pth_root_pjit(
|
1908 |
+
all_statistics, all_exponents
|
1909 |
+
)
|
1910 |
+
b1 = preconditioners.shape[0]
|
1911 |
+
|
1912 |
+
def split(batched_values):
|
1913 |
+
return [
|
1914 |
+
jnp.squeeze(v)
|
1915 |
+
for v in jnp.split(batched_values, indices_or_sections=b1, axis=0)
|
1916 |
+
]
|
1917 |
+
|
1918 |
+
return split(preconditioners), split(errors)
|
1919 |
+
|
1920 |
+
if preconditioning_compute_steps == 1:
|
1921 |
+
preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
|
1922 |
+
else:
|
1923 |
+
# Passing statistics instead of preconditioners as they are similarly
|
1924 |
+
# shaped tensors. Note statistics will be ignored as we are passing in
|
1925 |
+
# a large init value for error.
|
1926 |
+
preconditioners_init = padded_statistics
|
1927 |
+
errors_init = [inverse_failure_threshold] * len(padded_statistics)
|
1928 |
+
init_state = [preconditioners_init, errors_init]
|
1929 |
+
perform_step = step % preconditioning_compute_steps == 0
|
1930 |
+
preconditioners_flat, errors_flat = efficient_cond(
|
1931 |
+
perform_step, _internal_inverse_pth_root_all, init_state
|
1932 |
+
)
|
1933 |
+
|
1934 |
+
def _skip(error):
|
1935 |
+
condition = jnp.logical_or(
|
1936 |
+
jnp.isnan(error), error >= inverse_failure_threshold
|
1937 |
+
)
|
1938 |
+
return condition.astype(error.dtype)
|
1939 |
+
|
1940 |
+
def _select_preconditioner(error, new_p, old_p):
|
1941 |
+
return lax.cond(
|
1942 |
+
_skip(error), lambda _: old_p, lambda _: new_p, operand=None
|
1943 |
+
)
|
1944 |
+
|
1945 |
+
new_preconditioners_flat = []
|
1946 |
+
new_errors_flat = []
|
1947 |
+
for p, shape, prev_p, error in zip(
|
1948 |
+
preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
|
1949 |
+
):
|
1950 |
+
new_preconditioners_flat.append(
|
1951 |
+
_select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
|
1952 |
+
)
|
1953 |
+
new_errors_flat.append(error)
|
1954 |
+
|
1955 |
+
assert len(states) == len(num_statistics_per_state)
|
1956 |
+
assert len(new_preconditioners_flat) == num_statistics
|
1957 |
+
|
1958 |
+
# Add back empty preconditioners so we that we can set the optimizer state.
|
1959 |
+
preconditioners_for_states = []
|
1960 |
+
errors_for_states = []
|
1961 |
+
idx = 0
|
1962 |
+
for num_statistics, state in zip(num_statistics_per_state, states):
|
1963 |
+
if num_statistics == 0:
|
1964 |
+
preconditioners_for_states.append([])
|
1965 |
+
errors_for_states.append([])
|
1966 |
+
else:
|
1967 |
+
preconditioners_for_state = new_preconditioners_flat[
|
1968 |
+
idx : idx + num_statistics
|
1969 |
+
]
|
1970 |
+
assert len(state.statistics) == len(preconditioners_for_state)
|
1971 |
+
preconditioners_for_states.append(preconditioners_for_state)
|
1972 |
+
|
1973 |
+
errors_for_state = jnp.stack(
|
1974 |
+
new_errors_flat[idx : idx + num_statistics]
|
1975 |
+
)
|
1976 |
+
assert len(state.statistics) == len(errors_for_state)
|
1977 |
+
errors_for_states.append(errors_for_state)
|
1978 |
+
idx += num_statistics
|
1979 |
+
|
1980 |
+
new_states = []
|
1981 |
+
for state, new_preconditioners, new_errors in zip(
|
1982 |
+
states, preconditioners_for_states, errors_for_states
|
1983 |
+
):
|
1984 |
+
if state.statistics:
|
1985 |
+
new_errors = jnp.where(
|
1986 |
+
jnp.logical_and(
|
1987 |
+
new_errors > 0.0, new_errors != inverse_failure_threshold
|
1988 |
+
),
|
1989 |
+
new_errors,
|
1990 |
+
state.training_metrics.inverse_pth_root_errors,
|
1991 |
+
)
|
1992 |
+
new_training_metrics = TrainingMetrics(new_errors)
|
1993 |
+
new_states.append(
|
1994 |
+
ParameterStats(
|
1995 |
+
state.diagonal_statistics,
|
1996 |
+
state.statistics,
|
1997 |
+
new_preconditioners,
|
1998 |
+
state.diagonal_momentum,
|
1999 |
+
state.momentum,
|
2000 |
+
new_training_metrics,
|
2001 |
+
)
|
2002 |
+
)
|
2003 |
+
|
2004 |
+
return new_states
|
2005 |
+
|
2006 |
+
def _compute_preconditioners(states, params, step):
|
2007 |
+
"""Computes preconditioners for given statistics in states.
|
2008 |
+
|
2009 |
+
Args:
|
2010 |
+
states: A list of optimizer states.
|
2011 |
+
params: A list of params.
|
2012 |
+
step: Current step number
|
2013 |
+
|
2014 |
+
Returns:
|
2015 |
+
New optimizer states after computing the preconditioner.
|
2016 |
+
"""
|
2017 |
+
statistics = []
|
2018 |
+
num_statistics_per_state = []
|
2019 |
+
original_shapes = []
|
2020 |
+
exponents = []
|
2021 |
+
max_size = 0
|
2022 |
+
prev_preconditioners = []
|
2023 |
+
|
2024 |
+
for state, param in zip(states, params):
|
2025 |
+
num_statistics = len(state.statistics)
|
2026 |
+
num_statistics_per_state.append(num_statistics)
|
2027 |
+
original_shapes_for_state = []
|
2028 |
+
if num_statistics > 0:
|
2029 |
+
preconditioner = Preconditioner(
|
2030 |
+
param, block_size, best_effort_shape_interpretation
|
2031 |
+
)
|
2032 |
+
for statistic in state.statistics:
|
2033 |
+
exponents.append(
|
2034 |
+
preconditioner.exponent_for_preconditioner()
|
2035 |
+
if exponent_override == 0
|
2036 |
+
else exponent_override
|
2037 |
+
)
|
2038 |
+
original_shapes_for_state.append(statistic.shape)
|
2039 |
+
max_size = max(max_size, statistic.shape[0])
|
2040 |
+
|
2041 |
+
statistics.extend(state.statistics)
|
2042 |
+
prev_preconditioners.extend(state.preconditioners)
|
2043 |
+
original_shapes.extend(original_shapes_for_state)
|
2044 |
+
|
2045 |
+
if batch_axis_name:
|
2046 |
+
# Quantization is only enabled if batch_axis_name is not set.
|
2047 |
+
quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
|
2048 |
+
|
2049 |
+
if quantized_dtype == jnp.float32:
|
2050 |
+
return _pmap_compute_preconditioners(
|
2051 |
+
states,
|
2052 |
+
step,
|
2053 |
+
statistics,
|
2054 |
+
num_statistics_per_state,
|
2055 |
+
original_shapes,
|
2056 |
+
exponents,
|
2057 |
+
max_size,
|
2058 |
+
prev_preconditioners,
|
2059 |
+
)
|
2060 |
+
else:
|
2061 |
+
return _pmap_quantized_compute_preconditioners(
|
2062 |
+
states,
|
2063 |
+
step,
|
2064 |
+
statistics,
|
2065 |
+
num_statistics_per_state,
|
2066 |
+
original_shapes,
|
2067 |
+
exponents,
|
2068 |
+
max_size,
|
2069 |
+
prev_preconditioners,
|
2070 |
+
)
|
2071 |
+
|
2072 |
+
else:
|
2073 |
+
return _pjit_compute_preconditioners(
|
2074 |
+
states,
|
2075 |
+
step,
|
2076 |
+
statistics,
|
2077 |
+
num_statistics_per_state,
|
2078 |
+
original_shapes,
|
2079 |
+
exponents,
|
2080 |
+
max_size,
|
2081 |
+
prev_preconditioners,
|
2082 |
+
)
|
2083 |
+
|
2084 |
+
def _transform_grad(grad, state, param, step):
|
2085 |
+
"""Transform per-parameter gradients."""
|
2086 |
+
preconditioner = Preconditioner(
|
2087 |
+
param, block_size, best_effort_shape_interpretation
|
2088 |
+
)
|
2089 |
+
sgd_update = grad
|
2090 |
+
new_diagonal_statistics = state.diagonal_statistics.to_float()
|
2091 |
+
if (
|
2092 |
+
graft_type == GraftingType.ADAGRAD
|
2093 |
+
or graft_type == GraftingType.ADAGRAD_NORMALIZED
|
2094 |
+
):
|
2095 |
+
|
2096 |
+
scaled_grad = grad
|
2097 |
+
if graft_type == GraftingType.ADAGRAD_NORMALIZED:
|
2098 |
+
scaled_grad = grad / (jnp.linalg.norm(grad) + 1e-16)
|
2099 |
+
|
2100 |
+
new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square(
|
2101 |
+
scaled_grad
|
2102 |
+
)
|
2103 |
+
adagrad_update = scaled_grad / (
|
2104 |
+
jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
|
2105 |
+
)
|
2106 |
+
grafting_update = adagrad_update
|
2107 |
+
elif (
|
2108 |
+
graft_type == GraftingType.RMSPROP
|
2109 |
+
or graft_type == GraftingType.RMSPROP_NORMALIZED
|
2110 |
+
):
|
2111 |
+
|
2112 |
+
scaled_grad = grad
|
2113 |
+
if graft_type == GraftingType.RMSPROP_NORMALIZED:
|
2114 |
+
scaled_grad = grad / (jnp.linalg.norm(grad) + 1e-16)
|
2115 |
+
|
2116 |
+
w1 = beta2
|
2117 |
+
w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
|
2118 |
+
|
2119 |
+
new_diagonal_statistics = (
|
2120 |
+
w1 * state.diagonal_statistics.to_float() + w2 * jnp.square(scaled_grad)
|
2121 |
+
)
|
2122 |
+
rmsprop_update = scaled_grad / (
|
2123 |
+
jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
|
2124 |
+
)
|
2125 |
+
|
2126 |
+
if clip_by_scaled_gradient_norm:
|
2127 |
+
scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / (
|
2128 |
+
jnp.sqrt(float(rmsprop_update.size))
|
2129 |
+
)
|
2130 |
+
clipping_denom = jnp.maximum(
|
2131 |
+
1.0, scaled_grad_norm / clip_by_scaled_gradient_norm
|
2132 |
+
)
|
2133 |
+
rmsprop_update /= clipping_denom
|
2134 |
+
|
2135 |
+
grafting_update = rmsprop_update
|
2136 |
+
elif graft_type == GraftingType.SGD:
|
2137 |
+
grafting_update = sgd_update
|
2138 |
+
else:
|
2139 |
+
grafting_update = jnp.ones_like(sgd_update) * jnp.sign(sgd_update)
|
2140 |
+
|
2141 |
+
precond_grad = grad
|
2142 |
+
if not _skip_preconditioning(param):
|
2143 |
+
precond_grad = preconditioner.preconditioned_grad(
|
2144 |
+
precond_grad, _maybe_dequantize_preconditioners(state.preconditioners)
|
2145 |
+
)
|
2146 |
+
else:
|
2147 |
+
precond_grad = grafting_update
|
2148 |
+
|
2149 |
+
grafting_update_norm = jnp.linalg.norm(grafting_update)
|
2150 |
+
precond_grad_norm = jnp.linalg.norm(precond_grad)
|
2151 |
+
|
2152 |
+
multiplier = grafting_update_norm / (precond_grad_norm + 1e-16)
|
2153 |
+
shampoo_update = precond_grad * multiplier
|
2154 |
+
|
2155 |
+
shampoo_update_with_wd = shampoo_update
|
2156 |
+
grafting_update_with_wd = grafting_update
|
2157 |
+
if weight_decay != 0:
|
2158 |
+
shampoo_update_with_wd = shampoo_update + weight_decay * param
|
2159 |
+
grafting_update_with_wd = grafting_update + weight_decay * param
|
2160 |
+
|
2161 |
+
w = (1.0 - beta1) if moving_average_for_momentum else 1.0
|
2162 |
+
|
2163 |
+
shampoo_update_with_wd_momentum = (
|
2164 |
+
state.momentum.to_float() * beta1 + w * shampoo_update_with_wd
|
2165 |
+
)
|
2166 |
+
|
2167 |
+
if _graft_type_has_diagonal_momentum_states():
|
2168 |
+
grafting_update_with_wd_momentum = (
|
2169 |
+
state.diagonal_momentum.to_float() * beta1 + w * grafting_update_with_wd
|
2170 |
+
)
|
2171 |
+
else:
|
2172 |
+
# Share the momentum buffer
|
2173 |
+
grafting_update_with_wd_momentum = (
|
2174 |
+
state.momentum.to_float() * beta1 + w * grafting_update_with_wd
|
2175 |
+
)
|
2176 |
+
|
2177 |
+
run_shampoo = (step >= start_preconditioning_step).astype(
|
2178 |
+
grafting_update_with_wd_momentum.dtype
|
2179 |
+
)
|
2180 |
+
|
2181 |
+
momentum_update = (
|
2182 |
+
run_shampoo * shampoo_update_with_wd_momentum
|
2183 |
+
+ (1.0 - run_shampoo) * grafting_update_with_wd_momentum
|
2184 |
+
)
|
2185 |
+
|
2186 |
+
wd_update = (
|
2187 |
+
run_shampoo * shampoo_update_with_wd
|
2188 |
+
+ (1.0 - run_shampoo) * grafting_update_with_wd
|
2189 |
+
)
|
2190 |
+
|
2191 |
+
nesterov_momentum_update = momentum_update
|
2192 |
+
if nesterov:
|
2193 |
+
nesterov_momentum_update = w * wd_update + beta1 * momentum_update
|
2194 |
+
|
2195 |
+
lr = learning_rate
|
2196 |
+
if callable(learning_rate):
|
2197 |
+
lr = learning_rate(step)
|
2198 |
+
transformed_update = -1.0 * lr * nesterov_momentum_update
|
2199 |
+
|
2200 |
+
new_diagonal_momentum = grafting_update_with_wd_momentum
|
2201 |
+
new_momentum = shampoo_update_with_wd_momentum
|
2202 |
+
if not _graft_type_has_diagonal_momentum_states():
|
2203 |
+
new_diagonal_momentum = []
|
2204 |
+
new_momentum = momentum_update
|
2205 |
+
|
2206 |
+
param_stats = ParameterStats(
|
2207 |
+
_quantize_diagonal_statistics(new_diagonal_statistics),
|
2208 |
+
state.statistics,
|
2209 |
+
state.preconditioners,
|
2210 |
+
_quantize_momentum(new_diagonal_momentum),
|
2211 |
+
_quantize_momentum(new_momentum),
|
2212 |
+
state.training_metrics,
|
2213 |
+
)
|
2214 |
+
|
2215 |
+
return transformed_update, param_stats
|
2216 |
+
|
2217 |
+
def update_fn(grads, state, params):
|
2218 |
+
"""Transform the input gradient and update all statistics.
|
2219 |
+
|
2220 |
+
Args:
|
2221 |
+
grads: the gradient tensors for the parameters.
|
2222 |
+
state: a named tuple containing the state of the optimizer
|
2223 |
+
params: the parameters that should be updated.
|
2224 |
+
|
2225 |
+
Returns:
|
2226 |
+
A tuple containing the new parameters and the new optimizer state.
|
2227 |
+
"""
|
2228 |
+
params_flat, treedef = jax.tree_flatten(params)
|
2229 |
+
stats_flat = treedef.flatten_up_to(state.stats)
|
2230 |
+
grads_flat = treedef.flatten_up_to(grads)
|
2231 |
+
|
2232 |
+
new_stats_flat = jax.tree_multimap(
|
2233 |
+
lambda g, s, p: _compute_stats(g, s, p, state.count),
|
2234 |
+
grads_flat,
|
2235 |
+
stats_flat,
|
2236 |
+
params_flat,
|
2237 |
+
)
|
2238 |
+
new_stats_flat = _compute_preconditioners(
|
2239 |
+
new_stats_flat, params_flat, state.count
|
2240 |
+
)
|
2241 |
+
outputs = jax.tree_multimap(
|
2242 |
+
lambda g, s, p: _transform_grad(g, s, p, state.count),
|
2243 |
+
grads_flat,
|
2244 |
+
new_stats_flat,
|
2245 |
+
params_flat,
|
2246 |
+
)
|
2247 |
+
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
|
2248 |
+
|
2249 |
+
updates = jax.tree_unflatten(treedef, updates_flat)
|
2250 |
+
new_stats = jax.tree_unflatten(treedef, new_stats_flat)
|
2251 |
+
|
2252 |
+
new_state = ShampooState(count=state.count + 1, stats=new_stats)
|
2253 |
+
return updates, new_state
|
2254 |
+
|
2255 |
+
if shard_optimizer_states:
|
2256 |
+
# Hijacks the init_fn signature so we can return an OptState with
|
2257 |
+
# appropriate init_fns.
|
2258 |
+
def _init_fns(unused_params):
|
2259 |
+
return InitFnState(
|
2260 |
+
init_fn=sharded_init_fn,
|
2261 |
+
pspec_fn=sharded_init_partition_spec_fn,
|
2262 |
+
shape_and_dtype_fn=sharded_init_shape_and_dtype_fn,
|
2263 |
+
)
|
2264 |
+
|
2265 |
+
return optax.GradientTransformation(_init_fns, sharded_update_fn)
|
2266 |
+
else:
|
2267 |
+
return optax.GradientTransformation(init_fn, update_fn)
|
tools/train/scalable_shampoo/quantization_utils.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Helper routines for quantization."""
|
17 |
+
|
18 |
+
from typing import Any
|
19 |
+
|
20 |
+
import chex
|
21 |
+
import jax.numpy as jnp
|
22 |
+
from flax import struct
|
23 |
+
|
24 |
+
|
25 |
+
# pylint:disable=no-value-for-parameter
|
26 |
+
@struct.dataclass
|
27 |
+
class QuantizedValue:
|
28 |
+
"""State associated with quantized value."""
|
29 |
+
|
30 |
+
quantized: chex.Array
|
31 |
+
diagonal: chex.Array # Diagonal (if extract_diagonal is set)
|
32 |
+
bucket_size: chex.Array
|
33 |
+
quantized_dtype: jnp.dtype = struct.field(
|
34 |
+
pytree_node=False
|
35 |
+
) # Dtype for the quantized value.
|
36 |
+
extract_diagonal: bool = struct.field(pytree_node=False) # In case its centered.
|
37 |
+
shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
|
38 |
+
|
39 |
+
@classmethod
|
40 |
+
def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
|
41 |
+
if isinstance(fvalue, list) and not fvalue:
|
42 |
+
return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
|
43 |
+
quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
|
44 |
+
fvalue, quantized_dtype, extract_diagonal
|
45 |
+
)
|
46 |
+
return QuantizedValue(
|
47 |
+
quantized,
|
48 |
+
diagonal_fvalue,
|
49 |
+
bucket_size,
|
50 |
+
quantized_dtype,
|
51 |
+
extract_diagonal,
|
52 |
+
list(quantized.shape),
|
53 |
+
)
|
54 |
+
|
55 |
+
# Quantization is from Lingvo JAX optimizers.
|
56 |
+
# We extend it for int16 quantization of PSD matrices.
|
57 |
+
@classmethod
|
58 |
+
def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
|
59 |
+
"""Returns quantized value and the bucket."""
|
60 |
+
if quantized_dtype == jnp.float32:
|
61 |
+
return fvalue, [], []
|
62 |
+
elif quantized_dtype == jnp.bfloat16:
|
63 |
+
return fvalue.astype(jnp.bfloat16), [], []
|
64 |
+
|
65 |
+
float_dtype = fvalue.dtype
|
66 |
+
if quantized_dtype == jnp.int8:
|
67 |
+
# value -128 is not used.
|
68 |
+
num_buckets = jnp.array(127.0, dtype=float_dtype)
|
69 |
+
elif quantized_dtype == jnp.int16:
|
70 |
+
# value -32768 is not used.
|
71 |
+
num_buckets = jnp.array(32767.0, dtype=float_dtype)
|
72 |
+
else:
|
73 |
+
raise ValueError(f"Quantized dtype {quantized_dtype} not supported.")
|
74 |
+
# max value is mapped to num_buckets
|
75 |
+
|
76 |
+
if extract_diagonal and fvalue.ndim != 2:
|
77 |
+
raise ValueError(
|
78 |
+
f"Input array {fvalue} must be 2D to work with extract_diagonal."
|
79 |
+
)
|
80 |
+
|
81 |
+
diagonal_fvalue = []
|
82 |
+
if extract_diagonal:
|
83 |
+
diagonal_fvalue = jnp.diag(fvalue)
|
84 |
+
# Remove the diagonal entries.
|
85 |
+
fvalue = fvalue - jnp.diag(diagonal_fvalue)
|
86 |
+
|
87 |
+
# TODO(rohananil): Extend this by making use of information about the blocks
|
88 |
+
# SM3 style which will be useful for diagonal statistics
|
89 |
+
# We first decide the scale.
|
90 |
+
if fvalue.ndim < 1:
|
91 |
+
raise ValueError(
|
92 |
+
f"Input array {fvalue} must have a strictly positive number of "
|
93 |
+
"dimensions."
|
94 |
+
)
|
95 |
+
|
96 |
+
max_abs = jnp.max(jnp.abs(fvalue), axis=0)
|
97 |
+
bucket_size = max_abs / num_buckets
|
98 |
+
bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
|
99 |
+
# To avoid divide by 0.0
|
100 |
+
bs_nonzero = jnp.where(
|
101 |
+
bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)
|
102 |
+
)
|
103 |
+
ratio = fvalue / bs_nonzero
|
104 |
+
# We use rounding to remove bias.
|
105 |
+
quantized = jnp.round(ratio)
|
106 |
+
return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
|
107 |
+
|
108 |
+
def to_float(self):
|
109 |
+
"""Returns the float value."""
|
110 |
+
if isinstance(self.quantized, list) and not self.quantized:
|
111 |
+
return self.quantized
|
112 |
+
|
113 |
+
if self.quantized_dtype == jnp.float32:
|
114 |
+
return self.quantized
|
115 |
+
|
116 |
+
if self.quantized_dtype == jnp.bfloat16:
|
117 |
+
return self.quantized.astype(jnp.float32)
|
118 |
+
|
119 |
+
float_dtype = self.bucket_size.dtype
|
120 |
+
bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
|
121 |
+
val = self.quantized.astype(float_dtype) * bucket_size
|
122 |
+
if self.extract_diagonal:
|
123 |
+
val += jnp.diag(self.diagonal)
|
124 |
+
return val
|
tools/train/scalable_shampoo/sm3.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# An implementation of SM3 from:
|
17 |
+
#
|
18 |
+
# Memory-Efficient Adaptive Optimization, https://arxiv.org/pdf/1901.11150.pdf
|
19 |
+
# Rohan Anil, Vineet Gupta, Tomer Koren, Yoram Singer
|
20 |
+
#
|
21 |
+
# Author: Rohan Anil (rohananil at google dot com)
|
22 |
+
#
|
23 |
+
|
24 |
+
"""SM3 Implementation."""
|
25 |
+
|
26 |
+
import functools
|
27 |
+
from typing import Any, NamedTuple
|
28 |
+
|
29 |
+
import chex
|
30 |
+
import jax
|
31 |
+
import jax.numpy as jnp
|
32 |
+
import optax
|
33 |
+
|
34 |
+
from .quantization_utils import QuantizedValue
|
35 |
+
|
36 |
+
|
37 |
+
class SM3State(NamedTuple):
|
38 |
+
count: chex.Array
|
39 |
+
stats: Any
|
40 |
+
|
41 |
+
|
42 |
+
# Per parameter optimizer state used in data-parallel training.
|
43 |
+
class ParameterStats(NamedTuple):
|
44 |
+
"""State associated to each parameter of the model being trained."""
|
45 |
+
|
46 |
+
diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner
|
47 |
+
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
|
48 |
+
|
49 |
+
|
50 |
+
def sm3(
|
51 |
+
learning_rate, beta1=0.9, beta2=0.999, diagonal_epsilon=1e-10, normalize_grads=False
|
52 |
+
):
|
53 |
+
"""SM3 optimizer.
|
54 |
+
|
55 |
+
Memory-Efficient Adaptive Optimization, Rohan Anil, Vineet Gupta, Tomer Koren,
|
56 |
+
Yoram Singer
|
57 |
+
|
58 |
+
https://arxiv.org/abs/1901.11150
|
59 |
+
|
60 |
+
Args:
|
61 |
+
learning_rate: the step size used to update the parameters.
|
62 |
+
beta1: momentum parameter.
|
63 |
+
beta2: second moment averaging parameter.
|
64 |
+
diagonal_epsilon: epsilon for sm3
|
65 |
+
normalize_grads: Whether to normalize grads. Author finds it useful when
|
66 |
+
grads are high variance.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
a GradientTransformation.
|
70 |
+
"""
|
71 |
+
|
72 |
+
def _quantize_momentum(momentum_statistics):
|
73 |
+
return QuantizedValue.from_float_value(momentum_statistics, jnp.int8)
|
74 |
+
|
75 |
+
def init_fn(params):
|
76 |
+
"""Initialise the optimiser's state."""
|
77 |
+
|
78 |
+
def _init(param):
|
79 |
+
accumulators = [jnp.zeros([s]) for s in param.shape]
|
80 |
+
momentum = _quantize_momentum(jnp.zeros_like(param))
|
81 |
+
return ParameterStats(accumulators, momentum)
|
82 |
+
|
83 |
+
return SM3State(
|
84 |
+
count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)
|
85 |
+
)
|
86 |
+
|
87 |
+
def _get_expanded_shape(shape, i):
|
88 |
+
rank = len(shape)
|
89 |
+
# Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i.
|
90 |
+
# For eg: i = 1 returns [1, N, 1].
|
91 |
+
return [1] * i + [shape[i]] + [1] * (rank - i - 1)
|
92 |
+
|
93 |
+
def _moving_averages(grad, accumulators):
|
94 |
+
w = (1.0 - beta2) if beta2 != 1.0 else 1.0
|
95 |
+
if grad.ndim < 2:
|
96 |
+
return beta2 * accumulators[0] + w * grad**2
|
97 |
+
else:
|
98 |
+
min_accumulator = functools.reduce(jnp.minimum, accumulators)
|
99 |
+
return beta2 * min_accumulator + w * grad**2
|
100 |
+
|
101 |
+
def _moving_averages_momentum(grad, momentum):
|
102 |
+
w = (1.0 - beta1) if beta1 != 1.0 else 1.0
|
103 |
+
return beta1 * momentum.to_float() + w * grad
|
104 |
+
|
105 |
+
def _sketch_diagonal_statistics(grad, updated_diagonal_statistics):
|
106 |
+
all_diagonal_statistics = []
|
107 |
+
for i in range(grad.ndim):
|
108 |
+
axes = list(range(i)) + list(range(i + 1, grad.ndim))
|
109 |
+
dim_diagonal_statistics = jnp.max(updated_diagonal_statistics, axis=axes)
|
110 |
+
all_diagonal_statistics.append(dim_diagonal_statistics)
|
111 |
+
if grad.ndim == 1:
|
112 |
+
all_diagonal_statistics[0] = updated_diagonal_statistics
|
113 |
+
return all_diagonal_statistics
|
114 |
+
|
115 |
+
def update_fn(updates, state, params=None):
|
116 |
+
del params
|
117 |
+
stats = state.stats
|
118 |
+
if normalize_grads:
|
119 |
+
updates = jax.tree_map(lambda g: g / (jnp.linalg.norm(g) + 1e-16), updates)
|
120 |
+
# Reshape all vectors into N-d tensors to compute min over them.
|
121 |
+
# [n], [m] -> [n, 1], [1, m]
|
122 |
+
expanded_diagonal_statistics = jax.tree_multimap(
|
123 |
+
lambda grad, state: [ # pylint:disable=g-long-lambda
|
124 |
+
jnp.reshape(
|
125 |
+
state.diagonal_statistics[i], _get_expanded_shape(grad.shape, i)
|
126 |
+
)
|
127 |
+
for i in range(grad.ndim)
|
128 |
+
],
|
129 |
+
updates,
|
130 |
+
stats,
|
131 |
+
)
|
132 |
+
|
133 |
+
# Compute new diagonal statistics
|
134 |
+
new_diagonal_statistics = jax.tree_multimap(
|
135 |
+
_moving_averages, updates, expanded_diagonal_statistics
|
136 |
+
)
|
137 |
+
|
138 |
+
# Compute preconditioners (1/sqrt(s)) where s is the statistics.
|
139 |
+
new_preconditioners = jax.tree_map(
|
140 |
+
lambda t: 1.0 / jnp.sqrt(t + diagonal_epsilon), new_diagonal_statistics
|
141 |
+
)
|
142 |
+
preconditioned_grads = jax.tree_multimap(
|
143 |
+
lambda g, p: g * p, updates, new_preconditioners
|
144 |
+
)
|
145 |
+
|
146 |
+
# Compute updated momentum (also handle quantization)
|
147 |
+
updated_momentum = jax.tree_multimap(
|
148 |
+
lambda preconditioned_grad, state: _moving_averages_momentum( # pylint:disable=g-long-lambda
|
149 |
+
preconditioned_grad, state.diagonal_momentum
|
150 |
+
),
|
151 |
+
preconditioned_grads,
|
152 |
+
stats,
|
153 |
+
)
|
154 |
+
|
155 |
+
# Update diagonal statistics.
|
156 |
+
updated_diagonal_statistics = jax.tree_multimap(
|
157 |
+
_sketch_diagonal_statistics, updates, new_diagonal_statistics
|
158 |
+
)
|
159 |
+
|
160 |
+
# Update momentum.
|
161 |
+
new_sm3_stats = jax.tree_multimap(
|
162 |
+
lambda momentum, diagonal_stats: ParameterStats( # pylint:disable=g-long-lambda
|
163 |
+
diagonal_stats, _quantize_momentum(momentum)
|
164 |
+
),
|
165 |
+
updated_momentum,
|
166 |
+
updated_diagonal_statistics,
|
167 |
+
)
|
168 |
+
|
169 |
+
lr = learning_rate
|
170 |
+
if callable(learning_rate):
|
171 |
+
lr = learning_rate(state.count)
|
172 |
+
|
173 |
+
new_updates = jax.tree_map(lambda pg: -lr * pg, updated_momentum)
|
174 |
+
return new_updates, SM3State(count=state.count + 1, stats=new_sm3_stats)
|
175 |
+
|
176 |
+
return optax.GradientTransformation(init_fn, update_fn)
|
tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py
ADDED
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""JAX Ops for symmetric matrices used by the Shampoo optimizer."""
|
17 |
+
|
18 |
+
import functools
|
19 |
+
from typing import Any, List, Optional, Sequence, Union
|
20 |
+
|
21 |
+
import jax
|
22 |
+
import jax.numpy as jnp
|
23 |
+
import numpy as np
|
24 |
+
from flax import struct
|
25 |
+
from jax import lax
|
26 |
+
|
27 |
+
|
28 |
+
@struct.dataclass
|
29 |
+
class SlicedSymmetricMatrix:
|
30 |
+
"""A symmetric matrix represented by lower-triangular block row slices.
|
31 |
+
|
32 |
+
For example, the symmetric matrix M = [[a, b^T], [b, c]] would be represented
|
33 |
+
by the block rows a and [b, c].
|
34 |
+
|
35 |
+
The matrix may be batched, in which case each entry of block_rows may have
|
36 |
+
dimension greater than 2. The last two dimensions represent the rows and cols.
|
37 |
+
"""
|
38 |
+
|
39 |
+
block_rows: List[jnp.ndarray]
|
40 |
+
|
41 |
+
|
42 |
+
def product_with_transpose(
|
43 |
+
mat1,
|
44 |
+
mat2,
|
45 |
+
axes,
|
46 |
+
precision=lax.Precision.DEFAULT,
|
47 |
+
):
|
48 |
+
"""Returns mat1 * mat2^T for two matrices (possibly batched).
|
49 |
+
|
50 |
+
The rows and columns are the last two dimensions for each matrix.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
mat1: First matrix.
|
54 |
+
mat2: Second matrix.
|
55 |
+
axes: The axes over which to apply the product.
|
56 |
+
precision: JAX precision to use for the multiplication.
|
57 |
+
"""
|
58 |
+
return jnp.tensordot(a=mat1, b=mat2, axes=axes, precision=precision)
|
59 |
+
|
60 |
+
|
61 |
+
@functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
|
62 |
+
def sliced_transposed_product(
|
63 |
+
mat,
|
64 |
+
block_size,
|
65 |
+
axes=(-1,),
|
66 |
+
precision=lax.Precision.DEFAULT,
|
67 |
+
):
|
68 |
+
"""Returns the blocked slices representing a symmetric contraction.
|
69 |
+
|
70 |
+
Specifically, the output is a contraction of the input mat with itself, in the
|
71 |
+
specified axes.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
mat: The matrix for which we will compute a contraction with itself.
|
75 |
+
block_size: The size of row blocks to compute.
|
76 |
+
axes: Axes to use for the contraction.
|
77 |
+
precision: The precision to use in each computation.
|
78 |
+
|
79 |
+
Raises:
|
80 |
+
ValueError: Raised when the specified block size does not evenly divide
|
81 |
+
the number of rows of the input mat.
|
82 |
+
"""
|
83 |
+
rank = len(mat.shape)
|
84 |
+
|
85 |
+
def _make_axis_positive(ax):
|
86 |
+
assert -rank <= ax < rank
|
87 |
+
return ax + rank if ax < 0 else ax
|
88 |
+
|
89 |
+
positive_axes = [_make_axis_positive(ax) for ax in axes]
|
90 |
+
assert len(positive_axes) == len(axes)
|
91 |
+
remaining_axes = set(range(rank)) - set(positive_axes)
|
92 |
+
assert len(remaining_axes) == 1
|
93 |
+
remaining_ax = remaining_axes.pop()
|
94 |
+
|
95 |
+
num_rows = mat.shape[remaining_ax]
|
96 |
+
if num_rows % block_size != 0:
|
97 |
+
raise ValueError(
|
98 |
+
"The row dimension must be divisible by block_size. "
|
99 |
+
f"Instead got row dimension={num_rows} and block_size={block_size}."
|
100 |
+
)
|
101 |
+
|
102 |
+
block_rows = []
|
103 |
+
for i in range(num_rows // block_size):
|
104 |
+
start_indices = [0] * rank
|
105 |
+
start_indices[remaining_ax] = i * block_size
|
106 |
+
|
107 |
+
slice_sizes = list(mat.shape)
|
108 |
+
slice_sizes[remaining_ax] = block_size
|
109 |
+
|
110 |
+
slice_sizes_full = list(mat.shape)
|
111 |
+
slice_sizes_full[remaining_ax] = (i + 1) * block_size
|
112 |
+
|
113 |
+
block_rows.append(
|
114 |
+
product_with_transpose(
|
115 |
+
lax.dynamic_slice(
|
116 |
+
mat, start_indices=start_indices, slice_sizes=slice_sizes
|
117 |
+
),
|
118 |
+
lax.dynamic_slice(
|
119 |
+
mat, start_indices=[0] * rank, slice_sizes=slice_sizes_full
|
120 |
+
),
|
121 |
+
axes=(axes, axes),
|
122 |
+
precision=precision,
|
123 |
+
)
|
124 |
+
)
|
125 |
+
|
126 |
+
return SlicedSymmetricMatrix(block_rows=block_rows)
|
127 |
+
|
128 |
+
|
129 |
+
@functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
|
130 |
+
def sliced_transposed_product_concat(
|
131 |
+
mat,
|
132 |
+
block_size,
|
133 |
+
axes=(-1,),
|
134 |
+
precision=lax.Precision.DEFAULT,
|
135 |
+
):
|
136 |
+
"""Returns the concatenated slices representing mat*mat^T.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
mat: The matrix for which we will compute mat*mat^T. It does not need to be
|
140 |
+
square, and may be batched.
|
141 |
+
block_size: The size of row blocks to compute.
|
142 |
+
axes: Axes to use for the contraction.
|
143 |
+
precision: The precision to use in each computation.
|
144 |
+
|
145 |
+
Raises:
|
146 |
+
ValueError: Raised when the specified block size does not evenly divide
|
147 |
+
the number of rows of the input mat.
|
148 |
+
"""
|
149 |
+
sliced_symmetric_matrix = sliced_transposed_product(
|
150 |
+
mat=mat, block_size=block_size, axes=axes, precision=precision
|
151 |
+
)
|
152 |
+
return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
|
153 |
+
|
154 |
+
|
155 |
+
@jax.jit
|
156 |
+
def materialize_matrix(symmetric_matrix):
|
157 |
+
"""Returns a materialized symmetric matrix.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
symmetric_matrix: the matrix represented by lower-triangular block slices.
|
161 |
+
"""
|
162 |
+
block_rows = symmetric_matrix.block_rows
|
163 |
+
block_size = block_rows[0].shape[-2]
|
164 |
+
num_blocks = len(block_rows)
|
165 |
+
|
166 |
+
# Slice the lower-triangular and diagonal blocks into blocks.
|
167 |
+
blocks = [
|
168 |
+
[
|
169 |
+
block_row[Ellipsis, i * block_size : (i + 1) * block_size]
|
170 |
+
for i in range(k + 1)
|
171 |
+
]
|
172 |
+
for k, block_row in enumerate(block_rows)
|
173 |
+
]
|
174 |
+
|
175 |
+
# Generate the (off-diagonal) upper-triangular blocks.
|
176 |
+
off_diags = [[] for _ in range(num_blocks - 1)]
|
177 |
+
for k, block_row in enumerate(block_rows[1:]):
|
178 |
+
for i in range(k + 1):
|
179 |
+
off_diags[i].append(
|
180 |
+
jnp.swapaxes(
|
181 |
+
a=block_row[Ellipsis, i * block_size : (i + 1) * block_size],
|
182 |
+
axis1=-1,
|
183 |
+
axis2=-2,
|
184 |
+
)
|
185 |
+
)
|
186 |
+
|
187 |
+
return jnp.block(
|
188 |
+
[row + row_t for row, row_t in zip(blocks[:-1], off_diags)] + [blocks[-1]]
|
189 |
+
)
|
190 |
+
|
191 |
+
|
192 |
+
@functools.partial(jax.jit, static_argnames=("num_blocks"))
|
193 |
+
def materialize_matrix_from_concat(
|
194 |
+
block_rows_concat,
|
195 |
+
num_blocks=None,
|
196 |
+
):
|
197 |
+
"""Returns a materialized symmetric matrix from concatenated slices.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
block_rows_concat: The matrix represented as the concatenated
|
201 |
+
lower-triangular blocks.
|
202 |
+
num_blocks: The number of block-rows used to represent the symmetric matrix.
|
203 |
+
If not specified, it is inferred from the shape of block_rows_concat.
|
204 |
+
"""
|
205 |
+
if num_blocks is None:
|
206 |
+
num_blocks = find_num_blocks(block_rows_concat)
|
207 |
+
|
208 |
+
block_size = block_rows_concat.shape[-2]
|
209 |
+
|
210 |
+
block_rows = [
|
211 |
+
block_rows_concat[
|
212 |
+
Ellipsis,
|
213 |
+
(k * (k + 1))
|
214 |
+
// 2
|
215 |
+
* block_size : (((k + 1) * (k + 2)) // 2 + 1)
|
216 |
+
* block_size,
|
217 |
+
]
|
218 |
+
for k in range(num_blocks)
|
219 |
+
]
|
220 |
+
|
221 |
+
return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows))
|
222 |
+
|
223 |
+
|
224 |
+
@functools.partial(jax.jit, static_argnames=("alpha", "beta", "axes"))
|
225 |
+
def update_sliced_rows(
|
226 |
+
symmetric_matrix,
|
227 |
+
mat,
|
228 |
+
alpha,
|
229 |
+
beta,
|
230 |
+
axes=(-1,),
|
231 |
+
):
|
232 |
+
"""Implements the blocked equivalent of SYRK.
|
233 |
+
|
234 |
+
Specifically, the symmetric matrix (represented using lower-triangular block
|
235 |
+
rows) is updated using the sliced product of mat.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
symmetric_matrix: The symmetric matrix to update.
|
239 |
+
mat: The matrix to use for the update = mat * mat^T. The number of rows
|
240 |
+
should match that of symmetric_matrix.
|
241 |
+
alpha: The weight for the update.
|
242 |
+
beta: The weight for the original symmetric matrix.
|
243 |
+
axes: Axes to use for the contraction of the update.
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
The updated rows of alpha * mat * mat^T + beta * symmetric_matrix.
|
247 |
+
"""
|
248 |
+
block_size = symmetric_matrix.block_rows[0].shape[-2]
|
249 |
+
sym_prod = sliced_transposed_product(mat=mat, block_size=block_size, axes=axes)
|
250 |
+
return SlicedSymmetricMatrix(
|
251 |
+
block_rows=[
|
252 |
+
update * alpha + row * beta
|
253 |
+
for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows)
|
254 |
+
]
|
255 |
+
)
|
256 |
+
|
257 |
+
|
258 |
+
def num_blocks_from_total_blocks(total_blocks):
|
259 |
+
"""Returns the number of blocks (i.e.
|
260 |
+
|
261 |
+
block rows) from the total blocks.
|
262 |
+
|
263 |
+
This is the inverse of the function x -> x*(x+1)/2.
|
264 |
+
|
265 |
+
For example, the matrix M = [[A, B^T], [B, C]] may be represented using a
|
266 |
+
total of 3 blocks ([A, B, C]). The number of corresponding block rows is 2.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
total_blocks: The total blocks used to represent the matrix.
|
270 |
+
"""
|
271 |
+
num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32)
|
272 |
+
if (num_blocks * (num_blocks + 1)) / 2 != total_blocks:
|
273 |
+
raise ValueError(
|
274 |
+
f"total_blocks={total_blocks} does not correspond to "
|
275 |
+
"a symmetric matrix. It must have the form total_blocks = x*(x+1)/2."
|
276 |
+
)
|
277 |
+
return num_blocks
|
278 |
+
|
279 |
+
|
280 |
+
def find_num_blocks(block_rows_concat):
|
281 |
+
"""Returns the number of (row) blocks representing the concatenated matrix.
|
282 |
+
|
283 |
+
For example, an input with dimensions [256, 2560] represents 10 square blocks,
|
284 |
+
which matches 4 lower-triangular block rows (1+2+3+4). So this function will
|
285 |
+
return 4.
|
286 |
+
|
287 |
+
Use ordinary numpy functions here so that the returned value is static.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
block_rows_concat: The concatenated block array.
|
291 |
+
|
292 |
+
Raises:
|
293 |
+
ValueError: When the dimensions of the matrix do not correspond to a lower
|
294 |
+
triangular block representation.
|
295 |
+
"""
|
296 |
+
# Compute the number of square blocks used to represent the matrix.
|
297 |
+
total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2]
|
298 |
+
# Determine the number of block rows by inverting y = x*(x+1)/2.
|
299 |
+
return num_blocks_from_total_blocks(total_blocks)
|
300 |
+
|
301 |
+
|
302 |
+
@functools.partial(jax.jit, static_argnames=("block_size"))
|
303 |
+
def slice_symmetric_matrix(
|
304 |
+
mat,
|
305 |
+
block_size,
|
306 |
+
):
|
307 |
+
"""Returns sliced row blocks.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
mat: A symmetric matrix.
|
311 |
+
block_size: The size of the row slices.
|
312 |
+
"""
|
313 |
+
num_rows = mat.shape[-2]
|
314 |
+
num_cols = mat.shape[-1]
|
315 |
+
if num_rows != num_cols:
|
316 |
+
raise ValueError("mat is not square.")
|
317 |
+
if num_rows % block_size != 0:
|
318 |
+
raise ValueError(
|
319 |
+
"block size does not evenly divide rows. "
|
320 |
+
f"num_rows={num_rows}, block_size={block_size}"
|
321 |
+
)
|
322 |
+
return SlicedSymmetricMatrix(
|
323 |
+
block_rows=[
|
324 |
+
mat[
|
325 |
+
Ellipsis,
|
326 |
+
i * block_size : (i + 1) * block_size,
|
327 |
+
0 : (i + 1) * block_size,
|
328 |
+
]
|
329 |
+
for i in range(num_rows // block_size)
|
330 |
+
]
|
331 |
+
)
|
332 |
+
|
333 |
+
|
334 |
+
@functools.partial(jax.jit, static_argnames=("block_size"))
|
335 |
+
def slice_symmetric_matrix_concat(
|
336 |
+
mat,
|
337 |
+
block_size,
|
338 |
+
):
|
339 |
+
"""Returns the concatenated sliced row blocks.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
mat: A symmetric matrix.
|
343 |
+
block_size: The size of the row slices.
|
344 |
+
"""
|
345 |
+
sliced_symmetric_matrix = slice_symmetric_matrix(mat=mat, block_size=block_size)
|
346 |
+
return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
|
347 |
+
|
348 |
+
|
349 |
+
def sliced_matrix_diag(mat):
|
350 |
+
"""Returns the diagonal of the symmetric matrix.
|
351 |
+
|
352 |
+
Args:
|
353 |
+
mat: The symmetric matrix represented in concatenated block form.
|
354 |
+
"""
|
355 |
+
rows, cols = mat.shape
|
356 |
+
total_blocks = cols // rows
|
357 |
+
num_blocks = num_blocks_from_total_blocks(total_blocks)
|
358 |
+
diags = []
|
359 |
+
for i in range(num_blocks):
|
360 |
+
last_index = rows * ((i + 2) * (i + 1)) // 2
|
361 |
+
first_index = last_index - rows
|
362 |
+
diags.append(jnp.diag(mat[Ellipsis, first_index:last_index]))
|
363 |
+
return jnp.concatenate(diags, axis=-1)
|
364 |
+
|
365 |
+
|
366 |
+
def diag_as_concat(diag, block_size):
|
367 |
+
"""Returns the representation of a diagonal matrix in symmetric block form.
|
368 |
+
|
369 |
+
Args:
|
370 |
+
diag: The 1D array for the diagonals.
|
371 |
+
block_size: The size of blocks to use. Must divide the length of diag.
|
372 |
+
"""
|
373 |
+
assert len(diag.shape) == 1 # diag must be 1D.
|
374 |
+
assert len(diag) % block_size == 0
|
375 |
+
num_diag_blocks = len(diag) // block_size
|
376 |
+
blocks = []
|
377 |
+
for i in range(num_diag_blocks):
|
378 |
+
blocks.append(jnp.zeros(shape=(block_size, block_size * i), dtype=diag.dtype))
|
379 |
+
blocks.append(jnp.diag(diag[i * block_size : (i + 1) * block_size]))
|
380 |
+
return jnp.concatenate(blocks, axis=-1)
|
381 |
+
|
382 |
+
|
383 |
+
def row_abs_maxes(mat):
|
384 |
+
"""Returns the max of the absolute values of the rows of the full matrix.
|
385 |
+
|
386 |
+
For example the symmetric matrix M = [[1, 6], [6, 2]] is represented using
|
387 |
+
mat = [1, 6, 2] with block_size = 1. In this case the function returns the
|
388 |
+
aboslute row maxes of the original symmetric matrix, [6, 6].
|
389 |
+
|
390 |
+
Args:
|
391 |
+
mat: The symmetric matrix represented as the concatenated blocks.
|
392 |
+
"""
|
393 |
+
rows, cols = mat.shape
|
394 |
+
|
395 |
+
# Find col and row max for each block.
|
396 |
+
col_maxes = []
|
397 |
+
row_maxes = []
|
398 |
+
for i in range(cols // rows):
|
399 |
+
block = jnp.abs(mat[Ellipsis, i * rows : (i + 1) * rows])
|
400 |
+
col_maxes.append(jnp.max(block, axis=1))
|
401 |
+
row_maxes.append(jnp.max(block, axis=0))
|
402 |
+
|
403 |
+
# global row max from block maxes.
|
404 |
+
num_blocks = num_blocks_from_total_blocks(cols // rows)
|
405 |
+
maxes = []
|
406 |
+
for i in range(num_blocks):
|
407 |
+
maxes.append(
|
408 |
+
jnp.concatenate(
|
409 |
+
row_maxes[(i * (i + 1) // 2) : ((i + 2) * (i + 1) // 2)]
|
410 |
+
+ [
|
411 |
+
col_maxes[((j + 1) * (j + 2)) // 2 - (j - i + 1)]
|
412 |
+
for j in range(i + 1, num_blocks)
|
413 |
+
],
|
414 |
+
axis=-1,
|
415 |
+
)
|
416 |
+
)
|
417 |
+
|
418 |
+
return jnp.max(jnp.stack(maxes), axis=0)
|
419 |
+
|
420 |
+
|
421 |
+
def times_vector(mat, vec):
|
422 |
+
"""Returns the symmetric block-concatenated matrix multiplied by a vector.
|
423 |
+
|
424 |
+
Specifically, each value in the vector is multiplied by a row of the full
|
425 |
+
matrix. That is, the vector is broadcast and multiplied element-wise. Note
|
426 |
+
this would be the transpose of full_mat * vec if full_mat represented the full
|
427 |
+
symmetric matrix.
|
428 |
+
|
429 |
+
Args:
|
430 |
+
mat: The symmetric matrix represented as the concatenated blocks.
|
431 |
+
vec: The vector, having the same dimension as the materialized matrix.
|
432 |
+
"""
|
433 |
+
rows, cols = mat.shape
|
434 |
+
num_blocks = num_blocks_from_total_blocks(cols // rows)
|
435 |
+
multiplied = []
|
436 |
+
for i in range(num_blocks):
|
437 |
+
mat_block = mat[
|
438 |
+
Ellipsis, rows * ((i + 1) * i) // 2 : rows * ((i + 1) * (i + 2)) // 2
|
439 |
+
]
|
440 |
+
vec_block = vec[Ellipsis, rows * i : rows * (i + 1)]
|
441 |
+
multiplied.append(jnp.einsum("...ij,...i->ij", mat_block, vec_block))
|
442 |
+
return jnp.concatenate(multiplied, axis=-1)
|
tools/train/sweep.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
program: train.py
|
2 |
+
project: dalle-mini
|
3 |
+
method: random
|
4 |
+
metric:
|
5 |
+
name: eval/loss
|
6 |
+
goal: minimize
|
7 |
+
parameters:
|
8 |
+
optim:
|
9 |
+
value: distributed_shampoo
|
10 |
+
learning_rate:
|
11 |
+
distribution: log_uniform
|
12 |
+
# from exp(min) to exp(max)
|
13 |
+
min: -9.2
|
14 |
+
max: -6.9
|
15 |
+
tokenizer_name:
|
16 |
+
value: boris/dalle-mini-tokenizer
|
17 |
+
config_name:
|
18 |
+
value: ./config/mini
|
19 |
+
dtype:
|
20 |
+
value: bfloat16
|
21 |
+
dataset_repo_or_path:
|
22 |
+
value: ./data
|
23 |
+
per_device_train_batch_size:
|
24 |
+
value: 64
|
25 |
+
per_device_eval_batch_size:
|
26 |
+
value: 64
|
27 |
+
gradient_accumulation_steps:
|
28 |
+
value: 1
|
29 |
+
warmup_steps:
|
30 |
+
value: 1000
|
31 |
+
num_train_epochs:
|
32 |
+
value: 1
|
33 |
+
max_train_samples:
|
34 |
+
value: 1000000
|
35 |
+
logging_steps:
|
36 |
+
value: 40
|
37 |
+
eval_steps:
|
38 |
+
value: 200
|
39 |
+
|
40 |
+
command:
|
41 |
+
- python3
|
42 |
+
- ${program}
|
43 |
+
- "--streaming"
|
44 |
+
- "--output_dir"
|
45 |
+
- "./output"
|
46 |
+
- "--overwrite_output_dir"
|
47 |
+
- "--do_train"
|
48 |
+
- "--do_eval"
|
49 |
+
- ${args}
|
tools/train/train.py
ADDED
@@ -0,0 +1,1436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2021-2022 The HuggingFace & DALL·E Mini team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Training DALL·E Mini.
|
18 |
+
Script adapted from run_summarization_flax.py
|
19 |
+
"""
|
20 |
+
|
21 |
+
import io
|
22 |
+
import logging
|
23 |
+
import os
|
24 |
+
import sys
|
25 |
+
import tempfile
|
26 |
+
import time
|
27 |
+
from dataclasses import asdict, dataclass, field
|
28 |
+
from pathlib import Path
|
29 |
+
from typing import Any, Callable, NamedTuple, Optional
|
30 |
+
|
31 |
+
import datasets
|
32 |
+
import flax
|
33 |
+
import jax
|
34 |
+
import jax.numpy as jnp
|
35 |
+
import jaxlib
|
36 |
+
import numpy as np
|
37 |
+
import optax
|
38 |
+
import transformers
|
39 |
+
import wandb
|
40 |
+
from datasets import Dataset
|
41 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
42 |
+
from flax.serialization import from_bytes, to_bytes
|
43 |
+
from flax.training import train_state
|
44 |
+
from flax.training.common_utils import onehot
|
45 |
+
from jax.experimental import PartitionSpec, maps
|
46 |
+
from jax.experimental.compilation_cache import compilation_cache as cc
|
47 |
+
from jax.experimental.pjit import pjit, with_sharding_constraint
|
48 |
+
from scalable_shampoo.distributed_shampoo import GraftingType, distributed_shampoo
|
49 |
+
from tqdm import tqdm
|
50 |
+
from transformers import HfArgumentParser
|
51 |
+
|
52 |
+
import dalle_mini
|
53 |
+
from dalle_mini.data import Dataset
|
54 |
+
from dalle_mini.model import (
|
55 |
+
DalleBart,
|
56 |
+
DalleBartConfig,
|
57 |
+
DalleBartTokenizer,
|
58 |
+
set_partitions,
|
59 |
+
)
|
60 |
+
|
61 |
+
try:
|
62 |
+
from google.cloud import storage
|
63 |
+
except:
|
64 |
+
storage = None
|
65 |
+
|
66 |
+
cc.initialize_cache("./jax_cache", max_cache_size_bytes=10 * 2**30)
|
67 |
+
|
68 |
+
logger = logging.getLogger(__name__)
|
69 |
+
|
70 |
+
|
71 |
+
@dataclass
|
72 |
+
class ModelArguments:
|
73 |
+
"""
|
74 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
75 |
+
"""
|
76 |
+
|
77 |
+
model_name_or_path: Optional[str] = field(
|
78 |
+
default=None,
|
79 |
+
metadata={
|
80 |
+
"help": "The model checkpoint for weights initialization. "
|
81 |
+
"Don't set if you want to train a model from scratch. "
|
82 |
+
"W&B artifact references are supported in addition to the sources supported by `PreTrainedModel`."
|
83 |
+
},
|
84 |
+
)
|
85 |
+
config_name: Optional[str] = field(
|
86 |
+
default=None,
|
87 |
+
metadata={
|
88 |
+
"help": "Pretrained config name or path if not the same as model_name_or_path"
|
89 |
+
},
|
90 |
+
)
|
91 |
+
tokenizer_name: Optional[str] = field(
|
92 |
+
default=None,
|
93 |
+
metadata={
|
94 |
+
"help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
|
95 |
+
},
|
96 |
+
)
|
97 |
+
dtype: Optional[str] = field(
|
98 |
+
default="float32",
|
99 |
+
metadata={
|
100 |
+
"help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
|
101 |
+
},
|
102 |
+
)
|
103 |
+
restore_state: Optional[bool] = field(
|
104 |
+
default=False,
|
105 |
+
metadata={
|
106 |
+
"help": "Restore optimizer and training state. Can be True (will retrieve associated wandb artifact), a local directory or a Google bucket path."
|
107 |
+
},
|
108 |
+
)
|
109 |
+
|
110 |
+
def __post_init__(self):
|
111 |
+
if self.tokenizer_name is None:
|
112 |
+
self.tokenizer_name = self.model_name_or_path
|
113 |
+
assert (
|
114 |
+
self.tokenizer_name is not None
|
115 |
+
), "Tokenizer name or model name/path needs to be specified"
|
116 |
+
if self.restore_state:
|
117 |
+
assert self.model_name_or_path is not None and (
|
118 |
+
"/model-" in self.model_name_or_path
|
119 |
+
), "Restoring state only available with W&B artifact reference"
|
120 |
+
|
121 |
+
def get_metadata(self):
|
122 |
+
if self.restore_state:
|
123 |
+
if jax.process_index() == 0:
|
124 |
+
artifact = wandb.run.use_artifact(self.model_name_or_path)
|
125 |
+
else:
|
126 |
+
artifact = wandb.Api().artifact(self.model_name_or_path)
|
127 |
+
return artifact.metadata
|
128 |
+
else:
|
129 |
+
return dict()
|
130 |
+
|
131 |
+
def get_opt_state(self):
|
132 |
+
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
|
133 |
+
if self.restore_state is True:
|
134 |
+
# wandb artifact
|
135 |
+
state_artifact = self.model_name_or_path.replace(
|
136 |
+
"/model-", "/state-", 1
|
137 |
+
)
|
138 |
+
if jax.process_index() == 0:
|
139 |
+
artifact = wandb.run.use_artifact(state_artifact)
|
140 |
+
else:
|
141 |
+
artifact = wandb.Api().artifact(state_artifact)
|
142 |
+
if artifact.metadata.get("bucket_path"):
|
143 |
+
# we will read directly file contents
|
144 |
+
self.restore_state = artifact.metadata["bucket_path"]
|
145 |
+
else:
|
146 |
+
artifact_dir = artifact.download(tmp_dir)
|
147 |
+
self.restore_state = str(Path(artifact_dir) / "opt_state.msgpack")
|
148 |
+
|
149 |
+
if self.restore_state.startswith("gs://"):
|
150 |
+
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
|
151 |
+
bucket, blob_name = str(bucket_path).split("/", 1)
|
152 |
+
assert (
|
153 |
+
storage is not None
|
154 |
+
), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
|
155 |
+
client = storage.Client()
|
156 |
+
bucket = client.bucket(bucket)
|
157 |
+
blob = bucket.blob(blob_name)
|
158 |
+
return blob.download_as_bytes()
|
159 |
+
|
160 |
+
with Path(self.restore_state).open("rb") as f:
|
161 |
+
return f.read()
|
162 |
+
|
163 |
+
|
164 |
+
@dataclass
|
165 |
+
class DataTrainingArguments:
|
166 |
+
"""
|
167 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
168 |
+
"""
|
169 |
+
|
170 |
+
text_column: Optional[str] = field(
|
171 |
+
default="caption",
|
172 |
+
metadata={
|
173 |
+
"help": "The name of the column in the datasets containing the full texts (for summarization)."
|
174 |
+
},
|
175 |
+
)
|
176 |
+
encoding_column: Optional[str] = field(
|
177 |
+
default="encoding",
|
178 |
+
metadata={
|
179 |
+
"help": "The name of the column in the datasets containing the image encodings."
|
180 |
+
},
|
181 |
+
)
|
182 |
+
dataset_repo_or_path: str = field(
|
183 |
+
default=None,
|
184 |
+
metadata={"help": "The dataset repository containing encoded files."},
|
185 |
+
)
|
186 |
+
train_file: Optional[str] = field(
|
187 |
+
default=None,
|
188 |
+
metadata={
|
189 |
+
"help": "The input training data file (glob & braceexpand acceptable)."
|
190 |
+
},
|
191 |
+
)
|
192 |
+
validation_file: Optional[str] = field(
|
193 |
+
default=None,
|
194 |
+
metadata={
|
195 |
+
"help": "An optional input evaluation data file (glob & braceexpand acceptable)."
|
196 |
+
},
|
197 |
+
)
|
198 |
+
# data loading should not be a bottleneck so we use "streaming" mode by default
|
199 |
+
streaming: Optional[bool] = field(
|
200 |
+
default=True,
|
201 |
+
metadata={"help": "Whether to stream the dataset."},
|
202 |
+
)
|
203 |
+
use_auth_token: Optional[bool] = field(
|
204 |
+
default=False,
|
205 |
+
metadata={
|
206 |
+
"help": "Whether to use the authentication token for private datasets."
|
207 |
+
},
|
208 |
+
)
|
209 |
+
shard_by_host: Optional[bool] = field(
|
210 |
+
default=False,
|
211 |
+
metadata={
|
212 |
+
"help": "Whether to shard data files by host in multi-host environments."
|
213 |
+
},
|
214 |
+
)
|
215 |
+
blank_caption_prob: Optional[float] = field(
|
216 |
+
default=0.0,
|
217 |
+
metadata={
|
218 |
+
"help": "Probability of removing some captions for classifier-free guidance."
|
219 |
+
},
|
220 |
+
)
|
221 |
+
clip_score_column: Optional[str] = field(
|
222 |
+
default="clip_score",
|
223 |
+
metadata={"help": "Column that containts clip score for filtering."},
|
224 |
+
)
|
225 |
+
min_clip_score: Optional[float] = field(
|
226 |
+
default=None,
|
227 |
+
metadata={"help": "Minimum clip score required."},
|
228 |
+
)
|
229 |
+
max_clip_score: Optional[float] = field(
|
230 |
+
default=None,
|
231 |
+
metadata={"help": "Maximum clip score required."},
|
232 |
+
)
|
233 |
+
filter_column: Optional[str] = field(
|
234 |
+
default=None,
|
235 |
+
metadata={"help": "Column that containts classes to be filtered."},
|
236 |
+
)
|
237 |
+
filter_value: Optional[str] = field(
|
238 |
+
default=None,
|
239 |
+
metadata={"help": "Class value to be kept during filtering."},
|
240 |
+
)
|
241 |
+
max_train_samples: Optional[int] = field(
|
242 |
+
default=None,
|
243 |
+
metadata={
|
244 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples."
|
245 |
+
},
|
246 |
+
)
|
247 |
+
max_eval_samples: Optional[int] = field(
|
248 |
+
default=None,
|
249 |
+
metadata={
|
250 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples."
|
251 |
+
},
|
252 |
+
)
|
253 |
+
preprocessing_num_workers: Optional[int] = field(
|
254 |
+
default=None,
|
255 |
+
metadata={
|
256 |
+
"help": "The number of processes to use for the preprocessing. Not used in streaming mode."
|
257 |
+
},
|
258 |
+
)
|
259 |
+
overwrite_cache: bool = field(
|
260 |
+
default=False,
|
261 |
+
metadata={
|
262 |
+
"help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
|
263 |
+
},
|
264 |
+
)
|
265 |
+
# default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
|
266 |
+
seed_dataset: int = field(
|
267 |
+
default=None,
|
268 |
+
metadata={
|
269 |
+
"help": "Random seed for the dataset that will be set at the beginning of training."
|
270 |
+
},
|
271 |
+
)
|
272 |
+
|
273 |
+
def __post_init__(self):
|
274 |
+
if self.dataset_repo_or_path is None:
|
275 |
+
raise ValueError("Need a dataset repository or path.")
|
276 |
+
|
277 |
+
|
278 |
+
@dataclass
|
279 |
+
class TrainingArguments:
|
280 |
+
"""
|
281 |
+
Arguments pertaining to training parameters.
|
282 |
+
"""
|
283 |
+
|
284 |
+
output_dir: str = field(
|
285 |
+
metadata={
|
286 |
+
"help": "The output directory where the model predictions and checkpoints will be written."
|
287 |
+
},
|
288 |
+
)
|
289 |
+
overwrite_output_dir: bool = field(
|
290 |
+
default=False,
|
291 |
+
metadata={
|
292 |
+
"help": (
|
293 |
+
"Overwrite the content of the output directory. "
|
294 |
+
"Use this to continue training if output_dir points to a checkpoint directory."
|
295 |
+
)
|
296 |
+
},
|
297 |
+
)
|
298 |
+
|
299 |
+
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
300 |
+
do_eval: bool = field(
|
301 |
+
default=False, metadata={"help": "Whether to run eval on the validation set."}
|
302 |
+
)
|
303 |
+
|
304 |
+
per_device_train_batch_size: int = field(
|
305 |
+
default=8,
|
306 |
+
metadata={"help": "Batch size per data parallel device for training."},
|
307 |
+
)
|
308 |
+
per_device_eval_batch_size: Optional[int] = field(
|
309 |
+
default=None,
|
310 |
+
metadata={
|
311 |
+
"help": "Batch size per data parallel device for evaluation. Same as training batch size if not set."
|
312 |
+
},
|
313 |
+
)
|
314 |
+
|
315 |
+
gradient_accumulation_steps: int = field(
|
316 |
+
default=1,
|
317 |
+
metadata={
|
318 |
+
"help": "Number of updates steps to accumulate before performing an update pass."
|
319 |
+
},
|
320 |
+
)
|
321 |
+
gradient_checkpointing: bool = field(
|
322 |
+
default=False, metadata={"help": "Use gradient checkpointing."}
|
323 |
+
)
|
324 |
+
|
325 |
+
learning_rate: float = field(
|
326 |
+
default=5e-5, metadata={"help": "The initial learning rate."}
|
327 |
+
)
|
328 |
+
optim: str = field(
|
329 |
+
default="distributed_shampoo",
|
330 |
+
metadata={
|
331 |
+
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
|
332 |
+
},
|
333 |
+
)
|
334 |
+
beta1: float = field(
|
335 |
+
default=0.9,
|
336 |
+
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
|
337 |
+
)
|
338 |
+
beta2: float = field(
|
339 |
+
default=0.999,
|
340 |
+
metadata={"help": "Beta2 for for Adam & Distributed Shampoo."},
|
341 |
+
)
|
342 |
+
adam_epsilon: float = field(
|
343 |
+
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
|
344 |
+
)
|
345 |
+
max_grad_norm: float = field(
|
346 |
+
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
347 |
+
)
|
348 |
+
block_size: int = field(
|
349 |
+
default=1024,
|
350 |
+
metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
|
351 |
+
)
|
352 |
+
preconditioning_compute_steps: int = field(
|
353 |
+
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
354 |
+
)
|
355 |
+
skip_preconditioning_dim_size_gt: int = field(
|
356 |
+
default=4096,
|
357 |
+
metadata={"help": "Max size for preconditioning with Distributed Shampoo."},
|
358 |
+
)
|
359 |
+
graft_type: str = field(
|
360 |
+
default="rmsprop_normalized",
|
361 |
+
metadata={
|
362 |
+
"help": "The type of grafting to use. Can be 'rmsprop_normalized' (default), 'rmsprop', 'adagrad', 'adagrad_normalized', 'sgd' or 'sqrt_n'"
|
363 |
+
},
|
364 |
+
)
|
365 |
+
optim_quantized: bool = field(
|
366 |
+
default=False,
|
367 |
+
metadata={
|
368 |
+
"help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
|
369 |
+
},
|
370 |
+
)
|
371 |
+
|
372 |
+
num_train_epochs: int = field(
|
373 |
+
default=3, metadata={"help": "Total number of training epochs to perform."}
|
374 |
+
)
|
375 |
+
|
376 |
+
warmup_steps: int = field(
|
377 |
+
default=0, metadata={"help": "Linear warmup over warmup_steps."}
|
378 |
+
)
|
379 |
+
lr_decay: str = field(
|
380 |
+
default=None,
|
381 |
+
metadata={
|
382 |
+
"help": "Decay to be used in the learning rate scheduler. Can be None (default), linear or exponential."
|
383 |
+
},
|
384 |
+
)
|
385 |
+
lr_transition_steps: int = field(
|
386 |
+
default=None,
|
387 |
+
metadata={
|
388 |
+
"help": "Number of transition steps associated with learning rate decay when using exponential decay."
|
389 |
+
},
|
390 |
+
)
|
391 |
+
lr_decay_rate: float = field(
|
392 |
+
default=None,
|
393 |
+
metadata={
|
394 |
+
"help": "Decay rate associated with learning rate when using exponential decay."
|
395 |
+
},
|
396 |
+
)
|
397 |
+
lr_staircase: bool = field(
|
398 |
+
default=False,
|
399 |
+
metadata={
|
400 |
+
"help": "Whether to use staircase or continuous learning rate when using exponential decay."
|
401 |
+
},
|
402 |
+
)
|
403 |
+
|
404 |
+
logging_steps: int = field(
|
405 |
+
default=40, metadata={"help": "Log every X updates steps."}
|
406 |
+
)
|
407 |
+
eval_steps: int = field(
|
408 |
+
default=400, metadata={"help": "Run an evaluation every X steps."}
|
409 |
+
)
|
410 |
+
save_steps: int = field(
|
411 |
+
default=4000, metadata={"help": "Save checkpoint every X updates steps."}
|
412 |
+
)
|
413 |
+
log_model: bool = field(
|
414 |
+
default=False,
|
415 |
+
metadata={"help": "Log model to wandb at `save_steps` frequency."},
|
416 |
+
)
|
417 |
+
log_norm_steps: int = field(
|
418 |
+
default=True,
|
419 |
+
metadata={"help": "Log parameters and gradients norm at this frequency."},
|
420 |
+
)
|
421 |
+
log_histogram_steps: int = field(
|
422 |
+
default=False,
|
423 |
+
metadata={
|
424 |
+
"help": "Log parameters and gradients histograms at this frequency. Slows down training."
|
425 |
+
},
|
426 |
+
)
|
427 |
+
|
428 |
+
seed_model: int = field(
|
429 |
+
default=42,
|
430 |
+
metadata={
|
431 |
+
"help": "Random seed for the model that will be set at the beginning of training."
|
432 |
+
},
|
433 |
+
)
|
434 |
+
|
435 |
+
wandb_entity: Optional[str] = field(
|
436 |
+
default=None,
|
437 |
+
metadata={"help": "The wandb entity to use (for teams)."},
|
438 |
+
)
|
439 |
+
wandb_project: str = field(
|
440 |
+
default="dalle-mini",
|
441 |
+
metadata={"help": "The name of the wandb project."},
|
442 |
+
)
|
443 |
+
wandb_job_type: str = field(
|
444 |
+
default="Seq2Seq",
|
445 |
+
metadata={"help": "The name of the wandb job type."},
|
446 |
+
)
|
447 |
+
|
448 |
+
assert_TPU_available: bool = field(
|
449 |
+
default=False,
|
450 |
+
metadata={"help": "Verify that TPU is not in use."},
|
451 |
+
)
|
452 |
+
|
453 |
+
mp_devices: Optional[int] = field(
|
454 |
+
default=1,
|
455 |
+
metadata={
|
456 |
+
"help": "Number of devices required for model parallelism. The other dimension of available devices is used for data parallelism."
|
457 |
+
},
|
458 |
+
)
|
459 |
+
|
460 |
+
dp_devices: int = field(init=False)
|
461 |
+
|
462 |
+
def __post_init__(self):
|
463 |
+
if self.assert_TPU_available:
|
464 |
+
assert (
|
465 |
+
jax.local_device_count() == 8
|
466 |
+
), "TPUs in use, please check running processes"
|
467 |
+
if self.output_dir.startswith("gs://"):
|
468 |
+
assert (
|
469 |
+
storage is not None
|
470 |
+
), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
|
471 |
+
assert self.optim in [
|
472 |
+
"distributed_shampoo",
|
473 |
+
"adam",
|
474 |
+
"adafactor",
|
475 |
+
], f"Selected optimizer not supported: {self.optim}"
|
476 |
+
assert self.graft_type in [
|
477 |
+
"rmsprop_normalized",
|
478 |
+
"rmsprop",
|
479 |
+
"adagrad",
|
480 |
+
"adagrad_normalized",
|
481 |
+
"sgd",
|
482 |
+
"sqrt_n",
|
483 |
+
], f"Selected graft type not supported: {self.graft_type}"
|
484 |
+
assert self.lr_decay in [
|
485 |
+
None,
|
486 |
+
"linear",
|
487 |
+
"exponential",
|
488 |
+
], f"Selected learning rate decay not supported: {self.lr_decay}"
|
489 |
+
if self.per_device_eval_batch_size is None:
|
490 |
+
self.per_device_eval_batch_size = self.per_device_train_batch_size
|
491 |
+
if self.log_norm_steps is True:
|
492 |
+
self.log_norm_steps = self.logging_steps
|
493 |
+
if (
|
494 |
+
os.path.exists(self.output_dir)
|
495 |
+
and os.listdir(self.output_dir)
|
496 |
+
and self.do_train
|
497 |
+
and not self.overwrite_output_dir
|
498 |
+
):
|
499 |
+
raise ValueError(
|
500 |
+
f"Output directory ({self.output_dir}) already exists and is not empty."
|
501 |
+
"Use --overwrite_output_dir to overcome."
|
502 |
+
)
|
503 |
+
assert (
|
504 |
+
self.mp_devices > 0
|
505 |
+
), f"Number of devices for model parallelism must be > 0"
|
506 |
+
assert (
|
507 |
+
jax.device_count() % self.mp_devices == 0
|
508 |
+
), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})."
|
509 |
+
self.dp_devices = jax.device_count() // self.mp_devices
|
510 |
+
|
511 |
+
|
512 |
+
class TrainState(train_state.TrainState):
|
513 |
+
dropout_rng: jnp.ndarray = None
|
514 |
+
epoch: int = 0
|
515 |
+
train_time: float = 0.0 # total time the model trained
|
516 |
+
train_samples: int = 0 # number of samples seen
|
517 |
+
|
518 |
+
|
519 |
+
def main():
|
520 |
+
# See all possible arguments by passing the --help flag to this script.
|
521 |
+
parser = HfArgumentParser(
|
522 |
+
(ModelArguments, DataTrainingArguments, TrainingArguments)
|
523 |
+
)
|
524 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
525 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
526 |
+
# let's parse it to get our arguments.
|
527 |
+
model_args, data_args, training_args = parser.parse_json_file(
|
528 |
+
json_file=os.path.abspath(sys.argv[1])
|
529 |
+
)
|
530 |
+
else:
|
531 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
532 |
+
|
533 |
+
# Make one log on every process with the configuration for debugging.
|
534 |
+
logging.basicConfig(
|
535 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
536 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
537 |
+
level=logging.INFO,
|
538 |
+
)
|
539 |
+
# Setup logging, we only want one process per machine to log things on the screen.
|
540 |
+
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
|
541 |
+
if jax.process_index() == 0:
|
542 |
+
datasets.utils.logging.set_verbosity_warning()
|
543 |
+
transformers.utils.logging.set_verbosity_info()
|
544 |
+
else:
|
545 |
+
datasets.utils.logging.set_verbosity_error()
|
546 |
+
transformers.utils.logging.set_verbosity_error()
|
547 |
+
|
548 |
+
# Set the verbosity to info of the Transformers logger (on main process only):
|
549 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
550 |
+
|
551 |
+
# Load dataset
|
552 |
+
dataset = Dataset(
|
553 |
+
**asdict(data_args),
|
554 |
+
do_train=training_args.do_train,
|
555 |
+
do_eval=training_args.do_eval,
|
556 |
+
)
|
557 |
+
|
558 |
+
logger.info(f"Local TPUs: {jax.local_device_count()}")
|
559 |
+
logger.info(f"Global TPUs: {jax.device_count()}")
|
560 |
+
|
561 |
+
# Set up wandb run
|
562 |
+
if jax.process_index() == 0:
|
563 |
+
wandb.init(
|
564 |
+
entity=training_args.wandb_entity,
|
565 |
+
project=training_args.wandb_project,
|
566 |
+
job_type=training_args.wandb_job_type,
|
567 |
+
config=parser.parse_args(),
|
568 |
+
)
|
569 |
+
|
570 |
+
# Set up our new model config
|
571 |
+
if model_args.config_name:
|
572 |
+
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
573 |
+
config.gradient_checkpointing = training_args.gradient_checkpointing
|
574 |
+
else:
|
575 |
+
config = None
|
576 |
+
|
577 |
+
# Load or create new model
|
578 |
+
if model_args.model_name_or_path:
|
579 |
+
model = DalleBart.from_pretrained(
|
580 |
+
model_args.model_name_or_path,
|
581 |
+
config=config,
|
582 |
+
seed=training_args.seed_model,
|
583 |
+
dtype=getattr(jnp, model_args.dtype),
|
584 |
+
abstract_init=True, # we overwrite them with loaded checkpoint
|
585 |
+
gradient_checkpointing=training_args.gradient_checkpointing,
|
586 |
+
)
|
587 |
+
else:
|
588 |
+
model = DalleBart(
|
589 |
+
config,
|
590 |
+
seed=training_args.seed_model,
|
591 |
+
dtype=getattr(jnp, model_args.dtype),
|
592 |
+
abstract_init=True,
|
593 |
+
)
|
594 |
+
|
595 |
+
# get model metadata
|
596 |
+
model_metadata = model_args.get_metadata()
|
597 |
+
|
598 |
+
# get PartitionSpec for model params (required to be a dict)
|
599 |
+
param_spec = set_partitions(model.params)
|
600 |
+
|
601 |
+
# convert params to frozen dict
|
602 |
+
model._params = freeze(model.params)
|
603 |
+
|
604 |
+
# Load tokenizer
|
605 |
+
tokenizer = DalleBartTokenizer.from_pretrained(
|
606 |
+
model_args.tokenizer_name, use_fast=True
|
607 |
+
)
|
608 |
+
|
609 |
+
# Preprocessing the datasets.
|
610 |
+
# We need to normalize and tokenize inputs and targets.
|
611 |
+
dataset.preprocess(tokenizer=tokenizer, config=model.config)
|
612 |
+
|
613 |
+
# Initialize our training
|
614 |
+
dropout_rng = jax.random.PRNGKey(training_args.seed_model)
|
615 |
+
|
616 |
+
# Store some constant
|
617 |
+
num_epochs = training_args.num_train_epochs
|
618 |
+
# batch size
|
619 |
+
batch_size_per_node_per_grad_step = (
|
620 |
+
training_args.per_device_train_batch_size
|
621 |
+
* jax.local_device_count()
|
622 |
+
// training_args.mp_devices
|
623 |
+
)
|
624 |
+
batch_size_per_node = (
|
625 |
+
batch_size_per_node_per_grad_step * training_args.gradient_accumulation_steps
|
626 |
+
)
|
627 |
+
batch_size_per_step = batch_size_per_node * jax.process_count()
|
628 |
+
eval_batch_size_per_node = (
|
629 |
+
training_args.per_device_eval_batch_size
|
630 |
+
* jax.local_device_count()
|
631 |
+
// training_args.mp_devices
|
632 |
+
)
|
633 |
+
eval_batch_size_per_step = eval_batch_size_per_node * jax.process_count()
|
634 |
+
len_train_dataset, len_eval_dataset = dataset.length
|
635 |
+
steps_per_epoch = (
|
636 |
+
len_train_dataset // batch_size_per_node
|
637 |
+
if len_train_dataset is not None
|
638 |
+
else None
|
639 |
+
)
|
640 |
+
num_train_steps = (
|
641 |
+
steps_per_epoch * num_epochs if steps_per_epoch is not None else None
|
642 |
+
)
|
643 |
+
num_params = model.num_params
|
644 |
+
|
645 |
+
logger.info("***** Running training *****")
|
646 |
+
logger.info(f" Num examples = {len_train_dataset}")
|
647 |
+
logger.info(f" Num Epochs = {num_epochs}")
|
648 |
+
logger.info(
|
649 |
+
f" Batch size per dp device = {training_args.per_device_train_batch_size}"
|
650 |
+
)
|
651 |
+
logger.info(f" Number of devices = {jax.device_count()}")
|
652 |
+
logger.info(
|
653 |
+
f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
|
654 |
+
)
|
655 |
+
logger.info(f" Batch size per update = {batch_size_per_step}")
|
656 |
+
logger.info(f" Model parameters = {num_params:,}")
|
657 |
+
|
658 |
+
# set up wandb run
|
659 |
+
if jax.process_index() == 0:
|
660 |
+
# set default x-axis as 'train/step'
|
661 |
+
wandb.define_metric("*", step_metric="train/step")
|
662 |
+
|
663 |
+
# add interesting config parameters
|
664 |
+
wandb.config.update(
|
665 |
+
{
|
666 |
+
"len_train_dataset": len_train_dataset,
|
667 |
+
"len_eval_dataset": len_eval_dataset,
|
668 |
+
"batch_size_per_step": batch_size_per_step,
|
669 |
+
"num_params": num_params,
|
670 |
+
"model_config": model.config.to_dict(),
|
671 |
+
"num_devices": jax.device_count(),
|
672 |
+
"versions": {
|
673 |
+
"jax": jax.__version__,
|
674 |
+
"jaxlib": jaxlib.__version__,
|
675 |
+
"flax": flax.__version__,
|
676 |
+
"transformers": transformers.__version__,
|
677 |
+
"datasets": datasets.__version__,
|
678 |
+
"wandb": wandb.__version__,
|
679 |
+
"dalle_mini": dalle_mini.__version__,
|
680 |
+
},
|
681 |
+
}
|
682 |
+
)
|
683 |
+
|
684 |
+
# Create learning rate schedule
|
685 |
+
def create_learning_rate_fn() -> Callable[[int], jnp.array]:
|
686 |
+
"""Create the learning rate function."""
|
687 |
+
warmup_fn = optax.linear_schedule(
|
688 |
+
init_value=0.0,
|
689 |
+
end_value=training_args.learning_rate,
|
690 |
+
transition_steps=training_args.warmup_steps + 1, # ensure not 0
|
691 |
+
)
|
692 |
+
# offset step when resuming
|
693 |
+
if model_metadata.get("step", 0):
|
694 |
+
warmup_fn = optax.join_schedules(
|
695 |
+
schedules=[optax.constant_schedule(0.0), warmup_fn],
|
696 |
+
boundaries=[model_metadata["step"]],
|
697 |
+
)
|
698 |
+
if training_args.lr_decay is None:
|
699 |
+
return warmup_fn
|
700 |
+
elif training_args.lr_decay == "linear":
|
701 |
+
assert (
|
702 |
+
num_train_steps is not None
|
703 |
+
), "linear decay requires knowing the dataset length"
|
704 |
+
decay_fn = optax.linear_schedule(
|
705 |
+
init_value=training_args.learning_rate,
|
706 |
+
end_value=0,
|
707 |
+
transition_steps=num_train_steps - training_args.warmup_steps,
|
708 |
+
)
|
709 |
+
elif training_args.lr_decay == "exponential":
|
710 |
+
decay_fn = optax.exponential_decay(
|
711 |
+
init_value=training_args.learning_rate,
|
712 |
+
transition_steps=training_args.lr_transition_steps,
|
713 |
+
decay_rate=training_args.lr_decay_rate,
|
714 |
+
staircase=training_args.lr_staircase,
|
715 |
+
)
|
716 |
+
schedule_fn = optax.join_schedules(
|
717 |
+
schedules=[warmup_fn, decay_fn],
|
718 |
+
boundaries=[model_metadata.get("step", 0) + training_args.warmup_steps],
|
719 |
+
)
|
720 |
+
return schedule_fn
|
721 |
+
|
722 |
+
learning_rate_fn = create_learning_rate_fn()
|
723 |
+
|
724 |
+
# create adam optimizer
|
725 |
+
if training_args.optim == "distributed_shampoo":
|
726 |
+
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
727 |
+
graft_type = {
|
728 |
+
"sgd": GraftingType.SGD,
|
729 |
+
"adagrad": GraftingType.ADAGRAD,
|
730 |
+
"rmsprop": GraftingType.RMSPROP,
|
731 |
+
"rmsprop_normalized": GraftingType.RMSPROP_NORMALIZED,
|
732 |
+
"sqrt_n": GraftingType.SQRT_N,
|
733 |
+
"adagrad_normalized": GraftingType.ADAGRAD_NORMALIZED,
|
734 |
+
}[training_args.graft_type]
|
735 |
+
optimizer = distributed_shampoo(
|
736 |
+
learning_rate_fn,
|
737 |
+
block_size=training_args.block_size,
|
738 |
+
beta1=training_args.beta1,
|
739 |
+
beta2=training_args.beta2,
|
740 |
+
diagonal_epsilon=1e-10,
|
741 |
+
matrix_epsilon=1e-6,
|
742 |
+
start_preconditioning_step=max(
|
743 |
+
training_args.preconditioning_compute_steps + 1, 101
|
744 |
+
),
|
745 |
+
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
746 |
+
statistics_compute_steps=1,
|
747 |
+
best_effort_shape_interpretation=True,
|
748 |
+
graft_type=graft_type,
|
749 |
+
nesterov=False,
|
750 |
+
exponent_override=0,
|
751 |
+
statistics_partition_spec=PartitionSpec(None, "dp", None),
|
752 |
+
preconditioner_partition_spec=PartitionSpec("dp", None, None),
|
753 |
+
num_devices_for_pjit=training_args.dp_devices,
|
754 |
+
shard_optimizer_states=True,
|
755 |
+
inverse_failure_threshold=0.1,
|
756 |
+
moving_average_for_momentum=True,
|
757 |
+
skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
|
758 |
+
clip_by_scaled_gradient_norm=None,
|
759 |
+
precision=jax.lax.Precision.HIGHEST,
|
760 |
+
best_effort_memory_usage_reduction=training_args.optim_quantized,
|
761 |
+
)
|
762 |
+
# get the real optimizer and helper functions
|
763 |
+
update_fn = optimizer.update
|
764 |
+
optimizer = optimizer.init(model.params)
|
765 |
+
opt_fn = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
|
766 |
+
optimizer.pspec_fn, optimizer.shape_and_dtype_fn
|
767 |
+
)
|
768 |
+
optimizer = optax.GradientTransformation(optimizer.init_fn, update_fn)
|
769 |
+
|
770 |
+
elif training_args.optim == "adam":
|
771 |
+
optimizer = optax.adamw(
|
772 |
+
learning_rate=learning_rate_fn,
|
773 |
+
b1=training_args.beta1,
|
774 |
+
b2=training_args.beta2,
|
775 |
+
eps=training_args.adam_epsilon,
|
776 |
+
)
|
777 |
+
elif training_args.optim == "adafactor":
|
778 |
+
# We use the default parameters here to initialize adafactor,
|
779 |
+
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
780 |
+
optimizer = optax.adafactor(
|
781 |
+
learning_rate=learning_rate_fn,
|
782 |
+
clipping_threshold=training_args.max_grad_norm,
|
783 |
+
)
|
784 |
+
|
785 |
+
# get PartitionSpec for optimizer state
|
786 |
+
def get_opt_state_spec_and_shape(param_spec):
|
787 |
+
# get opt_state shape without actual init
|
788 |
+
opt_state_shape = jax.eval_shape(optimizer.init, model.params)
|
789 |
+
|
790 |
+
if training_args.optim == "adam":
|
791 |
+
|
792 |
+
def _opt_state_spec_per_leaf(x):
|
793 |
+
if isinstance(x, FrozenDict):
|
794 |
+
# variables with same structure as params
|
795 |
+
return param_spec
|
796 |
+
else:
|
797 |
+
# other variables such as count
|
798 |
+
return None
|
799 |
+
|
800 |
+
opt_state_spec = jax.tree_map(
|
801 |
+
_opt_state_spec_per_leaf,
|
802 |
+
opt_state_shape,
|
803 |
+
# return None spec for empty elements
|
804 |
+
is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
|
805 |
+
)
|
806 |
+
|
807 |
+
elif training_args.optim == "adafactor":
|
808 |
+
# factorized state must be replicated (rank different than params)
|
809 |
+
opt_state_spec = None
|
810 |
+
|
811 |
+
elif training_args.optim == "distributed_shampoo":
|
812 |
+
opt_state_spec = opt_fn.pspec_fn(
|
813 |
+
params=model.params,
|
814 |
+
params_partition_spec=param_spec,
|
815 |
+
partition_spec_for_statistics=PartitionSpec(None, "dp", None),
|
816 |
+
)
|
817 |
+
else:
|
818 |
+
raise NotImplementedError
|
819 |
+
return opt_state_spec, opt_state_shape
|
820 |
+
|
821 |
+
opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape(param_spec)
|
822 |
+
|
823 |
+
# create a mesh
|
824 |
+
mesh_shape = (training_args.dp_devices, training_args.mp_devices)
|
825 |
+
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
826 |
+
mesh = maps.Mesh(devices, ("dp", "mp"))
|
827 |
+
logger.info(f" Mesh shape: {mesh_shape}")
|
828 |
+
|
829 |
+
# define state spec
|
830 |
+
state_spec = TrainState(
|
831 |
+
params=param_spec,
|
832 |
+
opt_state=opt_state_spec,
|
833 |
+
dropout_rng=None,
|
834 |
+
step=None,
|
835 |
+
epoch=None,
|
836 |
+
train_time=None,
|
837 |
+
train_samples=None,
|
838 |
+
apply_fn=model.__call__,
|
839 |
+
tx=optimizer,
|
840 |
+
)
|
841 |
+
|
842 |
+
# init params if not available yet
|
843 |
+
def maybe_init_params(params):
|
844 |
+
if model_args.model_name_or_path:
|
845 |
+
# model params are correctly loaded
|
846 |
+
return params
|
847 |
+
else:
|
848 |
+
# params have not been initialized yet
|
849 |
+
return model.init_weights()
|
850 |
+
|
851 |
+
with mesh:
|
852 |
+
logger.info(" Creating state")
|
853 |
+
if not model_args.restore_state:
|
854 |
+
|
855 |
+
def init_state(params):
|
856 |
+
return TrainState.create(
|
857 |
+
apply_fn=model.__call__,
|
858 |
+
tx=optimizer,
|
859 |
+
params=maybe_init_params(params),
|
860 |
+
dropout_rng=dropout_rng,
|
861 |
+
)
|
862 |
+
|
863 |
+
state = pjit(
|
864 |
+
init_state,
|
865 |
+
in_axis_resources=(param_spec,)
|
866 |
+
if model_args.model_name_or_path
|
867 |
+
else None,
|
868 |
+
out_axis_resources=state_spec,
|
869 |
+
donate_argnums=(0,),
|
870 |
+
)(model.params if model_args.model_name_or_path else None)
|
871 |
+
|
872 |
+
else:
|
873 |
+
# load opt_state
|
874 |
+
opt_state = from_bytes(opt_state_shape, model_args.get_opt_state())
|
875 |
+
|
876 |
+
# restore other attributes
|
877 |
+
attr_state = {
|
878 |
+
k: model_metadata[k]
|
879 |
+
for k in ["step", "epoch", "train_time", "train_samples"]
|
880 |
+
}
|
881 |
+
|
882 |
+
def restore_state(params, opt_state):
|
883 |
+
return TrainState(
|
884 |
+
apply_fn=model.__call__,
|
885 |
+
tx=optimizer,
|
886 |
+
params=params,
|
887 |
+
opt_state=opt_state,
|
888 |
+
dropout_rng=dropout_rng,
|
889 |
+
**attr_state,
|
890 |
+
)
|
891 |
+
|
892 |
+
state = pjit(
|
893 |
+
restore_state,
|
894 |
+
in_axis_resources=(
|
895 |
+
param_spec,
|
896 |
+
opt_state_spec,
|
897 |
+
),
|
898 |
+
out_axis_resources=state_spec,
|
899 |
+
donate_argnums=(0, 1),
|
900 |
+
)(model.params, opt_state)
|
901 |
+
|
902 |
+
# remove opt_state from CPU
|
903 |
+
del opt_state
|
904 |
+
|
905 |
+
# free CPU memory
|
906 |
+
del model._params, opt_state_spec, opt_state_shape
|
907 |
+
|
908 |
+
# define batch specs
|
909 |
+
batch_spec = PartitionSpec("dp")
|
910 |
+
grad_batch_spec = PartitionSpec(None, "dp")
|
911 |
+
|
912 |
+
# define loss
|
913 |
+
def loss_fn(logits, labels):
|
914 |
+
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
|
915 |
+
loss = loss.mean()
|
916 |
+
return loss
|
917 |
+
|
918 |
+
# "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens)
|
919 |
+
# lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2
|
920 |
+
use_vmap_trick = True
|
921 |
+
|
922 |
+
# make grad_param_spec for vmap
|
923 |
+
if use_vmap_trick:
|
924 |
+
grad_param_spec = jax.tree_map(
|
925 |
+
lambda x: PartitionSpec(*("dp",) + (x if x is not None else (None,))),
|
926 |
+
param_spec,
|
927 |
+
)
|
928 |
+
|
929 |
+
# Define gradient update step fn
|
930 |
+
def train_step(state, batch, train_time):
|
931 |
+
|
932 |
+
# get a minibatch (one gradient accumulation slice)
|
933 |
+
def get_minibatch(batch, grad_idx):
|
934 |
+
return jax.tree_map(
|
935 |
+
lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),
|
936 |
+
batch,
|
937 |
+
)
|
938 |
+
|
939 |
+
def compute_loss(params, minibatch, dropout_rng):
|
940 |
+
# minibatch has dim (batch_size, ...)
|
941 |
+
minibatch, labels = minibatch.pop("labels")
|
942 |
+
logits = state.apply_fn(
|
943 |
+
**minibatch, params=params, dropout_rng=dropout_rng, train=True
|
944 |
+
)[0]
|
945 |
+
return loss_fn(logits, labels)
|
946 |
+
|
947 |
+
grad_fn = jax.value_and_grad(compute_loss)
|
948 |
+
|
949 |
+
def loss_and_grad(grad_idx, dropout_rng):
|
950 |
+
# minibatch at grad_idx for gradient accumulation (None otherwise)
|
951 |
+
minibatch = (
|
952 |
+
get_minibatch(batch, grad_idx) if grad_idx is not None else batch
|
953 |
+
)
|
954 |
+
# ensure it is sharded properly
|
955 |
+
minibatch = with_sharding_constraint(minibatch, batch_spec)
|
956 |
+
# only 1 single rng per grad step, let us handle larger batch size (not sure why)
|
957 |
+
dropout_rng, _ = jax.random.split(dropout_rng)
|
958 |
+
|
959 |
+
if use_vmap_trick:
|
960 |
+
# "vmap trick", calculate loss and grads independently per dp_device
|
961 |
+
loss, grads = jax.vmap(
|
962 |
+
grad_fn, in_axes=(None, 0, None), out_axes=(0, 0)
|
963 |
+
)(state.params, minibatch, dropout_rng)
|
964 |
+
# ensure they are sharded correctly
|
965 |
+
loss = with_sharding_constraint(loss, batch_spec)
|
966 |
+
grads = with_sharding_constraint(grads, grad_param_spec)
|
967 |
+
# average across all devices
|
968 |
+
# Note: we could average per device only after gradient accumulation, right before params update
|
969 |
+
loss, grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), (loss, grads))
|
970 |
+
else:
|
971 |
+
# "vmap trick" does not work in multi-hosts and requires too much hbm
|
972 |
+
loss, grads = grad_fn(state.params, minibatch, dropout_rng)
|
973 |
+
# ensure grads are sharded
|
974 |
+
grads = with_sharding_constraint(grads, param_spec)
|
975 |
+
# return loss and grads
|
976 |
+
return loss, grads, dropout_rng
|
977 |
+
|
978 |
+
if training_args.gradient_accumulation_steps == 1:
|
979 |
+
loss, grads, dropout_rng = loss_and_grad(None, state.dropout_rng)
|
980 |
+
else:
|
981 |
+
# create initial state for cumul_minibatch_step loop
|
982 |
+
init_minibatch_step = (
|
983 |
+
0.0,
|
984 |
+
with_sharding_constraint(
|
985 |
+
jax.tree_map(jnp.zeros_like, state.params), param_spec
|
986 |
+
),
|
987 |
+
state.dropout_rng,
|
988 |
+
)
|
989 |
+
|
990 |
+
# accumulate gradients
|
991 |
+
def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
|
992 |
+
cumul_loss, cumul_grads, dropout_rng = cumul_loss_grad_dropout
|
993 |
+
loss, grads, dropout_rng = loss_and_grad(grad_idx, dropout_rng)
|
994 |
+
cumul_loss, cumul_grads = jax.tree_map(
|
995 |
+
jnp.add, (cumul_loss, cumul_grads), (loss, grads)
|
996 |
+
)
|
997 |
+
cumul_grads = with_sharding_constraint(cumul_grads, param_spec)
|
998 |
+
return cumul_loss, cumul_grads, dropout_rng
|
999 |
+
|
1000 |
+
# loop over gradients
|
1001 |
+
loss, grads, dropout_rng = jax.lax.fori_loop(
|
1002 |
+
0,
|
1003 |
+
training_args.gradient_accumulation_steps,
|
1004 |
+
cumul_minibatch_step,
|
1005 |
+
init_minibatch_step,
|
1006 |
+
)
|
1007 |
+
grads = with_sharding_constraint(grads, param_spec)
|
1008 |
+
# sum -> mean
|
1009 |
+
loss, grads = jax.tree_map(
|
1010 |
+
lambda x: x / training_args.gradient_accumulation_steps, (loss, grads)
|
1011 |
+
)
|
1012 |
+
|
1013 |
+
grads = with_sharding_constraint(grads, param_spec)
|
1014 |
+
|
1015 |
+
# update state
|
1016 |
+
state = state.apply_gradients(
|
1017 |
+
grads=grads,
|
1018 |
+
dropout_rng=dropout_rng,
|
1019 |
+
train_time=train_time,
|
1020 |
+
train_samples=state.train_samples + batch_size_per_step,
|
1021 |
+
)
|
1022 |
+
|
1023 |
+
metrics = {
|
1024 |
+
"loss": loss,
|
1025 |
+
"learning_rate": learning_rate_fn(state.step),
|
1026 |
+
}
|
1027 |
+
|
1028 |
+
def maybe_fn(fn, val, zeros, freq):
|
1029 |
+
"""Call fn only if it is a logging step"""
|
1030 |
+
return jax.lax.cond(
|
1031 |
+
state.step % freq == 0,
|
1032 |
+
fn,
|
1033 |
+
lambda _: zeros,
|
1034 |
+
val,
|
1035 |
+
)
|
1036 |
+
|
1037 |
+
if training_args.log_norm_steps:
|
1038 |
+
zeros_norm = jax.tree_map(lambda _: jnp.float32(0), state.params)
|
1039 |
+
|
1040 |
+
def norm(val):
|
1041 |
+
return jax.tree_map(lambda x: jnp.linalg.norm(x), val)
|
1042 |
+
|
1043 |
+
gradients_norm = maybe_fn(
|
1044 |
+
norm, grads, zeros_norm, training_args.log_norm_steps
|
1045 |
+
)
|
1046 |
+
params_norm = maybe_fn(
|
1047 |
+
norm, state.params, zeros_norm, training_args.log_norm_steps
|
1048 |
+
)
|
1049 |
+
|
1050 |
+
metrics.update(
|
1051 |
+
{
|
1052 |
+
"gradients_norm": gradients_norm,
|
1053 |
+
"params_norm": params_norm,
|
1054 |
+
}
|
1055 |
+
)
|
1056 |
+
|
1057 |
+
if training_args.log_histogram_steps:
|
1058 |
+
zeros_hist = jax.tree_map(
|
1059 |
+
lambda _: jnp.histogram(jnp.zeros(1), density=True), state.params
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
def histogram(val):
|
1063 |
+
return jax.tree_map(lambda x: jnp.histogram(x, density=True), val)
|
1064 |
+
|
1065 |
+
gradients_hist = maybe_fn(
|
1066 |
+
histogram, grads, zeros_hist, training_args.log_histogram_steps
|
1067 |
+
)
|
1068 |
+
params_hist = maybe_fn(
|
1069 |
+
histogram, state.params, zeros_hist, training_args.log_histogram_steps
|
1070 |
+
)
|
1071 |
+
|
1072 |
+
metrics.update(
|
1073 |
+
{
|
1074 |
+
"params_hist": params_hist,
|
1075 |
+
"gradients_hist": gradients_hist,
|
1076 |
+
}
|
1077 |
+
)
|
1078 |
+
|
1079 |
+
return state, metrics
|
1080 |
+
|
1081 |
+
# Define eval fn
|
1082 |
+
def eval_step(state, batch):
|
1083 |
+
def compute_eval_loss(batch):
|
1084 |
+
batch, labels = batch.pop("labels")
|
1085 |
+
logits = model(**batch, params=state.params, train=False)[0]
|
1086 |
+
return loss_fn(logits, labels)
|
1087 |
+
|
1088 |
+
if use_vmap_trick:
|
1089 |
+
loss = jax.vmap(compute_eval_loss)(batch)
|
1090 |
+
# ensure they are sharded correctly
|
1091 |
+
loss = with_sharding_constraint(loss, batch_spec)
|
1092 |
+
# average across all devices
|
1093 |
+
loss = jnp.mean(loss)
|
1094 |
+
else:
|
1095 |
+
loss = compute_eval_loss(batch)
|
1096 |
+
|
1097 |
+
return loss
|
1098 |
+
|
1099 |
+
# Create parallel version of the train and eval step
|
1100 |
+
p_train_step = pjit(
|
1101 |
+
train_step,
|
1102 |
+
in_axis_resources=(
|
1103 |
+
state_spec,
|
1104 |
+
grad_batch_spec
|
1105 |
+
if training_args.gradient_accumulation_steps > 1
|
1106 |
+
else batch_spec,
|
1107 |
+
None,
|
1108 |
+
),
|
1109 |
+
out_axis_resources=(state_spec, None),
|
1110 |
+
donate_argnums=(0,),
|
1111 |
+
)
|
1112 |
+
p_eval_step = pjit(
|
1113 |
+
eval_step,
|
1114 |
+
in_axis_resources=(state_spec, batch_spec),
|
1115 |
+
out_axis_resources=None,
|
1116 |
+
)
|
1117 |
+
|
1118 |
+
# define metrics logger
|
1119 |
+
class MetricsLogger:
|
1120 |
+
def __init__(self, step):
|
1121 |
+
# keep state
|
1122 |
+
self.state_dict = {}
|
1123 |
+
# estimate speed
|
1124 |
+
self.step = step
|
1125 |
+
self.time = time.perf_counter()
|
1126 |
+
self.offset_time = 0.0
|
1127 |
+
|
1128 |
+
def update_state_metrics(self, state):
|
1129 |
+
"""Update internal state metrics (logged at each call to be used as x-axis)"""
|
1130 |
+
self.state_dict = {
|
1131 |
+
f'train/{k.split("_")[-1]}': state[k]
|
1132 |
+
for k in ["step", "epoch", "train_time", "train_samples"]
|
1133 |
+
}
|
1134 |
+
# timing metrics
|
1135 |
+
new_step = int(state["step"])
|
1136 |
+
new_time = time.perf_counter()
|
1137 |
+
if new_step > self.step:
|
1138 |
+
# remove time for eval & save
|
1139 |
+
delta_time = new_time - self.time - self.offset_time
|
1140 |
+
self.offset_time = 0
|
1141 |
+
time_per_step = delta_time / (new_step - self.step)
|
1142 |
+
self.step = new_step
|
1143 |
+
self.time = new_time
|
1144 |
+
self.log_time("train_per_step", time_per_step, offset=False)
|
1145 |
+
self.log_time("train_per_log", delta_time, offset=False)
|
1146 |
+
|
1147 |
+
def log_time(self, key, duration, offset=True):
|
1148 |
+
wandb.log({f"time/{key}": duration, **self.state_dict})
|
1149 |
+
if offset:
|
1150 |
+
self.offset_time += duration
|
1151 |
+
|
1152 |
+
def log(self, metrics, prefix=None):
|
1153 |
+
if jax.process_index() == 0:
|
1154 |
+
log_metrics = {}
|
1155 |
+
for k, v in metrics.items():
|
1156 |
+
if "_norm" in k:
|
1157 |
+
if self.step % training_args.log_norm_steps == 0:
|
1158 |
+
log_metrics[f"{k}/"] = unfreeze(v)
|
1159 |
+
elif "_hist" in k:
|
1160 |
+
if self.step % training_args.log_histogram_steps == 0:
|
1161 |
+
v = jax.tree_map(lambda x: jax.device_get(x), unfreeze(v))
|
1162 |
+
v = jax.tree_map(
|
1163 |
+
lambda x: wandb.Histogram(np_histogram=x),
|
1164 |
+
v,
|
1165 |
+
is_leaf=lambda x: isinstance(x, tuple),
|
1166 |
+
)
|
1167 |
+
log_metrics[f"{k}/"] = v
|
1168 |
+
else:
|
1169 |
+
if prefix is not None:
|
1170 |
+
k = f"{prefix}/{k}"
|
1171 |
+
log_metrics[k] = v
|
1172 |
+
wandb.log({**log_metrics, **self.state_dict})
|
1173 |
+
|
1174 |
+
# keep local copy of state
|
1175 |
+
local_state = {
|
1176 |
+
k: jax.device_get(getattr(state, k)).item()
|
1177 |
+
for k in ["step", "epoch", "train_time", "train_samples"]
|
1178 |
+
}
|
1179 |
+
# init variables
|
1180 |
+
start_time = time.perf_counter() - local_state["train_time"]
|
1181 |
+
train_metrics = None
|
1182 |
+
metrics_logger = MetricsLogger(local_state["step"])
|
1183 |
+
epochs = tqdm(
|
1184 |
+
range(local_state["epoch"], num_epochs),
|
1185 |
+
desc=f"Epoch ... (1/{num_epochs})",
|
1186 |
+
position=0,
|
1187 |
+
disable=jax.process_index() > 0,
|
1188 |
+
)
|
1189 |
+
|
1190 |
+
def run_evaluation():
|
1191 |
+
# ======================== Evaluating ==============================
|
1192 |
+
if training_args.do_eval:
|
1193 |
+
start_eval_time = time.perf_counter()
|
1194 |
+
eval_loader = dataset.dataloader("eval", eval_batch_size_per_step)
|
1195 |
+
eval_steps = (
|
1196 |
+
len_eval_dataset // eval_batch_size_per_step
|
1197 |
+
if len_eval_dataset is not None
|
1198 |
+
else None
|
1199 |
+
)
|
1200 |
+
eval_loss = []
|
1201 |
+
for batch in tqdm(
|
1202 |
+
eval_loader,
|
1203 |
+
desc="Evaluating...",
|
1204 |
+
position=2,
|
1205 |
+
leave=False,
|
1206 |
+
total=eval_steps,
|
1207 |
+
disable=jax.process_index() > 0,
|
1208 |
+
):
|
1209 |
+
# need to keep only eval_batch_size_per_node items relevant to the node
|
1210 |
+
batch = jax.tree_map(
|
1211 |
+
lambda x: x.reshape(
|
1212 |
+
(jax.process_count(), eval_batch_size_per_node) + x.shape[1:]
|
1213 |
+
),
|
1214 |
+
batch,
|
1215 |
+
)
|
1216 |
+
batch = jax.tree_map(lambda x: x[jax.process_index()], batch)
|
1217 |
+
|
1218 |
+
# add dp dimension when using "vmap trick"
|
1219 |
+
if use_vmap_trick:
|
1220 |
+
bs_shape = (
|
1221 |
+
jax.local_device_count() // training_args.mp_devices,
|
1222 |
+
training_args.per_device_eval_batch_size,
|
1223 |
+
)
|
1224 |
+
batch = jax.tree_map(
|
1225 |
+
lambda x: x.reshape(bs_shape + x.shape[1:]), batch
|
1226 |
+
)
|
1227 |
+
|
1228 |
+
# freeze batch to pass safely to jax transforms
|
1229 |
+
batch = freeze(batch)
|
1230 |
+
# accumulate losses async
|
1231 |
+
eval_loss.append(p_eval_step(state, batch))
|
1232 |
+
|
1233 |
+
# get the mean of the loss
|
1234 |
+
eval_loss = jnp.stack(eval_loss)
|
1235 |
+
eval_loss = jnp.mean(eval_loss)
|
1236 |
+
eval_metrics = {"loss": eval_loss}
|
1237 |
+
|
1238 |
+
# log metrics
|
1239 |
+
metrics_logger.log(eval_metrics, prefix="eval")
|
1240 |
+
metrics_logger.log_time("eval", time.perf_counter() - start_eval_time)
|
1241 |
+
|
1242 |
+
# Print metrics and update progress bar
|
1243 |
+
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
1244 |
+
epochs.write(desc)
|
1245 |
+
epochs.desc = desc
|
1246 |
+
|
1247 |
+
return eval_metrics
|
1248 |
+
|
1249 |
+
def run_save_model(state, eval_metrics=None):
|
1250 |
+
if jax.process_index() == 0:
|
1251 |
+
|
1252 |
+
start_save_time = time.perf_counter()
|
1253 |
+
output_dir = training_args.output_dir
|
1254 |
+
use_bucket = output_dir.startswith("gs://")
|
1255 |
+
if use_bucket:
|
1256 |
+
bucket_path = Path(output_dir[5:]) / wandb.run.id / f"step_{state.step}"
|
1257 |
+
bucket, dir_path = str(bucket_path).split("/", 1)
|
1258 |
+
tmp_dir = tempfile.TemporaryDirectory()
|
1259 |
+
output_dir = tmp_dir.name
|
1260 |
+
|
1261 |
+
# save model
|
1262 |
+
params = jax.device_get(state.params)
|
1263 |
+
model.save_pretrained(
|
1264 |
+
output_dir,
|
1265 |
+
params=params,
|
1266 |
+
)
|
1267 |
+
|
1268 |
+
# save tokenizer
|
1269 |
+
tokenizer.save_pretrained(output_dir)
|
1270 |
+
|
1271 |
+
# copy to bucket
|
1272 |
+
if use_bucket:
|
1273 |
+
client = storage.Client()
|
1274 |
+
bucket = client.bucket(bucket)
|
1275 |
+
for filename in Path(output_dir).glob("*"):
|
1276 |
+
blob_name = str(Path(dir_path) / "model" / filename.name)
|
1277 |
+
blob = bucket.blob(blob_name)
|
1278 |
+
blob.upload_from_filename(str(filename))
|
1279 |
+
tmp_dir.cleanup()
|
1280 |
+
|
1281 |
+
# save state
|
1282 |
+
opt_state = jax.device_get(state.opt_state)
|
1283 |
+
if use_bucket:
|
1284 |
+
blob_name = str(Path(dir_path) / "state" / "opt_state.msgpack")
|
1285 |
+
blob = bucket.blob(blob_name)
|
1286 |
+
blob.upload_from_file(io.BytesIO(to_bytes(opt_state)))
|
1287 |
+
else:
|
1288 |
+
with (Path(output_dir) / "opt_state.msgpack").open("wb") as f:
|
1289 |
+
f.write(to_bytes(opt_state))
|
1290 |
+
|
1291 |
+
# save to W&B
|
1292 |
+
if training_args.log_model:
|
1293 |
+
# save some space
|
1294 |
+
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
1295 |
+
c.cleanup(wandb.util.from_human_size("20GB"))
|
1296 |
+
|
1297 |
+
metadata = {
|
1298 |
+
k: jax.device_get(getattr(state, k)).item()
|
1299 |
+
for k in ["step", "epoch", "train_time", "train_samples"]
|
1300 |
+
}
|
1301 |
+
metadata["num_params"] = num_params
|
1302 |
+
if eval_metrics is not None:
|
1303 |
+
metadata["eval"] = eval_metrics
|
1304 |
+
|
1305 |
+
# create model artifact
|
1306 |
+
if use_bucket:
|
1307 |
+
metadata["bucket_path"] = f"gs://{bucket_path}/model"
|
1308 |
+
artifact = wandb.Artifact(
|
1309 |
+
name=f"model-{wandb.run.id}",
|
1310 |
+
type="DalleBart_model",
|
1311 |
+
metadata=metadata,
|
1312 |
+
)
|
1313 |
+
if use_bucket:
|
1314 |
+
artifact.add_reference(metadata["bucket_path"])
|
1315 |
+
else:
|
1316 |
+
for filename in [
|
1317 |
+
"config.json",
|
1318 |
+
"flax_model.msgpack",
|
1319 |
+
"merges.txt",
|
1320 |
+
"special_tokens_map.json",
|
1321 |
+
"tokenizer.json",
|
1322 |
+
"tokenizer_config.json",
|
1323 |
+
"vocab.json",
|
1324 |
+
]:
|
1325 |
+
artifact.add_file(
|
1326 |
+
f"{Path(training_args.output_dir) / filename}"
|
1327 |
+
)
|
1328 |
+
wandb.run.log_artifact(artifact)
|
1329 |
+
|
1330 |
+
# create state artifact
|
1331 |
+
if use_bucket:
|
1332 |
+
metadata["bucket_path"] = f"gs://{bucket_path}/state"
|
1333 |
+
artifact_state = wandb.Artifact(
|
1334 |
+
name=f"state-{wandb.run.id}",
|
1335 |
+
type="DalleBart_state",
|
1336 |
+
metadata=metadata,
|
1337 |
+
)
|
1338 |
+
if use_bucket:
|
1339 |
+
artifact_state.add_reference(metadata["bucket_path"])
|
1340 |
+
else:
|
1341 |
+
artifact_state.add_file(
|
1342 |
+
f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
|
1343 |
+
)
|
1344 |
+
wandb.run.log_artifact(artifact_state)
|
1345 |
+
metrics_logger.log_time("save_model", time.perf_counter() - start_save_time)
|
1346 |
+
|
1347 |
+
logger.info(" Ready to start training")
|
1348 |
+
with mesh:
|
1349 |
+
for epoch in epochs:
|
1350 |
+
state.replace(epoch=epoch)
|
1351 |
+
local_state["epoch"] = epoch
|
1352 |
+
# ======================== Training ================================
|
1353 |
+
metrics_logger.update_state_metrics(local_state)
|
1354 |
+
metrics_logger.log({})
|
1355 |
+
|
1356 |
+
# Generate an epoch by shuffling sampling indices from the train dataset
|
1357 |
+
train_loader = dataset.dataloader(
|
1358 |
+
"train",
|
1359 |
+
batch_size_per_node,
|
1360 |
+
epoch,
|
1361 |
+
)
|
1362 |
+
# train
|
1363 |
+
for batch in tqdm(
|
1364 |
+
train_loader,
|
1365 |
+
desc="Training...",
|
1366 |
+
position=1,
|
1367 |
+
leave=False,
|
1368 |
+
total=steps_per_epoch,
|
1369 |
+
disable=jax.process_index() > 0,
|
1370 |
+
):
|
1371 |
+
# calculate delta time (we have a lag of one step but it's ok)
|
1372 |
+
train_time = time.perf_counter() - start_time
|
1373 |
+
|
1374 |
+
# set correct shape to batch
|
1375 |
+
# - add grad_step dim if gradient_accumulation_steps > 1
|
1376 |
+
# - split per dp device if not multi-host for vmap trick (does not work in multi-host)
|
1377 |
+
bs_shape = (
|
1378 |
+
(batch_size_per_node_per_grad_step,)
|
1379 |
+
if not use_vmap_trick
|
1380 |
+
else (
|
1381 |
+
jax.local_device_count()
|
1382 |
+
// training_args.mp_devices, # local dp devices
|
1383 |
+
training_args.per_device_train_batch_size,
|
1384 |
+
)
|
1385 |
+
)
|
1386 |
+
if training_args.gradient_accumulation_steps > 1:
|
1387 |
+
# reshape data into (gradient_accumulation_steps, batch_per_node, ...)
|
1388 |
+
# to avoid any data redistribution when sharding
|
1389 |
+
bs_shape = (training_args.gradient_accumulation_steps,) + bs_shape
|
1390 |
+
|
1391 |
+
# reshape batch
|
1392 |
+
batch = jax.tree_map(
|
1393 |
+
lambda x: x.reshape(bs_shape + x.shape[1:]),
|
1394 |
+
batch,
|
1395 |
+
)
|
1396 |
+
# freeze batch to pass safely to jax transforms
|
1397 |
+
batch = freeze(batch)
|
1398 |
+
|
1399 |
+
# train step
|
1400 |
+
state, train_metrics = p_train_step(state, batch, train_time)
|
1401 |
+
local_state["step"] += 1
|
1402 |
+
local_state["train_time"] = train_time
|
1403 |
+
local_state["train_samples"] += batch_size_per_step
|
1404 |
+
|
1405 |
+
if (
|
1406 |
+
local_state["step"] % training_args.logging_steps == 0
|
1407 |
+
and jax.process_index() == 0
|
1408 |
+
):
|
1409 |
+
metrics_logger.update_state_metrics(local_state)
|
1410 |
+
metrics_logger.log(train_metrics, prefix="train")
|
1411 |
+
|
1412 |
+
eval_metrics = None
|
1413 |
+
if local_state["step"] % training_args.eval_steps == 0:
|
1414 |
+
eval_metrics = run_evaluation()
|
1415 |
+
|
1416 |
+
if local_state["step"] % training_args.save_steps == 0:
|
1417 |
+
run_save_model(state, eval_metrics)
|
1418 |
+
|
1419 |
+
# log final train metrics
|
1420 |
+
if train_metrics is not None:
|
1421 |
+
metrics_logger.update_state_metrics(state)
|
1422 |
+
metrics_logger.log(train_metrics, prefix="train")
|
1423 |
+
|
1424 |
+
epochs.write(
|
1425 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|
1426 |
+
)
|
1427 |
+
|
1428 |
+
# Final evaluation
|
1429 |
+
eval_metrics = run_evaluation()
|
1430 |
+
|
1431 |
+
# save checkpoint after each epoch
|
1432 |
+
run_save_model(state, eval_metrics)
|
1433 |
+
|
1434 |
+
|
1435 |
+
if __name__ == "__main__":
|
1436 |
+
main()
|