FritsLyneborg commited on
Commit
6742988
·
1 Parent(s): d2606d5
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__
2
+ .ipynb_checkpoints
3
+ .streamlit
4
+ wandb/
5
+ *.egg-info/
6
+ jax_cache/
CITATION.cff ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YAML 1.2
2
+ ---
3
+ abstract: "DALL·E mini is a JAX/Flax reimplementation of OpenAI's DALL·E that requires much smaller hardware resources. By simplifying the architecture and model memory requirements, as well as leveraging open-source code and pre-trained models, we were able to create a model that is 27 times smaller than the original DALL·E and train it on a single TPU v3-8 for only 3 days. DALL·E mini achieves impressive results, albeit of a lower quality than the original system. It can be used for exploration and further experimentation on commodity hardware."
4
+ authors:
5
+ -
6
+ family-names: Dayma
7
+ given-names: Boris
8
+ -
9
+ family-names: Patil
10
+ given-names: Suraj
11
+ -
12
+ family-names: Cuenca
13
+ given-names: Pedro
14
+ -
15
+ family-names: Saifullah
16
+ given-names: Khalid
17
+ -
18
+ family-names: Abraham
19
+ given-names: Tanishq
20
+ -
21
+ family-names: "Lê Khắc"
22
+ given-names: "Phúc"
23
+ -
24
+ family-names: Melas
25
+ given-names: Luke
26
+ -
27
+ family-names: Ghosh
28
+ given-names: Ritobrata
29
+ cff-version: "1.1.0"
30
+ date-released: 2021-07-29
31
+ identifiers:
32
+ keywords:
33
+ - dalle
34
+ - "text-to-image generation"
35
+ - transformer
36
+ - "zero-shot"
37
+ - JAX
38
+ license: "Apache-2.0"
39
+ doi: 10.5281/zenodo.5146400
40
+ message: "If you use this project, please cite it using these metadata."
41
+ repository-code: "https://github.com/borisdayma/dalle-mini"
42
+ title: "DALL·E Mini"
43
+ version: "v0.1-alpha"
44
+ ...
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2021 The DALL·E mini Authors
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Makefile ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .PHONY: style
2
+
3
+ style:
4
+ black .
5
+ isort .
app/gradio/app_gradio.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # Uncomment to run on cpu
5
+ # import os
6
+ # os.environ["JAX_PLATFORM_NAME"] = "cpu"
7
+
8
+ import random
9
+
10
+ import gradio as gr
11
+ import jax
12
+ import numpy as np
13
+ from flax.jax_utils import replicate
14
+ from flax.training.common_utils import shard
15
+ from PIL import Image, ImageDraw, ImageFont
16
+
17
+ # ## CLIP Scoring
18
+ from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel
19
+ from vqgan_jax.modeling_flax_vqgan import VQModel
20
+
21
+ from dalle_mini.model import CustomFlaxBartForConditionalGeneration
22
+
23
+ DALLE_REPO = "flax-community/dalle-mini"
24
+ DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
25
+
26
+ VQGAN_REPO = "flax-community/vqgan_f16_16384"
27
+ VQGAN_COMMIT_ID = "90cc46addd2dd8f5be21586a9a23e1b95aa506a9"
28
+
29
+ tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
30
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(
31
+ DALLE_REPO, revision=DALLE_COMMIT_ID
32
+ )
33
+ vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
34
+
35
+
36
+ def captioned_strip(images, caption=None, rows=1):
37
+ increased_h = 0 if caption is None else 48
38
+ w, h = images[0].size[0], images[0].size[1]
39
+ img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h))
40
+ for i, img_ in enumerate(images):
41
+ img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
42
+
43
+ if caption is not None:
44
+ draw = ImageDraw.Draw(img)
45
+ font = ImageFont.truetype(
46
+ "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
47
+ )
48
+ draw.text((20, 3), caption, (255, 255, 255), font=font)
49
+ return img
50
+
51
+
52
+ def custom_to_pil(x):
53
+ x = np.clip(x, 0.0, 1.0)
54
+ x = (255 * x).astype(np.uint8)
55
+ x = Image.fromarray(x)
56
+ if not x.mode == "RGB":
57
+ x = x.convert("RGB")
58
+ return x
59
+
60
+
61
+ def generate(input, rng, params):
62
+ return model.generate(
63
+ **input,
64
+ max_length=257,
65
+ num_beams=1,
66
+ do_sample=True,
67
+ prng_key=rng,
68
+ eos_token_id=50000,
69
+ pad_token_id=50000,
70
+ params=params,
71
+ )
72
+
73
+
74
+ def get_images(indices, params):
75
+ return vqgan.decode_code(indices, params=params)
76
+
77
+
78
+ p_generate = jax.pmap(generate, "batch")
79
+ p_get_images = jax.pmap(get_images, "batch")
80
+
81
+ bart_params = replicate(model.params)
82
+ vqgan_params = replicate(vqgan.params)
83
+
84
+ clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
85
+ print("Initialize FlaxCLIPModel")
86
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
87
+ print("Initialize CLIPProcessor")
88
+
89
+
90
+ def hallucinate(prompt, num_images=64):
91
+ prompt = [prompt] * jax.device_count()
92
+ inputs = tokenizer(
93
+ prompt,
94
+ return_tensors="jax",
95
+ padding="max_length",
96
+ truncation=True,
97
+ max_length=128,
98
+ ).data
99
+ inputs = shard(inputs)
100
+
101
+ all_images = []
102
+ for i in range(num_images // jax.device_count()):
103
+ key = random.randint(0, 1e7)
104
+ rng = jax.random.PRNGKey(key)
105
+ rngs = jax.random.split(rng, jax.local_device_count())
106
+ indices = p_generate(inputs, rngs, bart_params).sequences
107
+ indices = indices[:, :, 1:]
108
+
109
+ images = p_get_images(indices, vqgan_params)
110
+ images = np.squeeze(np.asarray(images), 1)
111
+ for image in images:
112
+ all_images.append(custom_to_pil(image))
113
+ return all_images
114
+
115
+
116
+ def clip_top_k(prompt, images, k=8):
117
+ inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
118
+ outputs = clip(**inputs)
119
+ logits = outputs.logits_per_text
120
+ scores = np.array(logits[0]).argsort()[-k:][::-1]
121
+ return [images[score] for score in scores]
122
+
123
+
124
+ def compose_predictions(images, caption=None):
125
+ increased_h = 0 if caption is None else 48
126
+ w, h = images[0].size[0], images[0].size[1]
127
+ img = Image.new("RGB", (len(images) * w, h + increased_h))
128
+ for i, img_ in enumerate(images):
129
+ img.paste(img_, (i * w, increased_h))
130
+
131
+ if caption is not None:
132
+ draw = ImageDraw.Draw(img)
133
+ font = ImageFont.truetype(
134
+ "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
135
+ )
136
+ draw.text((20, 3), caption, (255, 255, 255), font=font)
137
+ return img
138
+
139
+
140
+ def top_k_predictions(prompt, num_candidates=32, k=8):
141
+ images = hallucinate(prompt, num_images=num_candidates)
142
+ images = clip_top_k(prompt, images, k=k)
143
+ return images
144
+
145
+
146
+ def run_inference(prompt, num_images=32, num_preds=8):
147
+ images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
148
+ predictions = captioned_strip(images)
149
+ output_title = f"""
150
+ <b>{prompt}</b>
151
+ """
152
+ return (output_title, predictions)
153
+
154
+
155
+ outputs = [
156
+ gr.outputs.HTML(label=""), # To be used as title
157
+ gr.outputs.Image(label=""),
158
+ ]
159
+
160
+ description = """
161
+ DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
162
+ """
163
+ gr.Interface(
164
+ run_inference,
165
+ inputs=[gr.inputs.Textbox(label="What do you want to see?")],
166
+ outputs=outputs,
167
+ title="DALL·E mini",
168
+ description=description,
169
+ article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
170
+ layout="vertical",
171
+ theme="huggingface",
172
+ examples=[
173
+ ["an armchair in the shape of an avocado"],
174
+ ["snowy mountains by the sea"],
175
+ ],
176
+ allow_flagging=False,
177
+ live=False,
178
+ # server_port=8999
179
+ ).launch(share=True)
app/gradio/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Requirements for huggingface spaces
2
+ gradio>=2.2.3
3
+ flax
4
+ transformers
app/streamlit/app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import base64
5
+ from io import BytesIO
6
+
7
+ import requests
8
+ import streamlit as st
9
+ from PIL import Image
10
+
11
+
12
+ class ServiceError(Exception):
13
+ def __init__(self, status_code):
14
+ self.status_code = status_code
15
+
16
+
17
+ def get_images_from_backend(prompt, backend_url):
18
+ r = requests.post(backend_url, json={"prompt": prompt})
19
+ if r.status_code == 200:
20
+ images = r.json()["images"]
21
+ images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
22
+ return images
23
+ else:
24
+ raise ServiceError(r.status_code)
25
+
26
+
27
+ st.sidebar.markdown(
28
+ """
29
+ <style>
30
+ .aligncenter {
31
+ text-align: center;
32
+ }
33
+ </style>
34
+ <p class="aligncenter">
35
+ <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/img/logo.png"/>
36
+ </p>
37
+ """,
38
+ unsafe_allow_html=True,
39
+ )
40
+ st.sidebar.markdown(
41
+ """
42
+ ___
43
+ <p style='text-align: center'>
44
+ DALL·E mini is an AI model that generates images from any prompt you give!
45
+ </p>
46
+
47
+ <p style='text-align: center'>
48
+ Created by Boris Dayma et al. 2021
49
+ <br/>
50
+ <a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
51
+ </p>
52
+ """,
53
+ unsafe_allow_html=True,
54
+ )
55
+
56
+ st.header("DALL·E mini")
57
+ st.subheader("Generate images from text")
58
+
59
+ prompt = st.text_input("What do you want to see?")
60
+
61
+ DEBUG = False
62
+ if prompt != "":
63
+ container = st.empty()
64
+ container.markdown(
65
+ f"""
66
+ <style> p {{ margin:0 }} div {{ margin:0 }} </style>
67
+ <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
68
+ <div class="stAlert">
69
+ <div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
70
+ <div class="st-b7">
71
+ <div class="css-whx05o e13vu3m50">
72
+ <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
73
+ <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
74
+ Generating predictions for: <b>{prompt}</b>
75
+ </div>
76
+ </div>
77
+ </div>
78
+ </div>
79
+ </div>
80
+ </div>
81
+ <small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
82
+ """,
83
+ unsafe_allow_html=True,
84
+ )
85
+
86
+ try:
87
+ backend_url = st.secrets["BACKEND_SERVER"]
88
+ print(f"Getting selections: {prompt}")
89
+ selected = get_images_from_backend(prompt, backend_url)
90
+
91
+ margin = 0.1 # for better position of zoom in arrow
92
+ n_columns = 3
93
+ cols = st.columns([1] + [margin, 1] * (n_columns - 1))
94
+ for i, img in enumerate(selected):
95
+ cols[(i % n_columns) * 2].image(img)
96
+ container.markdown(f"**{prompt}**")
97
+
98
+ st.button("Again!", key="again_button")
99
+
100
+ except ServiceError as error:
101
+ container.text(f"Service unavailable, status: {error.status_code}")
102
+ except KeyError:
103
+ if DEBUG:
104
+ container.markdown(
105
+ """
106
+ **Error: BACKEND_SERVER unset**
107
+
108
+ Please, create a file called `.streamlit/secrets.toml` inside the app's folder and include a line to configure the server URL:
109
+ ```
110
+ BACKEND_SERVER="<server url>"
111
+ ```
112
+ """
113
+ )
114
+ else:
115
+ container.markdown(
116
+ "Error -5, please try again or [report it](mailto:pcuenca-dalle@guenever.net)."
117
+ )
app/streamlit/img/loading.gif ADDED
img/logo.png ADDED
pyproject.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [tool.isort]
2
+ profile = "black"
setup.cfg ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [metadata]
2
+ name = dalle-mini
3
+ version = attr: dalle_mini.__version__
4
+ author = Boris Dayma et al.
5
+ author_email = boris.dayma@gmail.com
6
+ description = DALL·E mini - Generate images from a text prompt
7
+ long_description = file: README.md
8
+ long_description_content_type = text/markdown
9
+ url = https://github.com/borisdayma/dalle-mini
10
+ project_urls =
11
+ Bug Tracker = https://github.com/borisdayma/dalle-mini/issues
12
+ classifiers =
13
+ Programming Language :: Python :: 3
14
+ License :: OSI Approved :: Apache Software License
15
+ Operating System :: OS Independent
16
+ Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Development Status :: 3 - Alpha
18
+ Intended Audience :: Developers
19
+
20
+ [options]
21
+ package_dir =
22
+ =src
23
+ packages = find:
24
+ python_requires = >=3.6
25
+ install_requires =
26
+ transformers
27
+ einops
28
+ unidecode
29
+ ftfy
30
+ emoji
31
+ pillow
32
+ jax
33
+ flax
34
+ wandb
35
+
36
+ [options.extras_require]
37
+ dev =
38
+ tqdm
39
+ optax
40
+ braceexpand
41
+ datasets[streaming]
42
+ black[jupyter]
43
+ isort
44
+
45
+ [options.packages.find]
46
+ where = src
setup.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ if __name__ == "__main__":
4
+ setup()
src/dalle_mini/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __version__ = "0.0.4"
2
+
3
+ from .model import DalleBart, DalleBartProcessor
src/dalle_mini/data.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass, field
3
+ from functools import partial
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+ from braceexpand import braceexpand
9
+ from datasets import Dataset, load_dataset
10
+
11
+ from .model.text import TextNormalizer
12
+
13
+
14
+ @dataclass
15
+ class Dataset:
16
+ dataset_repo_or_path: str
17
+ train_file: str = None
18
+ validation_file: str = None
19
+ streaming: bool = True
20
+ use_auth_token: bool = False
21
+ text_column: str = "caption"
22
+ encoding_column: str = "encoding"
23
+ max_train_samples: int = None
24
+ max_eval_samples: int = None
25
+ preprocessing_num_workers: int = None
26
+ overwrite_cache: bool = False
27
+ do_train: bool = False
28
+ do_eval: bool = True
29
+ seed_dataset: int = None
30
+ shard_by_host: bool = False
31
+ blank_caption_prob: float = 0.0
32
+ clip_score_column: str = "clip_score"
33
+ min_clip_score: float = None
34
+ max_clip_score: float = None
35
+ filter_column: str = None
36
+ filter_value: str = None
37
+ train_dataset: Dataset = field(init=False)
38
+ eval_dataset: Dataset = field(init=False)
39
+ rng_dataset: jnp.ndarray = field(init=False)
40
+ multi_hosts: bool = field(init=False)
41
+
42
+ def __post_init__(self):
43
+ if self.seed_dataset is None:
44
+ # create a random seed
45
+ self.seed_dataset = random.randint(0, 2**32 - 1)
46
+ self.multi_hosts = jax.process_count() > 1
47
+ # feed blank captions only in streaming mode for now
48
+ # otherwise dataset could be cached with same blanked captions
49
+ if self.blank_caption_prob:
50
+ assert (
51
+ self.streaming is True
52
+ ), "blank_caption_prob can only be used in streaming mode"
53
+ # define data_files
54
+ if self.train_file is not None or self.validation_file is not None:
55
+ # accept braceexpand notation
56
+ for k in ["train_file", "validation_file"]:
57
+ f = getattr(self, k)
58
+ if isinstance(f, str):
59
+ setattr(self, k, list(braceexpand(f)))
60
+ # for list of files, split training data shards by host
61
+ if (
62
+ isinstance(self.train_file, list)
63
+ and self.multi_hosts
64
+ and self.shard_by_host
65
+ ):
66
+ self.train_file = self.train_file[
67
+ jax.process_index() :: jax.process_count()
68
+ ]
69
+ data_files = {
70
+ "train": self.train_file,
71
+ "validation": self.validation_file,
72
+ }
73
+ else:
74
+ data_files = None
75
+
76
+ # load dataset
77
+ dataset = load_dataset(
78
+ self.dataset_repo_or_path,
79
+ data_files=data_files,
80
+ streaming=self.streaming,
81
+ use_auth_token=self.use_auth_token,
82
+ )
83
+ if self.do_train:
84
+ if "train" not in dataset:
85
+ raise ValueError("Training requires a training dataset")
86
+ self.train_dataset = dataset["train"]
87
+ if self.max_train_samples is not None:
88
+ self.train_dataset = (
89
+ self.train_dataset.take(self.max_train_samples)
90
+ if self.streaming
91
+ else self.train_dataset.select(range(self.max_train_samples))
92
+ )
93
+ if self.do_eval:
94
+ if "validation" not in dataset:
95
+ raise ValueError("Evaluating requires a validation dataset")
96
+ self.eval_dataset = dataset["validation"]
97
+ if self.max_eval_samples is not None:
98
+ self.eval_dataset = (
99
+ self.eval_dataset.take(self.max_eval_samples)
100
+ if self.streaming
101
+ else self.eval_dataset.select(range(self.max_eval_samples))
102
+ )
103
+
104
+ def preprocess(self, tokenizer, config):
105
+ # get required config variables
106
+ decoder_start_token_id = config.decoder_start_token_id
107
+ normalize_text = config.normalize_text
108
+ max_length = config.max_text_length
109
+
110
+ if self.streaming:
111
+ # we need to shuffle early in streaming mode
112
+ if hasattr(self, "train_dataset"):
113
+ self.train_dataset = self.train_dataset.shuffle(
114
+ buffer_size=5000, seed=self.seed_dataset
115
+ )
116
+ else:
117
+ self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
118
+
119
+ # filter data
120
+ partial_filter_function = partial(
121
+ filter_function,
122
+ filter_column=self.filter_column,
123
+ filter_value=self.filter_value,
124
+ clip_score_column=self.clip_score_column,
125
+ min_clip_score=self.min_clip_score,
126
+ max_clip_score=self.max_clip_score,
127
+ )
128
+ for ds in ["train_dataset", "eval_dataset"]:
129
+ if hasattr(self, ds):
130
+ setattr(
131
+ self,
132
+ ds,
133
+ (
134
+ getattr(self, ds).filter(partial_filter_function)
135
+ if self.streaming
136
+ else getattr(self, ds).filter(
137
+ partial_filter_function,
138
+ num_proc=self.preprocessing_num_workers,
139
+ load_from_cache_file=not self.overwrite_cache,
140
+ desc="Filtering datasets",
141
+ )
142
+ ),
143
+ )
144
+
145
+ # normalize text
146
+ if normalize_text:
147
+ text_normalizer = TextNormalizer()
148
+ partial_normalize_function = partial(
149
+ normalize_function,
150
+ text_column=self.text_column,
151
+ text_normalizer=text_normalizer,
152
+ )
153
+ for ds in ["train_dataset", "eval_dataset"]:
154
+ if hasattr(self, ds):
155
+ setattr(
156
+ self,
157
+ ds,
158
+ (
159
+ getattr(self, ds).map(partial_normalize_function)
160
+ if self.streaming
161
+ else getattr(self, ds).map(
162
+ partial_normalize_function,
163
+ num_proc=self.preprocessing_num_workers,
164
+ load_from_cache_file=not self.overwrite_cache,
165
+ desc="Normalizing datasets",
166
+ )
167
+ ),
168
+ )
169
+
170
+ # blank captions
171
+ if self.blank_caption_prob:
172
+ partial_blank_caption_function = partial(
173
+ blank_caption_function,
174
+ text_column=self.text_column,
175
+ blank_caption_prob=self.blank_caption_prob,
176
+ )
177
+ if hasattr(self, "train_dataset"):
178
+ self.train_dataset = (
179
+ self.train_dataset.map(partial_blank_caption_function)
180
+ if self.streaming
181
+ else self.train_dataset.map(
182
+ partial_blank_caption_function,
183
+ num_proc=self.preprocessing_num_workers,
184
+ load_from_cache_file=False,
185
+ desc="Blanking some captions",
186
+ )
187
+ )
188
+
189
+ # preprocess
190
+ partial_preprocess_function = partial(
191
+ preprocess_function,
192
+ tokenizer=tokenizer,
193
+ text_column=self.text_column,
194
+ encoding_column=self.encoding_column,
195
+ max_length=max_length,
196
+ decoder_start_token_id=decoder_start_token_id,
197
+ )
198
+ for ds in ["train_dataset", "eval_dataset"]:
199
+ if hasattr(self, ds):
200
+ setattr(
201
+ self,
202
+ ds,
203
+ (
204
+ getattr(self, ds).map(
205
+ partial_preprocess_function,
206
+ batched=True,
207
+ remove_columns=[
208
+ self.text_column,
209
+ self.encoding_column,
210
+ ],
211
+ )
212
+ if self.streaming
213
+ else getattr(self, ds).map(
214
+ partial_preprocess_function,
215
+ batched=True,
216
+ remove_columns=getattr(ds, "column_names"),
217
+ num_proc=self.preprocessing_num_workers,
218
+ load_from_cache_file=not self.overwrite_cache,
219
+ desc="Preprocessing datasets",
220
+ )
221
+ ),
222
+ )
223
+
224
+ def dataloader(self, split, batch_size, epoch=None):
225
+ def _dataloader_datasets_non_streaming(
226
+ dataset: Dataset,
227
+ rng: jax.random.PRNGKey = None,
228
+ ):
229
+ """
230
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
231
+ Shuffle batches if rng is set.
232
+ """
233
+ steps_per_epoch = len(dataset) // batch_size
234
+
235
+ if rng is not None:
236
+ batch_idx = jax.random.permutation(rng, len(dataset))
237
+ else:
238
+ batch_idx = jnp.arange(len(dataset))
239
+
240
+ batch_idx = batch_idx[
241
+ : steps_per_epoch * batch_size
242
+ ] # Skip incomplete batch.
243
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
244
+
245
+ for idx in batch_idx:
246
+ batch = dataset[idx]
247
+ batch = {k: jnp.array(v) for k, v in batch.items()}
248
+ yield batch
249
+
250
+ def _dataloader_datasets_streaming(
251
+ dataset: Dataset,
252
+ epoch: int,
253
+ ):
254
+ keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
255
+ batch = {k: [] for k in keys}
256
+ first_loop = True # stop after one loop in some cases
257
+ while (self.multi_hosts and split == "train") or first_loop:
258
+ # in multi-host, we run forever (no epoch) as hosts need to stop
259
+ # at the same time and training data may not be split equally
260
+ # For validation data we put the entire batch on each host and then
261
+ # keep only the one specific to each host (could be improved but not necessary)
262
+ if epoch is not None:
263
+ assert split == "train"
264
+ # reshuffle training data at each epoch
265
+ dataset.set_epoch(epoch)
266
+ epoch += 1
267
+ for item in dataset:
268
+ for k in keys:
269
+ batch[k].append(item[k])
270
+ if len(batch[keys[0]]) == batch_size:
271
+ batch = {k: jnp.array(v) for k, v in batch.items()}
272
+ yield batch
273
+ batch = {k: [] for k in keys}
274
+ first_loop = False
275
+
276
+ if split == "train":
277
+ ds = self.train_dataset
278
+ elif split == "eval":
279
+ ds = self.eval_dataset
280
+ else:
281
+ raise ValueError(f'split must be "train" or "eval", got {split}')
282
+
283
+ if self.streaming:
284
+ return _dataloader_datasets_streaming(ds, epoch)
285
+ else:
286
+ if split == "train":
287
+ self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
288
+ return _dataloader_datasets_non_streaming(ds, input_rng)
289
+
290
+ @property
291
+ def length(self):
292
+ len_train_dataset, len_eval_dataset = None, None
293
+ if self.streaming:
294
+ # we don't know the length, let's just assume max_samples if defined
295
+ if self.max_train_samples is not None:
296
+ len_train_dataset = self.max_train_samples
297
+ if self.max_eval_samples is not None:
298
+ len_eval_dataset = self.max_eval_samples
299
+ else:
300
+ len_train_dataset = (
301
+ len(self.train_dataset) if hasattr(self, "train_dataset") else None
302
+ )
303
+ len_eval_dataset = (
304
+ len(self.eval_dataset) if hasattr(self, "eval_dataset") else None
305
+ )
306
+ return len_train_dataset, len_eval_dataset
307
+
308
+
309
+ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
310
+ """
311
+ Shift input ids one token to the right.
312
+ """
313
+ shifted_input_ids = np.zeros(input_ids.shape)
314
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
315
+ shifted_input_ids[:, 0] = decoder_start_token_id
316
+ return shifted_input_ids
317
+
318
+
319
+ def blank_caption_function(example, text_column, blank_caption_prob):
320
+ if blank_caption_prob and np.random.rand() < blank_caption_prob:
321
+ example[text_column] = ""
322
+ return example
323
+
324
+
325
+ def normalize_function(example, text_column, text_normalizer):
326
+ example[text_column] = text_normalizer(example[text_column])
327
+ return example
328
+
329
+
330
+ def filter_function(
331
+ example,
332
+ min_clip_score,
333
+ max_clip_score,
334
+ clip_score_column,
335
+ filter_column,
336
+ filter_value,
337
+ ):
338
+ if min_clip_score is not None and example[clip_score_column] < min_clip_score:
339
+ return False
340
+ if max_clip_score is not None and example[clip_score_column] > max_clip_score:
341
+ return False
342
+ if filter_column is not None and example[filter_column] != filter_value:
343
+ return False
344
+ return True
345
+
346
+
347
+ def preprocess_function(
348
+ examples,
349
+ tokenizer,
350
+ text_column,
351
+ encoding_column,
352
+ max_length,
353
+ decoder_start_token_id,
354
+ ):
355
+ inputs = examples[text_column]
356
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
357
+ model_inputs = tokenizer(
358
+ inputs,
359
+ max_length=max_length,
360
+ padding="max_length",
361
+ truncation=True,
362
+ return_tensors="np",
363
+ )
364
+
365
+ # set up targets
366
+ # Note: labels correspond to our target indices
367
+ # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
368
+ labels = examples[encoding_column]
369
+ labels = np.asarray(labels)
370
+
371
+ # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
372
+ model_inputs["labels"] = labels
373
+
374
+ # In our case, this prepends the bos token and removes the last one
375
+ decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id)
376
+ model_inputs["decoder_input_ids"] = decoder_input_ids
377
+
378
+ return model_inputs
src/dalle_mini/model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .configuration import DalleBartConfig
2
+ from .modeling import DalleBart
3
+ from .partitions import set_partitions
4
+ from .processor import DalleBartProcessor
5
+ from .tokenizer import DalleBartTokenizer
src/dalle_mini/model/configuration.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ DalleBart model configuration """
16
+ import warnings
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ from .utils import PretrainedFromWandbMixin
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
27
+ model_type = "dallebart"
28
+ keys_to_ignore_at_inference = ["past_key_values"]
29
+ attribute_map = {
30
+ "num_attention_heads": "encoder_attention_heads",
31
+ "hidden_size": "d_model",
32
+ }
33
+
34
+ def __init__(
35
+ self,
36
+ normalize_text=False,
37
+ encoder_vocab_size=50264,
38
+ image_vocab_size=16384, # encoded image token space
39
+ image_length=256, # number of encoded tokens
40
+ max_text_length=64, # max number of text tokens
41
+ encoder_layers=12,
42
+ encoder_ffn_dim=4096,
43
+ encoder_attention_heads=16,
44
+ decoder_layers=12,
45
+ decoder_ffn_dim=4096,
46
+ decoder_attention_heads=16,
47
+ activation_function="gelu",
48
+ d_model=1024,
49
+ dropout=0.1,
50
+ attention_dropout=0.0,
51
+ activation_dropout=0.0,
52
+ init_std=0.02,
53
+ scale_embedding=False,
54
+ gradient_checkpointing=False,
55
+ use_cache=True,
56
+ is_encoder_decoder=True,
57
+ forced_eos_token_id=None,
58
+ tie_word_embeddings=False, # different modalities and sizes
59
+ do_sample=True,
60
+ # transformer variants
61
+ use_bias=False, # use bias in attention and dense layers (except for lm_head)
62
+ ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
63
+ ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
64
+ use_head_scale=False, # used in NormFormer
65
+ use_cosine_attention=False, # used in Swin v2
66
+ tau_init=0.05, # used only in cosine attention (Swin v2)
67
+ use_absolute_position_embeddings=True, # default
68
+ use_swin_position_embeddings=False, # used in Swin v1/v2
69
+ use_deepnet_scaling=False, # used in Deepnet
70
+ use_glu=False, # "GLU Variants Improve Transformer"
71
+ use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
72
+ sinkhorn_iters=1, # used in SinkFormers
73
+ use_final_ln_encoder=True, # final layer normalization in encoder
74
+ use_final_ln_decoder=True, # final layer normalization in decoder
75
+ # parameters that should not be necessary but could affect results
76
+ force_ln_scale=False, # force scale in layernorm even when followed by dense layers
77
+ **kwargs,
78
+ ):
79
+ # text normalizer
80
+ self.normalize_text = normalize_text
81
+
82
+ # transformer variants
83
+ self.use_bias = use_bias
84
+ assert ln_type in [
85
+ "rmsnorm",
86
+ "layernorm",
87
+ ], "ln_type must be 'rmsnorm' or 'layernorm'"
88
+ self.ln_type = ln_type
89
+ if ln_positions == "deepnet":
90
+ ln_positions = "postln"
91
+ assert ln_positions in [
92
+ "normformer",
93
+ "swinv2",
94
+ "cogview",
95
+ "postln",
96
+ "preln",
97
+ ], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln'"
98
+ self.use_head_scale = use_head_scale
99
+ assert use_alibi is False, "use_alibi is not supported yet"
100
+ self.ln_positions = ln_positions
101
+ self.use_cosine_attention = use_cosine_attention
102
+ self.tau_init = tau_init
103
+ self.use_absolute_position_embeddings = use_absolute_position_embeddings
104
+ self.use_swin_position_embeddings = use_swin_position_embeddings
105
+ self.use_deepnet_scaling = use_deepnet_scaling
106
+ self.use_glu = use_glu
107
+ self.use_alibi = use_alibi
108
+ self.sinkhorn_iters = sinkhorn_iters
109
+ if ln_positions == "postln":
110
+ assert (
111
+ use_final_ln_encoder
112
+ ), "use_final_ln_encoder must be True when ln_positions is 'postln'"
113
+ assert (
114
+ use_final_ln_decoder
115
+ ), "use_final_ln_decoder must be True when ln_positions is 'postln'"
116
+ self.use_final_ln_encoder = use_final_ln_encoder
117
+ self.use_final_ln_decoder = use_final_ln_decoder
118
+ self.force_ln_scale = force_ln_scale
119
+
120
+ # common parameters
121
+ self.encoder_vocab_size = encoder_vocab_size
122
+ self.image_vocab_size = image_vocab_size
123
+ self.image_length = image_length
124
+ self.max_text_length = max_text_length
125
+ self.d_model = d_model
126
+ self.encoder_ffn_dim = encoder_ffn_dim
127
+ self.encoder_layers = encoder_layers
128
+ self.encoder_attention_heads = encoder_attention_heads
129
+ self.decoder_ffn_dim = decoder_ffn_dim
130
+ self.decoder_layers = decoder_layers
131
+ self.decoder_attention_heads = decoder_attention_heads
132
+ self.dropout = dropout
133
+ self.attention_dropout = attention_dropout
134
+ self.activation_dropout = activation_dropout
135
+ self.activation_function = activation_function
136
+ self.init_std = init_std
137
+ self.use_cache = use_cache
138
+ self.gradient_checkpointing = gradient_checkpointing
139
+ self.scale_embedding = (
140
+ scale_embedding # scale factor will be sqrt(d_model) if True
141
+ )
142
+
143
+ # special token id's are appended to vocab if not provided
144
+ decoder_start_token_id = kwargs.pop("decoder_start_token_id", image_vocab_size)
145
+ bos_token_id = kwargs.pop("bos_token_id", image_vocab_size)
146
+ pad_token_id = kwargs.pop("pad_token_id", image_vocab_size)
147
+ eos_token_id = kwargs.pop("eos_token_id", image_vocab_size)
148
+
149
+ # we generate to image_length + 1 (for bos) by default
150
+ min_length = kwargs.pop("min_length", image_length + 1)
151
+ max_length = kwargs.pop("max_length", image_length + 1)
152
+
153
+ super().__init__(
154
+ # args required in parent class
155
+ is_encoder_decoder=is_encoder_decoder,
156
+ tie_word_embeddings=tie_word_embeddings,
157
+ forced_eos_token_id=forced_eos_token_id,
158
+ decoder_start_token_id=decoder_start_token_id,
159
+ bos_token_id=bos_token_id,
160
+ pad_token_id=pad_token_id,
161
+ eos_token_id=eos_token_id,
162
+ min_length=min_length,
163
+ max_length=max_length,
164
+ do_sample=do_sample,
165
+ **kwargs,
166
+ )
167
+
168
+ # ensure backward compatibility for BART CNN models
169
+ if self.forced_bos_token_id is None and kwargs.get(
170
+ "force_bos_token_to_be_generated", False
171
+ ):
172
+ self.forced_bos_token_id = self.bos_token_id
173
+ warnings.warn(
174
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
175
+ "The config can simply be saved and uploaded again to be fixed."
176
+ )
src/dalle_mini/model/modeling.py ADDED
@@ -0,0 +1,2093 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021-2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and & DALL·E Mini team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ DalleBart model. """
16
+
17
+ import math
18
+ import os
19
+ from functools import partial
20
+ from pickle import UnpicklingError
21
+ from typing import Any, Dict, Optional, Tuple, Union
22
+
23
+ import flax
24
+ import flax.linen as nn
25
+ import jax
26
+ import jax.numpy as jnp
27
+ import msgpack.exceptions
28
+ from einops import rearrange
29
+ from flax.core.frozen_dict import unfreeze
30
+ from flax.linen import combine_masks, make_causal_mask
31
+ from flax.linen import partitioning as nn_partitioning
32
+ from flax.linen.linear import PrecisionLike
33
+ from flax.serialization import from_bytes
34
+ from flax.traverse_util import flatten_dict, unflatten_dict
35
+ from jax import custom_jvp, lax
36
+ from jax.random import PRNGKey
37
+ from transformers.configuration_utils import PretrainedConfig
38
+ from transformers.file_utils import (
39
+ FLAX_WEIGHTS_NAME,
40
+ WEIGHTS_NAME,
41
+ cached_path,
42
+ hf_bucket_url,
43
+ is_offline_mode,
44
+ is_remote_url,
45
+ )
46
+ from transformers.generation_flax_utils import FlaxSampleOutput
47
+ from transformers.modeling_flax_outputs import (
48
+ FlaxBaseModelOutput,
49
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
50
+ FlaxCausalLMOutputWithCrossAttentions,
51
+ FlaxSeq2SeqLMOutput,
52
+ )
53
+ from transformers.modeling_flax_utils import ACT2FN
54
+ from transformers.models.bart.modeling_flax_bart import (
55
+ FlaxBartAttention,
56
+ FlaxBartForConditionalGeneration,
57
+ FlaxBartForConditionalGenerationModule,
58
+ FlaxBartModule,
59
+ FlaxBartPreTrainedModel,
60
+ )
61
+ from transformers.utils import logging
62
+
63
+ from .configuration import DalleBartConfig
64
+ from .utils import PretrainedFromWandbMixin
65
+
66
+ logger = logging.get_logger(__name__)
67
+
68
+ remat = nn_partitioning.remat
69
+
70
+
71
+ def smelu(beta: Any = 1.0):
72
+ """
73
+ Implementation of "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations"
74
+ https://arxiv.org/abs/2202.06499
75
+ """
76
+
77
+ @custom_jvp
78
+ @jax.jit
79
+ def _smelu(x: Any) -> Any:
80
+ x = jnp.where(x <= -beta, 0.0, x)
81
+ return jnp.where(x >= beta, x, jnp.square(x + beta) / (4 * beta))
82
+
83
+ _smelu.defjvps(
84
+ lambda g, ans, x: lax.select(
85
+ x == -beta,
86
+ lax.full_like(g, 0),
87
+ lax.select(x == beta, lax.full_like(g, 1), g),
88
+ )
89
+ )
90
+ return _smelu
91
+
92
+
93
+ ACT2FN.update({"smelu": smelu})
94
+
95
+ # deepnet initialization
96
+ def deepnet_init(gain=1):
97
+ init = jax.nn.initializers.glorot_normal()
98
+
99
+ def _init(*args, **kwargs):
100
+ return gain * init(*args, **kwargs)
101
+
102
+ return _init
103
+
104
+
105
+ # deepnet gain
106
+ deepnet_gain = {
107
+ "encoder": {
108
+ "alpha": lambda config: 0.81
109
+ * (config.encoder_layers**4 * config.decoder_layers) ** 0.0625,
110
+ "beta": lambda config: 0.87
111
+ * (config.encoder_layers**4 * config.decoder_layers) ** -0.0625,
112
+ },
113
+ "decoder": {
114
+ "alpha": lambda config: (3 * config.decoder_layers) ** 0.25,
115
+ "beta": lambda config: (12 * config.decoder_layers) ** -0.25,
116
+ },
117
+ }
118
+
119
+
120
+ class RMSNorm(nn.Module):
121
+ """
122
+ From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
123
+
124
+ Adapted from flax.linen.LayerNorm
125
+ """
126
+
127
+ epsilon: float = 1e-6
128
+ dtype: Any = jnp.float32
129
+ param_dtype: Any = jnp.float32
130
+ use_scale: bool = True
131
+ scale_init: Any = jax.nn.initializers.ones
132
+
133
+ @nn.compact
134
+ def __call__(self, x):
135
+ reduction_axes = (-1,)
136
+ feature_axes = (-1,)
137
+
138
+ rms_sq = self._compute_rms_sq(x, reduction_axes)
139
+
140
+ return self._normalize(
141
+ self,
142
+ x,
143
+ rms_sq,
144
+ reduction_axes,
145
+ feature_axes,
146
+ self.dtype,
147
+ self.param_dtype,
148
+ self.epsilon,
149
+ self.use_scale,
150
+ self.scale_init,
151
+ )
152
+
153
+ def _compute_rms_sq(self, x, axes):
154
+ x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
155
+ rms_sq = jnp.mean(jax.lax.square(x), axes)
156
+ return rms_sq
157
+
158
+ def _normalize(
159
+ self,
160
+ mdl,
161
+ x,
162
+ rms_sq,
163
+ reduction_axes,
164
+ feature_axes,
165
+ dtype,
166
+ param_dtype,
167
+ epsilon,
168
+ use_scale,
169
+ scale_init,
170
+ ):
171
+ reduction_axes = nn.normalization._canonicalize_axes(x.ndim, reduction_axes)
172
+ feature_axes = nn.normalization._canonicalize_axes(x.ndim, feature_axes)
173
+ stats_shape = list(x.shape)
174
+ for axis in reduction_axes:
175
+ stats_shape[axis] = 1
176
+ rms_sq = rms_sq.reshape(stats_shape)
177
+ feature_shape = [1] * x.ndim
178
+ reduced_feature_shape = []
179
+ for ax in feature_axes:
180
+ feature_shape[ax] = x.shape[ax]
181
+ reduced_feature_shape.append(x.shape[ax])
182
+ mul = lax.rsqrt(rms_sq + epsilon)
183
+ if use_scale:
184
+ scale = mdl.param(
185
+ "scale", scale_init, reduced_feature_shape, param_dtype
186
+ ).reshape(feature_shape)
187
+ mul *= scale
188
+ y = mul * x
189
+ return jnp.asarray(y, dtype)
190
+
191
+
192
+ def norm(type, *args, **kwargs):
193
+ if type == "rmsnorm":
194
+ return RMSNorm(*args, **kwargs)
195
+ elif type == "layernorm":
196
+ return nn.LayerNorm(*args, **kwargs)
197
+ else:
198
+ raise ValueError(f"Unknown norm type {type}")
199
+
200
+
201
+ def dot_product_attention_weights(
202
+ query: Any,
203
+ key: Any,
204
+ bias: Optional[Any] = None,
205
+ mask: Optional[Any] = None,
206
+ embed_pos: Optional[Any] = None,
207
+ broadcast_dropout: bool = True,
208
+ dropout_rng: Optional[PRNGKey] = None,
209
+ dropout_rate: float = 0.0,
210
+ deterministic: bool = False,
211
+ dtype: Any = jnp.float32,
212
+ precision: PrecisionLike = None,
213
+ sinkhorn_iters: int = 1,
214
+ is_encoder: bool = False,
215
+ ):
216
+ """
217
+ Computes dot-product attention weights given query and key.
218
+ mask is included into the bias.
219
+
220
+ Adapted from flax.linen.attention.dot_product_attention_weights"
221
+ """
222
+ assert query.ndim == key.ndim, "q, k must have same rank."
223
+ assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
224
+ assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
225
+ assert query.shape[-1] == key.shape[-1], "q, k depths must match."
226
+
227
+ # calculate attention matrix
228
+ depth = query.shape[-1]
229
+ query = query / jnp.sqrt(depth).astype(dtype)
230
+ # attn weight shape is (batch..., num_heads, q_length, kv_length)
231
+ attn_weights = jnp.einsum("...qhd,...khd->...hqk", query, key, precision=precision)
232
+
233
+ # apply attention bias: masking, dropout, proximity bias, etc.
234
+ if bias is not None:
235
+ attn_weights = attn_weights + bias
236
+
237
+ # add relative position
238
+ if embed_pos is not None:
239
+ attn_weights = attn_weights + embed_pos
240
+
241
+ # normalize the attention weights
242
+ if not is_encoder or sinkhorn_iters == 1:
243
+ # sinkhorn does not work for causal (leaks info of future tokens into past)
244
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
245
+ else:
246
+ # adapted from https://github.com/lucidrains/sinkhorn-transformer
247
+ for i in range(sinkhorn_iters):
248
+ # when causal, some attn_weights have been set to -inf through bias
249
+ if i % 2 == 0:
250
+ attn_weights -= jax.nn.logsumexp(attn_weights, axis=-1, keepdims=True)
251
+ else:
252
+ attn_weights -= jax.nn.logsumexp(attn_weights, axis=-2, keepdims=True)
253
+ if mask is not None:
254
+ attn_weights = jnp.where(mask, attn_weights, -jnp.inf)
255
+ attn_weights = jnp.exp(attn_weights).astype(dtype)
256
+
257
+ # apply attention dropout
258
+ if not deterministic and dropout_rate > 0.0:
259
+ keep_prob = 1.0 - dropout_rate
260
+ if broadcast_dropout:
261
+ # dropout is broadcast across the batch + head dimensions
262
+ dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
263
+ keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
264
+ else:
265
+ keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
266
+ multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(
267
+ keep_prob, dtype=dtype
268
+ )
269
+ attn_weights = attn_weights * multiplier
270
+
271
+ return attn_weights
272
+
273
+
274
+ class FlaxBartAttention(FlaxBartAttention):
275
+ """
276
+ Edits:
277
+ - causal mask is used only in decoder and considers image_length
278
+ - scale attention heads per NormFormer paper
279
+ """
280
+
281
+ is_encoder: bool = False
282
+ q_length: int = None
283
+ k_length: int = None
284
+
285
+ def setup(self) -> None:
286
+ self.head_dim = self.embed_dim // self.num_heads
287
+ if self.head_dim * self.num_heads != self.embed_dim:
288
+ raise ValueError(
289
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
290
+ f" and `num_heads`: {self.num_heads})."
291
+ )
292
+
293
+ dense = partial(
294
+ nn.Dense,
295
+ self.embed_dim,
296
+ use_bias=self.bias,
297
+ dtype=self.dtype,
298
+ )
299
+
300
+ gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
301
+ self.config
302
+ )
303
+
304
+ self.q_proj = dense(
305
+ kernel_init=deepnet_init()
306
+ if self.config.use_deepnet_scaling
307
+ else jax.nn.initializers.normal(self.config.init_std)
308
+ )
309
+ self.k_proj = dense(
310
+ kernel_init=deepnet_init()
311
+ if self.config.use_deepnet_scaling
312
+ else jax.nn.initializers.normal(self.config.init_std)
313
+ )
314
+ self.v_proj = dense(
315
+ kernel_init=deepnet_init(gain)
316
+ if self.config.use_deepnet_scaling
317
+ else jax.nn.initializers.normal(self.config.init_std)
318
+ )
319
+ self.out_proj = dense(
320
+ kernel_init=deepnet_init(gain)
321
+ if self.config.use_deepnet_scaling
322
+ else jax.nn.initializers.normal(self.config.init_std)
323
+ )
324
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
325
+
326
+ if self.config.use_head_scale:
327
+ self.head_scale = self.param(
328
+ "head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1)
329
+ )
330
+
331
+ if self.config.use_cosine_attention:
332
+ self.tau = self.param(
333
+ "tau",
334
+ jax.nn.initializers.constant(self.config.tau_init),
335
+ (1, self.num_heads, 1, 1),
336
+ )
337
+
338
+ if self.config.use_swin_position_embeddings:
339
+ self.rel_bias = nn.Embed(
340
+ self.q_length,
341
+ self.k_length * self.num_heads,
342
+ embedding_init=deepnet_init()
343
+ if self.config.use_deepnet_scaling
344
+ else jax.nn.initializers.normal(self.config.init_std),
345
+ )
346
+
347
+ if self.causal:
348
+ # used only in decoder
349
+ self.causal_mask = make_causal_mask(
350
+ jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool"
351
+ )
352
+
353
+ def __call__(
354
+ self,
355
+ hidden_states: jnp.ndarray,
356
+ key_value_states: Optional[jnp.ndarray] = None,
357
+ attention_mask: Optional[jnp.ndarray] = None,
358
+ init_cache: bool = False,
359
+ deterministic: bool = True,
360
+ ) -> Tuple[jnp.ndarray]:
361
+ """Input shape: Batch x Time x Channel"""
362
+
363
+ # if key_value_states are provided this layer is used as a cross-attention layer
364
+ # for the decoder
365
+ is_cross_attention = key_value_states is not None
366
+ batch_size = hidden_states.shape[0]
367
+
368
+ # get query proj
369
+ query_states = self.q_proj(hidden_states)
370
+ # get key, value proj
371
+ if is_cross_attention:
372
+ # cross_attentions
373
+ key_states = self.k_proj(key_value_states)
374
+ value_states = self.v_proj(key_value_states)
375
+ else:
376
+ # self_attention
377
+ key_states = self.k_proj(hidden_states)
378
+ value_states = self.v_proj(hidden_states)
379
+
380
+ query_states = self._split_heads(query_states)
381
+ key_states = self._split_heads(key_states)
382
+ value_states = self._split_heads(value_states)
383
+
384
+ # handle cache prepare causal attention mask
385
+ if self.causal:
386
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
387
+ if self.has_variable("cache", "cached_key"):
388
+ mask_shift = self.variables["cache"]["cache_index"]
389
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
390
+ causal_mask = lax.dynamic_slice(
391
+ self.causal_mask,
392
+ (0, 0, mask_shift, 0),
393
+ (1, 1, query_length, max_decoder_length),
394
+ )
395
+ else:
396
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
397
+ causal_mask = jnp.broadcast_to(
398
+ causal_mask, (batch_size,) + causal_mask.shape[1:]
399
+ )
400
+
401
+ # combine masks if needed
402
+ if attention_mask is not None and self.causal:
403
+ attention_mask = jnp.broadcast_to(
404
+ jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape
405
+ )
406
+ attention_mask = combine_masks(attention_mask, causal_mask)
407
+ elif self.causal:
408
+ attention_mask = causal_mask
409
+ elif attention_mask is not None:
410
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
411
+
412
+ # During fast autoregressive decoding, we feed one position at a time,
413
+ # and cache the keys and values step by step.
414
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
415
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
416
+ key_states, value_states, query_states, attention_mask
417
+ )
418
+
419
+ # Convert the boolean attention mask to an attention bias.
420
+ if attention_mask is not None:
421
+ # attention mask in the form of attention bias
422
+ attention_bias = lax.select(
423
+ attention_mask > 0,
424
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
425
+ jnp.full(attention_mask.shape, -jnp.inf).astype(self.dtype),
426
+ )
427
+ else:
428
+ attention_bias = None
429
+
430
+ dropout_rng = None
431
+ if not deterministic and self.dropout > 0.0:
432
+ dropout_rng = self.make_rng("dropout")
433
+
434
+ if self.config.use_cosine_attention:
435
+ # normalize q and k
436
+ query_states = query_states / (
437
+ jnp.linalg.norm(query_states, axis=-1, keepdims=True) + 1e-8
438
+ )
439
+ key_states = key_states / (
440
+ jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8
441
+ )
442
+
443
+ # relative position embeddings
444
+ if self.config.use_swin_position_embeddings:
445
+ position_ids = jnp.arange(self.q_length)
446
+ embed_pos = self.rel_bias(position_ids)
447
+ embed_pos = rearrange(embed_pos, "q (k h) -> 1 h q k", h=self.num_heads)
448
+ else:
449
+ embed_pos = None
450
+
451
+ attn_weights = dot_product_attention_weights(
452
+ query_states,
453
+ key_states,
454
+ bias=attention_bias,
455
+ mask=attention_mask,
456
+ embed_pos=embed_pos,
457
+ dropout_rng=dropout_rng,
458
+ dropout_rate=self.dropout,
459
+ broadcast_dropout=True,
460
+ deterministic=deterministic,
461
+ dtype=self.dtype,
462
+ precision=None,
463
+ sinkhorn_iters=self.config.sinkhorn_iters,
464
+ is_encoder=self.is_encoder,
465
+ )
466
+ if self.config.use_cosine_attention:
467
+ # divide by tau
468
+ attn_weights = attn_weights / jnp.maximum(self.tau, 0.01)
469
+
470
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
471
+ if self.config.use_head_scale:
472
+ # per Normformer
473
+ attn_output = attn_output * self.head_scale
474
+ attn_output = self._merge_heads(attn_output)
475
+ attn_output = self.out_proj(attn_output)
476
+
477
+ return attn_output, attn_weights
478
+
479
+
480
+ class GLU(nn.Module):
481
+ """From "GLU Variants Improve Transformer" by https://arxiv.org/abs/2002.05202"""
482
+
483
+ config: DalleBartConfig
484
+ ffn_dim: int
485
+ embed_dim: int
486
+ dtype: jnp.dtype = jnp.float32
487
+ is_encoder: bool = False
488
+
489
+ @nn.compact
490
+ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
491
+
492
+ gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
493
+ self.config
494
+ )
495
+
496
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
497
+ x = norm(
498
+ self.config.ln_type,
499
+ dtype=self.dtype,
500
+ epsilon=1e-05,
501
+ use_scale=self.config.force_ln_scale,
502
+ )(x)
503
+ w = nn.Dense(
504
+ self.ffn_dim,
505
+ dtype=self.dtype,
506
+ use_bias=self.config.use_bias,
507
+ kernel_init=deepnet_init(gain)
508
+ if self.config.use_deepnet_scaling
509
+ else jax.nn.initializers.normal(self.config.init_std),
510
+ )(x)
511
+ w = ACT2FN[self.config.activation_function](w)
512
+ v = nn.Dense(
513
+ self.ffn_dim,
514
+ dtype=self.dtype,
515
+ use_bias=self.config.use_bias,
516
+ kernel_init=deepnet_init(gain)
517
+ if self.config.use_deepnet_scaling
518
+ else jax.nn.initializers.normal(self.config.init_std),
519
+ )(x)
520
+ x = w * v
521
+ if self.config.ln_positions in ["normformer"]:
522
+ x = norm(
523
+ self.config.ln_type,
524
+ dtype=self.dtype,
525
+ epsilon=1e-05,
526
+ use_scale=self.config.force_ln_scale,
527
+ )(x)
528
+ x = nn.Dropout(rate=self.config.activation_dropout)(
529
+ x, deterministic=deterministic
530
+ )
531
+
532
+ x = nn.Dense(
533
+ self.embed_dim,
534
+ dtype=self.dtype,
535
+ use_bias=self.config.use_bias,
536
+ kernel_init=deepnet_init(gain)
537
+ if self.config.use_deepnet_scaling
538
+ else jax.nn.initializers.normal(self.config.init_std),
539
+ )(x)
540
+ if self.config.ln_positions in ["swinv2", "cogview"]:
541
+ x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
542
+ x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
543
+ return x
544
+
545
+
546
+ class FFN(nn.Module):
547
+ """Simple FFN layer"""
548
+
549
+ config: DalleBartConfig
550
+ ffn_dim: int
551
+ embed_dim: int
552
+ dtype: jnp.dtype = jnp.float32
553
+ is_encoder: bool = False
554
+
555
+ @nn.compact
556
+ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
557
+
558
+ gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
559
+ self.config
560
+ )
561
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
562
+ x = norm(
563
+ self.config.ln_type,
564
+ dtype=self.dtype,
565
+ epsilon=1e-05,
566
+ use_scale=self.config.force_ln_scale,
567
+ )(x)
568
+ x = nn.Dense(
569
+ self.ffn_dim,
570
+ dtype=self.dtype,
571
+ use_bias=self.config.use_bias,
572
+ kernel_init=deepnet_init(gain)
573
+ if self.config.use_deepnet_scaling
574
+ else jax.nn.initializers.normal(self.config.init_std),
575
+ )(x)
576
+ x = ACT2FN[self.config.activation_function](x)
577
+ if self.config.ln_positions in ["normformer"]:
578
+ x = norm(
579
+ self.config.ln_type,
580
+ dtype=self.dtype,
581
+ epsilon=1e-05,
582
+ use_scale=self.config.force_ln_scale,
583
+ )(x)
584
+ x = nn.Dropout(rate=self.config.activation_dropout)(
585
+ x, deterministic=deterministic
586
+ )
587
+ x = nn.Dense(
588
+ self.embed_dim,
589
+ dtype=self.dtype,
590
+ use_bias=self.config.use_bias,
591
+ kernel_init=deepnet_init(gain)
592
+ if self.config.use_deepnet_scaling
593
+ else jax.nn.initializers.normal(self.config.init_std),
594
+ )(x)
595
+ if self.config.ln_positions in ["swinv2", "cogview"]:
596
+ x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
597
+ x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
598
+ return x
599
+
600
+
601
+ class FlaxBartEncoderLayer(nn.Module):
602
+ """
603
+ Edits:
604
+ - no bias
605
+ - use custom FlaxBartAttention
606
+ """
607
+
608
+ config: DalleBartConfig
609
+ dtype: jnp.dtype = jnp.float32
610
+ add_norm: bool = False
611
+ use_scale: bool = True
612
+
613
+ @nn.compact
614
+ def __call__(
615
+ self,
616
+ hidden_states: jnp.ndarray,
617
+ attention_mask: jnp.ndarray,
618
+ output_attentions: bool = True,
619
+ deterministic: bool = True,
620
+ ) -> Tuple[jnp.ndarray]:
621
+
622
+ res_gain = (
623
+ deepnet_gain["encoder"]["alpha"](self.config)
624
+ if self.config.use_deepnet_scaling
625
+ else 1
626
+ )
627
+
628
+ embed_dim = self.config.d_model
629
+ residual = hidden_states
630
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
631
+ hidden_states = norm(
632
+ self.config.ln_type,
633
+ dtype=self.dtype,
634
+ epsilon=1e-05,
635
+ use_scale=self.config.force_ln_scale,
636
+ )(hidden_states)
637
+ hidden_states, attn_weights = FlaxBartAttention(
638
+ config=self.config,
639
+ embed_dim=embed_dim,
640
+ num_heads=self.config.encoder_attention_heads,
641
+ dropout=self.config.attention_dropout,
642
+ bias=self.config.use_bias,
643
+ dtype=self.dtype,
644
+ is_encoder=True,
645
+ q_length=self.config.max_text_length,
646
+ k_length=self.config.max_text_length,
647
+ )(hidden_states=hidden_states, attention_mask=attention_mask)
648
+
649
+ if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
650
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
651
+ hidden_states
652
+ )
653
+ hidden_states = nn.Dropout(rate=self.config.dropout)(
654
+ hidden_states, deterministic=deterministic
655
+ )
656
+ hidden_states = residual * res_gain + hidden_states
657
+ if self.config.ln_positions in ["postln"]:
658
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
659
+ hidden_states
660
+ )
661
+
662
+ residual = hidden_states
663
+ ff_block = (
664
+ GLU(
665
+ config=self.config,
666
+ ffn_dim=self.config.encoder_ffn_dim,
667
+ embed_dim=embed_dim,
668
+ dtype=self.dtype,
669
+ is_encoder=True,
670
+ )
671
+ if self.config.use_glu
672
+ else FFN(
673
+ config=self.config,
674
+ ffn_dim=self.config.encoder_ffn_dim,
675
+ embed_dim=embed_dim,
676
+ dtype=self.dtype,
677
+ is_encoder=True,
678
+ )
679
+ )
680
+ hidden_states = ff_block(hidden_states, deterministic=deterministic)
681
+ hidden_states = residual * res_gain + hidden_states
682
+ if self.add_norm or self.config.ln_positions in ["postln"]:
683
+ use_scale = (
684
+ self.use_scale
685
+ or self.config.ln_positions == "postln"
686
+ or self.config.force_ln_scale
687
+ )
688
+ hidden_states = norm(
689
+ self.config.ln_type,
690
+ dtype=self.dtype,
691
+ epsilon=1e-05,
692
+ use_scale=use_scale,
693
+ )(hidden_states)
694
+
695
+ outputs = (hidden_states,)
696
+
697
+ if output_attentions:
698
+ outputs += (attn_weights,)
699
+
700
+ return outputs
701
+
702
+
703
+ class FlaxBartDecoderLayer(nn.Module):
704
+ """
705
+ Edits:
706
+ - no bias
707
+ - use custom FlaxBartAttention
708
+ """
709
+
710
+ config: DalleBartConfig
711
+ dtype: jnp.dtype = jnp.float32
712
+ add_norm: bool = False
713
+ use_scale: bool = False
714
+
715
+ @nn.compact
716
+ def __call__(
717
+ self,
718
+ hidden_states: jnp.ndarray,
719
+ attention_mask: jnp.ndarray,
720
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
721
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
722
+ init_cache: bool = False,
723
+ output_attentions: bool = True,
724
+ deterministic: bool = True,
725
+ ) -> Tuple[jnp.ndarray]:
726
+
727
+ res_gain = (
728
+ deepnet_gain["decoder"]["alpha"](self.config)
729
+ if self.config.use_deepnet_scaling
730
+ else 1
731
+ )
732
+
733
+ embed_dim = self.config.d_model
734
+ residual = hidden_states
735
+
736
+ # Self Attention
737
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
738
+ hidden_states = norm(
739
+ self.config.ln_type,
740
+ dtype=self.dtype,
741
+ epsilon=1e-05,
742
+ use_scale=self.config.force_ln_scale,
743
+ )(hidden_states)
744
+ hidden_states, attn_weights = FlaxBartAttention(
745
+ config=self.config,
746
+ embed_dim=embed_dim,
747
+ num_heads=self.config.decoder_attention_heads,
748
+ dropout=self.config.attention_dropout,
749
+ causal=True,
750
+ bias=self.config.use_bias,
751
+ dtype=self.dtype,
752
+ is_encoder=False,
753
+ q_length=self.config.image_length,
754
+ k_length=self.config.image_length,
755
+ )(
756
+ hidden_states=hidden_states,
757
+ attention_mask=attention_mask,
758
+ init_cache=init_cache,
759
+ )
760
+
761
+ if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
762
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
763
+ hidden_states
764
+ )
765
+ hidden_states = nn.Dropout(rate=self.config.dropout)(
766
+ hidden_states, deterministic=deterministic
767
+ )
768
+ hidden_states = residual * res_gain + hidden_states
769
+ if self.config.ln_positions in ["postln"]:
770
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
771
+ hidden_states
772
+ )
773
+
774
+ # Cross Attention
775
+ cross_attn_weights = None
776
+ if encoder_hidden_states is not None:
777
+ residual = hidden_states
778
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
779
+ hidden_states = norm(
780
+ self.config.ln_type,
781
+ dtype=self.dtype,
782
+ epsilon=1e-05,
783
+ use_scale=self.config.force_ln_scale,
784
+ )(hidden_states)
785
+ hidden_states, cross_attn_weights = FlaxBartAttention(
786
+ config=self.config,
787
+ embed_dim=embed_dim,
788
+ num_heads=self.config.decoder_attention_heads,
789
+ dropout=self.config.attention_dropout,
790
+ bias=self.config.use_bias,
791
+ dtype=self.dtype,
792
+ is_encoder=False,
793
+ q_length=self.config.image_length,
794
+ k_length=self.config.max_text_length,
795
+ )(
796
+ hidden_states=hidden_states,
797
+ key_value_states=encoder_hidden_states,
798
+ attention_mask=encoder_attention_mask,
799
+ )
800
+ if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
801
+ hidden_states = norm(
802
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
803
+ )(hidden_states)
804
+ hidden_states = nn.Dropout(rate=self.config.dropout)(
805
+ hidden_states, deterministic=deterministic
806
+ )
807
+ hidden_states = residual * res_gain + hidden_states
808
+ if self.config.ln_positions in ["postln"]:
809
+ hidden_states = norm(
810
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
811
+ )(hidden_states)
812
+
813
+ # Feed forward
814
+ residual = hidden_states
815
+ ff_block = (
816
+ GLU(
817
+ config=self.config,
818
+ ffn_dim=self.config.decoder_ffn_dim,
819
+ embed_dim=embed_dim,
820
+ dtype=self.dtype,
821
+ is_encoder=False,
822
+ )
823
+ if self.config.use_glu
824
+ else FFN(
825
+ config=self.config,
826
+ ffn_dim=self.config.decoder_ffn_dim,
827
+ embed_dim=embed_dim,
828
+ dtype=self.dtype,
829
+ is_encoder=False,
830
+ )
831
+ )
832
+ hidden_states = ff_block(hidden_states, deterministic=deterministic)
833
+ hidden_states = residual * res_gain + hidden_states
834
+ if self.add_norm or self.config.ln_positions in ["postln"]:
835
+ use_scale = (
836
+ self.use_scale
837
+ or self.config.ln_positions == "postln"
838
+ or self.config.force_ln_scale
839
+ )
840
+ hidden_states = norm(
841
+ self.config.ln_type,
842
+ dtype=self.dtype,
843
+ epsilon=1e-05,
844
+ use_scale=use_scale,
845
+ )(hidden_states)
846
+
847
+ outputs = (hidden_states,)
848
+
849
+ if output_attentions:
850
+ outputs += (attn_weights, cross_attn_weights)
851
+
852
+ return outputs
853
+
854
+
855
+ class FlaxBartEncoderLayerCollection(nn.Module):
856
+ config: DalleBartConfig
857
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
858
+ """
859
+ Edits:
860
+ - use custom FlaxBartEncoderLayer
861
+ - allow Gradient Checkpointing (nn.remat)
862
+ """
863
+
864
+ @nn.compact
865
+ def __call__(
866
+ self,
867
+ hidden_states,
868
+ attention_mask,
869
+ deterministic: bool = True,
870
+ output_attentions: bool = False,
871
+ output_hidden_states: bool = False,
872
+ return_dict: bool = True,
873
+ ):
874
+ all_hidden_states = () if output_hidden_states else None
875
+ all_self_attns = () if output_attentions else None
876
+
877
+ n_layers = self.config.encoder_layers
878
+ layer = (
879
+ remat(FlaxBartEncoderLayer, static_argnums=(2, 3))
880
+ if self.config.gradient_checkpointing
881
+ else FlaxBartEncoderLayer
882
+ )
883
+ for i in range(n_layers):
884
+ if output_hidden_states:
885
+ all_hidden_states += (hidden_states,)
886
+ # final layernorm on the output of the last layer
887
+ # or every 6 layers for Swin v2
888
+ add_norm = (
889
+ self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0)
890
+ ) or (self.config.use_final_ln_encoder and (i == n_layers - 1))
891
+ # we don't need to scale the norm for the last layer
892
+ use_scale = i != n_layers - 1
893
+ layer_outputs = layer(
894
+ self.config, dtype=self.dtype, add_norm=add_norm, use_scale=use_scale
895
+ )(
896
+ hidden_states,
897
+ attention_mask,
898
+ output_attentions,
899
+ deterministic,
900
+ )
901
+ hidden_states = layer_outputs[0]
902
+ if output_attentions:
903
+ all_self_attns += (layer_outputs[1],)
904
+
905
+ # add hidden states from the last layer
906
+ if output_hidden_states:
907
+ all_hidden_states += (hidden_states,)
908
+
909
+ outputs = [
910
+ hidden_states,
911
+ all_hidden_states,
912
+ all_self_attns,
913
+ ]
914
+
915
+ if not return_dict:
916
+ return tuple(v for v in outputs if v is not None)
917
+
918
+ return FlaxBaseModelOutput(
919
+ last_hidden_state=hidden_states,
920
+ hidden_states=all_hidden_states,
921
+ attentions=all_self_attns,
922
+ )
923
+
924
+
925
+ class FlaxBartDecoderLayerCollection(nn.Module):
926
+ config: DalleBartConfig
927
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
928
+ """
929
+ Edits:
930
+ - use custom FlaxBartDecoderLayer
931
+ - allow Gradient Checkpointing (nn.remat)
932
+ """
933
+
934
+ @nn.compact
935
+ def __call__(
936
+ self,
937
+ hidden_states,
938
+ attention_mask,
939
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
940
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
941
+ deterministic: bool = True,
942
+ init_cache: bool = False,
943
+ output_attentions: bool = False,
944
+ output_hidden_states: bool = False,
945
+ return_dict: bool = True,
946
+ ):
947
+ # decoder layers
948
+ all_hidden_states = () if output_hidden_states else None
949
+ all_self_attns = () if output_attentions else None
950
+ all_cross_attentions = (
951
+ () if (output_attentions and encoder_hidden_states is not None) else None
952
+ )
953
+
954
+ n_layers = self.config.decoder_layers
955
+ layer = (
956
+ remat(FlaxBartDecoderLayer, static_argnums=(4, 5, 6))
957
+ if self.config.gradient_checkpointing
958
+ else FlaxBartDecoderLayer
959
+ )
960
+ for i in range(n_layers):
961
+ if output_hidden_states:
962
+ all_hidden_states += (hidden_states,)
963
+ # final layernorm on the output of the last layer
964
+ # or every 6 layers for Swin v2
965
+ add_norm = (
966
+ self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0)
967
+ ) or (self.config.use_final_ln_decoder and (i == n_layers - 1))
968
+ # we don't need to scale the norm for the last layer
969
+ use_scale = i != n_layers - 1
970
+ layer_outputs = layer(
971
+ self.config, dtype=self.dtype, add_norm=add_norm, use_scale=use_scale
972
+ )(
973
+ hidden_states,
974
+ attention_mask,
975
+ encoder_hidden_states,
976
+ encoder_attention_mask,
977
+ init_cache,
978
+ output_attentions,
979
+ deterministic,
980
+ )
981
+
982
+ hidden_states = layer_outputs[0]
983
+ if output_attentions:
984
+ all_self_attns += (layer_outputs[1],)
985
+
986
+ if encoder_hidden_states is not None:
987
+ all_cross_attentions += (layer_outputs[2],)
988
+
989
+ # add hidden states from the last decoder layer
990
+ if output_hidden_states:
991
+ all_hidden_states += (hidden_states,)
992
+
993
+ outputs = [
994
+ hidden_states,
995
+ all_hidden_states,
996
+ all_self_attns,
997
+ all_cross_attentions,
998
+ ]
999
+
1000
+ if not return_dict:
1001
+ return tuple(v for v in outputs if v is not None)
1002
+
1003
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
1004
+ last_hidden_state=hidden_states,
1005
+ hidden_states=all_hidden_states,
1006
+ attentions=all_self_attns,
1007
+ cross_attentions=all_cross_attentions,
1008
+ )
1009
+
1010
+
1011
+ class FlaxBartEncoder(nn.Module):
1012
+ config: DalleBartConfig
1013
+ embed_tokens: nn.Embed
1014
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1015
+ """
1016
+ Edits:
1017
+ - offset set to 0 (no padding token)
1018
+ - use max_text_length instead of max_position_embeddings
1019
+ - use custom FlaxBartEncoderLayerCollection
1020
+ - embed_tokens cannot be None (issue at compile time)
1021
+ """
1022
+
1023
+ def setup(self):
1024
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
1025
+
1026
+ embed_dim = self.config.d_model
1027
+ self.padding_idx = self.config.pad_token_id
1028
+ self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
1029
+
1030
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
1031
+ # and adjust num_embeddings appropriately. Other models don't have this hack
1032
+ self.offset = 0
1033
+ if self.config.use_absolute_position_embeddings:
1034
+ self.embed_positions = nn.Embed(
1035
+ self.config.max_text_length + self.offset, # image length for BOS
1036
+ embed_dim,
1037
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1038
+ )
1039
+ self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
1040
+ self.layernorm_embedding = norm(
1041
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
1042
+ )
1043
+
1044
+ def __call__(
1045
+ self,
1046
+ input_ids,
1047
+ attention_mask,
1048
+ position_ids,
1049
+ output_attentions: bool = False,
1050
+ output_hidden_states: bool = False,
1051
+ return_dict: bool = True,
1052
+ deterministic: bool = True,
1053
+ ):
1054
+ input_shape = input_ids.shape
1055
+ input_ids = input_ids.reshape(-1, input_shape[-1])
1056
+
1057
+ hidden_states = self.embed_tokens(input_ids) * self.embed_scale
1058
+
1059
+ if self.config.use_absolute_position_embeddings:
1060
+ embed_pos = self.embed_positions(position_ids + self.offset)
1061
+ hidden_states = hidden_states + embed_pos
1062
+
1063
+ hidden_states = self.layernorm_embedding(hidden_states)
1064
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
1065
+
1066
+ outputs = self.layers(
1067
+ hidden_states,
1068
+ attention_mask,
1069
+ deterministic=deterministic,
1070
+ output_attentions=output_attentions,
1071
+ output_hidden_states=output_hidden_states,
1072
+ return_dict=return_dict,
1073
+ )
1074
+
1075
+ if not return_dict:
1076
+ return outputs
1077
+
1078
+ return FlaxBaseModelOutput(
1079
+ last_hidden_state=outputs.last_hidden_state,
1080
+ hidden_states=outputs.hidden_states,
1081
+ attentions=outputs.attentions,
1082
+ )
1083
+
1084
+
1085
+ class FlaxBartDecoder(nn.Module):
1086
+ config: DalleBartConfig
1087
+ embed_tokens: nn.Embed
1088
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1089
+ """
1090
+ Edits:
1091
+ - offset set to 0 (no padding token)
1092
+ - use image_length instead of max_position_embeddings
1093
+ - use custom FlaxBartDecoderLayerCollection
1094
+ - embed_tokens cannot be None (issue at compile time)
1095
+ """
1096
+
1097
+ def setup(self):
1098
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
1099
+
1100
+ embed_dim = self.config.d_model
1101
+ self.padding_idx = self.config.pad_token_id
1102
+ self.embed_scale = (
1103
+ math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
1104
+ )
1105
+
1106
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
1107
+ # and adjust num_embeddings appropriately. Other models don't have this hack
1108
+ self.offset = 0
1109
+ if self.config.use_absolute_position_embeddings:
1110
+ self.embed_positions = nn.Embed(
1111
+ self.config.image_length + self.offset, # image length for BOS
1112
+ embed_dim,
1113
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1114
+ )
1115
+
1116
+ self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
1117
+ self.layernorm_embedding = norm(
1118
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
1119
+ )
1120
+
1121
+ def __call__(
1122
+ self,
1123
+ input_ids,
1124
+ attention_mask,
1125
+ position_ids,
1126
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
1127
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1128
+ init_cache: bool = False,
1129
+ output_attentions: bool = False,
1130
+ output_hidden_states: bool = False,
1131
+ return_dict: bool = True,
1132
+ deterministic: bool = True,
1133
+ ):
1134
+ input_shape = input_ids.shape
1135
+ input_ids = input_ids.reshape(-1, input_shape[-1])
1136
+
1137
+ hidden_states = self.embed_tokens(input_ids) * self.embed_scale
1138
+
1139
+ if self.config.use_absolute_position_embeddings:
1140
+ embed_pos = self.embed_positions(position_ids + self.offset)
1141
+ hidden_states = hidden_states + embed_pos
1142
+
1143
+ hidden_states = self.layernorm_embedding(hidden_states)
1144
+ hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
1145
+
1146
+ outputs = self.layers(
1147
+ hidden_states,
1148
+ attention_mask,
1149
+ encoder_hidden_states,
1150
+ encoder_attention_mask,
1151
+ deterministic=deterministic,
1152
+ init_cache=init_cache,
1153
+ output_attentions=output_attentions,
1154
+ output_hidden_states=output_hidden_states,
1155
+ return_dict=return_dict,
1156
+ )
1157
+
1158
+ if not return_dict:
1159
+ return outputs
1160
+
1161
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
1162
+ last_hidden_state=outputs.last_hidden_state,
1163
+ hidden_states=outputs.hidden_states,
1164
+ attentions=outputs.attentions,
1165
+ cross_attentions=outputs.cross_attentions,
1166
+ )
1167
+
1168
+
1169
+ class FlaxBartModule(FlaxBartModule):
1170
+ """
1171
+ Edits
1172
+ - use custom FlaxBartEncoder & FlaxBartDecoder
1173
+ - use separate embeddings for Encoder & Decoder
1174
+ """
1175
+
1176
+ def setup(self):
1177
+ encoder_embed_tokens = nn.Embed(
1178
+ self.config.encoder_vocab_size,
1179
+ self.config.d_model,
1180
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1181
+ )
1182
+ decoder_embed_tokens = nn.Embed(
1183
+ self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
1184
+ self.config.d_model,
1185
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
1186
+ )
1187
+
1188
+ self.encoder = FlaxBartEncoder(
1189
+ self.config, dtype=self.dtype, embed_tokens=encoder_embed_tokens
1190
+ )
1191
+ self.decoder = FlaxBartDecoder(
1192
+ self.config, dtype=self.dtype, embed_tokens=decoder_embed_tokens
1193
+ )
1194
+
1195
+
1196
+ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
1197
+ """
1198
+ Edits:
1199
+ - added num_params property
1200
+ - config_class replaced to DalleBartConfig
1201
+ - __init__ accepts abstract_init which does uses parameter shape to initialize the model
1202
+ - init weights on CPU with `load_on_cpu`
1203
+ - restore weights on CPU with custom `from_pretrained`
1204
+ """
1205
+
1206
+ config_class = DalleBartConfig
1207
+
1208
+ def __init__(
1209
+ self,
1210
+ config: DalleBartConfig,
1211
+ input_shape: Tuple[int] = (1, 1),
1212
+ seed: int = 0,
1213
+ dtype: jnp.dtype = jnp.float32,
1214
+ abstract_init: bool = False,
1215
+ load_on_cpu: bool = False,
1216
+ init_weights: bool = True,
1217
+ **kwargs,
1218
+ ):
1219
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
1220
+
1221
+ # adapted from HuggingFace FlaxPreTrainedModel
1222
+ if config is None:
1223
+ raise ValueError("config cannot be None")
1224
+
1225
+ if module is None:
1226
+ raise ValueError("module cannot be None")
1227
+
1228
+ # Those are private to be exposed as typed property on derived classes.
1229
+ self._config = config
1230
+ self._module = module
1231
+
1232
+ # Those are public as their type is generic to every derived classes.
1233
+ self.key = PRNGKey(seed)
1234
+ self.dtype = dtype
1235
+
1236
+ if init_weights:
1237
+ # get shape of params only
1238
+ random_params = self.init_weights(
1239
+ self.key,
1240
+ input_shape,
1241
+ abstract_init=abstract_init,
1242
+ load_on_cpu=load_on_cpu,
1243
+ )
1244
+
1245
+ # save required_params as set
1246
+ self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
1247
+ self.params = random_params
1248
+
1249
+ def init_weights(
1250
+ self, rng=None, input_shape=(1, 1), abstract_init=False, load_on_cpu=False
1251
+ ):
1252
+ if rng is None:
1253
+ rng = self.key
1254
+ init_fn = super().init_weights
1255
+ if load_on_cpu:
1256
+ init_fn = jax.jit(init_fn, static_argnums=(1,), backend="cpu")
1257
+ if abstract_init:
1258
+ # only set shape and dtype, load parameters separately
1259
+ init_fn = partial(init_fn, input_shape=input_shape)
1260
+ params = jax.eval_shape(init_fn, rng)
1261
+ else:
1262
+ params = init_fn(rng, input_shape)
1263
+ return params
1264
+
1265
+ @property
1266
+ def num_params(self):
1267
+ num_params = jax.tree_map(
1268
+ lambda param: param.size, flatten_dict(unfreeze(self.params))
1269
+ ).values()
1270
+ return sum(list(num_params))
1271
+
1272
+ @classmethod
1273
+ def from_pretrained(
1274
+ cls,
1275
+ pretrained_model_name_or_path: Union[str, os.PathLike],
1276
+ dtype: jnp.dtype = jnp.float32,
1277
+ *model_args,
1278
+ **kwargs,
1279
+ ):
1280
+ config = kwargs.pop("config", None)
1281
+ cache_dir = kwargs.pop("cache_dir", None)
1282
+ from_pt = kwargs.pop("from_pt", False)
1283
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1284
+ force_download = kwargs.pop("force_download", False)
1285
+ resume_download = kwargs.pop("resume_download", False)
1286
+ proxies = kwargs.pop("proxies", None)
1287
+ local_files_only = kwargs.pop("local_files_only", False)
1288
+ use_auth_token = kwargs.pop("use_auth_token", None)
1289
+ revision = kwargs.pop("revision", None)
1290
+ from_pipeline = kwargs.pop("_from_pipeline", None)
1291
+ from_auto_class = kwargs.pop("_from_auto", False)
1292
+
1293
+ user_agent = {
1294
+ "file_type": "model",
1295
+ "framework": "flax",
1296
+ "from_auto_class": from_auto_class,
1297
+ }
1298
+ if from_pipeline is not None:
1299
+ user_agent["using_pipeline"] = from_pipeline
1300
+
1301
+ if is_offline_mode() and not local_files_only:
1302
+ logger.info("Offline mode: forcing local_files_only=True")
1303
+ local_files_only = True
1304
+
1305
+ # Load config if we don't provide a configuration
1306
+ if not isinstance(config, PretrainedConfig):
1307
+ config_path = (
1308
+ config if config is not None else pretrained_model_name_or_path
1309
+ )
1310
+ config, model_kwargs = cls.config_class.from_pretrained(
1311
+ config_path,
1312
+ cache_dir=cache_dir,
1313
+ return_unused_kwargs=True,
1314
+ force_download=force_download,
1315
+ resume_download=resume_download,
1316
+ proxies=proxies,
1317
+ local_files_only=local_files_only,
1318
+ use_auth_token=use_auth_token,
1319
+ revision=revision,
1320
+ _from_auto=from_auto_class,
1321
+ _from_pipeline=from_pipeline,
1322
+ **kwargs,
1323
+ )
1324
+ else:
1325
+ model_kwargs = kwargs
1326
+
1327
+ # Add the dtype to model_kwargs
1328
+ model_kwargs["dtype"] = dtype
1329
+
1330
+ # Load model
1331
+ if pretrained_model_name_or_path is not None:
1332
+ if os.path.isdir(pretrained_model_name_or_path):
1333
+ if from_pt and os.path.isfile(
1334
+ os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
1335
+ ):
1336
+ # Load from a PyTorch checkpoint
1337
+ archive_file = os.path.join(
1338
+ pretrained_model_name_or_path, WEIGHTS_NAME
1339
+ )
1340
+ elif os.path.isfile(
1341
+ os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
1342
+ ):
1343
+ # Load from a Flax checkpoint
1344
+ archive_file = os.path.join(
1345
+ pretrained_model_name_or_path, FLAX_WEIGHTS_NAME
1346
+ )
1347
+ else:
1348
+ raise EnvironmentError(
1349
+ f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory "
1350
+ f"{pretrained_model_name_or_path} or `from_pt` set to False"
1351
+ )
1352
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
1353
+ pretrained_model_name_or_path
1354
+ ):
1355
+ archive_file = pretrained_model_name_or_path
1356
+ else:
1357
+ archive_file = hf_bucket_url(
1358
+ pretrained_model_name_or_path,
1359
+ filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
1360
+ revision=revision,
1361
+ )
1362
+
1363
+ # redirect to the cache, if necessary
1364
+ try:
1365
+ resolved_archive_file = cached_path(
1366
+ archive_file,
1367
+ cache_dir=cache_dir,
1368
+ force_download=force_download,
1369
+ proxies=proxies,
1370
+ resume_download=resume_download,
1371
+ local_files_only=local_files_only,
1372
+ use_auth_token=use_auth_token,
1373
+ user_agent=user_agent,
1374
+ )
1375
+ except EnvironmentError as err:
1376
+ logger.error(err)
1377
+ msg = (
1378
+ f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
1379
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n"
1380
+ f" (make sure '{pretrained_model_name_or_path}' is not a path to a local directory with something else, in that case)\n\n"
1381
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
1382
+ )
1383
+ raise EnvironmentError(msg)
1384
+
1385
+ if resolved_archive_file == archive_file:
1386
+ logger.info(f"loading weights file {archive_file}")
1387
+ else:
1388
+ logger.info(
1389
+ f"loading weights file {archive_file} from cache at {resolved_archive_file}"
1390
+ )
1391
+ else:
1392
+ resolved_archive_file = None
1393
+
1394
+ # init random models
1395
+ model = cls(config, *model_args, **model_kwargs)
1396
+
1397
+ with open(resolved_archive_file, "rb") as state_f:
1398
+ try:
1399
+ state = from_bytes(cls, state_f.read())
1400
+ except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
1401
+ try:
1402
+ with open(resolved_archive_file) as f:
1403
+ if f.read().startswith("version"):
1404
+ raise OSError(
1405
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
1406
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
1407
+ "you cloned."
1408
+ )
1409
+ else:
1410
+ raise ValueError from e
1411
+ except (UnicodeDecodeError, ValueError):
1412
+ raise EnvironmentError(
1413
+ f"Unable to convert {archive_file} to Flax deserializable object. "
1414
+ )
1415
+
1416
+ # if model is base model only use model_prefix key
1417
+ if (
1418
+ cls.base_model_prefix not in dict(model.params)
1419
+ and cls.base_model_prefix in state
1420
+ ):
1421
+ state = state[cls.base_model_prefix]
1422
+
1423
+ # if model is head model and we are loading weights from base model
1424
+ # we initialize new params dict with base_model_prefix
1425
+ if (
1426
+ cls.base_model_prefix in dict(model.params)
1427
+ and cls.base_model_prefix not in state
1428
+ ):
1429
+ state = {cls.base_model_prefix: state}
1430
+
1431
+ # flatten dicts
1432
+ state = flatten_dict(state)
1433
+
1434
+ random_state = flatten_dict(unfreeze(model.params))
1435
+
1436
+ missing_keys = model.required_params - set(state.keys())
1437
+ unexpected_keys = set(state.keys()) - model.required_params
1438
+
1439
+ # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
1440
+ # matching the weights in the model.
1441
+ mismatched_keys = []
1442
+ for key in state.keys():
1443
+ if key in random_state and state[key].shape != random_state[key].shape:
1444
+ if ignore_mismatched_sizes:
1445
+ mismatched_keys.append(
1446
+ (key, state[key].shape, random_state[key].shape)
1447
+ )
1448
+ state[key] = random_state[key]
1449
+ else:
1450
+ raise ValueError(
1451
+ f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
1452
+ f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
1453
+ "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
1454
+ "model."
1455
+ )
1456
+
1457
+ # add missing keys as random parameters
1458
+ for missing_key in missing_keys:
1459
+ state[missing_key] = random_state[missing_key]
1460
+
1461
+ # remove unexpected keys to not be saved again
1462
+ for unexpected_key in unexpected_keys:
1463
+ del state[unexpected_key]
1464
+
1465
+ if len(unexpected_keys) > 0:
1466
+ logger.warning(
1467
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
1468
+ f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
1469
+ f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
1470
+ f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
1471
+ f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
1472
+ f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
1473
+ )
1474
+ else:
1475
+ logger.info(
1476
+ f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
1477
+ )
1478
+
1479
+ if len(missing_keys) > 0:
1480
+ logger.warning(
1481
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
1482
+ f"and are newly initialized: {missing_keys}\n"
1483
+ f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1484
+ )
1485
+ elif len(mismatched_keys) == 0:
1486
+ logger.info(
1487
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
1488
+ f"If your task is similar to the task the model of the checkpoint was trained on, "
1489
+ f"you can already use {model.__class__.__name__} for predictions without further training."
1490
+ )
1491
+ if len(mismatched_keys) > 0:
1492
+ mismatched_warning = "\n".join(
1493
+ [
1494
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1495
+ for key, shape1, shape2 in mismatched_keys
1496
+ ]
1497
+ )
1498
+ logger.warning(
1499
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
1500
+ f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
1501
+ f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1502
+ )
1503
+
1504
+ # set correct parameters
1505
+ model.params = unflatten_dict(state)
1506
+
1507
+ return model
1508
+
1509
+
1510
+ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
1511
+ """
1512
+ Edits:
1513
+ - no bias
1514
+ - lm_head set to image_vocab_size + 1 (for BOS)
1515
+ - uses custom FlaxBartModule
1516
+ """
1517
+
1518
+ def setup(self):
1519
+ self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
1520
+ self.lm_head = nn.Dense(
1521
+ self.config.image_vocab_size
1522
+ + 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
1523
+ use_bias=False,
1524
+ dtype=self.dtype,
1525
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
1526
+ )
1527
+
1528
+ def __call__(
1529
+ self,
1530
+ input_ids,
1531
+ attention_mask,
1532
+ decoder_input_ids,
1533
+ decoder_attention_mask,
1534
+ position_ids,
1535
+ decoder_position_ids,
1536
+ output_attentions: bool = False,
1537
+ output_hidden_states: bool = False,
1538
+ return_dict: bool = True,
1539
+ deterministic: bool = True,
1540
+ ):
1541
+ outputs = self.model(
1542
+ input_ids=input_ids,
1543
+ attention_mask=attention_mask,
1544
+ decoder_input_ids=decoder_input_ids,
1545
+ decoder_attention_mask=decoder_attention_mask,
1546
+ position_ids=position_ids,
1547
+ decoder_position_ids=decoder_position_ids,
1548
+ output_attentions=output_attentions,
1549
+ output_hidden_states=output_hidden_states,
1550
+ return_dict=return_dict,
1551
+ deterministic=deterministic,
1552
+ )
1553
+
1554
+ hidden_states = outputs[0]
1555
+
1556
+ if self.config.tie_word_embeddings:
1557
+ shared_embedding = self.model.variables["params"]["shared"]["embedding"]
1558
+ lm_logits = self.lm_head.apply(
1559
+ {"params": {"kernel": shared_embedding.T}}, hidden_states
1560
+ )
1561
+ else:
1562
+ lm_logits = self.lm_head(hidden_states)
1563
+
1564
+ if not return_dict:
1565
+ output = (lm_logits,) + outputs[1:]
1566
+ return output
1567
+
1568
+ return FlaxSeq2SeqLMOutput(
1569
+ logits=lm_logits,
1570
+ decoder_hidden_states=outputs.decoder_hidden_states,
1571
+ decoder_attentions=outputs.decoder_attentions,
1572
+ cross_attentions=outputs.cross_attentions,
1573
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1574
+ encoder_hidden_states=outputs.encoder_hidden_states,
1575
+ encoder_attentions=outputs.encoder_attentions,
1576
+ )
1577
+
1578
+
1579
+ @flax.struct.dataclass
1580
+ class SampleState:
1581
+ cur_len: jnp.ndarray
1582
+ sequences: jnp.ndarray
1583
+ running_token: jnp.ndarray
1584
+ is_sent_finished: jnp.ndarray
1585
+ prng_key: jnp.ndarray
1586
+ model_kwargs: Dict[str, jnp.ndarray]
1587
+ model_kwargs_uncond: Dict[str, jnp.ndarray]
1588
+
1589
+
1590
+ class DalleBart(
1591
+ PretrainedFromWandbMixin, FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration
1592
+ ):
1593
+ """
1594
+ Edits:
1595
+ - renamed from FlaxBartForConditionalGeneration
1596
+ - uses custom FlaxBartPreTrainedModel
1597
+ - uses custom FlaxBartForConditionalGenerationModule
1598
+ - no bias in decode method
1599
+ - custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
1600
+ related to position embedding during model.generate()
1601
+ - custom generate method to allow super conditions
1602
+ """
1603
+
1604
+ module_class = FlaxBartForConditionalGenerationModule
1605
+
1606
+ def decode(
1607
+ self,
1608
+ decoder_input_ids,
1609
+ encoder_outputs,
1610
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1611
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
1612
+ decoder_position_ids: Optional[jnp.ndarray] = None,
1613
+ past_key_values: dict = None,
1614
+ output_attentions: Optional[bool] = None,
1615
+ output_hidden_states: Optional[bool] = None,
1616
+ return_dict: Optional[bool] = None,
1617
+ train: bool = False,
1618
+ params: dict = None,
1619
+ dropout_rng: PRNGKey = None,
1620
+ ):
1621
+ output_attentions = (
1622
+ output_attentions
1623
+ if output_attentions is not None
1624
+ else self.config.output_attentions
1625
+ )
1626
+ output_hidden_states = (
1627
+ output_hidden_states
1628
+ if output_hidden_states is not None
1629
+ else self.config.output_hidden_states
1630
+ )
1631
+ return_dict = (
1632
+ return_dict if return_dict is not None else self.config.return_dict
1633
+ )
1634
+
1635
+ encoder_hidden_states = encoder_outputs[0]
1636
+ if encoder_attention_mask is None:
1637
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
1638
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
1639
+
1640
+ batch_size, sequence_length = decoder_input_ids.shape
1641
+ if decoder_attention_mask is None:
1642
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
1643
+
1644
+ if decoder_position_ids is None:
1645
+ if past_key_values is not None:
1646
+ raise ValueError(
1647
+ "Make sure to provide `decoder_position_ids` when passing `past_key_values`."
1648
+ )
1649
+
1650
+ decoder_position_ids = jnp.broadcast_to(
1651
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
1652
+ )
1653
+
1654
+ # Handle any PRNG if needed
1655
+ rngs = {}
1656
+ if dropout_rng is not None:
1657
+ rngs["dropout"] = dropout_rng
1658
+
1659
+ inputs = {"params": params or self.params}
1660
+
1661
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
1662
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
1663
+ # it can be changed by FlaxBartAttention module
1664
+ if past_key_values:
1665
+ inputs["cache"] = past_key_values
1666
+ mutable = ["cache"]
1667
+ else:
1668
+ mutable = False
1669
+
1670
+ def _decoder_forward(
1671
+ module,
1672
+ decoder_input_ids,
1673
+ decoder_attention_mask,
1674
+ decoder_position_ids,
1675
+ **kwargs,
1676
+ ):
1677
+ decoder_module = module._get_decoder_module()
1678
+ outputs = decoder_module(
1679
+ decoder_input_ids,
1680
+ decoder_attention_mask,
1681
+ decoder_position_ids,
1682
+ **kwargs,
1683
+ )
1684
+ hidden_states = outputs[0]
1685
+
1686
+ if self.config.tie_word_embeddings:
1687
+ shared_embedding = module.model.variables["params"]["shared"][
1688
+ "embedding"
1689
+ ]
1690
+ lm_logits = module.lm_head.apply(
1691
+ {"params": {"kernel": shared_embedding.T}}, hidden_states
1692
+ )
1693
+ else:
1694
+ lm_logits = module.lm_head(hidden_states)
1695
+
1696
+ return lm_logits, outputs
1697
+
1698
+ outputs = self.module.apply(
1699
+ inputs,
1700
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
1701
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
1702
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
1703
+ encoder_hidden_states=encoder_hidden_states,
1704
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
1705
+ output_attentions=output_attentions,
1706
+ output_hidden_states=output_hidden_states,
1707
+ return_dict=return_dict,
1708
+ deterministic=not train,
1709
+ rngs=rngs,
1710
+ mutable=mutable,
1711
+ method=_decoder_forward,
1712
+ )
1713
+
1714
+ if past_key_values is None:
1715
+ lm_logits, decoder_outputs = outputs
1716
+ else:
1717
+ (lm_logits, decoder_outputs), past = outputs
1718
+
1719
+ if return_dict:
1720
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
1721
+ logits=lm_logits,
1722
+ hidden_states=decoder_outputs.hidden_states,
1723
+ attentions=decoder_outputs.attentions,
1724
+ cross_attentions=decoder_outputs.cross_attentions,
1725
+ )
1726
+ else:
1727
+ outputs = (lm_logits,) + decoder_outputs[1:]
1728
+
1729
+ # add updated cache to model output
1730
+ if past_key_values is not None and return_dict:
1731
+ outputs["past_key_values"] = unfreeze(past["cache"])
1732
+ return outputs
1733
+ elif past_key_values is not None and not return_dict:
1734
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
1735
+
1736
+ return outputs
1737
+
1738
+ def prepare_inputs_for_generation(
1739
+ self,
1740
+ decoder_input_ids,
1741
+ max_length,
1742
+ attention_mask: Optional[jnp.DeviceArray] = None,
1743
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
1744
+ encoder_outputs=None,
1745
+ **kwargs,
1746
+ ):
1747
+ # initializing the cache
1748
+ batch_size, seq_length = decoder_input_ids.shape
1749
+
1750
+ past_key_values = self.init_cache(batch_size, max_length - 1, encoder_outputs)
1751
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1752
+ # But since the decoder uses a causal mask, those positions are masked anyways.
1753
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
1754
+ extended_attention_mask = jnp.ones((batch_size, max_length - 1), dtype="i4")
1755
+ if decoder_attention_mask is not None:
1756
+ position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
1757
+ extended_attention_mask = lax.dynamic_update_slice(
1758
+ extended_attention_mask, decoder_attention_mask, (0, 0)
1759
+ )
1760
+ else:
1761
+ position_ids = jnp.broadcast_to(
1762
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
1763
+ )
1764
+
1765
+ return {
1766
+ "past_key_values": past_key_values,
1767
+ "encoder_outputs": encoder_outputs,
1768
+ "encoder_attention_mask": attention_mask,
1769
+ "decoder_attention_mask": extended_attention_mask,
1770
+ "decoder_position_ids": position_ids,
1771
+ }
1772
+
1773
+ def generate(
1774
+ self,
1775
+ input_ids: jnp.ndarray,
1776
+ attention_mask: Optional[jnp.ndarray] = None,
1777
+ max_length: Optional[int] = None,
1778
+ pad_token_id: Optional[int] = None,
1779
+ bos_token_id: Optional[int] = None,
1780
+ eos_token_id: Optional[int] = None,
1781
+ decoder_start_token_id: Optional[int] = None,
1782
+ do_sample: Optional[bool] = None,
1783
+ prng_key: Optional[jnp.ndarray] = None,
1784
+ top_k: Optional[int] = None,
1785
+ top_p: Optional[float] = None,
1786
+ temperature: Optional[float] = None,
1787
+ num_beams: Optional[int] = None,
1788
+ no_repeat_ngram_size: Optional[int] = None,
1789
+ min_length: Optional[int] = None,
1790
+ forced_bos_token_id: Optional[int] = None,
1791
+ forced_eos_token_id: Optional[int] = None,
1792
+ length_penalty: Optional[float] = None,
1793
+ early_stopping: Optional[bool] = None,
1794
+ trace: bool = True,
1795
+ params: Optional[Dict[str, jnp.ndarray]] = None,
1796
+ condition_scale: Optional[float] = 1.0,
1797
+ input_ids_uncond: Optional[jnp.ndarray] = None,
1798
+ attention_mask_uncond: Optional[jnp.ndarray] = None,
1799
+ **model_kwargs,
1800
+ ):
1801
+ """Edit: Allow super conditioning."""
1802
+
1803
+ # set init values
1804
+ max_length = max_length if max_length is not None else self.config.max_length
1805
+ bos_token_id = (
1806
+ bos_token_id if bos_token_id is not None else self.config.bos_token_id
1807
+ )
1808
+ pad_token_id = (
1809
+ pad_token_id if pad_token_id is not None else self.config.pad_token_id
1810
+ )
1811
+ eos_token_id = (
1812
+ eos_token_id if eos_token_id is not None else self.config.eos_token_id
1813
+ )
1814
+ decoder_start_token_id = (
1815
+ decoder_start_token_id
1816
+ if decoder_start_token_id
1817
+ else self.config.decoder_start_token_id
1818
+ )
1819
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
1820
+
1821
+ if decoder_start_token_id is None and self.config.is_encoder_decoder:
1822
+ raise ValueError(
1823
+ "`decoder_start_token_id` has to be defined for encoder-decoder generation."
1824
+ )
1825
+
1826
+ do_sample = do_sample if do_sample is not None else self.config.do_sample
1827
+ num_beams = num_beams if num_beams is not None else self.config.num_beams
1828
+
1829
+ if self.config.is_encoder_decoder:
1830
+ # add encoder_outputs to model_kwargs
1831
+ if model_kwargs.get("encoder_outputs") is None:
1832
+ model_kwargs_input = dict(model_kwargs)
1833
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
1834
+ input_ids,
1835
+ params,
1836
+ {"attention_mask": attention_mask, **model_kwargs_input},
1837
+ )
1838
+ if condition_scale != 1.0:
1839
+ assert (
1840
+ input_ids_uncond is not None
1841
+ ), "`input_ids_uncond` has to be defined for super conditioning."
1842
+ assert (
1843
+ do_sample is True
1844
+ ), "`do_sample` has to be True for super conditioning."
1845
+ assert (
1846
+ num_beams == 1
1847
+ ), "`num_beams` has to be 1 for super conditioning."
1848
+ model_kwargs_uncond = (
1849
+ self._prepare_encoder_decoder_kwargs_for_generation(
1850
+ input_ids_uncond,
1851
+ params,
1852
+ {
1853
+ "attention_mask": attention_mask_uncond,
1854
+ **model_kwargs_input,
1855
+ },
1856
+ )
1857
+ )
1858
+ else:
1859
+ model_kwargs_uncond = None
1860
+ # prepare decoder_input_ids for generation
1861
+ input_ids = (
1862
+ jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
1863
+ )
1864
+
1865
+ if not do_sample and num_beams == 1:
1866
+ logits_processor = self._get_logits_processor(
1867
+ no_repeat_ngram_size,
1868
+ min_length,
1869
+ max_length,
1870
+ eos_token_id,
1871
+ forced_bos_token_id,
1872
+ forced_eos_token_id,
1873
+ )
1874
+ return self._greedy_search(
1875
+ input_ids,
1876
+ max_length,
1877
+ pad_token_id,
1878
+ eos_token_id,
1879
+ logits_processor=logits_processor,
1880
+ trace=trace,
1881
+ params=params,
1882
+ model_kwargs=model_kwargs,
1883
+ )
1884
+ elif do_sample and num_beams == 1:
1885
+ logits_warper = self._get_logits_warper(
1886
+ top_k=top_k, top_p=top_p, temperature=temperature
1887
+ )
1888
+ logits_processor = self._get_logits_processor(
1889
+ no_repeat_ngram_size,
1890
+ min_length,
1891
+ max_length,
1892
+ eos_token_id,
1893
+ forced_bos_token_id,
1894
+ forced_eos_token_id,
1895
+ )
1896
+ return self._sample(
1897
+ input_ids,
1898
+ max_length,
1899
+ pad_token_id,
1900
+ eos_token_id,
1901
+ prng_key,
1902
+ logits_warper=logits_warper,
1903
+ logits_processor=logits_processor,
1904
+ trace=trace,
1905
+ params=params,
1906
+ model_kwargs=model_kwargs,
1907
+ condition_scale=condition_scale,
1908
+ model_kwargs_uncond=model_kwargs_uncond,
1909
+ )
1910
+ elif not do_sample and num_beams > 1:
1911
+ # broadcast input_ids & encoder_outputs
1912
+ input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
1913
+
1914
+ if "encoder_outputs" in model_kwargs:
1915
+ model_kwargs["encoder_outputs"][
1916
+ "last_hidden_state"
1917
+ ] = self._expand_to_num_beams(
1918
+ model_kwargs["encoder_outputs"]["last_hidden_state"],
1919
+ num_beams=num_beams,
1920
+ )
1921
+
1922
+ if "attention_mask" in model_kwargs:
1923
+ model_kwargs["attention_mask"] = self._expand_to_num_beams(
1924
+ model_kwargs["attention_mask"], num_beams=num_beams
1925
+ )
1926
+
1927
+ logits_processor = self._get_logits_processor(
1928
+ no_repeat_ngram_size,
1929
+ min_length,
1930
+ max_length,
1931
+ eos_token_id,
1932
+ forced_bos_token_id,
1933
+ forced_eos_token_id,
1934
+ )
1935
+
1936
+ return self._beam_search(
1937
+ input_ids,
1938
+ max_length,
1939
+ pad_token_id,
1940
+ eos_token_id,
1941
+ length_penalty=length_penalty,
1942
+ early_stopping=early_stopping,
1943
+ logits_processor=logits_processor,
1944
+ trace=trace,
1945
+ params=params,
1946
+ model_kwargs=model_kwargs,
1947
+ )
1948
+ else:
1949
+ raise NotImplementedError("`Beam sampling is currently not implemented.")
1950
+
1951
+ def _sample(
1952
+ self,
1953
+ input_ids: None,
1954
+ max_length: Optional[int] = None,
1955
+ pad_token_id: Optional[int] = None,
1956
+ eos_token_id: Optional[int] = None,
1957
+ prng_key: Optional[jnp.ndarray] = None,
1958
+ logits_processor=None,
1959
+ logits_warper=None,
1960
+ trace: bool = True,
1961
+ params: Optional[Dict[str, jnp.ndarray]] = None,
1962
+ model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
1963
+ condition_scale: float = 1.0,
1964
+ model_kwargs_uncond: Optional[Dict[str, jnp.ndarray]] = None,
1965
+ ):
1966
+ # init values
1967
+ max_length = max_length if max_length is not None else self.config.max_length
1968
+ pad_token_id = (
1969
+ pad_token_id if pad_token_id is not None else self.config.pad_token_id
1970
+ )
1971
+ eos_token_id = (
1972
+ eos_token_id if eos_token_id is not None else self.config.eos_token_id
1973
+ )
1974
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
1975
+
1976
+ batch_size, cur_len = input_ids.shape
1977
+
1978
+ eos_token_id = jnp.array(eos_token_id)
1979
+ pad_token_id = jnp.array(pad_token_id)
1980
+ cur_len = jnp.array(cur_len)
1981
+
1982
+ # per batch-item holding current token in loop.
1983
+ sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
1984
+ sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
1985
+
1986
+ # per batch-item state bit indicating if sentence has finished.
1987
+ is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
1988
+
1989
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
1990
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
1991
+ model = self.decode if self.config.is_encoder_decoder else self
1992
+
1993
+ # initialize model specific kwargs
1994
+ model_kwargs = self.prepare_inputs_for_generation(
1995
+ input_ids, max_length, **model_kwargs
1996
+ )
1997
+ if condition_scale != 1.0:
1998
+ model_kwargs_uncond = self.prepare_inputs_for_generation(
1999
+ input_ids, max_length, **model_kwargs_uncond
2000
+ )
2001
+
2002
+ # initialize state
2003
+ state = SampleState(
2004
+ cur_len=cur_len,
2005
+ sequences=sequences,
2006
+ running_token=input_ids,
2007
+ is_sent_finished=is_sent_finished,
2008
+ prng_key=prng_key,
2009
+ model_kwargs=model_kwargs,
2010
+ model_kwargs_uncond=model_kwargs_uncond,
2011
+ )
2012
+
2013
+ def sample_search_cond_fn(state):
2014
+ """state termination condition fn."""
2015
+ has_reached_max_length = state.cur_len == max_length
2016
+ all_sequence_finished = jnp.all(state.is_sent_finished)
2017
+ finish_generation = jnp.logical_or(
2018
+ has_reached_max_length, all_sequence_finished
2019
+ )
2020
+ return ~finish_generation
2021
+
2022
+ def sample_search_body_fn(state):
2023
+ """state update fn."""
2024
+ prng_key, prng_key_next = jax.random.split(state.prng_key)
2025
+ model_outputs = model(
2026
+ state.running_token, params=params, **state.model_kwargs
2027
+ )
2028
+
2029
+ logits = model_outputs.logits[:, -1]
2030
+
2031
+ # perform super conditioning
2032
+ # Source: @RiversHaveWings - https://twitter.com/RiversHaveWings/status/1478093658716966912?s=20&t=xdm-wZ61Wf7OLnE_NJHZ1w
2033
+ if condition_scale != 1.0:
2034
+ model_outputs_uncond = model(
2035
+ state.running_token, params=params, **state.model_kwargs_uncond
2036
+ )
2037
+ logits_uncond = model_outputs_uncond.logits[:, -1]
2038
+ logits = logits_uncond + condition_scale * (logits - logits_uncond)
2039
+ else:
2040
+ model_outputs_uncond = None
2041
+
2042
+ # apply min_length, ...
2043
+ logits = logits_processor(state.sequences, logits, state.cur_len)
2044
+ # apply top_k, top_k, temperature
2045
+ logits = logits_warper(logits, logits, state.cur_len)
2046
+
2047
+ next_token = jax.random.categorical(prng_key, logits, axis=-1)
2048
+
2049
+ next_is_sent_finished = state.is_sent_finished | (
2050
+ next_token == eos_token_id
2051
+ )
2052
+ next_token = (
2053
+ next_token * ~next_is_sent_finished
2054
+ + pad_token_id * next_is_sent_finished
2055
+ )
2056
+ next_token = next_token[:, None]
2057
+
2058
+ next_sequences = lax.dynamic_update_slice(
2059
+ state.sequences, next_token, (0, state.cur_len)
2060
+ )
2061
+ next_model_kwargs = self.update_inputs_for_generation(
2062
+ model_outputs, state.model_kwargs
2063
+ )
2064
+ next_model_kwargs_uncond = (
2065
+ self.update_inputs_for_generation(
2066
+ model_outputs_uncond, state.model_kwargs_uncond
2067
+ )
2068
+ if condition_scale != 1.0
2069
+ else None
2070
+ )
2071
+
2072
+ return SampleState(
2073
+ cur_len=state.cur_len + 1,
2074
+ sequences=next_sequences,
2075
+ running_token=next_token,
2076
+ is_sent_finished=next_is_sent_finished,
2077
+ model_kwargs=next_model_kwargs,
2078
+ model_kwargs_uncond=next_model_kwargs_uncond,
2079
+ prng_key=prng_key_next,
2080
+ )
2081
+
2082
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
2083
+ if input_ids.shape[1] > 1:
2084
+ state = sample_search_body_fn(state)
2085
+
2086
+ if not trace:
2087
+ state = self._run_loop_in_debug(
2088
+ sample_search_cond_fn, sample_search_body_fn, state
2089
+ )
2090
+ else:
2091
+ state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
2092
+
2093
+ return FlaxSampleOutput(sequences=state.sequences)
src/dalle_mini/model/partitions.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from flax.core.frozen_dict import freeze
4
+ from flax.traverse_util import flatten_dict, unflatten_dict
5
+ from jax.experimental import PartitionSpec as P
6
+
7
+ # utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
8
+ # Sentinels
9
+ _unmatched = object()
10
+
11
+ # For specifying empty leaf dict `{}`
12
+ empty_dict = object()
13
+
14
+
15
+ def _match(qs, ks):
16
+ """Return True if regexes in qs match any window of strings in tuple ks."""
17
+ # compile regexes and force complete match
18
+ qts = tuple(map(lambda x: re.compile(x + "$"), qs))
19
+ for i in range(len(ks) - len(qs) + 1):
20
+ matches = [x.match(y) for x, y in zip(qts, ks[i:])]
21
+ if matches and all(matches):
22
+ return True
23
+ return False
24
+
25
+
26
+ def _replacement_rules(rules):
27
+ def replace(key, val):
28
+ for rule, replacement in rules:
29
+ if _match(rule, key):
30
+ return replacement
31
+ return val
32
+
33
+ return replace
34
+
35
+
36
+ def _get_partition_rules():
37
+ return [
38
+ # embeddings
39
+ (("embed_positions", "embedding"), P("mp", None)),
40
+ (("embed_tokens", "embedding"), P("mp", None)),
41
+ (("rel_bias", "embedding"), P(None, "mp")),
42
+ # attention
43
+ (("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
44
+ (("out_proj", "kernel"), P("mp", None)),
45
+ # FFN
46
+ (("Dense_0", "kernel"), P(None, "mp")),
47
+ (("GLU.*", "Dense_1", "kernel"), P(None, "mp")),
48
+ (("GLU.*", "Dense_2", "kernel"), P("mp", None)),
49
+ (("FFN.*", "Dense_1", "kernel"), P("mp", None)),
50
+ # layer norms
51
+ (("(bias|scale)",), None),
52
+ (("lm_head", "kernel"), P(None, "mp")),
53
+ # head scale and tau
54
+ (("(head_scale|tau)",), None),
55
+ ]
56
+
57
+
58
+ def set_partitions(in_dict):
59
+ rules = _get_partition_rules()
60
+ replace = _replacement_rules(rules)
61
+ initd = {k: _unmatched for k in flatten_dict(in_dict)}
62
+ result = {k: replace(k, v) for k, v in initd.items()}
63
+ for k, v in result.items():
64
+ if v == _unmatched:
65
+ print(f"Unmatched -> {k}")
66
+ assert _unmatched not in result.values(), "Incomplete partition spec."
67
+ return freeze(unflatten_dict(result))
src/dalle_mini/model/processor.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ DalleBart processor """
2
+
3
+ import jax.numpy as jnp
4
+
5
+ from .configuration import DalleBartConfig
6
+ from .text import TextNormalizer
7
+ from .tokenizer import DalleBartTokenizer
8
+ from .utils import PretrainedFromWandbMixin
9
+
10
+
11
+ class DalleBartProcessorBase:
12
+ def __init__(
13
+ self, tokenizer: DalleBartTokenizer, normalize_text: bool, max_text_length: int
14
+ ):
15
+ self.tokenizer = tokenizer
16
+ self.normalize_text = normalize_text
17
+ self.max_text_length = max_text_length
18
+ if normalize_text:
19
+ self.text_processor = TextNormalizer()
20
+ # create unconditional tokens
21
+ uncond = self.tokenizer(
22
+ "",
23
+ return_tensors="jax",
24
+ padding="max_length",
25
+ truncation=True,
26
+ max_length=self.max_text_length,
27
+ ).data
28
+ self.input_ids_uncond = uncond["input_ids"]
29
+ self.attention_mask_uncond = uncond["attention_mask"]
30
+
31
+ def __call__(self, text: str = None):
32
+ # check that text is not a string
33
+ assert not isinstance(text, str), "text must be a list of strings"
34
+
35
+ if self.normalize_text:
36
+ text = [self.text_processor(t) for t in text]
37
+ res = self.tokenizer(
38
+ text,
39
+ return_tensors="jax",
40
+ padding="max_length",
41
+ truncation=True,
42
+ max_length=self.max_text_length,
43
+ ).data
44
+ # tokens used only with super conditioning
45
+ n = len(text)
46
+ res["input_ids_uncond"] = jnp.repeat(self.input_ids_uncond, n, axis=0)
47
+ res["attention_mask_uncond"] = jnp.repeat(self.attention_mask_uncond, n, axis=0)
48
+ return res
49
+
50
+ @classmethod
51
+ def from_pretrained(cls, *args, **kwargs):
52
+ tokenizer = DalleBartTokenizer.from_pretrained(*args, **kwargs)
53
+ config = DalleBartConfig.from_pretrained(*args, **kwargs)
54
+ return cls(tokenizer, config.normalize_text, config.max_text_length)
55
+
56
+
57
+ class DalleBartProcessor(PretrainedFromWandbMixin, DalleBartProcessorBase):
58
+ pass
src/dalle_mini/model/text.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for processing text.
3
+ """
4
+
5
+ import html
6
+ import math
7
+ import random
8
+ import re
9
+ from pathlib import Path
10
+
11
+ import emoji
12
+ import ftfy
13
+ from huggingface_hub import hf_hub_download
14
+ from unidecode import unidecode
15
+
16
+ # based on wiki word occurence
17
+ person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
18
+ temp_token = "xtokx" # avoid repeating chars
19
+
20
+
21
+ class HashtagProcessor:
22
+ # Adapted from wordninja library
23
+ # We use our wikipedia word count + a good heuristic to make it work
24
+ def __init__(self):
25
+ wiki_word_frequency = hf_hub_download(
26
+ "dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt"
27
+ )
28
+ self._word_cost = (
29
+ l.split()[0]
30
+ for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines()
31
+ )
32
+ self._word_cost = {
33
+ str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)
34
+ }
35
+ self._max_word = max(len(x) for x in self._word_cost.keys())
36
+ self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+")
37
+
38
+ def __call__(self, s):
39
+ """Uses dynamic programming to infer the location of spaces in a string without spaces."""
40
+ l = [self._split(x) for x in self._SPLIT_RE.split(s)]
41
+ return " ".join([item for sublist in l for item in sublist])
42
+
43
+ def _split(self, s):
44
+ # Find the best match for the i first characters, assuming cost has
45
+ # been built for the i-1 first characters.
46
+ # Returns a pair (match_cost, match_length).
47
+ def best_match(i):
48
+ candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i]))
49
+ return min(
50
+ (c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1)
51
+ for k, c in candidates
52
+ )
53
+
54
+ # Build the cost array
55
+ cost = [0]
56
+ for i in range(1, len(s) + 1):
57
+ c, k = best_match(i)
58
+ cost.append(c)
59
+
60
+ # Backtrack to recover the minimal-cost string.
61
+ out = []
62
+ i = len(s)
63
+ while i > 0:
64
+ c, k = best_match(i)
65
+ assert c == cost[i]
66
+ newToken = True
67
+ if not s[i - k : i] == "'": # ignore a lone apostrophe
68
+ if len(out) > 0:
69
+ # re-attach split 's and split digits
70
+ if out[-1] == "'s" or (
71
+ s[i - 1].isdigit() and out[-1][0].isdigit()
72
+ ): # digit followed by digit
73
+ out[-1] = (
74
+ s[i - k : i] + out[-1]
75
+ ) # combine current token with previous token
76
+ newToken = False
77
+
78
+ if newToken:
79
+ out.append(s[i - k : i])
80
+
81
+ i -= k
82
+
83
+ return reversed(out)
84
+
85
+
86
+ def replace_person_token(t):
87
+ "Used for CC12M"
88
+ t = re.sub("<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
89
+ while "<person>" in t:
90
+ t = t.replace(
91
+ "<person>", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1
92
+ )
93
+ return t
94
+
95
+
96
+ def fix_html(t):
97
+ # from OpenAI CLIP
98
+ return html.unescape(html.unescape(t))
99
+
100
+
101
+ def replace_punctuation_with_commas(t):
102
+ return re.sub("[()[\].,|:;?!=+~\-\/{}]", ",", t)
103
+
104
+
105
+ def simplify_quotes(t):
106
+ return re.sub("""['"`]""", ' " ', t)
107
+
108
+
109
+ def merge_quotes(t):
110
+ return re.sub('(\s*"+\s*)+', ' " ', t)
111
+
112
+
113
+ def remove_comma_numbers(t):
114
+ def _f(t):
115
+ return re.sub("(\d),(\d{3})", r"\1\2", t)
116
+
117
+ return _f(_f(t))
118
+
119
+
120
+ def pre_process_dot_numbers(t):
121
+ return re.sub("(\w)\.(\w)", rf"\1{temp_token}dot{temp_token}\2", t)
122
+
123
+
124
+ def post_process_dot_numbers(t):
125
+ return re.sub(f"{temp_token}dot{temp_token}", ".", t)
126
+
127
+
128
+ def pre_process_quotes(t):
129
+ # allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
130
+ return re.sub(
131
+ r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", rf"{temp_token}quote{temp_token}", t
132
+ )
133
+
134
+
135
+ def post_process_quotes(t):
136
+ return re.sub(f"{temp_token}quote{temp_token}", "'", t)
137
+
138
+
139
+ def pre_process_dates(t):
140
+ return re.sub("(\d)/(\d)", rf"\1{temp_token}slash{temp_token}\2", t)
141
+
142
+
143
+ def post_process_dates(t):
144
+ return re.sub(f"{temp_token}slash{temp_token}", "/", t)
145
+
146
+
147
+ def merge_commas(t):
148
+ return re.sub("(\s*,+\s*)+", ", ", t)
149
+
150
+
151
+ def add_space_after_commas(t):
152
+ return re.sub(",", ", ", t)
153
+
154
+
155
+ def handle_special_chars(t):
156
+ "Handle special characters"
157
+ # replace "-" with a space when between words without space
158
+ t = re.sub("(\w)-(\w)", r"\1 \2", t)
159
+ # always add space around some characters
160
+ return re.sub("([%&\/$*])", r" \1 ", t)
161
+
162
+
163
+ def expand_hashtags(t, hashtag_processor):
164
+ "Remove # and try to split words"
165
+ return re.sub("#(\w+)", lambda m: hashtag_processor(m.group(1)), t)
166
+
167
+
168
+ _re_ignore_chars = r"[_#\\]"
169
+
170
+
171
+ def ignore_chars(t):
172
+ "Ignore useless characters"
173
+ return re.sub(_re_ignore_chars, " ", t)
174
+
175
+
176
+ def remove_extra_spaces(t):
177
+ "Remove extra spaces (including \t and \n)"
178
+ return re.sub("\s+", " ", t)
179
+
180
+
181
+ def remove_repeating_chars(t):
182
+ "If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance"
183
+ return re.sub(r"(\D)(\1{3,})", r"\1", t)
184
+
185
+
186
+ def remove_urls(t):
187
+ return re.sub(r"http\S+", "", t)
188
+
189
+
190
+ def remove_html_tags(t):
191
+ return re.sub("<[^<]+?>", "", t)
192
+
193
+
194
+ def remove_first_last_commas(t):
195
+ t = t.strip()
196
+ t = t[:-1] if t and t[-1] == "," else t
197
+ t = t[1:] if t and t[0] == "," else t
198
+ return t.strip()
199
+
200
+
201
+ def remove_wiki_ref(t):
202
+ t = re.sub(r"\A\s*\[\d+\]", "", t)
203
+ return re.sub(r"\[\d+\]\s*\Z", "", t)
204
+
205
+
206
+ class TextNormalizer:
207
+ "Normalize text"
208
+
209
+ def __init__(self):
210
+ self._hashtag_processor = HashtagProcessor()
211
+
212
+ def __call__(self, t):
213
+ # fix some characters
214
+ t = ftfy.fix_text(t)
215
+ # fix html
216
+ t = fix_html(t)
217
+ # decode emojis (would be removed by unidecode)
218
+ t = emoji.demojize(t)
219
+ # decode and simplify text: see unidecode library
220
+ t = unidecode(t)
221
+ # lower case
222
+ t = t.lower()
223
+ # replace <PERSON> (for CC12M)
224
+ t = replace_person_token(t)
225
+ # remove wiki reference (for WIT)
226
+ t = remove_wiki_ref(t)
227
+ # remove html tags
228
+ t = remove_html_tags(t)
229
+ # remove urls
230
+ t = remove_urls(t)
231
+ # remove commas in numbers
232
+ t = remove_comma_numbers(t)
233
+ # handle dots in numbers and quotes - Part 1
234
+ t = pre_process_dot_numbers(t)
235
+ t = pre_process_quotes(t)
236
+ t = pre_process_dates(t)
237
+ # handle special characters
238
+ t = handle_special_chars(t)
239
+ # handle hashtags
240
+ t = expand_hashtags(t, self._hashtag_processor)
241
+ # ignore useless characters
242
+ t = ignore_chars(t)
243
+ # simplify quotes
244
+ t = simplify_quotes(t)
245
+ # all punctuation becomes commas
246
+ t = replace_punctuation_with_commas(t)
247
+ # handle dots in numbers and quotes - Part 2
248
+ t = post_process_dot_numbers(t)
249
+ t = post_process_quotes(t)
250
+ t = post_process_dates(t)
251
+ # handle repeating characters
252
+ t = remove_repeating_chars(t)
253
+ # merge quotes
254
+ t = merge_quotes(t)
255
+ # merge commas
256
+ t = merge_commas(t)
257
+ # remove multiple spaces
258
+ t = remove_extra_spaces(t)
259
+ # remove first and last comma
260
+ t = remove_first_last_commas(t)
261
+ # always start with a space
262
+ return f" {t}"
src/dalle_mini/model/tokenizer.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """ DalleBart tokenizer """
2
+ from transformers import BartTokenizerFast
3
+
4
+ from .utils import PretrainedFromWandbMixin
5
+
6
+
7
+ class DalleBartTokenizer(PretrainedFromWandbMixin, BartTokenizerFast):
8
+ pass
src/dalle_mini/model/utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from pathlib import Path
4
+
5
+ import wandb
6
+
7
+
8
+ class PretrainedFromWandbMixin:
9
+ @classmethod
10
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
11
+ """
12
+ Initializes from a wandb artifact or delegates loading to the superclass.
13
+ """
14
+ with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
15
+ if ":" in pretrained_model_name_or_path and not os.path.isdir(
16
+ pretrained_model_name_or_path
17
+ ):
18
+ # wandb artifact
19
+ if wandb.run is not None:
20
+ artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
21
+ else:
22
+ artifact = wandb.Api().artifact(pretrained_model_name_or_path)
23
+ pretrained_model_name_or_path = artifact.download(tmp_dir)
24
+
25
+ return super(PretrainedFromWandbMixin, cls).from_pretrained(
26
+ pretrained_model_name_or_path, *model_args, **kwargs
27
+ )
tools/dataset/encode_dataset.ipynb ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d0b72877",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Pre-encoding a dataset for DALLE·mini"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "ba7b31e6",
14
+ "metadata": {},
15
+ "source": [
16
+ "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
17
+ "\n",
18
+ "Adapt it to your own dataset and image encoder.\n",
19
+ "\n",
20
+ "At the end you should have a dataset of pairs:\n",
21
+ "* a caption defined as a string\n",
22
+ "* an encoded image defined as a list of int."
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "id": "3b59489e",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "from tqdm.notebook import tqdm\n",
33
+ "\n",
34
+ "import torchvision.transforms as T\n",
35
+ "\n",
36
+ "import webdataset as wds\n",
37
+ "\n",
38
+ "import jax\n",
39
+ "import braceexpand\n",
40
+ "from pathlib import Path"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "id": "c7c4c1e6",
46
+ "metadata": {},
47
+ "source": [
48
+ "## Configuration Parameters"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 3,
54
+ "id": "1265dbfe",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "shards = \"my_images/shard-{0000..0008}.tar\" # defined using braceexpand format as used by webdataset\n",
59
+ "encoded_output = Path(\"encoded_data\") # where we will save our encoded data\n",
60
+ "\n",
61
+ "VQGAN_REPO, VQGAN_COMMIT_ID = (\n",
62
+ " \"dalle-mini/vqgan_imagenet_f16_16384\",\n",
63
+ " \"85eb5d3b51a1c62a0cc8f4ccdee9882c0d0bd384\",\n",
64
+ ")\n",
65
+ "\n",
66
+ "# good defaults for a TPU v3-8\n",
67
+ "batch_size = 128 # Per device\n",
68
+ "num_workers = 8 # For parallel processing\n",
69
+ "total_bs = batch_size * jax.device_count() # You can use a smaller size while testing\n",
70
+ "save_frequency = 128 # Number of batches to create a new file (180MB for f16 and 720MB for f8 per file)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 5,
76
+ "id": "cd956ec6-7d98-4d4d-a454-f80fe857eadd",
77
+ "metadata": {},
78
+ "outputs": [
79
+ {
80
+ "data": {
81
+ "text/plain": [
82
+ "['XXX/shard-0000.tar',\n",
83
+ " 'XXX/shard-0001.tar',\n",
84
+ " 'XXX/shard-0002.tar',\n",
85
+ " 'XXX/shard-0003.tar',\n",
86
+ " 'XXX/shard-0004.tar',\n",
87
+ " 'XXX/shard-0005.tar',\n",
88
+ " 'XXX/shard-0006.tar',\n",
89
+ " 'XXX/shard-0007.tar',\n",
90
+ " 'XXX/shard-0008.tar']"
91
+ ]
92
+ },
93
+ "execution_count": 5,
94
+ "metadata": {},
95
+ "output_type": "execute_result"
96
+ }
97
+ ],
98
+ "source": [
99
+ "shards = list(\n",
100
+ " braceexpand.braceexpand(shards)\n",
101
+ ") # better display for tqdm with known length"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "markdown",
106
+ "id": "75dba8e2",
107
+ "metadata": {},
108
+ "source": [
109
+ "## Load data"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "markdown",
114
+ "id": "a1e8fb95",
115
+ "metadata": {},
116
+ "source": [
117
+ "We load data using `webdataset`."
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "id": "9ef5de9e",
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "ds = (\n",
128
+ " wds.WebDataset(shards, handler=wds.warn_and_continue)\n",
129
+ " .decode(\"rgb\", handler=wds.warn_and_continue)\n",
130
+ " .to_tuple(\"jpg\", \"txt\") # assumes image is in `jpg` and caption in `txt`\n",
131
+ " .batched(total_bs) # load in batch per worker (faster)\n",
132
+ ")"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "markdown",
137
+ "id": "90981824",
138
+ "metadata": {},
139
+ "source": [
140
+ "Note:\n",
141
+ "* you can also shuffle shards and items using `shardshuffle` and `shuffle` if necessary.\n",
142
+ "* you may need to resize images in your pipeline (with `map_dict` for example), we assume they are already set to 256x256.\n",
143
+ "* you can also filter out some items using `select`."
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "markdown",
148
+ "id": "129c377d",
149
+ "metadata": {},
150
+ "source": [
151
+ "We can now inspect our data."
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": null,
157
+ "id": "8cac98cb",
158
+ "metadata": {
159
+ "scrolled": true
160
+ },
161
+ "outputs": [],
162
+ "source": [
163
+ "%%time\n",
164
+ "images, captions = next(iter(ds))"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "id": "cd268fbf",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "images.shape"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "id": "5acfc4d8",
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "captions[:10]"
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": null,
190
+ "id": "c24693c0",
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "T.ToPILImage()(images[0].permute(2, 0, 1))"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "markdown",
199
+ "id": "3059ffb1",
200
+ "metadata": {},
201
+ "source": [
202
+ "Finally we create our dataloader."
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "id": "c227c551",
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "dl = (\n",
213
+ " wds.WebLoader(ds, batch_size=None, num_workers=8).unbatched().batched(total_bs)\n",
214
+ ") # avoid partial batch at the end of each worker"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "markdown",
219
+ "id": "a354472b",
220
+ "metadata": {},
221
+ "source": [
222
+ "## Image encoder\n",
223
+ "\n",
224
+ "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": null,
230
+ "id": "47a8b818",
231
+ "metadata": {
232
+ "scrolled": true
233
+ },
234
+ "outputs": [],
235
+ "source": [
236
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
237
+ "from flax.jax_utils import replicate\n",
238
+ "\n",
239
+ "vqgan = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")\n",
240
+ "vqgan_params = replicate(vqgan.params)"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "markdown",
245
+ "id": "62ad01c3",
246
+ "metadata": {},
247
+ "source": [
248
+ "## Encoding"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "markdown",
253
+ "id": "20357f74",
254
+ "metadata": {},
255
+ "source": [
256
+ "Encoding is really simple using `shard` to automatically distribute batches across devices and `pmap`."
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": null,
262
+ "id": "322a4619",
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "from flax.training.common_utils import shard\n",
267
+ "from functools import partial\n",
268
+ "\n",
269
+ "\n",
270
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
271
+ "def p_encode(batch, params):\n",
272
+ " # Not sure if we should `replicate` params, does not seem to have any effect\n",
273
+ " _, indices = vqgan.encode(batch, params=params)\n",
274
+ " return indices"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": null,
280
+ "id": "ff6c10d4",
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "import pandas as pd\n",
285
+ "\n",
286
+ "\n",
287
+ "def encode_dataset(dataloader, output_dir, save_frequency):\n",
288
+ " output_dir.mkdir(parents=True, exist_ok=True)\n",
289
+ " all_captions = []\n",
290
+ " all_encoding = []\n",
291
+ " n_file = 1\n",
292
+ " for idx, (images, captions) in enumerate(tqdm(dataloader)):\n",
293
+ " images = images.numpy()\n",
294
+ " n = len(images) // 8 * 8\n",
295
+ " if n != len(images):\n",
296
+ " # get the max number of images we can (multiple of 8)\n",
297
+ " print(f\"Different sizes {n} vs {len(images)}\")\n",
298
+ " images = images[:n]\n",
299
+ " captions = captions[:n]\n",
300
+ " if not len(captions):\n",
301
+ " print(f\"No images/captions in batch...\")\n",
302
+ " continue\n",
303
+ " images = shard(images)\n",
304
+ " encoded = p_encode(images, vqgan_params)\n",
305
+ " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
306
+ " all_captions.extend(captions)\n",
307
+ " all_encoding.extend(encoded.tolist())\n",
308
+ "\n",
309
+ " # save files\n",
310
+ " if (idx + 1) % save_frequency == 0:\n",
311
+ " print(f\"Saving file {n_file}\")\n",
312
+ " batch_df = pd.DataFrame.from_dict(\n",
313
+ " {\"caption\": all_captions, \"encoding\": all_encoding}\n",
314
+ " )\n",
315
+ " batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")\n",
316
+ " all_captions = []\n",
317
+ " all_encoding = []\n",
318
+ " n_file += 1\n",
319
+ "\n",
320
+ " if len(all_captions):\n",
321
+ " print(f\"Saving final file {n_file}\")\n",
322
+ " batch_df = pd.DataFrame.from_dict(\n",
323
+ " {\"caption\": all_captions, \"encoding\": all_encoding}\n",
324
+ " )\n",
325
+ " batch_df.to_parquet(f\"{output_dir}/{n_file:03d}.parquet\")"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "id": "7704863d",
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": [
335
+ "encode_dataset(dl, output_dir=encoded_output, save_frequency=save_frequency)"
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "markdown",
340
+ "id": "8953dd84",
341
+ "metadata": {},
342
+ "source": [
343
+ "----"
344
+ ]
345
+ }
346
+ ],
347
+ "metadata": {
348
+ "interpreter": {
349
+ "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
350
+ },
351
+ "kernelspec": {
352
+ "display_name": "Python 3 (ipykernel)",
353
+ "language": "python",
354
+ "name": "python3"
355
+ },
356
+ "language_info": {
357
+ "codemirror_mode": {
358
+ "name": "ipython",
359
+ "version": 3
360
+ },
361
+ "file_extension": ".py",
362
+ "mimetype": "text/x-python",
363
+ "name": "python",
364
+ "nbconvert_exporter": "python",
365
+ "pygments_lexer": "ipython3",
366
+ "version": "3.9.7"
367
+ }
368
+ },
369
+ "nbformat": 4,
370
+ "nbformat_minor": 5
371
+ }
tools/inference/inference_pipeline.ipynb ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "118UKH5bWCGa"
17
+ },
18
+ "source": [
19
+ "# DALL·E mini - Inference pipeline\n",
20
+ "\n",
21
+ "*Generate images from a text prompt*\n",
22
+ "\n",
23
+ "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
24
+ "\n",
25
+ "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
26
+ "\n",
27
+ "Just want to play? Use [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).\n",
28
+ "\n",
29
+ "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "metadata": {
35
+ "id": "dS8LbaonYm3a"
36
+ },
37
+ "source": [
38
+ "## 🛠️ Installation and set-up"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {
45
+ "id": "uzjAM2GBYpZX"
46
+ },
47
+ "outputs": [],
48
+ "source": [
49
+ "# Install required libraries\n",
50
+ "!pip install -q git+https://github.com/huggingface/transformers.git\n",
51
+ "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
52
+ "!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "metadata": {
58
+ "id": "ozHzTkyv8cqU"
59
+ },
60
+ "source": [
61
+ "We load required models:\n",
62
+ "* dalle·mini for text to encoded images\n",
63
+ "* VQGAN for decoding images\n",
64
+ "* CLIP for scoring predictions"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "metadata": {
71
+ "id": "K6CxW2o42f-w"
72
+ },
73
+ "outputs": [],
74
+ "source": [
75
+ "# Model references\n",
76
+ "\n",
77
+ "# dalle-mini\n",
78
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/model-3f0lem84:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
79
+ "DALLE_COMMIT_ID = None\n",
80
+ "\n",
81
+ "# VQGAN model\n",
82
+ "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
83
+ "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
84
+ "\n",
85
+ "# CLIP model\n",
86
+ "CLIP_REPO = \"openai/clip-vit-large-patch14\"\n",
87
+ "CLIP_COMMIT_ID = None"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {
94
+ "id": "Yv-aR3t4Oe5v"
95
+ },
96
+ "outputs": [],
97
+ "source": [
98
+ "import jax\n",
99
+ "import jax.numpy as jnp\n",
100
+ "\n",
101
+ "# check how many devices are available\n",
102
+ "jax.local_device_count()"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {
109
+ "id": "HWnQrQuXOe5w"
110
+ },
111
+ "outputs": [],
112
+ "source": [
113
+ "# type used for computation - use bfloat16 on TPU's\n",
114
+ "dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n",
115
+ "\n",
116
+ "# TODO: fix issue with bfloat16\n",
117
+ "dtype = jnp.float32"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "metadata": {
124
+ "id": "92zYmvsQ38vL"
125
+ },
126
+ "outputs": [],
127
+ "source": [
128
+ "# Load models & tokenizer\n",
129
+ "from dalle_mini import DalleBart, DalleBartProcessor\n",
130
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
131
+ "from transformers import CLIPProcessor, FlaxCLIPModel\n",
132
+ "\n",
133
+ "# Load dalle-mini\n",
134
+ "model = DalleBart.from_pretrained(\n",
135
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
136
+ ")\n",
137
+ "\n",
138
+ "# Load VQGAN\n",
139
+ "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
140
+ "\n",
141
+ "# Load CLIP\n",
142
+ "clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
143
+ "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "markdown",
148
+ "metadata": {
149
+ "id": "o_vH2X1tDtzA"
150
+ },
151
+ "source": [
152
+ "Model parameters are replicated on each device for faster inference."
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "metadata": {
159
+ "id": "wtvLoM48EeVw"
160
+ },
161
+ "outputs": [],
162
+ "source": [
163
+ "from flax.jax_utils import replicate\n",
164
+ "\n",
165
+ "# convert model parameters for inference if requested\n",
166
+ "if dtype == jnp.bfloat16:\n",
167
+ " model.params = model.to_bf16(model.params)\n",
168
+ "\n",
169
+ "model._params = replicate(model.params)\n",
170
+ "vqgan._params = replicate(vqgan.params)\n",
171
+ "clip._params = replicate(clip.params)"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "markdown",
176
+ "metadata": {
177
+ "id": "0A9AHQIgZ_qw"
178
+ },
179
+ "source": [
180
+ "Model functions are compiled and parallelized to take advantage of multiple devices."
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": null,
186
+ "metadata": {
187
+ "id": "sOtoOmYsSYPz"
188
+ },
189
+ "outputs": [],
190
+ "source": [
191
+ "from functools import partial\n",
192
+ "\n",
193
+ "# model inference\n",
194
+ "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
195
+ "def p_generate(\n",
196
+ " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
197
+ "):\n",
198
+ " return model.generate(\n",
199
+ " **tokenized_prompt,\n",
200
+ " prng_key=key,\n",
201
+ " params=params,\n",
202
+ " top_k=top_k,\n",
203
+ " top_p=top_p,\n",
204
+ " temperature=temperature,\n",
205
+ " condition_scale=condition_scale,\n",
206
+ " )\n",
207
+ "\n",
208
+ "\n",
209
+ "# decode images\n",
210
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
211
+ "def p_decode(indices, params):\n",
212
+ " return vqgan.decode_code(indices, params=params)\n",
213
+ "\n",
214
+ "\n",
215
+ "# score images\n",
216
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
217
+ "def p_clip(inputs, params):\n",
218
+ " logits = clip(params=params, **inputs).logits_per_image\n",
219
+ " return logits"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "markdown",
224
+ "metadata": {
225
+ "id": "HmVN6IBwapBA"
226
+ },
227
+ "source": [
228
+ "Keys are passed to the model on each device to generate unique inference per device."
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": null,
234
+ "metadata": {
235
+ "id": "4CTXmlUkThhX"
236
+ },
237
+ "outputs": [],
238
+ "source": [
239
+ "import random\n",
240
+ "\n",
241
+ "# create a random key\n",
242
+ "seed = random.randint(0, 2**32 - 1)\n",
243
+ "key = jax.random.PRNGKey(seed)"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "markdown",
248
+ "metadata": {
249
+ "id": "BrnVyCo81pij"
250
+ },
251
+ "source": [
252
+ "## 🖍 Text Prompt"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "markdown",
257
+ "metadata": {
258
+ "id": "rsmj0Aj5OQox"
259
+ },
260
+ "source": [
261
+ "Our model requires processing prompts."
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": null,
267
+ "metadata": {
268
+ "id": "YjjhUychOVxm"
269
+ },
270
+ "outputs": [],
271
+ "source": [
272
+ "from dalle_mini import DalleBartProcessor\n",
273
+ "\n",
274
+ "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "markdown",
279
+ "metadata": {
280
+ "id": "BQ7fymSPyvF_"
281
+ },
282
+ "source": [
283
+ "Let's define a text prompt."
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": null,
289
+ "metadata": {
290
+ "id": "x_0vI9ge1oKr"
291
+ },
292
+ "outputs": [],
293
+ "source": [
294
+ "prompt = \"sunset over the lake in the mountains\""
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "execution_count": null,
300
+ "metadata": {
301
+ "id": "VKjEZGjtO49k"
302
+ },
303
+ "outputs": [],
304
+ "source": [
305
+ "tokenized_prompt = processor([prompt])"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "markdown",
310
+ "metadata": {
311
+ "id": "-CEJBnuJOe5z"
312
+ },
313
+ "source": [
314
+ "Finally we replicate it onto each device."
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": null,
320
+ "metadata": {
321
+ "id": "lQePgju5Oe5z"
322
+ },
323
+ "outputs": [],
324
+ "source": [
325
+ "tokenized_prompt = replicate(tokenized_prompt)"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "markdown",
330
+ "metadata": {
331
+ "id": "phQ9bhjRkgAZ"
332
+ },
333
+ "source": [
334
+ "## 🎨 Generate images\n",
335
+ "\n",
336
+ "We generate images using dalle-mini model and decode them with the VQGAN."
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "execution_count": null,
342
+ "metadata": {
343
+ "id": "d0wVkXpKqnHA"
344
+ },
345
+ "outputs": [],
346
+ "source": [
347
+ "# number of predictions\n",
348
+ "n_predictions = 32\n",
349
+ "\n",
350
+ "# We can customize top_k/top_p used for generating samples\n",
351
+ "gen_top_k = None\n",
352
+ "gen_top_p = None\n",
353
+ "temperature = 0.85\n",
354
+ "cond_scale = 3.0"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": null,
360
+ "metadata": {
361
+ "id": "SDjEx9JxR3v8"
362
+ },
363
+ "outputs": [],
364
+ "source": [
365
+ "from flax.training.common_utils import shard_prng_key\n",
366
+ "import numpy as np\n",
367
+ "from PIL import Image\n",
368
+ "from tqdm.notebook import trange\n",
369
+ "\n",
370
+ "# generate images\n",
371
+ "images = []\n",
372
+ "for i in trange(n_predictions // jax.device_count()):\n",
373
+ " # get a new key\n",
374
+ " key, subkey = jax.random.split(key)\n",
375
+ " # generate images\n",
376
+ " encoded_images = p_generate(\n",
377
+ " tokenized_prompt,\n",
378
+ " shard_prng_key(subkey),\n",
379
+ " model.params,\n",
380
+ " gen_top_k,\n",
381
+ " gen_top_p,\n",
382
+ " temperature,\n",
383
+ " cond_scale,\n",
384
+ " )\n",
385
+ " # remove BOS\n",
386
+ " encoded_images = encoded_images.sequences[..., 1:]\n",
387
+ " # decode images\n",
388
+ " decoded_images = p_decode(encoded_images, vqgan.params)\n",
389
+ " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
390
+ " for img in decoded_images:\n",
391
+ " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "markdown",
396
+ "metadata": {
397
+ "id": "tw02wG9zGmyB"
398
+ },
399
+ "source": [
400
+ "Let's calculate their score with CLIP."
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": null,
406
+ "metadata": {
407
+ "id": "FoLXpjCmGpju"
408
+ },
409
+ "outputs": [],
410
+ "source": [
411
+ "from flax.training.common_utils import shard\n",
412
+ "\n",
413
+ "# get clip scores\n",
414
+ "clip_inputs = clip_processor(\n",
415
+ " text=[prompt] * jax.device_count(),\n",
416
+ " images=images,\n",
417
+ " return_tensors=\"np\",\n",
418
+ " padding=\"max_length\",\n",
419
+ " max_length=77,\n",
420
+ " truncation=True,\n",
421
+ ").data\n",
422
+ "logits = p_clip(shard(clip_inputs), clip.params)\n",
423
+ "logits = logits.squeeze().flatten()"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "markdown",
428
+ "metadata": {
429
+ "id": "4AAWRm70LgED"
430
+ },
431
+ "source": [
432
+ "Let's display images ranked by CLIP score."
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "execution_count": null,
438
+ "metadata": {
439
+ "id": "zsgxxubLLkIu"
440
+ },
441
+ "outputs": [],
442
+ "source": [
443
+ "print(f\"Prompt: {prompt}\\n\")\n",
444
+ "for idx in logits.argsort()[::-1]:\n",
445
+ " display(images[idx])\n",
446
+ " print(f\"Score: {logits[idx]:.2f}\\n\")"
447
+ ]
448
+ }
449
+ ],
450
+ "metadata": {
451
+ "accelerator": "GPU",
452
+ "colab": {
453
+ "collapsed_sections": [],
454
+ "include_colab_link": true,
455
+ "machine_shape": "hm",
456
+ "name": "DALL·E mini - Inference pipeline.ipynb",
457
+ "provenance": []
458
+ },
459
+ "kernelspec": {
460
+ "display_name": "Python 3 (ipykernel)",
461
+ "language": "python",
462
+ "name": "python3"
463
+ },
464
+ "language_info": {
465
+ "codemirror_mode": {
466
+ "name": "ipython",
467
+ "version": 3
468
+ },
469
+ "file_extension": ".py",
470
+ "mimetype": "text/x-python",
471
+ "name": "python",
472
+ "nbconvert_exporter": "python",
473
+ "pygments_lexer": "ipython3",
474
+ "version": "3.9.7"
475
+ }
476
+ },
477
+ "nbformat": 4,
478
+ "nbformat_minor": 0
479
+ }
tools/train/config/medium/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "attention_dropout": 0.0,
5
+ "bos_token_id": 16385,
6
+ "d_model": 1408,
7
+ "decoder_attention_heads": 16,
8
+ "decoder_ffn_dim": 4096,
9
+ "decoder_layerdrop": 0.0,
10
+ "decoder_layers": 14,
11
+ "decoder_start_token_id": 16384,
12
+ "dropout": 0.0,
13
+ "encoder_attention_heads": 16,
14
+ "encoder_ffn_dim": 4096,
15
+ "encoder_layerdrop": 0.0,
16
+ "encoder_layers": 14,
17
+ "encoder_vocab_size": 50264,
18
+ "eos_token_id": 16385,
19
+ "gradient_checkpointing": false,
20
+ "image_length": 256,
21
+ "image_vocab_size": 16384,
22
+ "init_std": 0.01,
23
+ "is_encoder_decoder": true,
24
+ "max_text_length": 64,
25
+ "model_type": "dallebart",
26
+ "normalize_text": true,
27
+ "pad_token_id": 16385,
28
+ "scale_embedding": false,
29
+ "tie_word_embeddings": false,
30
+ "use_cache": true
31
+ }
tools/train/config/mega/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "attention_dropout": 0.0,
5
+ "bos_token_id": 16385,
6
+ "d_model": 2048,
7
+ "decoder_attention_heads": 32,
8
+ "decoder_ffn_dim": 8192,
9
+ "decoder_layerdrop": 0.0,
10
+ "decoder_layers": 24,
11
+ "decoder_start_token_id": 16384,
12
+ "dropout": 0.0,
13
+ "encoder_attention_heads": 32,
14
+ "encoder_ffn_dim": 8192,
15
+ "encoder_layerdrop": 0.0,
16
+ "encoder_layers": 24,
17
+ "encoder_vocab_size": 50264,
18
+ "eos_token_id": 16385,
19
+ "image_length": 256,
20
+ "image_vocab_size": 16391,
21
+ "init_std": 0.01,
22
+ "is_encoder_decoder": true,
23
+ "max_text_length": 64,
24
+ "model_type": "dallebart",
25
+ "normalize_text": true,
26
+ "pad_token_id": 16385,
27
+ "scale_embedding": false,
28
+ "tie_word_embeddings": false,
29
+ "use_cache": true
30
+ }
tools/train/config/micro/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "attention_dropout": 0.0,
5
+ "bos_token_id": 16385,
6
+ "d_model": 256,
7
+ "decoder_attention_heads": 2,
8
+ "decoder_ffn_dim": 256,
9
+ "decoder_layerdrop": 0.0,
10
+ "decoder_layers": 2,
11
+ "decoder_start_token_id": 16384,
12
+ "dropout": 0.0,
13
+ "encoder_attention_heads": 2,
14
+ "encoder_ffn_dim": 256,
15
+ "encoder_layerdrop": 0.0,
16
+ "encoder_layers": 2,
17
+ "encoder_vocab_size": 50264,
18
+ "eos_token_id": 16385,
19
+ "image_length": 256,
20
+ "image_vocab_size": 16391,
21
+ "init_std": 0.02,
22
+ "is_encoder_decoder": true,
23
+ "max_text_length": 64,
24
+ "model_type": "dallebart",
25
+ "normalize_text": true,
26
+ "pad_token_id": 16385,
27
+ "scale_embedding": false,
28
+ "tie_word_embeddings": false,
29
+ "use_cache": true
30
+ }
tools/train/config/mini/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "attention_dropout": 0.0,
5
+ "bos_token_id": 16385,
6
+ "d_model": 1024,
7
+ "decoder_attention_heads": 16,
8
+ "decoder_ffn_dim": 4096,
9
+ "decoder_layers": 12,
10
+ "decoder_start_token_id": 16384,
11
+ "dropout": 0.0,
12
+ "encoder_attention_heads": 16,
13
+ "encoder_ffn_dim": 4096,
14
+ "encoder_layers": 12,
15
+ "encoder_vocab_size": 50264,
16
+ "eos_token_id": 16385,
17
+ "gradient_checkpointing": false,
18
+ "image_length": 256,
19
+ "image_vocab_size": 16384,
20
+ "init_std": 0.02,
21
+ "is_encoder_decoder": true,
22
+ "max_text_length": 64,
23
+ "model_type": "dallebart",
24
+ "normalize_text": true,
25
+ "pad_token_id": 16385,
26
+ "scale_embedding": false,
27
+ "tie_word_embeddings": false,
28
+ "use_cache": true
29
+ }
tools/train/config/mini_glu/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "attention_dropout": 0.0,
5
+ "bos_token_id": 16385,
6
+ "d_model": 1024,
7
+ "decoder_attention_heads": 16,
8
+ "decoder_ffn_dim": 2730,
9
+ "decoder_layers": 12,
10
+ "decoder_start_token_id": 16384,
11
+ "dropout": 0.0,
12
+ "encoder_attention_heads": 16,
13
+ "encoder_ffn_dim": 2730,
14
+ "encoder_layers": 12,
15
+ "encoder_vocab_size": 50264,
16
+ "eos_token_id": 16385,
17
+ "gradient_checkpointing": false,
18
+ "image_length": 256,
19
+ "image_vocab_size": 16384,
20
+ "init_std": 0.02,
21
+ "is_encoder_decoder": true,
22
+ "max_text_length": 64,
23
+ "model_type": "dallebart",
24
+ "normalize_text": true,
25
+ "pad_token_id": 16385,
26
+ "scale_embedding": false,
27
+ "tie_word_embeddings": false,
28
+ "use_cache": true
29
+ }
tools/train/scalable_shampoo/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Notes
2
+
3
+ Files copied from [google-research/scalable_shampoo/optax](https://github.com/google-research/google-research/tree/master/scalable_shampoo/optax).
4
+
5
+ Imports have been modified to be relative.
6
+
7
+ This will eventually be replaced with `optax-shampoo` package.
tools/train/scalable_shampoo/distributed_shampoo.py ADDED
@@ -0,0 +1,2267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # An implementation of distributed Shampoo optimizer from:
17
+ #
18
+ # Scalable Second Order Optimization for Deep Learning
19
+ # Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
20
+ # Preprint Paper: https://arxiv.org/abs/2002.09018
21
+ #
22
+ # This implementation moves computation of inverse pth root back to the
23
+ # accelerator (if higher precision is available).
24
+ #
25
+ # Authors: Rohan Anil (rohananil at google dot com)
26
+ # & Vineet Gupta (vineet at google dot com)
27
+ #
28
+ """Distributed Shampoo Implementation."""
29
+
30
+ import enum
31
+ import functools
32
+ import itertools
33
+ from typing import Any, List, NamedTuple, Tuple
34
+
35
+ import chex
36
+ import jax
37
+ import jax.experimental.pjit as pjit
38
+ import jax.numpy as jnp
39
+ import numpy as np
40
+ import optax
41
+ from flax import struct
42
+ from jax import lax
43
+
44
+ from .quantization_utils import QuantizedValue
45
+ from .symmetric_matrices import symmetric_matrices
46
+
47
+ # Dtype for inverse-pth root routine
48
+ # Switch to f64 if you have hardware that supports it. Enable the jax flag
49
+ # jax_enable_x64 for this to work, otherwise it will default to float32.
50
+ _MAT_INV_PTH_ROOT_DTYPE = jnp.float64
51
+
52
+
53
+ @struct.dataclass
54
+ class TrainingMetrics:
55
+ inverse_pth_root_errors: chex.Array # Error for inverse-pth roots.
56
+ # TODO(rohananil): Add more important metrics to track during training.
57
+
58
+
59
+ # Per parameter optimizer state used in data-parallel training.
60
+ class ParameterStats(NamedTuple):
61
+ """State associated to each parameter of the model being trained."""
62
+
63
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
64
+ statistics: List[Any] # Statistics (QuantizedValue, chex.Array)
65
+ preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
66
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
67
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
68
+ training_metrics: TrainingMetrics # Metrics (optional for training).
69
+
70
+
71
+ # For training extremely large model; We keep a global state with a concatenated
72
+ # statistics and preconditioner states for all vars. This is so that we can
73
+ # annotate the leading axis to be sharded to save memory at the cost of
74
+ # communication.
75
+ @struct.dataclass
76
+ class GlobalShardedParameterStats:
77
+ statistics: chex.Array # Statistics
78
+ preconditioners: chex.Array # Preconditioners
79
+ exponents: chex.Array # exponents
80
+
81
+
82
+ # These are per-parameter local states; All statistics here mirror the parameter
83
+ # Thus the sharding is copied over from the param specification.
84
+ @struct.dataclass
85
+ class LocalShardedParameterStats:
86
+ """State associated to each parameter of the model being trained."""
87
+
88
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
89
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
90
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
91
+ training_metrics: TrainingMetrics # Metrics (optional for training).
92
+ index_start: np.int32 = struct.field(
93
+ pytree_node=False
94
+ ) # Index into global statistics array
95
+ sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
96
+
97
+
98
+ def init_training_metrics(num_statistics):
99
+ # Since the downstream apis expect a jnp.array - we create a dummy one if
100
+ # num_statistics=0.
101
+ n = 1 if not num_statistics else num_statistics
102
+ return TrainingMetrics(jnp.zeros([n], jnp.float32))
103
+
104
+
105
+ def init_training_metrics_shapes(num_statistics):
106
+ # Since the downstream apis expect a jnp.array - we create a dummy one if
107
+ # num_statistics=0.
108
+ n = 1 if not num_statistics else num_statistics
109
+ return TrainingMetrics([[n], jnp.float32])
110
+
111
+
112
+ def init_training_metrics_pspec():
113
+ return TrainingMetrics(pjit.PartitionSpec())
114
+
115
+
116
+ class ShardedShampooStats(NamedTuple):
117
+ """Shampoo state in sharded mode."""
118
+
119
+ global_stats: Any
120
+ local_stats: Any
121
+
122
+
123
+ class ShampooState(NamedTuple):
124
+ count: chex.Array
125
+ stats: Any
126
+
127
+
128
+ class InitFnState(NamedTuple):
129
+ init_fn: Any
130
+ pspec_fn: Any
131
+ shape_and_dtype_fn: Any
132
+
133
+
134
+ class GraftingType(enum.IntEnum):
135
+ SGD = 1
136
+ ADAGRAD = 2
137
+ RMSPROP = 3
138
+ RMSPROP_NORMALIZED = 4
139
+ SQRT_N = 5
140
+ ADAGRAD_NORMALIZED = 6
141
+
142
+
143
+ def power_iteration(
144
+ matrix,
145
+ num_iters=100,
146
+ error_tolerance=1e-6,
147
+ precision=lax.Precision.HIGHEST,
148
+ ):
149
+ r"""Power iteration algorithm.
150
+
151
+ The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
152
+ a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
153
+ of `A`, and a vector v, which is the corresponding eigenvector of `A`.
154
+
155
+ References:
156
+ [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
157
+
158
+ Args:
159
+ matrix: the symmetric PSD matrix.
160
+ num_iters: Number of iterations.
161
+ error_tolerance: Iterative exit condition.
162
+ precision: precision XLA related flag, the available options are: a)
163
+ lax.Precision.DEFAULT (better step time, but not precise) b)
164
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
165
+ (best possible precision, slowest)
166
+
167
+ Returns:
168
+ eigen vector, eigen value
169
+ """
170
+ matrix_size = matrix.shape[-1]
171
+
172
+ def _iter_condition(state):
173
+ i, unused_v, unused_s, unused_s_v, run_step = state
174
+ return jnp.logical_and(i < num_iters, run_step)
175
+
176
+ def _iter_body(state):
177
+ """One step of power iteration."""
178
+ i, new_v, s, s_v, unused_run_step = state
179
+ new_v = new_v / jnp.linalg.norm(new_v)
180
+
181
+ s_v = jnp.einsum("ij,j->i", matrix, new_v, precision=precision)
182
+ s_new = jnp.einsum("i,i->", new_v, s_v, precision=precision)
183
+ return (
184
+ i + 1,
185
+ s_v,
186
+ s_new,
187
+ s_v,
188
+ jnp.greater(jnp.abs(s_new - s), error_tolerance),
189
+ )
190
+
191
+ # Figure out how to use step as seed for random.
192
+ v_0 = (
193
+ np.random.RandomState(1729).uniform(-1.0, 1.0, matrix_size).astype(matrix.dtype)
194
+ )
195
+
196
+ init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
197
+ _, v_out, s_out, _, _ = lax.while_loop(_iter_condition, _iter_body, init_state)
198
+ v_out = v_out / jnp.linalg.norm(v_out)
199
+ return v_out, s_out
200
+
201
+
202
+ def mat_power(
203
+ mat_m,
204
+ p,
205
+ precision=lax.Precision.HIGHEST,
206
+ ):
207
+ """A simple matrix power method. M^p where p can be TracedValue."""
208
+ power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
209
+
210
+ def _iter_condition(state):
211
+ i, _, _ = state
212
+ return i > 0
213
+
214
+ def _iter_body(state):
215
+ i, power, mat = state
216
+
217
+ power = jax.lax.cond(
218
+ i % 2 == 1,
219
+ lambda: jnp.matmul(mat, power, precision=precision),
220
+ lambda: power,
221
+ )
222
+ i //= 2
223
+ mat = jnp.matmul(mat, mat, precision=precision)
224
+ return i, power, mat
225
+
226
+ _, result, _ = lax.while_loop(_iter_condition, _iter_body, (p, power, mat_m))
227
+ return result
228
+
229
+
230
+ def matrix_inverse_pth_root(
231
+ matrix,
232
+ p,
233
+ num_iters=100,
234
+ ridge_epsilon=1e-6,
235
+ error_tolerance=1e-6,
236
+ precision=lax.Precision.HIGHEST,
237
+ ):
238
+ """Computes `matrix^(-1/p)`, where `p` is a positive integer.
239
+
240
+ This function uses the Coupled newton iterations algorithm for
241
+ the computation of a matrix's inverse pth root.
242
+
243
+
244
+ References:
245
+ [Functions of Matrices, Theory and Computation,
246
+ Nicholas J Higham, Pg 184, Eq 7.18](
247
+ https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
248
+
249
+ Args:
250
+ matrix: the symmetric PSD matrix whose power it to be computed
251
+ p: exponent, for p a positive integer.
252
+ num_iters: Maximum number of iterations.
253
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
254
+ error_tolerance: Error indicator, useful for early termination.
255
+ precision: precision XLA related flag, the available options are: a)
256
+ lax.Precision.DEFAULT (better step time, but not precise) b)
257
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
258
+ (best possible precision, slowest)
259
+
260
+ Returns:
261
+ matrix^(-1/p)
262
+ """
263
+
264
+ # If the input is not square, materialize it from the concatenated form.
265
+ if matrix.shape[0] != matrix.shape[1]:
266
+ matrix = symmetric_matrices.materialize_matrix_from_concat(matrix)
267
+
268
+ assert matrix.shape[0] == matrix.shape[1]
269
+
270
+ # We use _MAT_INV_PTH_ROOT_DTYPE for the matrix inverse pth root.
271
+ # Switch to f64 if you have hardware that supports it. Enable the jax flag
272
+ # jax_enable_x64 for this to work.
273
+ matrix_size = matrix.shape[0]
274
+ orig_dtype = matrix.dtype
275
+ matrix = matrix.astype(_MAT_INV_PTH_ROOT_DTYPE)
276
+ alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE)
277
+ identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
278
+ _, max_ev = power_iteration(
279
+ matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision
280
+ )
281
+ ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-6)
282
+
283
+ def _iter_condition(state):
284
+ (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state
285
+ error_above_threshold = jnp.logical_and(error > error_tolerance, run_step)
286
+ return jnp.logical_and(i < num_iters, error_above_threshold)
287
+
288
+ def _iter_body(state):
289
+ (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state
290
+ mat_m_i = (1 - alpha) * identity + alpha * mat_m
291
+ new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision)
292
+ new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision)
293
+ new_error = jnp.max(jnp.abs(new_mat_m - identity))
294
+ # sometimes error increases after an iteration before decreasing and
295
+ # converging. 1.2 factor is used to bound the maximal allowed increase.
296
+ return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, new_error < error * 1.2)
297
+
298
+ if matrix_size == 1:
299
+ resultant_mat_h = (matrix + ridge_epsilon) ** alpha
300
+ error = 0
301
+ else:
302
+ damped_matrix = matrix + ridge_epsilon * identity
303
+
304
+ z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix))
305
+ new_mat_m_0 = damped_matrix * z
306
+ new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
307
+ new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
308
+ init_state = tuple([0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True])
309
+ _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
310
+ _iter_condition, _iter_body, init_state
311
+ )
312
+ error = jnp.max(jnp.abs(mat_m - identity)).astype(jnp.float32)
313
+ is_converged = jnp.asarray(convergence, old_mat_h.dtype)
314
+ resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
315
+ resultant_mat_h = jnp.asarray(resultant_mat_h, orig_dtype)
316
+ return resultant_mat_h, error
317
+
318
+
319
+ def merge_small_dims(shape_to_merge, max_dim):
320
+ """Merge small dimensions.
321
+
322
+ If there are some small dimensions, we collapse them:
323
+ e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
324
+ [1, 2, 768, 1, 2048] --> [2, 768, 2048]
325
+
326
+ Args:
327
+ shape_to_merge: Shape to merge small dimensions.
328
+ max_dim: Maximal dimension of output shape used in merging.
329
+
330
+ Returns:
331
+ Merged shape.
332
+ """
333
+ if shape_to_merge and np.all(np.array(shape_to_merge) == 1):
334
+ return [1]
335
+
336
+ resulting_shape = []
337
+ product = 1
338
+ for d in shape_to_merge:
339
+ if product * d <= max_dim:
340
+ product *= d
341
+ else:
342
+ if product > 1:
343
+ resulting_shape.append(product)
344
+ product = d
345
+ if product > 1:
346
+ resulting_shape.append(product)
347
+ return resulting_shape
348
+
349
+
350
+ def pad_square_matrix(mat, max_size):
351
+ """Pad a square matrix up to max_size.
352
+
353
+ Args:
354
+ mat: a matrix to pad.
355
+ max_size: matrix size requested.
356
+
357
+ Returns:
358
+ Given M returns [[M, 0], [0, I]]
359
+ """
360
+ rows, cols = mat.shape
361
+ if rows != cols:
362
+ raise ValueError(
363
+ "Must have rows == cols, instead got " f"rows={rows}, cols={cols}"
364
+ )
365
+ if cols > max_size:
366
+ raise ValueError(
367
+ "Must have cols <= max_size. Instead got "
368
+ f"cols={cols}, max_size={max_size}."
369
+ )
370
+ if rows == max_size:
371
+ return mat
372
+ pad_size = max_size - rows
373
+
374
+ zs1 = jnp.zeros([rows, pad_size], dtype=mat.dtype)
375
+ zs2 = jnp.zeros([pad_size, rows], dtype=mat.dtype)
376
+ eye = jnp.eye(pad_size, dtype=mat.dtype)
377
+ mat = jnp.concatenate([mat, zs1], 1)
378
+ mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
379
+ return mat
380
+
381
+
382
+ def make_sliced_padding(
383
+ symmetric_block_size,
384
+ num_blocks,
385
+ starting_block,
386
+ dtype,
387
+ ):
388
+ """Returns padding for symmetric block matrix.
389
+
390
+ Specifically, the padding is given concatenated rectangular matrices
391
+ representing the lower-triangular rows below the starting block. For example,
392
+ if we want to pad the symmetric matrix
393
+
394
+ M = [[A, B^T]
395
+ [B, C]],
396
+
397
+ the desired output (in terms of the full matrix) with num_blocks = 4 is
398
+
399
+ M_padded = [[A, B^T, 0, 0]
400
+ [B, C, 0, 0]
401
+ [0, 0, I, 0]
402
+ 0, 0, 0, I].
403
+
404
+ We would represent M as the block matrix mat = [A, B, C]. In this form, the
405
+ additional padding to provide has form [0, 0, I, 0, 0, 0, I] (only the lower
406
+ triangular parts in the third and fourth rows).
407
+
408
+ Args:
409
+ symmetric_block_size: The size of each block.
410
+ num_blocks: The total number of blocks.
411
+ starting_block: The block where to start the padding.
412
+ dtype: The type to use for the blocks.
413
+ """
414
+ if starting_block == num_blocks:
415
+ return jnp.zeros(shape=(symmetric_block_size, 0), dtype=dtype)
416
+
417
+ blocks = []
418
+ for i in range(starting_block, num_blocks):
419
+ blocks.append(
420
+ jnp.zeros(
421
+ shape=(symmetric_block_size, symmetric_block_size * i), dtype=dtype
422
+ )
423
+ )
424
+ blocks.append(jnp.eye(symmetric_block_size, dtype=dtype))
425
+ return jnp.concatenate(blocks, axis=-1)
426
+
427
+
428
+ def pad_block_symmetric_matrix(
429
+ mat,
430
+ symmetric_block_size,
431
+ max_num_blocks,
432
+ ):
433
+ """Returns the padded blocked symmetric matrix.
434
+
435
+ The size of the padded matrix will be:
436
+ [symmetric_block_size, symmetric_block_size * max_num_blocks]
437
+
438
+ The input matrix can either:
439
+ - Be square with size less or equal to symmetric_block_size. In this case,
440
+ mat will first be padded to a square matrix of size symmetric_block_size,
441
+ and then be padded again up to the full size of the blocked matrix.
442
+ - Be a rectangle with number of rows equal to block size.
443
+ In this case, number of columns must be a multiple of number of rows, and
444
+ the ratio must correspond to a block representation of a symmetric matrix.
445
+ That is, the ratio must have form x * (x + 1) / 2. Here, x represents the
446
+ number of block rows represented by the matrix.
447
+
448
+ Args:
449
+ mat: The input block matrix.
450
+ symmetric_block_size: The size of blocks.
451
+ max_num_blocks: The largest number of blocks to pad to.
452
+ """
453
+ rows, cols = mat.shape
454
+ if rows > symmetric_block_size:
455
+ raise ValueError(
456
+ "Must have rows <= symmetric_block_size. Instead got "
457
+ f"rows={rows}, symmetric_block_size={symmetric_block_size}."
458
+ )
459
+ if rows > cols:
460
+ raise ValueError(
461
+ "Must have rows <= cols, instead got " f"rows={rows}, cols={cols}."
462
+ )
463
+ if cols > symmetric_block_size * max_num_blocks:
464
+ raise ValueError(
465
+ "Must have cols <= symmetric_block_size * max_num_blocks "
466
+ f"Instead got cols={cols}, "
467
+ f"symmetric_block_size={symmetric_block_size}, "
468
+ f"max_num_blocks={max_num_blocks}."
469
+ )
470
+ if rows < symmetric_block_size:
471
+ mat = pad_square_matrix(mat, max_size=symmetric_block_size)
472
+ # Update rows and cols after possibly padding in pad_square_matrix.
473
+ rows, cols = mat.shape
474
+ assert rows == symmetric_block_size
475
+ assert cols % rows == 0
476
+ filled_blocks = cols // rows
477
+ padding_blocks = make_sliced_padding(
478
+ symmetric_block_size=symmetric_block_size,
479
+ num_blocks=symmetric_matrices.num_blocks_from_total_blocks(max_num_blocks),
480
+ starting_block=symmetric_matrices.num_blocks_from_total_blocks(filled_blocks),
481
+ dtype=mat.dtype,
482
+ )
483
+ return jnp.concatenate([mat, padding_blocks], axis=-1)
484
+
485
+
486
+ def pad_vector(vec, max_size):
487
+ """Pad a vector to a max_size.
488
+
489
+ Args:
490
+ vec: a vector to pad.
491
+ max_size: matrix size requested.
492
+
493
+ Returns:
494
+ Given V returns [V, 0]
495
+ """
496
+ size = vec.shape[0]
497
+ assert size <= max_size
498
+ if size == max_size:
499
+ return vec
500
+ pad_size = max_size - size
501
+ zs1 = jnp.zeros([pad_size], dtype=vec.dtype)
502
+ return jnp.concatenate([vec, zs1], 0)
503
+
504
+
505
+ def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
506
+ """Avoids wasteful buffer allocation with XLA."""
507
+
508
+ def _iter_body(unused_state):
509
+ results = compute_fn(*args, **kwargs)
510
+ return tuple([False] + list(results))
511
+
512
+ def _iter_condition(state):
513
+ return state[0]
514
+
515
+ results = jax.lax.while_loop(
516
+ _iter_condition, _iter_body, tuple([predicate] + init_state)
517
+ )
518
+ return tuple(results[1:])
519
+
520
+
521
+ class BlockPartitioner:
522
+ """Partitions a tensor into smaller tensors."""
523
+
524
+ def __init__(self, param, block_size):
525
+ self._shape = param.shape
526
+ self._splits = []
527
+ split_sizes = []
528
+ # We split params into smaller blocks. Here we store the metadata to make
529
+ # that split.
530
+ for i, d in enumerate(param.shape):
531
+ if 0 < block_size < d:
532
+ # d-1, otherwise split appends a 0-size array.
533
+ nsplit = (d - 1) // block_size
534
+ indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size
535
+ sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size
536
+ sizes[-1] = d - indices[-1]
537
+ self._splits.append((i, indices))
538
+ split_sizes.append(sizes)
539
+ else:
540
+ split_sizes.append(np.array([d], dtype=np.int32))
541
+ self._num_splits = len(split_sizes)
542
+ self._preconditioner_shapes = []
543
+ for t in itertools.product(*split_sizes):
544
+ self._preconditioner_shapes.extend([[d, d] for d in t])
545
+
546
+ def shapes_for_preconditioners(self):
547
+ return self._preconditioner_shapes
548
+
549
+ def num_splits(self):
550
+ return self._num_splits
551
+
552
+ def partition(self, tensor):
553
+ """Partition tensor into blocks."""
554
+
555
+ assert tensor.shape == self._shape
556
+ tensors = [tensor]
557
+ for (i, indices) in self._splits:
558
+ tensors_local = []
559
+ for t in tensors:
560
+ tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
561
+ tensors = tensors_local
562
+ return tensors
563
+
564
+ def merge_partitions(self, partitions):
565
+ """Merge partitions back to original shape."""
566
+
567
+ for (i, indices) in reversed(self._splits):
568
+ n = len(indices) + 1
569
+ partial_merged_tensors = []
570
+ ind = 0
571
+ while ind < len(partitions):
572
+ partial_merged_tensors.append(
573
+ jnp.concatenate(partitions[ind : ind + n], axis=i)
574
+ )
575
+ ind += n
576
+ partitions = partial_merged_tensors
577
+ assert len(partitions) == 1
578
+ return partitions[0]
579
+
580
+
581
+ class Preconditioner:
582
+ """Compute statistics/shape from gradients for preconditioning."""
583
+
584
+ def __init__(self, param, block_size, best_effort_shape_interpretation):
585
+ self._original_shape = param.shape
586
+ self._transformed_shape = param.shape
587
+ if best_effort_shape_interpretation:
588
+ self._transformed_shape = merge_small_dims(self._original_shape, block_size)
589
+ reshaped_param = jnp.reshape(param, self._transformed_shape)
590
+ self._partitioner = BlockPartitioner(reshaped_param, block_size)
591
+
592
+ def statistics_from_grad(self, grad):
593
+ """Compute statistics from gradients.
594
+
595
+ Args:
596
+ grad: Gradient to compute statistics from.
597
+
598
+ Returns:
599
+ A list of gradient statistics for each partition.
600
+ """
601
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
602
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
603
+ stats = []
604
+ for g in partitioned_grads:
605
+ g_stats = []
606
+ rank = len(g.shape)
607
+ for i in range(rank):
608
+ axes = list(range(i)) + list(range(i + 1, rank))
609
+ stat = jnp.tensordot(g, g, axes=(axes, axes))
610
+ g_stats.append(stat)
611
+ stats.extend(g_stats)
612
+ return stats
613
+
614
+ def shapes_for_preconditioners(self):
615
+ """Returns shape from statistics."""
616
+ return self._partitioner.shapes_for_preconditioners()
617
+
618
+ def exponent_for_preconditioner(self):
619
+ """Returns exponent to use for inverse-pth root M^{-1/p}."""
620
+ return 2 * len(self._transformed_shape)
621
+
622
+ def preconditioned_grad(self, grad, preconditioners):
623
+ """Precondition the gradient.
624
+
625
+ Args:
626
+ grad: A gradient tensor to precondition.
627
+ preconditioners: A list of preconditioners to apply.
628
+
629
+ Returns:
630
+ A preconditioned gradient.
631
+ """
632
+
633
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
634
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
635
+ preconditioned_partitioned_grads = []
636
+ num_splits = self._partitioner.num_splits()
637
+ for i, g in enumerate(partitioned_grads):
638
+ preconditioners_for_grad = preconditioners[
639
+ i * num_splits : (i + 1) * num_splits
640
+ ]
641
+ rank = len(g.shape)
642
+ precond_g = g
643
+ for j in range(rank):
644
+ precond_g = jnp.tensordot(
645
+ precond_g, preconditioners_for_grad[j], axes=[[0], [0]]
646
+ )
647
+ preconditioned_partitioned_grads.append(precond_g)
648
+ merged_grad = self._partitioner.merge_partitions(
649
+ preconditioned_partitioned_grads
650
+ )
651
+ return jnp.reshape(merged_grad, self._original_shape)
652
+
653
+
654
+ def _convert_to_parameter_stats(global_stats, local_stat):
655
+ """Creates parameter stats from sharded stats."""
656
+ index_start = int(local_stat.index_start)
657
+ index_end = int(len(local_stat.sizes)) + index_start
658
+ statistics = global_stats.statistics[index_start:index_end, :, :]
659
+ preconditioners = global_stats.preconditioners[index_start:index_end, :, :]
660
+ new_statistics = []
661
+ new_preconditioners = []
662
+ for i, size in enumerate(local_stat.sizes):
663
+ new_statistics.append(statistics[i][:size, :size])
664
+ new_preconditioners.append(preconditioners[i][:size, :size])
665
+ return ParameterStats(
666
+ local_stat.diagonal_statistics,
667
+ new_statistics,
668
+ new_preconditioners,
669
+ local_stat.diagonal_momentum,
670
+ local_stat.momentum,
671
+ local_stat.training_metrics,
672
+ )
673
+
674
+
675
+ def _convert_from_parameter_stats(parameter_stats, local_stats):
676
+ """Creates sharded stats from paramter stats."""
677
+ return LocalShardedParameterStats(
678
+ parameter_stats.diagonal_statistics,
679
+ parameter_stats.diagonal_momentum,
680
+ parameter_stats.momentum,
681
+ parameter_stats.training_metrics,
682
+ local_stats.index_start,
683
+ local_stats.sizes,
684
+ )
685
+
686
+
687
+ def _add_error_into_local_stats(local_stats, errors, inverse_failure_threshold):
688
+ """Adds errors back into local statistics."""
689
+ new_local_stats = []
690
+ for local_stat in local_stats:
691
+ index_start = int(local_stat.index_start)
692
+ index_end = int(len(local_stat.sizes)) + index_start
693
+ per_stat_error = errors[index_start:index_end]
694
+ if local_stat.sizes:
695
+ per_stat_error = jnp.where(
696
+ jnp.logical_and(
697
+ per_stat_error > 0.0, per_stat_error != inverse_failure_threshold
698
+ ),
699
+ per_stat_error,
700
+ local_stat.training_metrics.inverse_pth_root_errors,
701
+ )
702
+ new_local_stats.append(
703
+ LocalShardedParameterStats(
704
+ local_stat.diagonal_statistics,
705
+ local_stat.diagonal_momentum,
706
+ local_stat.momentum,
707
+ TrainingMetrics(per_stat_error),
708
+ local_stat.index_start,
709
+ local_stat.sizes,
710
+ )
711
+ )
712
+ return new_local_stats
713
+
714
+
715
+ def batch(x, num_devices):
716
+ """Batch `x` so that so that leading axis is num_devices."""
717
+ n = len(x)
718
+ b = int(n / num_devices)
719
+ return jnp.stack([jnp.stack(x[idx : idx + b]) for idx in range(0, n, b)])
720
+
721
+
722
+ def unbatch(batched_values):
723
+ """Unbatch values across leading axis and return a list of elements."""
724
+ b1, b2 = batched_values.shape[0], batched_values.shape[1]
725
+ results = []
726
+ for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0):
727
+ v_array = jnp.squeeze(v_array)
728
+ # b2 = batches (number of preconditioner computation) per core.
729
+ if b2 > 1:
730
+ for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
731
+ results.append(jnp.squeeze(v))
732
+ else:
733
+ results.append(v_array)
734
+ return results
735
+
736
+
737
+ def distributed_shampoo(
738
+ learning_rate,
739
+ block_size,
740
+ beta1=0.9,
741
+ beta2=0.999,
742
+ diagonal_epsilon=1e-10,
743
+ matrix_epsilon=1e-6,
744
+ weight_decay=0.0,
745
+ start_preconditioning_step=5,
746
+ preconditioning_compute_steps=1,
747
+ statistics_compute_steps=1,
748
+ best_effort_shape_interpretation=True,
749
+ graft_type=GraftingType.SGD,
750
+ nesterov=True,
751
+ exponent_override=0,
752
+ # Pass pmap 'batch axis name' in pmap mode.
753
+ batch_axis_name=None,
754
+ ### Only set following 3 params in pjit/spmd mode.
755
+ ### WARNING: Experimental
756
+ statistics_partition_spec=None,
757
+ preconditioner_partition_spec=None,
758
+ num_devices_for_pjit=None,
759
+ shard_optimizer_states=False,
760
+ ###
761
+ ### Experimental memory reduction mode
762
+ best_effort_memory_usage_reduction=False,
763
+ ###
764
+ inverse_failure_threshold=0.1,
765
+ moving_average_for_momentum=False,
766
+ skip_preconditioning_dim_size_gt=4096,
767
+ clip_by_scaled_gradient_norm=None,
768
+ precision=lax.Precision.HIGHEST,
769
+ ):
770
+ """Distributed Shampoo optimizer.
771
+
772
+ Distributed Shampoo is a second-order preconditioned method (concretely, a
773
+ variant of full-matrix Adagrad), that provides significant convergence and
774
+ wall-clock time improvements compared to conventional first-order methods,
775
+ and that has been shown to scale to large state-of-the-art deep learning
776
+ models.
777
+
778
+ References:
779
+ Scalable Second Order Optimization for Deep Learning,
780
+ Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
781
+
782
+ Preprint: https://arxiv.org/abs/2002.09018
783
+
784
+ Args:
785
+ learning_rate: the step size used to update the parameters.
786
+ block_size: Block size for large layers (if > 0). Preconditioning compute
787
+ operation is cubic in the dimension of the tensor. Block size allows us to
788
+ chunk the layers into sub-layers of maximal dimension dictated by this
789
+ value. Use 128 as default (increase if you have compute budget).
790
+ beta1: momentum parameter.
791
+ beta2: second moment averaging parameter.
792
+ diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
793
+ to AdaGrad is enabled).
794
+ matrix_epsilon: epsilon to add to statistics before computing inverse pth
795
+ root. If you are running in f32 precision for inverse pth root
796
+ (recommended today) this can go upto 1e-6. If you have latest hardware
797
+ with native f64 precision, set this upto 1e-12.
798
+ weight_decay: Weight decay for regularization.
799
+ start_preconditioning_step: When to start Shampoo update before which
800
+ diagonal update is used. This is because we dont have enough information
801
+ to do stable inverse.
802
+ preconditioning_compute_steps: How often to compute preconditioner.
803
+ Performance tuning params for controlling memory and compute requirements.
804
+ Ideally set this and statistics_compute_steps params to 1.
805
+ statistics_compute_steps: How often to compute statistics.
806
+ best_effort_shape_interpretation: If there are some small dimensions,
807
+ collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
808
+ block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
809
+ graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
810
+ optimizer. This allows us to plugin the Shampoo optimizer into settings
811
+ where SGD/AdaGrad is already well tuned.
812
+ nesterov: Nesterov momentum.
813
+ exponent_override: Override the exponent used in matrix inverse.
814
+ batch_axis_name: labeled axis over pmap for data-parallel training the
815
+ optimizer used for.
816
+ statistics_partition_spec: PartitionSpec to be used in sharded mode.
817
+ preconditioner_partition_spec: PartitionSpec to be used in sharded mode.
818
+ num_devices_for_pjit: Number of devices to parallelize over when using pjit.
819
+ shard_optimizer_states: Shard optimizer states to save memory in model
820
+ parallel training.
821
+ best_effort_memory_usage_reduction: Best effort memory usage reduction. -
822
+ diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x) -> jnp.int8 -
823
+ statistics, preconditioners -> jnp.int16 + diagonals
824
+ inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
825
+ determine that using this threshold.
826
+ moving_average_for_momentum: Whether to use moving average for momentum
827
+ instead of exponential moving average.
828
+ skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
829
+ greater than this value.
830
+ clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful when
831
+ using RMSProp Grafting).
832
+ precision: precision XLA related flag, the available options are: a)
833
+ lax.Precision.DEFAULT (better step time, but not precise) b)
834
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
835
+ (best possible precision, slowest)
836
+
837
+ Returns:
838
+ a GradientTransformation.
839
+ """
840
+
841
+ def _graft_type_has_diagonal_statistics():
842
+ """Returns True if using diagonal firt order method for grafting."""
843
+ return graft_type != GraftingType.SGD and graft_type != GraftingType.SQRT_N
844
+
845
+ def _graft_type_has_diagonal_momentum_states():
846
+ """Returns False if using SQRT_N for grafting."""
847
+ return graft_type != GraftingType.SQRT_N
848
+
849
+ def quantized_dtype_for_momentum_buffers():
850
+ return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32
851
+
852
+ # TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes.
853
+ def quantized_dtype_for_diagonal_statistics_buffers():
854
+ return jnp.float32
855
+
856
+ # Preconditioner and statistics are both stores as int16 in this mode.
857
+ # We take out the diagonal to make quantization easier.
858
+ def quantized_dtype_for_second_moment_statistics_buffers():
859
+ return (
860
+ jnp.int16
861
+ if best_effort_memory_usage_reduction and batch_axis_name
862
+ else jnp.float32
863
+ )
864
+
865
+ # Preconditioner and statistics are both stores as int16 in this mode.
866
+ # We take out the diagonal to make quantization easier.
867
+ def quantized_dtype_for_second_moment_preconditioner_buffers():
868
+ return (
869
+ jnp.int16
870
+ if best_effort_memory_usage_reduction and batch_axis_name
871
+ else jnp.float32
872
+ )
873
+
874
+ def _to_float(maybe_quantized):
875
+ if isinstance(maybe_quantized, QuantizedValue):
876
+ return maybe_quantized.to_float()
877
+ else:
878
+ return maybe_quantized
879
+
880
+ def _maybe_quantize_statistics(statistics_list):
881
+ return _maybe_quantize_matrices_with_dtype(
882
+ statistics_list, quantized_dtype_for_second_moment_statistics_buffers()
883
+ )
884
+
885
+ def _maybe_quantize_preconditioners(statistics_list):
886
+ return _maybe_quantize_matrices_with_dtype(
887
+ statistics_list, quantized_dtype_for_second_moment_preconditioner_buffers()
888
+ )
889
+
890
+ def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype):
891
+ if quantized_dtype != jnp.float32:
892
+ return [
893
+ QuantizedValue.from_float_value(
894
+ s, quantized_dtype, extract_diagonal=True
895
+ )
896
+ for s in statistics_list
897
+ ]
898
+ else:
899
+ return statistics_list
900
+
901
+ def _maybe_dequantize_preconditioners(preconditioner_list):
902
+ return _maybe_dequantize_matrices_with_dtype(
903
+ preconditioner_list,
904
+ quantized_dtype_for_second_moment_preconditioner_buffers(),
905
+ )
906
+
907
+ def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype):
908
+ if quantized_dtype != jnp.float32:
909
+ return [s.to_float() for s in statistics_list]
910
+ else:
911
+ return statistics_list
912
+
913
+ def _quantize_diagonal_statistics(diagonal_statistics):
914
+ return QuantizedValue.from_float_value(
915
+ diagonal_statistics, quantized_dtype_for_diagonal_statistics_buffers()
916
+ )
917
+
918
+ def _quantize_momentum(momentum_statistics):
919
+ return QuantizedValue.from_float_value(
920
+ momentum_statistics, quantized_dtype_for_momentum_buffers()
921
+ )
922
+
923
+ def sharded_init_fn(params):
924
+ """Returns optimizer state (for PJIT mode).
925
+
926
+ Args:
927
+ params: the parameters that should be updated.
928
+ """
929
+ params_flat, treedef = jax.tree_flatten(params)
930
+ # Find max size to pad to.
931
+ max_size = 0
932
+ for param in params_flat:
933
+ preconditioner = Preconditioner(
934
+ param, block_size, best_effort_shape_interpretation
935
+ )
936
+ if not _skip_preconditioning(param):
937
+ shapes = preconditioner.shapes_for_preconditioners()
938
+ sizes = [s[0] for s in shapes]
939
+ max_size = max(max(sizes), max_size)
940
+
941
+ padded_statistics = []
942
+ padded_preconditioners = []
943
+ local_stats_flat = []
944
+ exponents = []
945
+ for param in params_flat:
946
+ preconditioner = Preconditioner(
947
+ param, block_size, best_effort_shape_interpretation
948
+ )
949
+ shapes = preconditioner.shapes_for_preconditioners()
950
+ sizes = []
951
+
952
+ statistics = []
953
+ preconditioners = []
954
+ index_start = len(padded_statistics)
955
+ if not _skip_preconditioning(param):
956
+ sizes = [s[0] for s in shapes]
957
+ shapes = preconditioner.shapes_for_preconditioners()
958
+ statistics = [
959
+ matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32)
960
+ for s in shapes
961
+ ]
962
+ preconditioners = [jnp.eye(max_size, dtype=jnp.float32) for s in shapes]
963
+ padded_statistics.extend(statistics)
964
+ padded_preconditioners.extend(preconditioners)
965
+ exponent = (
966
+ preconditioner.exponent_for_preconditioner()
967
+ if exponent_override == 0
968
+ else exponent_override
969
+ )
970
+ exponents.extend([exponent] * len(shapes))
971
+
972
+ diagonal_statistics = []
973
+ if _graft_type_has_diagonal_statistics():
974
+ diagonal_statistics = jnp.zeros_like(param)
975
+
976
+ diagonal_momentum = _quantize_momentum([])
977
+ momentum = _quantize_momentum(jnp.zeros_like(param))
978
+ if _graft_type_has_diagonal_momentum_states():
979
+ diagonal_momentum = _quantize_momentum((jnp.zeros_like(param)))
980
+
981
+ local_stats_flat.append(
982
+ LocalShardedParameterStats(
983
+ _quantize_diagonal_statistics(diagonal_statistics),
984
+ diagonal_momentum,
985
+ momentum,
986
+ init_training_metrics(len(sizes)),
987
+ index_start,
988
+ sizes,
989
+ )
990
+ )
991
+
992
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
993
+ to_pad = -len(padded_statistics) % num_devices_for_pjit
994
+ if max_size == 0:
995
+ to_pad = num_devices_for_pjit
996
+ max_size = block_size
997
+ stat_dtype = jnp.float32
998
+ else:
999
+ stat_dtype = padded_statistics[0].dtype
1000
+ # Pad the statistics and preconditioner matrices to be a multiple of
1001
+ # num devices.
1002
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
1003
+ # is split on.
1004
+ padded_statistics.extend(
1005
+ [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
1006
+ )
1007
+ padded_preconditioners.extend(
1008
+ [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
1009
+ )
1010
+ exponents.extend([1 for _ in range(to_pad)])
1011
+ global_stats = GlobalShardedParameterStats(
1012
+ jnp.stack(padded_statistics),
1013
+ jnp.stack(padded_preconditioners),
1014
+ jnp.stack(exponents),
1015
+ )
1016
+ return ShampooState(
1017
+ count=jnp.zeros([], jnp.int32),
1018
+ stats=ShardedShampooStats(global_stats, local_stats),
1019
+ )
1020
+
1021
+ def _max_statistics_size_from_params(params):
1022
+ max_size = 0
1023
+ for param in params:
1024
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
1025
+ preconditioner = Preconditioner(
1026
+ param_clone, block_size, best_effort_shape_interpretation
1027
+ )
1028
+ if not _skip_preconditioning(param):
1029
+ shapes = preconditioner.shapes_for_preconditioners()
1030
+ sizes = [s[0] for s in shapes]
1031
+ max_size = max(max(sizes), max_size)
1032
+ return max_size
1033
+
1034
+ def _remove_leading_sharding_annotation(pspec):
1035
+ """Mapping from N-d to (N-1)-d, used for quantization, factoring etc."""
1036
+ # None and PSpec(None) are valid PSpecs.
1037
+ if pspec and len(pspec) > 1:
1038
+ return pjit.PartitionSpec(*pspec[1:])
1039
+ else:
1040
+ return []
1041
+
1042
+ def sharded_init_partition_spec_fn(
1043
+ params, params_partition_spec, partition_spec_for_statistics
1044
+ ):
1045
+ """Returns a parallel state tree with PartitionSpec associated with state.
1046
+
1047
+
1048
+ Args:
1049
+ params: A pytree with params.
1050
+ params_partition_spec: A pytree with PartitionSpec for params.
1051
+ partition_spec_for_statistics: PartitionSpec for the statistics.
1052
+ """
1053
+ # Parallel lists of spec, and params.
1054
+ param_pspec_flat, _ = jax.tree_flatten(
1055
+ params_partition_spec, is_leaf=lambda x: x is None
1056
+ )
1057
+ params_flat, treedef = jax.tree_flatten(params)
1058
+ assert param_pspec_flat
1059
+ assert params_flat
1060
+ # Step is replicated across cores.
1061
+ # None means cores.
1062
+ local_stats_flat = []
1063
+ num_statistics = 0
1064
+ for param, param_pspec in zip(params_flat, param_pspec_flat):
1065
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
1066
+ preconditioner = Preconditioner(
1067
+ param_clone, block_size, best_effort_shape_interpretation
1068
+ )
1069
+ shapes = preconditioner.shapes_for_preconditioners()
1070
+ sizes = []
1071
+
1072
+ index_start = num_statistics
1073
+ if not _skip_preconditioning(param):
1074
+ sizes = [s[0] for s in shapes]
1075
+ shapes = preconditioner.shapes_for_preconditioners()
1076
+ num_statistics += len(shapes)
1077
+
1078
+ diagonal_statistics_pspec = []
1079
+ diagonal_statistics_scale_pspec = []
1080
+ if _graft_type_has_diagonal_statistics():
1081
+ # Identically shaped param.
1082
+ diagonal_statistics_pspec = param_pspec
1083
+ if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
1084
+ diagonal_statistics_scale_pspec = (
1085
+ _remove_leading_sharding_annotation(param_pspec)
1086
+ )
1087
+
1088
+ m1_pspec = []
1089
+ m1_scale_pspec = []
1090
+ if _graft_type_has_diagonal_momentum_states():
1091
+ m1_pspec = param_pspec
1092
+ if quantized_dtype_for_momentum_buffers() != jnp.float32:
1093
+ m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
1094
+
1095
+ m2_pspec = param_pspec
1096
+ m2_scale_pspec = []
1097
+ if quantized_dtype_for_momentum_buffers() != jnp.float32:
1098
+ m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec)
1099
+
1100
+ local_stats_flat.append(
1101
+ LocalShardedParameterStats(
1102
+ QuantizedValue(
1103
+ diagonal_statistics_pspec,
1104
+ [],
1105
+ diagonal_statistics_scale_pspec,
1106
+ quantized_dtype_for_diagonal_statistics_buffers(),
1107
+ False,
1108
+ list(param.shape),
1109
+ ),
1110
+ QuantizedValue(
1111
+ m1_pspec,
1112
+ [],
1113
+ m1_scale_pspec,
1114
+ quantized_dtype_for_momentum_buffers(),
1115
+ False,
1116
+ list(param.shape),
1117
+ ),
1118
+ QuantizedValue(
1119
+ m2_pspec,
1120
+ [],
1121
+ m2_scale_pspec,
1122
+ quantized_dtype_for_momentum_buffers(),
1123
+ False,
1124
+ list(param.shape),
1125
+ ),
1126
+ init_training_metrics_pspec(),
1127
+ index_start,
1128
+ sizes,
1129
+ )
1130
+ )
1131
+
1132
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
1133
+ global_stats = GlobalShardedParameterStats(
1134
+ partition_spec_for_statistics,
1135
+ partition_spec_for_statistics,
1136
+ pjit.PartitionSpec(),
1137
+ )
1138
+ count_pspec = pjit.PartitionSpec()
1139
+ return ShampooState(
1140
+ count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats)
1141
+ )
1142
+
1143
+ def sharded_init_shape_and_dtype_fn(params):
1144
+ """Returns a parallel state tree with shape, dtype associated with state.
1145
+
1146
+
1147
+ Args:
1148
+ params: A pytree with params.
1149
+ """
1150
+ # Parallel lists of spec, and params.
1151
+ params_flat, treedef = jax.tree_flatten(params)
1152
+ assert params_flat
1153
+ # Step is replicated across cores.
1154
+ # None means cores.
1155
+ local_stats_flat = []
1156
+ num_statistics = 0
1157
+ for param in params_flat:
1158
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
1159
+ preconditioner = Preconditioner(
1160
+ param_clone, block_size, best_effort_shape_interpretation
1161
+ )
1162
+ shapes = preconditioner.shapes_for_preconditioners()
1163
+ sizes = []
1164
+
1165
+ index_start = num_statistics
1166
+ if not _skip_preconditioning(param):
1167
+ sizes = [s[0] for s in shapes]
1168
+ shapes = preconditioner.shapes_for_preconditioners()
1169
+ num_statistics += len(shapes)
1170
+
1171
+ diagonal_statistics_shape_and_dtype = []
1172
+ diagonal_statistics_scale_shape_and_dtype = []
1173
+ if _graft_type_has_diagonal_statistics():
1174
+ diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype]
1175
+ qdtype = quantized_dtype_for_diagonal_statistics_buffers()
1176
+ if qdtype != jnp.float32:
1177
+ diagonal_statistics_shape_and_dtype = [list(param.shape), qdtype]
1178
+ diagonal_statistics_scale_shape_and_dtype = [
1179
+ list(param.shape)[1:],
1180
+ param.dtype,
1181
+ ]
1182
+
1183
+ qdtype = quantized_dtype_for_momentum_buffers()
1184
+ m1_shape_and_dtype = []
1185
+ m1_scale_shape_and_dtype = []
1186
+ if _graft_type_has_diagonal_momentum_states():
1187
+ m1_shape_and_dtype = [list(param.shape), qdtype]
1188
+ if quantized_dtype_for_momentum_buffers() != jnp.float32:
1189
+ m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1190
+
1191
+ m2_shape_and_dtype = [list(param.shape), param.dtype]
1192
+ m2_scale_shape_and_dtype = []
1193
+ if qdtype != jnp.float32:
1194
+ m2_shape_and_dtype = [list(param.shape), qdtype]
1195
+ m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1196
+
1197
+ local_stats_flat.append(
1198
+ LocalShardedParameterStats(
1199
+ QuantizedValue(
1200
+ diagonal_statistics_shape_and_dtype,
1201
+ [],
1202
+ diagonal_statistics_scale_shape_and_dtype,
1203
+ quantized_dtype_for_diagonal_statistics_buffers(),
1204
+ False,
1205
+ list(param.shape),
1206
+ ),
1207
+ QuantizedValue(
1208
+ m1_shape_and_dtype,
1209
+ [],
1210
+ m1_scale_shape_and_dtype,
1211
+ quantized_dtype_for_momentum_buffers(),
1212
+ False,
1213
+ list(param.shape),
1214
+ ),
1215
+ QuantizedValue(
1216
+ m2_shape_and_dtype,
1217
+ [],
1218
+ m2_scale_shape_and_dtype,
1219
+ quantized_dtype_for_momentum_buffers(),
1220
+ False,
1221
+ list(param.shape),
1222
+ ),
1223
+ init_training_metrics_shapes(len(sizes)),
1224
+ index_start,
1225
+ sizes,
1226
+ )
1227
+ )
1228
+
1229
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
1230
+ max_statistics_size = _max_statistics_size_from_params(params_flat)
1231
+ to_pad = -num_statistics % num_devices_for_pjit
1232
+ num_statistics += to_pad
1233
+ if num_statistics == 0:
1234
+ num_statistics = num_devices_for_pjit
1235
+ max_statistics_size = block_size
1236
+ statistics_shape = [num_statistics, max_statistics_size, max_statistics_size]
1237
+ global_stats = GlobalShardedParameterStats(
1238
+ [statistics_shape, jnp.float32],
1239
+ [statistics_shape, jnp.float32],
1240
+ [[num_statistics], jnp.int32],
1241
+ )
1242
+ return ShampooState(
1243
+ count=[[], jnp.float32],
1244
+ stats=ShardedShampooStats(global_stats, local_stats),
1245
+ )
1246
+
1247
+ def sharded_update_fn(grads, state, params):
1248
+ """Transform the input gradient and update all statistics in sharded mode.
1249
+
1250
+ Args:
1251
+ grads: the gradient tensors for the parameters.
1252
+ state: a named tuple containing the state of the optimizer
1253
+ params: the parameters that should be updated.
1254
+
1255
+ Returns:
1256
+ A tuple containing the new parameters and the new optimizer state.
1257
+ """
1258
+ params_flat, treedef = jax.tree_flatten(params)
1259
+ grads_flat = treedef.flatten_up_to(grads)
1260
+
1261
+ global_stats = state.stats.global_stats
1262
+ local_stats_flat = treedef.flatten_up_to(state.stats.local_stats)
1263
+ stats_flat = [
1264
+ _convert_to_parameter_stats(global_stats, local_stat)
1265
+ for local_stat in local_stats_flat
1266
+ ]
1267
+ new_stats_flat = jax.tree_multimap(
1268
+ lambda g, s, p: _compute_stats(g, s, p, state.count),
1269
+ grads_flat,
1270
+ stats_flat,
1271
+ params_flat,
1272
+ )
1273
+
1274
+ outputs = jax.tree_multimap(
1275
+ lambda g, s, p: _transform_grad(g, s, p, state.count),
1276
+ grads_flat,
1277
+ new_stats_flat,
1278
+ params_flat,
1279
+ )
1280
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
1281
+
1282
+ updates = jax.tree_unflatten(treedef, updates_flat)
1283
+ # Create new local_stats
1284
+ new_local_stats_flat = [
1285
+ _convert_from_parameter_stats(new_stat, local_stat)
1286
+ for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
1287
+ ]
1288
+
1289
+ max_size = global_stats.statistics.shape[1]
1290
+ new_padded_statistics = []
1291
+ for stat in new_stats_flat:
1292
+ new_padded_statistics.extend(
1293
+ [pad_square_matrix(stat, max_size) for stat in stat.statistics]
1294
+ )
1295
+
1296
+ # Create global stats
1297
+ # TODO(rohananil): Preconditioner is not updated every step, so cost of
1298
+ # stack/pad can be obviated away.
1299
+ # Pad the statistics and preconditioner matrices to be a multiple of
1300
+ # num devices.
1301
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
1302
+ # is split on.
1303
+ to_pad = -len(new_padded_statistics) % num_devices_for_pjit
1304
+ new_padded_statistics.extend(
1305
+ [
1306
+ jnp.eye(max_size, dtype=new_padded_statistics[0].dtype)
1307
+ for _ in range(to_pad)
1308
+ ]
1309
+ )
1310
+ new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
1311
+ new_stacked_padded_statistics = pjit.with_sharding_constraint(
1312
+ new_stacked_padded_statistics, statistics_partition_spec
1313
+ )
1314
+
1315
+ def _internal_inverse_pth_root_all():
1316
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
1317
+ new_stacked_padded_statistics,
1318
+ global_stats.exponents,
1319
+ statistics_partition_spec,
1320
+ )
1321
+ return preconditioners, errors
1322
+
1323
+ if preconditioning_compute_steps == 1:
1324
+ new_preconditioners, errors = _internal_inverse_pth_root_all()
1325
+ else:
1326
+ # Passing statistics instead of preconditioners as they are similarly
1327
+ # shaped tensors. Note statistics will be ignored as we are passing in
1328
+ # a large init value for error.
1329
+ preconditioners_init = new_stacked_padded_statistics
1330
+ n = new_stacked_padded_statistics.shape[0]
1331
+ errors_init = jnp.ones([n], jnp.float32) * inverse_failure_threshold
1332
+ init_state = [preconditioners_init, errors_init]
1333
+ perform_step = state.count % preconditioning_compute_steps == 0
1334
+ new_preconditioners, errors = efficient_cond(
1335
+ perform_step, _internal_inverse_pth_root_all, init_state
1336
+ )
1337
+
1338
+ new_local_stats_flat = _add_error_into_local_stats(
1339
+ new_local_stats_flat, errors, inverse_failure_threshold
1340
+ )
1341
+ new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
1342
+ errors = errors.reshape((-1, 1, 1))
1343
+ predicate = jnp.logical_or(
1344
+ jnp.isnan(errors), errors >= inverse_failure_threshold
1345
+ ).astype(new_preconditioners.dtype)
1346
+ # TODO(rohananil): Check for numerical instabilities.
1347
+ new_conditional_preconditioners = (
1348
+ predicate * global_stats.preconditioners
1349
+ + (1.0 - predicate) * new_preconditioners
1350
+ )
1351
+ new_global_stats = GlobalShardedParameterStats(
1352
+ new_stacked_padded_statistics,
1353
+ new_conditional_preconditioners,
1354
+ global_stats.exponents,
1355
+ )
1356
+ new_shampoo_state = ShampooState(
1357
+ count=state.count + 1,
1358
+ stats=ShardedShampooStats(new_global_stats, new_local_stats),
1359
+ )
1360
+ return updates, new_shampoo_state
1361
+
1362
+ def init_fn(params):
1363
+ """Initialise the optimiser's state."""
1364
+
1365
+ def _init(param):
1366
+ preconditioner = Preconditioner(
1367
+ param, block_size, best_effort_shape_interpretation
1368
+ )
1369
+ statistics = []
1370
+ preconditioners = []
1371
+ if not _skip_preconditioning(param):
1372
+ shapes = preconditioner.shapes_for_preconditioners()
1373
+ statistics = [
1374
+ matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes
1375
+ ]
1376
+ preconditioners = [jnp.eye(s[0], dtype=jnp.float32) for s in shapes]
1377
+
1378
+ diagonal_statistics = []
1379
+ if _graft_type_has_diagonal_statistics():
1380
+ diagonal_statistics = jnp.zeros_like(param)
1381
+
1382
+ diagonal_momentum = _quantize_momentum([])
1383
+ momentum = _quantize_momentum(jnp.zeros_like(param))
1384
+ if _graft_type_has_diagonal_momentum_states():
1385
+ diagonal_momentum = _quantize_momentum(jnp.zeros_like(param))
1386
+
1387
+ return ParameterStats(
1388
+ _quantize_diagonal_statistics(diagonal_statistics),
1389
+ _maybe_quantize_statistics(statistics),
1390
+ _maybe_quantize_preconditioners(preconditioners),
1391
+ diagonal_momentum,
1392
+ momentum,
1393
+ init_training_metrics(len(statistics)),
1394
+ )
1395
+
1396
+ return ShampooState(
1397
+ count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)
1398
+ )
1399
+
1400
+ def _skip_preconditioning(param):
1401
+ return len(param.shape) < 1 or any(
1402
+ [s > skip_preconditioning_dim_size_gt for s in param.shape]
1403
+ )
1404
+
1405
+ def _compute_stats(grad, state, param, step):
1406
+ """Compute per-parameter statistics."""
1407
+ preconditioner = Preconditioner(
1408
+ param, block_size, best_effort_shape_interpretation
1409
+ )
1410
+ new_statistics = [[]] * len(state.statistics)
1411
+ w1 = beta2
1412
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
1413
+ if not _skip_preconditioning(param):
1414
+
1415
+ def compute_updated_statistics():
1416
+ new_stats = preconditioner.statistics_from_grad(grad)
1417
+ new_stats_accumulators = []
1418
+ for stat, stat_accumulator in zip(new_stats, state.statistics):
1419
+ new_stats_accumulators.append(
1420
+ w1 * _to_float(stat_accumulator) + w2 * stat
1421
+ )
1422
+ return _maybe_quantize_statistics(new_stats_accumulators)
1423
+
1424
+ if statistics_compute_steps > 1:
1425
+ perform_step = step % statistics_compute_steps == 0
1426
+ init_state = state.statistics
1427
+ new_statistics = list(
1428
+ efficient_cond(perform_step, compute_updated_statistics, init_state)
1429
+ )
1430
+ else:
1431
+ new_statistics = compute_updated_statistics()
1432
+ return ParameterStats(
1433
+ state.diagonal_statistics,
1434
+ new_statistics,
1435
+ state.preconditioners,
1436
+ state.diagonal_momentum,
1437
+ state.momentum,
1438
+ state.training_metrics,
1439
+ )
1440
+
1441
+ def _matrix_inverse_pth_root_vmap(xs, ps):
1442
+ mi_pth_root = functools.partial(
1443
+ matrix_inverse_pth_root, ridge_epsilon=matrix_epsilon, precision=precision
1444
+ )
1445
+ return jax.vmap(mi_pth_root)(xs, ps)
1446
+
1447
+ def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps):
1448
+ def _quantized_to_float(qx, qd, qb):
1449
+ qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape))
1450
+ return qv.to_float()
1451
+
1452
+ def matrix_inverse_pth_root_wrapper(qx, qd, qb, p):
1453
+ v = _quantized_to_float(qx, qd, qb)
1454
+ preconditioner, error = matrix_inverse_pth_root(
1455
+ v, p, ridge_epsilon=matrix_epsilon, precision=precision
1456
+ )
1457
+ qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True)
1458
+ return qp.quantized, qp.diagonal, qp.bucket_size, error
1459
+
1460
+ return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
1461
+
1462
+ def _matrix_inverse_pth_root_pjit(xs, ps, statistics_partition_spec=None):
1463
+ # Partition the concatenated statistics matrix across all cores.
1464
+ pspec_for_partition = preconditioner_partition_spec
1465
+ partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition)
1466
+ partitioned_ps = pjit.with_sharding_constraint(
1467
+ ps, pjit.PartitionSpec(preconditioner_partition_spec[0])
1468
+ )
1469
+ # Run matrix inverse pth root on each shard.
1470
+ partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
1471
+ partitioned_xs, partitioned_ps
1472
+ )
1473
+ # Reshard output to have the same PSpec as input. This is required to avoid
1474
+ # vmap seeing the full set of statistics.
1475
+ partitioned_preconditioners = pjit.with_sharding_constraint(
1476
+ partitioned_preconditioners, pspec_for_partition
1477
+ )
1478
+ # Recombine the outputs at each core.
1479
+ preconditioners = pjit.with_sharding_constraint(
1480
+ partitioned_preconditioners, statistics_partition_spec
1481
+ )
1482
+ errors = pjit.with_sharding_constraint(partitioned_errors, pjit.PartitionSpec())
1483
+ return preconditioners, errors
1484
+
1485
+ def _pmap_compute_preconditioners(
1486
+ states,
1487
+ step,
1488
+ statistics,
1489
+ num_statistics_per_state,
1490
+ original_shapes,
1491
+ exponents,
1492
+ max_size,
1493
+ prev_preconditioners,
1494
+ ):
1495
+ """Computes preconditioners for given statistics in states in PMAP mode.
1496
+
1497
+ Args:
1498
+ states: A list of optimizer states.
1499
+ step: Current step number
1500
+ statistics: A list of statistics for all variables (for every dim)
1501
+ num_statistics_per_state: Number of statistis per state to reconstruct
1502
+ output states.
1503
+ original_shapes: A list of shapes of the statistics.
1504
+ exponents: Exponent power to use for inverse-pth roots.
1505
+ max_size: Maximum dim of the statistics to pad.
1506
+ prev_preconditioners: Previously available preconditioner.
1507
+
1508
+ Returns:
1509
+ New optimizer states after computing the preconditioner.
1510
+ """
1511
+ num_devices = lax.psum(1, batch_axis_name)
1512
+ num_statistics = len(statistics)
1513
+ # Pad statistics and exponents to next multiple of num_devices.
1514
+ packed_statistics = [pad_square_matrix(stat, max_size) for stat in statistics]
1515
+ to_pad = -num_statistics % num_devices
1516
+ packed_statistics.extend(
1517
+ [jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad)]
1518
+ )
1519
+ exponents.extend([1 for _ in range(to_pad)])
1520
+
1521
+ if not packed_statistics:
1522
+ return states
1523
+
1524
+ all_statistics = batch(packed_statistics, num_devices)
1525
+ all_exponents = batch(exponents, num_devices)
1526
+
1527
+ def _internal_inverse_pth_root_all():
1528
+ current_replica = lax.axis_index(batch_axis_name)
1529
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
1530
+ all_statistics[current_replica], all_exponents[current_replica]
1531
+ )
1532
+ preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
1533
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1534
+ preconditioners_flat = unbatch(preconditioners)
1535
+ errors_flat = unbatch(errors)
1536
+ return preconditioners_flat, errors_flat
1537
+
1538
+ if preconditioning_compute_steps == 1:
1539
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1540
+ else:
1541
+ # Passing statistics instead of preconditioners as they are similarly
1542
+ # shaped tensors. Note statistics will be ignored as we are passing in
1543
+ # a large init value for error.
1544
+ preconditioners_init = packed_statistics
1545
+ errors_init = [inverse_failure_threshold] * len(packed_statistics)
1546
+ init_state = [preconditioners_init, errors_init]
1547
+ perform_step = step % preconditioning_compute_steps == 0
1548
+ preconditioners_flat, errors_flat = efficient_cond(
1549
+ perform_step, _internal_inverse_pth_root_all, init_state
1550
+ )
1551
+
1552
+ def _skip(error):
1553
+ condition = jnp.logical_or(
1554
+ jnp.isnan(error), error >= inverse_failure_threshold
1555
+ )
1556
+ return condition.astype(error.dtype)
1557
+
1558
+ def _select_preconditioner(error, new_p, old_p):
1559
+ return lax.cond(
1560
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None
1561
+ )
1562
+
1563
+ new_preconditioners_flat = []
1564
+ new_errors_flat = []
1565
+ for p, shape, prev_p, error in zip(
1566
+ preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
1567
+ ):
1568
+ new_preconditioners_flat.append(
1569
+ _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
1570
+ )
1571
+ new_errors_flat.append(error)
1572
+
1573
+ assert len(states) == len(num_statistics_per_state)
1574
+ assert len(new_preconditioners_flat) == num_statistics
1575
+ assert len(new_errors_flat) == num_statistics
1576
+
1577
+ # Add back empty preconditioners so we that we can set the optimizer state.
1578
+ preconditioners_for_states = []
1579
+ idx = 0
1580
+ errors_for_states = []
1581
+ for num_statistics, state in zip(num_statistics_per_state, states):
1582
+ if num_statistics == 0:
1583
+ preconditioners_for_states.append([])
1584
+ errors_for_states.append([])
1585
+ else:
1586
+ preconditioners_for_state = new_preconditioners_flat[
1587
+ idx : idx + num_statistics
1588
+ ]
1589
+ assert len(state.statistics) == len(preconditioners_for_state)
1590
+ preconditioners_for_states.append(preconditioners_for_state)
1591
+
1592
+ errors_for_state = jnp.stack(
1593
+ new_errors_flat[idx : idx + num_statistics]
1594
+ )
1595
+ assert len(state.statistics) == len(errors_for_state)
1596
+ errors_for_states.append(errors_for_state)
1597
+
1598
+ idx += num_statistics
1599
+ new_states = []
1600
+ for state, new_preconditioners, new_errors in zip(
1601
+ states, preconditioners_for_states, errors_for_states
1602
+ ):
1603
+ if state.statistics:
1604
+ new_errors = jnp.where(
1605
+ jnp.logical_and(
1606
+ new_errors > 0.0, new_errors != inverse_failure_threshold
1607
+ ),
1608
+ new_errors,
1609
+ state.training_metrics.inverse_pth_root_errors,
1610
+ )
1611
+ new_training_metrics = TrainingMetrics(new_errors)
1612
+ new_states.append(
1613
+ ParameterStats(
1614
+ state.diagonal_statistics,
1615
+ state.statistics,
1616
+ new_preconditioners,
1617
+ state.diagonal_momentum,
1618
+ state.momentum,
1619
+ new_training_metrics,
1620
+ )
1621
+ )
1622
+
1623
+ return new_states
1624
+
1625
+ def _pmap_quantized_compute_preconditioners(
1626
+ states,
1627
+ step,
1628
+ statistics,
1629
+ num_statistics_per_state,
1630
+ original_shapes,
1631
+ exponents,
1632
+ max_size,
1633
+ prev_preconditioners,
1634
+ ):
1635
+ """Computes preconditioners for given statistics in states in PMAP mode.
1636
+
1637
+ For quantization, each statistic is represented by three values:
1638
+ quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots
1639
+ without ever recreating the original matrix in f32.
1640
+
1641
+ Args:
1642
+ states: A list of optimizer states.
1643
+ step: Current step number
1644
+ statistics: A list of statistics for all variables (for every dim)
1645
+ num_statistics_per_state: Number of statistis per state to reconstruct
1646
+ output states.
1647
+ original_shapes: A list of shapes of the statistics.
1648
+ exponents: Exponent power to use for inverse-pth roots.
1649
+ max_size: Maximum dim of the statistics to pad.
1650
+ prev_preconditioners: Previously available preconditioner.
1651
+
1652
+ Returns:
1653
+ New optimizer states after computing the preconditioner.
1654
+ """
1655
+ num_devices = lax.psum(1, batch_axis_name)
1656
+ num_statistics = len(statistics)
1657
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
1658
+ # Complexity here is around: shapes needing be statically shaped,
1659
+ # our custom quantization type requires a different type of packing.
1660
+
1661
+ # Parallel tensors:
1662
+ # quantized [dxd]
1663
+ # diagonals [d] f32
1664
+ # bucket_sizes [d] f32
1665
+ packed_quantized_statistics = [
1666
+ pad_square_matrix(stat.quantized, max_size) for stat in statistics
1667
+ ]
1668
+ packed_quantized_diagonals = [
1669
+ pad_vector(stat.diagonal, max_size) for stat in statistics
1670
+ ]
1671
+ packed_quantized_bucket_sizes = [
1672
+ pad_vector(stat.bucket_size, max_size) for stat in statistics
1673
+ ]
1674
+
1675
+ to_pad = -num_statistics % num_devices
1676
+ padded_eye = jnp.eye(max_size, dtype=jnp.float32)
1677
+ quantized_eye = QuantizedValue.from_float_value(
1678
+ padded_eye, quantized_dtype, True
1679
+ )
1680
+ packed_quantized_statistics.extend(
1681
+ [quantized_eye.quantized for _ in range(to_pad)]
1682
+ )
1683
+ packed_quantized_diagonals.extend(
1684
+ [quantized_eye.diagonal for _ in range(to_pad)]
1685
+ )
1686
+ packed_quantized_bucket_sizes.extend(
1687
+ [quantized_eye.bucket_size for _ in range(to_pad)]
1688
+ )
1689
+ exponents.extend([1 for _ in range(to_pad)])
1690
+
1691
+ if not packed_quantized_statistics:
1692
+ return states
1693
+
1694
+ all_quantized_statistics = batch(packed_quantized_statistics, num_devices)
1695
+ all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices)
1696
+ all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes, num_devices)
1697
+ all_exponents = batch(exponents, num_devices)
1698
+
1699
+ def _internal_inverse_pth_root_all():
1700
+ current_replica = lax.axis_index(batch_axis_name)
1701
+ (
1702
+ quantized_preconditioners,
1703
+ quantized_diagonals,
1704
+ quantized_bucket_sizes,
1705
+ errors,
1706
+ ) = _quantized_matrix_inverse_pth_root_vmap(
1707
+ all_quantized_statistics[current_replica],
1708
+ all_quantized_diagonals[current_replica],
1709
+ all_quantized_bucket_sizes[current_replica],
1710
+ all_exponents[current_replica],
1711
+ )
1712
+ quantized_preconditioners = jax.lax.all_gather(
1713
+ quantized_preconditioners, batch_axis_name
1714
+ )
1715
+ quantized_diagonals = jax.lax.all_gather(
1716
+ quantized_diagonals, batch_axis_name
1717
+ )
1718
+ quantized_bucket_sizes = jax.lax.all_gather(
1719
+ quantized_bucket_sizes, batch_axis_name
1720
+ )
1721
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1722
+ quantized_preconditioners_flat = unbatch(quantized_preconditioners)
1723
+ quantized_diagonals_flat = unbatch(quantized_diagonals)
1724
+ quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes)
1725
+ errors_flat = unbatch(errors)
1726
+ return (
1727
+ quantized_preconditioners_flat,
1728
+ quantized_diagonals_flat,
1729
+ quantized_bucket_sizes_flat,
1730
+ errors_flat,
1731
+ )
1732
+
1733
+ if preconditioning_compute_steps == 1:
1734
+ (
1735
+ quantized_preconditioners_flat,
1736
+ quantized_diagonals_flat,
1737
+ quantized_bucket_sizes_flat,
1738
+ errors_flat,
1739
+ ) = _internal_inverse_pth_root_all()
1740
+ else:
1741
+ # Passing statistics instead of preconditioners as they are similarly
1742
+ # shaped tensors. Note statistics will be ignored as we are passing in
1743
+ # a large init value for error.
1744
+ quantized_preconditioners_init = packed_quantized_statistics
1745
+ quantized_diagonals_init = packed_quantized_diagonals
1746
+ quantized_bucket_sizes_init = packed_quantized_bucket_sizes
1747
+ errors_init = [inverse_failure_threshold] * len(
1748
+ quantized_preconditioners_init
1749
+ )
1750
+ init_state = [
1751
+ quantized_preconditioners_init,
1752
+ quantized_diagonals_init,
1753
+ quantized_bucket_sizes_init,
1754
+ errors_init,
1755
+ ]
1756
+ perform_step = step % preconditioning_compute_steps == 0
1757
+ (
1758
+ quantized_preconditioners_flat,
1759
+ quantized_diagonals_flat,
1760
+ quantized_bucket_sizes_flat,
1761
+ errors_flat,
1762
+ ) = efficient_cond(perform_step, _internal_inverse_pth_root_all, init_state)
1763
+
1764
+ def _skip(error):
1765
+ condition = jnp.logical_or(
1766
+ jnp.isnan(error), error >= inverse_failure_threshold
1767
+ )
1768
+ return condition.astype(error.dtype)
1769
+
1770
+ def _select_preconditioner(error, new_p, old_p):
1771
+ return lax.cond(
1772
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None
1773
+ )
1774
+
1775
+ new_quantized_preconditioners_flat = []
1776
+ new_quantized_diagonals_flat = []
1777
+ new_quantized_bucket_sizes_flat = []
1778
+ new_errors_flat = []
1779
+ for p, d, b, shape, prev_p, error in zip(
1780
+ quantized_preconditioners_flat,
1781
+ quantized_diagonals_flat,
1782
+ quantized_bucket_sizes_flat,
1783
+ original_shapes,
1784
+ prev_preconditioners,
1785
+ errors_flat,
1786
+ ):
1787
+ new_quantized_preconditioners_flat.append(
1788
+ _select_preconditioner(
1789
+ error, p[: shape[0], : shape[1]], prev_p.quantized
1790
+ )
1791
+ )
1792
+ new_quantized_diagonals_flat.append(
1793
+ _select_preconditioner(error, d[: shape[0]], prev_p.diagonal)
1794
+ )
1795
+ new_quantized_bucket_sizes_flat.append(
1796
+ _select_preconditioner(error, b[: shape[0]], prev_p.bucket_size)
1797
+ )
1798
+ new_errors_flat.append(error)
1799
+
1800
+ assert len(states) == len(num_statistics_per_state)
1801
+ assert len(new_quantized_preconditioners_flat) == num_statistics
1802
+ assert len(new_quantized_diagonals_flat) == num_statistics
1803
+ assert len(new_quantized_bucket_sizes_flat) == num_statistics
1804
+
1805
+ # Add back empty preconditioners so we that we can set the optimizer state.
1806
+ preconditioners_for_states = []
1807
+ errors_for_states = []
1808
+ idx = 0
1809
+ for num_statistics, state in zip(num_statistics_per_state, states):
1810
+ if num_statistics == 0:
1811
+ preconditioners_for_states.append([])
1812
+ errors_for_states.append([])
1813
+ else:
1814
+ quantized_preconditioners_for_state = (
1815
+ new_quantized_preconditioners_flat[idx : idx + num_statistics]
1816
+ )
1817
+ quantized_diagonals_for_state = new_quantized_diagonals_flat[
1818
+ idx : idx + num_statistics
1819
+ ]
1820
+ quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
1821
+ idx : idx + num_statistics
1822
+ ]
1823
+ errors_for_state = jnp.stack(
1824
+ new_errors_flat[idx : idx + num_statistics]
1825
+ )
1826
+
1827
+ assert len(state.statistics) == len(quantized_preconditioners_for_state)
1828
+ assert len(state.statistics) == len(quantized_diagonals_for_state)
1829
+ assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
1830
+ assert len(state.statistics) == len(errors_for_state)
1831
+
1832
+ quantized_preconditioners = []
1833
+ for qv, qd, qb in zip(
1834
+ quantized_preconditioners_for_state,
1835
+ quantized_diagonals_for_state,
1836
+ quantized_bucket_sizes_for_state,
1837
+ ):
1838
+ quantized_preconditioners.append(
1839
+ QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape))
1840
+ )
1841
+ preconditioners_for_states.append(quantized_preconditioners)
1842
+ errors_for_states.append(errors_for_state)
1843
+ idx += num_statistics
1844
+ new_states = []
1845
+ for state, new_preconditioners, new_errors in zip(
1846
+ states, preconditioners_for_states, errors_for_states
1847
+ ):
1848
+ if state.statistics:
1849
+ new_errors = jnp.where(
1850
+ jnp.logical_and(
1851
+ new_errors > 0.0, new_errors != inverse_failure_threshold
1852
+ ),
1853
+ new_errors,
1854
+ state.training_metrics.inverse_pth_root_errors,
1855
+ )
1856
+ new_training_metrics = TrainingMetrics(new_errors)
1857
+ new_states.append(
1858
+ ParameterStats(
1859
+ state.diagonal_statistics,
1860
+ state.statistics,
1861
+ new_preconditioners,
1862
+ state.diagonal_momentum,
1863
+ state.momentum,
1864
+ new_training_metrics,
1865
+ )
1866
+ )
1867
+
1868
+ return new_states
1869
+
1870
+ def _pjit_compute_preconditioners(
1871
+ states,
1872
+ step,
1873
+ statistics,
1874
+ num_statistics_per_state,
1875
+ original_shapes,
1876
+ exponents,
1877
+ max_size,
1878
+ prev_preconditioners,
1879
+ ):
1880
+ """Computes preconditioners for given statistics in states in PJIT mode.
1881
+
1882
+ Args:
1883
+ states: A list of optimizer states.
1884
+ step: Current step number
1885
+ statistics: A list of statistics for all variables (for every dim)
1886
+ num_statistics_per_state: Number of statistis per state to reconstruct
1887
+ output states.
1888
+ original_shapes: A list of shapes of the statistics.
1889
+ exponents: Exponent power to use for inverse-pth roots.
1890
+ max_size: Maximum dim of the statistics to pad.
1891
+ prev_preconditioners: Previously available preconditioner.
1892
+
1893
+ Returns:
1894
+ New optimizer states after computing the preconditioner.
1895
+ """
1896
+ num_statistics = len(statistics)
1897
+ to_pad = -num_statistics % num_devices_for_pjit
1898
+ padded_statistics = [pad_square_matrix(stat, max_size) for stat in statistics]
1899
+ padded_statistics.extend(
1900
+ [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
1901
+ )
1902
+ exponents.extend([1 for _ in range(to_pad)])
1903
+ all_statistics = jnp.stack(padded_statistics)
1904
+ all_exponents = jnp.stack(exponents)
1905
+
1906
+ def _internal_inverse_pth_root_all():
1907
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
1908
+ all_statistics, all_exponents
1909
+ )
1910
+ b1 = preconditioners.shape[0]
1911
+
1912
+ def split(batched_values):
1913
+ return [
1914
+ jnp.squeeze(v)
1915
+ for v in jnp.split(batched_values, indices_or_sections=b1, axis=0)
1916
+ ]
1917
+
1918
+ return split(preconditioners), split(errors)
1919
+
1920
+ if preconditioning_compute_steps == 1:
1921
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1922
+ else:
1923
+ # Passing statistics instead of preconditioners as they are similarly
1924
+ # shaped tensors. Note statistics will be ignored as we are passing in
1925
+ # a large init value for error.
1926
+ preconditioners_init = padded_statistics
1927
+ errors_init = [inverse_failure_threshold] * len(padded_statistics)
1928
+ init_state = [preconditioners_init, errors_init]
1929
+ perform_step = step % preconditioning_compute_steps == 0
1930
+ preconditioners_flat, errors_flat = efficient_cond(
1931
+ perform_step, _internal_inverse_pth_root_all, init_state
1932
+ )
1933
+
1934
+ def _skip(error):
1935
+ condition = jnp.logical_or(
1936
+ jnp.isnan(error), error >= inverse_failure_threshold
1937
+ )
1938
+ return condition.astype(error.dtype)
1939
+
1940
+ def _select_preconditioner(error, new_p, old_p):
1941
+ return lax.cond(
1942
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None
1943
+ )
1944
+
1945
+ new_preconditioners_flat = []
1946
+ new_errors_flat = []
1947
+ for p, shape, prev_p, error in zip(
1948
+ preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
1949
+ ):
1950
+ new_preconditioners_flat.append(
1951
+ _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
1952
+ )
1953
+ new_errors_flat.append(error)
1954
+
1955
+ assert len(states) == len(num_statistics_per_state)
1956
+ assert len(new_preconditioners_flat) == num_statistics
1957
+
1958
+ # Add back empty preconditioners so we that we can set the optimizer state.
1959
+ preconditioners_for_states = []
1960
+ errors_for_states = []
1961
+ idx = 0
1962
+ for num_statistics, state in zip(num_statistics_per_state, states):
1963
+ if num_statistics == 0:
1964
+ preconditioners_for_states.append([])
1965
+ errors_for_states.append([])
1966
+ else:
1967
+ preconditioners_for_state = new_preconditioners_flat[
1968
+ idx : idx + num_statistics
1969
+ ]
1970
+ assert len(state.statistics) == len(preconditioners_for_state)
1971
+ preconditioners_for_states.append(preconditioners_for_state)
1972
+
1973
+ errors_for_state = jnp.stack(
1974
+ new_errors_flat[idx : idx + num_statistics]
1975
+ )
1976
+ assert len(state.statistics) == len(errors_for_state)
1977
+ errors_for_states.append(errors_for_state)
1978
+ idx += num_statistics
1979
+
1980
+ new_states = []
1981
+ for state, new_preconditioners, new_errors in zip(
1982
+ states, preconditioners_for_states, errors_for_states
1983
+ ):
1984
+ if state.statistics:
1985
+ new_errors = jnp.where(
1986
+ jnp.logical_and(
1987
+ new_errors > 0.0, new_errors != inverse_failure_threshold
1988
+ ),
1989
+ new_errors,
1990
+ state.training_metrics.inverse_pth_root_errors,
1991
+ )
1992
+ new_training_metrics = TrainingMetrics(new_errors)
1993
+ new_states.append(
1994
+ ParameterStats(
1995
+ state.diagonal_statistics,
1996
+ state.statistics,
1997
+ new_preconditioners,
1998
+ state.diagonal_momentum,
1999
+ state.momentum,
2000
+ new_training_metrics,
2001
+ )
2002
+ )
2003
+
2004
+ return new_states
2005
+
2006
+ def _compute_preconditioners(states, params, step):
2007
+ """Computes preconditioners for given statistics in states.
2008
+
2009
+ Args:
2010
+ states: A list of optimizer states.
2011
+ params: A list of params.
2012
+ step: Current step number
2013
+
2014
+ Returns:
2015
+ New optimizer states after computing the preconditioner.
2016
+ """
2017
+ statistics = []
2018
+ num_statistics_per_state = []
2019
+ original_shapes = []
2020
+ exponents = []
2021
+ max_size = 0
2022
+ prev_preconditioners = []
2023
+
2024
+ for state, param in zip(states, params):
2025
+ num_statistics = len(state.statistics)
2026
+ num_statistics_per_state.append(num_statistics)
2027
+ original_shapes_for_state = []
2028
+ if num_statistics > 0:
2029
+ preconditioner = Preconditioner(
2030
+ param, block_size, best_effort_shape_interpretation
2031
+ )
2032
+ for statistic in state.statistics:
2033
+ exponents.append(
2034
+ preconditioner.exponent_for_preconditioner()
2035
+ if exponent_override == 0
2036
+ else exponent_override
2037
+ )
2038
+ original_shapes_for_state.append(statistic.shape)
2039
+ max_size = max(max_size, statistic.shape[0])
2040
+
2041
+ statistics.extend(state.statistics)
2042
+ prev_preconditioners.extend(state.preconditioners)
2043
+ original_shapes.extend(original_shapes_for_state)
2044
+
2045
+ if batch_axis_name:
2046
+ # Quantization is only enabled if batch_axis_name is not set.
2047
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
2048
+
2049
+ if quantized_dtype == jnp.float32:
2050
+ return _pmap_compute_preconditioners(
2051
+ states,
2052
+ step,
2053
+ statistics,
2054
+ num_statistics_per_state,
2055
+ original_shapes,
2056
+ exponents,
2057
+ max_size,
2058
+ prev_preconditioners,
2059
+ )
2060
+ else:
2061
+ return _pmap_quantized_compute_preconditioners(
2062
+ states,
2063
+ step,
2064
+ statistics,
2065
+ num_statistics_per_state,
2066
+ original_shapes,
2067
+ exponents,
2068
+ max_size,
2069
+ prev_preconditioners,
2070
+ )
2071
+
2072
+ else:
2073
+ return _pjit_compute_preconditioners(
2074
+ states,
2075
+ step,
2076
+ statistics,
2077
+ num_statistics_per_state,
2078
+ original_shapes,
2079
+ exponents,
2080
+ max_size,
2081
+ prev_preconditioners,
2082
+ )
2083
+
2084
+ def _transform_grad(grad, state, param, step):
2085
+ """Transform per-parameter gradients."""
2086
+ preconditioner = Preconditioner(
2087
+ param, block_size, best_effort_shape_interpretation
2088
+ )
2089
+ sgd_update = grad
2090
+ new_diagonal_statistics = state.diagonal_statistics.to_float()
2091
+ if (
2092
+ graft_type == GraftingType.ADAGRAD
2093
+ or graft_type == GraftingType.ADAGRAD_NORMALIZED
2094
+ ):
2095
+
2096
+ scaled_grad = grad
2097
+ if graft_type == GraftingType.ADAGRAD_NORMALIZED:
2098
+ scaled_grad = grad / (jnp.linalg.norm(grad) + 1e-16)
2099
+
2100
+ new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square(
2101
+ scaled_grad
2102
+ )
2103
+ adagrad_update = scaled_grad / (
2104
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
2105
+ )
2106
+ grafting_update = adagrad_update
2107
+ elif (
2108
+ graft_type == GraftingType.RMSPROP
2109
+ or graft_type == GraftingType.RMSPROP_NORMALIZED
2110
+ ):
2111
+
2112
+ scaled_grad = grad
2113
+ if graft_type == GraftingType.RMSPROP_NORMALIZED:
2114
+ scaled_grad = grad / (jnp.linalg.norm(grad) + 1e-16)
2115
+
2116
+ w1 = beta2
2117
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
2118
+
2119
+ new_diagonal_statistics = (
2120
+ w1 * state.diagonal_statistics.to_float() + w2 * jnp.square(scaled_grad)
2121
+ )
2122
+ rmsprop_update = scaled_grad / (
2123
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
2124
+ )
2125
+
2126
+ if clip_by_scaled_gradient_norm:
2127
+ scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / (
2128
+ jnp.sqrt(float(rmsprop_update.size))
2129
+ )
2130
+ clipping_denom = jnp.maximum(
2131
+ 1.0, scaled_grad_norm / clip_by_scaled_gradient_norm
2132
+ )
2133
+ rmsprop_update /= clipping_denom
2134
+
2135
+ grafting_update = rmsprop_update
2136
+ elif graft_type == GraftingType.SGD:
2137
+ grafting_update = sgd_update
2138
+ else:
2139
+ grafting_update = jnp.ones_like(sgd_update) * jnp.sign(sgd_update)
2140
+
2141
+ precond_grad = grad
2142
+ if not _skip_preconditioning(param):
2143
+ precond_grad = preconditioner.preconditioned_grad(
2144
+ precond_grad, _maybe_dequantize_preconditioners(state.preconditioners)
2145
+ )
2146
+ else:
2147
+ precond_grad = grafting_update
2148
+
2149
+ grafting_update_norm = jnp.linalg.norm(grafting_update)
2150
+ precond_grad_norm = jnp.linalg.norm(precond_grad)
2151
+
2152
+ multiplier = grafting_update_norm / (precond_grad_norm + 1e-16)
2153
+ shampoo_update = precond_grad * multiplier
2154
+
2155
+ shampoo_update_with_wd = shampoo_update
2156
+ grafting_update_with_wd = grafting_update
2157
+ if weight_decay != 0:
2158
+ shampoo_update_with_wd = shampoo_update + weight_decay * param
2159
+ grafting_update_with_wd = grafting_update + weight_decay * param
2160
+
2161
+ w = (1.0 - beta1) if moving_average_for_momentum else 1.0
2162
+
2163
+ shampoo_update_with_wd_momentum = (
2164
+ state.momentum.to_float() * beta1 + w * shampoo_update_with_wd
2165
+ )
2166
+
2167
+ if _graft_type_has_diagonal_momentum_states():
2168
+ grafting_update_with_wd_momentum = (
2169
+ state.diagonal_momentum.to_float() * beta1 + w * grafting_update_with_wd
2170
+ )
2171
+ else:
2172
+ # Share the momentum buffer
2173
+ grafting_update_with_wd_momentum = (
2174
+ state.momentum.to_float() * beta1 + w * grafting_update_with_wd
2175
+ )
2176
+
2177
+ run_shampoo = (step >= start_preconditioning_step).astype(
2178
+ grafting_update_with_wd_momentum.dtype
2179
+ )
2180
+
2181
+ momentum_update = (
2182
+ run_shampoo * shampoo_update_with_wd_momentum
2183
+ + (1.0 - run_shampoo) * grafting_update_with_wd_momentum
2184
+ )
2185
+
2186
+ wd_update = (
2187
+ run_shampoo * shampoo_update_with_wd
2188
+ + (1.0 - run_shampoo) * grafting_update_with_wd
2189
+ )
2190
+
2191
+ nesterov_momentum_update = momentum_update
2192
+ if nesterov:
2193
+ nesterov_momentum_update = w * wd_update + beta1 * momentum_update
2194
+
2195
+ lr = learning_rate
2196
+ if callable(learning_rate):
2197
+ lr = learning_rate(step)
2198
+ transformed_update = -1.0 * lr * nesterov_momentum_update
2199
+
2200
+ new_diagonal_momentum = grafting_update_with_wd_momentum
2201
+ new_momentum = shampoo_update_with_wd_momentum
2202
+ if not _graft_type_has_diagonal_momentum_states():
2203
+ new_diagonal_momentum = []
2204
+ new_momentum = momentum_update
2205
+
2206
+ param_stats = ParameterStats(
2207
+ _quantize_diagonal_statistics(new_diagonal_statistics),
2208
+ state.statistics,
2209
+ state.preconditioners,
2210
+ _quantize_momentum(new_diagonal_momentum),
2211
+ _quantize_momentum(new_momentum),
2212
+ state.training_metrics,
2213
+ )
2214
+
2215
+ return transformed_update, param_stats
2216
+
2217
+ def update_fn(grads, state, params):
2218
+ """Transform the input gradient and update all statistics.
2219
+
2220
+ Args:
2221
+ grads: the gradient tensors for the parameters.
2222
+ state: a named tuple containing the state of the optimizer
2223
+ params: the parameters that should be updated.
2224
+
2225
+ Returns:
2226
+ A tuple containing the new parameters and the new optimizer state.
2227
+ """
2228
+ params_flat, treedef = jax.tree_flatten(params)
2229
+ stats_flat = treedef.flatten_up_to(state.stats)
2230
+ grads_flat = treedef.flatten_up_to(grads)
2231
+
2232
+ new_stats_flat = jax.tree_multimap(
2233
+ lambda g, s, p: _compute_stats(g, s, p, state.count),
2234
+ grads_flat,
2235
+ stats_flat,
2236
+ params_flat,
2237
+ )
2238
+ new_stats_flat = _compute_preconditioners(
2239
+ new_stats_flat, params_flat, state.count
2240
+ )
2241
+ outputs = jax.tree_multimap(
2242
+ lambda g, s, p: _transform_grad(g, s, p, state.count),
2243
+ grads_flat,
2244
+ new_stats_flat,
2245
+ params_flat,
2246
+ )
2247
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
2248
+
2249
+ updates = jax.tree_unflatten(treedef, updates_flat)
2250
+ new_stats = jax.tree_unflatten(treedef, new_stats_flat)
2251
+
2252
+ new_state = ShampooState(count=state.count + 1, stats=new_stats)
2253
+ return updates, new_state
2254
+
2255
+ if shard_optimizer_states:
2256
+ # Hijacks the init_fn signature so we can return an OptState with
2257
+ # appropriate init_fns.
2258
+ def _init_fns(unused_params):
2259
+ return InitFnState(
2260
+ init_fn=sharded_init_fn,
2261
+ pspec_fn=sharded_init_partition_spec_fn,
2262
+ shape_and_dtype_fn=sharded_init_shape_and_dtype_fn,
2263
+ )
2264
+
2265
+ return optax.GradientTransformation(_init_fns, sharded_update_fn)
2266
+ else:
2267
+ return optax.GradientTransformation(init_fn, update_fn)
tools/train/scalable_shampoo/quantization_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Helper routines for quantization."""
17
+
18
+ from typing import Any
19
+
20
+ import chex
21
+ import jax.numpy as jnp
22
+ from flax import struct
23
+
24
+
25
+ # pylint:disable=no-value-for-parameter
26
+ @struct.dataclass
27
+ class QuantizedValue:
28
+ """State associated with quantized value."""
29
+
30
+ quantized: chex.Array
31
+ diagonal: chex.Array # Diagonal (if extract_diagonal is set)
32
+ bucket_size: chex.Array
33
+ quantized_dtype: jnp.dtype = struct.field(
34
+ pytree_node=False
35
+ ) # Dtype for the quantized value.
36
+ extract_diagonal: bool = struct.field(pytree_node=False) # In case its centered.
37
+ shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
38
+
39
+ @classmethod
40
+ def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
41
+ if isinstance(fvalue, list) and not fvalue:
42
+ return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
43
+ quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
44
+ fvalue, quantized_dtype, extract_diagonal
45
+ )
46
+ return QuantizedValue(
47
+ quantized,
48
+ diagonal_fvalue,
49
+ bucket_size,
50
+ quantized_dtype,
51
+ extract_diagonal,
52
+ list(quantized.shape),
53
+ )
54
+
55
+ # Quantization is from Lingvo JAX optimizers.
56
+ # We extend it for int16 quantization of PSD matrices.
57
+ @classmethod
58
+ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
59
+ """Returns quantized value and the bucket."""
60
+ if quantized_dtype == jnp.float32:
61
+ return fvalue, [], []
62
+ elif quantized_dtype == jnp.bfloat16:
63
+ return fvalue.astype(jnp.bfloat16), [], []
64
+
65
+ float_dtype = fvalue.dtype
66
+ if quantized_dtype == jnp.int8:
67
+ # value -128 is not used.
68
+ num_buckets = jnp.array(127.0, dtype=float_dtype)
69
+ elif quantized_dtype == jnp.int16:
70
+ # value -32768 is not used.
71
+ num_buckets = jnp.array(32767.0, dtype=float_dtype)
72
+ else:
73
+ raise ValueError(f"Quantized dtype {quantized_dtype} not supported.")
74
+ # max value is mapped to num_buckets
75
+
76
+ if extract_diagonal and fvalue.ndim != 2:
77
+ raise ValueError(
78
+ f"Input array {fvalue} must be 2D to work with extract_diagonal."
79
+ )
80
+
81
+ diagonal_fvalue = []
82
+ if extract_diagonal:
83
+ diagonal_fvalue = jnp.diag(fvalue)
84
+ # Remove the diagonal entries.
85
+ fvalue = fvalue - jnp.diag(diagonal_fvalue)
86
+
87
+ # TODO(rohananil): Extend this by making use of information about the blocks
88
+ # SM3 style which will be useful for diagonal statistics
89
+ # We first decide the scale.
90
+ if fvalue.ndim < 1:
91
+ raise ValueError(
92
+ f"Input array {fvalue} must have a strictly positive number of "
93
+ "dimensions."
94
+ )
95
+
96
+ max_abs = jnp.max(jnp.abs(fvalue), axis=0)
97
+ bucket_size = max_abs / num_buckets
98
+ bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
99
+ # To avoid divide by 0.0
100
+ bs_nonzero = jnp.where(
101
+ bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)
102
+ )
103
+ ratio = fvalue / bs_nonzero
104
+ # We use rounding to remove bias.
105
+ quantized = jnp.round(ratio)
106
+ return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
107
+
108
+ def to_float(self):
109
+ """Returns the float value."""
110
+ if isinstance(self.quantized, list) and not self.quantized:
111
+ return self.quantized
112
+
113
+ if self.quantized_dtype == jnp.float32:
114
+ return self.quantized
115
+
116
+ if self.quantized_dtype == jnp.bfloat16:
117
+ return self.quantized.astype(jnp.float32)
118
+
119
+ float_dtype = self.bucket_size.dtype
120
+ bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
121
+ val = self.quantized.astype(float_dtype) * bucket_size
122
+ if self.extract_diagonal:
123
+ val += jnp.diag(self.diagonal)
124
+ return val
tools/train/scalable_shampoo/sm3.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # An implementation of SM3 from:
17
+ #
18
+ # Memory-Efficient Adaptive Optimization, https://arxiv.org/pdf/1901.11150.pdf
19
+ # Rohan Anil, Vineet Gupta, Tomer Koren, Yoram Singer
20
+ #
21
+ # Author: Rohan Anil (rohananil at google dot com)
22
+ #
23
+
24
+ """SM3 Implementation."""
25
+
26
+ import functools
27
+ from typing import Any, NamedTuple
28
+
29
+ import chex
30
+ import jax
31
+ import jax.numpy as jnp
32
+ import optax
33
+
34
+ from .quantization_utils import QuantizedValue
35
+
36
+
37
+ class SM3State(NamedTuple):
38
+ count: chex.Array
39
+ stats: Any
40
+
41
+
42
+ # Per parameter optimizer state used in data-parallel training.
43
+ class ParameterStats(NamedTuple):
44
+ """State associated to each parameter of the model being trained."""
45
+
46
+ diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner
47
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
48
+
49
+
50
+ def sm3(
51
+ learning_rate, beta1=0.9, beta2=0.999, diagonal_epsilon=1e-10, normalize_grads=False
52
+ ):
53
+ """SM3 optimizer.
54
+
55
+ Memory-Efficient Adaptive Optimization, Rohan Anil, Vineet Gupta, Tomer Koren,
56
+ Yoram Singer
57
+
58
+ https://arxiv.org/abs/1901.11150
59
+
60
+ Args:
61
+ learning_rate: the step size used to update the parameters.
62
+ beta1: momentum parameter.
63
+ beta2: second moment averaging parameter.
64
+ diagonal_epsilon: epsilon for sm3
65
+ normalize_grads: Whether to normalize grads. Author finds it useful when
66
+ grads are high variance.
67
+
68
+ Returns:
69
+ a GradientTransformation.
70
+ """
71
+
72
+ def _quantize_momentum(momentum_statistics):
73
+ return QuantizedValue.from_float_value(momentum_statistics, jnp.int8)
74
+
75
+ def init_fn(params):
76
+ """Initialise the optimiser's state."""
77
+
78
+ def _init(param):
79
+ accumulators = [jnp.zeros([s]) for s in param.shape]
80
+ momentum = _quantize_momentum(jnp.zeros_like(param))
81
+ return ParameterStats(accumulators, momentum)
82
+
83
+ return SM3State(
84
+ count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)
85
+ )
86
+
87
+ def _get_expanded_shape(shape, i):
88
+ rank = len(shape)
89
+ # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i.
90
+ # For eg: i = 1 returns [1, N, 1].
91
+ return [1] * i + [shape[i]] + [1] * (rank - i - 1)
92
+
93
+ def _moving_averages(grad, accumulators):
94
+ w = (1.0 - beta2) if beta2 != 1.0 else 1.0
95
+ if grad.ndim < 2:
96
+ return beta2 * accumulators[0] + w * grad**2
97
+ else:
98
+ min_accumulator = functools.reduce(jnp.minimum, accumulators)
99
+ return beta2 * min_accumulator + w * grad**2
100
+
101
+ def _moving_averages_momentum(grad, momentum):
102
+ w = (1.0 - beta1) if beta1 != 1.0 else 1.0
103
+ return beta1 * momentum.to_float() + w * grad
104
+
105
+ def _sketch_diagonal_statistics(grad, updated_diagonal_statistics):
106
+ all_diagonal_statistics = []
107
+ for i in range(grad.ndim):
108
+ axes = list(range(i)) + list(range(i + 1, grad.ndim))
109
+ dim_diagonal_statistics = jnp.max(updated_diagonal_statistics, axis=axes)
110
+ all_diagonal_statistics.append(dim_diagonal_statistics)
111
+ if grad.ndim == 1:
112
+ all_diagonal_statistics[0] = updated_diagonal_statistics
113
+ return all_diagonal_statistics
114
+
115
+ def update_fn(updates, state, params=None):
116
+ del params
117
+ stats = state.stats
118
+ if normalize_grads:
119
+ updates = jax.tree_map(lambda g: g / (jnp.linalg.norm(g) + 1e-16), updates)
120
+ # Reshape all vectors into N-d tensors to compute min over them.
121
+ # [n], [m] -> [n, 1], [1, m]
122
+ expanded_diagonal_statistics = jax.tree_multimap(
123
+ lambda grad, state: [ # pylint:disable=g-long-lambda
124
+ jnp.reshape(
125
+ state.diagonal_statistics[i], _get_expanded_shape(grad.shape, i)
126
+ )
127
+ for i in range(grad.ndim)
128
+ ],
129
+ updates,
130
+ stats,
131
+ )
132
+
133
+ # Compute new diagonal statistics
134
+ new_diagonal_statistics = jax.tree_multimap(
135
+ _moving_averages, updates, expanded_diagonal_statistics
136
+ )
137
+
138
+ # Compute preconditioners (1/sqrt(s)) where s is the statistics.
139
+ new_preconditioners = jax.tree_map(
140
+ lambda t: 1.0 / jnp.sqrt(t + diagonal_epsilon), new_diagonal_statistics
141
+ )
142
+ preconditioned_grads = jax.tree_multimap(
143
+ lambda g, p: g * p, updates, new_preconditioners
144
+ )
145
+
146
+ # Compute updated momentum (also handle quantization)
147
+ updated_momentum = jax.tree_multimap(
148
+ lambda preconditioned_grad, state: _moving_averages_momentum( # pylint:disable=g-long-lambda
149
+ preconditioned_grad, state.diagonal_momentum
150
+ ),
151
+ preconditioned_grads,
152
+ stats,
153
+ )
154
+
155
+ # Update diagonal statistics.
156
+ updated_diagonal_statistics = jax.tree_multimap(
157
+ _sketch_diagonal_statistics, updates, new_diagonal_statistics
158
+ )
159
+
160
+ # Update momentum.
161
+ new_sm3_stats = jax.tree_multimap(
162
+ lambda momentum, diagonal_stats: ParameterStats( # pylint:disable=g-long-lambda
163
+ diagonal_stats, _quantize_momentum(momentum)
164
+ ),
165
+ updated_momentum,
166
+ updated_diagonal_statistics,
167
+ )
168
+
169
+ lr = learning_rate
170
+ if callable(learning_rate):
171
+ lr = learning_rate(state.count)
172
+
173
+ new_updates = jax.tree_map(lambda pg: -lr * pg, updated_momentum)
174
+ return new_updates, SM3State(count=state.count + 1, stats=new_sm3_stats)
175
+
176
+ return optax.GradientTransformation(init_fn, update_fn)
tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """JAX Ops for symmetric matrices used by the Shampoo optimizer."""
17
+
18
+ import functools
19
+ from typing import Any, List, Optional, Sequence, Union
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+ from flax import struct
25
+ from jax import lax
26
+
27
+
28
+ @struct.dataclass
29
+ class SlicedSymmetricMatrix:
30
+ """A symmetric matrix represented by lower-triangular block row slices.
31
+
32
+ For example, the symmetric matrix M = [[a, b^T], [b, c]] would be represented
33
+ by the block rows a and [b, c].
34
+
35
+ The matrix may be batched, in which case each entry of block_rows may have
36
+ dimension greater than 2. The last two dimensions represent the rows and cols.
37
+ """
38
+
39
+ block_rows: List[jnp.ndarray]
40
+
41
+
42
+ def product_with_transpose(
43
+ mat1,
44
+ mat2,
45
+ axes,
46
+ precision=lax.Precision.DEFAULT,
47
+ ):
48
+ """Returns mat1 * mat2^T for two matrices (possibly batched).
49
+
50
+ The rows and columns are the last two dimensions for each matrix.
51
+
52
+ Args:
53
+ mat1: First matrix.
54
+ mat2: Second matrix.
55
+ axes: The axes over which to apply the product.
56
+ precision: JAX precision to use for the multiplication.
57
+ """
58
+ return jnp.tensordot(a=mat1, b=mat2, axes=axes, precision=precision)
59
+
60
+
61
+ @functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
62
+ def sliced_transposed_product(
63
+ mat,
64
+ block_size,
65
+ axes=(-1,),
66
+ precision=lax.Precision.DEFAULT,
67
+ ):
68
+ """Returns the blocked slices representing a symmetric contraction.
69
+
70
+ Specifically, the output is a contraction of the input mat with itself, in the
71
+ specified axes.
72
+
73
+ Args:
74
+ mat: The matrix for which we will compute a contraction with itself.
75
+ block_size: The size of row blocks to compute.
76
+ axes: Axes to use for the contraction.
77
+ precision: The precision to use in each computation.
78
+
79
+ Raises:
80
+ ValueError: Raised when the specified block size does not evenly divide
81
+ the number of rows of the input mat.
82
+ """
83
+ rank = len(mat.shape)
84
+
85
+ def _make_axis_positive(ax):
86
+ assert -rank <= ax < rank
87
+ return ax + rank if ax < 0 else ax
88
+
89
+ positive_axes = [_make_axis_positive(ax) for ax in axes]
90
+ assert len(positive_axes) == len(axes)
91
+ remaining_axes = set(range(rank)) - set(positive_axes)
92
+ assert len(remaining_axes) == 1
93
+ remaining_ax = remaining_axes.pop()
94
+
95
+ num_rows = mat.shape[remaining_ax]
96
+ if num_rows % block_size != 0:
97
+ raise ValueError(
98
+ "The row dimension must be divisible by block_size. "
99
+ f"Instead got row dimension={num_rows} and block_size={block_size}."
100
+ )
101
+
102
+ block_rows = []
103
+ for i in range(num_rows // block_size):
104
+ start_indices = [0] * rank
105
+ start_indices[remaining_ax] = i * block_size
106
+
107
+ slice_sizes = list(mat.shape)
108
+ slice_sizes[remaining_ax] = block_size
109
+
110
+ slice_sizes_full = list(mat.shape)
111
+ slice_sizes_full[remaining_ax] = (i + 1) * block_size
112
+
113
+ block_rows.append(
114
+ product_with_transpose(
115
+ lax.dynamic_slice(
116
+ mat, start_indices=start_indices, slice_sizes=slice_sizes
117
+ ),
118
+ lax.dynamic_slice(
119
+ mat, start_indices=[0] * rank, slice_sizes=slice_sizes_full
120
+ ),
121
+ axes=(axes, axes),
122
+ precision=precision,
123
+ )
124
+ )
125
+
126
+ return SlicedSymmetricMatrix(block_rows=block_rows)
127
+
128
+
129
+ @functools.partial(jax.jit, static_argnames=("block_size", "axes", "precision"))
130
+ def sliced_transposed_product_concat(
131
+ mat,
132
+ block_size,
133
+ axes=(-1,),
134
+ precision=lax.Precision.DEFAULT,
135
+ ):
136
+ """Returns the concatenated slices representing mat*mat^T.
137
+
138
+ Args:
139
+ mat: The matrix for which we will compute mat*mat^T. It does not need to be
140
+ square, and may be batched.
141
+ block_size: The size of row blocks to compute.
142
+ axes: Axes to use for the contraction.
143
+ precision: The precision to use in each computation.
144
+
145
+ Raises:
146
+ ValueError: Raised when the specified block size does not evenly divide
147
+ the number of rows of the input mat.
148
+ """
149
+ sliced_symmetric_matrix = sliced_transposed_product(
150
+ mat=mat, block_size=block_size, axes=axes, precision=precision
151
+ )
152
+ return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
153
+
154
+
155
+ @jax.jit
156
+ def materialize_matrix(symmetric_matrix):
157
+ """Returns a materialized symmetric matrix.
158
+
159
+ Args:
160
+ symmetric_matrix: the matrix represented by lower-triangular block slices.
161
+ """
162
+ block_rows = symmetric_matrix.block_rows
163
+ block_size = block_rows[0].shape[-2]
164
+ num_blocks = len(block_rows)
165
+
166
+ # Slice the lower-triangular and diagonal blocks into blocks.
167
+ blocks = [
168
+ [
169
+ block_row[Ellipsis, i * block_size : (i + 1) * block_size]
170
+ for i in range(k + 1)
171
+ ]
172
+ for k, block_row in enumerate(block_rows)
173
+ ]
174
+
175
+ # Generate the (off-diagonal) upper-triangular blocks.
176
+ off_diags = [[] for _ in range(num_blocks - 1)]
177
+ for k, block_row in enumerate(block_rows[1:]):
178
+ for i in range(k + 1):
179
+ off_diags[i].append(
180
+ jnp.swapaxes(
181
+ a=block_row[Ellipsis, i * block_size : (i + 1) * block_size],
182
+ axis1=-1,
183
+ axis2=-2,
184
+ )
185
+ )
186
+
187
+ return jnp.block(
188
+ [row + row_t for row, row_t in zip(blocks[:-1], off_diags)] + [blocks[-1]]
189
+ )
190
+
191
+
192
+ @functools.partial(jax.jit, static_argnames=("num_blocks"))
193
+ def materialize_matrix_from_concat(
194
+ block_rows_concat,
195
+ num_blocks=None,
196
+ ):
197
+ """Returns a materialized symmetric matrix from concatenated slices.
198
+
199
+ Args:
200
+ block_rows_concat: The matrix represented as the concatenated
201
+ lower-triangular blocks.
202
+ num_blocks: The number of block-rows used to represent the symmetric matrix.
203
+ If not specified, it is inferred from the shape of block_rows_concat.
204
+ """
205
+ if num_blocks is None:
206
+ num_blocks = find_num_blocks(block_rows_concat)
207
+
208
+ block_size = block_rows_concat.shape[-2]
209
+
210
+ block_rows = [
211
+ block_rows_concat[
212
+ Ellipsis,
213
+ (k * (k + 1))
214
+ // 2
215
+ * block_size : (((k + 1) * (k + 2)) // 2 + 1)
216
+ * block_size,
217
+ ]
218
+ for k in range(num_blocks)
219
+ ]
220
+
221
+ return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows))
222
+
223
+
224
+ @functools.partial(jax.jit, static_argnames=("alpha", "beta", "axes"))
225
+ def update_sliced_rows(
226
+ symmetric_matrix,
227
+ mat,
228
+ alpha,
229
+ beta,
230
+ axes=(-1,),
231
+ ):
232
+ """Implements the blocked equivalent of SYRK.
233
+
234
+ Specifically, the symmetric matrix (represented using lower-triangular block
235
+ rows) is updated using the sliced product of mat.
236
+
237
+ Args:
238
+ symmetric_matrix: The symmetric matrix to update.
239
+ mat: The matrix to use for the update = mat * mat^T. The number of rows
240
+ should match that of symmetric_matrix.
241
+ alpha: The weight for the update.
242
+ beta: The weight for the original symmetric matrix.
243
+ axes: Axes to use for the contraction of the update.
244
+
245
+ Returns:
246
+ The updated rows of alpha * mat * mat^T + beta * symmetric_matrix.
247
+ """
248
+ block_size = symmetric_matrix.block_rows[0].shape[-2]
249
+ sym_prod = sliced_transposed_product(mat=mat, block_size=block_size, axes=axes)
250
+ return SlicedSymmetricMatrix(
251
+ block_rows=[
252
+ update * alpha + row * beta
253
+ for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows)
254
+ ]
255
+ )
256
+
257
+
258
+ def num_blocks_from_total_blocks(total_blocks):
259
+ """Returns the number of blocks (i.e.
260
+
261
+ block rows) from the total blocks.
262
+
263
+ This is the inverse of the function x -> x*(x+1)/2.
264
+
265
+ For example, the matrix M = [[A, B^T], [B, C]] may be represented using a
266
+ total of 3 blocks ([A, B, C]). The number of corresponding block rows is 2.
267
+
268
+ Args:
269
+ total_blocks: The total blocks used to represent the matrix.
270
+ """
271
+ num_blocks = np.round((np.sqrt(8 * total_blocks + 1) - 1) / 2).astype(np.int32)
272
+ if (num_blocks * (num_blocks + 1)) / 2 != total_blocks:
273
+ raise ValueError(
274
+ f"total_blocks={total_blocks} does not correspond to "
275
+ "a symmetric matrix. It must have the form total_blocks = x*(x+1)/2."
276
+ )
277
+ return num_blocks
278
+
279
+
280
+ def find_num_blocks(block_rows_concat):
281
+ """Returns the number of (row) blocks representing the concatenated matrix.
282
+
283
+ For example, an input with dimensions [256, 2560] represents 10 square blocks,
284
+ which matches 4 lower-triangular block rows (1+2+3+4). So this function will
285
+ return 4.
286
+
287
+ Use ordinary numpy functions here so that the returned value is static.
288
+
289
+ Args:
290
+ block_rows_concat: The concatenated block array.
291
+
292
+ Raises:
293
+ ValueError: When the dimensions of the matrix do not correspond to a lower
294
+ triangular block representation.
295
+ """
296
+ # Compute the number of square blocks used to represent the matrix.
297
+ total_blocks = block_rows_concat.shape[-1] / block_rows_concat.shape[-2]
298
+ # Determine the number of block rows by inverting y = x*(x+1)/2.
299
+ return num_blocks_from_total_blocks(total_blocks)
300
+
301
+
302
+ @functools.partial(jax.jit, static_argnames=("block_size"))
303
+ def slice_symmetric_matrix(
304
+ mat,
305
+ block_size,
306
+ ):
307
+ """Returns sliced row blocks.
308
+
309
+ Args:
310
+ mat: A symmetric matrix.
311
+ block_size: The size of the row slices.
312
+ """
313
+ num_rows = mat.shape[-2]
314
+ num_cols = mat.shape[-1]
315
+ if num_rows != num_cols:
316
+ raise ValueError("mat is not square.")
317
+ if num_rows % block_size != 0:
318
+ raise ValueError(
319
+ "block size does not evenly divide rows. "
320
+ f"num_rows={num_rows}, block_size={block_size}"
321
+ )
322
+ return SlicedSymmetricMatrix(
323
+ block_rows=[
324
+ mat[
325
+ Ellipsis,
326
+ i * block_size : (i + 1) * block_size,
327
+ 0 : (i + 1) * block_size,
328
+ ]
329
+ for i in range(num_rows // block_size)
330
+ ]
331
+ )
332
+
333
+
334
+ @functools.partial(jax.jit, static_argnames=("block_size"))
335
+ def slice_symmetric_matrix_concat(
336
+ mat,
337
+ block_size,
338
+ ):
339
+ """Returns the concatenated sliced row blocks.
340
+
341
+ Args:
342
+ mat: A symmetric matrix.
343
+ block_size: The size of the row slices.
344
+ """
345
+ sliced_symmetric_matrix = slice_symmetric_matrix(mat=mat, block_size=block_size)
346
+ return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
347
+
348
+
349
+ def sliced_matrix_diag(mat):
350
+ """Returns the diagonal of the symmetric matrix.
351
+
352
+ Args:
353
+ mat: The symmetric matrix represented in concatenated block form.
354
+ """
355
+ rows, cols = mat.shape
356
+ total_blocks = cols // rows
357
+ num_blocks = num_blocks_from_total_blocks(total_blocks)
358
+ diags = []
359
+ for i in range(num_blocks):
360
+ last_index = rows * ((i + 2) * (i + 1)) // 2
361
+ first_index = last_index - rows
362
+ diags.append(jnp.diag(mat[Ellipsis, first_index:last_index]))
363
+ return jnp.concatenate(diags, axis=-1)
364
+
365
+
366
+ def diag_as_concat(diag, block_size):
367
+ """Returns the representation of a diagonal matrix in symmetric block form.
368
+
369
+ Args:
370
+ diag: The 1D array for the diagonals.
371
+ block_size: The size of blocks to use. Must divide the length of diag.
372
+ """
373
+ assert len(diag.shape) == 1 # diag must be 1D.
374
+ assert len(diag) % block_size == 0
375
+ num_diag_blocks = len(diag) // block_size
376
+ blocks = []
377
+ for i in range(num_diag_blocks):
378
+ blocks.append(jnp.zeros(shape=(block_size, block_size * i), dtype=diag.dtype))
379
+ blocks.append(jnp.diag(diag[i * block_size : (i + 1) * block_size]))
380
+ return jnp.concatenate(blocks, axis=-1)
381
+
382
+
383
+ def row_abs_maxes(mat):
384
+ """Returns the max of the absolute values of the rows of the full matrix.
385
+
386
+ For example the symmetric matrix M = [[1, 6], [6, 2]] is represented using
387
+ mat = [1, 6, 2] with block_size = 1. In this case the function returns the
388
+ aboslute row maxes of the original symmetric matrix, [6, 6].
389
+
390
+ Args:
391
+ mat: The symmetric matrix represented as the concatenated blocks.
392
+ """
393
+ rows, cols = mat.shape
394
+
395
+ # Find col and row max for each block.
396
+ col_maxes = []
397
+ row_maxes = []
398
+ for i in range(cols // rows):
399
+ block = jnp.abs(mat[Ellipsis, i * rows : (i + 1) * rows])
400
+ col_maxes.append(jnp.max(block, axis=1))
401
+ row_maxes.append(jnp.max(block, axis=0))
402
+
403
+ # global row max from block maxes.
404
+ num_blocks = num_blocks_from_total_blocks(cols // rows)
405
+ maxes = []
406
+ for i in range(num_blocks):
407
+ maxes.append(
408
+ jnp.concatenate(
409
+ row_maxes[(i * (i + 1) // 2) : ((i + 2) * (i + 1) // 2)]
410
+ + [
411
+ col_maxes[((j + 1) * (j + 2)) // 2 - (j - i + 1)]
412
+ for j in range(i + 1, num_blocks)
413
+ ],
414
+ axis=-1,
415
+ )
416
+ )
417
+
418
+ return jnp.max(jnp.stack(maxes), axis=0)
419
+
420
+
421
+ def times_vector(mat, vec):
422
+ """Returns the symmetric block-concatenated matrix multiplied by a vector.
423
+
424
+ Specifically, each value in the vector is multiplied by a row of the full
425
+ matrix. That is, the vector is broadcast and multiplied element-wise. Note
426
+ this would be the transpose of full_mat * vec if full_mat represented the full
427
+ symmetric matrix.
428
+
429
+ Args:
430
+ mat: The symmetric matrix represented as the concatenated blocks.
431
+ vec: The vector, having the same dimension as the materialized matrix.
432
+ """
433
+ rows, cols = mat.shape
434
+ num_blocks = num_blocks_from_total_blocks(cols // rows)
435
+ multiplied = []
436
+ for i in range(num_blocks):
437
+ mat_block = mat[
438
+ Ellipsis, rows * ((i + 1) * i) // 2 : rows * ((i + 1) * (i + 2)) // 2
439
+ ]
440
+ vec_block = vec[Ellipsis, rows * i : rows * (i + 1)]
441
+ multiplied.append(jnp.einsum("...ij,...i->ij", mat_block, vec_block))
442
+ return jnp.concatenate(multiplied, axis=-1)
tools/train/sweep.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ program: train.py
2
+ project: dalle-mini
3
+ method: random
4
+ metric:
5
+ name: eval/loss
6
+ goal: minimize
7
+ parameters:
8
+ optim:
9
+ value: distributed_shampoo
10
+ learning_rate:
11
+ distribution: log_uniform
12
+ # from exp(min) to exp(max)
13
+ min: -9.2
14
+ max: -6.9
15
+ tokenizer_name:
16
+ value: boris/dalle-mini-tokenizer
17
+ config_name:
18
+ value: ./config/mini
19
+ dtype:
20
+ value: bfloat16
21
+ dataset_repo_or_path:
22
+ value: ./data
23
+ per_device_train_batch_size:
24
+ value: 64
25
+ per_device_eval_batch_size:
26
+ value: 64
27
+ gradient_accumulation_steps:
28
+ value: 1
29
+ warmup_steps:
30
+ value: 1000
31
+ num_train_epochs:
32
+ value: 1
33
+ max_train_samples:
34
+ value: 1000000
35
+ logging_steps:
36
+ value: 40
37
+ eval_steps:
38
+ value: 200
39
+
40
+ command:
41
+ - python3
42
+ - ${program}
43
+ - "--streaming"
44
+ - "--output_dir"
45
+ - "./output"
46
+ - "--overwrite_output_dir"
47
+ - "--do_train"
48
+ - "--do_eval"
49
+ - ${args}
tools/train/train.py ADDED
@@ -0,0 +1,1436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021-2022 The HuggingFace & DALL·E Mini team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training DALL·E Mini.
18
+ Script adapted from run_summarization_flax.py
19
+ """
20
+
21
+ import io
22
+ import logging
23
+ import os
24
+ import sys
25
+ import tempfile
26
+ import time
27
+ from dataclasses import asdict, dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any, Callable, NamedTuple, Optional
30
+
31
+ import datasets
32
+ import flax
33
+ import jax
34
+ import jax.numpy as jnp
35
+ import jaxlib
36
+ import numpy as np
37
+ import optax
38
+ import transformers
39
+ import wandb
40
+ from datasets import Dataset
41
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
42
+ from flax.serialization import from_bytes, to_bytes
43
+ from flax.training import train_state
44
+ from flax.training.common_utils import onehot
45
+ from jax.experimental import PartitionSpec, maps
46
+ from jax.experimental.compilation_cache import compilation_cache as cc
47
+ from jax.experimental.pjit import pjit, with_sharding_constraint
48
+ from scalable_shampoo.distributed_shampoo import GraftingType, distributed_shampoo
49
+ from tqdm import tqdm
50
+ from transformers import HfArgumentParser
51
+
52
+ import dalle_mini
53
+ from dalle_mini.data import Dataset
54
+ from dalle_mini.model import (
55
+ DalleBart,
56
+ DalleBartConfig,
57
+ DalleBartTokenizer,
58
+ set_partitions,
59
+ )
60
+
61
+ try:
62
+ from google.cloud import storage
63
+ except:
64
+ storage = None
65
+
66
+ cc.initialize_cache("./jax_cache", max_cache_size_bytes=10 * 2**30)
67
+
68
+ logger = logging.getLogger(__name__)
69
+
70
+
71
+ @dataclass
72
+ class ModelArguments:
73
+ """
74
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
75
+ """
76
+
77
+ model_name_or_path: Optional[str] = field(
78
+ default=None,
79
+ metadata={
80
+ "help": "The model checkpoint for weights initialization. "
81
+ "Don't set if you want to train a model from scratch. "
82
+ "W&B artifact references are supported in addition to the sources supported by `PreTrainedModel`."
83
+ },
84
+ )
85
+ config_name: Optional[str] = field(
86
+ default=None,
87
+ metadata={
88
+ "help": "Pretrained config name or path if not the same as model_name_or_path"
89
+ },
90
+ )
91
+ tokenizer_name: Optional[str] = field(
92
+ default=None,
93
+ metadata={
94
+ "help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
95
+ },
96
+ )
97
+ dtype: Optional[str] = field(
98
+ default="float32",
99
+ metadata={
100
+ "help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
101
+ },
102
+ )
103
+ restore_state: Optional[bool] = field(
104
+ default=False,
105
+ metadata={
106
+ "help": "Restore optimizer and training state. Can be True (will retrieve associated wandb artifact), a local directory or a Google bucket path."
107
+ },
108
+ )
109
+
110
+ def __post_init__(self):
111
+ if self.tokenizer_name is None:
112
+ self.tokenizer_name = self.model_name_or_path
113
+ assert (
114
+ self.tokenizer_name is not None
115
+ ), "Tokenizer name or model name/path needs to be specified"
116
+ if self.restore_state:
117
+ assert self.model_name_or_path is not None and (
118
+ "/model-" in self.model_name_or_path
119
+ ), "Restoring state only available with W&B artifact reference"
120
+
121
+ def get_metadata(self):
122
+ if self.restore_state:
123
+ if jax.process_index() == 0:
124
+ artifact = wandb.run.use_artifact(self.model_name_or_path)
125
+ else:
126
+ artifact = wandb.Api().artifact(self.model_name_or_path)
127
+ return artifact.metadata
128
+ else:
129
+ return dict()
130
+
131
+ def get_opt_state(self):
132
+ with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
133
+ if self.restore_state is True:
134
+ # wandb artifact
135
+ state_artifact = self.model_name_or_path.replace(
136
+ "/model-", "/state-", 1
137
+ )
138
+ if jax.process_index() == 0:
139
+ artifact = wandb.run.use_artifact(state_artifact)
140
+ else:
141
+ artifact = wandb.Api().artifact(state_artifact)
142
+ if artifact.metadata.get("bucket_path"):
143
+ # we will read directly file contents
144
+ self.restore_state = artifact.metadata["bucket_path"]
145
+ else:
146
+ artifact_dir = artifact.download(tmp_dir)
147
+ self.restore_state = str(Path(artifact_dir) / "opt_state.msgpack")
148
+
149
+ if self.restore_state.startswith("gs://"):
150
+ bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
151
+ bucket, blob_name = str(bucket_path).split("/", 1)
152
+ assert (
153
+ storage is not None
154
+ ), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
155
+ client = storage.Client()
156
+ bucket = client.bucket(bucket)
157
+ blob = bucket.blob(blob_name)
158
+ return blob.download_as_bytes()
159
+
160
+ with Path(self.restore_state).open("rb") as f:
161
+ return f.read()
162
+
163
+
164
+ @dataclass
165
+ class DataTrainingArguments:
166
+ """
167
+ Arguments pertaining to what data we are going to input our model for training and eval.
168
+ """
169
+
170
+ text_column: Optional[str] = field(
171
+ default="caption",
172
+ metadata={
173
+ "help": "The name of the column in the datasets containing the full texts (for summarization)."
174
+ },
175
+ )
176
+ encoding_column: Optional[str] = field(
177
+ default="encoding",
178
+ metadata={
179
+ "help": "The name of the column in the datasets containing the image encodings."
180
+ },
181
+ )
182
+ dataset_repo_or_path: str = field(
183
+ default=None,
184
+ metadata={"help": "The dataset repository containing encoded files."},
185
+ )
186
+ train_file: Optional[str] = field(
187
+ default=None,
188
+ metadata={
189
+ "help": "The input training data file (glob & braceexpand acceptable)."
190
+ },
191
+ )
192
+ validation_file: Optional[str] = field(
193
+ default=None,
194
+ metadata={
195
+ "help": "An optional input evaluation data file (glob & braceexpand acceptable)."
196
+ },
197
+ )
198
+ # data loading should not be a bottleneck so we use "streaming" mode by default
199
+ streaming: Optional[bool] = field(
200
+ default=True,
201
+ metadata={"help": "Whether to stream the dataset."},
202
+ )
203
+ use_auth_token: Optional[bool] = field(
204
+ default=False,
205
+ metadata={
206
+ "help": "Whether to use the authentication token for private datasets."
207
+ },
208
+ )
209
+ shard_by_host: Optional[bool] = field(
210
+ default=False,
211
+ metadata={
212
+ "help": "Whether to shard data files by host in multi-host environments."
213
+ },
214
+ )
215
+ blank_caption_prob: Optional[float] = field(
216
+ default=0.0,
217
+ metadata={
218
+ "help": "Probability of removing some captions for classifier-free guidance."
219
+ },
220
+ )
221
+ clip_score_column: Optional[str] = field(
222
+ default="clip_score",
223
+ metadata={"help": "Column that containts clip score for filtering."},
224
+ )
225
+ min_clip_score: Optional[float] = field(
226
+ default=None,
227
+ metadata={"help": "Minimum clip score required."},
228
+ )
229
+ max_clip_score: Optional[float] = field(
230
+ default=None,
231
+ metadata={"help": "Maximum clip score required."},
232
+ )
233
+ filter_column: Optional[str] = field(
234
+ default=None,
235
+ metadata={"help": "Column that containts classes to be filtered."},
236
+ )
237
+ filter_value: Optional[str] = field(
238
+ default=None,
239
+ metadata={"help": "Class value to be kept during filtering."},
240
+ )
241
+ max_train_samples: Optional[int] = field(
242
+ default=None,
243
+ metadata={
244
+ "help": "For debugging purposes or quicker training, truncate the number of training examples."
245
+ },
246
+ )
247
+ max_eval_samples: Optional[int] = field(
248
+ default=None,
249
+ metadata={
250
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples."
251
+ },
252
+ )
253
+ preprocessing_num_workers: Optional[int] = field(
254
+ default=None,
255
+ metadata={
256
+ "help": "The number of processes to use for the preprocessing. Not used in streaming mode."
257
+ },
258
+ )
259
+ overwrite_cache: bool = field(
260
+ default=False,
261
+ metadata={
262
+ "help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
263
+ },
264
+ )
265
+ # default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
266
+ seed_dataset: int = field(
267
+ default=None,
268
+ metadata={
269
+ "help": "Random seed for the dataset that will be set at the beginning of training."
270
+ },
271
+ )
272
+
273
+ def __post_init__(self):
274
+ if self.dataset_repo_or_path is None:
275
+ raise ValueError("Need a dataset repository or path.")
276
+
277
+
278
+ @dataclass
279
+ class TrainingArguments:
280
+ """
281
+ Arguments pertaining to training parameters.
282
+ """
283
+
284
+ output_dir: str = field(
285
+ metadata={
286
+ "help": "The output directory where the model predictions and checkpoints will be written."
287
+ },
288
+ )
289
+ overwrite_output_dir: bool = field(
290
+ default=False,
291
+ metadata={
292
+ "help": (
293
+ "Overwrite the content of the output directory. "
294
+ "Use this to continue training if output_dir points to a checkpoint directory."
295
+ )
296
+ },
297
+ )
298
+
299
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
300
+ do_eval: bool = field(
301
+ default=False, metadata={"help": "Whether to run eval on the validation set."}
302
+ )
303
+
304
+ per_device_train_batch_size: int = field(
305
+ default=8,
306
+ metadata={"help": "Batch size per data parallel device for training."},
307
+ )
308
+ per_device_eval_batch_size: Optional[int] = field(
309
+ default=None,
310
+ metadata={
311
+ "help": "Batch size per data parallel device for evaluation. Same as training batch size if not set."
312
+ },
313
+ )
314
+
315
+ gradient_accumulation_steps: int = field(
316
+ default=1,
317
+ metadata={
318
+ "help": "Number of updates steps to accumulate before performing an update pass."
319
+ },
320
+ )
321
+ gradient_checkpointing: bool = field(
322
+ default=False, metadata={"help": "Use gradient checkpointing."}
323
+ )
324
+
325
+ learning_rate: float = field(
326
+ default=5e-5, metadata={"help": "The initial learning rate."}
327
+ )
328
+ optim: str = field(
329
+ default="distributed_shampoo",
330
+ metadata={
331
+ "help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
332
+ },
333
+ )
334
+ beta1: float = field(
335
+ default=0.9,
336
+ metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
337
+ )
338
+ beta2: float = field(
339
+ default=0.999,
340
+ metadata={"help": "Beta2 for for Adam & Distributed Shampoo."},
341
+ )
342
+ adam_epsilon: float = field(
343
+ default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
344
+ )
345
+ max_grad_norm: float = field(
346
+ default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
347
+ )
348
+ block_size: int = field(
349
+ default=1024,
350
+ metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
351
+ )
352
+ preconditioning_compute_steps: int = field(
353
+ default=10, metadata={"help": "Number of steps to update preconditioner."}
354
+ )
355
+ skip_preconditioning_dim_size_gt: int = field(
356
+ default=4096,
357
+ metadata={"help": "Max size for preconditioning with Distributed Shampoo."},
358
+ )
359
+ graft_type: str = field(
360
+ default="rmsprop_normalized",
361
+ metadata={
362
+ "help": "The type of grafting to use. Can be 'rmsprop_normalized' (default), 'rmsprop', 'adagrad', 'adagrad_normalized', 'sgd' or 'sqrt_n'"
363
+ },
364
+ )
365
+ optim_quantized: bool = field(
366
+ default=False,
367
+ metadata={
368
+ "help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
369
+ },
370
+ )
371
+
372
+ num_train_epochs: int = field(
373
+ default=3, metadata={"help": "Total number of training epochs to perform."}
374
+ )
375
+
376
+ warmup_steps: int = field(
377
+ default=0, metadata={"help": "Linear warmup over warmup_steps."}
378
+ )
379
+ lr_decay: str = field(
380
+ default=None,
381
+ metadata={
382
+ "help": "Decay to be used in the learning rate scheduler. Can be None (default), linear or exponential."
383
+ },
384
+ )
385
+ lr_transition_steps: int = field(
386
+ default=None,
387
+ metadata={
388
+ "help": "Number of transition steps associated with learning rate decay when using exponential decay."
389
+ },
390
+ )
391
+ lr_decay_rate: float = field(
392
+ default=None,
393
+ metadata={
394
+ "help": "Decay rate associated with learning rate when using exponential decay."
395
+ },
396
+ )
397
+ lr_staircase: bool = field(
398
+ default=False,
399
+ metadata={
400
+ "help": "Whether to use staircase or continuous learning rate when using exponential decay."
401
+ },
402
+ )
403
+
404
+ logging_steps: int = field(
405
+ default=40, metadata={"help": "Log every X updates steps."}
406
+ )
407
+ eval_steps: int = field(
408
+ default=400, metadata={"help": "Run an evaluation every X steps."}
409
+ )
410
+ save_steps: int = field(
411
+ default=4000, metadata={"help": "Save checkpoint every X updates steps."}
412
+ )
413
+ log_model: bool = field(
414
+ default=False,
415
+ metadata={"help": "Log model to wandb at `save_steps` frequency."},
416
+ )
417
+ log_norm_steps: int = field(
418
+ default=True,
419
+ metadata={"help": "Log parameters and gradients norm at this frequency."},
420
+ )
421
+ log_histogram_steps: int = field(
422
+ default=False,
423
+ metadata={
424
+ "help": "Log parameters and gradients histograms at this frequency. Slows down training."
425
+ },
426
+ )
427
+
428
+ seed_model: int = field(
429
+ default=42,
430
+ metadata={
431
+ "help": "Random seed for the model that will be set at the beginning of training."
432
+ },
433
+ )
434
+
435
+ wandb_entity: Optional[str] = field(
436
+ default=None,
437
+ metadata={"help": "The wandb entity to use (for teams)."},
438
+ )
439
+ wandb_project: str = field(
440
+ default="dalle-mini",
441
+ metadata={"help": "The name of the wandb project."},
442
+ )
443
+ wandb_job_type: str = field(
444
+ default="Seq2Seq",
445
+ metadata={"help": "The name of the wandb job type."},
446
+ )
447
+
448
+ assert_TPU_available: bool = field(
449
+ default=False,
450
+ metadata={"help": "Verify that TPU is not in use."},
451
+ )
452
+
453
+ mp_devices: Optional[int] = field(
454
+ default=1,
455
+ metadata={
456
+ "help": "Number of devices required for model parallelism. The other dimension of available devices is used for data parallelism."
457
+ },
458
+ )
459
+
460
+ dp_devices: int = field(init=False)
461
+
462
+ def __post_init__(self):
463
+ if self.assert_TPU_available:
464
+ assert (
465
+ jax.local_device_count() == 8
466
+ ), "TPUs in use, please check running processes"
467
+ if self.output_dir.startswith("gs://"):
468
+ assert (
469
+ storage is not None
470
+ ), 'Could not find google.storage. Install with "pip install google-cloud-storage"'
471
+ assert self.optim in [
472
+ "distributed_shampoo",
473
+ "adam",
474
+ "adafactor",
475
+ ], f"Selected optimizer not supported: {self.optim}"
476
+ assert self.graft_type in [
477
+ "rmsprop_normalized",
478
+ "rmsprop",
479
+ "adagrad",
480
+ "adagrad_normalized",
481
+ "sgd",
482
+ "sqrt_n",
483
+ ], f"Selected graft type not supported: {self.graft_type}"
484
+ assert self.lr_decay in [
485
+ None,
486
+ "linear",
487
+ "exponential",
488
+ ], f"Selected learning rate decay not supported: {self.lr_decay}"
489
+ if self.per_device_eval_batch_size is None:
490
+ self.per_device_eval_batch_size = self.per_device_train_batch_size
491
+ if self.log_norm_steps is True:
492
+ self.log_norm_steps = self.logging_steps
493
+ if (
494
+ os.path.exists(self.output_dir)
495
+ and os.listdir(self.output_dir)
496
+ and self.do_train
497
+ and not self.overwrite_output_dir
498
+ ):
499
+ raise ValueError(
500
+ f"Output directory ({self.output_dir}) already exists and is not empty."
501
+ "Use --overwrite_output_dir to overcome."
502
+ )
503
+ assert (
504
+ self.mp_devices > 0
505
+ ), f"Number of devices for model parallelism must be > 0"
506
+ assert (
507
+ jax.device_count() % self.mp_devices == 0
508
+ ), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})."
509
+ self.dp_devices = jax.device_count() // self.mp_devices
510
+
511
+
512
+ class TrainState(train_state.TrainState):
513
+ dropout_rng: jnp.ndarray = None
514
+ epoch: int = 0
515
+ train_time: float = 0.0 # total time the model trained
516
+ train_samples: int = 0 # number of samples seen
517
+
518
+
519
+ def main():
520
+ # See all possible arguments by passing the --help flag to this script.
521
+ parser = HfArgumentParser(
522
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
523
+ )
524
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
525
+ # If we pass only one argument to the script and it's the path to a json file,
526
+ # let's parse it to get our arguments.
527
+ model_args, data_args, training_args = parser.parse_json_file(
528
+ json_file=os.path.abspath(sys.argv[1])
529
+ )
530
+ else:
531
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
532
+
533
+ # Make one log on every process with the configuration for debugging.
534
+ logging.basicConfig(
535
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
536
+ datefmt="%m/%d/%Y %H:%M:%S",
537
+ level=logging.INFO,
538
+ )
539
+ # Setup logging, we only want one process per machine to log things on the screen.
540
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
541
+ if jax.process_index() == 0:
542
+ datasets.utils.logging.set_verbosity_warning()
543
+ transformers.utils.logging.set_verbosity_info()
544
+ else:
545
+ datasets.utils.logging.set_verbosity_error()
546
+ transformers.utils.logging.set_verbosity_error()
547
+
548
+ # Set the verbosity to info of the Transformers logger (on main process only):
549
+ logger.info(f"Training/evaluation parameters {training_args}")
550
+
551
+ # Load dataset
552
+ dataset = Dataset(
553
+ **asdict(data_args),
554
+ do_train=training_args.do_train,
555
+ do_eval=training_args.do_eval,
556
+ )
557
+
558
+ logger.info(f"Local TPUs: {jax.local_device_count()}")
559
+ logger.info(f"Global TPUs: {jax.device_count()}")
560
+
561
+ # Set up wandb run
562
+ if jax.process_index() == 0:
563
+ wandb.init(
564
+ entity=training_args.wandb_entity,
565
+ project=training_args.wandb_project,
566
+ job_type=training_args.wandb_job_type,
567
+ config=parser.parse_args(),
568
+ )
569
+
570
+ # Set up our new model config
571
+ if model_args.config_name:
572
+ config = DalleBartConfig.from_pretrained(model_args.config_name)
573
+ config.gradient_checkpointing = training_args.gradient_checkpointing
574
+ else:
575
+ config = None
576
+
577
+ # Load or create new model
578
+ if model_args.model_name_or_path:
579
+ model = DalleBart.from_pretrained(
580
+ model_args.model_name_or_path,
581
+ config=config,
582
+ seed=training_args.seed_model,
583
+ dtype=getattr(jnp, model_args.dtype),
584
+ abstract_init=True, # we overwrite them with loaded checkpoint
585
+ gradient_checkpointing=training_args.gradient_checkpointing,
586
+ )
587
+ else:
588
+ model = DalleBart(
589
+ config,
590
+ seed=training_args.seed_model,
591
+ dtype=getattr(jnp, model_args.dtype),
592
+ abstract_init=True,
593
+ )
594
+
595
+ # get model metadata
596
+ model_metadata = model_args.get_metadata()
597
+
598
+ # get PartitionSpec for model params (required to be a dict)
599
+ param_spec = set_partitions(model.params)
600
+
601
+ # convert params to frozen dict
602
+ model._params = freeze(model.params)
603
+
604
+ # Load tokenizer
605
+ tokenizer = DalleBartTokenizer.from_pretrained(
606
+ model_args.tokenizer_name, use_fast=True
607
+ )
608
+
609
+ # Preprocessing the datasets.
610
+ # We need to normalize and tokenize inputs and targets.
611
+ dataset.preprocess(tokenizer=tokenizer, config=model.config)
612
+
613
+ # Initialize our training
614
+ dropout_rng = jax.random.PRNGKey(training_args.seed_model)
615
+
616
+ # Store some constant
617
+ num_epochs = training_args.num_train_epochs
618
+ # batch size
619
+ batch_size_per_node_per_grad_step = (
620
+ training_args.per_device_train_batch_size
621
+ * jax.local_device_count()
622
+ // training_args.mp_devices
623
+ )
624
+ batch_size_per_node = (
625
+ batch_size_per_node_per_grad_step * training_args.gradient_accumulation_steps
626
+ )
627
+ batch_size_per_step = batch_size_per_node * jax.process_count()
628
+ eval_batch_size_per_node = (
629
+ training_args.per_device_eval_batch_size
630
+ * jax.local_device_count()
631
+ // training_args.mp_devices
632
+ )
633
+ eval_batch_size_per_step = eval_batch_size_per_node * jax.process_count()
634
+ len_train_dataset, len_eval_dataset = dataset.length
635
+ steps_per_epoch = (
636
+ len_train_dataset // batch_size_per_node
637
+ if len_train_dataset is not None
638
+ else None
639
+ )
640
+ num_train_steps = (
641
+ steps_per_epoch * num_epochs if steps_per_epoch is not None else None
642
+ )
643
+ num_params = model.num_params
644
+
645
+ logger.info("***** Running training *****")
646
+ logger.info(f" Num examples = {len_train_dataset}")
647
+ logger.info(f" Num Epochs = {num_epochs}")
648
+ logger.info(
649
+ f" Batch size per dp device = {training_args.per_device_train_batch_size}"
650
+ )
651
+ logger.info(f" Number of devices = {jax.device_count()}")
652
+ logger.info(
653
+ f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
654
+ )
655
+ logger.info(f" Batch size per update = {batch_size_per_step}")
656
+ logger.info(f" Model parameters = {num_params:,}")
657
+
658
+ # set up wandb run
659
+ if jax.process_index() == 0:
660
+ # set default x-axis as 'train/step'
661
+ wandb.define_metric("*", step_metric="train/step")
662
+
663
+ # add interesting config parameters
664
+ wandb.config.update(
665
+ {
666
+ "len_train_dataset": len_train_dataset,
667
+ "len_eval_dataset": len_eval_dataset,
668
+ "batch_size_per_step": batch_size_per_step,
669
+ "num_params": num_params,
670
+ "model_config": model.config.to_dict(),
671
+ "num_devices": jax.device_count(),
672
+ "versions": {
673
+ "jax": jax.__version__,
674
+ "jaxlib": jaxlib.__version__,
675
+ "flax": flax.__version__,
676
+ "transformers": transformers.__version__,
677
+ "datasets": datasets.__version__,
678
+ "wandb": wandb.__version__,
679
+ "dalle_mini": dalle_mini.__version__,
680
+ },
681
+ }
682
+ )
683
+
684
+ # Create learning rate schedule
685
+ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
686
+ """Create the learning rate function."""
687
+ warmup_fn = optax.linear_schedule(
688
+ init_value=0.0,
689
+ end_value=training_args.learning_rate,
690
+ transition_steps=training_args.warmup_steps + 1, # ensure not 0
691
+ )
692
+ # offset step when resuming
693
+ if model_metadata.get("step", 0):
694
+ warmup_fn = optax.join_schedules(
695
+ schedules=[optax.constant_schedule(0.0), warmup_fn],
696
+ boundaries=[model_metadata["step"]],
697
+ )
698
+ if training_args.lr_decay is None:
699
+ return warmup_fn
700
+ elif training_args.lr_decay == "linear":
701
+ assert (
702
+ num_train_steps is not None
703
+ ), "linear decay requires knowing the dataset length"
704
+ decay_fn = optax.linear_schedule(
705
+ init_value=training_args.learning_rate,
706
+ end_value=0,
707
+ transition_steps=num_train_steps - training_args.warmup_steps,
708
+ )
709
+ elif training_args.lr_decay == "exponential":
710
+ decay_fn = optax.exponential_decay(
711
+ init_value=training_args.learning_rate,
712
+ transition_steps=training_args.lr_transition_steps,
713
+ decay_rate=training_args.lr_decay_rate,
714
+ staircase=training_args.lr_staircase,
715
+ )
716
+ schedule_fn = optax.join_schedules(
717
+ schedules=[warmup_fn, decay_fn],
718
+ boundaries=[model_metadata.get("step", 0) + training_args.warmup_steps],
719
+ )
720
+ return schedule_fn
721
+
722
+ learning_rate_fn = create_learning_rate_fn()
723
+
724
+ # create adam optimizer
725
+ if training_args.optim == "distributed_shampoo":
726
+ # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
727
+ graft_type = {
728
+ "sgd": GraftingType.SGD,
729
+ "adagrad": GraftingType.ADAGRAD,
730
+ "rmsprop": GraftingType.RMSPROP,
731
+ "rmsprop_normalized": GraftingType.RMSPROP_NORMALIZED,
732
+ "sqrt_n": GraftingType.SQRT_N,
733
+ "adagrad_normalized": GraftingType.ADAGRAD_NORMALIZED,
734
+ }[training_args.graft_type]
735
+ optimizer = distributed_shampoo(
736
+ learning_rate_fn,
737
+ block_size=training_args.block_size,
738
+ beta1=training_args.beta1,
739
+ beta2=training_args.beta2,
740
+ diagonal_epsilon=1e-10,
741
+ matrix_epsilon=1e-6,
742
+ start_preconditioning_step=max(
743
+ training_args.preconditioning_compute_steps + 1, 101
744
+ ),
745
+ preconditioning_compute_steps=training_args.preconditioning_compute_steps,
746
+ statistics_compute_steps=1,
747
+ best_effort_shape_interpretation=True,
748
+ graft_type=graft_type,
749
+ nesterov=False,
750
+ exponent_override=0,
751
+ statistics_partition_spec=PartitionSpec(None, "dp", None),
752
+ preconditioner_partition_spec=PartitionSpec("dp", None, None),
753
+ num_devices_for_pjit=training_args.dp_devices,
754
+ shard_optimizer_states=True,
755
+ inverse_failure_threshold=0.1,
756
+ moving_average_for_momentum=True,
757
+ skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
758
+ clip_by_scaled_gradient_norm=None,
759
+ precision=jax.lax.Precision.HIGHEST,
760
+ best_effort_memory_usage_reduction=training_args.optim_quantized,
761
+ )
762
+ # get the real optimizer and helper functions
763
+ update_fn = optimizer.update
764
+ optimizer = optimizer.init(model.params)
765
+ opt_fn = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
766
+ optimizer.pspec_fn, optimizer.shape_and_dtype_fn
767
+ )
768
+ optimizer = optax.GradientTransformation(optimizer.init_fn, update_fn)
769
+
770
+ elif training_args.optim == "adam":
771
+ optimizer = optax.adamw(
772
+ learning_rate=learning_rate_fn,
773
+ b1=training_args.beta1,
774
+ b2=training_args.beta2,
775
+ eps=training_args.adam_epsilon,
776
+ )
777
+ elif training_args.optim == "adafactor":
778
+ # We use the default parameters here to initialize adafactor,
779
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
780
+ optimizer = optax.adafactor(
781
+ learning_rate=learning_rate_fn,
782
+ clipping_threshold=training_args.max_grad_norm,
783
+ )
784
+
785
+ # get PartitionSpec for optimizer state
786
+ def get_opt_state_spec_and_shape(param_spec):
787
+ # get opt_state shape without actual init
788
+ opt_state_shape = jax.eval_shape(optimizer.init, model.params)
789
+
790
+ if training_args.optim == "adam":
791
+
792
+ def _opt_state_spec_per_leaf(x):
793
+ if isinstance(x, FrozenDict):
794
+ # variables with same structure as params
795
+ return param_spec
796
+ else:
797
+ # other variables such as count
798
+ return None
799
+
800
+ opt_state_spec = jax.tree_map(
801
+ _opt_state_spec_per_leaf,
802
+ opt_state_shape,
803
+ # return None spec for empty elements
804
+ is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
805
+ )
806
+
807
+ elif training_args.optim == "adafactor":
808
+ # factorized state must be replicated (rank different than params)
809
+ opt_state_spec = None
810
+
811
+ elif training_args.optim == "distributed_shampoo":
812
+ opt_state_spec = opt_fn.pspec_fn(
813
+ params=model.params,
814
+ params_partition_spec=param_spec,
815
+ partition_spec_for_statistics=PartitionSpec(None, "dp", None),
816
+ )
817
+ else:
818
+ raise NotImplementedError
819
+ return opt_state_spec, opt_state_shape
820
+
821
+ opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape(param_spec)
822
+
823
+ # create a mesh
824
+ mesh_shape = (training_args.dp_devices, training_args.mp_devices)
825
+ devices = np.asarray(jax.devices()).reshape(*mesh_shape)
826
+ mesh = maps.Mesh(devices, ("dp", "mp"))
827
+ logger.info(f" Mesh shape: {mesh_shape}")
828
+
829
+ # define state spec
830
+ state_spec = TrainState(
831
+ params=param_spec,
832
+ opt_state=opt_state_spec,
833
+ dropout_rng=None,
834
+ step=None,
835
+ epoch=None,
836
+ train_time=None,
837
+ train_samples=None,
838
+ apply_fn=model.__call__,
839
+ tx=optimizer,
840
+ )
841
+
842
+ # init params if not available yet
843
+ def maybe_init_params(params):
844
+ if model_args.model_name_or_path:
845
+ # model params are correctly loaded
846
+ return params
847
+ else:
848
+ # params have not been initialized yet
849
+ return model.init_weights()
850
+
851
+ with mesh:
852
+ logger.info(" Creating state")
853
+ if not model_args.restore_state:
854
+
855
+ def init_state(params):
856
+ return TrainState.create(
857
+ apply_fn=model.__call__,
858
+ tx=optimizer,
859
+ params=maybe_init_params(params),
860
+ dropout_rng=dropout_rng,
861
+ )
862
+
863
+ state = pjit(
864
+ init_state,
865
+ in_axis_resources=(param_spec,)
866
+ if model_args.model_name_or_path
867
+ else None,
868
+ out_axis_resources=state_spec,
869
+ donate_argnums=(0,),
870
+ )(model.params if model_args.model_name_or_path else None)
871
+
872
+ else:
873
+ # load opt_state
874
+ opt_state = from_bytes(opt_state_shape, model_args.get_opt_state())
875
+
876
+ # restore other attributes
877
+ attr_state = {
878
+ k: model_metadata[k]
879
+ for k in ["step", "epoch", "train_time", "train_samples"]
880
+ }
881
+
882
+ def restore_state(params, opt_state):
883
+ return TrainState(
884
+ apply_fn=model.__call__,
885
+ tx=optimizer,
886
+ params=params,
887
+ opt_state=opt_state,
888
+ dropout_rng=dropout_rng,
889
+ **attr_state,
890
+ )
891
+
892
+ state = pjit(
893
+ restore_state,
894
+ in_axis_resources=(
895
+ param_spec,
896
+ opt_state_spec,
897
+ ),
898
+ out_axis_resources=state_spec,
899
+ donate_argnums=(0, 1),
900
+ )(model.params, opt_state)
901
+
902
+ # remove opt_state from CPU
903
+ del opt_state
904
+
905
+ # free CPU memory
906
+ del model._params, opt_state_spec, opt_state_shape
907
+
908
+ # define batch specs
909
+ batch_spec = PartitionSpec("dp")
910
+ grad_batch_spec = PartitionSpec(None, "dp")
911
+
912
+ # define loss
913
+ def loss_fn(logits, labels):
914
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
915
+ loss = loss.mean()
916
+ return loss
917
+
918
+ # "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens)
919
+ # lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2
920
+ use_vmap_trick = True
921
+
922
+ # make grad_param_spec for vmap
923
+ if use_vmap_trick:
924
+ grad_param_spec = jax.tree_map(
925
+ lambda x: PartitionSpec(*("dp",) + (x if x is not None else (None,))),
926
+ param_spec,
927
+ )
928
+
929
+ # Define gradient update step fn
930
+ def train_step(state, batch, train_time):
931
+
932
+ # get a minibatch (one gradient accumulation slice)
933
+ def get_minibatch(batch, grad_idx):
934
+ return jax.tree_map(
935
+ lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),
936
+ batch,
937
+ )
938
+
939
+ def compute_loss(params, minibatch, dropout_rng):
940
+ # minibatch has dim (batch_size, ...)
941
+ minibatch, labels = minibatch.pop("labels")
942
+ logits = state.apply_fn(
943
+ **minibatch, params=params, dropout_rng=dropout_rng, train=True
944
+ )[0]
945
+ return loss_fn(logits, labels)
946
+
947
+ grad_fn = jax.value_and_grad(compute_loss)
948
+
949
+ def loss_and_grad(grad_idx, dropout_rng):
950
+ # minibatch at grad_idx for gradient accumulation (None otherwise)
951
+ minibatch = (
952
+ get_minibatch(batch, grad_idx) if grad_idx is not None else batch
953
+ )
954
+ # ensure it is sharded properly
955
+ minibatch = with_sharding_constraint(minibatch, batch_spec)
956
+ # only 1 single rng per grad step, let us handle larger batch size (not sure why)
957
+ dropout_rng, _ = jax.random.split(dropout_rng)
958
+
959
+ if use_vmap_trick:
960
+ # "vmap trick", calculate loss and grads independently per dp_device
961
+ loss, grads = jax.vmap(
962
+ grad_fn, in_axes=(None, 0, None), out_axes=(0, 0)
963
+ )(state.params, minibatch, dropout_rng)
964
+ # ensure they are sharded correctly
965
+ loss = with_sharding_constraint(loss, batch_spec)
966
+ grads = with_sharding_constraint(grads, grad_param_spec)
967
+ # average across all devices
968
+ # Note: we could average per device only after gradient accumulation, right before params update
969
+ loss, grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), (loss, grads))
970
+ else:
971
+ # "vmap trick" does not work in multi-hosts and requires too much hbm
972
+ loss, grads = grad_fn(state.params, minibatch, dropout_rng)
973
+ # ensure grads are sharded
974
+ grads = with_sharding_constraint(grads, param_spec)
975
+ # return loss and grads
976
+ return loss, grads, dropout_rng
977
+
978
+ if training_args.gradient_accumulation_steps == 1:
979
+ loss, grads, dropout_rng = loss_and_grad(None, state.dropout_rng)
980
+ else:
981
+ # create initial state for cumul_minibatch_step loop
982
+ init_minibatch_step = (
983
+ 0.0,
984
+ with_sharding_constraint(
985
+ jax.tree_map(jnp.zeros_like, state.params), param_spec
986
+ ),
987
+ state.dropout_rng,
988
+ )
989
+
990
+ # accumulate gradients
991
+ def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
992
+ cumul_loss, cumul_grads, dropout_rng = cumul_loss_grad_dropout
993
+ loss, grads, dropout_rng = loss_and_grad(grad_idx, dropout_rng)
994
+ cumul_loss, cumul_grads = jax.tree_map(
995
+ jnp.add, (cumul_loss, cumul_grads), (loss, grads)
996
+ )
997
+ cumul_grads = with_sharding_constraint(cumul_grads, param_spec)
998
+ return cumul_loss, cumul_grads, dropout_rng
999
+
1000
+ # loop over gradients
1001
+ loss, grads, dropout_rng = jax.lax.fori_loop(
1002
+ 0,
1003
+ training_args.gradient_accumulation_steps,
1004
+ cumul_minibatch_step,
1005
+ init_minibatch_step,
1006
+ )
1007
+ grads = with_sharding_constraint(grads, param_spec)
1008
+ # sum -> mean
1009
+ loss, grads = jax.tree_map(
1010
+ lambda x: x / training_args.gradient_accumulation_steps, (loss, grads)
1011
+ )
1012
+
1013
+ grads = with_sharding_constraint(grads, param_spec)
1014
+
1015
+ # update state
1016
+ state = state.apply_gradients(
1017
+ grads=grads,
1018
+ dropout_rng=dropout_rng,
1019
+ train_time=train_time,
1020
+ train_samples=state.train_samples + batch_size_per_step,
1021
+ )
1022
+
1023
+ metrics = {
1024
+ "loss": loss,
1025
+ "learning_rate": learning_rate_fn(state.step),
1026
+ }
1027
+
1028
+ def maybe_fn(fn, val, zeros, freq):
1029
+ """Call fn only if it is a logging step"""
1030
+ return jax.lax.cond(
1031
+ state.step % freq == 0,
1032
+ fn,
1033
+ lambda _: zeros,
1034
+ val,
1035
+ )
1036
+
1037
+ if training_args.log_norm_steps:
1038
+ zeros_norm = jax.tree_map(lambda _: jnp.float32(0), state.params)
1039
+
1040
+ def norm(val):
1041
+ return jax.tree_map(lambda x: jnp.linalg.norm(x), val)
1042
+
1043
+ gradients_norm = maybe_fn(
1044
+ norm, grads, zeros_norm, training_args.log_norm_steps
1045
+ )
1046
+ params_norm = maybe_fn(
1047
+ norm, state.params, zeros_norm, training_args.log_norm_steps
1048
+ )
1049
+
1050
+ metrics.update(
1051
+ {
1052
+ "gradients_norm": gradients_norm,
1053
+ "params_norm": params_norm,
1054
+ }
1055
+ )
1056
+
1057
+ if training_args.log_histogram_steps:
1058
+ zeros_hist = jax.tree_map(
1059
+ lambda _: jnp.histogram(jnp.zeros(1), density=True), state.params
1060
+ )
1061
+
1062
+ def histogram(val):
1063
+ return jax.tree_map(lambda x: jnp.histogram(x, density=True), val)
1064
+
1065
+ gradients_hist = maybe_fn(
1066
+ histogram, grads, zeros_hist, training_args.log_histogram_steps
1067
+ )
1068
+ params_hist = maybe_fn(
1069
+ histogram, state.params, zeros_hist, training_args.log_histogram_steps
1070
+ )
1071
+
1072
+ metrics.update(
1073
+ {
1074
+ "params_hist": params_hist,
1075
+ "gradients_hist": gradients_hist,
1076
+ }
1077
+ )
1078
+
1079
+ return state, metrics
1080
+
1081
+ # Define eval fn
1082
+ def eval_step(state, batch):
1083
+ def compute_eval_loss(batch):
1084
+ batch, labels = batch.pop("labels")
1085
+ logits = model(**batch, params=state.params, train=False)[0]
1086
+ return loss_fn(logits, labels)
1087
+
1088
+ if use_vmap_trick:
1089
+ loss = jax.vmap(compute_eval_loss)(batch)
1090
+ # ensure they are sharded correctly
1091
+ loss = with_sharding_constraint(loss, batch_spec)
1092
+ # average across all devices
1093
+ loss = jnp.mean(loss)
1094
+ else:
1095
+ loss = compute_eval_loss(batch)
1096
+
1097
+ return loss
1098
+
1099
+ # Create parallel version of the train and eval step
1100
+ p_train_step = pjit(
1101
+ train_step,
1102
+ in_axis_resources=(
1103
+ state_spec,
1104
+ grad_batch_spec
1105
+ if training_args.gradient_accumulation_steps > 1
1106
+ else batch_spec,
1107
+ None,
1108
+ ),
1109
+ out_axis_resources=(state_spec, None),
1110
+ donate_argnums=(0,),
1111
+ )
1112
+ p_eval_step = pjit(
1113
+ eval_step,
1114
+ in_axis_resources=(state_spec, batch_spec),
1115
+ out_axis_resources=None,
1116
+ )
1117
+
1118
+ # define metrics logger
1119
+ class MetricsLogger:
1120
+ def __init__(self, step):
1121
+ # keep state
1122
+ self.state_dict = {}
1123
+ # estimate speed
1124
+ self.step = step
1125
+ self.time = time.perf_counter()
1126
+ self.offset_time = 0.0
1127
+
1128
+ def update_state_metrics(self, state):
1129
+ """Update internal state metrics (logged at each call to be used as x-axis)"""
1130
+ self.state_dict = {
1131
+ f'train/{k.split("_")[-1]}': state[k]
1132
+ for k in ["step", "epoch", "train_time", "train_samples"]
1133
+ }
1134
+ # timing metrics
1135
+ new_step = int(state["step"])
1136
+ new_time = time.perf_counter()
1137
+ if new_step > self.step:
1138
+ # remove time for eval & save
1139
+ delta_time = new_time - self.time - self.offset_time
1140
+ self.offset_time = 0
1141
+ time_per_step = delta_time / (new_step - self.step)
1142
+ self.step = new_step
1143
+ self.time = new_time
1144
+ self.log_time("train_per_step", time_per_step, offset=False)
1145
+ self.log_time("train_per_log", delta_time, offset=False)
1146
+
1147
+ def log_time(self, key, duration, offset=True):
1148
+ wandb.log({f"time/{key}": duration, **self.state_dict})
1149
+ if offset:
1150
+ self.offset_time += duration
1151
+
1152
+ def log(self, metrics, prefix=None):
1153
+ if jax.process_index() == 0:
1154
+ log_metrics = {}
1155
+ for k, v in metrics.items():
1156
+ if "_norm" in k:
1157
+ if self.step % training_args.log_norm_steps == 0:
1158
+ log_metrics[f"{k}/"] = unfreeze(v)
1159
+ elif "_hist" in k:
1160
+ if self.step % training_args.log_histogram_steps == 0:
1161
+ v = jax.tree_map(lambda x: jax.device_get(x), unfreeze(v))
1162
+ v = jax.tree_map(
1163
+ lambda x: wandb.Histogram(np_histogram=x),
1164
+ v,
1165
+ is_leaf=lambda x: isinstance(x, tuple),
1166
+ )
1167
+ log_metrics[f"{k}/"] = v
1168
+ else:
1169
+ if prefix is not None:
1170
+ k = f"{prefix}/{k}"
1171
+ log_metrics[k] = v
1172
+ wandb.log({**log_metrics, **self.state_dict})
1173
+
1174
+ # keep local copy of state
1175
+ local_state = {
1176
+ k: jax.device_get(getattr(state, k)).item()
1177
+ for k in ["step", "epoch", "train_time", "train_samples"]
1178
+ }
1179
+ # init variables
1180
+ start_time = time.perf_counter() - local_state["train_time"]
1181
+ train_metrics = None
1182
+ metrics_logger = MetricsLogger(local_state["step"])
1183
+ epochs = tqdm(
1184
+ range(local_state["epoch"], num_epochs),
1185
+ desc=f"Epoch ... (1/{num_epochs})",
1186
+ position=0,
1187
+ disable=jax.process_index() > 0,
1188
+ )
1189
+
1190
+ def run_evaluation():
1191
+ # ======================== Evaluating ==============================
1192
+ if training_args.do_eval:
1193
+ start_eval_time = time.perf_counter()
1194
+ eval_loader = dataset.dataloader("eval", eval_batch_size_per_step)
1195
+ eval_steps = (
1196
+ len_eval_dataset // eval_batch_size_per_step
1197
+ if len_eval_dataset is not None
1198
+ else None
1199
+ )
1200
+ eval_loss = []
1201
+ for batch in tqdm(
1202
+ eval_loader,
1203
+ desc="Evaluating...",
1204
+ position=2,
1205
+ leave=False,
1206
+ total=eval_steps,
1207
+ disable=jax.process_index() > 0,
1208
+ ):
1209
+ # need to keep only eval_batch_size_per_node items relevant to the node
1210
+ batch = jax.tree_map(
1211
+ lambda x: x.reshape(
1212
+ (jax.process_count(), eval_batch_size_per_node) + x.shape[1:]
1213
+ ),
1214
+ batch,
1215
+ )
1216
+ batch = jax.tree_map(lambda x: x[jax.process_index()], batch)
1217
+
1218
+ # add dp dimension when using "vmap trick"
1219
+ if use_vmap_trick:
1220
+ bs_shape = (
1221
+ jax.local_device_count() // training_args.mp_devices,
1222
+ training_args.per_device_eval_batch_size,
1223
+ )
1224
+ batch = jax.tree_map(
1225
+ lambda x: x.reshape(bs_shape + x.shape[1:]), batch
1226
+ )
1227
+
1228
+ # freeze batch to pass safely to jax transforms
1229
+ batch = freeze(batch)
1230
+ # accumulate losses async
1231
+ eval_loss.append(p_eval_step(state, batch))
1232
+
1233
+ # get the mean of the loss
1234
+ eval_loss = jnp.stack(eval_loss)
1235
+ eval_loss = jnp.mean(eval_loss)
1236
+ eval_metrics = {"loss": eval_loss}
1237
+
1238
+ # log metrics
1239
+ metrics_logger.log(eval_metrics, prefix="eval")
1240
+ metrics_logger.log_time("eval", time.perf_counter() - start_eval_time)
1241
+
1242
+ # Print metrics and update progress bar
1243
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
1244
+ epochs.write(desc)
1245
+ epochs.desc = desc
1246
+
1247
+ return eval_metrics
1248
+
1249
+ def run_save_model(state, eval_metrics=None):
1250
+ if jax.process_index() == 0:
1251
+
1252
+ start_save_time = time.perf_counter()
1253
+ output_dir = training_args.output_dir
1254
+ use_bucket = output_dir.startswith("gs://")
1255
+ if use_bucket:
1256
+ bucket_path = Path(output_dir[5:]) / wandb.run.id / f"step_{state.step}"
1257
+ bucket, dir_path = str(bucket_path).split("/", 1)
1258
+ tmp_dir = tempfile.TemporaryDirectory()
1259
+ output_dir = tmp_dir.name
1260
+
1261
+ # save model
1262
+ params = jax.device_get(state.params)
1263
+ model.save_pretrained(
1264
+ output_dir,
1265
+ params=params,
1266
+ )
1267
+
1268
+ # save tokenizer
1269
+ tokenizer.save_pretrained(output_dir)
1270
+
1271
+ # copy to bucket
1272
+ if use_bucket:
1273
+ client = storage.Client()
1274
+ bucket = client.bucket(bucket)
1275
+ for filename in Path(output_dir).glob("*"):
1276
+ blob_name = str(Path(dir_path) / "model" / filename.name)
1277
+ blob = bucket.blob(blob_name)
1278
+ blob.upload_from_filename(str(filename))
1279
+ tmp_dir.cleanup()
1280
+
1281
+ # save state
1282
+ opt_state = jax.device_get(state.opt_state)
1283
+ if use_bucket:
1284
+ blob_name = str(Path(dir_path) / "state" / "opt_state.msgpack")
1285
+ blob = bucket.blob(blob_name)
1286
+ blob.upload_from_file(io.BytesIO(to_bytes(opt_state)))
1287
+ else:
1288
+ with (Path(output_dir) / "opt_state.msgpack").open("wb") as f:
1289
+ f.write(to_bytes(opt_state))
1290
+
1291
+ # save to W&B
1292
+ if training_args.log_model:
1293
+ # save some space
1294
+ c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
1295
+ c.cleanup(wandb.util.from_human_size("20GB"))
1296
+
1297
+ metadata = {
1298
+ k: jax.device_get(getattr(state, k)).item()
1299
+ for k in ["step", "epoch", "train_time", "train_samples"]
1300
+ }
1301
+ metadata["num_params"] = num_params
1302
+ if eval_metrics is not None:
1303
+ metadata["eval"] = eval_metrics
1304
+
1305
+ # create model artifact
1306
+ if use_bucket:
1307
+ metadata["bucket_path"] = f"gs://{bucket_path}/model"
1308
+ artifact = wandb.Artifact(
1309
+ name=f"model-{wandb.run.id}",
1310
+ type="DalleBart_model",
1311
+ metadata=metadata,
1312
+ )
1313
+ if use_bucket:
1314
+ artifact.add_reference(metadata["bucket_path"])
1315
+ else:
1316
+ for filename in [
1317
+ "config.json",
1318
+ "flax_model.msgpack",
1319
+ "merges.txt",
1320
+ "special_tokens_map.json",
1321
+ "tokenizer.json",
1322
+ "tokenizer_config.json",
1323
+ "vocab.json",
1324
+ ]:
1325
+ artifact.add_file(
1326
+ f"{Path(training_args.output_dir) / filename}"
1327
+ )
1328
+ wandb.run.log_artifact(artifact)
1329
+
1330
+ # create state artifact
1331
+ if use_bucket:
1332
+ metadata["bucket_path"] = f"gs://{bucket_path}/state"
1333
+ artifact_state = wandb.Artifact(
1334
+ name=f"state-{wandb.run.id}",
1335
+ type="DalleBart_state",
1336
+ metadata=metadata,
1337
+ )
1338
+ if use_bucket:
1339
+ artifact_state.add_reference(metadata["bucket_path"])
1340
+ else:
1341
+ artifact_state.add_file(
1342
+ f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
1343
+ )
1344
+ wandb.run.log_artifact(artifact_state)
1345
+ metrics_logger.log_time("save_model", time.perf_counter() - start_save_time)
1346
+
1347
+ logger.info(" Ready to start training")
1348
+ with mesh:
1349
+ for epoch in epochs:
1350
+ state.replace(epoch=epoch)
1351
+ local_state["epoch"] = epoch
1352
+ # ======================== Training ================================
1353
+ metrics_logger.update_state_metrics(local_state)
1354
+ metrics_logger.log({})
1355
+
1356
+ # Generate an epoch by shuffling sampling indices from the train dataset
1357
+ train_loader = dataset.dataloader(
1358
+ "train",
1359
+ batch_size_per_node,
1360
+ epoch,
1361
+ )
1362
+ # train
1363
+ for batch in tqdm(
1364
+ train_loader,
1365
+ desc="Training...",
1366
+ position=1,
1367
+ leave=False,
1368
+ total=steps_per_epoch,
1369
+ disable=jax.process_index() > 0,
1370
+ ):
1371
+ # calculate delta time (we have a lag of one step but it's ok)
1372
+ train_time = time.perf_counter() - start_time
1373
+
1374
+ # set correct shape to batch
1375
+ # - add grad_step dim if gradient_accumulation_steps > 1
1376
+ # - split per dp device if not multi-host for vmap trick (does not work in multi-host)
1377
+ bs_shape = (
1378
+ (batch_size_per_node_per_grad_step,)
1379
+ if not use_vmap_trick
1380
+ else (
1381
+ jax.local_device_count()
1382
+ // training_args.mp_devices, # local dp devices
1383
+ training_args.per_device_train_batch_size,
1384
+ )
1385
+ )
1386
+ if training_args.gradient_accumulation_steps > 1:
1387
+ # reshape data into (gradient_accumulation_steps, batch_per_node, ...)
1388
+ # to avoid any data redistribution when sharding
1389
+ bs_shape = (training_args.gradient_accumulation_steps,) + bs_shape
1390
+
1391
+ # reshape batch
1392
+ batch = jax.tree_map(
1393
+ lambda x: x.reshape(bs_shape + x.shape[1:]),
1394
+ batch,
1395
+ )
1396
+ # freeze batch to pass safely to jax transforms
1397
+ batch = freeze(batch)
1398
+
1399
+ # train step
1400
+ state, train_metrics = p_train_step(state, batch, train_time)
1401
+ local_state["step"] += 1
1402
+ local_state["train_time"] = train_time
1403
+ local_state["train_samples"] += batch_size_per_step
1404
+
1405
+ if (
1406
+ local_state["step"] % training_args.logging_steps == 0
1407
+ and jax.process_index() == 0
1408
+ ):
1409
+ metrics_logger.update_state_metrics(local_state)
1410
+ metrics_logger.log(train_metrics, prefix="train")
1411
+
1412
+ eval_metrics = None
1413
+ if local_state["step"] % training_args.eval_steps == 0:
1414
+ eval_metrics = run_evaluation()
1415
+
1416
+ if local_state["step"] % training_args.save_steps == 0:
1417
+ run_save_model(state, eval_metrics)
1418
+
1419
+ # log final train metrics
1420
+ if train_metrics is not None:
1421
+ metrics_logger.update_state_metrics(state)
1422
+ metrics_logger.log(train_metrics, prefix="train")
1423
+
1424
+ epochs.write(
1425
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
1426
+ )
1427
+
1428
+ # Final evaluation
1429
+ eval_metrics = run_evaluation()
1430
+
1431
+ # save checkpoint after each epoch
1432
+ run_save_model(state, eval_metrics)
1433
+
1434
+
1435
+ if __name__ == "__main__":
1436
+ main()