boris commited on
Commit
542378c
1 Parent(s): b7b619a

feat: implement transformer variants (#144)

Browse files

* added DeepNet
* added Swin v2
* added NormFormer
* added RMSNorm
* added GLU variants

.gitignore CHANGED
@@ -3,3 +3,4 @@ __pycache__
3
  .streamlit
4
  wandb/
5
  *.egg-info/
 
 
3
  .streamlit
4
  wandb/
5
  *.egg-info/
6
+ jax_cache/
README.md CHANGED
@@ -94,26 +94,43 @@ Many thanks to the people who helped make it better:
94
 
95
  - the [DALLE-Pytorch](https://discord.gg/xBPBXfcFHd) and [EleutherAI](https://www.eleuther.ai/) communities for testing and exchanging cool ideas
96
  - [Rohan Anil](https://github.com/rohan-anil) for adding Distributed Shampoo optimizer
 
97
  - [Katherine Crowson](https://github.com/crowsonkb) for [super conditioning](https://twitter.com/RiversHaveWings/status/1478093658716966912)
98
 
99
  ## Citing DALL·E mini
100
 
101
  If you find DALL·E mini useful in your research or wish to refer, please use the following BibTeX entry.
102
 
103
- ```
104
  @misc{Dayma_DALL·E_Mini_2021,
105
- author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
106
- doi = {10.5281/zenodo.5146400},
107
- month = {7},
108
- title = {DALL·E Mini},
109
- url = {https://github.com/borisdayma/dalle-mini},
110
- year = {2021}
111
  }
112
  ```
113
 
114
  ## References
115
 
116
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  @misc{ramesh2021zeroshot,
118
  title={Zero-Shot Text-to-Image Generation},
119
  author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
@@ -124,7 +141,18 @@ year = {2021}
124
  }
125
  ```
126
 
 
 
 
 
 
 
 
 
 
127
  ```
 
 
128
  @misc{esser2021taming,
129
  title={Taming Transformers for High-Resolution Image Synthesis},
130
  author={Patrick Esser and Robin Rombach and Björn Ommer},
@@ -135,7 +163,7 @@ year = {2021}
135
  }
136
  ```
137
 
138
- ```
139
  @misc{lewis2019bart,
140
  title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension},
141
  author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer},
@@ -146,24 +174,64 @@ year = {2021}
146
  }
147
  ```
148
 
149
- ```
150
- @misc{radford2021learning,
151
- title={Learning Transferable Visual Models From Natural Language Supervision},
152
- author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
153
  year={2021},
154
- eprint={2103.00020},
155
  archivePrefix={arXiv},
156
- primaryClass={cs.CV}
157
  }
158
  ```
159
 
 
 
 
 
 
 
 
160
  ```
161
- @misc{anil2021scalable,
162
- title={Scalable Second Order Optimization for Deep Learning},
163
- author={Rohan Anil and Vineet Gupta and Tomer Koren and Kevin Regan and Yoram Singer},
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  year={2021},
165
- eprint={2002.09018},
166
  archivePrefix={arXiv},
167
- primaryClass={cs.LG}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  }
169
  ```
 
94
 
95
  - the [DALLE-Pytorch](https://discord.gg/xBPBXfcFHd) and [EleutherAI](https://www.eleuther.ai/) communities for testing and exchanging cool ideas
96
  - [Rohan Anil](https://github.com/rohan-anil) for adding Distributed Shampoo optimizer
97
+ - [Phil Wang](https://github.com/lucidrains) has provided a lot of cool implementations of transformer variants and gives interesting insights with [x-transformers](https://github.com/lucidrains/x-transformers)
98
  - [Katherine Crowson](https://github.com/crowsonkb) for [super conditioning](https://twitter.com/RiversHaveWings/status/1478093658716966912)
99
 
100
  ## Citing DALL·E mini
101
 
102
  If you find DALL·E mini useful in your research or wish to refer, please use the following BibTeX entry.
103
 
104
+ ```text
105
  @misc{Dayma_DALL·E_Mini_2021,
106
+ author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
107
+ doi = {10.5281/zenodo.5146400},
108
+ month = {7},
109
+ title = {DALL·E Mini},
110
+ url = {https://github.com/borisdayma/dalle-mini},
111
+ year = {2021}
112
  }
113
  ```
114
 
115
  ## References
116
 
117
+ Original DALL·E from "[Zero-Shot Text-to-Image Generation](https://arxiv.org/abs/2102.12092)" with image quantization from "[Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)".
118
+
119
+ Image encoder from "[Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841v2)".
120
+
121
+ Sequence to sequence model based on "[BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461v1)" with implementation of a few variants:
122
+
123
+ - "[GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202)"
124
+ - "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)"
125
+ - "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)"
126
+ - "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)"
127
+ - "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)"
128
+
129
+ Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)".
130
+
131
+ ### Citations
132
+
133
+ ```text
134
  @misc{ramesh2021zeroshot,
135
  title={Zero-Shot Text-to-Image Generation},
136
  author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
 
141
  }
142
  ```
143
 
144
+ ```text
145
+ @misc{radford2021learning,
146
+ title={Learning Transferable Visual Models From Natural Language Supervision},
147
+ author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
148
+ year={2021},
149
+ eprint={2103.00020},
150
+ archivePrefix={arXiv},
151
+ primaryClass={cs.CV}
152
+ }
153
  ```
154
+
155
+ ```text
156
  @misc{esser2021taming,
157
  title={Taming Transformers for High-Resolution Image Synthesis},
158
  author={Patrick Esser and Robin Rombach and Björn Ommer},
 
163
  }
164
  ```
165
 
166
+ ```text
167
  @misc{lewis2019bart,
168
  title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension},
169
  author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer},
 
174
  }
175
  ```
176
 
177
+ ```text
178
+ @misc{anil2021scalable,
179
+ title={Scalable Second Order Optimization for Deep Learning},
180
+ author={Rohan Anil and Vineet Gupta and Tomer Koren and Kevin Regan and Yoram Singer},
181
  year={2021},
182
+ eprint={2002.09018},
183
  archivePrefix={arXiv},
184
+ primaryClass={cs.LG}
185
  }
186
  ```
187
 
188
+ ```text
189
+ @misc{shazeer2020glu,
190
+ title={GLU Variants Improve Transformer},
191
+ author={Noam Shazeer},
192
+ year={2020},
193
+ url={https://arxiv.org/abs/2002.05202}
194
+ }
195
  ```
196
+
197
+ ```text
198
+ @misc{wang_ma_dong_huang_zhang_wei_2022,
199
+ title={DeepNet: Scaling transformers to 1,000 layers},
200
+ author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Zhang, Dongdong and Wei, Furu},
201
+ year={2022},
202
+ eprint={2203.00555}
203
+ archivePrefix={arXiv},
204
+ primaryClass={cs.LG}
205
+ }
206
+ ```
207
+
208
+ ```text
209
+ @misc{shleifer2021normformer,
210
+ title={NormFormer: Improved Transformer Pretraining with Extra Normalization},
211
+ author={Sam Shleifer and Jason Weston and Myle Ott},
212
  year={2021},
213
+ eprint={2110.09456},
214
  archivePrefix={arXiv},
215
+ primaryClass={cs.CL}
216
+ }
217
+ ```
218
+
219
+ ```text
220
+ @inproceedings{liu2021swinv2,
221
+ title={Swin Transformer V2: Scaling Up Capacity and Resolution},
222
+ author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
223
+ booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},
224
+ year={2022}
225
+ }
226
+ ```
227
+
228
+ ```text
229
+ @misc{zhang2019root,
230
+ title = {Root Mean Square Layer Normalization},
231
+ author = {Biao Zhang and Rico Sennrich},
232
+ year = {2019},
233
+ eprint = {1910.07467},
234
+ archivePrefix = {arXiv},
235
+ primaryClass = {cs.LG}
236
  }
