Spaces:
Running
Running
feat: implement transformer variants (#144)
Browse files* added DeepNet
* added Swin v2
* added NormFormer
* added RMSNorm
* added GLU variants
- .gitignore +1 -0
- README.md +88 -20
- src/dalle_mini/data.py +7 -4
- src/dalle_mini/model/configuration.py +29 -6
- src/dalle_mini/model/modeling.py +740 -90
- src/dalle_mini/model/partitions.py +14 -16
- src/dalle_mini/model/utils.py +3 -19
- tools/inference/inference_pipeline.ipynb +3 -24
- tools/train/config/medium/config.json +0 -1
- tools/train/config/mega/config.json +0 -1
- tools/train/config/micro/config.json +0 -1
- tools/train/config/mini/config.json +2 -5
- tools/train/train.py +143 -108
.gitignore
CHANGED
@@ -3,3 +3,4 @@ __pycache__
|
|
3 |
.streamlit
|
4 |
wandb/
|
5 |
*.egg-info/
|
|
|
|
3 |
.streamlit
|
4 |
wandb/
|
5 |
*.egg-info/
|
6 |
+
jax_cache/
|
README.md
CHANGED
@@ -94,26 +94,43 @@ Many thanks to the people who helped make it better:
|
|
94 |
|
95 |
- the [DALLE-Pytorch](https://discord.gg/xBPBXfcFHd) and [EleutherAI](https://www.eleuther.ai/) communities for testing and exchanging cool ideas
|
96 |
- [Rohan Anil](https://github.com/rohan-anil) for adding Distributed Shampoo optimizer
|
|
|
97 |
- [Katherine Crowson](https://github.com/crowsonkb) for [super conditioning](https://twitter.com/RiversHaveWings/status/1478093658716966912)
|
98 |
|
99 |
## Citing DALL·E mini
|
100 |
|
101 |
If you find DALL·E mini useful in your research or wish to refer, please use the following BibTeX entry.
|
102 |
|
103 |
-
```
|
104 |
@misc{Dayma_DALL·E_Mini_2021,
|
105 |
-
author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
|
106 |
-
doi = {10.5281/zenodo.5146400},
|
107 |
-
month = {7},
|
108 |
-
title = {DALL·E Mini},
|
109 |
-
url = {https://github.com/borisdayma/dalle-mini},
|
110 |
-
year = {2021}
|
111 |
}
|
112 |
```
|
113 |
|
114 |
## References
|
115 |
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
@misc{ramesh2021zeroshot,
|
118 |
title={Zero-Shot Text-to-Image Generation},
|
119 |
author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
|
@@ -124,7 +141,18 @@ year = {2021}
|
|
124 |
}
|
125 |
```
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
```
|
|
|
|
|
128 |
@misc{esser2021taming,
|
129 |
title={Taming Transformers for High-Resolution Image Synthesis},
|
130 |
author={Patrick Esser and Robin Rombach and Björn Ommer},
|
@@ -135,7 +163,7 @@ year = {2021}
|
|
135 |
}
|
136 |
```
|
137 |
|
138 |
-
```
|
139 |
@misc{lewis2019bart,
|
140 |
title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension},
|
141 |
author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer},
|
@@ -146,24 +174,64 @@ year = {2021}
|
|
146 |
}
|
147 |
```
|
148 |
|
149 |
-
```
|
150 |
-
@misc{
|
151 |
-
title={
|
152 |
-
author={
|
153 |
year={2021},
|
154 |
-
eprint={
|
155 |
archivePrefix={arXiv},
|
156 |
-
primaryClass={cs.
|
157 |
}
|
158 |
```
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
```
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
year={2021},
|
165 |
-
eprint={
|
166 |
archivePrefix={arXiv},
|
167 |
-
primaryClass={cs.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
}
|
169 |
```
|
|
|
94 |
|
95 |
- the [DALLE-Pytorch](https://discord.gg/xBPBXfcFHd) and [EleutherAI](https://www.eleuther.ai/) communities for testing and exchanging cool ideas
|
96 |
- [Rohan Anil](https://github.com/rohan-anil) for adding Distributed Shampoo optimizer
|
97 |
+
- [Phil Wang](https://github.com/lucidrains) has provided a lot of cool implementations of transformer variants and gives interesting insights with [x-transformers](https://github.com/lucidrains/x-transformers)
|
98 |
- [Katherine Crowson](https://github.com/crowsonkb) for [super conditioning](https://twitter.com/RiversHaveWings/status/1478093658716966912)
|
99 |
|
100 |
## Citing DALL·E mini
|
101 |
|
102 |
If you find DALL·E mini useful in your research or wish to refer, please use the following BibTeX entry.
|
103 |
|
104 |
+
```text
|
105 |
@misc{Dayma_DALL·E_Mini_2021,
|
106 |
+
author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
|
107 |
+
doi = {10.5281/zenodo.5146400},
|
108 |
+
month = {7},
|
109 |
+
title = {DALL·E Mini},
|
110 |
+
url = {https://github.com/borisdayma/dalle-mini},
|
111 |
+
year = {2021}
|
112 |
}
|
113 |
```
|
114 |
|
115 |
## References
|
116 |
|
117 |
+
Original DALL·E from "[Zero-Shot Text-to-Image Generation](https://arxiv.org/abs/2102.12092)" with image quantization from "[Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)".
|
118 |
+
|
119 |
+
Image encoder from "[Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841v2)".
|
120 |
+
|
121 |
+
Sequence to sequence model based on "[BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461v1)" with implementation of a few variants:
|
122 |
+
|
123 |
+
- "[GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202)"
|
124 |
+
- "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)"
|
125 |
+
- "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)"
|
126 |
+
- "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)"
|
127 |
+
- "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)"
|
128 |
+
|
129 |
+
Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)".
|
130 |
+
|
131 |
+
### Citations
|
132 |
+
|
133 |
+
```text
|
134 |
@misc{ramesh2021zeroshot,
|
135 |
title={Zero-Shot Text-to-Image Generation},
|
136 |
author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
|
|
|
141 |
}
|
142 |
```
|
143 |
|
144 |
+
```text
|
145 |
+
@misc{radford2021learning,
|
146 |
+
title={Learning Transferable Visual Models From Natural Language Supervision},
|
147 |
+
author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
|
148 |
+
year={2021},
|
149 |
+
eprint={2103.00020},
|
150 |
+
archivePrefix={arXiv},
|
151 |
+
primaryClass={cs.CV}
|
152 |
+
}
|
153 |
```
|
154 |
+
|
155 |
+
```text
|
156 |
@misc{esser2021taming,
|
157 |
title={Taming Transformers for High-Resolution Image Synthesis},
|
158 |
author={Patrick Esser and Robin Rombach and Björn Ommer},
|
|
|
163 |
}
|
164 |
```
|
165 |
|
166 |
+
```text
|
167 |
@misc{lewis2019bart,
|
168 |
title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension},
|
169 |
author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer},
|
|
|
174 |
}
|
175 |
```
|
176 |
|
177 |
+
```text
|
178 |
+
@misc{anil2021scalable,
|
179 |
+
title={Scalable Second Order Optimization for Deep Learning},
|
180 |
+
author={Rohan Anil and Vineet Gupta and Tomer Koren and Kevin Regan and Yoram Singer},
|
181 |
year={2021},
|
182 |
+
eprint={2002.09018},
|
183 |
archivePrefix={arXiv},
|
184 |
+
primaryClass={cs.LG}
|
185 |
}
|
186 |
```
|
187 |
|
188 |
+
```text
|
189 |
+
@misc{shazeer2020glu,
|
190 |
+
title={GLU Variants Improve Transformer},
|
191 |
+
author={Noam Shazeer},
|
192 |
+
year={2020},
|
193 |
+
url={https://arxiv.org/abs/2002.05202}
|
194 |
+
}
|
195 |
```
|
196 |
+
|
197 |
+
```text
|
198 |
+
@misc{wang_ma_dong_huang_zhang_wei_2022,
|
199 |
+
title={DeepNet: Scaling transformers to 1,000 layers},
|
200 |
+
author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Zhang, Dongdong and Wei, Furu},
|
201 |
+
year={2022},
|
202 |
+
eprint={2203.00555}
|
203 |
+
archivePrefix={arXiv},
|
204 |
+
primaryClass={cs.LG}
|
205 |
+
}
|
206 |
+
```
|
207 |
+
|
208 |
+
```text
|
209 |
+
@misc{shleifer2021normformer,
|
210 |
+
title={NormFormer: Improved Transformer Pretraining with Extra Normalization},
|
211 |
+
author={Sam Shleifer and Jason Weston and Myle Ott},
|
212 |
year={2021},
|
213 |
+
eprint={2110.09456},
|
214 |
archivePrefix={arXiv},
|
215 |
+
primaryClass={cs.CL}
|
216 |
+
}
|
217 |
+
```
|
218 |
+
|
219 |
+
```text
|
220 |
+
@inproceedings{liu2021swinv2,
|
221 |
+
title={Swin Transformer V2: Scaling Up Capacity and Resolution},
|
222 |
+
author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
|
223 |
+
booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},
|
224 |
+
year={2022}
|
225 |
+
}
|
226 |
+
```
|
227 |
+
|
228 |
+
```text
|
229 |
+
@misc{zhang2019root,
|
230 |
+
title = {Root Mean Square Layer Normalization},
|
231 |
+
author = {Biao Zhang and Rico Sennrich},
|
232 |
+
year = {2019},
|
233 |
+
eprint = {1910.07467},
|
234 |
+
archivePrefix = {arXiv},
|
235 |
+
primaryClass = {cs.LG}
|
236 |
}
|
237 |
```
|
src/dalle_mini/data.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from dataclasses import dataclass, field
|
2 |
from functools import partial
|
3 |
|
@@ -39,6 +40,9 @@ class Dataset:
|
|
39 |
multi_hosts: bool = field(init=False)
|
40 |
|
41 |
def __post_init__(self):
|
|
|
|
|
|
|
42 |
self.multi_hosts = jax.process_count() > 1
|
43 |
# feed blank captions only in streaming mode for now
|
44 |
# otherwise dataset could be cached with same blanked captions
|
@@ -106,11 +110,10 @@ class Dataset:
|
|
106 |
if self.streaming:
|
107 |
# we need to shuffle early in streaming mode
|
108 |
if hasattr(self, "train_dataset"):
|
109 |
-
self.train_dataset = self.train_dataset.shuffle(
|
|
|
|
|
110 |
else:
|
111 |
-
# prepare rng for later shuffling
|
112 |
-
if self.seed_dataset is None:
|
113 |
-
self.seed_dataset = np.random.get_state()[1][0]
|
114 |
self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
|
115 |
|
116 |
# filter data
|
|
|
1 |
+
import random
|
2 |
from dataclasses import dataclass, field
|
3 |
from functools import partial
|
4 |
|
|
|
40 |
multi_hosts: bool = field(init=False)
|
41 |
|
42 |
def __post_init__(self):
|
43 |
+
if self.seed_dataset is None:
|
44 |
+
# create a random seed
|
45 |
+
self.seed_dataset = random.randint(0, 2**32 - 1)
|
46 |
self.multi_hosts = jax.process_count() > 1
|
47 |
# feed blank captions only in streaming mode for now
|
48 |
# otherwise dataset could be cached with same blanked captions
|
|
|
110 |
if self.streaming:
|
111 |
# we need to shuffle early in streaming mode
|
112 |
if hasattr(self, "train_dataset"):
|
113 |
+
self.train_dataset = self.train_dataset.shuffle(
|
114 |
+
buffer_size=5000, seed=self.seed_dataset
|
115 |
+
)
|
116 |
else:
|
|
|
|
|
|
|
117 |
self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
|
118 |
|
119 |
# filter data
|
src/dalle_mini/model/configuration.py
CHANGED
@@ -44,15 +44,12 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
44 |
decoder_layers=12,
|
45 |
decoder_ffn_dim=4096,
|
46 |
decoder_attention_heads=16,
|
47 |
-
encoder_layerdrop=0.0,
|
48 |
-
decoder_layerdrop=0.0,
|
49 |
activation_function="gelu",
|
50 |
d_model=1024,
|
51 |
dropout=0.1,
|
52 |
attention_dropout=0.0,
|
53 |
activation_dropout=0.0,
|
54 |
init_std=0.02,
|
55 |
-
classifier_dropout=0.0,
|
56 |
scale_embedding=False,
|
57 |
gradient_checkpointing=False,
|
58 |
use_cache=True,
|
@@ -60,9 +57,38 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
60 |
forced_eos_token_id=None,
|
61 |
tie_word_embeddings=False, # different modalities and sizes
|
62 |
do_sample=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
**kwargs,
|
64 |
):
|
|
|
65 |
self.normalize_text = normalize_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
self.encoder_vocab_size = encoder_vocab_size
|
67 |
self.image_vocab_size = image_vocab_size
|
68 |
self.image_length = image_length
|
@@ -79,9 +105,6 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
79 |
self.activation_dropout = activation_dropout
|
80 |
self.activation_function = activation_function
|
81 |
self.init_std = init_std
|
82 |
-
self.encoder_layerdrop = encoder_layerdrop
|
83 |
-
self.decoder_layerdrop = decoder_layerdrop
|
84 |
-
self.classifier_dropout = classifier_dropout
|
85 |
self.use_cache = use_cache
|
86 |
self.gradient_checkpointing = gradient_checkpointing
|
87 |
self.scale_embedding = (
|
|
|
44 |
decoder_layers=12,
|
45 |
decoder_ffn_dim=4096,
|
46 |
decoder_attention_heads=16,
|
|
|
|
|
47 |
activation_function="gelu",
|
48 |
d_model=1024,
|
49 |
dropout=0.1,
|
50 |
attention_dropout=0.0,
|
51 |
activation_dropout=0.0,
|
52 |
init_std=0.02,
|
|
|
53 |
scale_embedding=False,
|
54 |
gradient_checkpointing=False,
|
55 |
use_cache=True,
|
|
|
57 |
forced_eos_token_id=None,
|
58 |
tie_word_embeddings=False, # different modalities and sizes
|
59 |
do_sample=True,
|
60 |
+
# transformer variants
|
61 |
+
head_scale=False, # used in NormFormer
|
62 |
+
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
63 |
+
ln_positions="deepnet", # layer normalization positions, "normformer", "swinv2", "deepnet" (same as post-ln)
|
64 |
+
use_cosine_attention=False, # used in Swin v2
|
65 |
+
tau_init=0.05, # used only in cosine attention (Swin v2)
|
66 |
+
use_deepnet_scaling=False, # used in Deepnet
|
67 |
+
use_glu=False, # "GLU Variants Improve Transformer"
|
68 |
**kwargs,
|
69 |
):
|
70 |
+
# text normalizer
|
71 |
self.normalize_text = normalize_text
|
72 |
+
|
73 |
+
# transformer variants
|
74 |
+
self.head_scale = head_scale # per Normformer
|
75 |
+
assert ln_type in [
|
76 |
+
"rmsnorm",
|
77 |
+
"layernorm",
|
78 |
+
], "ln_type must be 'rmsnorm' or 'layernorm'"
|
79 |
+
self.ln_type = ln_type
|
80 |
+
assert ln_positions in [
|
81 |
+
"normformer",
|
82 |
+
"swinv2",
|
83 |
+
"deepnet",
|
84 |
+
], "ln_positions must be 'normformer', 'swinv2' or 'deepnet'"
|
85 |
+
self.ln_positions = ln_positions
|
86 |
+
self.use_cosine_attention = use_cosine_attention
|
87 |
+
self.tau_init = tau_init
|
88 |
+
self.use_deepnet_scaling = use_deepnet_scaling
|
89 |
+
self.use_glu = use_glu
|
90 |
+
|
91 |
+
# common parameters
|
92 |
self.encoder_vocab_size = encoder_vocab_size
|
93 |
self.image_vocab_size = image_vocab_size
|
94 |
self.image_length = image_length
|
|
|
105 |
self.activation_dropout = activation_dropout
|
106 |
self.activation_function = activation_function
|
107 |
self.init_std = init_std
|
|
|
|
|
|
|
108 |
self.use_cache = use_cache
|
109 |
self.gradient_checkpointing = gradient_checkpointing
|
110 |
self.scale_embedding = (
|
src/dalle_mini/model/modeling.py
CHANGED
@@ -18,7 +18,7 @@ import math
|
|
18 |
import os
|
19 |
from functools import partial
|
20 |
from pickle import UnpicklingError
|
21 |
-
from typing import Dict, Optional, Tuple, Union
|
22 |
|
23 |
import flax
|
24 |
import flax.linen as nn
|
@@ -26,7 +26,9 @@ import jax
|
|
26 |
import jax.numpy as jnp
|
27 |
import msgpack.exceptions
|
28 |
from flax.core.frozen_dict import unfreeze
|
29 |
-
from flax.linen import make_causal_mask
|
|
|
|
|
30 |
from flax.serialization import from_bytes
|
31 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
32 |
from jax import lax
|
@@ -42,6 +44,8 @@ from transformers.file_utils import (
|
|
42 |
)
|
43 |
from transformers.generation_flax_utils import FlaxSampleOutput
|
44 |
from transformers.modeling_flax_outputs import (
|
|
|
|
|
45 |
FlaxCausalLMOutputWithCrossAttentions,
|
46 |
FlaxSeq2SeqLMOutput,
|
47 |
)
|
@@ -49,11 +53,7 @@ from transformers.modeling_flax_utils import ACT2FN
|
|
49 |
from transformers.models.bart.modeling_flax_bart import (
|
50 |
FlaxBartAttention,
|
51 |
FlaxBartDecoder,
|
52 |
-
FlaxBartDecoderLayer,
|
53 |
-
FlaxBartDecoderLayerCollection,
|
54 |
FlaxBartEncoder,
|
55 |
-
FlaxBartEncoderLayer,
|
56 |
-
FlaxBartEncoderLayerCollection,
|
57 |
FlaxBartForConditionalGeneration,
|
58 |
FlaxBartForConditionalGenerationModule,
|
59 |
FlaxBartModule,
|
@@ -66,13 +66,124 @@ from .utils import PretrainedFromWandbMixin
|
|
66 |
|
67 |
logger = logging.get_logger(__name__)
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
class FlaxBartAttention(FlaxBartAttention):
|
71 |
"""
|
72 |
Edits:
|
73 |
- causal mask is used only in decoder and considers image_length
|
|
|
74 |
"""
|
75 |
|
|
|
|
|
76 |
def setup(self) -> None:
|
77 |
self.head_dim = self.embed_dim // self.num_heads
|
78 |
if self.head_dim * self.num_heads != self.embed_dim:
|
@@ -86,142 +197,667 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
86 |
self.embed_dim,
|
87 |
use_bias=self.bias,
|
88 |
dtype=self.dtype,
|
89 |
-
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
90 |
)
|
91 |
|
92 |
-
|
93 |
-
|
|
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
if self.causal:
|
98 |
# used only in decoder
|
99 |
self.causal_mask = make_causal_mask(
|
100 |
jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool"
|
101 |
)
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
- no bias
|
108 |
-
- use custom FlaxBartAttention
|
109 |
-
"""
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
dtype=self.dtype,
|
|
|
120 |
)
|
121 |
-
self.
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
dtype=self.dtype,
|
128 |
use_bias=False,
|
129 |
-
kernel_init=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
)
|
131 |
-
|
|
|
132 |
self.embed_dim,
|
133 |
dtype=self.dtype,
|
134 |
use_bias=False,
|
135 |
-
kernel_init=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
)
|
137 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
|
140 |
-
class
|
141 |
"""
|
142 |
Edits:
|
143 |
-
-
|
144 |
-
-
|
145 |
"""
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
)
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
]
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
|
160 |
-
class FlaxBartDecoderLayer(
|
161 |
"""
|
162 |
Edits:
|
163 |
- no bias
|
164 |
-
-
|
165 |
"""
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
config=self.config,
|
171 |
-
embed_dim=
|
172 |
num_heads=self.config.decoder_attention_heads,
|
173 |
dropout=self.config.attention_dropout,
|
174 |
causal=True,
|
175 |
bias=False,
|
176 |
dtype=self.dtype,
|
|
|
|
|
|
|
|
|
|
|
177 |
)
|
178 |
-
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
179 |
-
self.activation_fn = ACT2FN[self.config.activation_function]
|
180 |
-
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
|
181 |
|
182 |
-
self.
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
bias=False,
|
189 |
-
dtype=self.dtype,
|
190 |
)
|
191 |
-
|
192 |
-
self.
|
193 |
-
self.config.
|
194 |
-
|
195 |
-
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
)
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
)
|
204 |
-
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
|
205 |
|
206 |
|
207 |
-
class FlaxBartDecoderLayerCollection(
|
|
|
|
|
208 |
"""
|
209 |
Edits:
|
210 |
- use custom FlaxBartDecoderLayer
|
211 |
- allow Gradient Checkpointing (nn.remat)
|
212 |
"""
|
213 |
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
if self.config.gradient_checkpointing
|
218 |
else FlaxBartDecoderLayer
|
219 |
)
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
]
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
|
227 |
class FlaxBartEncoder(FlaxBartEncoder):
|
@@ -246,10 +882,14 @@ class FlaxBartEncoder(FlaxBartEncoder):
|
|
246 |
self.embed_positions = nn.Embed(
|
247 |
self.config.max_text_length + self.offset,
|
248 |
embed_dim,
|
249 |
-
embedding_init=
|
|
|
|
|
250 |
)
|
251 |
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
|
252 |
-
self.layernorm_embedding =
|
|
|
|
|
253 |
|
254 |
|
255 |
class FlaxBartDecoder(FlaxBartDecoder):
|
@@ -276,11 +916,15 @@ class FlaxBartDecoder(FlaxBartDecoder):
|
|
276 |
self.embed_positions = nn.Embed(
|
277 |
self.config.image_length + self.offset, # image length for BOS
|
278 |
embed_dim,
|
279 |
-
embedding_init=
|
|
|
|
|
280 |
)
|
281 |
|
282 |
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
283 |
-
self.layernorm_embedding =
|
|
|
|
|
284 |
|
285 |
|
286 |
class FlaxBartModule(FlaxBartModule):
|
@@ -294,12 +938,16 @@ class FlaxBartModule(FlaxBartModule):
|
|
294 |
encoder_embed_tokens = nn.Embed(
|
295 |
self.config.encoder_vocab_size,
|
296 |
self.config.d_model,
|
297 |
-
embedding_init=
|
|
|
|
|
298 |
)
|
299 |
decoder_embed_tokens = nn.Embed(
|
300 |
self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
|
301 |
self.config.d_model,
|
302 |
-
embedding_init=
|
|
|
|
|
303 |
)
|
304 |
|
305 |
self.encoder = FlaxBartEncoder(
|
@@ -639,7 +1287,9 @@ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationMod
|
|
639 |
+ 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
|
640 |
use_bias=False,
|
641 |
dtype=self.dtype,
|
642 |
-
kernel_init=
|
|
|
|
|
643 |
)
|
644 |
|
645 |
def __call__(
|
|
|
18 |
import os
|
19 |
from functools import partial
|
20 |
from pickle import UnpicklingError
|
21 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
22 |
|
23 |
import flax
|
24 |
import flax.linen as nn
|
|
|
26 |
import jax.numpy as jnp
|
27 |
import msgpack.exceptions
|
28 |
from flax.core.frozen_dict import unfreeze
|
29 |
+
from flax.linen import combine_masks, make_causal_mask
|
30 |
+
from flax.linen import partitioning as nn_partitioning
|
31 |
+
from flax.linen.attention import dot_product_attention_weights
|
32 |
from flax.serialization import from_bytes
|
33 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
34 |
from jax import lax
|
|
|
44 |
)
|
45 |
from transformers.generation_flax_utils import FlaxSampleOutput
|
46 |
from transformers.modeling_flax_outputs import (
|
47 |
+
FlaxBaseModelOutput,
|
48 |
+
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
49 |
FlaxCausalLMOutputWithCrossAttentions,
|
50 |
FlaxSeq2SeqLMOutput,
|
51 |
)
|
|
|
53 |
from transformers.models.bart.modeling_flax_bart import (
|
54 |
FlaxBartAttention,
|
55 |
FlaxBartDecoder,
|
|
|
|
|
56 |
FlaxBartEncoder,
|
|
|
|
|
57 |
FlaxBartForConditionalGeneration,
|
58 |
FlaxBartForConditionalGenerationModule,
|
59 |
FlaxBartModule,
|
|
|
66 |
|
67 |
logger = logging.get_logger(__name__)
|
68 |
|
69 |
+
remat = nn_partitioning.remat
|
70 |
+
|
71 |
+
|
72 |
+
# deepnet initialization
|
73 |
+
def deepnet_init(gain=1):
|
74 |
+
init = jax.nn.initializers.glorot_normal()
|
75 |
+
|
76 |
+
def _init(*args, **kwargs):
|
77 |
+
return gain * init(*args, **kwargs)
|
78 |
+
|
79 |
+
return _init
|
80 |
+
|
81 |
+
|
82 |
+
# deepnet gain
|
83 |
+
deepnet_gain = {
|
84 |
+
"encoder": {
|
85 |
+
"alpha": lambda config: 0.81
|
86 |
+
* (config.encoder_layers**4 * config.decoder_layers) ** 0.0625,
|
87 |
+
"beta": lambda config: 0.87
|
88 |
+
* (config.encoder_layers**4 * config.decoder_layers) ** -0.0625,
|
89 |
+
},
|
90 |
+
"decoder": {
|
91 |
+
"alpha": lambda config: (3 * config.decoder_layers) ** 0.25,
|
92 |
+
"beta": lambda config: (12 * config.decoder_layers) ** -0.25,
|
93 |
+
},
|
94 |
+
}
|
95 |
+
|
96 |
+
|
97 |
+
class RMSNorm(nn.Module):
|
98 |
+
"""
|
99 |
+
From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
|
100 |
+
|
101 |
+
Adapted from flax.linen.LayerNorm
|
102 |
+
"""
|
103 |
+
|
104 |
+
epsilon: float = 1e-6
|
105 |
+
dtype: Any = jnp.float32
|
106 |
+
param_dtype: Any = jnp.float32
|
107 |
+
use_scale: bool = True
|
108 |
+
scale_init: Any = jax.nn.initializers.ones
|
109 |
+
|
110 |
+
@nn.compact
|
111 |
+
def __call__(self, x):
|
112 |
+
reduction_axes = (-1,)
|
113 |
+
feature_axes = (-1,)
|
114 |
+
|
115 |
+
rms_sq = self._compute_rms_sq(x, reduction_axes)
|
116 |
+
|
117 |
+
return self._normalize(
|
118 |
+
self,
|
119 |
+
x,
|
120 |
+
rms_sq,
|
121 |
+
reduction_axes,
|
122 |
+
feature_axes,
|
123 |
+
self.dtype,
|
124 |
+
self.param_dtype,
|
125 |
+
self.epsilon,
|
126 |
+
self.use_scale,
|
127 |
+
self.scale_init,
|
128 |
+
)
|
129 |
+
|
130 |
+
def _compute_rms_sq(self, x, axes):
|
131 |
+
x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
|
132 |
+
rms_sq = jnp.mean(jax.lax.square(x), axes)
|
133 |
+
return rms_sq
|
134 |
+
|
135 |
+
def _normalize(
|
136 |
+
self,
|
137 |
+
mdl,
|
138 |
+
x,
|
139 |
+
rms_sq,
|
140 |
+
reduction_axes,
|
141 |
+
feature_axes,
|
142 |
+
dtype,
|
143 |
+
param_dtype,
|
144 |
+
epsilon,
|
145 |
+
use_scale,
|
146 |
+
scale_init,
|
147 |
+
):
|
148 |
+
reduction_axes = nn.normalization._canonicalize_axes(x.ndim, reduction_axes)
|
149 |
+
feature_axes = nn.normalization._canonicalize_axes(x.ndim, feature_axes)
|
150 |
+
stats_shape = list(x.shape)
|
151 |
+
for axis in reduction_axes:
|
152 |
+
stats_shape[axis] = 1
|
153 |
+
rms_sq = rms_sq.reshape(stats_shape)
|
154 |
+
feature_shape = [1] * x.ndim
|
155 |
+
reduced_feature_shape = []
|
156 |
+
for ax in feature_axes:
|
157 |
+
feature_shape[ax] = x.shape[ax]
|
158 |
+
reduced_feature_shape.append(x.shape[ax])
|
159 |
+
mul = lax.rsqrt(rms_sq + epsilon)
|
160 |
+
if use_scale:
|
161 |
+
scale = mdl.param(
|
162 |
+
"scale", scale_init, reduced_feature_shape, param_dtype
|
163 |
+
).reshape(feature_shape)
|
164 |
+
mul *= scale
|
165 |
+
y = mul * x
|
166 |
+
return jnp.asarray(y, dtype)
|
167 |
+
|
168 |
+
|
169 |
+
def norm(type, *args, **kwargs):
|
170 |
+
if type == "rmsnorm":
|
171 |
+
return RMSNorm(*args, **kwargs)
|
172 |
+
elif type == "layernorm":
|
173 |
+
return nn.LayerNorm(*args, **kwargs)
|
174 |
+
else:
|
175 |
+
raise ValueError(f"Unknown norm type {type}")
|
176 |
+
|
177 |
|
178 |
class FlaxBartAttention(FlaxBartAttention):
|
179 |
"""
|
180 |
Edits:
|
181 |
- causal mask is used only in decoder and considers image_length
|
182 |
+
- scale attention heads per NormFormer paper
|
183 |
"""
|
184 |
|
185 |
+
is_encoder: bool = False
|
186 |
+
|
187 |
def setup(self) -> None:
|
188 |
self.head_dim = self.embed_dim // self.num_heads
|
189 |
if self.head_dim * self.num_heads != self.embed_dim:
|
|
|
197 |
self.embed_dim,
|
198 |
use_bias=self.bias,
|
199 |
dtype=self.dtype,
|
|
|
200 |
)
|
201 |
|
202 |
+
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
|
203 |
+
self.config
|
204 |
+
)
|
205 |
|
206 |
+
self.q_proj = dense(
|
207 |
+
kernel_init=deepnet_init()
|
208 |
+
if self.config.use_deepnet_scaling
|
209 |
+
else jax.nn.initializers.normal(self.config.init_std)
|
210 |
+
)
|
211 |
+
self.k_proj = dense(
|
212 |
+
kernel_init=deepnet_init()
|
213 |
+
if self.config.use_deepnet_scaling
|
214 |
+
else jax.nn.initializers.normal(self.config.init_std)
|
215 |
+
)
|
216 |
+
self.v_proj = dense(
|
217 |
+
kernel_init=deepnet_init(gain)
|
218 |
+
if self.config.use_deepnet_scaling
|
219 |
+
else jax.nn.initializers.normal(self.config.init_std)
|
220 |
+
)
|
221 |
+
self.out_proj = dense(
|
222 |
+
kernel_init=deepnet_init(gain)
|
223 |
+
if self.config.use_deepnet_scaling
|
224 |
+
else jax.nn.initializers.normal(self.config.init_std)
|
225 |
+
)
|
226 |
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
227 |
|
228 |
+
if self.config.head_scale:
|
229 |
+
self.head_scale = self.param(
|
230 |
+
"head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1)
|
231 |
+
)
|
232 |
+
|
233 |
+
if self.config.use_cosine_attention:
|
234 |
+
self.tau = self.param(
|
235 |
+
"tau",
|
236 |
+
jax.nn.initializers.constant(self.config.tau_init),
|
237 |
+
(1, self.num_heads, 1, 1),
|
238 |
+
)
|
239 |
+
|
240 |
if self.causal:
|
241 |
# used only in decoder
|
242 |
self.causal_mask = make_causal_mask(
|
243 |
jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool"
|
244 |
)
|
245 |
|
246 |
+
def __call__(
|
247 |
+
self,
|
248 |
+
hidden_states: jnp.ndarray,
|
249 |
+
key_value_states: Optional[jnp.ndarray] = None,
|
250 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
251 |
+
init_cache: bool = False,
|
252 |
+
deterministic: bool = True,
|
253 |
+
) -> Tuple[jnp.ndarray]:
|
254 |
+
"""Input shape: Batch x Time x Channel"""
|
255 |
+
|
256 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
257 |
+
# for the decoder
|
258 |
+
is_cross_attention = key_value_states is not None
|
259 |
+
batch_size = hidden_states.shape[0]
|
260 |
+
|
261 |
+
# get query proj
|
262 |
+
query_states = self.q_proj(hidden_states)
|
263 |
+
# get key, value proj
|
264 |
+
if is_cross_attention:
|
265 |
+
# cross_attentions
|
266 |
+
key_states = self.k_proj(key_value_states)
|
267 |
+
value_states = self.v_proj(key_value_states)
|
268 |
+
else:
|
269 |
+
# self_attention
|
270 |
+
key_states = self.k_proj(hidden_states)
|
271 |
+
value_states = self.v_proj(hidden_states)
|
272 |
|
273 |
+
query_states = self._split_heads(query_states)
|
274 |
+
key_states = self._split_heads(key_states)
|
275 |
+
value_states = self._split_heads(value_states)
|
|
|
|
|
|
|
276 |
|
277 |
+
# handle cache prepare causal attention mask
|
278 |
+
if self.causal:
|
279 |
+
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
280 |
+
if self.has_variable("cache", "cached_key"):
|
281 |
+
mask_shift = self.variables["cache"]["cache_index"]
|
282 |
+
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
283 |
+
causal_mask = lax.dynamic_slice(
|
284 |
+
self.causal_mask,
|
285 |
+
(0, 0, mask_shift, 0),
|
286 |
+
(1, 1, query_length, max_decoder_length),
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
290 |
+
causal_mask = jnp.broadcast_to(
|
291 |
+
causal_mask, (batch_size,) + causal_mask.shape[1:]
|
292 |
+
)
|
293 |
+
|
294 |
+
# combine masks if needed
|
295 |
+
if attention_mask is not None and self.causal:
|
296 |
+
attention_mask = jnp.broadcast_to(
|
297 |
+
jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape
|
298 |
+
)
|
299 |
+
attention_mask = combine_masks(attention_mask, causal_mask)
|
300 |
+
elif self.causal:
|
301 |
+
attention_mask = causal_mask
|
302 |
+
elif attention_mask is not None:
|
303 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
304 |
+
|
305 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
306 |
+
# and cache the keys and values step by step.
|
307 |
+
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
308 |
+
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
309 |
+
key_states, value_states, query_states, attention_mask
|
310 |
+
)
|
311 |
+
|
312 |
+
# Convert the boolean attention mask to an attention bias.
|
313 |
+
if attention_mask is not None:
|
314 |
+
# attention mask in the form of attention bias
|
315 |
+
attention_bias = lax.select(
|
316 |
+
attention_mask > 0,
|
317 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
318 |
+
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
319 |
+
)
|
320 |
+
else:
|
321 |
+
attention_bias = None
|
322 |
+
|
323 |
+
dropout_rng = None
|
324 |
+
if not deterministic and self.dropout > 0.0:
|
325 |
+
dropout_rng = self.make_rng("dropout")
|
326 |
+
|
327 |
+
if self.config.use_cosine_attention:
|
328 |
+
# normalize q and k
|
329 |
+
query_states = query_states / (
|
330 |
+
jnp.linalg.norm(query_states, axis=-1, keepdims=True) + 1e-8
|
331 |
+
)
|
332 |
+
key_states = key_states / (
|
333 |
+
jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8
|
334 |
+
)
|
335 |
+
attn_weights = dot_product_attention_weights(
|
336 |
+
query_states,
|
337 |
+
key_states,
|
338 |
+
bias=attention_bias,
|
339 |
+
dropout_rng=dropout_rng,
|
340 |
+
dropout_rate=self.dropout,
|
341 |
+
broadcast_dropout=True,
|
342 |
+
deterministic=deterministic,
|
343 |
dtype=self.dtype,
|
344 |
+
precision=None,
|
345 |
)
|
346 |
+
if self.config.use_cosine_attention:
|
347 |
+
# divide by tau
|
348 |
+
attn_weights = attn_weights / jnp.maximum(self.tau, 0.01)
|
349 |
+
|
350 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
351 |
+
if self.config.head_scale:
|
352 |
+
# per Normformer
|
353 |
+
attn_output = attn_output * self.head_scale
|
354 |
+
attn_output = self._merge_heads(attn_output)
|
355 |
+
attn_output = self.out_proj(attn_output)
|
356 |
+
|
357 |
+
return attn_output, attn_weights
|
358 |
+
|
359 |
+
|
360 |
+
class GLU(nn.Module):
|
361 |
+
"""From "GLU Variants Improve Transformer" by https://arxiv.org/abs/2002.05202"""
|
362 |
+
|
363 |
+
config: DalleBartConfig
|
364 |
+
ffn_dim: int
|
365 |
+
embed_dim: int
|
366 |
+
dtype: jnp.dtype = jnp.float32
|
367 |
+
is_encoder: bool = False
|
368 |
+
|
369 |
+
@nn.compact
|
370 |
+
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
371 |
+
|
372 |
+
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
|
373 |
+
self.config
|
374 |
+
)
|
375 |
+
|
376 |
+
if self.config.ln_positions in ["normformer"]:
|
377 |
+
x = norm(
|
378 |
+
self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
|
379 |
+
)(x)
|
380 |
+
w = nn.Dense(
|
381 |
+
self.ffn_dim,
|
382 |
+
dtype=self.dtype,
|
383 |
+
use_bias=False,
|
384 |
+
kernel_init=deepnet_init(gain)
|
385 |
+
if self.config.use_deepnet_scaling
|
386 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
387 |
+
)(x)
|
388 |
+
w = ACT2FN[self.config.activation_function](w)
|
389 |
+
v = nn.Dense(
|
390 |
+
self.ffn_dim,
|
391 |
dtype=self.dtype,
|
392 |
use_bias=False,
|
393 |
+
kernel_init=deepnet_init(gain)
|
394 |
+
if self.config.use_deepnet_scaling
|
395 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
396 |
+
)(x)
|
397 |
+
x = w * v
|
398 |
+
if self.config.ln_positions in ["normformer"]:
|
399 |
+
x = norm(
|
400 |
+
self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
|
401 |
+
)(x)
|
402 |
+
x = nn.Dropout(rate=self.config.activation_dropout)(
|
403 |
+
x, deterministic=deterministic
|
404 |
)
|
405 |
+
|
406 |
+
x = nn.Dense(
|
407 |
self.embed_dim,
|
408 |
dtype=self.dtype,
|
409 |
use_bias=False,
|
410 |
+
kernel_init=deepnet_init(gain)
|
411 |
+
if self.config.use_deepnet_scaling
|
412 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
413 |
+
)(x)
|
414 |
+
if self.config.ln_positions in ["swinv2"]:
|
415 |
+
x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
|
416 |
+
x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
|
417 |
+
return x
|
418 |
+
|
419 |
+
|
420 |
+
class FFN(nn.Module):
|
421 |
+
"""Simple FFN layer"""
|
422 |
+
|
423 |
+
config: DalleBartConfig
|
424 |
+
ffn_dim: int
|
425 |
+
embed_dim: int
|
426 |
+
dtype: jnp.dtype = jnp.float32
|
427 |
+
is_encoder: bool = False
|
428 |
+
|
429 |
+
@nn.compact
|
430 |
+
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
431 |
+
|
432 |
+
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
|
433 |
+
self.config
|
434 |
)
|
435 |
+
if self.config.ln_positions in ["normformer"]:
|
436 |
+
x = norm(
|
437 |
+
self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
|
438 |
+
)(x)
|
439 |
+
x = nn.Dense(
|
440 |
+
self.ffn_dim,
|
441 |
+
dtype=self.dtype,
|
442 |
+
use_bias=False,
|
443 |
+
kernel_init=deepnet_init(gain)
|
444 |
+
if self.config.use_deepnet_scaling
|
445 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
446 |
+
)(x)
|
447 |
+
x = ACT2FN[self.config.activation_function](x)
|
448 |
+
if self.config.ln_positions in ["normformer"]:
|
449 |
+
x = norm(
|
450 |
+
self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
|
451 |
+
)(x)
|
452 |
+
x = nn.Dropout(rate=self.config.activation_dropout)(
|
453 |
+
x, deterministic=deterministic
|
454 |
+
)
|
455 |
+
x = nn.Dense(
|
456 |
+
self.embed_dim,
|
457 |
+
dtype=self.dtype,
|
458 |
+
use_bias=False,
|
459 |
+
kernel_init=deepnet_init(gain)
|
460 |
+
if self.config.use_deepnet_scaling
|
461 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
462 |
+
)(x)
|
463 |
+
if self.config.ln_positions in ["swinv2"]:
|
464 |
+
x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
|
465 |
+
x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
|
466 |
+
return x
|
467 |
|
468 |
|
469 |
+
class FlaxBartEncoderLayer(nn.Module):
|
470 |
"""
|
471 |
Edits:
|
472 |
+
- no bias
|
473 |
+
- use custom FlaxBartAttention
|
474 |
"""
|
475 |
|
476 |
+
config: DalleBartConfig
|
477 |
+
dtype: jnp.dtype = jnp.float32
|
478 |
+
add_norm: bool = False
|
479 |
+
use_scale: bool = True
|
480 |
+
|
481 |
+
@nn.compact
|
482 |
+
def __call__(
|
483 |
+
self,
|
484 |
+
hidden_states: jnp.ndarray,
|
485 |
+
attention_mask: jnp.ndarray,
|
486 |
+
output_attentions: bool = True,
|
487 |
+
deterministic: bool = True,
|
488 |
+
) -> Tuple[jnp.ndarray]:
|
489 |
+
|
490 |
+
res_gain = (
|
491 |
+
deepnet_gain["encoder"]["alpha"](self.config)
|
492 |
+
if self.config.use_deepnet_scaling
|
493 |
+
else 1
|
494 |
)
|
495 |
+
|
496 |
+
embed_dim = self.config.d_model
|
497 |
+
residual = hidden_states
|
498 |
+
if self.config.ln_positions in ["normformer"]:
|
499 |
+
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
500 |
+
hidden_states
|
501 |
+
)
|
502 |
+
hidden_states, attn_weights = FlaxBartAttention(
|
503 |
+
config=self.config,
|
504 |
+
embed_dim=embed_dim,
|
505 |
+
num_heads=self.config.encoder_attention_heads,
|
506 |
+
dropout=self.config.attention_dropout,
|
507 |
+
bias=False,
|
508 |
+
dtype=self.dtype,
|
509 |
+
is_encoder=True,
|
510 |
+
)(hidden_states=hidden_states, attention_mask=attention_mask)
|
511 |
+
|
512 |
+
if self.config.ln_positions in ["normformer", "swinv2"]:
|
513 |
+
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
514 |
+
hidden_states
|
515 |
+
)
|
516 |
+
hidden_states = nn.Dropout(rate=self.config.dropout)(
|
517 |
+
hidden_states, deterministic=deterministic
|
518 |
+
)
|
519 |
+
hidden_states = residual * res_gain + hidden_states
|
520 |
+
if self.config.ln_positions in ["deepnet"]:
|
521 |
+
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
522 |
+
hidden_states
|
523 |
+
)
|
524 |
+
|
525 |
+
residual = hidden_states
|
526 |
+
ff_block = (
|
527 |
+
GLU(
|
528 |
+
config=self.config,
|
529 |
+
ffn_dim=self.config.encoder_ffn_dim,
|
530 |
+
embed_dim=embed_dim,
|
531 |
+
dtype=self.dtype,
|
532 |
+
is_encoder=True,
|
533 |
+
)
|
534 |
+
if self.config.use_glu
|
535 |
+
else FFN(
|
536 |
+
config=self.config,
|
537 |
+
ffn_dim=self.config.encoder_ffn_dim,
|
538 |
+
embed_dim=embed_dim,
|
539 |
+
dtype=self.dtype,
|
540 |
+
is_encoder=True,
|
541 |
+
)
|
542 |
+
)
|
543 |
+
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
544 |
+
hidden_states = residual * res_gain + hidden_states
|
545 |
+
if self.add_norm or self.config.ln_positions in ["deepnet"]:
|
546 |
+
use_scale = self.use_scale or self.config.ln_positions == "deepnet"
|
547 |
+
hidden_states = norm(
|
548 |
+
self.config.ln_type,
|
549 |
+
dtype=self.dtype,
|
550 |
+
epsilon=1e-05,
|
551 |
+
use_scale=use_scale,
|
552 |
+
)(hidden_states)
|
553 |
+
|
554 |
+
outputs = (hidden_states,)
|
555 |
+
|
556 |
+
if output_attentions:
|
557 |
+
outputs += (attn_weights,)
|
558 |
+
|
559 |
+
return outputs
|
560 |
|
561 |
|
562 |
+
class FlaxBartDecoderLayer(nn.Module):
|
563 |
"""
|
564 |
Edits:
|
565 |
- no bias
|
566 |
+
- use custom FlaxBartAttention
|
567 |
"""
|
568 |
|
569 |
+
config: DalleBartConfig
|
570 |
+
dtype: jnp.dtype = jnp.float32
|
571 |
+
add_norm: bool = False
|
572 |
+
use_scale: bool = False
|
573 |
+
|
574 |
+
@nn.compact
|
575 |
+
def __call__(
|
576 |
+
self,
|
577 |
+
hidden_states: jnp.ndarray,
|
578 |
+
attention_mask: jnp.ndarray,
|
579 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
580 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
581 |
+
init_cache: bool = False,
|
582 |
+
output_attentions: bool = True,
|
583 |
+
deterministic: bool = True,
|
584 |
+
) -> Tuple[jnp.ndarray]:
|
585 |
+
|
586 |
+
res_gain = (
|
587 |
+
deepnet_gain["decoder"]["alpha"](self.config)
|
588 |
+
if self.config.use_deepnet_scaling
|
589 |
+
else 1
|
590 |
+
)
|
591 |
+
|
592 |
+
embed_dim = self.config.d_model
|
593 |
+
residual = hidden_states
|
594 |
+
|
595 |
+
# Self Attention
|
596 |
+
if self.config.ln_positions in ["normformer"]:
|
597 |
+
hidden_states = norm(
|
598 |
+
self.config.ln_type,
|
599 |
+
dtype=self.dtype,
|
600 |
+
epsilon=1e-05,
|
601 |
+
use_scale=False,
|
602 |
+
)(hidden_states)
|
603 |
+
hidden_states, attn_weights = FlaxBartAttention(
|
604 |
config=self.config,
|
605 |
+
embed_dim=embed_dim,
|
606 |
num_heads=self.config.decoder_attention_heads,
|
607 |
dropout=self.config.attention_dropout,
|
608 |
causal=True,
|
609 |
bias=False,
|
610 |
dtype=self.dtype,
|
611 |
+
is_encoder=False,
|
612 |
+
)(
|
613 |
+
hidden_states=hidden_states,
|
614 |
+
attention_mask=attention_mask,
|
615 |
+
init_cache=init_cache,
|
616 |
)
|
|
|
|
|
|
|
617 |
|
618 |
+
if self.config.ln_positions in ["normformer", "swinv2"]:
|
619 |
+
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
620 |
+
hidden_states
|
621 |
+
)
|
622 |
+
hidden_states = nn.Dropout(rate=self.config.dropout)(
|
623 |
+
hidden_states, deterministic=deterministic
|
|
|
|
|
624 |
)
|
625 |
+
hidden_states = residual * res_gain + hidden_states
|
626 |
+
if self.config.ln_positions in ["deepnet"]:
|
627 |
+
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
|
628 |
+
hidden_states
|
629 |
+
)
|
630 |
+
|
631 |
+
# Cross Attention
|
632 |
+
cross_attn_weights = None
|
633 |
+
if encoder_hidden_states is not None:
|
634 |
+
residual = hidden_states
|
635 |
+
if self.config.ln_positions in ["normformer"]:
|
636 |
+
hidden_states = norm(
|
637 |
+
self.config.ln_type,
|
638 |
+
dtype=self.dtype,
|
639 |
+
epsilon=1e-05,
|
640 |
+
use_scale=False,
|
641 |
+
)(hidden_states)
|
642 |
+
hidden_states, cross_attn_weights = FlaxBartAttention(
|
643 |
+
config=self.config,
|
644 |
+
embed_dim=embed_dim,
|
645 |
+
num_heads=self.config.decoder_attention_heads,
|
646 |
+
dropout=self.config.attention_dropout,
|
647 |
+
bias=False,
|
648 |
+
dtype=self.dtype,
|
649 |
+
is_encoder=False,
|
650 |
+
)(
|
651 |
+
hidden_states=hidden_states,
|
652 |
+
key_value_states=encoder_hidden_states,
|
653 |
+
attention_mask=encoder_attention_mask,
|
654 |
+
)
|
655 |
+
if self.config.ln_positions in ["normformer", "swinv2"]:
|
656 |
+
hidden_states = norm(
|
657 |
+
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
658 |
+
)(hidden_states)
|
659 |
+
hidden_states = nn.Dropout(rate=self.config.dropout)(
|
660 |
+
hidden_states, deterministic=deterministic
|
661 |
+
)
|
662 |
+
hidden_states = residual * res_gain + hidden_states
|
663 |
+
if self.config.ln_positions in ["deepnet"]:
|
664 |
+
hidden_states = norm(
|
665 |
+
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
666 |
+
)(hidden_states)
|
667 |
+
|
668 |
+
# Feed forward
|
669 |
+
residual = hidden_states
|
670 |
+
ff_block = (
|
671 |
+
GLU(
|
672 |
+
config=self.config,
|
673 |
+
ffn_dim=self.config.decoder_ffn_dim,
|
674 |
+
embed_dim=embed_dim,
|
675 |
+
dtype=self.dtype,
|
676 |
+
is_encoder=False,
|
677 |
+
)
|
678 |
+
if self.config.use_glu
|
679 |
+
else FFN(
|
680 |
+
config=self.config,
|
681 |
+
ffn_dim=self.config.decoder_ffn_dim,
|
682 |
+
embed_dim=embed_dim,
|
683 |
+
dtype=self.dtype,
|
684 |
+
is_encoder=False,
|
685 |
+
)
|
686 |
)
|
687 |
+
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
688 |
+
hidden_states = residual * res_gain + hidden_states
|
689 |
+
if self.add_norm or self.config.ln_positions in ["deepnet"]:
|
690 |
+
use_scale = self.use_scale or self.config.ln_positions == "deepnet"
|
691 |
+
hidden_states = norm(
|
692 |
+
self.config.ln_type,
|
693 |
+
dtype=self.dtype,
|
694 |
+
epsilon=1e-05,
|
695 |
+
use_scale=use_scale,
|
696 |
+
)(hidden_states)
|
697 |
+
|
698 |
+
outputs = (hidden_states,)
|
699 |
+
|
700 |
+
if output_attentions:
|
701 |
+
outputs += (attn_weights, cross_attn_weights)
|
702 |
+
|
703 |
+
return outputs
|
704 |
+
|
705 |
+
|
706 |
+
class FlaxBartEncoderLayerCollection(nn.Module):
|
707 |
+
config: DalleBartConfig
|
708 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
709 |
+
"""
|
710 |
+
Edits:
|
711 |
+
- use custom FlaxBartEncoderLayer
|
712 |
+
- allow Gradient Checkpointing (nn.remat)
|
713 |
+
"""
|
714 |
+
|
715 |
+
@nn.compact
|
716 |
+
def __call__(
|
717 |
+
self,
|
718 |
+
hidden_states,
|
719 |
+
attention_mask,
|
720 |
+
deterministic: bool = True,
|
721 |
+
output_attentions: bool = False,
|
722 |
+
output_hidden_states: bool = False,
|
723 |
+
return_dict: bool = True,
|
724 |
+
):
|
725 |
+
all_hidden_states = () if output_hidden_states else None
|
726 |
+
all_self_attns = () if output_attentions else None
|
727 |
+
|
728 |
+
n_layers = self.config.encoder_layers
|
729 |
+
layer = (
|
730 |
+
remat(FlaxBartEncoderLayer, static_argnums=(2, 3))
|
731 |
+
if self.config.gradient_checkpointing
|
732 |
+
else FlaxBartEncoderLayer
|
733 |
+
)
|
734 |
+
for i in range(n_layers):
|
735 |
+
if output_hidden_states:
|
736 |
+
all_hidden_states += (hidden_states,)
|
737 |
+
# final layernorm on the output of the last layer
|
738 |
+
# or every 6 layers for Swin v2
|
739 |
+
# ignored args for deepnet which always add a norm with scale
|
740 |
+
add_norm = (i == n_layers - 1) or (
|
741 |
+
(self.config.ln_positions == "swinv2") and ((i + 1) % 6 == 0)
|
742 |
+
)
|
743 |
+
# we don't need to scale the norm for the last layer
|
744 |
+
use_scale = i != n_layers - 1
|
745 |
+
layer_outputs = layer(
|
746 |
+
self.config, dtype=self.dtype, add_norm=add_norm, use_scale=use_scale
|
747 |
+
)(
|
748 |
+
hidden_states,
|
749 |
+
attention_mask,
|
750 |
+
output_attentions,
|
751 |
+
deterministic,
|
752 |
+
)
|
753 |
+
hidden_states = layer_outputs[0]
|
754 |
+
if output_attentions:
|
755 |
+
all_self_attns += (layer_outputs[1],)
|
756 |
+
|
757 |
+
# add hidden states from the last layer
|
758 |
+
if output_hidden_states:
|
759 |
+
all_hidden_states += (hidden_states,)
|
760 |
+
|
761 |
+
outputs = [
|
762 |
+
hidden_states,
|
763 |
+
all_hidden_states,
|
764 |
+
all_self_attns,
|
765 |
+
]
|
766 |
+
|
767 |
+
if not return_dict:
|
768 |
+
return tuple(v for v in outputs if v is not None)
|
769 |
+
|
770 |
+
return FlaxBaseModelOutput(
|
771 |
+
last_hidden_state=hidden_states,
|
772 |
+
hidden_states=all_hidden_states,
|
773 |
+
attentions=all_self_attns,
|
774 |
)
|
|
|
775 |
|
776 |
|
777 |
+
class FlaxBartDecoderLayerCollection(nn.Module):
|
778 |
+
config: DalleBartConfig
|
779 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
780 |
"""
|
781 |
Edits:
|
782 |
- use custom FlaxBartDecoderLayer
|
783 |
- allow Gradient Checkpointing (nn.remat)
|
784 |
"""
|
785 |
|
786 |
+
@nn.compact
|
787 |
+
def __call__(
|
788 |
+
self,
|
789 |
+
hidden_states,
|
790 |
+
attention_mask,
|
791 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
792 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
793 |
+
deterministic: bool = True,
|
794 |
+
init_cache: bool = False,
|
795 |
+
output_attentions: bool = False,
|
796 |
+
output_hidden_states: bool = False,
|
797 |
+
return_dict: bool = True,
|
798 |
+
):
|
799 |
+
# decoder layers
|
800 |
+
all_hidden_states = () if output_hidden_states else None
|
801 |
+
all_self_attns = () if output_attentions else None
|
802 |
+
all_cross_attentions = (
|
803 |
+
() if (output_attentions and encoder_hidden_states is not None) else None
|
804 |
+
)
|
805 |
+
|
806 |
+
n_layers = self.config.decoder_layers
|
807 |
+
layer = (
|
808 |
+
remat(FlaxBartDecoderLayer, static_argnums=(4, 5, 6))
|
809 |
if self.config.gradient_checkpointing
|
810 |
else FlaxBartDecoderLayer
|
811 |
)
|
812 |
+
for i in range(n_layers):
|
813 |
+
if output_hidden_states:
|
814 |
+
all_hidden_states += (hidden_states,)
|
815 |
+
# final layernorm on the output of the last layer
|
816 |
+
# or every 6 layers for Swin v2
|
817 |
+
add_norm = (i == n_layers - 1) or (
|
818 |
+
(self.config.ln_positions == "swinv2") and ((i + 1) % 6 == 0)
|
819 |
+
)
|
820 |
+
# we don't need to scale the norm for the last layer
|
821 |
+
use_scale = i != n_layers - 1
|
822 |
+
layer_outputs = layer(
|
823 |
+
self.config, dtype=self.dtype, add_norm=add_norm, use_scale=use_scale
|
824 |
+
)(
|
825 |
+
hidden_states,
|
826 |
+
attention_mask,
|
827 |
+
encoder_hidden_states,
|
828 |
+
encoder_attention_mask,
|
829 |
+
init_cache,
|
830 |
+
output_attentions,
|
831 |
+
deterministic,
|
832 |
+
)
|
833 |
+
|
834 |
+
hidden_states = layer_outputs[0]
|
835 |
+
if output_attentions:
|
836 |
+
all_self_attns += (layer_outputs[1],)
|
837 |
+
|
838 |
+
if encoder_hidden_states is not None:
|
839 |
+
all_cross_attentions += (layer_outputs[2],)
|
840 |
+
|
841 |
+
# add hidden states from the last decoder layer
|
842 |
+
if output_hidden_states:
|
843 |
+
all_hidden_states += (hidden_states,)
|
844 |
+
|
845 |
+
outputs = [
|
846 |
+
hidden_states,
|
847 |
+
all_hidden_states,
|
848 |
+
all_self_attns,
|
849 |
+
all_cross_attentions,
|
850 |
]
|
851 |
+
|
852 |
+
if not return_dict:
|
853 |
+
return tuple(v for v in outputs if v is not None)
|
854 |
+
|
855 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
856 |
+
last_hidden_state=hidden_states,
|
857 |
+
hidden_states=all_hidden_states,
|
858 |
+
attentions=all_self_attns,
|
859 |
+
cross_attentions=all_cross_attentions,
|
860 |
+
)
|
861 |
|
862 |
|
863 |
class FlaxBartEncoder(FlaxBartEncoder):
|
|
|
882 |
self.embed_positions = nn.Embed(
|
883 |
self.config.max_text_length + self.offset,
|
884 |
embed_dim,
|
885 |
+
embedding_init=deepnet_init()
|
886 |
+
if self.config.use_deepnet_scaling
|
887 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
888 |
)
|
889 |
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
|
890 |
+
self.layernorm_embedding = norm(
|
891 |
+
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
892 |
+
)
|
893 |
|
894 |
|
895 |
class FlaxBartDecoder(FlaxBartDecoder):
|
|
|
916 |
self.embed_positions = nn.Embed(
|
917 |
self.config.image_length + self.offset, # image length for BOS
|
918 |
embed_dim,
|
919 |
+
embedding_init=deepnet_init()
|
920 |
+
if self.config.use_deepnet_scaling
|
921 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
922 |
)
|
923 |
|
924 |
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
925 |
+
self.layernorm_embedding = norm(
|
926 |
+
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
|
927 |
+
)
|
928 |
|
929 |
|
930 |
class FlaxBartModule(FlaxBartModule):
|
|
|
938 |
encoder_embed_tokens = nn.Embed(
|
939 |
self.config.encoder_vocab_size,
|
940 |
self.config.d_model,
|
941 |
+
embedding_init=deepnet_init()
|
942 |
+
if self.config.use_deepnet_scaling
|
943 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
944 |
)
|
945 |
decoder_embed_tokens = nn.Embed(
|
946 |
self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
|
947 |
self.config.d_model,
|
948 |
+
embedding_init=deepnet_init()
|
949 |
+
if self.config.use_deepnet_scaling
|
950 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
951 |
)
|
952 |
|
953 |
self.encoder = FlaxBartEncoder(
|
|
|
1287 |
+ 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
|
1288 |
use_bias=False,
|
1289 |
dtype=self.dtype,
|
1290 |
+
kernel_init=deepnet_init()
|
1291 |
+
if self.config.use_deepnet_scaling
|
1292 |
+
else jax.nn.initializers.normal(self.config.init_std),
|
1293 |
)
|
1294 |
|
1295 |
def __call__(
|
src/dalle_mini/model/partitions.py
CHANGED
@@ -36,23 +36,21 @@ def _replacement_rules(rules):
|
|
36 |
def _get_partition_rules():
|
37 |
return [
|
38 |
# embeddings
|
39 |
-
((
|
40 |
-
((
|
41 |
-
#
|
42 |
-
((
|
43 |
-
((
|
44 |
-
# enc-dec attention
|
45 |
-
((r"encoder_attn", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
|
46 |
-
((r"encoder_attn", "out_proj", "kernel"), P("mp", None)),
|
47 |
# FFN
|
48 |
-
((
|
49 |
-
((
|
|
|
|
|
50 |
# layer norms
|
51 |
-
((
|
52 |
-
((
|
53 |
-
|
54 |
-
((
|
55 |
-
((r"lm_head", "kernel"), P(None, "mp")),
|
56 |
]
|
57 |
|
58 |
|
@@ -63,6 +61,6 @@ def set_partitions(in_dict):
|
|
63 |
result = {k: replace(k, v) for k, v in initd.items()}
|
64 |
for k, v in result.items():
|
65 |
if v == _unmatched:
|
66 |
-
print(k)
|
67 |
assert _unmatched not in result.values(), "Incomplete partition spec."
|
68 |
return freeze(unflatten_dict(result))
|
|
|
36 |
def _get_partition_rules():
|
37 |
return [
|
38 |
# embeddings
|
39 |
+
(("embed_positions", "embedding"), P("mp", None)),
|
40 |
+
(("embed_tokens", "embedding"), P("mp", None)),
|
41 |
+
# attention
|
42 |
+
(("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
|
43 |
+
(("out_proj", "kernel"), P("mp", None)),
|
|
|
|
|
|
|
44 |
# FFN
|
45 |
+
(("Dense_0", "kernel"), P(None, "mp")),
|
46 |
+
(("GLU.*", "Dense_1", "kernel"), P(None, "mp")),
|
47 |
+
(("GLU.*", "Dense_2", "kernel"), P("mp", None)),
|
48 |
+
(("FFN.*", "Dense_1", "kernel"), P("mp", None)),
|
49 |
# layer norms
|
50 |
+
(("(bias|scale)",), None),
|
51 |
+
(("lm_head", "kernel"), P(None, "mp")),
|
52 |
+
# head scale and tau
|
53 |
+
(("(head_scale|tau)",), None),
|
|
|
54 |
]
|
55 |
|
56 |
|
|
|
61 |
result = {k: replace(k, v) for k, v in initd.items()}
|
62 |
for k, v in result.items():
|
63 |
if v == _unmatched:
|
64 |
+
print(f"Unmatched -> {k}")
|
65 |
assert _unmatched not in result.values(), "Incomplete partition spec."
|
66 |
return freeze(unflatten_dict(result))
|
src/dalle_mini/model/utils.py
CHANGED
@@ -9,13 +9,11 @@ class PretrainedFromWandbMixin:
|
|
9 |
@classmethod
|
10 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
11 |
"""
|
12 |
-
Initializes from a wandb artifact
|
13 |
"""
|
14 |
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
|
15 |
-
if (
|
16 |
-
|
17 |
-
and not os.path.isdir(pretrained_model_name_or_path)
|
18 |
-
and not pretrained_model_name_or_path.startswith("gs")
|
19 |
):
|
20 |
# wandb artifact
|
21 |
if wandb.run is not None:
|
@@ -27,17 +25,3 @@ class PretrainedFromWandbMixin:
|
|
27 |
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|
28 |
pretrained_model_name_or_path, *model_args, **kwargs
|
29 |
)
|
30 |
-
|
31 |
-
|
32 |
-
def copy_blobs(source_path, dest_path):
|
33 |
-
assert source_path.startswith("gs://")
|
34 |
-
from google.cloud import storage
|
35 |
-
|
36 |
-
bucket_path = Path(source_path[5:])
|
37 |
-
bucket, dir_path = str(bucket_path).split("/", 1)
|
38 |
-
client = storage.Client()
|
39 |
-
bucket = client.bucket(bucket)
|
40 |
-
blobs = client.list_blobs(bucket, prefix=f"{dir_path}/")
|
41 |
-
for blob in blobs:
|
42 |
-
dest_name = str(Path(dest_path) / Path(blob.name).name)
|
43 |
-
blob.download_to_filename(dest_name)
|
|
|
9 |
@classmethod
|
10 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
11 |
"""
|
12 |
+
Initializes from a wandb artifact or delegates loading to the superclass.
|
13 |
"""
|
14 |
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
|
15 |
+
if ":" in pretrained_model_name_or_path and not os.path.isdir(
|
16 |
+
pretrained_model_name_or_path
|
|
|
|
|
17 |
):
|
18 |
# wandb artifact
|
19 |
if wandb.run is not None:
|
|
|
25 |
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|
26 |
pretrained_model_name_or_path, *model_args, **kwargs
|
27 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/inference/inference_pipeline.ipynb
CHANGED
@@ -47,7 +47,7 @@
|
|
47 |
"outputs": [],
|
48 |
"source": [
|
49 |
"# Install required libraries\n",
|
50 |
-
"!pip install -q transformers\n",
|
51 |
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
|
52 |
"!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
|
53 |
]
|
@@ -75,7 +75,7 @@
|
|
75 |
"# Model references\n",
|
76 |
"\n",
|
77 |
"# dalle-mini\n",
|
78 |
-
"DALLE_MODEL = \"dalle-mini/dalle-mini/model-
|
79 |
"DALLE_COMMIT_ID = None\n",
|
80 |
"\n",
|
81 |
"# VQGAN model\n",
|
@@ -302,21 +302,7 @@
|
|
302 |
},
|
303 |
"outputs": [],
|
304 |
"source": [
|
305 |
-
"tokenized_prompt = processor([prompt])
|
306 |
-
"tokenized_prompt"
|
307 |
-
]
|
308 |
-
},
|
309 |
-
{
|
310 |
-
"cell_type": "markdown",
|
311 |
-
"metadata": {
|
312 |
-
"id": "_Y5dqFj7prMQ"
|
313 |
-
},
|
314 |
-
"source": [
|
315 |
-
"Notes:\n",
|
316 |
-
"\n",
|
317 |
-
"* `0`: BOS, special token representing the beginning of a sequence\n",
|
318 |
-
"* `2`: EOS, special token representing the end of a sequence\n",
|
319 |
-
"* `1`: special token representing the padding of a sequence when requesting a specific length"
|
320 |
]
|
321 |
},
|
322 |
{
|
@@ -459,13 +445,6 @@
|
|
459 |
" display(images[idx])\n",
|
460 |
" print(f\"Score: {logits[idx]:.2f}\\n\")"
|
461 |
]
|
462 |
-
},
|
463 |
-
{
|
464 |
-
"cell_type": "code",
|
465 |
-
"execution_count": null,
|
466 |
-
"metadata": {},
|
467 |
-
"outputs": [],
|
468 |
-
"source": []
|
469 |
}
|
470 |
],
|
471 |
"metadata": {
|
|
|
47 |
"outputs": [],
|
48 |
"source": [
|
49 |
"# Install required libraries\n",
|
50 |
+
"!pip install -q git+https://github.com/huggingface/transformers.git\n",
|
51 |
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
|
52 |
"!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
|
53 |
]
|
|
|
75 |
"# Model references\n",
|
76 |
"\n",
|
77 |
"# dalle-mini\n",
|
78 |
+
"DALLE_MODEL = \"dalle-mini/dalle-mini/model-3e2l7fxk:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
|
79 |
"DALLE_COMMIT_ID = None\n",
|
80 |
"\n",
|
81 |
"# VQGAN model\n",
|
|
|
302 |
},
|
303 |
"outputs": [],
|
304 |
"source": [
|
305 |
+
"tokenized_prompt = processor([prompt])"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
]
|
307 |
},
|
308 |
{
|
|
|
445 |
" display(images[idx])\n",
|
446 |
" print(f\"Score: {logits[idx]:.2f}\\n\")"
|
447 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
}
|
449 |
],
|
450 |
"metadata": {
|
tools/train/config/medium/config.json
CHANGED
@@ -3,7 +3,6 @@
|
|
3 |
"activation_function": "gelu",
|
4 |
"attention_dropout": 0.0,
|
5 |
"bos_token_id": 16385,
|
6 |
-
"classifier_dropout": 0.0,
|
7 |
"d_model": 1408,
|
8 |
"decoder_attention_heads": 16,
|
9 |
"decoder_ffn_dim": 4096,
|
|
|
3 |
"activation_function": "gelu",
|
4 |
"attention_dropout": 0.0,
|
5 |
"bos_token_id": 16385,
|
|
|
6 |
"d_model": 1408,
|
7 |
"decoder_attention_heads": 16,
|
8 |
"decoder_ffn_dim": 4096,
|
tools/train/config/mega/config.json
CHANGED
@@ -3,7 +3,6 @@
|
|
3 |
"activation_function": "gelu",
|
4 |
"attention_dropout": 0.0,
|
5 |
"bos_token_id": 16385,
|
6 |
-
"classifier_dropout": 0.0,
|
7 |
"d_model": 2048,
|
8 |
"decoder_attention_heads": 32,
|
9 |
"decoder_ffn_dim": 8192,
|
|
|
3 |
"activation_function": "gelu",
|
4 |
"attention_dropout": 0.0,
|
5 |
"bos_token_id": 16385,
|
|
|
6 |
"d_model": 2048,
|
7 |
"decoder_attention_heads": 32,
|
8 |
"decoder_ffn_dim": 8192,
|
tools/train/config/micro/config.json
CHANGED
@@ -3,7 +3,6 @@
|
|
3 |
"activation_function": "gelu",
|
4 |
"attention_dropout": 0.0,
|
5 |
"bos_token_id": 16385,
|
6 |
-
"classifier_dropout": 0.0,
|
7 |
"d_model": 256,
|
8 |
"decoder_attention_heads": 2,
|
9 |
"decoder_ffn_dim": 256,
|
|
|
3 |
"activation_function": "gelu",
|
4 |
"attention_dropout": 0.0,
|
5 |
"bos_token_id": 16385,
|
|
|
6 |
"d_model": 256,
|
7 |
"decoder_attention_heads": 2,
|
8 |
"decoder_ffn_dim": 256,
|
tools/train/config/mini/config.json
CHANGED
@@ -3,17 +3,14 @@
|
|
3 |
"activation_function": "gelu",
|
4 |
"attention_dropout": 0.0,
|
5 |
"bos_token_id": 16385,
|
6 |
-
"classifier_dropout": 0.0,
|
7 |
"d_model": 1024,
|
8 |
"decoder_attention_heads": 16,
|
9 |
-
"decoder_ffn_dim":
|
10 |
-
"decoder_layerdrop": 0.0,
|
11 |
"decoder_layers": 12,
|
12 |
"decoder_start_token_id": 16384,
|
13 |
"dropout": 0.0,
|
14 |
"encoder_attention_heads": 16,
|
15 |
-
"encoder_ffn_dim":
|
16 |
-
"encoder_layerdrop": 0.0,
|
17 |
"encoder_layers": 12,
|
18 |
"encoder_vocab_size": 50264,
|
19 |
"eos_token_id": 16385,
|
|
|
3 |
"activation_function": "gelu",
|
4 |
"attention_dropout": 0.0,
|
5 |
"bos_token_id": 16385,
|
|
|
6 |
"d_model": 1024,
|
7 |
"decoder_attention_heads": 16,
|
8 |
+
"decoder_ffn_dim": 2560,
|
|
|
9 |
"decoder_layers": 12,
|
10 |
"decoder_start_token_id": 16384,
|
11 |
"dropout": 0.0,
|
12 |
"encoder_attention_heads": 16,
|
13 |
+
"encoder_ffn_dim": 2560,
|
|
|
14 |
"encoder_layers": 12,
|
15 |
"encoder_vocab_size": 50264,
|
16 |
"eos_token_id": 16385,
|
tools/train/train.py
CHANGED
@@ -18,7 +18,6 @@ Training DALL·E Mini.
|
|
18 |
Script adapted from run_summarization_flax.py
|
19 |
"""
|
20 |
|
21 |
-
import copy
|
22 |
import io
|
23 |
import logging
|
24 |
import os
|
@@ -30,8 +29,10 @@ from pathlib import Path
|
|
30 |
from typing import Any, Callable, NamedTuple, Optional
|
31 |
|
32 |
import datasets
|
|
|
33 |
import jax
|
34 |
import jax.numpy as jnp
|
|
|
35 |
import numpy as np
|
36 |
import optax
|
37 |
import transformers
|
@@ -405,10 +406,14 @@ class TrainingArguments:
|
|
405 |
default=False,
|
406 |
metadata={"help": "Log model to wandb at `save_steps` frequency."},
|
407 |
)
|
408 |
-
|
|
|
|
|
|
|
|
|
409 |
default=False,
|
410 |
metadata={
|
411 |
-
"help": "Log parameters and gradients histograms. Slows down training."
|
412 |
},
|
413 |
)
|
414 |
|
@@ -471,6 +476,8 @@ class TrainingArguments:
|
|
471 |
], f"Selected learning rate decay not supported: {self.lr_decay}"
|
472 |
if self.per_device_eval_batch_size is None:
|
473 |
self.per_device_eval_batch_size = self.per_device_train_batch_size
|
|
|
|
|
474 |
if (
|
475 |
os.path.exists(self.output_dir)
|
476 |
and os.listdir(self.output_dir)
|
@@ -497,48 +504,6 @@ class TrainState(train_state.TrainState):
|
|
497 |
train_samples: int = 0 # number of samples seen
|
498 |
|
499 |
|
500 |
-
class MetricsLogger:
|
501 |
-
def __init__(self, step):
|
502 |
-
self.step = step
|
503 |
-
self.time = time.perf_counter()
|
504 |
-
self.state_dict = {}
|
505 |
-
|
506 |
-
def update_state_metrics(self, state):
|
507 |
-
"""Update internal state metrics (logged at each call to be used as x-axis)"""
|
508 |
-
self.state_dict = {
|
509 |
-
f'train/{k.split("_")[-1]}': getattr(state, k)
|
510 |
-
for k in ["step", "epoch", "train_time", "train_samples"]
|
511 |
-
}
|
512 |
-
# timing metrics
|
513 |
-
new_step = int(state.step)
|
514 |
-
new_time = time.perf_counter()
|
515 |
-
if new_step > self.step:
|
516 |
-
time_per_step = (new_time - self.time) / (new_step - self.step)
|
517 |
-
self.step = new_step
|
518 |
-
self.time = new_time
|
519 |
-
self.state_dict["train/time_per_step"] = time_per_step
|
520 |
-
|
521 |
-
def log(self, metrics, prefix=None):
|
522 |
-
if jax.process_index() == 0:
|
523 |
-
log_metrics = {}
|
524 |
-
for k, v in metrics.items():
|
525 |
-
if prefix is not None:
|
526 |
-
k = f"{prefix}/{k}"
|
527 |
-
if "_norm" in k:
|
528 |
-
log_metrics[f"{k}/"] = unfreeze(v)
|
529 |
-
elif "_hist" in k:
|
530 |
-
v = jax.tree_map(lambda x: jax.device_get(x), unfreeze(v))
|
531 |
-
v = jax.tree_map(
|
532 |
-
lambda x: wandb.Histogram(np_histogram=x),
|
533 |
-
v,
|
534 |
-
is_leaf=lambda x: isinstance(x, tuple),
|
535 |
-
)
|
536 |
-
log_metrics[f"{k}/"] = v
|
537 |
-
else:
|
538 |
-
log_metrics[k] = v
|
539 |
-
wandb.log({**log_metrics, **self.state_dict})
|
540 |
-
|
541 |
-
|
542 |
def main():
|
543 |
# See all possible arguments by passing the --help flag to this script.
|
544 |
parser = HfArgumentParser(
|
@@ -593,9 +558,7 @@ def main():
|
|
593 |
# Set up our new model config
|
594 |
if model_args.config_name:
|
595 |
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
596 |
-
|
597 |
-
# we correctly set it later per training_args
|
598 |
-
config.gradient_checkpointing = False
|
599 |
else:
|
600 |
config = None
|
601 |
|
@@ -607,9 +570,7 @@ def main():
|
|
607 |
seed=training_args.seed_model,
|
608 |
dtype=getattr(jnp, model_args.dtype),
|
609 |
abstract_init=True, # we overwrite them with loaded checkpoint
|
610 |
-
|
611 |
-
# we correctly set it later per training_args
|
612 |
-
gradient_checkpointing=False,
|
613 |
)
|
614 |
else:
|
615 |
model = DalleBart(
|
@@ -619,21 +580,6 @@ def main():
|
|
619 |
abstract_init=True,
|
620 |
)
|
621 |
|
622 |
-
# define model eval and train functions
|
623 |
-
eval_fn = model.__call__
|
624 |
-
if training_args.gradient_checkpointing:
|
625 |
-
remat_config = copy.deepcopy(model.config)
|
626 |
-
remat_config.gradient_checkpointing = True
|
627 |
-
remat_model = DalleBart(
|
628 |
-
remat_config,
|
629 |
-
seed=training_args.seed_model,
|
630 |
-
dtype=getattr(jnp, model_args.dtype),
|
631 |
-
init_weights=False,
|
632 |
-
)
|
633 |
-
train_fn = remat_model.__call__
|
634 |
-
else:
|
635 |
-
train_fn = model.__call__
|
636 |
-
|
637 |
# get model metadata
|
638 |
model_metadata = model_args.get_metadata()
|
639 |
|
@@ -708,8 +654,16 @@ def main():
|
|
708 |
"len_train_dataset": len_train_dataset,
|
709 |
"len_eval_dataset": len_eval_dataset,
|
710 |
"batch_size_per_step": batch_size_per_step,
|
711 |
-
"num_params": num_params,
|
712 |
"num_devices": jax.device_count(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
713 |
}
|
714 |
)
|
715 |
|
@@ -719,7 +673,7 @@ def main():
|
|
719 |
warmup_fn = optax.linear_schedule(
|
720 |
init_value=0.0,
|
721 |
end_value=training_args.learning_rate,
|
722 |
-
transition_steps=training_args.warmup_steps,
|
723 |
)
|
724 |
# offset step when resuming
|
725 |
if model_metadata.get("step", 0):
|
@@ -867,7 +821,7 @@ def main():
|
|
867 |
epoch=None,
|
868 |
train_time=None,
|
869 |
train_samples=None,
|
870 |
-
apply_fn=
|
871 |
tx=optimizer,
|
872 |
)
|
873 |
|
@@ -880,13 +834,13 @@ def main():
|
|
880 |
# params have not been initialized yet
|
881 |
return model.init_weights()
|
882 |
|
883 |
-
with
|
884 |
logger.info(" Creating state")
|
885 |
if not model_args.restore_state:
|
886 |
|
887 |
def init_state(params):
|
888 |
return TrainState.create(
|
889 |
-
apply_fn=
|
890 |
tx=optimizer,
|
891 |
params=maybe_init_params(params),
|
892 |
dropout_rng=dropout_rng,
|
@@ -913,7 +867,7 @@ def main():
|
|
913 |
|
914 |
def restore_state(params, opt_state):
|
915 |
return TrainState(
|
916 |
-
apply_fn=
|
917 |
tx=optimizer,
|
918 |
params=params,
|
919 |
opt_state=opt_state,
|
@@ -959,7 +913,7 @@ def main():
|
|
959 |
)
|
960 |
|
961 |
# Define gradient update step fn
|
962 |
-
def train_step(state, batch,
|
963 |
|
964 |
# get a minibatch (one gradient accumulation slice)
|
965 |
def get_minibatch(batch, grad_idx):
|
@@ -1048,36 +1002,45 @@ def main():
|
|
1048 |
state = state.apply_gradients(
|
1049 |
grads=grads,
|
1050 |
dropout_rng=dropout_rng,
|
1051 |
-
train_time=
|
1052 |
train_samples=state.train_samples + batch_size_per_step,
|
1053 |
)
|
1054 |
|
1055 |
-
|
1056 |
-
|
|
|
|
|
1057 |
|
1058 |
-
def maybe_fn(fn, val, zeros):
|
1059 |
"""Call fn only if it is a logging step"""
|
1060 |
return jax.lax.cond(
|
1061 |
-
state.step %
|
1062 |
fn,
|
1063 |
lambda _: zeros,
|
1064 |
val,
|
1065 |
)
|
1066 |
|
1067 |
-
|
1068 |
-
|
1069 |
|
1070 |
-
|
1071 |
-
|
1072 |
|
1073 |
-
|
1074 |
-
|
1075 |
-
|
1076 |
-
|
1077 |
-
|
1078 |
-
|
1079 |
|
1080 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1081 |
zeros_hist = jax.tree_map(
|
1082 |
lambda _: jnp.histogram(jnp.zeros(1), density=True), state.params
|
1083 |
)
|
@@ -1085,8 +1048,12 @@ def main():
|
|
1085 |
def histogram(val):
|
1086 |
return jax.tree_map(lambda x: jnp.histogram(x, density=True), val)
|
1087 |
|
1088 |
-
gradients_hist = maybe_fn(
|
1089 |
-
|
|
|
|
|
|
|
|
|
1090 |
|
1091 |
metrics.update(
|
1092 |
{
|
@@ -1101,7 +1068,7 @@ def main():
|
|
1101 |
def eval_step(state, batch):
|
1102 |
def compute_eval_loss(batch):
|
1103 |
batch, labels = batch.pop("labels")
|
1104 |
-
logits =
|
1105 |
return loss_fn(logits, labels)
|
1106 |
|
1107 |
if use_vmap_trick:
|
@@ -1134,13 +1101,73 @@ def main():
|
|
1134 |
out_axis_resources=None,
|
1135 |
)
|
1136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1137 |
# init variables
|
1138 |
-
|
1139 |
train_metrics = None
|
1140 |
-
|
1141 |
-
metrics_logger = MetricsLogger(step)
|
1142 |
epochs = tqdm(
|
1143 |
-
range(
|
1144 |
desc=f"Epoch ... (1/{num_epochs})",
|
1145 |
position=0,
|
1146 |
disable=jax.process_index() > 0,
|
@@ -1149,6 +1176,7 @@ def main():
|
|
1149 |
def run_evaluation():
|
1150 |
# ======================== Evaluating ==============================
|
1151 |
if training_args.do_eval:
|
|
|
1152 |
eval_loader = dataset.dataloader("eval", eval_batch_size_per_step)
|
1153 |
eval_steps = (
|
1154 |
len_eval_dataset // eval_batch_size_per_step
|
@@ -1195,6 +1223,7 @@ def main():
|
|
1195 |
|
1196 |
# log metrics
|
1197 |
metrics_logger.log(eval_metrics, prefix="eval")
|
|
|
1198 |
|
1199 |
# Print metrics and update progress bar
|
1200 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
@@ -1206,6 +1235,7 @@ def main():
|
|
1206 |
def run_save_model(state, eval_metrics=None):
|
1207 |
if jax.process_index() == 0:
|
1208 |
|
|
|
1209 |
output_dir = training_args.output_dir
|
1210 |
use_bucket = output_dir.startswith("gs://")
|
1211 |
if use_bucket:
|
@@ -1298,13 +1328,15 @@ def main():
|
|
1298 |
f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
|
1299 |
)
|
1300 |
wandb.run.log_artifact(artifact_state)
|
|
|
1301 |
|
1302 |
logger.info(" Ready to start training")
|
1303 |
-
with
|
1304 |
for epoch in epochs:
|
1305 |
state.replace(epoch=epoch)
|
|
|
1306 |
# ======================== Training ================================
|
1307 |
-
metrics_logger.update_state_metrics(
|
1308 |
metrics_logger.log({})
|
1309 |
|
1310 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
@@ -1323,9 +1355,7 @@ def main():
|
|
1323 |
disable=jax.process_index() > 0,
|
1324 |
):
|
1325 |
# calculate delta time (we have a lag of one step but it's ok)
|
1326 |
-
|
1327 |
-
delta_time = new_time - last_time
|
1328 |
-
last_time = new_time
|
1329 |
|
1330 |
# set correct shape to batch
|
1331 |
# - add grad_step dim if gradient_accumulation_steps > 1
|
@@ -1353,18 +1383,23 @@ def main():
|
|
1353 |
batch = freeze(batch)
|
1354 |
|
1355 |
# train step
|
1356 |
-
state, train_metrics = p_train_step(state, batch,
|
1357 |
-
step += 1
|
1358 |
-
|
1359 |
-
|
1360 |
-
|
|
|
|
|
|
|
|
|
|
|
1361 |
metrics_logger.log(train_metrics, prefix="train")
|
1362 |
|
1363 |
eval_metrics = None
|
1364 |
-
if step % training_args.eval_steps == 0:
|
1365 |
eval_metrics = run_evaluation()
|
1366 |
|
1367 |
-
if step % training_args.save_steps == 0:
|
1368 |
run_save_model(state, eval_metrics)
|
1369 |
|
1370 |
# log final train metrics
|
|
|
18 |
Script adapted from run_summarization_flax.py
|
19 |
"""
|
20 |
|
|
|
21 |
import io
|
22 |
import logging
|
23 |
import os
|
|
|
29 |
from typing import Any, Callable, NamedTuple, Optional
|
30 |
|
31 |
import datasets
|
32 |
+
import flax
|
33 |
import jax
|
34 |
import jax.numpy as jnp
|
35 |
+
import jaxlib
|
36 |
import numpy as np
|
37 |
import optax
|
38 |
import transformers
|
|
|
406 |
default=False,
|
407 |
metadata={"help": "Log model to wandb at `save_steps` frequency."},
|
408 |
)
|
409 |
+
log_norm_steps: int = field(
|
410 |
+
default=True,
|
411 |
+
metadata={"help": "Log parameters and gradients norm at this frequency."},
|
412 |
+
)
|
413 |
+
log_histogram_steps: int = field(
|
414 |
default=False,
|
415 |
metadata={
|
416 |
+
"help": "Log parameters and gradients histograms at this frequency. Slows down training."
|
417 |
},
|
418 |
)
|
419 |
|
|
|
476 |
], f"Selected learning rate decay not supported: {self.lr_decay}"
|
477 |
if self.per_device_eval_batch_size is None:
|
478 |
self.per_device_eval_batch_size = self.per_device_train_batch_size
|
479 |
+
if self.log_norm_steps is True:
|
480 |
+
self.log_norm_steps = self.logging_steps
|
481 |
if (
|
482 |
os.path.exists(self.output_dir)
|
483 |
and os.listdir(self.output_dir)
|
|
|
504 |
train_samples: int = 0 # number of samples seen
|
505 |
|
506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
507 |
def main():
|
508 |
# See all possible arguments by passing the --help flag to this script.
|
509 |
parser = HfArgumentParser(
|
|
|
558 |
# Set up our new model config
|
559 |
if model_args.config_name:
|
560 |
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
561 |
+
config.gradient_checkpointing = training_args.gradient_checkpointing
|
|
|
|
|
562 |
else:
|
563 |
config = None
|
564 |
|
|
|
570 |
seed=training_args.seed_model,
|
571 |
dtype=getattr(jnp, model_args.dtype),
|
572 |
abstract_init=True, # we overwrite them with loaded checkpoint
|
573 |
+
gradient_checkpointing=training_args.gradient_checkpointing,
|
|
|
|
|
574 |
)
|
575 |
else:
|
576 |
model = DalleBart(
|
|
|
580 |
abstract_init=True,
|
581 |
)
|
582 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
583 |
# get model metadata
|
584 |
model_metadata = model_args.get_metadata()
|
585 |
|
|
|
654 |
"len_train_dataset": len_train_dataset,
|
655 |
"len_eval_dataset": len_eval_dataset,
|
656 |
"batch_size_per_step": batch_size_per_step,
|
657 |
+
"model": {"num_params": num_params, "config": model.config.to_dict()},
|
658 |
"num_devices": jax.device_count(),
|
659 |
+
"versions": {
|
660 |
+
"jax": jax.__version__,
|
661 |
+
"jaxlib": jaxlib.__version__,
|
662 |
+
"flax": flax.__version__,
|
663 |
+
"transformers": transformers.__version__,
|
664 |
+
"datasets": datasets.__version__,
|
665 |
+
"wandb": wandb.__version__,
|
666 |
+
},
|
667 |
}
|
668 |
)
|
669 |
|
|
|
673 |
warmup_fn = optax.linear_schedule(
|
674 |
init_value=0.0,
|
675 |
end_value=training_args.learning_rate,
|
676 |
+
transition_steps=training_args.warmup_steps + 1, # ensure not 0
|
677 |
)
|
678 |
# offset step when resuming
|
679 |
if model_metadata.get("step", 0):
|
|
|
821 |
epoch=None,
|
822 |
train_time=None,
|
823 |
train_samples=None,
|
824 |
+
apply_fn=model.__call__,
|
825 |
tx=optimizer,
|
826 |
)
|
827 |
|
|
|
834 |
# params have not been initialized yet
|
835 |
return model.init_weights()
|
836 |
|
837 |
+
with mesh:
|
838 |
logger.info(" Creating state")
|
839 |
if not model_args.restore_state:
|
840 |
|
841 |
def init_state(params):
|
842 |
return TrainState.create(
|
843 |
+
apply_fn=model.__call__,
|
844 |
tx=optimizer,
|
845 |
params=maybe_init_params(params),
|
846 |
dropout_rng=dropout_rng,
|
|
|
867 |
|
868 |
def restore_state(params, opt_state):
|
869 |
return TrainState(
|
870 |
+
apply_fn=model.__call__,
|
871 |
tx=optimizer,
|
872 |
params=params,
|
873 |
opt_state=opt_state,
|
|
|
913 |
)
|
914 |
|
915 |
# Define gradient update step fn
|
916 |
+
def train_step(state, batch, train_time):
|
917 |
|
918 |
# get a minibatch (one gradient accumulation slice)
|
919 |
def get_minibatch(batch, grad_idx):
|
|
|
1002 |
state = state.apply_gradients(
|
1003 |
grads=grads,
|
1004 |
dropout_rng=dropout_rng,
|
1005 |
+
train_time=train_time,
|
1006 |
train_samples=state.train_samples + batch_size_per_step,
|
1007 |
)
|
1008 |
|
1009 |
+
metrics = {
|
1010 |
+
"loss": loss,
|
1011 |
+
"learning_rate": learning_rate_fn(state.step),
|
1012 |
+
}
|
1013 |
|
1014 |
+
def maybe_fn(fn, val, zeros, freq):
|
1015 |
"""Call fn only if it is a logging step"""
|
1016 |
return jax.lax.cond(
|
1017 |
+
state.step % freq == 0,
|
1018 |
fn,
|
1019 |
lambda _: zeros,
|
1020 |
val,
|
1021 |
)
|
1022 |
|
1023 |
+
if training_args.log_norm_steps:
|
1024 |
+
zeros_norm = jax.tree_map(lambda _: jnp.float32(0), state.params)
|
1025 |
|
1026 |
+
def norm(val):
|
1027 |
+
return jax.tree_map(lambda x: jnp.linalg.norm(x), val)
|
1028 |
|
1029 |
+
gradients_norm = maybe_fn(
|
1030 |
+
norm, grads, zeros_norm, training_args.log_norm_steps
|
1031 |
+
)
|
1032 |
+
params_norm = maybe_fn(
|
1033 |
+
norm, state.params, zeros_norm, training_args.log_norm_steps
|
1034 |
+
)
|
1035 |
|
1036 |
+
metrics.update(
|
1037 |
+
{
|
1038 |
+
"gradients_norm": gradients_norm,
|
1039 |
+
"params_norm": params_norm,
|
1040 |
+
}
|
1041 |
+
)
|
1042 |
+
|
1043 |
+
if training_args.log_histogram_steps:
|
1044 |
zeros_hist = jax.tree_map(
|
1045 |
lambda _: jnp.histogram(jnp.zeros(1), density=True), state.params
|
1046 |
)
|
|
|
1048 |
def histogram(val):
|
1049 |
return jax.tree_map(lambda x: jnp.histogram(x, density=True), val)
|
1050 |
|
1051 |
+
gradients_hist = maybe_fn(
|
1052 |
+
histogram, grads, zeros_hist, training_args.log_histogram_steps
|
1053 |
+
)
|
1054 |
+
params_hist = maybe_fn(
|
1055 |
+
histogram, state.params, zeros_hist, training_args.log_histogram_steps
|
1056 |
+
)
|
1057 |
|
1058 |
metrics.update(
|
1059 |
{
|
|
|
1068 |
def eval_step(state, batch):
|
1069 |
def compute_eval_loss(batch):
|
1070 |
batch, labels = batch.pop("labels")
|
1071 |
+
logits = model(**batch, params=state.params, train=False)[0]
|
1072 |
return loss_fn(logits, labels)
|
1073 |
|
1074 |
if use_vmap_trick:
|
|
|
1101 |
out_axis_resources=None,
|
1102 |
)
|
1103 |
|
1104 |
+
# define metrics logger
|
1105 |
+
class MetricsLogger:
|
1106 |
+
def __init__(self, step):
|
1107 |
+
# keep state
|
1108 |
+
self.state_dict = {}
|
1109 |
+
# estimate speed
|
1110 |
+
self.step = step
|
1111 |
+
self.time = time.perf_counter()
|
1112 |
+
self.offset_time = 0.0
|
1113 |
+
|
1114 |
+
def update_state_metrics(self, state):
|
1115 |
+
"""Update internal state metrics (logged at each call to be used as x-axis)"""
|
1116 |
+
self.state_dict = {
|
1117 |
+
f'train/{k.split("_")[-1]}': state[k]
|
1118 |
+
for k in ["step", "epoch", "train_time", "train_samples"]
|
1119 |
+
}
|
1120 |
+
# timing metrics
|
1121 |
+
new_step = int(state["step"])
|
1122 |
+
new_time = time.perf_counter()
|
1123 |
+
if new_step > self.step:
|
1124 |
+
# remove time for eval & save
|
1125 |
+
delta_time = new_time - self.time - self.offset_time
|
1126 |
+
self.offset_time = 0
|
1127 |
+
time_per_step = delta_time / (new_step - self.step)
|
1128 |
+
self.step = new_step
|
1129 |
+
self.time = new_time
|
1130 |
+
self.log_time("train_per_step", time_per_step, offset=False)
|
1131 |
+
self.log_time("train_per_log", delta_time, offset=False)
|
1132 |
+
|
1133 |
+
def log_time(self, key, duration, offset=True):
|
1134 |
+
wandb.log({f"time/{key}": duration, **self.state_dict})
|
1135 |
+
if offset:
|
1136 |
+
self.offset_time += duration
|
1137 |
+
|
1138 |
+
def log(self, metrics, prefix=None):
|
1139 |
+
if jax.process_index() == 0:
|
1140 |
+
log_metrics = {}
|
1141 |
+
for k, v in metrics.items():
|
1142 |
+
if "_norm" in k:
|
1143 |
+
if self.step % training_args.log_norm_steps == 0:
|
1144 |
+
log_metrics[f"{k}/"] = unfreeze(v)
|
1145 |
+
elif "_hist" in k:
|
1146 |
+
if self.step % training_args.log_histogram_steps == 0:
|
1147 |
+
v = jax.tree_map(lambda x: jax.device_get(x), unfreeze(v))
|
1148 |
+
v = jax.tree_map(
|
1149 |
+
lambda x: wandb.Histogram(np_histogram=x),
|
1150 |
+
v,
|
1151 |
+
is_leaf=lambda x: isinstance(x, tuple),
|
1152 |
+
)
|
1153 |
+
log_metrics[f"{k}/"] = v
|
1154 |
+
else:
|
1155 |
+
if prefix is not None:
|
1156 |
+
k = f"{prefix}/{k}"
|
1157 |
+
log_metrics[k] = v
|
1158 |
+
wandb.log({**log_metrics, **self.state_dict})
|
1159 |
+
|
1160 |
+
# keep local copy of state
|
1161 |
+
local_state = {
|
1162 |
+
k: jax.device_get(getattr(state, k)).item()
|
1163 |
+
for k in ["step", "epoch", "train_time", "train_samples"]
|
1164 |
+
}
|
1165 |
# init variables
|
1166 |
+
start_time = time.perf_counter() - local_state["train_time"]
|
1167 |
train_metrics = None
|
1168 |
+
metrics_logger = MetricsLogger(local_state["step"])
|
|
|
1169 |
epochs = tqdm(
|
1170 |
+
range(local_state["epoch"], num_epochs),
|
1171 |
desc=f"Epoch ... (1/{num_epochs})",
|
1172 |
position=0,
|
1173 |
disable=jax.process_index() > 0,
|
|
|
1176 |
def run_evaluation():
|
1177 |
# ======================== Evaluating ==============================
|
1178 |
if training_args.do_eval:
|
1179 |
+
start_eval_time = time.perf_counter()
|
1180 |
eval_loader = dataset.dataloader("eval", eval_batch_size_per_step)
|
1181 |
eval_steps = (
|
1182 |
len_eval_dataset // eval_batch_size_per_step
|
|
|
1223 |
|
1224 |
# log metrics
|
1225 |
metrics_logger.log(eval_metrics, prefix="eval")
|
1226 |
+
metrics_logger.log_time("eval", time.perf_counter() - start_eval_time)
|
1227 |
|
1228 |
# Print metrics and update progress bar
|
1229 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
|
1235 |
def run_save_model(state, eval_metrics=None):
|
1236 |
if jax.process_index() == 0:
|
1237 |
|
1238 |
+
start_save_time = time.perf_counter()
|
1239 |
output_dir = training_args.output_dir
|
1240 |
use_bucket = output_dir.startswith("gs://")
|
1241 |
if use_bucket:
|
|
|
1328 |
f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
|
1329 |
)
|
1330 |
wandb.run.log_artifact(artifact_state)
|
1331 |
+
metrics_logger.log_time("save_model", time.perf_counter() - start_save_time)
|
1332 |
|
1333 |
logger.info(" Ready to start training")
|
1334 |
+
with mesh:
|
1335 |
for epoch in epochs:
|
1336 |
state.replace(epoch=epoch)
|
1337 |
+
local_state["epoch"] = epoch
|
1338 |
# ======================== Training ================================
|
1339 |
+
metrics_logger.update_state_metrics(local_state)
|
1340 |
metrics_logger.log({})
|
1341 |
|
1342 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
|
|
1355 |
disable=jax.process_index() > 0,
|
1356 |
):
|
1357 |
# calculate delta time (we have a lag of one step but it's ok)
|
1358 |
+
train_time = time.perf_counter() - start_time
|
|
|
|
|
1359 |
|
1360 |
# set correct shape to batch
|
1361 |
# - add grad_step dim if gradient_accumulation_steps > 1
|
|
|
1383 |
batch = freeze(batch)
|
1384 |
|
1385 |
# train step
|
1386 |
+
state, train_metrics = p_train_step(state, batch, train_time)
|
1387 |
+
local_state["step"] += 1
|
1388 |
+
local_state["train_time"] = train_time
|
1389 |
+
local_state["train_samples"] += batch_size_per_step
|
1390 |
+
|
1391 |
+
if (
|
1392 |
+
local_state["step"] % training_args.logging_steps == 0
|
1393 |
+
and jax.process_index() == 0
|
1394 |
+
):
|
1395 |
+
metrics_logger.update_state_metrics(local_state)
|
1396 |
metrics_logger.log(train_metrics, prefix="train")
|
1397 |
|
1398 |
eval_metrics = None
|
1399 |
+
if local_state["step"] % training_args.eval_steps == 0:
|
1400 |
eval_metrics = run_evaluation()
|
1401 |
|
1402 |
+
if local_state["step"] % training_args.save_steps == 0:
|
1403 |
run_save_model(state, eval_metrics)
|
1404 |
|
1405 |
# log final train metrics
|