237
  ```
src/dalle_mini/data.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from dataclasses import dataclass, field
2
  from functools import partial
3
 
@@ -39,6 +40,9 @@ class Dataset:
39
  multi_hosts: bool = field(init=False)
40
 
41
  def __post_init__(self):
 
 
 
42
  self.multi_hosts = jax.process_count() > 1
43
  # feed blank captions only in streaming mode for now
44
  # otherwise dataset could be cached with same blanked captions
@@ -106,11 +110,10 @@ class Dataset:
106
  if self.streaming:
107
  # we need to shuffle early in streaming mode
108
  if hasattr(self, "train_dataset"):
109
- self.train_dataset = self.train_dataset.shuffle(5000, self.seed_dataset)
 
 
110
  else:
111
- # prepare rng for later shuffling
112
- if self.seed_dataset is None:
113
- self.seed_dataset = np.random.get_state()[1][0]
114
  self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
115
 
116
  # filter data
 
1
+ import random
2
  from dataclasses import dataclass, field
3
  from functools import partial
4
 
 
40
  multi_hosts: bool = field(init=False)
41
 
42
  def __post_init__(self):
43
+ if self.seed_dataset is None:
44
+ # create a random seed
45
+ self.seed_dataset = random.randint(0, 2**32 - 1)
46
  self.multi_hosts = jax.process_count() > 1
47
  # feed blank captions only in streaming mode for now
48
  # otherwise dataset could be cached with same blanked captions
 
110
  if self.streaming:
111
  # we need to shuffle early in streaming mode
112
  if hasattr(self, "train_dataset"):
113
+ self.train_dataset = self.train_dataset.shuffle(
114
+ buffer_size=5000, seed=self.seed_dataset
115
+ )
116
  else:
 
 
 
117
  self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
118
 
119
  # filter data
src/dalle_mini/model/configuration.py CHANGED
@@ -44,15 +44,12 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
44
  decoder_layers=12,
45
  decoder_ffn_dim=4096,
46
  decoder_attention_heads=16,
47
- encoder_layerdrop=0.0,
48
- decoder_layerdrop=0.0,
49
  activation_function="gelu",
50
  d_model=1024,
51
  dropout=0.1,
52
  attention_dropout=0.0,
53
  activation_dropout=0.0,
54
  init_std=0.02,
55
- classifier_dropout=0.0,
56
  scale_embedding=False,
57
  gradient_checkpointing=False,
58
  use_cache=True,
@@ -60,9 +57,38 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
60
  forced_eos_token_id=None,
61
  tie_word_embeddings=False, # different modalities and sizes
62
  do_sample=True,
 
 
 
 
 
 
 
 
63
  **kwargs,
64
  ):
 
65
  self.normalize_text = normalize_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  self.encoder_vocab_size = encoder_vocab_size
67
  self.image_vocab_size = image_vocab_size
68
  self.image_length = image_length
@@ -79,9 +105,6 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
79
  self.activation_dropout = activation_dropout
80
  self.activation_function = activation_function
81
  self.init_std = init_std
82
- self.encoder_layerdrop = encoder_layerdrop
83
- self.decoder_layerdrop = decoder_layerdrop
84
- self.classifier_dropout = classifier_dropout
85
  self.use_cache = use_cache
86
  self.gradient_checkpointing = gradient_checkpointing
87
  self.scale_embedding = (
 
44
  decoder_layers=12,
45
  decoder_ffn_dim=4096,
46
  decoder_attention_heads=16,
 
 
47
  activation_function="gelu",
48
  d_model=1024,
49
  dropout=0.1,
50
  attention_dropout=0.0,
51
  activation_dropout=0.0,
52
  init_std=0.02,
 
53
  scale_embedding=False,
54
  gradient_checkpointing=False,
55
  use_cache=True,
 
57
  forced_eos_token_id=None,
58
  tie_word_embeddings=False, # different modalities and sizes
59
  do_sample=True,
60
+ # transformer variants
61
+ head_scale=False, # used in NormFormer
62
+ ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
63
+ ln_positions="deepnet", # layer normalization positions, "normformer", "swinv2", "deepnet" (same as post-ln)
64
+ use_cosine_attention=False, # used in Swin v2
65
+ tau_init=0.05, # used only in cosine attention (Swin v2)
66
+ use_deepnet_scaling=False, # used in Deepnet
67
+ use_glu=False, # "GLU Variants Improve Transformer"
68
  **kwargs,
69
  ):
70
+ # text normalizer
71
  self.normalize_text = normalize_text
72
+
73
+ # transformer variants
74
+ self.head_scale = head_scale # per Normformer
75
+ assert ln_type in [
76
+ "rmsnorm",
77
+ "layernorm",
78
+ ], "ln_type must be 'rmsnorm' or 'layernorm'"
79
+ self.ln_type = ln_type
80
+ assert ln_positions in [
81
+ "normformer",
82
+ "swinv2",
83
+ "deepnet",
84
+ ], "ln_positions must be 'normformer', 'swinv2' or 'deepnet'"
85
+ self.ln_positions = ln_positions
86
+ self.use_cosine_attention = use_cosine_attention
87
+ self.tau_init = tau_init
88
+ self.use_deepnet_scaling = use_deepnet_scaling
89
+ self.use_glu = use_glu
90
+
91
+ # common parameters
92
  self.encoder_vocab_size = encoder_vocab_size
93
  self.image_vocab_size = image_vocab_size
94
  self.image_length = image_length
 
105
  self.activation_dropout = activation_dropout
106
  self.activation_function = activation_function
107
  self.init_std = init_std
 
 
 
108
  self.use_cache = use_cache
109
  self.gradient_checkpointing = gradient_checkpointing
110
  self.scale_embedding = (
src/dalle_mini/model/modeling.py CHANGED
@@ -18,7 +18,7 @@ import math
18
  import os
19
  from functools import partial
20
  from pickle import UnpicklingError
21
- from typing import Dict, Optional, Tuple, Union
22
 
23
  import flax
24
  import flax.linen as nn
@@ -26,7 +26,9 @@ import jax
26
  import jax.numpy as jnp
27
  import msgpack.exceptions
28
  from flax.core.frozen_dict import unfreeze
29
- from flax.linen import make_causal_mask
 
 
30
  from flax.serialization import from_bytes
31
  from flax.traverse_util import flatten_dict, unflatten_dict
32
  from jax import lax
@@ -42,6 +44,8 @@ from transformers.file_utils import (
42
  )
43
  from transformers.generation_flax_utils import FlaxSampleOutput
44
  from transformers.modeling_flax_outputs import (
 
 
45
  FlaxCausalLMOutputWithCrossAttentions,
46
  FlaxSeq2SeqLMOutput,
47
  )
@@ -49,11 +53,7 @@ from transformers.modeling_flax_utils import ACT2FN
49
  from transformers.models.bart.modeling_flax_bart import (
50
  FlaxBartAttention,
51
  FlaxBartDecoder,
52
- FlaxBartDecoderLayer,
53
- FlaxBartDecoderLayerCollection,
54
  FlaxBartEncoder,
55
- FlaxBartEncoderLayer,
56
- FlaxBartEncoderLayerCollection,
57
  FlaxBartForConditionalGeneration,
58
  FlaxBartForConditionalGenerationModule,
59
  FlaxBartModule,
@@ -66,13 +66,124 @@ from .utils import PretrainedFromWandbMixin
66
 
67
  logger = logging.get_logger(__name__)
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  class FlaxBartAttention(FlaxBartAttention):
71
  """
72
  Edits:
73
  - causal mask is used only in decoder and considers image_length
 
74
  """
75
 
 
 
76
  def setup(self) -> None:
77
  self.head_dim = self.embed_dim // self.num_heads
78
  if self.head_dim * self.num_heads != self.embed_dim:
@@ -86,142 +197,667 @@ class FlaxBartAttention(FlaxBartAttention):
86
  self.embed_dim,
87
  use_bias=self.bias,
88
  dtype=self.dtype,
89
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
90
  )
91
 
92
- self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
93
- self.out_proj = dense()
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  self.dropout_layer = nn.Dropout(rate=self.dropout)
96
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  if self.causal:
98
  # used only in decoder
99
  self.causal_mask = make_causal_mask(
100
  jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool"
101
  )
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- class FlaxBartEncoderLayer(FlaxBartEncoderLayer):
105
- """
106
- Edits:
107
- - no bias
108
- - use custom FlaxBartAttention
109
- """
110
 
111
- def setup(self) -> None:
112
- self.embed_dim = self.config.d_model
113
- self.self_attn = FlaxBartAttention(
114
- config=self.config,
115
- embed_dim=self.embed_dim,
116
- num_heads=self.config.encoder_attention_heads,
117
- dropout=self.config.attention_dropout,
118
- bias=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  dtype=self.dtype,
 
120
  )
121
- self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
122
- self.dropout_layer = nn.Dropout(rate=self.config.dropout)
123
- self.activation_fn = ACT2FN[self.config.activation_function]
124
- self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
125
- self.fc1 = nn.Dense(
126
- self.config.encoder_ffn_dim,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  dtype=self.dtype,
128
  use_bias=False,
129
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
 
 
 
 
 
 
 
 
 
 
130
  )
131
- self.fc2 = nn.Dense(
 
132
  self.embed_dim,
133
  dtype=self.dtype,
134
  use_bias=False,
135
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  )
137
- self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
 
140
- class FlaxBartEncoderLayerCollection(FlaxBartEncoderLayerCollection):
141
  """
142
  Edits:
143
- - use custom FlaxBartEncoderLayer
144
- - allow Gradient Checkpointing (nn.remat)
145
  """
146
 
147
- def setup(self):
148
- layer_module = (
149
- nn.remat(FlaxBartEncoderLayer, concrete=True)
150
- if self.config.gradient_checkpointing
151
- else FlaxBartEncoderLayer
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  )
153
- self.layers = [
154
- layer_module(self.config, name=str(i), dtype=self.dtype)
155
- for i in range(self.config.encoder_layers)
156
- ]
157
- self.layerdrop = self.config.encoder_layerdrop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
 
160
- class FlaxBartDecoderLayer(FlaxBartDecoderLayer):
161
  """
162
  Edits:
163
  - no bias
164
- - uses custom FlaxBartAttention
165
  """
166
 
167
- def setup(self) -> None:
168
- self.embed_dim = self.config.d_model
169
- self.self_attn = FlaxBartAttention(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  config=self.config,
171
- embed_dim=self.embed_dim,
172
  num_heads=self.config.decoder_attention_heads,
173
  dropout=self.config.attention_dropout,
174
  causal=True,
175
  bias=False,
176
  dtype=self.dtype,
 
 
 
 
 
177
  )
178
- self.dropout_layer = nn.Dropout(rate=self.config.dropout)
179
- self.activation_fn = ACT2FN[self.config.activation_function]
180
- self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
181
 
182
- self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
183
- self.encoder_attn = FlaxBartAttention(
184
- config=self.config,
185
- embed_dim=self.embed_dim,
186
- num_heads=self.config.decoder_attention_heads,
187
- dropout=self.config.attention_dropout,
188
- bias=False,
189
- dtype=self.dtype,
190
  )
191
- self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
192
- self.fc1 = nn.Dense(
193
- self.config.encoder_ffn_dim,
194
- dtype=self.dtype,
195
- use_bias=False,
196
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  )
198
- self.fc2 = nn.Dense(
199
- self.embed_dim,
200
- dtype=self.dtype,
201
- use_bias=False,
202
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  )
204
- self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
205
 
206
 
207
- class FlaxBartDecoderLayerCollection(FlaxBartDecoderLayerCollection):
 
 
208
  """
209
  Edits:
210
  - use custom FlaxBartDecoderLayer
211
  - allow Gradient Checkpointing (nn.remat)
212
  """
213
 
214
- def setup(self):
215
- layer_module = (
216
- nn.remat(FlaxBartDecoderLayer, concrete=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  if self.config.gradient_checkpointing
218
  else FlaxBartDecoderLayer
219
  )
220
- self.layers = [
221
- layer_module(self.config, name=str(i), dtype=self.dtype)
222
- for i in range(self.config.decoder_layers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  ]
224
- self.layerdrop = self.config.decoder_layerdrop
 
 
 
 
 
 
 
 
 
225
 
226
 
227
  class FlaxBartEncoder(FlaxBartEncoder):
@@ -246,10 +882,14 @@ class FlaxBartEncoder(FlaxBartEncoder):
246
  self.embed_positions = nn.Embed(
247
  self.config.max_text_length + self.offset,
248
  embed_dim,
249
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
 
 
250
  )
251
  self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
252
- self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
 
 
253
 
254
 
255
  class FlaxBartDecoder(FlaxBartDecoder):
@@ -276,11 +916,15 @@ class FlaxBartDecoder(FlaxBartDecoder):
276
  self.embed_positions = nn.Embed(
277
  self.config.image_length + self.offset, # image length for BOS
278
  embed_dim,
279
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
 
 
280
  )
281
 
282
  self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
283
- self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
 
 
284
 
285
 
286
  class FlaxBartModule(FlaxBartModule):
@@ -294,12 +938,16 @@ class FlaxBartModule(FlaxBartModule):
294
  encoder_embed_tokens = nn.Embed(
295
  self.config.encoder_vocab_size,
296
  self.config.d_model,
297
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
 
 
298
  )
299
  decoder_embed_tokens = nn.Embed(
300
  self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
301
  self.config.d_model,
302
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
 
 
303
  )
304
 
305
  self.encoder = FlaxBartEncoder(
@@ -639,7 +1287,9 @@ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationMod
639
  + 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
640
  use_bias=False,
641
  dtype=self.dtype,
642
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
 
 
643
  )
644
 
645
  def __call__(
 
18
  import os
19
  from functools import partial
20
  from pickle import UnpicklingError
21
+ from typing import Any, Dict, Optional, Tuple, Union
22
 
23
  import flax
24
  import flax.linen as nn
 
26
  import jax.numpy as jnp
27
  import msgpack.exceptions
28
  from flax.core.frozen_dict import unfreeze
29
+ from flax.linen import combine_masks, make_causal_mask
30
+ from flax.linen import partitioning as nn_partitioning
31
+ from flax.linen.attention import dot_product_attention_weights
32
  from flax.serialization import from_bytes
33
  from flax.traverse_util import flatten_dict, unflatten_dict
34
  from jax import lax
 
44
  )
45
  from transformers.generation_flax_utils import FlaxSampleOutput
46
  from transformers.modeling_flax_outputs import (
47
+ FlaxBaseModelOutput,
48
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
49
  FlaxCausalLMOutputWithCrossAttentions,
50
  FlaxSeq2SeqLMOutput,
51
  )
 
53
  from transformers.models.bart.modeling_flax_bart import (
54
  FlaxBartAttention,
55
  FlaxBartDecoder,
 
 
56
  FlaxBartEncoder,
 
 
57
  FlaxBartForConditionalGeneration,
58
  FlaxBartForConditionalGenerationModule,
59
  FlaxBartModule,
 
66
 
67
  logger = logging.get_logger(__name__)
68
 
69
+ remat = nn_partitioning.remat
70
+
71
+
72
+ # deepnet initialization
73
+ def deepnet_init(gain=1):
74
+ init = jax.nn.initializers.glorot_normal()
75
+
76
+ def _init(*args, **kwargs):
77
+ return gain * init(*args, **kwargs)
78
+
79
+ return _init
80
+
81
+
82
+ # deepnet gain
83
+ deepnet_gain = {
84
+ "encoder": {
85
+ "alpha": lambda config: 0.81
86
+ * (config.encoder_layers**4 * config.decoder_layers) ** 0.0625,
87
+ "beta": lambda config: 0.87
88
+ * (config.encoder_layers**4 * config.decoder_layers) ** -0.0625,
89
+ },
90
+ "decoder": {
91
+ "alpha": lambda config: (3 * config.decoder_layers) ** 0.25,
92
+ "beta": lambda config: (12 * config.decoder_layers) ** -0.25,
93
+ },
94
+ }
95
+
96
+
97
+ class RMSNorm(nn.Module):
98
+ """
99
+ From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
100
+
101
+ Adapted from flax.linen.LayerNorm
102
+ """
103
+
104
+ epsilon: float = 1e-6
105
+ dtype: Any = jnp.float32
106
+ param_dtype: Any = jnp.float32
107
+ use_scale: bool = True
108
+ scale_init: Any = jax.nn.initializers.ones
109
+
110
+ @nn.compact
111
+ def __call__(self, x):
112
+ reduction_axes = (-1,)
113
+ feature_axes = (-1,)
114
+
115
+ rms_sq = self._compute_rms_sq(x, reduction_axes)
116
+
117
+ return self._normalize(
118
+ self,
119
+ x,
120
+ rms_sq,
121
+ reduction_axes,
122
+ feature_axes,
123
+ self.dtype,
124
+ self.param_dtype,
125
+ self.epsilon,
126
+ self.use_scale,
127
+ self.scale_init,
128
+ )
129
+
130
+ def _compute_rms_sq(self, x, axes):
131
+ x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
132
+ rms_sq = jnp.mean(jax.lax.square(x), axes)
133
+ return rms_sq
134
+
135
+ def _normalize(
136
+ self,
137
+ mdl,
138
+ x,
139
+ rms_sq,
140
+ reduction_axes,
141
+ feature_axes,
142
+ dtype,
143
+ param_dtype,
144
+ epsilon,
145
+ use_scale,
146
+ scale_init,
147
+ ):
148
+ reduction_axes = nn.normalization._canonicalize_axes(x.ndim, reduction_axes)
149
+ feature_axes = nn.normalization._canonicalize_axes(x.ndim, feature_axes)
150
+ stats_shape = list(x.shape)
151
+ for axis in reduction_axes:
152
+ stats_shape[axis] = 1
153
+ rms_sq = rms_sq.reshape(stats_shape)
154
+ feature_shape = [1] * x.ndim
155
+ reduced_feature_shape = []
156
+ for ax in feature_axes:
157
+ feature_shape[ax] = x.shape[ax]
158
+ reduced_feature_shape.append(x.shape[ax])
159
+ mul = lax.rsqrt(rms_sq + epsilon)
160
+ if use_scale:
161
+ scale = mdl.param(
162
+ "scale", scale_init, reduced_feature_shape, param_dtype
163
+ ).reshape(feature_shape)
164
+ mul *= scale
165
+ y = mul * x
166
+ return jnp.asarray(y, dtype)
167
+
168
+
169
+ def norm(type, *args, **kwargs):
170
+ if type == "rmsnorm":
171
+ return RMSNorm(*args, **kwargs)
172
+ elif type == "layernorm":
173
+ return nn.LayerNorm(*args, **kwargs)
174
+ else:
175
+ raise ValueError(f"Unknown norm type {type}")
176
+
177
 
178
  class FlaxBartAttention(FlaxBartAttention):
179
  """
180
  Edits:
181
  - causal mask is used only in decoder and considers image_length
182
+ - scale attention heads per NormFormer paper
183
  """
184
 
185
+ is_encoder: bool = False
186
+
187
  def setup(self) -> None:
188
  self.head_dim = self.embed_dim // self.num_heads
189
  if self.head_dim * self.num_heads != self.embed_dim:
 
197
  self.embed_dim,
198
  use_bias=self.bias,
199
  dtype=self.dtype,
 
200
  )
201
 
202
+ gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
203
+ self.config
204
+ )
205
 
206
+ self.q_proj = dense(
207
+ kernel_init=deepnet_init()
208
+ if self.config.use_deepnet_scaling
209
+ else jax.nn.initializers.normal(self.config.init_std)
210
+ )
211
+ self.k_proj = dense(
212
+ kernel_init=deepnet_init()
213
+ if self.config.use_deepnet_scaling
214
+ else jax.nn.initializers.normal(self.config.init_std)
215
+ )
216
+ self.v_proj = dense(
217
+ kernel_init=deepnet_init(gain)
218
+ if self.config.use_deepnet_scaling
219
+ else jax.nn.initializers.normal(self.config.init_std)
220
+ )
221
+ self.out_proj = dense(
222
+ kernel_init=deepnet_init(gain)
223
+ if self.config.use_deepnet_scaling
224
+ else jax.nn.initializers.normal(self.config.init_std)
225
+ )
226
  self.dropout_layer = nn.Dropout(rate=self.dropout)
227
 
228
+ if self.config.head_scale:
229
+ self.head_scale = self.param(
230
+ "head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1)
231
+ )
232
+
233
+ if self.config.use_cosine_attention:
234
+ self.tau = self.param(
235
+ "tau",
236
+ jax.nn.initializers.constant(self.config.tau_init),
237
+ (1, self.num_heads, 1, 1),
238
+ )
239
+
240
  if self.causal:
241
  # used only in decoder
242
  self.causal_mask = make_causal_mask(
243
  jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool"
244
  )
245
 
246
+ def __call__(
247
+ self,
248
+ hidden_states: jnp.ndarray,
249
+ key_value_states: Optional[jnp.ndarray] = None,
250
+ attention_mask: Optional[jnp.ndarray] = None,
251
+ init_cache: bool = False,
252
+ deterministic: bool = True,
253
+ ) -> Tuple[jnp.ndarray]:
254
+ """Input shape: Batch x Time x Channel"""
255
+
256
+ # if key_value_states are provided this layer is used as a cross-attention layer
257
+ # for the decoder
258
+ is_cross_attention = key_value_states is not None
259
+ batch_size = hidden_states.shape[0]
260
+
261
+ # get query proj
262
+ query_states = self.q_proj(hidden_states)
263
+ # get key, value proj
264
+ if is_cross_attention:
265
+ # cross_attentions
266
+ key_states = self.k_proj(key_value_states)
267
+ value_states = self.v_proj(key_value_states)
268
+ else:
269
+ # self_attention
270
+ key_states = self.k_proj(hidden_states)
271
+ value_states = self.v_proj(hidden_states)
272
 
273
+ query_states = self._split_heads(query_states)
274
+ key_states = self._split_heads(key_states)
275
+ value_states = self._split_heads(value_states)
 
 
 
276
 
277
+ # handle cache prepare causal attention mask
278
+ if self.causal:
279
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
280
+ if self.has_variable("cache", "cached_key"):
281
+ mask_shift = self.variables["cache"]["cache_index"]
282
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
283
+ causal_mask = lax.dynamic_slice(
284
+ self.causal_mask,
285
+ (0, 0, mask_shift, 0),
286
+ (1, 1, query_length, max_decoder_length),
287
+ )
288
+ else:
289
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
290
+ causal_mask = jnp.broadcast_to(
291
+ causal_mask, (batch_size,) + causal_mask.shape[1:]
292
+ )
293
+
294
+ # combine masks if needed
295
+ if attention_mask is not None and self.causal:
296
+ attention_mask = jnp.broadcast_to(
297
+ jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape
298
+ )
299
+ attention_mask = combine_masks(attention_mask, causal_mask)
300
+ elif self.causal:
301
+ attention_mask = causal_mask
302
+ elif attention_mask is not None:
303
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
304
+
305
+ # During fast autoregressive decoding, we feed one position at a time,
306
+ # and cache the keys and values step by step.
307
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
308
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
309
+ key_states, value_states, query_states, attention_mask
310
+ )
311
+
312
+ # Convert the boolean attention mask to an attention bias.
313
+ if attention_mask is not None:
314
+ # attention mask in the form of attention bias
315
+ attention_bias = lax.select(
316
+ attention_mask > 0,
317
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
318
+ jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
319
+ )
320
+ else:
321
+ attention_bias = None
322
+
323
+ dropout_rng = None
324
+ if not deterministic and self.dropout > 0.0:
325
+ dropout_rng = self.make_rng("dropout")
326
+
327
+ if self.config.use_cosine_attention:
328
+ # normalize q and k
329
+ query_states = query_states / (
330
+ jnp.linalg.norm(query_states, axis=-1, keepdims=True) + 1e-8
331
+ )
332
+ key_states = key_states / (
333
+ jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8
334
+ )
335
+ attn_weights = dot_product_attention_weights(
336
+ query_states,
337
+ key_states,
338
+ bias=attention_bias,
339
+ dropout_rng=dropout_rng,
340
+ dropout_rate=self.dropout,
341
+ broadcast_dropout=True,
342
+ deterministic=deterministic,
343
  dtype=self.dtype,
344
+ precision=None,
345
  )
346
+ if self.config.use_cosine_attention:
347
+ # divide by tau
348
+ attn_weights = attn_weights / jnp.maximum(self.tau, 0.01)
349
+
350
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
351
+ if self.config.head_scale:
352
+ # per Normformer
353
+ attn_output = attn_output * self.head_scale
354
+ attn_output = self._merge_heads(attn_output)
355
+ attn_output = self.out_proj(attn_output)
356
+
357
+ return attn_output, attn_weights
358
+
359
+
360
+ class GLU(nn.Module):
361
+ """From "GLU Variants Improve Transformer" by https://arxiv.org/abs/2002.05202"""
362
+
363
+ config: DalleBartConfig
364
+ ffn_dim: int
365
+ embed_dim: int
366
+ dtype: jnp.dtype = jnp.float32
367
+ is_encoder: bool = False
368
+
369
+ @nn.compact
370
+ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
371
+
372
+ gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
373
+ self.config
374
+ )
375
+
376
+ if self.config.ln_positions in ["normformer"]:
377
+ x = norm(
378
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
379
+ )(x)
380
+ w = nn.Dense(
381
+ self.ffn_dim,
382
+ dtype=self.dtype,
383
+ use_bias=False,
384
+ kernel_init=deepnet_init(gain)
385
+ if self.config.use_deepnet_scaling
386
+ else jax.nn.initializers.normal(self.config.init_std),
387
+ )(x)
388
+ w = ACT2FN[self.config.activation_function](w)
389
+ v = nn.Dense(
390
+ self.ffn_dim,
391
  dtype=self.dtype,
392
  use_bias=False,
393
+ kernel_init=deepnet_init(gain)
394
+ if self.config.use_deepnet_scaling
395
+ else jax.nn.initializers.normal(self.config.init_std),
396
+ )(x)
397
+ x = w * v
398
+ if self.config.ln_positions in ["normformer"]:
399
+ x = norm(
400
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
401
+ )(x)
402
+ x = nn.Dropout(rate=self.config.activation_dropout)(
403
+ x, deterministic=deterministic
404
  )
405
+
406
+ x = nn.Dense(
407
  self.embed_dim,
408
  dtype=self.dtype,
409
  use_bias=False,
410
+ kernel_init=deepnet_init(gain)
411
+ if self.config.use_deepnet_scaling
412
+ else jax.nn.initializers.normal(self.config.init_std),
413
+ )(x)
414
+ if self.config.ln_positions in ["swinv2"]:
415
+ x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
416
+ x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
417
+ return x
418
+
419
+
420
+ class FFN(nn.Module):
421
+ """Simple FFN layer"""
422
+
423
+ config: DalleBartConfig
424
+ ffn_dim: int
425
+ embed_dim: int
426
+ dtype: jnp.dtype = jnp.float32
427
+ is_encoder: bool = False
428
+
429
+ @nn.compact
430
+ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
431
+
432
+ gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
433
+ self.config
434
  )
435
+ if self.config.ln_positions in ["normformer"]:
436
+ x = norm(
437
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
438
+ )(x)
439
+ x = nn.Dense(
440
+ self.ffn_dim,
441
+ dtype=self.dtype,
442
+ use_bias=False,
443
+ kernel_init=deepnet_init(gain)
444
+ if self.config.use_deepnet_scaling
445
+ else jax.nn.initializers.normal(self.config.init_std),
446
+ )(x)
447
+ x = ACT2FN[self.config.activation_function](x)
448
+ if self.config.ln_positions in ["normformer"]:
449
+ x = norm(
450
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05, use_scale=False
451
+ )(x)
452
+ x = nn.Dropout(rate=self.config.activation_dropout)(
453
+ x, deterministic=deterministic
454
+ )
455
+ x = nn.Dense(
456
+ self.embed_dim,
457
+ dtype=self.dtype,
458
+ use_bias=False,
459
+ kernel_init=deepnet_init(gain)
460
+ if self.config.use_deepnet_scaling
461
+ else jax.nn.initializers.normal(self.config.init_std),
462
+ )(x)
463
+ if self.config.ln_positions in ["swinv2"]:
464
+ x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
465
+ x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
466
+ return x
467
 
468
 
469
+ class FlaxBartEncoderLayer(nn.Module):
470
  """
471
  Edits:
472
+ - no bias
473
+ - use custom FlaxBartAttention
474
  """
475
 
476
+ config: DalleBartConfig
477
+ dtype: jnp.dtype = jnp.float32
478
+ add_norm: bool = False
479
+ use_scale: bool = True
480
+
481
+ @nn.compact
482
+ def __call__(
483
+ self,
484
+ hidden_states: jnp.ndarray,
485
+ attention_mask: jnp.ndarray,
486
+ output_attentions: bool = True,
487
+ deterministic: bool = True,
488
+ ) -> Tuple[jnp.ndarray]:
489
+
490
+ res_gain = (
491
+ deepnet_gain["encoder"]["alpha"](self.config)
492
+ if self.config.use_deepnet_scaling
493
+ else 1
494
  )
495
+
496
+ embed_dim = self.config.d_model
497
+ residual = hidden_states
498
+ if self.config.ln_positions in ["normformer"]:
499
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
500
+ hidden_states
501
+ )
502
+ hidden_states, attn_weights = FlaxBartAttention(
503
+ config=self.config,
504
+ embed_dim=embed_dim,
505
+ num_heads=self.config.encoder_attention_heads,
506
+ dropout=self.config.attention_dropout,
507
+ bias=False,
508
+ dtype=self.dtype,
509
+ is_encoder=True,
510
+ )(hidden_states=hidden_states, attention_mask=attention_mask)
511
+
512
+ if self.config.ln_positions in ["normformer", "swinv2"]:
513
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
514
+ hidden_states
515
+ )
516
+ hidden_states = nn.Dropout(rate=self.config.dropout)(
517
+ hidden_states, deterministic=deterministic
518
+ )
519
+ hidden_states = residual * res_gain + hidden_states
520
+ if self.config.ln_positions in ["deepnet"]:
521
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
522
+ hidden_states
523
+ )
524
+
525
+ residual = hidden_states
526
+ ff_block = (
527
+ GLU(
528
+ config=self.config,
529
+ ffn_dim=self.config.encoder_ffn_dim,
530
+ embed_dim=embed_dim,
531
+ dtype=self.dtype,
532
+ is_encoder=True,
533
+ )
534
+ if self.config.use_glu
535
+ else FFN(
536
+ config=self.config,
537
+ ffn_dim=self.config.encoder_ffn_dim,
538
+ embed_dim=embed_dim,
539
+ dtype=self.dtype,
540
+ is_encoder=True,
541
+ )
542
+ )
543
+ hidden_states = ff_block(hidden_states, deterministic=deterministic)
544
+ hidden_states = residual * res_gain + hidden_states
545
+ if self.add_norm or self.config.ln_positions in ["deepnet"]:
546
+ use_scale = self.use_scale or self.config.ln_positions == "deepnet"
547
+ hidden_states = norm(
548
+ self.config.ln_type,
549
+ dtype=self.dtype,
550
+ epsilon=1e-05,
551
+ use_scale=use_scale,
552
+ )(hidden_states)
553
+
554
+ outputs = (hidden_states,)
555
+
556
+ if output_attentions:
557
+ outputs += (attn_weights,)
558
+
559
+ return outputs
560
 
561
 
562
+ class FlaxBartDecoderLayer(nn.Module):
563
  """
564
  Edits:
565
  - no bias
566
+ - use custom FlaxBartAttention
567
  """
568
 
569
+ config: DalleBartConfig
570
+ dtype: jnp.dtype = jnp.float32
571
+ add_norm: bool = False
572
+ use_scale: bool = False
573
+
574
+ @nn.compact
575
+ def __call__(
576
+ self,
577
+ hidden_states: jnp.ndarray,
578
+ attention_mask: jnp.ndarray,
579
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
580
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
581
+ init_cache: bool = False,
582
+ output_attentions: bool = True,
583
+ deterministic: bool = True,
584
+ ) -> Tuple[jnp.ndarray]:
585
+
586
+ res_gain = (
587
+ deepnet_gain["decoder"]["alpha"](self.config)
588
+ if self.config.use_deepnet_scaling
589
+ else 1
590
+ )
591
+
592
+ embed_dim = self.config.d_model
593
+ residual = hidden_states
594
+
595
+ # Self Attention
596
+ if self.config.ln_positions in ["normformer"]:
597
+ hidden_states = norm(
598
+ self.config.ln_type,
599
+ dtype=self.dtype,
600
+ epsilon=1e-05,
601
+ use_scale=False,
602
+ )(hidden_states)
603
+ hidden_states, attn_weights = FlaxBartAttention(
604
  config=self.config,
605
+ embed_dim=embed_dim,
606
  num_heads=self.config.decoder_attention_heads,
607
  dropout=self.config.attention_dropout,
608
  causal=True,
609
  bias=False,
610
  dtype=self.dtype,
611
+ is_encoder=False,
612
+ )(
613
+ hidden_states=hidden_states,
614
+ attention_mask=attention_mask,
615
+ init_cache=init_cache,
616
  )
 
 
 
617
 
618
+ if self.config.ln_positions in ["normformer", "swinv2"]:
619
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
620
+ hidden_states
621
+ )
622
+ hidden_states = nn.Dropout(rate=self.config.dropout)(
623
+ hidden_states, deterministic=deterministic
 
 
624
  )
625
+ hidden_states = residual * res_gain + hidden_states
626
+ if self.config.ln_positions in ["deepnet"]:
627
+ hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(
628
+ hidden_states
629
+ )
630
+
631
+ # Cross Attention
632
+ cross_attn_weights = None
633
+ if encoder_hidden_states is not None:
634
+ residual = hidden_states
635
+ if self.config.ln_positions in ["normformer"]:
636
+ hidden_states = norm(
637
+ self.config.ln_type,
638
+ dtype=self.dtype,
639
+ epsilon=1e-05,
640
+ use_scale=False,
641
+ )(hidden_states)
642
+ hidden_states, cross_attn_weights = FlaxBartAttention(
643
+ config=self.config,
644
+ embed_dim=embed_dim,
645
+ num_heads=self.config.decoder_attention_heads,
646
+ dropout=self.config.attention_dropout,
647
+ bias=False,
648
+ dtype=self.dtype,
649
+ is_encoder=False,
650
+ )(
651
+ hidden_states=hidden_states,
652
+ key_value_states=encoder_hidden_states,
653
+ attention_mask=encoder_attention_mask,
654
+ )
655
+ if self.config.ln_positions in ["normformer", "swinv2"]:
656
+ hidden_states = norm(
657
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
658
+ )(hidden_states)
659
+ hidden_states = nn.Dropout(rate=self.config.dropout)(
660
+ hidden_states, deterministic=deterministic
661
+ )
662
+ hidden_states = residual * res_gain + hidden_states
663
+ if self.config.ln_positions in ["deepnet"]:
664
+ hidden_states = norm(
665
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
666
+ )(hidden_states)
667
+
668
+ # Feed forward
669
+ residual = hidden_states
670
+ ff_block = (
671
+ GLU(
672
+ config=self.config,
673
+ ffn_dim=self.config.decoder_ffn_dim,
674
+ embed_dim=embed_dim,
675
+ dtype=self.dtype,
676
+ is_encoder=False,
677
+ )
678
+ if self.config.use_glu
679
+ else FFN(
680
+ config=self.config,
681
+ ffn_dim=self.config.decoder_ffn_dim,
682
+ embed_dim=embed_dim,
683
+ dtype=self.dtype,
684
+ is_encoder=False,
685
+ )
686
  )
687
+ hidden_states = ff_block(hidden_states, deterministic=deterministic)
688
+ hidden_states = residual * res_gain + hidden_states
689
+ if self.add_norm or self.config.ln_positions in ["deepnet"]:
690
+ use_scale = self.use_scale or self.config.ln_positions == "deepnet"
691
+ hidden_states = norm(
692
+ self.config.ln_type,
693
+ dtype=self.dtype,
694
+ epsilon=1e-05,
695
+ use_scale=use_scale,
696
+ )(hidden_states)
697
+
698
+ outputs = (hidden_states,)
699
+
700
+ if output_attentions:
701
+ outputs += (attn_weights, cross_attn_weights)
702
+
703
+ return outputs
704
+
705
+
706
+ class FlaxBartEncoderLayerCollection(nn.Module):
707
+ config: DalleBartConfig
708
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
709
+ """
710
+ Edits:
711
+ - use custom FlaxBartEncoderLayer
712
+ - allow Gradient Checkpointing (nn.remat)
713
+ """
714
+
715
+ @nn.compact
716
+ def __call__(
717
+ self,
718
+ hidden_states,
719
+ attention_mask,
720
+ deterministic: bool = True,
721
+ output_attentions: bool = False,
722
+ output_hidden_states: bool = False,
723
+ return_dict: bool = True,
724
+ ):
725
+ all_hidden_states = () if output_hidden_states else None
726
+ all_self_attns = () if output_attentions else None
727
+
728
+ n_layers = self.config.encoder_layers
729
+ layer = (
730
+ remat(FlaxBartEncoderLayer, static_argnums=(2, 3))
731
+ if self.config.gradient_checkpointing
732
+ else FlaxBartEncoderLayer
733
+ )
734
+ for i in range(n_layers):
735
+ if output_hidden_states:
736
+ all_hidden_states += (hidden_states,)
737
+ # final layernorm on the output of the last layer
738
+ # or every 6 layers for Swin v2
739
+ # ignored args for deepnet which always add a norm with scale
740
+ add_norm = (i == n_layers - 1) or (
741
+ (self.config.ln_positions == "swinv2") and ((i + 1) % 6 == 0)
742
+ )
743
+ # we don't need to scale the norm for the last layer
744
+ use_scale = i != n_layers - 1
745
+ layer_outputs = layer(
746
+ self.config, dtype=self.dtype, add_norm=add_norm, use_scale=use_scale
747
+ )(
748
+ hidden_states,
749
+ attention_mask,
750
+ output_attentions,
751
+ deterministic,
752
+ )
753
+ hidden_states = layer_outputs[0]
754
+ if output_attentions:
755
+ all_self_attns += (layer_outputs[1],)
756
+
757
+ # add hidden states from the last layer
758
+ if output_hidden_states:
759
+ all_hidden_states += (hidden_states,)
760
+
761
+ outputs = [
762
+ hidden_states,
763
+ all_hidden_states,
764
+ all_self_attns,
765
+ ]
766
+
767
+ if not return_dict:
768
+ return tuple(v for v in outputs if v is not None)
769
+
770
+ return FlaxBaseModelOutput(
771
+ last_hidden_state=hidden_states,
772
+ hidden_states=all_hidden_states,
773
+ attentions=all_self_attns,
774
  )
 
775
 
776
 
777
+ class FlaxBartDecoderLayerCollection(nn.Module):
778
+ config: DalleBartConfig
779
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
780
  """
781
  Edits:
782
  - use custom FlaxBartDecoderLayer
783
  - allow Gradient Checkpointing (nn.remat)
784
  """
785
 
786
+ @nn.compact
787
+ def __call__(
788
+ self,
789
+ hidden_states,
790
+ attention_mask,
791
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
792
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
793
+ deterministic: bool = True,
794
+ init_cache: bool = False,
795
+ output_attentions: bool = False,
796
+ output_hidden_states: bool = False,
797
+ return_dict: bool = True,
798
+ ):
799
+ # decoder layers
800
+ all_hidden_states = () if output_hidden_states else None
801
+ all_self_attns = () if output_attentions else None
802
+ all_cross_attentions = (
803
+ () if (output_attentions and encoder_hidden_states is not None) else None
804
+ )
805
+
806
+ n_layers = self.config.decoder_layers
807
+ layer = (
808
+ remat(FlaxBartDecoderLayer, static_argnums=(4, 5, 6))
809
  if self.config.gradient_checkpointing
810
  else FlaxBartDecoderLayer
811
  )
812
+ for i in range(n_layers):
813
+ if output_hidden_states:
814
+ all_hidden_states += (hidden_states,)
815
+ # final layernorm on the output of the last layer
816
+ # or every 6 layers for Swin v2
817
+ add_norm = (i == n_layers - 1) or (
818
+ (self.config.ln_positions == "swinv2") and ((i + 1) % 6 == 0)
819
+ )
820
+ # we don't need to scale the norm for the last layer
821
+ use_scale = i != n_layers - 1
822
+ layer_outputs = layer(
823
+ self.config, dtype=self.dtype, add_norm=add_norm, use_scale=use_scale
824
+ )(
825
+ hidden_states,
826
+ attention_mask,
827
+ encoder_hidden_states,
828
+ encoder_attention_mask,
829
+ init_cache,
830
+ output_attentions,
831
+ deterministic,
832
+ )
833
+
834
+ hidden_states = layer_outputs[0]
835
+ if output_attentions:
836
+ all_self_attns += (layer_outputs[1],)
837
+
838
+ if encoder_hidden_states is not None:
839
+ all_cross_attentions += (layer_outputs[2],)
840
+
841
+ # add hidden states from the last decoder layer
842
+ if output_hidden_states:
843
+ all_hidden_states += (hidden_states,)
844
+
845
+ outputs = [
846
+ hidden_states,
847
+ all_hidden_states,
848
+ all_self_attns,
849
+ all_cross_attentions,
850
  ]
851
+
852
+ if not return_dict:
853
+ return tuple(v for v in outputs if v is not None)
854
+
855
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
856
+ last_hidden_state=hidden_states,
857
+ hidden_states=all_hidden_states,
858
+ attentions=all_self_attns,
859
+ cross_attentions=all_cross_attentions,
860
+ )
861
 
862
 
863
  class FlaxBartEncoder(FlaxBartEncoder):
 
882
  self.embed_positions = nn.Embed(
883
  self.config.max_text_length + self.offset,
884
  embed_dim,
885
+ embedding_init=deepnet_init()
886
+ if self.config.use_deepnet_scaling
887
+ else jax.nn.initializers.normal(self.config.init_std),
888
  )
889
  self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
890
+ self.layernorm_embedding = norm(
891
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
892
+ )
893
 
894
 
895
  class FlaxBartDecoder(FlaxBartDecoder):
 
916
  self.embed_positions = nn.Embed(
917
  self.config.image_length + self.offset, # image length for BOS
918
  embed_dim,
919
+ embedding_init=deepnet_init()
920
+ if self.config.use_deepnet_scaling
921
+ else jax.nn.initializers.normal(self.config.init_std),
922
  )
923
 
924
  self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
925
+ self.layernorm_embedding = norm(
926
+ self.config.ln_type, dtype=self.dtype, epsilon=1e-05
927
+ )
928
 
929
 
930
  class FlaxBartModule(FlaxBartModule):
 
938
  encoder_embed_tokens = nn.Embed(
939
  self.config.encoder_vocab_size,
940
  self.config.d_model,
941
+ embedding_init=deepnet_init()
942
+ if self.config.use_deepnet_scaling
943
+ else jax.nn.initializers.normal(self.config.init_std),
944
  )
945
  decoder_embed_tokens = nn.Embed(
946
  self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
947
  self.config.d_model,
948
+ embedding_init=deepnet_init()
949
+ if self.config.use_deepnet_scaling
950
+ else jax.nn.initializers.normal(self.config.init_std),
951
  )
952
 
953
  self.encoder = FlaxBartEncoder(
 
1287
  + 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
1288
  use_bias=False,
1289
  dtype=self.dtype,
1290
+ kernel_init=deepnet_init()
1291
+ if self.config.use_deepnet_scaling
1292
+ else jax.nn.initializers.normal(self.config.init_std),
1293
  )
1294
 
1295
  def __call__(
src/dalle_mini/model/partitions.py CHANGED
@@ -36,23 +36,21 @@ def _replacement_rules(rules):
36
  def _get_partition_rules():
37
  return [
38
  # embeddings
39
- ((r"embed_positions", "embedding"), P("mp", None)),
40
- ((r"embed_tokens", "embedding"), P("mp", None)),
41
- # self-attention
42
- ((r"self_attn", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
43
- ((r"self_attn", "out_proj", "kernel"), P("mp", None)),
44
- # enc-dec attention
45
- ((r"encoder_attn", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
46
- ((r"encoder_attn", "out_proj", "kernel"), P("mp", None)),
47
  # FFN
48
- ((r"fc1", "kernel"), P(None, "mp")),
49
- ((r"fc2", "kernel"), P("mp", None)),
 
 
50
  # layer norms
51
- ((r"layernorm_embedding", "(bias|scale)"), None),
52
- ((r"self_attn_layer_norm", "(bias|scale)"), None),
53
- ((r"encoder_attn_layer_norm", "(bias|scale)"), None),
54
- ((r"final_layer_norm", "(bias|scale)"), None),
55
- ((r"lm_head", "kernel"), P(None, "mp")),
56
  ]
57
 
58
 
@@ -63,6 +61,6 @@ def set_partitions(in_dict):
63
  result = {k: replace(k, v) for k, v in initd.items()}
64
  for k, v in result.items():
65
  if v == _unmatched:
66
- print(k)
67
  assert _unmatched not in result.values(), "Incomplete partition spec."
68
  return freeze(unflatten_dict(result))
 
36
  def _get_partition_rules():
37
  return [
38
  # embeddings
39
+ (("embed_positions", "embedding"), P("mp", None)),
40
+ (("embed_tokens", "embedding"), P("mp", None)),
41
+ # attention
42
+ (("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
43
+ (("out_proj", "kernel"), P("mp", None)),
 
 
 
44
  # FFN
45
+ (("Dense_0", "kernel"), P(None, "mp")),
46
+ (("GLU.*", "Dense_1", "kernel"), P(None, "mp")),
47
+ (("GLU.*", "Dense_2", "kernel"), P("mp", None)),
48
+ (("FFN.*", "Dense_1", "kernel"), P("mp", None)),
49
  # layer norms
50
+ (("(bias|scale)",), None),
51
+ (("lm_head", "kernel"), P(None, "mp")),
52
+ # head scale and tau
53
+ (("(head_scale|tau)",), None),
 
54
  ]
55
 
56
 
 
61
  result = {k: replace(k, v) for k, v in initd.items()}
62
  for k, v in result.items():
63
  if v == _unmatched:
64
+ print(f"Unmatched -> {k}")
65
  assert _unmatched not in result.values(), "Incomplete partition spec."
66
  return freeze(unflatten_dict(result))
src/dalle_mini/model/utils.py CHANGED
@@ -9,13 +9,11 @@ class PretrainedFromWandbMixin:
9
  @classmethod
10
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
11
  """
12
- Initializes from a wandb artifact, google bucket path or delegates loading to the superclass.
13
  """
14
  with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
15
- if (
16
- ":" in pretrained_model_name_or_path
17
- and not os.path.isdir(pretrained_model_name_or_path)
18
- and not pretrained_model_name_or_path.startswith("gs")
19
  ):
20
  # wandb artifact
21
  if wandb.run is not None:
@@ -27,17 +25,3 @@ class PretrainedFromWandbMixin:
27
  return super(PretrainedFromWandbMixin, cls).from_pretrained(
28
  pretrained_model_name_or_path, *model_args, **kwargs
29
  )
30
-
31
-
32
- def copy_blobs(source_path, dest_path):
33
- assert source_path.startswith("gs://")
34
- from google.cloud import storage
35
-
36
- bucket_path = Path(source_path[5:])
37
- bucket, dir_path = str(bucket_path).split("/", 1)
38
- client = storage.Client()
39
- bucket = client.bucket(bucket)
40
- blobs = client.list_blobs(bucket, prefix=f"{dir_path}/")
41
- for blob in blobs:
42
- dest_name = str(Path(dest_path) / Path(blob.name).name)
43
- blob.download_to_filename(dest_name)
 
9
  @classmethod
10
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
11
  """
12
+ Initializes from a wandb artifact or delegates loading to the superclass.
13
  """
14
  with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
15
+ if ":" in pretrained_model_name_or_path and not os.path.isdir(
16
+ pretrained_model_name_or_path
 
 
17
  ):
18
  # wandb artifact
19
  if wandb.run is not None:
 
25
  return super(PretrainedFromWandbMixin, cls).from_pretrained(
26
  pretrained_model_name_or_path, *model_args, **kwargs
27
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/inference/inference_pipeline.ipynb CHANGED
@@ -47,7 +47,7 @@
47
  "outputs": [],
48
  "source": [
49
  "# Install required libraries\n",
50
- "!pip install -q transformers\n",
51
  "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
52
  "!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
53
  ]
@@ -75,7 +75,7 @@
75
  "# Model references\n",
76
  "\n",
77
  "# dalle-mini\n",
78
- "DALLE_MODEL = \"dalle-mini/dalle-mini/model-2vm4itcx:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
79
  "DALLE_COMMIT_ID = None\n",
80
  "\n",
81
  "# VQGAN model\n",
@@ -302,21 +302,7 @@
302
  },
303
  "outputs": [],
304
  "source": [
305
- "tokenized_prompt = processor([prompt])\n",
306
- "tokenized_prompt"
307
- ]
308
- },
309
- {
310
- "cell_type": "markdown",
311
- "metadata": {
312
- "id": "_Y5dqFj7prMQ"
313
- },
314
- "source": [
315
- "Notes:\n",
316
- "\n",
317
- "* `0`: BOS, special token representing the beginning of a sequence\n",
318
- "* `2`: EOS, special token representing the end of a sequence\n",
319
- "* `1`: special token representing the padding of a sequence when requesting a specific length"
320
  ]
321
  },
322
  {
@@ -459,13 +445,6 @@
459
  " display(images[idx])\n",
460
  " print(f\"Score: {logits[idx]:.2f}\\n\")"
461
  ]
462
- },
463
- {
464
- "cell_type": "code",
465
- "execution_count": null,
466
- "metadata": {},
467
- "outputs": [],
468
- "source": []
469
  }
470
  ],
471
  "metadata": {
 
47
  "outputs": [],
48
  "source": [
49
  "# Install required libraries\n",
50
+ "!pip install -q git+https://github.com/huggingface/transformers.git\n",
51
  "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
52
  "!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
53
  ]
 
75
  "# Model references\n",
76
  "\n",
77
  "# dalle-mini\n",
78
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/model-3e2l7fxk:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
79
  "DALLE_COMMIT_ID = None\n",
80
  "\n",
81
  "# VQGAN model\n",
 
302
  },
303
  "outputs": [],
304
  "source": [
305
+ "tokenized_prompt = processor([prompt])"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  ]
307
  },
308
  {
 
445
  " display(images[idx])\n",
446
  " print(f\"Score: {logits[idx]:.2f}\\n\")"
447
  ]
 
 
 
 
 
 
 
448
  }
449
  ],
450
  "metadata": {
tools/train/config/medium/config.json CHANGED
@@ -3,7 +3,6 @@
3
  "activation_function": "gelu",
4
  "attention_dropout": 0.0,
5
  "bos_token_id": 16385,
6
- "classifier_dropout": 0.0,
7
  "d_model": 1408,
8
  "decoder_attention_heads": 16,
9
  "decoder_ffn_dim": 4096,
 
3
  "activation_function": "gelu",
4
  "attention_dropout": 0.0,
5
  "bos_token_id": 16385,
 
6
  "d_model": 1408,
7
  "decoder_attention_heads": 16,
8
  "decoder_ffn_dim": 4096,
tools/train/config/mega/config.json CHANGED
@@ -3,7 +3,6 @@
3
  "activation_function": "gelu",
4
  "attention_dropout": 0.0,
5
  "bos_token_id": 16385,
6
- "classifier_dropout": 0.0,
7
  "d_model": 2048,
8
  "decoder_attention_heads": 32,
9
  "decoder_ffn_dim": 8192,
 
3
  "activation_function": "gelu",
4
  "attention_dropout": 0.0,
5
  "bos_token_id": 16385,
 
6
  "d_model": 2048,
7
  "decoder_attention_heads": 32,
8
  "decoder_ffn_dim": 8192,
tools/train/config/micro/config.json CHANGED
@@ -3,7 +3,6 @@
3
  "activation_function": "gelu",
4
  "attention_dropout": 0.0,
5
  "bos_token_id": 16385,
6
- "classifier_dropout": 0.0,
7
  "d_model": 256,
8
  "decoder_attention_heads": 2,
9
  "decoder_ffn_dim": 256,
 
3
  "activation_function": "gelu",
4
  "attention_dropout": 0.0,
5
  "bos_token_id": 16385,
 
6
  "d_model": 256,
7
  "decoder_attention_heads": 2,
8
  "decoder_ffn_dim": 256,
tools/train/config/mini/config.json CHANGED
@@ -3,17 +3,14 @@
3
  "activation_function": "gelu",
4
  "attention_dropout": 0.0,
5
  "bos_token_id": 16385,
6
- "classifier_dropout": 0.0,
7
  "d_model": 1024,
8
  "decoder_attention_heads": 16,
9
- "decoder_ffn_dim": 4096,
10
- "decoder_layerdrop": 0.0,
11
  "decoder_layers": 12,
12
  "decoder_start_token_id": 16384,
13
  "dropout": 0.0,
14
  "encoder_attention_heads": 16,
15
- "encoder_ffn_dim": 4096,
16
- "encoder_layerdrop": 0.0,
17
  "encoder_layers": 12,
18
  "encoder_vocab_size": 50264,
19
  "eos_token_id": 16385,
 
3
  "activation_function": "gelu",
4
  "attention_dropout": 0.0,
5
  "bos_token_id": 16385,
 
6
  "d_model": 1024,
7
  "decoder_attention_heads": 16,
8
+ "decoder_ffn_dim": 2560,
 
9
  "decoder_layers": 12,
10
  "decoder_start_token_id": 16384,
11
  "dropout": 0.0,
12
  "encoder_attention_heads": 16,
13
+ "encoder_ffn_dim": 2560,
 
14
  "encoder_layers": 12,
15
  "encoder_vocab_size": 50264,
16
  "eos_token_id": 16385,
tools/train/train.py CHANGED
@@ -18,7 +18,6 @@ Training DALL·E Mini.
18
  Script adapted from run_summarization_flax.py
19
  """
20
 
21
- import copy
22
  import io
23
  import logging
24
  import os
@@ -30,8 +29,10 @@ from pathlib import Path
30
  from typing import Any, Callable, NamedTuple, Optional
31
 
32
  import datasets
 
33
  import jax
34
  import jax.numpy as jnp
 
35
  import numpy as np
36
  import optax
37
  import transformers
@@ -405,10 +406,14 @@ class TrainingArguments:
405
  default=False,
406
  metadata={"help": "Log model to wandb at `save_steps` frequency."},
407
  )
408
- log_histograms: bool = field(
 
 
 
 
409
  default=False,
410
  metadata={
411
- "help": "Log parameters and gradients histograms. Slows down training."
412
  },
413
  )
414
 
@@ -471,6 +476,8 @@ class TrainingArguments:
471
  ], f"Selected learning rate decay not supported: {self.lr_decay}"
472
  if self.per_device_eval_batch_size is None:
473
  self.per_device_eval_batch_size = self.per_device_train_batch_size
 
 
474
  if (
475
  os.path.exists(self.output_dir)
476
  and os.listdir(self.output_dir)
@@ -497,48 +504,6 @@ class TrainState(train_state.TrainState):
497
  train_samples: int = 0 # number of samples seen
498
 
499
 
500
- class MetricsLogger:
501
- def __init__(self, step):
502
- self.step = step
503
- self.time = time.perf_counter()
504
- self.state_dict = {}
505
-
506
- def update_state_metrics(self, state):
507
- """Update internal state metrics (logged at each call to be used as x-axis)"""
508
- self.state_dict = {
509
- f'train/{k.split("_")[-1]}': getattr(state, k)
510
- for k in ["step", "epoch", "train_time", "train_samples"]
511
- }
512
- # timing metrics
513
- new_step = int(state.step)
514
- new_time = time.perf_counter()
515
- if new_step > self.step:
516
- time_per_step = (new_time - self.time) / (new_step - self.step)
517
- self.step = new_step
518
- self.time = new_time
519
- self.state_dict["train/time_per_step"] = time_per_step
520
-
521
- def log(self, metrics, prefix=None):
522
- if jax.process_index() == 0:
523
- log_metrics = {}
524
- for k, v in metrics.items():
525
- if prefix is not None:
526
- k = f"{prefix}/{k}"
527
- if "_norm" in k:
528
- log_metrics[f"{k}/"] = unfreeze(v)
529
- elif "_hist" in k:
530
- v = jax.tree_map(lambda x: jax.device_get(x), unfreeze(v))
531
- v = jax.tree_map(
532
- lambda x: wandb.Histogram(np_histogram=x),
533
- v,
534
- is_leaf=lambda x: isinstance(x, tuple),
535
- )
536
- log_metrics[f"{k}/"] = v
537
- else:
538
- log_metrics[k] = v
539
- wandb.log({**log_metrics, **self.state_dict})
540
-
541
-
542
  def main():
543
  # See all possible arguments by passing the --help flag to this script.
544
  parser = HfArgumentParser(
@@ -593,9 +558,7 @@ def main():
593
  # Set up our new model config
594
  if model_args.config_name:
595
  config = DalleBartConfig.from_pretrained(model_args.config_name)
596
- # initializing params with gradient checkpointing creates issues
597
- # we correctly set it later per training_args
598
- config.gradient_checkpointing = False
599
  else:
600
  config = None
601
 
@@ -607,9 +570,7 @@ def main():
607
  seed=training_args.seed_model,
608
  dtype=getattr(jnp, model_args.dtype),
609
  abstract_init=True, # we overwrite them with loaded checkpoint
610
- # initializing params with gradient checkpointing creates issues
611
- # we correctly set it later per training_args
612
- gradient_checkpointing=False,
613
  )
614
  else:
615
  model = DalleBart(
@@ -619,21 +580,6 @@ def main():
619
  abstract_init=True,
620
  )
621
 
622
- # define model eval and train functions
623
- eval_fn = model.__call__
624
- if training_args.gradient_checkpointing:
625
- remat_config = copy.deepcopy(model.config)
626
- remat_config.gradient_checkpointing = True
627
- remat_model = DalleBart(
628
- remat_config,
629
- seed=training_args.seed_model,
630
- dtype=getattr(jnp, model_args.dtype),
631
- init_weights=False,
632
- )
633
- train_fn = remat_model.__call__
634
- else:
635
- train_fn = model.__call__
636
-
637
  # get model metadata
638
  model_metadata = model_args.get_metadata()
639
 
@@ -708,8 +654,16 @@ def main():
708
  "len_train_dataset": len_train_dataset,
709
  "len_eval_dataset": len_eval_dataset,
710
  "batch_size_per_step": batch_size_per_step,
711
- "num_params": num_params,
712
  "num_devices": jax.device_count(),
 
 
 
 
 
 
 
 
713
  }
714
  )
715
 
@@ -719,7 +673,7 @@ def main():
719
  warmup_fn = optax.linear_schedule(
720
  init_value=0.0,
721
  end_value=training_args.learning_rate,
722
- transition_steps=training_args.warmup_steps,
723
  )
724
  # offset step when resuming
725
  if model_metadata.get("step", 0):
@@ -867,7 +821,7 @@ def main():
867
  epoch=None,
868
  train_time=None,
869
  train_samples=None,
870
- apply_fn=train_fn,
871
  tx=optimizer,
872
  )
873
 
@@ -880,13 +834,13 @@ def main():
880
  # params have not been initialized yet
881
  return model.init_weights()
882
 
883
- with maps.mesh(mesh.devices, mesh.axis_names):
884
  logger.info(" Creating state")
885
  if not model_args.restore_state:
886
 
887
  def init_state(params):
888
  return TrainState.create(
889
- apply_fn=train_fn,
890
  tx=optimizer,
891
  params=maybe_init_params(params),
892
  dropout_rng=dropout_rng,
@@ -913,7 +867,7 @@ def main():
913
 
914
  def restore_state(params, opt_state):
915
  return TrainState(
916
- apply_fn=train_fn,
917
  tx=optimizer,
918
  params=params,
919
  opt_state=opt_state,
@@ -959,7 +913,7 @@ def main():
959
  )
960
 
961
  # Define gradient update step fn
962
- def train_step(state, batch, delta_time):
963
 
964
  # get a minibatch (one gradient accumulation slice)
965
  def get_minibatch(batch, grad_idx):
@@ -1048,36 +1002,45 @@ def main():
1048
  state = state.apply_gradients(
1049
  grads=grads,
1050
  dropout_rng=dropout_rng,
1051
- train_time=state.train_time + delta_time,
1052
  train_samples=state.train_samples + batch_size_per_step,
1053
  )
1054
 
1055
- # get norm and histogram of grads and params
1056
- zeros_norm = jax.tree_map(lambda _: jnp.float32(0), state.params)
 
 
1057
 
1058
- def maybe_fn(fn, val, zeros):
1059
  """Call fn only if it is a logging step"""
1060
  return jax.lax.cond(
1061
- state.step % training_args.logging_steps == 0,
1062
  fn,
1063
  lambda _: zeros,
1064
  val,
1065
  )
1066
 
1067
- def norm(val):
1068
- return jax.tree_map(lambda x: jnp.linalg.norm(x), val)
1069
 
1070
- gradients_norm = maybe_fn(norm, grads, zeros_norm)
1071
- params_norm = maybe_fn(norm, state.params, zeros_norm)
1072
 
1073
- metrics = {
1074
- "loss": loss,
1075
- "learning_rate": learning_rate_fn(state.step),
1076
- "gradients_norm": gradients_norm,
1077
- "params_norm": params_norm,
1078
- }
1079
 
1080
- if training_args.log_histograms:
 
 
 
 
 
 
 
1081
  zeros_hist = jax.tree_map(
1082
  lambda _: jnp.histogram(jnp.zeros(1), density=True), state.params
1083
  )
@@ -1085,8 +1048,12 @@ def main():
1085
  def histogram(val):
1086
  return jax.tree_map(lambda x: jnp.histogram(x, density=True), val)
1087
 
1088
- gradients_hist = maybe_fn(histogram, grads, zeros_hist)
1089
- params_hist = maybe_fn(histogram, state.params, zeros_hist)
 
 
 
 
1090
 
1091
  metrics.update(
1092
  {
@@ -1101,7 +1068,7 @@ def main():
1101
  def eval_step(state, batch):
1102
  def compute_eval_loss(batch):
1103
  batch, labels = batch.pop("labels")
1104
- logits = eval_fn(**batch, params=state.params, train=False)[0]
1105
  return loss_fn(logits, labels)
1106
 
1107
  if use_vmap_trick:
@@ -1134,13 +1101,73 @@ def main():
1134
  out_axis_resources=None,
1135
  )
1136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1137
  # init variables
1138
- last_time = time.perf_counter()
1139
  train_metrics = None
1140
- step = int(state.step)
1141
- metrics_logger = MetricsLogger(step)
1142
  epochs = tqdm(
1143
- range(state.epoch, num_epochs),
1144
  desc=f"Epoch ... (1/{num_epochs})",
1145
  position=0,
1146
  disable=jax.process_index() > 0,
@@ -1149,6 +1176,7 @@ def main():
1149
  def run_evaluation():
1150
  # ======================== Evaluating ==============================
1151
  if training_args.do_eval:
 
1152
  eval_loader = dataset.dataloader("eval", eval_batch_size_per_step)
1153
  eval_steps = (
1154
  len_eval_dataset // eval_batch_size_per_step
@@ -1195,6 +1223,7 @@ def main():
1195
 
1196
  # log metrics
1197
  metrics_logger.log(eval_metrics, prefix="eval")
 
1198
 
1199
  # Print metrics and update progress bar
1200
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
@@ -1206,6 +1235,7 @@ def main():
1206
  def run_save_model(state, eval_metrics=None):
1207
  if jax.process_index() == 0:
1208
 
 
1209
  output_dir = training_args.output_dir
1210
  use_bucket = output_dir.startswith("gs://")
1211
  if use_bucket:
@@ -1298,13 +1328,15 @@ def main():
1298
  f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
1299
  )
1300
  wandb.run.log_artifact(artifact_state)
 
1301
 
1302
  logger.info(" Ready to start training")
1303
- with maps.mesh(mesh.devices, mesh.axis_names):
1304
  for epoch in epochs:
1305
  state.replace(epoch=epoch)
 
1306
  # ======================== Training ================================
1307
- metrics_logger.update_state_metrics(state)
1308
  metrics_logger.log({})
1309
 
1310
  # Generate an epoch by shuffling sampling indices from the train dataset
@@ -1323,9 +1355,7 @@ def main():
1323
  disable=jax.process_index() > 0,
1324
  ):
1325
  # calculate delta time (we have a lag of one step but it's ok)
1326
- new_time = time.perf_counter()
1327
- delta_time = new_time - last_time
1328
- last_time = new_time
1329
 
1330
  # set correct shape to batch
1331
  # - add grad_step dim if gradient_accumulation_steps > 1
@@ -1353,18 +1383,23 @@ def main():
1353
  batch = freeze(batch)
1354
 
1355
  # train step
1356
- state, train_metrics = p_train_step(state, batch, delta_time)
1357
- step += 1
1358
-
1359
- if step % training_args.logging_steps == 0 and jax.process_index() == 0:
1360
- metrics_logger.update_state_metrics(state)
 
 
 
 
 
1361
  metrics_logger.log(train_metrics, prefix="train")
1362
 
1363
  eval_metrics = None
1364
- if step % training_args.eval_steps == 0:
1365
  eval_metrics = run_evaluation()
1366
 
1367
- if step % training_args.save_steps == 0:
1368
  run_save_model(state, eval_metrics)
1369
 
1370
  # log final train metrics
 
18
  Script adapted from run_summarization_flax.py
19
  """
20
 
 
21
  import io
22
  import logging
23
  import os
 
29
  from typing import Any, Callable, NamedTuple, Optional
30
 
31
  import datasets
32
+ import flax
33
  import jax
34
  import jax.numpy as jnp
35
+ import jaxlib
36
  import numpy as np
37
  import optax
38
  import transformers
 
406
  default=False,
407
  metadata={"help": "Log model to wandb at `save_steps` frequency."},
408
  )
409
+ log_norm_steps: int = field(
410
+ default=True,
411
+ metadata={"help": "Log parameters and gradients norm at this frequency."},
412
+ )
413
+ log_histogram_steps: int = field(
414
  default=False,
415
  metadata={
416
+ "help": "Log parameters and gradients histograms at this frequency. Slows down training."
417
  },
418
  )
419
 
 
476
  ], f"Selected learning rate decay not supported: {self.lr_decay}"
477
  if self.per_device_eval_batch_size is None:
478
  self.per_device_eval_batch_size = self.per_device_train_batch_size
479
+ if self.log_norm_steps is True:
480
+ self.log_norm_steps = self.logging_steps
481
  if (
482
  os.path.exists(self.output_dir)
483
  and os.listdir(self.output_dir)
 
504
  train_samples: int = 0 # number of samples seen
505
 
506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
  def main():
508
  # See all possible arguments by passing the --help flag to this script.
509
  parser = HfArgumentParser(
 
558
  # Set up our new model config
559
  if model_args.config_name:
560
  config = DalleBartConfig.from_pretrained(model_args.config_name)
561
+ config.gradient_checkpointing = training_args.gradient_checkpointing
 
 
562
  else:
563
  config = None
564
 
 
570
  seed=training_args.seed_model,
571
  dtype=getattr(jnp, model_args.dtype),
572
  abstract_init=True, # we overwrite them with loaded checkpoint
573
+ gradient_checkpointing=training_args.gradient_checkpointing,
 
 
574
  )
575
  else:
576
  model = DalleBart(
 
580
  abstract_init=True,
581
  )
582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  # get model metadata
584
  model_metadata = model_args.get_metadata()
585
 
 
654
  "len_train_dataset": len_train_dataset,
655
  "len_eval_dataset": len_eval_dataset,
656
  "batch_size_per_step": batch_size_per_step,
657
+ "model": {"num_params": num_params, "config": model.config.to_dict()},
658
  "num_devices": jax.device_count(),
659
+ "versions": {
660
+ "jax": jax.__version__,
661
+ "jaxlib": jaxlib.__version__,
662
+ "flax": flax.__version__,
663
+ "transformers": transformers.__version__,
664
+ "datasets": datasets.__version__,
665
+ "wandb": wandb.__version__,
666
+ },
667
  }
668
  )
669
 
 
673
  warmup_fn = optax.linear_schedule(
674
  init_value=0.0,
675
  end_value=training_args.learning_rate,
676
+ transition_steps=training_args.warmup_steps + 1, # ensure not 0
677
  )
678
  # offset step when resuming
679
  if model_metadata.get("step", 0):
 
821
  epoch=None,
822
  train_time=None,
823
  train_samples=None,
824
+ apply_fn=model.__call__,
825
  tx=optimizer,
826
  )
827
 
 
834
  # params have not been initialized yet
835
  return model.init_weights()
836
 
837
+ with mesh:
838
  logger.info(" Creating state")
839
  if not model_args.restore_state:
840
 
841
  def init_state(params):
842
  return TrainState.create(
843
+ apply_fn=model.__call__,
844
  tx=optimizer,
845
  params=maybe_init_params(params),
846
  dropout_rng=dropout_rng,
 
867
 
868
  def restore_state(params, opt_state):
869
  return TrainState(
870
+ apply_fn=model.__call__,
871
  tx=optimizer,
872
  params=params,
873
  opt_state=opt_state,
 
913
  )
914
 
915
  # Define gradient update step fn
916
+ def train_step(state, batch, train_time):
917
 
918
  # get a minibatch (one gradient accumulation slice)
919
  def get_minibatch(batch, grad_idx):
 
1002
  state = state.apply_gradients(
1003
  grads=grads,
1004
  dropout_rng=dropout_rng,
1005
+ train_time=train_time,
1006
  train_samples=state.train_samples + batch_size_per_step,
1007
  )
1008
 
1009
+ metrics = {
1010
+ "loss": loss,
1011
+ "learning_rate": learning_rate_fn(state.step),
1012
+ }
1013
 
1014
+ def maybe_fn(fn, val, zeros, freq):
1015
  """Call fn only if it is a logging step"""
1016
  return jax.lax.cond(
1017
+ state.step % freq == 0,
1018
  fn,
1019
  lambda _: zeros,
1020
  val,
1021
  )
1022
 
1023
+ if training_args.log_norm_steps:
1024
+ zeros_norm = jax.tree_map(lambda _: jnp.float32(0), state.params)
1025
 
1026
+ def norm(val):
1027
+ return jax.tree_map(lambda x: jnp.linalg.norm(x), val)
1028
 
1029
+ gradients_norm = maybe_fn(
1030
+ norm, grads, zeros_norm, training_args.log_norm_steps
1031
+ )
1032
+ params_norm = maybe_fn(
1033
+ norm, state.params, zeros_norm, training_args.log_norm_steps
1034
+ )
1035
 
1036
+ metrics.update(
1037
+ {
1038
+ "gradients_norm": gradients_norm,
1039
+ "params_norm": params_norm,
1040
+ }
1041
+ )
1042
+
1043
+ if training_args.log_histogram_steps:
1044
  zeros_hist = jax.tree_map(
1045
  lambda _: jnp.histogram(jnp.zeros(1), density=True), state.params
1046
  )
 
1048
  def histogram(val):
1049
  return jax.tree_map(lambda x: jnp.histogram(x, density=True), val)
1050
 
1051
+ gradients_hist = maybe_fn(
1052
+ histogram, grads, zeros_hist, training_args.log_histogram_steps
1053
+ )
1054
+ params_hist = maybe_fn(
1055
+ histogram, state.params, zeros_hist, training_args.log_histogram_steps
1056
+ )
1057
 
1058
  metrics.update(
1059
  {
 
1068
  def eval_step(state, batch):
1069
  def compute_eval_loss(batch):
1070
  batch, labels = batch.pop("labels")
1071
+ logits = model(**batch, params=state.params, train=False)[0]
1072
  return loss_fn(logits, labels)
1073
 
1074
  if use_vmap_trick:
 
1101
  out_axis_resources=None,
1102
  )
1103
 
1104
+ # define metrics logger
1105
+ class MetricsLogger:
1106
+ def __init__(self, step):
1107
+ # keep state
1108
+ self.state_dict = {}
1109
+ # estimate speed
1110
+ self.step = step
1111
+ self.time = time.perf_counter()
1112
+ self.offset_time = 0.0
1113
+
1114
+ def update_state_metrics(self, state):
1115
+ """Update internal state metrics (logged at each call to be used as x-axis)"""
1116
+ self.state_dict = {
1117
+ f'train/{k.split("_")[-1]}': state[k]
1118
+ for k in ["step", "epoch", "train_time", "train_samples"]
1119
+ }
1120
+ # timing metrics
1121
+ new_step = int(state["step"])
1122
+ new_time = time.perf_counter()
1123
+ if new_step > self.step:
1124
+ # remove time for eval & save
1125
+ delta_time = new_time - self.time - self.offset_time
1126
+ self.offset_time = 0
1127
+ time_per_step = delta_time / (new_step - self.step)
1128
+ self.step = new_step
1129
+ self.time = new_time
1130
+ self.log_time("train_per_step", time_per_step, offset=False)
1131
+ self.log_time("train_per_log", delta_time, offset=False)
1132
+
1133
+ def log_time(self, key, duration, offset=True):
1134
+ wandb.log({f"time/{key}": duration, **self.state_dict})
1135
+ if offset:
1136
+ self.offset_time += duration
1137
+
1138
+ def log(self, metrics, prefix=None):
1139
+ if jax.process_index() == 0:
1140
+ log_metrics = {}
1141
+ for k, v in metrics.items():
1142
+ if "_norm" in k:
1143
+ if self.step % training_args.log_norm_steps == 0:
1144
+ log_metrics[f"{k}/"] = unfreeze(v)
1145
+ elif "_hist" in k:
1146
+ if self.step % training_args.log_histogram_steps == 0:
1147
+ v = jax.tree_map(lambda x: jax.device_get(x), unfreeze(v))
1148
+ v = jax.tree_map(
1149
+ lambda x: wandb.Histogram(np_histogram=x),
1150
+ v,
1151
+ is_leaf=lambda x: isinstance(x, tuple),
1152
+ )
1153
+ log_metrics[f"{k}/"] = v
1154
+ else:
1155
+ if prefix is not None:
1156
+ k = f"{prefix}/{k}"
1157
+ log_metrics[k] = v
1158
+ wandb.log({**log_metrics, **self.state_dict})
1159
+
1160
+ # keep local copy of state
1161
+ local_state = {
1162
+ k: jax.device_get(getattr(state, k)).item()
1163
+ for k in ["step", "epoch", "train_time", "train_samples"]
1164
+ }
1165
  # init variables
1166
+ start_time = time.perf_counter() - local_state["train_time"]
1167
  train_metrics = None
1168
+ metrics_logger = MetricsLogger(local_state["step"])
 
1169
  epochs = tqdm(
1170
+ range(local_state["epoch"], num_epochs),
1171
  desc=f"Epoch ... (1/{num_epochs})",
1172
  position=0,
1173
  disable=jax.process_index() > 0,
 
1176
  def run_evaluation():
1177
  # ======================== Evaluating ==============================
1178
  if training_args.do_eval:
1179
+ start_eval_time = time.perf_counter()
1180
  eval_loader = dataset.dataloader("eval", eval_batch_size_per_step)
1181
  eval_steps = (
1182
  len_eval_dataset // eval_batch_size_per_step
 
1223
 
1224
  # log metrics
1225
  metrics_logger.log(eval_metrics, prefix="eval")
1226
+ metrics_logger.log_time("eval", time.perf_counter() - start_eval_time)
1227
 
1228
  # Print metrics and update progress bar
1229
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
 
1235
  def run_save_model(state, eval_metrics=None):
1236
  if jax.process_index() == 0:
1237
 
1238
+ start_save_time = time.perf_counter()
1239
  output_dir = training_args.output_dir
1240
  use_bucket = output_dir.startswith("gs://")
1241
  if use_bucket:
 
1328
  f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
1329
  )
1330
  wandb.run.log_artifact(artifact_state)
1331
+ metrics_logger.log_time("save_model", time.perf_counter() - start_save_time)
1332
 
1333
  logger.info(" Ready to start training")
1334
+ with mesh:
1335
  for epoch in epochs:
1336
  state.replace(epoch=epoch)
1337
+ local_state["epoch"] = epoch
1338
  # ======================== Training ================================
1339
+ metrics_logger.update_state_metrics(local_state)
1340
  metrics_logger.log({})
1341
 
1342
  # Generate an epoch by shuffling sampling indices from the train dataset
 
1355
  disable=jax.process_index() > 0,
1356
  ):
1357
  # calculate delta time (we have a lag of one step but it's ok)
1358
+ train_time = time.perf_counter() - start_time
 
 
1359
 
1360
  # set correct shape to batch
1361
  # - add grad_step dim if gradient_accumulation_steps > 1
 
1383
  batch = freeze(batch)
1384
 
1385
  # train step
1386
+ state, train_metrics = p_train_step(state, batch, train_time)
1387
+ local_state["step"] += 1
1388
+ local_state["train_time"] = train_time
1389
+ local_state["train_samples"] += batch_size_per_step
1390
+
1391
+ if (
1392
+ local_state["step"] % training_args.logging_steps == 0
1393
+ and jax.process_index() == 0
1394
+ ):
1395
+ metrics_logger.update_state_metrics(local_state)
1396
  metrics_logger.log(train_metrics, prefix="train")
1397
 
1398
  eval_metrics = None
1399
+ if local_state["step"] % training_args.eval_steps == 0:
1400
  eval_metrics = run_evaluation()
1401
 
1402
+ if local_state["step"] % training_args.save_steps == 0:
1403
  run_save_model(state, eval_metrics)
1404
 
1405
  # log final train metrics