boris commited on
Commit
3a3d375
2 Parent(s): 26651dd af807f7

Merge pull request #115 from borisdayma/feat-shampoo

Browse files
README.md CHANGED
@@ -154,3 +154,14 @@ year = {2021}
154
  primaryClass={cs.CV}
155
  }
156
  ```
 
 
 
 
 
 
 
 
 
 
 
 
154
  primaryClass={cs.CV}
155
  }
156
  ```
157
+
158
+ ```
159
+ @misc{anil2021scalable,
160
+ title={Scalable Second Order Optimization for Deep Learning},
161
+ author={Rohan Anil and Vineet Gupta and Tomer Koren and Kevin Regan and Yoram Singer},
162
+ year={2021},
163
+ eprint={2002.09018},
164
+ archivePrefix={arXiv},
165
+ primaryClass={cs.LG}
166
+ }
167
+ ```
dalle_mini/data.py CHANGED
@@ -4,6 +4,7 @@ from functools import partial
4
  import jax
5
  import jax.numpy as jnp
6
  import numpy as np
 
7
  from datasets import Dataset, load_dataset
8
  from flax.training.common_utils import shard
9
 
@@ -15,12 +16,10 @@ class Dataset:
15
  dataset_repo_or_path: str
16
  train_file: str = None
17
  validation_file: str = None
18
- dataset_type: str = "dataset"
19
  streaming: bool = True
20
  use_auth_token: bool = False
21
  text_column: str = "caption"
22
  encoding_column: str = "encoding"
23
- max_source_length: int = 128
24
  max_train_samples: int = None
25
  max_eval_samples: int = None
26
  preprocessing_num_workers: int = None
@@ -28,13 +27,30 @@ class Dataset:
28
  do_train: bool = False
29
  do_eval: bool = True
30
  seed_dataset: int = None
 
31
  train_dataset: Dataset = field(init=False)
32
  eval_dataset: Dataset = field(init=False)
33
  rng_dataset: jnp.ndarray = field(init=False)
 
34
 
35
  def __post_init__(self):
 
36
  # define data_files
37
  if self.train_file is not None or self.validation_file is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  data_files = {
39
  "train": self.train_file,
40
  "validation": self.validation_file,
@@ -70,7 +86,7 @@ class Dataset:
70
  else self.eval_dataset.select(range(self.max_eval_samples))
71
  )
72
 
73
- def preprocess(self, tokenizer, decoder_start_token_id, normalize_text):
74
  if self.streaming:
75
  # we need to shuffle early in streaming mode
76
  if hasattr(self, "train_dataset"):
@@ -112,7 +128,7 @@ class Dataset:
112
  tokenizer=tokenizer,
113
  text_column=self.text_column,
114
  encoding_column=self.encoding_column,
115
- max_source_length=self.max_source_length,
116
  decoder_start_token_id=decoder_start_token_id,
117
  )
118
  for ds in ["train_dataset", "eval_dataset"]:
@@ -165,17 +181,29 @@ class Dataset:
165
  batch = shard(batch)
166
  yield batch
167
 
168
- def _dataloader_datasets_streaming(dataset: Dataset, batch_size: int):
 
 
 
169
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
170
  batch = {k: [] for k in keys}
171
- for item in dataset:
172
- for k, v in item.items():
173
- batch[k].append(v)
174
- if len(batch[keys[0]]) == batch_size:
175
- batch = {k: jnp.array(v) for k, v in batch.items()}
176
- batch = shard(batch)
177
- yield batch
178
- batch = {k: [] for k in keys}
 
 
 
 
 
 
 
 
 
179
 
180
  if split == "train":
181
  ds = self.train_dataset
@@ -187,7 +215,7 @@ class Dataset:
187
  if self.streaming:
188
  if split == "train":
189
  ds.set_epoch(epoch)
190
- return _dataloader_datasets_streaming(ds, batch_size)
191
  else:
192
  if split == "train":
193
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
@@ -232,14 +260,14 @@ def preprocess_function(
232
  tokenizer,
233
  text_column,
234
  encoding_column,
235
- max_source_length,
236
  decoder_start_token_id,
237
  ):
238
  inputs = examples[text_column]
239
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
240
  model_inputs = tokenizer(
241
  inputs,
242
- max_length=max_source_length,
243
  padding="max_length",
244
  truncation=True,
245
  return_tensors="np",
 
4
  import jax
5
  import jax.numpy as jnp
6
  import numpy as np
7
+ from braceexpand import braceexpand
8
  from datasets import Dataset, load_dataset
9
  from flax.training.common_utils import shard
10
 
 
16
  dataset_repo_or_path: str
17
  train_file: str = None
18
  validation_file: str = None
 
19
  streaming: bool = True
20
  use_auth_token: bool = False
21
  text_column: str = "caption"
22
  encoding_column: str = "encoding"
 
23
  max_train_samples: int = None
24
  max_eval_samples: int = None
25
  preprocessing_num_workers: int = None
 
27
  do_train: bool = False
28
  do_eval: bool = True
29
  seed_dataset: int = None
30
+ shard_by_host: bool = False
31
  train_dataset: Dataset = field(init=False)
32
  eval_dataset: Dataset = field(init=False)
33
  rng_dataset: jnp.ndarray = field(init=False)
34
+ multi_hosts: bool = field(init=False)
35
 
36
  def __post_init__(self):
37
+ self.multi_hosts = jax.process_count() > 1
38
  # define data_files
39
  if self.train_file is not None or self.validation_file is not None:
40
+ # accept braceexpand notation
41
+ for k in ["train_file", "validation_file"]:
42
+ f = getattr(self, k)
43
+ if isinstance(f, str):
44
+ setattr(self, k, list(braceexpand(f)))
45
+ # for list of files, split training data shards by host
46
+ if (
47
+ isinstance(self.train_file, list)
48
+ and self.multi_hosts
49
+ and self.shard_by_host
50
+ ):
51
+ self.train_file = self.train_file[
52
+ jax.process_index() :: jax.process_count()
53
+ ]
54
  data_files = {
55
  "train": self.train_file,
56
  "validation": self.validation_file,
 
86
  else self.eval_dataset.select(range(self.max_eval_samples))
87
  )
88
 
89
+ def preprocess(self, tokenizer, decoder_start_token_id, normalize_text, max_length):
90
  if self.streaming:
91
  # we need to shuffle early in streaming mode
92
  if hasattr(self, "train_dataset"):
 
128
  tokenizer=tokenizer,
129
  text_column=self.text_column,
130
  encoding_column=self.encoding_column,
131
+ max_length=max_length,
132
  decoder_start_token_id=decoder_start_token_id,
133
  )
134
  for ds in ["train_dataset", "eval_dataset"]:
 
181
  batch = shard(batch)
182
  yield batch
183
 
184
+ def _dataloader_datasets_streaming(
185
+ dataset: Dataset, batch_size: int, epoch: int
186
+ ):
187
+ # epoch is only use for multi-host
188
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
189
  batch = {k: [] for k in keys}
190
+ first_loop = True
191
+ while self.multi_hosts or first_loop:
192
+ # in multi-host, we run forever (no epoch) as hosts need to stop
193
+ # at the same time and we don't know how much data is on each host
194
+ if not first_loop:
195
+ # multi-host setting, we reshuffle shards
196
+ epoch += 1
197
+ dataset.set_epoch(epoch)
198
+ for item in dataset:
199
+ for k, v in item.items():
200
+ batch[k].append(v)
201
+ if len(batch[keys[0]]) == batch_size:
202
+ batch = {k: jnp.array(v) for k, v in batch.items()}
203
+ batch = shard(batch)
204
+ yield batch
205
+ batch = {k: [] for k in keys}
206
+ first_loop = False
207
 
208
  if split == "train":
209
  ds = self.train_dataset
 
215
  if self.streaming:
216
  if split == "train":
217
  ds.set_epoch(epoch)
218
+ return _dataloader_datasets_streaming(ds, batch_size, epoch)
219
  else:
220
  if split == "train":
221
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
 
260
  tokenizer,
261
  text_column,
262
  encoding_column,
263
+ max_length,
264
  decoder_start_token_id,
265
  ):
266
  inputs = examples[text_column]
267
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
268
  model_inputs = tokenizer(
269
  inputs,
270
+ max_length=max_length,
271
  padding="max_length",
272
  truncation=True,
273
  return_tensors="np",
dalle_mini/model.py DELETED
@@ -1,64 +0,0 @@
1
- import flax.linen as nn
2
- import jax
3
- from transformers import BartConfig
4
- from transformers.models.bart.modeling_flax_bart import (
5
- FlaxBartDecoder,
6
- FlaxBartEncoder,
7
- FlaxBartForConditionalGeneration,
8
- FlaxBartForConditionalGenerationModule,
9
- FlaxBartModule,
10
- )
11
-
12
-
13
- class CustomFlaxBartModule(FlaxBartModule):
14
- def setup(self):
15
- # we keep shared to easily load pre-trained weights
16
- self.shared = nn.Embed(
17
- self.config.vocab_size,
18
- self.config.d_model,
19
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
20
- )
21
- # a separate embedding is used for the decoder
22
- self.decoder_embed = nn.Embed(
23
- self.config.image_vocab_size + 1,
24
- self.config.d_model,
25
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
26
- )
27
- self.encoder = FlaxBartEncoder(
28
- self.config, dtype=self.dtype, embed_tokens=self.shared
29
- )
30
-
31
- # the decoder has a different config
32
- # TODO: should not be needed once we have custom config/module
33
- decoder_config = BartConfig(self.config.to_dict())
34
- decoder_config.max_position_embeddings = (
35
- self.config.image_length + 1 # image tokens + BOS
36
- )
37
- decoder_config.vocab_size = self.config.image_vocab_size + 1
38
- self.decoder = FlaxBartDecoder(
39
- decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed
40
- )
41
-
42
-
43
- class CustomFlaxBartForConditionalGenerationModule(
44
- FlaxBartForConditionalGenerationModule
45
- ):
46
- def setup(self):
47
- # set default config
48
- self.config.normalize_text = getattr(self.config, "normalize_text", False)
49
- self.config.image_length = getattr(self.config, "image_length", 256)
50
- self.config.image_vocab_size = getattr(self.config, "image_vocab_size", 16384)
51
-
52
- self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
53
- self.lm_head = nn.Dense(
54
- self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
55
- use_bias=False,
56
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
57
- )
58
- self.final_logits_bias = self.param(
59
- "final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1)
60
- )
61
-
62
-
63
- class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
64
- module_class = CustomFlaxBartForConditionalGenerationModule
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dalle_mini/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration import DalleBartConfig
2
+ from .modeling import DalleBart
dalle_mini/model/configuration.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ DalleBart model configuration """
16
+ import warnings
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class DalleBartConfig(PretrainedConfig):
25
+ model_type = "dallebart"
26
+ keys_to_ignore_at_inference = ["past_key_values"]
27
+ attribute_map = {
28
+ "num_attention_heads": "encoder_attention_heads",
29
+ "hidden_size": "d_model",
30
+ }
31
+
32
+ def __init__(
33
+ self,
34
+ normalize_text=False,
35
+ encoder_vocab_size=50264,
36
+ image_vocab_size=16384, # encoded image token space
37
+ image_length=256, # number of encoded tokens
38
+ max_text_length=64, # max number of text tokens
39
+ encoder_layers=12,
40
+ encoder_ffn_dim=4096,
41
+ encoder_attention_heads=16,
42
+ decoder_layers=12,
43
+ decoder_ffn_dim=4096,
44
+ decoder_attention_heads=16,
45
+ encoder_layerdrop=0.0,
46
+ decoder_layerdrop=0.0,
47
+ activation_function="gelu",
48
+ d_model=1024,
49
+ dropout=0.1,
50
+ attention_dropout=0.0,
51
+ activation_dropout=0.0,
52
+ init_std=0.02,
53
+ classifier_dropout=0.0,
54
+ scale_embedding=False,
55
+ gradient_checkpointing=False,
56
+ use_cache=True,
57
+ is_encoder_decoder=True,
58
+ forced_eos_token_id=None,
59
+ tie_word_embeddings=False, # different modalities and sizes
60
+ **kwargs,
61
+ ):
62
+ self.normalize_text = normalize_text
63
+ self.encoder_vocab_size = encoder_vocab_size
64
+ self.image_vocab_size = image_vocab_size
65
+ self.image_length = image_length
66
+ self.max_text_length = max_text_length
67
+ self.d_model = d_model
68
+ self.encoder_ffn_dim = encoder_ffn_dim
69
+ self.encoder_layers = encoder_layers
70
+ self.encoder_attention_heads = encoder_attention_heads
71
+ self.decoder_ffn_dim = decoder_ffn_dim
72
+ self.decoder_layers = decoder_layers
73
+ self.decoder_attention_heads = decoder_attention_heads
74
+ self.dropout = dropout
75
+ self.attention_dropout = attention_dropout
76
+ self.activation_dropout = activation_dropout
77
+ self.activation_function = activation_function
78
+ self.init_std = init_std
79
+ self.encoder_layerdrop = encoder_layerdrop
80
+ self.decoder_layerdrop = decoder_layerdrop
81
+ self.classifier_dropout = classifier_dropout
82
+ self.use_cache = use_cache
83
+ self.gradient_checkpointing = gradient_checkpointing
84
+ self.scale_embedding = (
85
+ scale_embedding # scale factor will be sqrt(d_model) if True
86
+ )
87
+
88
+ # remove inferred keys to prevent errors when loading config (passed as kwargs)
89
+ for k in [
90
+ "pad_token_id",
91
+ "bos_token_id",
92
+ "eos_token_id",
93
+ "decoder_start_token_id",
94
+ "min_length",
95
+ "max_length",
96
+ ]:
97
+ kwargs.pop(k, None)
98
+
99
+ super().__init__(
100
+ pad_token_id=image_vocab_size
101
+ + 1, # needed to avoid errors during generation (converted to jnp.array)
102
+ bos_token_id=image_vocab_size + 1, # set to unreachable values
103
+ eos_token_id=image_vocab_size + 1,
104
+ is_encoder_decoder=is_encoder_decoder,
105
+ decoder_start_token_id=image_vocab_size, # BOS appended to vocab
106
+ forced_eos_token_id=forced_eos_token_id,
107
+ tie_word_embeddings=tie_word_embeddings,
108
+ min_length=image_length + 1,
109
+ max_length=image_length + 1,
110
+ **kwargs,
111
+ )
112
+
113
+ # ensure backward compatibility for BART CNN models
114
+ if self.forced_bos_token_id is None and kwargs.get(
115
+ "force_bos_token_to_be_generated", False
116
+ ):
117
+ self.forced_bos_token_id = self.bos_token_id
118
+ warnings.warn(
119
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions."
120
+ "The config can simply be saved and uploaded again to be fixed."
121
+ )
dalle_mini/model/modeling.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and the DalleBart team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ DalleBart model. """
16
+
17
+ import math
18
+ from functools import partial
19
+ from typing import Optional, Tuple
20
+
21
+ import flax.linen as nn
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from flax.core.frozen_dict import unfreeze
25
+ from flax.linen import make_causal_mask
26
+ from flax.traverse_util import flatten_dict
27
+ from jax.random import PRNGKey
28
+ from transformers.modeling_flax_outputs import (
29
+ FlaxCausalLMOutputWithCrossAttentions,
30
+ FlaxSeq2SeqLMOutput,
31
+ )
32
+ from transformers.modeling_flax_utils import ACT2FN
33
+ from transformers.models.bart.modeling_flax_bart import (
34
+ FlaxBartAttention,
35
+ FlaxBartDecoder,
36
+ FlaxBartDecoderLayer,
37
+ FlaxBartDecoderLayerCollection,
38
+ FlaxBartEncoder,
39
+ FlaxBartEncoderLayer,
40
+ FlaxBartEncoderLayerCollection,
41
+ FlaxBartForConditionalGeneration,
42
+ FlaxBartForConditionalGenerationModule,
43
+ FlaxBartModule,
44
+ FlaxBartPreTrainedModel,
45
+ )
46
+ from transformers.utils import logging
47
+
48
+ from .configuration import DalleBartConfig
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+
53
+ class FlaxBartAttention(FlaxBartAttention):
54
+ """
55
+ Edits:
56
+ - causal mask is used only in decoder and considers image_length + 1 (for BOS)
57
+ """
58
+
59
+ def setup(self) -> None:
60
+ self.head_dim = self.embed_dim // self.num_heads
61
+ if self.head_dim * self.num_heads != self.embed_dim:
62
+ raise ValueError(
63
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
64
+ f" and `num_heads`: {self.num_heads})."
65
+ )
66
+
67
+ dense = partial(
68
+ nn.Dense,
69
+ self.embed_dim,
70
+ use_bias=self.bias,
71
+ dtype=self.dtype,
72
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
73
+ )
74
+
75
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
76
+ self.out_proj = dense()
77
+
78
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
79
+
80
+ if self.causal:
81
+ # used only in decoder
82
+ self.causal_mask = make_causal_mask(
83
+ jnp.ones((1, self.config.image_length + 1), dtype="bool"), dtype="bool"
84
+ )
85
+
86
+
87
+ class FlaxBartEncoderLayer(FlaxBartEncoderLayer):
88
+ """
89
+ Edits:
90
+ - no bias
91
+ - use custom FlaxBartAttention
92
+ """
93
+
94
+ def setup(self) -> None:
95
+ self.embed_dim = self.config.d_model
96
+ self.self_attn = FlaxBartAttention(
97
+ config=self.config,
98
+ embed_dim=self.embed_dim,
99
+ num_heads=self.config.encoder_attention_heads,
100
+ dropout=self.config.attention_dropout,
101
+ bias=False,
102
+ dtype=self.dtype,
103
+ )
104
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
105
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
106
+ self.activation_fn = ACT2FN[self.config.activation_function]
107
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
108
+ self.fc1 = nn.Dense(
109
+ self.config.encoder_ffn_dim,
110
+ dtype=self.dtype,
111
+ use_bias=False,
112
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
113
+ )
114
+ self.fc2 = nn.Dense(
115
+ self.embed_dim,
116
+ dtype=self.dtype,
117
+ use_bias=False,
118
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
119
+ )
120
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
121
+
122
+
123
+ class FlaxBartEncoderLayerCollection(FlaxBartEncoderLayerCollection):
124
+ """
125
+ Edits:
126
+ - use custom FlaxBartEncoderLayer
127
+ - allow Gradient Checkpointing (nn.remat)
128
+ """
129
+
130
+ def setup(self):
131
+ layer_module = (
132
+ nn.remat(FlaxBartEncoderLayer)
133
+ if self.config.gradient_checkpointing
134
+ else FlaxBartEncoderLayer
135
+ )
136
+ self.layers = [
137
+ layer_module(self.config, name=str(i), dtype=self.dtype)
138
+ for i in range(self.config.encoder_layers)
139
+ ]
140
+ self.layerdrop = self.config.encoder_layerdrop
141
+
142
+
143
+ class FlaxBartDecoderLayer(FlaxBartDecoderLayer):
144
+ """
145
+ Edits:
146
+ - no bias
147
+ - uses custom FlaxBartAttention
148
+ """
149
+
150
+ def setup(self) -> None:
151
+ self.embed_dim = self.config.d_model
152
+ self.self_attn = FlaxBartAttention(
153
+ config=self.config,
154
+ embed_dim=self.embed_dim,
155
+ num_heads=self.config.decoder_attention_heads,
156
+ dropout=self.config.attention_dropout,
157
+ causal=True,
158
+ bias=False,
159
+ dtype=self.dtype,
160
+ )
161
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
162
+ self.activation_fn = ACT2FN[self.config.activation_function]
163
+ self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
164
+
165
+ self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
166
+ self.encoder_attn = FlaxBartAttention(
167
+ config=self.config,
168
+ embed_dim=self.embed_dim,
169
+ num_heads=self.config.decoder_attention_heads,
170
+ dropout=self.config.attention_dropout,
171
+ bias=False,
172
+ dtype=self.dtype,
173
+ )
174
+ self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
175
+ self.fc1 = nn.Dense(
176
+ self.config.encoder_ffn_dim,
177
+ dtype=self.dtype,
178
+ use_bias=False,
179
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
180
+ )
181
+ self.fc2 = nn.Dense(
182
+ self.embed_dim,
183
+ dtype=self.dtype,
184
+ use_bias=False,
185
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
186
+ )
187
+ self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
188
+
189
+
190
+ class FlaxBartDecoderLayerCollection(FlaxBartDecoderLayerCollection):
191
+ """
192
+ Edits:
193
+ - use custom FlaxBartDecoderLayer
194
+ - allow Gradient Checkpointing (nn.remat)
195
+ """
196
+
197
+ def setup(self):
198
+ layer_module = (
199
+ nn.remat(FlaxBartDecoderLayer)
200
+ if self.config.gradient_checkpointing
201
+ else FlaxBartDecoderLayer
202
+ )
203
+ self.layers = [
204
+ layer_module(self.config, name=str(i), dtype=self.dtype)
205
+ for i in range(self.config.decoder_layers)
206
+ ]
207
+ self.layerdrop = self.config.decoder_layerdrop
208
+
209
+
210
+ class FlaxBartEncoder(FlaxBartEncoder):
211
+ """
212
+ Edits:
213
+ - offset set to 0 (no padding token)
214
+ - use max_text_length instead of max_position_embeddings
215
+ - use custom FlaxBartEncoderLayerCollection
216
+ - embed_tokens cannot be None (issue at compile time)
217
+ """
218
+
219
+ def setup(self):
220
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
221
+
222
+ embed_dim = self.config.d_model
223
+ self.padding_idx = self.config.pad_token_id
224
+ self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
225
+
226
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
227
+ # and adjust num_embeddings appropriately. Other models don't have this hack
228
+ self.offset = 0
229
+ self.embed_positions = nn.Embed(
230
+ self.config.max_text_length + self.offset,
231
+ embed_dim,
232
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
233
+ )
234
+ self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
235
+ self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
236
+
237
+
238
+ class FlaxBartDecoder(FlaxBartDecoder):
239
+ """
240
+ Edits:
241
+ - offset set to 0 (no padding token)
242
+ - use image_length + 1 (for BOS) instead of max_position_embeddings
243
+ - use custom FlaxBartDecoderLayerCollection
244
+ - embed_tokens cannot be None (issue at compile time)
245
+ """
246
+
247
+ def setup(self):
248
+ self.dropout_layer = nn.Dropout(rate=self.config.dropout)
249
+
250
+ embed_dim = self.config.d_model
251
+ self.padding_idx = self.config.pad_token_id
252
+ self.embed_scale = (
253
+ math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
254
+ )
255
+
256
+ # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
257
+ # and adjust num_embeddings appropriately. Other models don't have this hack
258
+ self.offset = 0
259
+ self.embed_positions = nn.Embed(
260
+ self.config.image_length + 1 + self.offset, # image length + 1 for BOS
261
+ embed_dim,
262
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
263
+ )
264
+
265
+ self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
266
+ self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
267
+
268
+
269
+ class FlaxBartModule(FlaxBartModule):
270
+ """
271
+ Edits
272
+ - use custom FlaxBartEncoder & FlaxBartDecoder
273
+ - use separate embeddings for Encoder & Decoder
274
+ """
275
+
276
+ def setup(self):
277
+ encoder_embed_tokens = nn.Embed(
278
+ self.config.encoder_vocab_size,
279
+ self.config.d_model,
280
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
281
+ )
282
+ decoder_embed_tokens = nn.Embed(
283
+ self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
284
+ self.config.d_model,
285
+ embedding_init=jax.nn.initializers.normal(self.config.init_std),
286
+ )
287
+
288
+ self.encoder = FlaxBartEncoder(
289
+ self.config, dtype=self.dtype, embed_tokens=encoder_embed_tokens
290
+ )
291
+ self.decoder = FlaxBartDecoder(
292
+ self.config, dtype=self.dtype, embed_tokens=decoder_embed_tokens
293
+ )
294
+
295
+
296
+ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
297
+ """
298
+ Edits:
299
+ - added num_params property
300
+ - config_class replaced to DalleBartConfig
301
+ - __init__ accepts abstract_init which does uses parameter shape to initialize the model
302
+ """
303
+
304
+ config_class = DalleBartConfig
305
+
306
+ def __init__(
307
+ self,
308
+ config: DalleBartConfig,
309
+ input_shape: Tuple[int] = (1, 1),
310
+ seed: int = 0,
311
+ dtype: jnp.dtype = jnp.float32,
312
+ abstract_init: bool = False,
313
+ **kwargs,
314
+ ):
315
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
316
+
317
+ # adapted from HuggingFace FlaxPreTrainedModel
318
+ if config is None:
319
+ raise ValueError("config cannot be None")
320
+
321
+ if module is None:
322
+ raise ValueError("module cannot be None")
323
+
324
+ # Those are private to be exposed as typed property on derived classes.
325
+ self._config = config
326
+ self._module = module
327
+
328
+ # Those are public as their type is generic to every derived classes.
329
+ self.key = PRNGKey(seed)
330
+ self.dtype = dtype
331
+
332
+ # randomly initialized parameters
333
+ if abstract_init:
334
+ # init the model weights only abstractly, eval_shape will return a pytree
335
+ # with the structure as weights but without any actual values, this will just contain
336
+ # the shape information. Weights need to be loaded later.
337
+ init_fn = partial(self.init_weights, input_shape=input_shape)
338
+ random_params = jax.eval_shape(init_fn, self.key)
339
+ else:
340
+ random_params = self.init_weights(self.key, input_shape)
341
+
342
+ # save required_params as set
343
+ self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
344
+ self.params = random_params
345
+
346
+ @property
347
+ def num_params(self):
348
+ num_params = jax.tree_map(
349
+ lambda param: param.size, flatten_dict(unfreeze(self.params))
350
+ ).values()
351
+ return sum(list(num_params))
352
+
353
+
354
+ class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
355
+ """
356
+ Edits:
357
+ - no bias
358
+ - lm_head set to image_vocab_size + 1 (for BOS)
359
+ - uses custom FlaxBartModule
360
+ """
361
+
362
+ def setup(self):
363
+ self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
364
+ self.lm_head = nn.Dense(
365
+ self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
366
+ use_bias=False,
367
+ dtype=self.dtype,
368
+ kernel_init=jax.nn.initializers.normal(self.config.init_std),
369
+ )
370
+
371
+ def __call__(
372
+ self,
373
+ input_ids,
374
+ attention_mask,
375
+ decoder_input_ids,
376
+ decoder_attention_mask,
377
+ position_ids,
378
+ decoder_position_ids,
379
+ output_attentions: bool = False,
380
+ output_hidden_states: bool = False,
381
+ return_dict: bool = True,
382
+ deterministic: bool = True,
383
+ ):
384
+ outputs = self.model(
385
+ input_ids=input_ids,
386
+ attention_mask=attention_mask,
387
+ decoder_input_ids=decoder_input_ids,
388
+ decoder_attention_mask=decoder_attention_mask,
389
+ position_ids=position_ids,
390
+ decoder_position_ids=decoder_position_ids,
391
+ output_attentions=output_attentions,
392
+ output_hidden_states=output_hidden_states,
393
+ return_dict=return_dict,
394
+ deterministic=deterministic,
395
+ )
396
+
397
+ hidden_states = outputs[0]
398
+
399
+ if self.config.tie_word_embeddings:
400
+ shared_embedding = self.model.variables["params"]["shared"]["embedding"]
401
+ lm_logits = self.lm_head.apply(
402
+ {"params": {"kernel": shared_embedding.T}}, hidden_states
403
+ )
404
+ else:
405
+ lm_logits = self.lm_head(hidden_states)
406
+
407
+ if not return_dict:
408
+ output = (lm_logits,) + outputs[1:]
409
+ return output
410
+
411
+ return FlaxSeq2SeqLMOutput(
412
+ logits=lm_logits,
413
+ decoder_hidden_states=outputs.decoder_hidden_states,
414
+ decoder_attentions=outputs.decoder_attentions,
415
+ cross_attentions=outputs.cross_attentions,
416
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
417
+ encoder_hidden_states=outputs.encoder_hidden_states,
418
+ encoder_attentions=outputs.encoder_attentions,
419
+ )
420
+
421
+
422
+ class DalleBart(FlaxBartPreTrainedModel, FlaxBartForConditionalGeneration):
423
+ """
424
+ Edits:
425
+ - renamed from FlaxBartForConditionalGeneration
426
+ - uses custom FlaxBartPreTrainedModel
427
+ - uses custom FlaxBartForConditionalGenerationModule
428
+ - no bias in decode method
429
+ """
430
+
431
+ module_class = FlaxBartForConditionalGenerationModule
432
+
433
+ def decode(
434
+ self,
435
+ decoder_input_ids,
436
+ encoder_outputs,
437
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
438
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
439
+ decoder_position_ids: Optional[jnp.ndarray] = None,
440
+ past_key_values: dict = None,
441
+ output_attentions: Optional[bool] = None,
442
+ output_hidden_states: Optional[bool] = None,
443
+ return_dict: Optional[bool] = None,
444
+ train: bool = False,
445
+ params: dict = None,
446
+ dropout_rng: PRNGKey = None,
447
+ ):
448
+ output_attentions = (
449
+ output_attentions
450
+ if output_attentions is not None
451
+ else self.config.output_attentions
452
+ )
453
+ output_hidden_states = (
454
+ output_hidden_states
455
+ if output_hidden_states is not None
456
+ else self.config.output_hidden_states
457
+ )
458
+ return_dict = (
459
+ return_dict if return_dict is not None else self.config.return_dict
460
+ )
461
+
462
+ encoder_hidden_states = encoder_outputs[0]
463
+ if encoder_attention_mask is None:
464
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
465
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
466
+
467
+ batch_size, sequence_length = decoder_input_ids.shape
468
+ if decoder_attention_mask is None:
469
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
470
+
471
+ if decoder_position_ids is None:
472
+ if past_key_values is not None:
473
+ raise ValueError(
474
+ "Make sure to provide `decoder_position_ids` when passing `past_key_values`."
475
+ )
476
+
477
+ decoder_position_ids = jnp.broadcast_to(
478
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
479
+ )
480
+
481
+ # Handle any PRNG if needed
482
+ rngs = {}
483
+ if dropout_rng is not None:
484
+ rngs["dropout"] = dropout_rng
485
+
486
+ inputs = {"params": params or self.params}
487
+
488
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
489
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
490
+ # it can be changed by FlaxBartAttention module
491
+ if past_key_values:
492
+ inputs["cache"] = past_key_values
493
+ mutable = ["cache"]
494
+ else:
495
+ mutable = False
496
+
497
+ def _decoder_forward(
498
+ module,
499
+ decoder_input_ids,
500
+ decoder_attention_mask,
501
+ decoder_position_ids,
502
+ **kwargs,
503
+ ):
504
+ decoder_module = module._get_decoder_module()
505
+ outputs = decoder_module(
506
+ decoder_input_ids,
507
+ decoder_attention_mask,
508
+ decoder_position_ids,
509
+ **kwargs,
510
+ )
511
+ hidden_states = outputs[0]
512
+
513
+ if self.config.tie_word_embeddings:
514
+ shared_embedding = module.model.variables["params"]["shared"][
515
+ "embedding"
516
+ ]
517
+ lm_logits = module.lm_head.apply(
518
+ {"params": {"kernel": shared_embedding.T}}, hidden_states
519
+ )
520
+ else:
521
+ lm_logits = module.lm_head(hidden_states)
522
+
523
+ return lm_logits, outputs
524
+
525
+ outputs = self.module.apply(
526
+ inputs,
527
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
528
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
529
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
530
+ encoder_hidden_states=encoder_hidden_states,
531
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
532
+ output_attentions=output_attentions,
533
+ output_hidden_states=output_hidden_states,
534
+ return_dict=return_dict,
535
+ deterministic=not train,
536
+ rngs=rngs,
537
+ mutable=mutable,
538
+ method=_decoder_forward,
539
+ )
540
+
541
+ if past_key_values is None:
542
+ lm_logits, decoder_outputs = outputs
543
+ else:
544
+ (lm_logits, decoder_outputs), past = outputs
545
+
546
+ if return_dict:
547
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
548
+ logits=lm_logits,
549
+ hidden_states=decoder_outputs.hidden_states,
550
+ attentions=decoder_outputs.attentions,
551
+ cross_attentions=decoder_outputs.cross_attentions,
552
+ )
553
+ else:
554
+ outputs = (lm_logits,) + decoder_outputs[1:]
555
+
556
+ # add updated cache to model output
557
+ if past_key_values is not None and return_dict:
558
+ outputs["past_key_values"] = unfreeze(past["cache"])
559
+ return outputs
560
+ elif past_key_values is not None and not return_dict:
561
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
562
+
563
+ return outputs
dalle_mini/model/partitions.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from flax.core.frozen_dict import freeze
4
+ from flax.traverse_util import flatten_dict, unflatten_dict
5
+ from jax.experimental import PartitionSpec as P
6
+
7
+ # utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
8
+ # Sentinels
9
+ _unmatched = object()
10
+
11
+ # For specifying empty leaf dict `{}`
12
+ empty_dict = object()
13
+
14
+
15
+ def _match(qs, ks):
16
+ """Return True if regexes in qs match any window of strings in tuple ks."""
17
+ # compile regexes and force complete match
18
+ qts = tuple(map(lambda x: re.compile(x + "$"), qs))
19
+ for i in range(len(ks) - len(qs) + 1):
20
+ matches = [x.match(y) for x, y in zip(qts, ks[i:])]
21
+ if matches and all(matches):
22
+ return True
23
+ return False
24
+
25
+
26
+ def _replacement_rules(rules):
27
+ def replace(key, val):
28
+ for rule, replacement in rules:
29
+ if _match(rule, key):
30
+ return replacement
31
+ return val
32
+
33
+ return replace
34
+
35
+
36
+ def _get_partition_rules():
37
+ return [
38
+ # embeddings
39
+ ((r"embed_positions", "embedding"), P("mp", None)),
40
+ ((r"embed_tokens", "embedding"), P("mp", None)),
41
+ # self-attention
42
+ ((r"self_attn", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
43
+ ((r"self_attn", "out_proj", "kernel"), P("mp", None)),
44
+ # enc-dec attention
45
+ ((r"encoder_attn", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
46
+ ((r"encoder_attn", "out_proj", "kernel"), P("mp", None)),
47
+ # FFN
48
+ ((r"fc1", "kernel"), P(None, "mp")),
49
+ ((r"fc2", "kernel"), P("mp", None)),
50
+ # layer norms
51
+ ((r"layernorm_embedding", "(bias|scale)"), None),
52
+ ((r"self_attn_layer_norm", "(bias|scale)"), None),
53
+ ((r"encoder_attn_layer_norm", "(bias|scale)"), None),
54
+ ((r"final_layer_norm", "(bias|scale)"), None),
55
+ ((r"lm_head", "kernel"), P(None, "mp")),
56
+ ]
57
+
58
+
59
+ def set_partitions(in_dict):
60
+ rules = _get_partition_rules()
61
+ replace = _replacement_rules(rules)
62
+ initd = {k: _unmatched for k in flatten_dict(in_dict)}
63
+ result = {k: replace(k, v) for k, v in initd.items()}
64
+ for k, v in result.items():
65
+ if v == _unmatched:
66
+ print(k)
67
+ assert _unmatched not in result.values(), "Incomplete partition spec."
68
+ return freeze(unflatten_dict(result))
setup.cfg CHANGED
@@ -23,5 +23,6 @@ dev =
23
  tqdm
24
  wandb
25
  optax
 
26
  black[jupyter]
27
  isort
 
23
  tqdm
24
  wandb
25
  optax
26
+ braceexpand
27
  black[jupyter]
28
  isort
tools/inference/inference_pipeline.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
tools/train/config/medium/config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "attention_dropout": 0.0,
5
+ "bos_token_id": 16385,
6
+ "classifier_dropout": 0.0,
7
+ "d_model": 1536,
8
+ "decoder_attention_heads": 16,
9
+ "decoder_ffn_dim": 4096,
10
+ "decoder_layerdrop": 0.0,
11
+ "decoder_layers": 18,
12
+ "decoder_start_token_id": 16384,
13
+ "dropout": 0.1,
14
+ "encoder_attention_heads": 16,
15
+ "encoder_ffn_dim": 4096,
16
+ "encoder_layerdrop": 0.0,
17
+ "encoder_layers": 18,
18
+ "encoder_vocab_size": 50264,
19
+ "eos_token_id": 16385,
20
+ "gradient_checkpointing": false,
21
+ "image_length": 256,
22
+ "image_vocab_size": 16384,
23
+ "init_std": 0.01,
24
+ "is_encoder_decoder": true,
25
+ "max_text_length": 64,
26
+ "model_type": "dallebart",
27
+ "normalize_text": true,
28
+ "pad_token_id": 16385,
29
+ "scale_embedding": false,
30
+ "tie_word_embeddings": false,
31
+ "transformers_version": "4.13.0.dev0",
32
+ "use_cache": true
33
+ }
tools/train/config/mega/config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "attention_dropout": 0.0,
5
+ "bos_token_id": 16385,
6
+ "classifier_dropout": 0.0,
7
+ "d_model": 2048,
8
+ "decoder_attention_heads": 16,
9
+ "decoder_ffn_dim": 4096,
10
+ "decoder_layerdrop": 0.0,
11
+ "decoder_layers": 31,
12
+ "decoder_start_token_id": 16384,
13
+ "dropout": 0.1,
14
+ "encoder_attention_heads": 16,
15
+ "encoder_ffn_dim": 4096,
16
+ "encoder_layerdrop": 0.0,
17
+ "encoder_layers": 31,
18
+ "encoder_vocab_size": 50264,
19
+ "eos_token_id": 16385,
20
+ "gradient_checkpointing": false,
21
+ "image_length": 256,
22
+ "image_vocab_size": 16384,
23
+ "init_std": 0.01,
24
+ "is_encoder_decoder": true,
25
+ "max_text_length": 64,
26
+ "model_type": "dallebart",
27
+ "normalize_text": true,
28
+ "pad_token_id": 16385,
29
+ "scale_embedding": false,
30
+ "tie_word_embeddings": false,
31
+ "transformers_version": "4.13.0.dev0",
32
+ "use_cache": true
33
+ }
tools/train/config/micro/config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "attention_dropout": 0.0,
5
+ "bos_token_id": 16385,
6
+ "classifier_dropout": 0.0,
7
+ "d_model": 1024,
8
+ "decoder_attention_heads": 16,
9
+ "decoder_ffn_dim": 2048,
10
+ "decoder_layerdrop": 0.0,
11
+ "decoder_layers": 6,
12
+ "decoder_start_token_id": 16384,
13
+ "dropout": 0.1,
14
+ "encoder_attention_heads": 16,
15
+ "encoder_ffn_dim": 2048,
16
+ "encoder_layerdrop": 0.0,
17
+ "encoder_layers": 6,
18
+ "encoder_vocab_size": 50264,
19
+ "eos_token_id": 16385,
20
+ "gradient_checkpointing": false,
21
+ "image_length": 256,
22
+ "image_vocab_size": 16384,
23
+ "init_std": 0.02,
24
+ "is_encoder_decoder": true,
25
+ "max_text_length": 64,
26
+ "model_type": "dallebart",
27
+ "normalize_text": true,
28
+ "pad_token_id": 16385,
29
+ "scale_embedding": false,
30
+ "tie_word_embeddings": false,
31
+ "transformers_version": "4.13.0.dev0",
32
+ "use_cache": true
33
+ }
tools/train/config/mini/config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "activation_function": "gelu",
4
+ "attention_dropout": 0.0,
5
+ "bos_token_id": 16385,
6
+ "classifier_dropout": 0.0,
7
+ "d_model": 1024,
8
+ "decoder_attention_heads": 16,
9
+ "decoder_ffn_dim": 4096,
10
+ "decoder_layerdrop": 0.0,
11
+ "decoder_layers": 12,
12
+ "decoder_start_token_id": 16384,
13
+ "dropout": 0.1,
14
+ "encoder_attention_heads": 16,
15
+ "encoder_ffn_dim": 4096,
16
+ "encoder_layerdrop": 0.0,
17
+ "encoder_layers": 12,
18
+ "encoder_vocab_size": 50264,
19
+ "eos_token_id": 16385,
20
+ "gradient_checkpointing": false,
21
+ "image_length": 256,
22
+ "image_vocab_size": 16384,
23
+ "init_std": 0.02,
24
+ "is_encoder_decoder": true,
25
+ "max_text_length": 64,
26
+ "model_type": "dallebart",
27
+ "normalize_text": true,
28
+ "pad_token_id": 16385,
29
+ "scale_embedding": false,
30
+ "tie_word_embeddings": false,
31
+ "transformers_version": "4.13.0.dev0",
32
+ "use_cache": true
33
+ }
tools/train/distributed_shampoo.py ADDED
@@ -0,0 +1,1826 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """File copied from https://github.com/google-research/google-research/edit/master/scalable_shampoo/optax/distributed_shampoo.py"""
2
+
3
+ # coding=utf-8
4
+ # Copyright 2021 The Google Research Authors.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ # An implementation of distributed Shampoo optimizer from:
19
+ #
20
+ # Scalable Second Order Optimization for Deep Learning
21
+ # Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
22
+ # Preprint Paper: https://arxiv.org/abs/2002.09018
23
+ #
24
+ # This implementation moves computation of inverse pth root back to the
25
+ # accelerator (if higher precision is available).
26
+ #
27
+ # Authors: Rohan Anil (rohananil at google dot com)
28
+ # & Vineet Gupta (vineet at google dot com)
29
+ #
30
+
31
+ """Distributed Shampoo Implementation."""
32
+
33
+ import enum
34
+ import functools
35
+ import itertools
36
+ from typing import Any, List, NamedTuple
37
+
38
+ import chex
39
+ import jax
40
+ import jax.experimental.pjit as pjit
41
+ import jax.numpy as jnp
42
+ import numpy as np
43
+ import optax
44
+ from flax import struct
45
+ from jax import lax
46
+
47
+
48
+ # pylint:disable=no-value-for-parameter
49
+ @struct.dataclass
50
+ class QuantizedValue:
51
+ """State associated with quantized value."""
52
+
53
+ quantized: chex.Array
54
+ diagonal: chex.Array # Diagonal (if extract_diagonal is set)
55
+ bucket_size: chex.Array
56
+ quantized_dtype: jnp.dtype = struct.field(
57
+ pytree_node=False
58
+ ) # Dtype for the quantized value.
59
+ extract_diagonal: bool = struct.field(pytree_node=False) # In case its centered.
60
+ shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
61
+
62
+ @classmethod
63
+ def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
64
+ if isinstance(fvalue, list) and not fvalue:
65
+ return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
66
+ quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
67
+ fvalue, quantized_dtype, extract_diagonal
68
+ )
69
+ return QuantizedValue(
70
+ quantized,
71
+ diagonal_fvalue,
72
+ bucket_size,
73
+ quantized_dtype,
74
+ extract_diagonal,
75
+ list(quantized.shape),
76
+ )
77
+
78
+ # Quantization is from Lingvo JAX optimizers.
79
+ # We extend it for int16 quantization of PSD matrices.
80
+ @classmethod
81
+ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
82
+ """Returns quantized value and the bucket."""
83
+ if quantized_dtype == jnp.float32:
84
+ return fvalue, [], []
85
+ elif quantized_dtype == jnp.bfloat16:
86
+ return fvalue.astype(jnp.bfloat16), [], []
87
+
88
+ float_dtype = fvalue.dtype
89
+ if quantized_dtype == jnp.int8:
90
+ # value -128 is not used.
91
+ num_buckets = jnp.array(127.0, dtype=float_dtype)
92
+ elif quantized_dtype == jnp.int16:
93
+ # value -32768 is not used.
94
+ num_buckets = jnp.array(32767.0, dtype=float_dtype)
95
+ else:
96
+ raise ValueError(f"Quantized dtype {quantized_dtype} not supported.")
97
+ # max value is mapped to num_buckets
98
+
99
+ if extract_diagonal and fvalue.ndim != 2:
100
+ raise ValueError(
101
+ f"Input array {fvalue} must be 2D to work with extract_diagonal."
102
+ )
103
+
104
+ diagonal_fvalue = []
105
+ if extract_diagonal:
106
+ diagonal_fvalue = jnp.diag(fvalue)
107
+ # Remove the diagonal entries.
108
+ fvalue = fvalue - jnp.diag(diagonal_fvalue)
109
+
110
+ # TODO(rohananil): Extend this by making use of information about the blocks
111
+ # SM3 style which will be useful for diagonal statistics
112
+ # We first decide the scale.
113
+ if fvalue.ndim < 1:
114
+ raise ValueError(
115
+ f"Input array {fvalue} must have a strictly positive number of "
116
+ "dimensions."
117
+ )
118
+
119
+ max_abs = jnp.max(jnp.abs(fvalue), axis=0)
120
+ bucket_size = max_abs / num_buckets
121
+ bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
122
+ # To avoid divide by 0.0
123
+ bs_nonzero = jnp.where(
124
+ bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)
125
+ )
126
+ ratio = fvalue / bs_nonzero
127
+ # We use rounding to remove bias.
128
+ quantized = jnp.round(ratio)
129
+ return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
130
+
131
+ def to_float(self):
132
+ """Returns the float value."""
133
+ if isinstance(self.quantized, list) and not self.quantized:
134
+ return self.quantized
135
+
136
+ if self.quantized_dtype == jnp.float32:
137
+ return self.quantized
138
+
139
+ if self.quantized_dtype == jnp.bfloat16:
140
+ return self.quantized.astype(jnp.float32)
141
+
142
+ float_dtype = self.bucket_size.dtype
143
+ bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
144
+ val = self.quantized.astype(float_dtype) * bucket_size
145
+ if self.extract_diagonal:
146
+ val += jnp.diag(self.diagonal)
147
+ return val
148
+
149
+
150
+ # Per parameter optimizer state used in data-parallel training.
151
+ class ParameterStats(NamedTuple):
152
+ """State associated to each parameter of the model being trained."""
153
+
154
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
155
+ statistics: List[Any] # Statistics (QuantizedValue, chex.Array)
156
+ preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
157
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
158
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
159
+
160
+
161
+ # For training extremely large model; We keep a global state with a concatenated
162
+ # statistics and preconditioner states for all vars. This is so that we can
163
+ # annotate the leading axis to be sharded to save memory at the cost of
164
+ # communication.
165
+ @struct.dataclass
166
+ class GlobalShardedParameterStats:
167
+ statistics: chex.Array # Statistics
168
+ preconditioners: chex.Array # Preconditioners
169
+
170
+
171
+ # These are per-parameter local states; All statistics here mirror the parameter
172
+ # Thus the sharding is copied over from the param specification.
173
+ @struct.dataclass
174
+ class LocalShardedParameterStats:
175
+ """State associated to each parameter of the model being trained."""
176
+
177
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
178
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
179
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
180
+ index_start: np.int32 = struct.field(
181
+ pytree_node=False
182
+ ) # Index into global statistics array
183
+ sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
184
+
185
+
186
+ class ShardedShampooStats(NamedTuple):
187
+ """Shampoo state in sharded mode."""
188
+
189
+ global_stats: Any
190
+ local_stats: Any
191
+
192
+
193
+ class ShampooState(NamedTuple):
194
+ count: chex.Array
195
+ stats: Any
196
+
197
+
198
+ class GraftingType(enum.IntEnum):
199
+ SGD = 1
200
+ ADAGRAD = 2
201
+ RMSPROP = 3
202
+ RMSPROP_NORMALIZED = 4
203
+
204
+
205
+ def power_iteration(
206
+ matrix, num_iters=100, error_tolerance=1e-6, precision=lax.Precision.HIGHEST
207
+ ):
208
+ r"""Power iteration algorithm.
209
+
210
+ The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
211
+ a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
212
+ of `A`, and a vector v, which is the corresponding eigenvector of `A`.
213
+
214
+ References:
215
+ [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
216
+
217
+ Args:
218
+ matrix: the symmetric PSD matrix.
219
+ num_iters: Number of iterations.
220
+ error_tolerance: Iterative exit condition.
221
+ precision: precision XLA related flag, the available options are:
222
+ a) lax.Precision.DEFAULT (better step time, but not precise)
223
+ b) lax.Precision.HIGH (increased precision, slower)
224
+ c) lax.Precision.HIGHEST (best possible precision, slowest)
225
+
226
+ Returns:
227
+ eigen vector, eigen value
228
+ """
229
+ matrix_size = matrix.shape[-1]
230
+
231
+ def _iter_condition(state):
232
+ i, unused_v, unused_s, unused_s_v, run_step = state
233
+ return jnp.logical_and(i < num_iters, run_step)
234
+
235
+ def _iter_body(state):
236
+ """One step of power iteration."""
237
+ i, new_v, s, s_v, unused_run_step = state
238
+ new_v = new_v / jnp.linalg.norm(new_v)
239
+
240
+ s_v = jnp.einsum("ij,j->i", matrix, new_v, precision=precision)
241
+ s_new = jnp.einsum("i,i->", new_v, s_v, precision=precision)
242
+ return (
243
+ i + 1,
244
+ s_v,
245
+ s_new,
246
+ s_v,
247
+ jnp.greater(jnp.abs(s_new - s), error_tolerance),
248
+ )
249
+
250
+ # Figure out how to use step as seed for random.
251
+ v_0 = (
252
+ np.random.RandomState(1729).uniform(-1.0, 1.0, matrix_size).astype(matrix.dtype)
253
+ )
254
+
255
+ init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
256
+ _, v_out, s_out, _, _ = lax.while_loop(_iter_condition, _iter_body, init_state)
257
+ v_out = v_out / jnp.linalg.norm(v_out)
258
+ return v_out, s_out
259
+
260
+
261
+ def matrix_inverse_pth_root(
262
+ matrix,
263
+ p,
264
+ num_iters=100,
265
+ ridge_epsilon=1e-6,
266
+ error_tolerance=1e-6,
267
+ precision=lax.Precision.HIGHEST,
268
+ ):
269
+ """Computes `matrix^(-1/p)`, where `p` is a positive integer.
270
+
271
+ This function uses the Coupled newton iterations algorithm for
272
+ the computation of a matrix's inverse pth root.
273
+
274
+
275
+ References:
276
+ [Functions of Matrices, Theory and Computation,
277
+ Nicholas J Higham, Pg 184, Eq 7.18](
278
+ https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
279
+
280
+ Args:
281
+ matrix: the symmetric PSD matrix whose power it to be computed
282
+ p: exponent, for p a positive integer.
283
+ num_iters: Maximum number of iterations.
284
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
285
+ error_tolerance: Error indicator, useful for early termination.
286
+ precision: precision XLA related flag, the available options are:
287
+ a) lax.Precision.DEFAULT (better step time, but not precise)
288
+ b) lax.Precision.HIGH (increased precision, slower)
289
+ c) lax.Precision.HIGHEST (best possible precision, slowest)
290
+
291
+ Returns:
292
+ matrix^(-1/p)
293
+ """
294
+
295
+ # We use float32 for the matrix inverse pth root.
296
+ # Switch to f64 if you have hardware that supports it.
297
+ matrix_size = matrix.shape[0]
298
+ alpha = jnp.asarray(-1.0 / p, jnp.float32)
299
+ identity = jnp.eye(matrix_size, dtype=jnp.float32)
300
+ _, max_ev = power_iteration(
301
+ matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision
302
+ )
303
+ ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16)
304
+
305
+ def _unrolled_mat_pow_1(mat_m):
306
+ """Computes mat_m^1."""
307
+ return mat_m
308
+
309
+ def _unrolled_mat_pow_2(mat_m):
310
+ """Computes mat_m^2."""
311
+ return jnp.matmul(mat_m, mat_m, precision=precision)
312
+
313
+ def _unrolled_mat_pow_4(mat_m):
314
+ """Computes mat_m^4."""
315
+ mat_pow_2 = _unrolled_mat_pow_2(mat_m)
316
+ return jnp.matmul(mat_pow_2, mat_pow_2, precision=precision)
317
+
318
+ def _unrolled_mat_pow_8(mat_m):
319
+ """Computes mat_m^4."""
320
+ mat_pow_4 = _unrolled_mat_pow_4(mat_m)
321
+ return jnp.matmul(mat_pow_4, mat_pow_4, precision=precision)
322
+
323
+ def mat_power(mat_m, p):
324
+ """Computes mat_m^p, for p == 1, 2, 4 or 8.
325
+
326
+ Args:
327
+ mat_m: a square matrix
328
+ p: a positive integer
329
+
330
+ Returns:
331
+ mat_m^p
332
+ """
333
+ # We unrolled the loop for performance reasons.
334
+ exponent = jnp.round(jnp.log2(p))
335
+ return lax.switch(
336
+ jnp.asarray(exponent, jnp.int32),
337
+ [
338
+ _unrolled_mat_pow_1,
339
+ _unrolled_mat_pow_2,
340
+ _unrolled_mat_pow_4,
341
+ _unrolled_mat_pow_8,
342
+ ],
343
+ (mat_m),
344
+ )
345
+
346
+ def _iter_condition(state):
347
+ (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state
348
+ error_above_threshold = jnp.logical_and(error > error_tolerance, run_step)
349
+ return jnp.logical_and(i < num_iters, error_above_threshold)
350
+
351
+ def _iter_body(state):
352
+ (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state
353
+ mat_m_i = (1 - alpha) * identity + alpha * mat_m
354
+ new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision)
355
+ new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision)
356
+ new_error = jnp.max(jnp.abs(new_mat_m - identity))
357
+ # sometimes error increases after an iteration before decreasing and
358
+ # converging. 1.2 factor is used to bound the maximal allowed increase.
359
+ return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, new_error < error * 1.2)
360
+
361
+ if matrix_size == 1:
362
+ resultant_mat_h = (matrix + ridge_epsilon) ** alpha
363
+ error = 0
364
+ else:
365
+ damped_matrix = matrix + ridge_epsilon * identity
366
+
367
+ z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix))
368
+ new_mat_m_0 = damped_matrix * z
369
+ new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
370
+ new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
371
+ init_state = tuple([0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True])
372
+ _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
373
+ _iter_condition, _iter_body, init_state
374
+ )
375
+ error = jnp.max(jnp.abs(mat_m - identity))
376
+ is_converged = jnp.asarray(convergence, old_mat_h.dtype)
377
+ resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
378
+ resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype)
379
+ return resultant_mat_h, error
380
+
381
+
382
+ def merge_small_dims(shape_to_merge, max_dim):
383
+ """Merge small dimensions.
384
+
385
+ If there are some small dimensions, we collapse them:
386
+ e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
387
+ [1, 2, 768, 1, 2048] --> [2, 768, 2048]
388
+
389
+ Args:
390
+ shape_to_merge: Shape to merge small dimensions.
391
+ max_dim: Maximal dimension of output shape used in merging.
392
+
393
+ Returns:
394
+ Merged shape.
395
+ """
396
+ resulting_shape = []
397
+ product = 1
398
+ for d in shape_to_merge:
399
+ if product * d <= max_dim:
400
+ product *= d
401
+ else:
402
+ if product > 1:
403
+ resulting_shape.append(product)
404
+ product = d
405
+ if product > 1:
406
+ resulting_shape.append(product)
407
+ return resulting_shape
408
+
409
+
410
+ def pad_matrix(mat, max_size):
411
+ """Pad a matrix to a max_size.
412
+
413
+ Args:
414
+ mat: a matrix to pad.
415
+ max_size: matrix size requested.
416
+
417
+ Returns:
418
+ Given M returns [[M, 0], [0, I]]
419
+ """
420
+ size = mat.shape[0]
421
+ assert size <= max_size
422
+ if size == max_size:
423
+ return mat
424
+ pad_size = max_size - size
425
+ zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype)
426
+ zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype)
427
+ eye = jnp.eye(pad_size, dtype=mat.dtype)
428
+ mat = jnp.concatenate([mat, zs1], 1)
429
+ mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
430
+ return mat
431
+
432
+
433
+ def pad_vector(vec, max_size):
434
+ """Pad a vector to a max_size.
435
+
436
+ Args:
437
+ vec: a vector to pad.
438
+ max_size: matrix size requested.
439
+
440
+ Returns:
441
+ Given V returns [V, 0]
442
+ """
443
+ size = vec.shape[0]
444
+ assert size <= max_size
445
+ if size == max_size:
446
+ return vec
447
+ pad_size = max_size - size
448
+ zs1 = jnp.zeros([pad_size], dtype=vec.dtype)
449
+ return jnp.concatenate([vec, zs1], 0)
450
+
451
+
452
+ def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
453
+ """Avoids wasteful buffer allocation with XLA."""
454
+
455
+ def _iter_body(unused_state):
456
+ results = compute_fn(*args, **kwargs)
457
+ return tuple([False] + list(results))
458
+
459
+ def _iter_condition(state):
460
+ return state[0]
461
+
462
+ results = jax.lax.while_loop(
463
+ _iter_condition, _iter_body, tuple([predicate] + init_state)
464
+ )
465
+ return tuple(results[1:])
466
+
467
+
468
+ class BlockPartitioner:
469
+ """Partitions a tensor into smaller tensors."""
470
+
471
+ def __init__(self, param, block_size):
472
+ self._shape = param.shape
473
+ self._splits = []
474
+ split_sizes = []
475
+ # We split params into smaller blocks. Here we store the metadata to make
476
+ # that split.
477
+ for i, d in enumerate(param.shape):
478
+ if 0 < block_size < d:
479
+ # d-1, otherwise split appends a 0-size array.
480
+ nsplit = (d - 1) // block_size
481
+ indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size
482
+ sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size
483
+ sizes[-1] = d - indices[-1]
484
+ self._splits.append((i, indices))
485
+ split_sizes.append(sizes)
486
+ else:
487
+ split_sizes.append(np.array([d], dtype=np.int32))
488
+ self._num_splits = len(split_sizes)
489
+ self._preconditioner_shapes = []
490
+ for t in itertools.product(*split_sizes):
491
+ self._preconditioner_shapes.extend([[d, d] for d in t])
492
+
493
+ def shapes_for_preconditioners(self):
494
+ return self._preconditioner_shapes
495
+
496
+ def num_splits(self):
497
+ return self._num_splits
498
+
499
+ def partition(self, tensor):
500
+ """Partition tensor into blocks."""
501
+
502
+ assert tensor.shape == self._shape
503
+ tensors = [tensor]
504
+ for (i, indices) in self._splits:
505
+ tensors_local = []
506
+ for t in tensors:
507
+ tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
508
+ tensors = tensors_local
509
+ return tensors
510
+
511
+ def merge_partitions(self, partitions):
512
+ """Merge partitions back to original shape."""
513
+
514
+ for (i, indices) in reversed(self._splits):
515
+ n = len(indices) + 1
516
+ partial_merged_tensors = []
517
+ ind = 0
518
+ while ind < len(partitions):
519
+ partial_merged_tensors.append(
520
+ jnp.concatenate(partitions[ind : ind + n], axis=i)
521
+ )
522
+ ind += n
523
+ partitions = partial_merged_tensors
524
+ assert len(partitions) == 1
525
+ return partitions[0]
526
+
527
+
528
+ class Preconditioner:
529
+ """Compute statistics/shape from gradients for preconditioning."""
530
+
531
+ def __init__(self, param, block_size, best_effort_shape_interpretation):
532
+ self._original_shape = param.shape
533
+ self._transformed_shape = param.shape
534
+ if best_effort_shape_interpretation:
535
+ self._transformed_shape = merge_small_dims(self._original_shape, block_size)
536
+ reshaped_param = jnp.reshape(param, self._transformed_shape)
537
+ self._partitioner = BlockPartitioner(reshaped_param, block_size)
538
+
539
+ def statistics_from_grad(self, grad):
540
+ """Compute statistics from gradients.
541
+
542
+ Args:
543
+ grad: Gradient to compute statistics from.
544
+
545
+ Returns:
546
+ A list of gradient statistics for each partition.
547
+ """
548
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
549
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
550
+ stats = []
551
+ for g in partitioned_grads:
552
+ g_stats = []
553
+ rank = len(g.shape)
554
+ for i in range(rank):
555
+ axes = list(range(i)) + list(range(i + 1, rank))
556
+ stat = jnp.tensordot(g, g, axes=(axes, axes))
557
+ g_stats.append(stat)
558
+ stats.extend(g_stats)
559
+ return stats
560
+
561
+ def shapes_for_preconditioners(self):
562
+ """Returns shape from statistics."""
563
+ return self._partitioner.shapes_for_preconditioners()
564
+
565
+ def exponent_for_preconditioner(self):
566
+ """Returns exponent to use for inverse-pth root M^{-1/p}."""
567
+ return 2 * len(self._transformed_shape)
568
+
569
+ def preconditioned_grad(self, grad, preconditioners):
570
+ """Precondition the gradient.
571
+
572
+ Args:
573
+ grad: A gradient tensor to precondition.
574
+ preconditioners: A list of preconditioners to apply.
575
+
576
+ Returns:
577
+ A preconditioned gradient.
578
+ """
579
+
580
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
581
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
582
+ preconditioned_partitioned_grads = []
583
+ num_splits = self._partitioner.num_splits()
584
+ for i, g in enumerate(partitioned_grads):
585
+ preconditioners_for_grad = preconditioners[
586
+ i * num_splits : (i + 1) * num_splits
587
+ ]
588
+ rank = len(g.shape)
589
+ precond_g = g
590
+ for j in range(rank):
591
+ precond_g = jnp.tensordot(
592
+ precond_g, preconditioners_for_grad[j], axes=[[0], [0]]
593
+ )
594
+ preconditioned_partitioned_grads.append(precond_g)
595
+ merged_grad = self._partitioner.merge_partitions(
596
+ preconditioned_partitioned_grads
597
+ )
598
+ return jnp.reshape(merged_grad, self._original_shape)
599
+
600
+
601
+ def _convert_to_parameter_stats(global_stats, local_stat):
602
+ """Creates parameter stats from sharded stats."""
603
+ index_start = int(local_stat.index_start)
604
+ index_end = int(len(local_stat.sizes)) + index_start
605
+ statistics = global_stats.statistics[index_start:index_end, :, :]
606
+ preconditioners = global_stats.preconditioners[index_start:index_end, :, :]
607
+ new_statistics = []
608
+ new_preconditioners = []
609
+ for i, size in enumerate(local_stat.sizes):
610
+ new_statistics.append(statistics[i][:size, :size])
611
+ new_preconditioners.append(preconditioners[i][:size, :size])
612
+ return ParameterStats(
613
+ local_stat.diagonal_statistics,
614
+ new_statistics,
615
+ new_preconditioners,
616
+ local_stat.diagonal_momentum,
617
+ local_stat.momentum,
618
+ )
619
+
620
+
621
+ def _convert_from_parameter_stats(parameter_stats, local_stats):
622
+ """Creates sharded stats from paramter stats."""
623
+ return LocalShardedParameterStats(
624
+ parameter_stats.diagonal_statistics,
625
+ parameter_stats.diagonal_momentum,
626
+ parameter_stats.momentum,
627
+ local_stats.index_start,
628
+ local_stats.sizes,
629
+ )
630
+
631
+
632
+ def batch(x, num_devices):
633
+ """Batch `x` so that so that leading axis is num_devices."""
634
+ n = len(x)
635
+ b = int(n / num_devices)
636
+ return jnp.stack([jnp.stack(x[idx : idx + b]) for idx in range(0, n, b)])
637
+
638
+
639
+ def unbatch(batched_values):
640
+ """Unbatch values across leading axis and return a list of elements."""
641
+ b1, b2 = batched_values.shape[0], batched_values.shape[1]
642
+ results = []
643
+ for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0):
644
+ v_array = jnp.squeeze(v_array)
645
+ # b2 = batches (number of preconditioner computation) per core.
646
+ if b2 > 1:
647
+ for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
648
+ results.append(jnp.squeeze(v))
649
+ else:
650
+ results.append(v_array)
651
+ return results
652
+
653
+
654
+ def distributed_shampoo(
655
+ learning_rate,
656
+ block_size,
657
+ beta1=0.9,
658
+ beta2=0.999,
659
+ diagonal_epsilon=1e-10,
660
+ matrix_epsilon=1e-6,
661
+ weight_decay=0.0,
662
+ start_preconditioning_step=5,
663
+ preconditioning_compute_steps=1,
664
+ statistics_compute_steps=1,
665
+ best_effort_shape_interpretation=True,
666
+ graft_type=GraftingType.SGD,
667
+ nesterov=True,
668
+ exponent_override=0,
669
+ # Pass pmap 'batch axis name' in pmap mode.
670
+ batch_axis_name=None,
671
+ ### Only set following 3 params in pjit/spmd mode.
672
+ ### WARNING: Experimental
673
+ mesh_axis_names=None,
674
+ num_devices_for_pjit=None,
675
+ shard_optimizer_states=False,
676
+ ###
677
+ ### Experimental memory reduction mode
678
+ best_effort_memory_usage_reduction=False,
679
+ ###
680
+ inverse_failure_threshold=0.1,
681
+ moving_average_for_momentum=False,
682
+ skip_preconditioning_dim_size_gt=4096,
683
+ clip_by_scaled_gradient_norm=None,
684
+ precision=lax.Precision.HIGHEST,
685
+ ):
686
+ """Distributed Shampoo optimizer.
687
+
688
+ Distributed Shampoo is a second-order preconditioned method (concretely, a
689
+ variant of full-matrix Adagrad), that provides significant convergence and
690
+ wall-clock time improvements compared to conventional first-order methods,
691
+ and that has been shown to scale to large state-of-the-art deep learning
692
+ models.
693
+
694
+ References:
695
+ Scalable Second Order Optimization for Deep Learning,
696
+ Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
697
+
698
+ Preprint: https://arxiv.org/abs/2002.09018
699
+
700
+ Args:
701
+ learning_rate: the step size used to update the parameters.
702
+ block_size: Block size for large layers (if > 0). Preconditioning compute
703
+ operation is cubic in the dimension of the tensor. Block size allows us to
704
+ chunk the layers into sub-layers of maximal dimension dictated by this
705
+ value. Use 128 as default (increase if you have compute budget).
706
+ beta1: momentum parameter.
707
+ beta2: second moment averaging parameter.
708
+ diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
709
+ to AdaGrad is enabled).
710
+ matrix_epsilon: epsilon to add to statistics before computing inverse pth
711
+ root. If you are running in f32 precision for inverse pth root
712
+ (recommended today) this can go upto 1e-6. If you have latest hardware
713
+ with native f64 precision, set this upto 1e-12.
714
+ weight_decay: Weight decay for regularization.
715
+ start_preconditioning_step: When to start Shampoo update before which
716
+ diagonal update is used. This is because we dont have enough information
717
+ to do stable inverse.
718
+ preconditioning_compute_steps: How often to compute preconditioner.
719
+ Performance tuning params for controlling memory and compute requirements.
720
+ Ideally set this and statistics_compute_steps params to 1.
721
+ statistics_compute_steps: How often to compute statistics.
722
+ best_effort_shape_interpretation: If there are some small dimensions,
723
+ collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
724
+ block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
725
+ graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
726
+ optimizer. This allows us to plugin the Shampoo optimizer into settings
727
+ where SGD/AdaGrad is already well tuned. Available options are:
728
+ GraftingType.SGD and GraftingType.ADAGRAD.
729
+ nesterov: Nesterov momentum.
730
+ exponent_override: Override the exponent used in matrix inverse.
731
+ batch_axis_name: labeled axis over pmap for data-parallel training the
732
+ optimizer used for.
733
+ mesh_axis_names: Axis names for the mesh (used in pjit).
734
+ num_devices_for_pjit: Number of devices to parallelize over when using pjit.
735
+ shard_optimizer_states: Shard optimizer states to save memory in model
736
+ parallel training.
737
+ best_effort_memory_usage_reduction: Best effort memory usage reduction.
738
+ diagonal_statistics -> jnp.bfloat16
739
+ momentum buffers (2x) -> jnp.int8
740
+ statistics, preconditioners -> jnp.int16 + diagonals
741
+ inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
742
+ determine that using this threshold.
743
+ moving_average_for_momentum: Whether to use moving average for momentum
744
+ instead of exponential moving average.
745
+ skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
746
+ greater than this value.
747
+ clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful
748
+ when using RMSProp Grafting).
749
+ precision: precision XLA related flag, the available options are: a)
750
+ lax.Precision.DEFAULT (better step time, but not precise) b)
751
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
752
+ (best possible precision, slowest)
753
+
754
+ Returns:
755
+ a GradientTransformation.
756
+ """
757
+
758
+ def quantized_dtype_for_momentum_buffers():
759
+ return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32
760
+
761
+ # TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes.
762
+ def quantized_dtype_for_diagonal_statistics_buffers():
763
+ return jnp.bfloat16 if best_effort_memory_usage_reduction else jnp.float32
764
+
765
+ # Preconditioner and statistics are both stores as int16 in this mode.
766
+ # We take out the diagonal to make quantization easier.
767
+ def quantized_dtype_for_second_moment_statistics_buffers():
768
+ return (
769
+ jnp.int16
770
+ if best_effort_memory_usage_reduction and batch_axis_name
771
+ else jnp.float32
772
+ )
773
+
774
+ # Preconditioner and statistics are both stores as int16 in this mode.
775
+ # We take out the diagonal to make quantization easier.
776
+ def quantized_dtype_for_second_moment_preconditioner_buffers():
777
+ return (
778
+ jnp.int16
779
+ if best_effort_memory_usage_reduction and batch_axis_name
780
+ else jnp.float32
781
+ )
782
+
783
+ def _to_float(maybe_quantized):
784
+ if isinstance(maybe_quantized, QuantizedValue):
785
+ return maybe_quantized.to_float()
786
+ else:
787
+ return maybe_quantized
788
+
789
+ def _maybe_quantize_statistics(statistics_list):
790
+ return _maybe_quantize_matrices_with_dtype(
791
+ statistics_list, quantized_dtype_for_second_moment_statistics_buffers()
792
+ )
793
+
794
+ def _maybe_quantize_preconditioners(statistics_list):
795
+ return _maybe_quantize_matrices_with_dtype(
796
+ statistics_list, quantized_dtype_for_second_moment_preconditioner_buffers()
797
+ )
798
+
799
+ def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype):
800
+ if quantized_dtype != jnp.float32:
801
+ return [
802
+ QuantizedValue.from_float_value(
803
+ s, quantized_dtype, extract_diagonal=True
804
+ )
805
+ for s in statistics_list
806
+ ]
807
+ else:
808
+ return statistics_list
809
+
810
+ def _maybe_dequantize_preconditioners(preconditioner_list):
811
+ return _maybe_dequantize_matrices_with_dtype(
812
+ preconditioner_list,
813
+ quantized_dtype_for_second_moment_preconditioner_buffers(),
814
+ )
815
+
816
+ def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype):
817
+ if quantized_dtype != jnp.float32:
818
+ return [s.to_float() for s in statistics_list]
819
+ else:
820
+ return statistics_list
821
+
822
+ def _quantize_diagonal_statistics(diagonal_statistics):
823
+ return QuantizedValue.from_float_value(
824
+ diagonal_statistics, quantized_dtype_for_diagonal_statistics_buffers()
825
+ )
826
+
827
+ def _quantize_momentum(momentum_statistics):
828
+ return QuantizedValue.from_float_value(
829
+ momentum_statistics, quantized_dtype_for_momentum_buffers()
830
+ )
831
+
832
+ def sharded_init_fn(params):
833
+ params_flat, treedef = jax.tree_flatten(params)
834
+ # Find max size to pad to.
835
+ max_size = 0
836
+ for param in params_flat:
837
+ preconditioner = Preconditioner(
838
+ param, block_size, best_effort_shape_interpretation
839
+ )
840
+ if not _skip_preconditioning(param):
841
+ shapes = preconditioner.shapes_for_preconditioners()
842
+ sizes = [s[0] for s in shapes]
843
+ max_size = max(max(sizes), max_size)
844
+
845
+ padded_statistics = []
846
+ padded_preconditioners = []
847
+ local_stats_flat = []
848
+ for param in params_flat:
849
+ preconditioner = Preconditioner(
850
+ param, block_size, best_effort_shape_interpretation
851
+ )
852
+ shapes = preconditioner.shapes_for_preconditioners()
853
+ sizes = []
854
+
855
+ statistics = []
856
+ preconditioners = []
857
+ index_start = len(padded_statistics)
858
+ if not _skip_preconditioning(param):
859
+ sizes = [s[0] for s in shapes]
860
+ shapes = preconditioner.shapes_for_preconditioners()
861
+ statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes]
862
+ preconditioners = [jnp.eye(max_size) for s in shapes]
863
+ padded_statistics.extend(statistics)
864
+ padded_preconditioners.extend(preconditioners)
865
+
866
+ diagonal_statistics = []
867
+ if graft_type != GraftingType.SGD:
868
+ diagonal_statistics = jnp.zeros_like(param)
869
+ local_stats_flat.append(
870
+ LocalShardedParameterStats(
871
+ _quantize_diagonal_statistics(diagonal_statistics),
872
+ _quantize_momentum(jnp.zeros_like(param)),
873
+ _quantize_momentum(jnp.zeros_like(param)),
874
+ index_start,
875
+ sizes,
876
+ )
877
+ )
878
+
879
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
880
+ # Pad the statistics and preconditioner matrices to be a multiple of
881
+ # num devices.
882
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
883
+ # is split on.
884
+ to_pad = -len(padded_statistics) % num_devices_for_pjit
885
+ padded_statistics.extend(
886
+ [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
887
+ )
888
+ padded_preconditioners.extend(
889
+ [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
890
+ )
891
+ global_stats = GlobalShardedParameterStats(
892
+ jnp.stack(padded_statistics), jnp.stack(padded_preconditioners)
893
+ )
894
+ return ShampooState(
895
+ count=jnp.zeros([], jnp.int32),
896
+ stats=ShardedShampooStats(global_stats, local_stats),
897
+ )
898
+
899
+ def sharded_update_fn(grads, state, params):
900
+ """Transform the input gradient and update all statistics in sharded mode.
901
+
902
+ Args:
903
+ grads: the gradient tensors for the parameters.
904
+ state: a named tuple containing the state of the optimizer
905
+ params: the parameters that should be updated.
906
+
907
+ Returns:
908
+ A tuple containing the new parameters and the new optimizer state.
909
+ """
910
+ params_flat, treedef = jax.tree_flatten(params)
911
+ grads_flat = treedef.flatten_up_to(grads)
912
+
913
+ global_stats = state.stats.global_stats
914
+ local_stats_flat = treedef.flatten_up_to(state.stats.local_stats)
915
+ stats_flat = [
916
+ _convert_to_parameter_stats(global_stats, local_stat)
917
+ for local_stat in local_stats_flat
918
+ ]
919
+ new_stats_flat = jax.tree_multimap(
920
+ lambda g, s, p: _compute_stats(g, s, p, state.count),
921
+ grads_flat,
922
+ stats_flat,
923
+ params_flat,
924
+ )
925
+
926
+ exponents = []
927
+ for stat, param in zip(new_stats_flat, params_flat):
928
+ num_statistics = len(stat.statistics)
929
+ if num_statistics > 0:
930
+ preconditioner = Preconditioner(
931
+ param, block_size, best_effort_shape_interpretation
932
+ )
933
+ exponent = (
934
+ preconditioner.exponent_for_preconditioner()
935
+ if exponent_override == 0
936
+ else exponent_override
937
+ )
938
+ exponents.extend([exponent] * num_statistics)
939
+
940
+ outputs = jax.tree_multimap(
941
+ lambda g, s, p: _transform_grad(g, s, p, state.count),
942
+ grads_flat,
943
+ new_stats_flat,
944
+ params_flat,
945
+ )
946
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
947
+
948
+ updates = jax.tree_unflatten(treedef, updates_flat)
949
+ # Create new local_stats
950
+ new_local_stats_flat = [
951
+ _convert_from_parameter_stats(new_stat, local_stat)
952
+ for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
953
+ ]
954
+ new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
955
+
956
+ max_size = global_stats.statistics.shape[1]
957
+ new_padded_statistics = []
958
+ for stat in new_stats_flat:
959
+ new_padded_statistics.extend(
960
+ [pad_matrix(stat, max_size) for stat in stat.statistics]
961
+ )
962
+
963
+ # Create global stats
964
+ # TODO(rohananil): Preconditioner is not updated every step, so cost of
965
+ # stack/pad can be obviated away.
966
+ # Pad the statistics and preconditioner matrices to be a multiple of
967
+ # num devices.
968
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
969
+ # is split on.
970
+ to_pad = -len(new_padded_statistics) % num_devices_for_pjit
971
+ new_padded_statistics.extend(
972
+ [
973
+ jnp.eye(max_size, dtype=new_padded_statistics[0].dtype)
974
+ for _ in range(to_pad)
975
+ ]
976
+ )
977
+ exponents.extend([1 for _ in range(to_pad)])
978
+ new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
979
+ new_stacked_exponents = jnp.stack(exponents)
980
+
981
+ def _matrix_inverse_pth_root_vmap(xs, ps):
982
+ mi_pth_root = functools.partial(
983
+ matrix_inverse_pth_root,
984
+ ridge_epsilon=matrix_epsilon,
985
+ precision=precision,
986
+ )
987
+ preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
988
+ return preconditioners, errors
989
+
990
+ def _internal_inverse_pth_root_all():
991
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
992
+ new_stacked_padded_statistics, new_stacked_exponents
993
+ )
994
+ return preconditioners, errors
995
+
996
+ if preconditioning_compute_steps == 1:
997
+ new_preconditioners, errors = _internal_inverse_pth_root_all()
998
+ else:
999
+ # Passing statistics instead of preconditioners as they are similarly
1000
+ # shaped tensors. Note statistics will be ignored as we are passing in
1001
+ # a large init value for error.
1002
+ preconditioners_init = new_stacked_padded_statistics
1003
+ errors_init = np.stack([inverse_failure_threshold] * len(exponents))
1004
+ init_state = [preconditioners_init, errors_init]
1005
+ perform_step = state.count % preconditioning_compute_steps == 0
1006
+ new_preconditioners, errors = efficient_cond(
1007
+ perform_step, _internal_inverse_pth_root_all, init_state
1008
+ )
1009
+
1010
+ errors = errors.reshape((-1, 1, 1))
1011
+ predicate = jnp.logical_or(
1012
+ jnp.isnan(errors), errors >= inverse_failure_threshold
1013
+ ).astype(new_preconditioners.dtype)
1014
+ # TODO(rohananil): Check for numerical instabilities.
1015
+ new_conditional_preconditioners = (
1016
+ predicate * global_stats.preconditioners
1017
+ + (1.0 - predicate) * new_preconditioners
1018
+ )
1019
+ new_global_stats = GlobalShardedParameterStats(
1020
+ new_stacked_padded_statistics, new_conditional_preconditioners
1021
+ )
1022
+ new_shampoo_state = ShampooState(
1023
+ count=state.count + 1,
1024
+ stats=ShardedShampooStats(new_global_stats, new_local_stats),
1025
+ )
1026
+ return updates, new_shampoo_state
1027
+
1028
+ def init_fn(params):
1029
+ """Initialise the optimiser's state."""
1030
+
1031
+ def _init(param):
1032
+ preconditioner = Preconditioner(
1033
+ param, block_size, best_effort_shape_interpretation
1034
+ )
1035
+ statistics = []
1036
+ preconditioners = []
1037
+ if not _skip_preconditioning(param):
1038
+ shapes = preconditioner.shapes_for_preconditioners()
1039
+ statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes]
1040
+ preconditioners = [jnp.eye(s[0]) for s in shapes]
1041
+
1042
+ diagonal_statistics = []
1043
+ if graft_type != GraftingType.SGD:
1044
+ diagonal_statistics = jnp.zeros_like(param)
1045
+ return ParameterStats(
1046
+ _quantize_diagonal_statistics(diagonal_statistics),
1047
+ _maybe_quantize_statistics(statistics),
1048
+ _maybe_quantize_preconditioners(preconditioners),
1049
+ _quantize_momentum(jnp.zeros_like(param)),
1050
+ _quantize_momentum(jnp.zeros_like(param)),
1051
+ )
1052
+
1053
+ return ShampooState(
1054
+ count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)
1055
+ )
1056
+
1057
+ def _skip_preconditioning(param):
1058
+ return len(param.shape) < 1 or any(
1059
+ [s > skip_preconditioning_dim_size_gt for s in param.shape]
1060
+ )
1061
+
1062
+ def _compute_stats(grad, state, param, step):
1063
+ """Compute per-parameter statistics."""
1064
+ preconditioner = Preconditioner(
1065
+ param, block_size, best_effort_shape_interpretation
1066
+ )
1067
+ new_statistics = [[]] * len(state.statistics)
1068
+ w1 = beta2
1069
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
1070
+ if not _skip_preconditioning(param):
1071
+
1072
+ def compute_updated_statistics():
1073
+ new_stats = preconditioner.statistics_from_grad(grad)
1074
+ new_stats_accumulators = []
1075
+ for stat, stat_accumulator in zip(new_stats, state.statistics):
1076
+ new_stats_accumulators.append(
1077
+ w1 * _to_float(stat_accumulator) + w2 * stat
1078
+ )
1079
+ return _maybe_quantize_statistics(new_stats_accumulators)
1080
+
1081
+ if statistics_compute_steps > 1:
1082
+ perform_step = step % statistics_compute_steps == 0
1083
+ init_state = state.statistics
1084
+ new_statistics = list(
1085
+ efficient_cond(perform_step, compute_updated_statistics, init_state)
1086
+ )
1087
+ else:
1088
+ new_statistics = compute_updated_statistics()
1089
+ return ParameterStats(
1090
+ state.diagonal_statistics,
1091
+ new_statistics,
1092
+ state.preconditioners,
1093
+ state.diagonal_momentum,
1094
+ state.momentum,
1095
+ )
1096
+
1097
+ def _matrix_inverse_pth_root_vmap(xs, ps):
1098
+ mi_pth_root = functools.partial(
1099
+ matrix_inverse_pth_root, ridge_epsilon=matrix_epsilon, precision=precision
1100
+ )
1101
+ return jax.vmap(mi_pth_root)(xs, ps)
1102
+
1103
+ def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps):
1104
+ def _quantized_to_float(qx, qd, qb):
1105
+ qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape))
1106
+ return qv.to_float()
1107
+
1108
+ def matrix_inverse_pth_root_wrapper(qx, qd, qb, p):
1109
+ v = _quantized_to_float(qx, qd, qb)
1110
+ preconditioner, error = matrix_inverse_pth_root(
1111
+ v, p, ridge_epsilon=matrix_epsilon, precision=precision
1112
+ )
1113
+ qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True)
1114
+ return qp.quantized, qp.diagonal, qp.bucket_size, error
1115
+
1116
+ return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
1117
+
1118
+ def _matrix_inverse_pth_root_pjit(xs, ps):
1119
+ mesh_axis_names_tuple = tuple(mesh_axis_names)
1120
+ # Partition the concatenated statistics matrix across all cores.
1121
+ partitioned_xs, partitioned_ps = pjit.pjit(
1122
+ lambda x, y: (x, y),
1123
+ in_axis_resources=None,
1124
+ out_axis_resources=pjit.PartitionSpec(
1125
+ mesh_axis_names_tuple,
1126
+ ),
1127
+ )(xs, ps)
1128
+ # Run matrix inverse pth root on each shard.
1129
+ partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
1130
+ partitioned_xs, partitioned_ps
1131
+ )
1132
+ # Recombine the outputs at each core.
1133
+ preconditioners, errors = pjit.pjit(
1134
+ lambda x, y: (x, y),
1135
+ in_axis_resources=(
1136
+ pjit.PartitionSpec(
1137
+ mesh_axis_names_tuple,
1138
+ ),
1139
+ pjit.PartitionSpec(
1140
+ mesh_axis_names_tuple,
1141
+ ),
1142
+ ),
1143
+ out_axis_resources=(None, None),
1144
+ )(partitioned_preconditioners, partitioned_errors)
1145
+ return preconditioners, errors
1146
+
1147
+ def _pmap_compute_preconditioners(
1148
+ states,
1149
+ step,
1150
+ statistics,
1151
+ num_statistics_per_state,
1152
+ original_shapes,
1153
+ exponents,
1154
+ max_size,
1155
+ prev_preconditioners,
1156
+ ):
1157
+ """Computes preconditioners for given statistics in states in PMAP mode.
1158
+
1159
+ Args:
1160
+ states: A list of optimizer states.
1161
+ step: Current step number
1162
+ statistics: A list of statistics for all variables (for every dim)
1163
+ num_statistics_per_state: Number of statistis per state to reconstruct
1164
+ output states.
1165
+ original_shapes: A list of shapes of the statistics.
1166
+ exponents: Exponent power to use for inverse-pth roots.
1167
+ max_size: Maximum dim of the statistics to pad.
1168
+ prev_preconditioners: Previously available preconditioner.
1169
+
1170
+ Returns:
1171
+ New optimizer states after computing the preconditioner.
1172
+ """
1173
+ num_devices = lax.psum(1, batch_axis_name)
1174
+ num_statistics = len(statistics)
1175
+ # Pad statistics and exponents to next multiple of num_devices.
1176
+ packed_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1177
+ to_pad = -num_statistics % num_devices
1178
+ packed_statistics.extend(
1179
+ [jnp.eye(max_size, dtype=packed_statistics[0].dtype) for _ in range(to_pad)]
1180
+ )
1181
+ exponents.extend([1 for _ in range(to_pad)])
1182
+
1183
+ if not packed_statistics:
1184
+ return states
1185
+
1186
+ all_statistics = batch(packed_statistics, num_devices)
1187
+ all_exponents = batch(exponents, num_devices)
1188
+
1189
+ def _internal_inverse_pth_root_all():
1190
+ current_replica = lax.axis_index(batch_axis_name)
1191
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
1192
+ all_statistics[current_replica], all_exponents[current_replica]
1193
+ )
1194
+ preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
1195
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1196
+ preconditioners_flat = unbatch(preconditioners)
1197
+ errors_flat = unbatch(errors)
1198
+ return preconditioners_flat, errors_flat
1199
+
1200
+ if preconditioning_compute_steps == 1:
1201
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1202
+ else:
1203
+ # Passing statistics instead of preconditioners as they are similarly
1204
+ # shaped tensors. Note statistics will be ignored as we are passing in
1205
+ # a large init value for error.
1206
+ preconditioners_init = packed_statistics
1207
+ errors_init = [inverse_failure_threshold] * len(packed_statistics)
1208
+ init_state = [preconditioners_init, errors_init]
1209
+ perform_step = step % preconditioning_compute_steps == 0
1210
+ preconditioners_flat, errors_flat = efficient_cond(
1211
+ perform_step, _internal_inverse_pth_root_all, init_state
1212
+ )
1213
+
1214
+ def _skip(error):
1215
+ condition = jnp.logical_or(
1216
+ jnp.isnan(error), error >= inverse_failure_threshold
1217
+ )
1218
+ return condition.astype(error.dtype)
1219
+
1220
+ def _select_preconditioner(error, new_p, old_p):
1221
+ return lax.cond(
1222
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None
1223
+ )
1224
+
1225
+ new_preconditioners_flat = []
1226
+ for p, shape, prev_p, error in zip(
1227
+ preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
1228
+ ):
1229
+ new_preconditioners_flat.append(
1230
+ _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
1231
+ )
1232
+
1233
+ assert len(states) == len(num_statistics_per_state)
1234
+ assert len(new_preconditioners_flat) == num_statistics
1235
+
1236
+ # Add back empty preconditioners so we that we can set the optimizer state.
1237
+ preconditioners_for_states = []
1238
+ idx = 0
1239
+ for num_statistics, state in zip(num_statistics_per_state, states):
1240
+ if num_statistics == 0:
1241
+ preconditioners_for_states.append([])
1242
+ else:
1243
+ preconditioners_for_state = new_preconditioners_flat[
1244
+ idx : idx + num_statistics
1245
+ ]
1246
+ assert len(state.statistics) == len(preconditioners_for_state)
1247
+ preconditioners_for_states.append(preconditioners_for_state)
1248
+ idx += num_statistics
1249
+ new_states = []
1250
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1251
+ new_states.append(
1252
+ ParameterStats(
1253
+ state.diagonal_statistics,
1254
+ state.statistics,
1255
+ new_preconditioners,
1256
+ state.diagonal_momentum,
1257
+ state.momentum,
1258
+ )
1259
+ )
1260
+
1261
+ return new_states
1262
+
1263
+ def _pmap_quantized_compute_preconditioners(
1264
+ states,
1265
+ step,
1266
+ statistics,
1267
+ num_statistics_per_state,
1268
+ original_shapes,
1269
+ exponents,
1270
+ max_size,
1271
+ prev_preconditioners,
1272
+ ):
1273
+ """Computes preconditioners for given statistics in states in PMAP mode.
1274
+
1275
+ For quantization, each statistic is represented by three values:
1276
+ quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots
1277
+ without ever recreating the original matrix in f32.
1278
+
1279
+ Args:
1280
+ states: A list of optimizer states.
1281
+ step: Current step number
1282
+ statistics: A list of statistics for all variables (for every dim)
1283
+ num_statistics_per_state: Number of statistis per state to reconstruct
1284
+ output states.
1285
+ original_shapes: A list of shapes of the statistics.
1286
+ exponents: Exponent power to use for inverse-pth roots.
1287
+ max_size: Maximum dim of the statistics to pad.
1288
+ prev_preconditioners: Previously available preconditioner.
1289
+
1290
+ Returns:
1291
+ New optimizer states after computing the preconditioner.
1292
+ """
1293
+ num_devices = lax.psum(1, batch_axis_name)
1294
+ num_statistics = len(statistics)
1295
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
1296
+ # Complexity here is around: shapes needing be statically shaped,
1297
+ # our custom quantization type requires a different type of packing.
1298
+
1299
+ # Parallel tensors:
1300
+ # quantized [dxd]
1301
+ # diagonals [d] f32
1302
+ # bucket_sizes [d] f32
1303
+ packed_quantized_statistics = [
1304
+ pad_matrix(stat.quantized, max_size) for stat in statistics
1305
+ ]
1306
+ packed_quantized_diagonals = [
1307
+ pad_vector(stat.diagonal, max_size) for stat in statistics
1308
+ ]
1309
+ packed_quantized_bucket_sizes = [
1310
+ pad_vector(stat.bucket_size, max_size) for stat in statistics
1311
+ ]
1312
+
1313
+ to_pad = -num_statistics % num_devices
1314
+ padded_eye = jnp.eye(max_size, dtype=jnp.float32)
1315
+ quantized_eye = QuantizedValue.from_float_value(
1316
+ padded_eye, quantized_dtype, True
1317
+ )
1318
+ packed_quantized_statistics.extend(
1319
+ [quantized_eye.quantized for _ in range(to_pad)]
1320
+ )
1321
+ packed_quantized_diagonals.extend(
1322
+ [quantized_eye.diagonal for _ in range(to_pad)]
1323
+ )
1324
+ packed_quantized_bucket_sizes.extend(
1325
+ [quantized_eye.bucket_size for _ in range(to_pad)]
1326
+ )
1327
+ exponents.extend([1 for _ in range(to_pad)])
1328
+
1329
+ if not packed_quantized_statistics:
1330
+ return states
1331
+
1332
+ all_quantized_statistics = batch(packed_quantized_statistics, num_devices)
1333
+ all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices)
1334
+ all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes, num_devices)
1335
+ all_exponents = batch(exponents, num_devices)
1336
+
1337
+ def _internal_inverse_pth_root_all():
1338
+ current_replica = lax.axis_index(batch_axis_name)
1339
+ (
1340
+ quantized_preconditioners,
1341
+ quantized_diagonals,
1342
+ quantized_bucket_sizes,
1343
+ errors,
1344
+ ) = _quantized_matrix_inverse_pth_root_vmap(
1345
+ all_quantized_statistics[current_replica],
1346
+ all_quantized_diagonals[current_replica],
1347
+ all_quantized_bucket_sizes[current_replica],
1348
+ all_exponents[current_replica],
1349
+ )
1350
+ quantized_preconditioners = jax.lax.all_gather(
1351
+ quantized_preconditioners, batch_axis_name
1352
+ )
1353
+ quantized_diagonals = jax.lax.all_gather(
1354
+ quantized_diagonals, batch_axis_name
1355
+ )
1356
+ quantized_bucket_sizes = jax.lax.all_gather(
1357
+ quantized_bucket_sizes, batch_axis_name
1358
+ )
1359
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1360
+ quantized_preconditioners_flat = unbatch(quantized_preconditioners)
1361
+ quantized_diagonals_flat = unbatch(quantized_diagonals)
1362
+ quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes)
1363
+ errors_flat = unbatch(errors)
1364
+ return (
1365
+ quantized_preconditioners_flat,
1366
+ quantized_diagonals_flat,
1367
+ quantized_bucket_sizes_flat,
1368
+ errors_flat,
1369
+ )
1370
+
1371
+ if preconditioning_compute_steps == 1:
1372
+ (
1373
+ quantized_preconditioners_flat,
1374
+ quantized_diagonals_flat,
1375
+ quantized_bucket_sizes_flat,
1376
+ errors_flat,
1377
+ ) = _internal_inverse_pth_root_all()
1378
+ else:
1379
+ # Passing statistics instead of preconditioners as they are similarly
1380
+ # shaped tensors. Note statistics will be ignored as we are passing in
1381
+ # a large init value for error.
1382
+ quantized_preconditioners_init = packed_quantized_statistics
1383
+ quantized_diagonals_init = packed_quantized_diagonals
1384
+ quantized_bucket_sizes_init = packed_quantized_bucket_sizes
1385
+ errors_init = [inverse_failure_threshold] * len(
1386
+ quantized_preconditioners_init
1387
+ )
1388
+ init_state = [
1389
+ quantized_preconditioners_init,
1390
+ quantized_diagonals_init,
1391
+ quantized_bucket_sizes_init,
1392
+ errors_init,
1393
+ ]
1394
+ perform_step = step % preconditioning_compute_steps == 0
1395
+ (
1396
+ quantized_preconditioners_flat,
1397
+ quantized_diagonals_flat,
1398
+ quantized_bucket_sizes_flat,
1399
+ errors_flat,
1400
+ ) = efficient_cond(perform_step, _internal_inverse_pth_root_all, init_state)
1401
+
1402
+ def _skip(error):
1403
+ condition = jnp.logical_or(
1404
+ jnp.isnan(error), error >= inverse_failure_threshold
1405
+ )
1406
+ return condition.astype(error.dtype)
1407
+
1408
+ def _select_preconditioner(error, new_p, old_p):
1409
+ return lax.cond(
1410
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None
1411
+ )
1412
+
1413
+ new_quantized_preconditioners_flat = []
1414
+ new_quantized_diagonals_flat = []
1415
+ new_quantized_bucket_sizes_flat = []
1416
+ for p, d, b, shape, prev_p, error in zip(
1417
+ quantized_preconditioners_flat,
1418
+ quantized_diagonals_flat,
1419
+ quantized_bucket_sizes_flat,
1420
+ original_shapes,
1421
+ prev_preconditioners,
1422
+ errors_flat,
1423
+ ):
1424
+ new_quantized_preconditioners_flat.append(
1425
+ _select_preconditioner(
1426
+ error, p[: shape[0], : shape[1]], prev_p.quantized
1427
+ )
1428
+ )
1429
+ new_quantized_diagonals_flat.append(
1430
+ _select_preconditioner(error, d[: shape[0]], prev_p.diagonal)
1431
+ )
1432
+ new_quantized_bucket_sizes_flat.append(
1433
+ _select_preconditioner(error, b[: shape[0]], prev_p.bucket_size)
1434
+ )
1435
+
1436
+ assert len(states) == len(num_statistics_per_state)
1437
+ assert len(new_quantized_preconditioners_flat) == num_statistics
1438
+ assert len(new_quantized_diagonals_flat) == num_statistics
1439
+ assert len(new_quantized_bucket_sizes_flat) == num_statistics
1440
+
1441
+ # Add back empty preconditioners so we that we can set the optimizer state.
1442
+ preconditioners_for_states = []
1443
+ idx = 0
1444
+ for num_statistics, state in zip(num_statistics_per_state, states):
1445
+ if num_statistics == 0:
1446
+ preconditioners_for_states.append([])
1447
+ else:
1448
+ quantized_preconditioners_for_state = (
1449
+ new_quantized_preconditioners_flat[idx : idx + num_statistics]
1450
+ )
1451
+ quantized_diagonals_for_state = new_quantized_diagonals_flat[
1452
+ idx : idx + num_statistics
1453
+ ]
1454
+ quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
1455
+ idx : idx + num_statistics
1456
+ ]
1457
+
1458
+ assert len(state.statistics) == len(quantized_preconditioners_for_state)
1459
+ assert len(state.statistics) == len(quantized_diagonals_for_state)
1460
+ assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
1461
+
1462
+ quantized_preconditioners = []
1463
+ for qv, qd, qb in zip(
1464
+ quantized_preconditioners_for_state,
1465
+ quantized_diagonals_for_state,
1466
+ quantized_bucket_sizes_for_state,
1467
+ ):
1468
+ quantized_preconditioners.append(
1469
+ QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape))
1470
+ )
1471
+ preconditioners_for_states.append(quantized_preconditioners)
1472
+ idx += num_statistics
1473
+ new_states = []
1474
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1475
+ new_states.append(
1476
+ ParameterStats(
1477
+ state.diagonal_statistics,
1478
+ state.statistics,
1479
+ new_preconditioners,
1480
+ state.diagonal_momentum,
1481
+ state.momentum,
1482
+ )
1483
+ )
1484
+
1485
+ return new_states
1486
+
1487
+ def _pjit_compute_preconditioners(
1488
+ states,
1489
+ step,
1490
+ statistics,
1491
+ num_statistics_per_state,
1492
+ original_shapes,
1493
+ exponents,
1494
+ max_size,
1495
+ prev_preconditioners,
1496
+ ):
1497
+ """Computes preconditioners for given statistics in states in PJIT mode.
1498
+
1499
+ Args:
1500
+ states: A list of optimizer states.
1501
+ step: Current step number
1502
+ statistics: A list of statistics for all variables (for every dim)
1503
+ num_statistics_per_state: Number of statistis per state to reconstruct
1504
+ output states.
1505
+ original_shapes: A list of shapes of the statistics.
1506
+ exponents: Exponent power to use for inverse-pth roots.
1507
+ max_size: Maximum dim of the statistics to pad.
1508
+ prev_preconditioners: Previously available preconditioner.
1509
+
1510
+ Returns:
1511
+ New optimizer states after computing the preconditioner.
1512
+ """
1513
+ num_statistics = len(statistics)
1514
+ to_pad = -num_statistics % num_devices_for_pjit
1515
+ padded_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1516
+ padded_statistics.extend(
1517
+ [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
1518
+ )
1519
+ exponents.extend([1 for _ in range(to_pad)])
1520
+ all_statistics = jnp.stack(padded_statistics)
1521
+ all_exponents = jnp.stack(exponents)
1522
+
1523
+ def _internal_inverse_pth_root_all():
1524
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
1525
+ all_statistics, all_exponents
1526
+ )
1527
+ b1 = preconditioners.shape[0]
1528
+
1529
+ def split(batched_values):
1530
+ return [
1531
+ jnp.squeeze(v)
1532
+ for v in jnp.split(batched_values, indices_or_sections=b1, axis=0)
1533
+ ]
1534
+
1535
+ return split(preconditioners), split(errors)
1536
+
1537
+ if preconditioning_compute_steps == 1:
1538
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1539
+ else:
1540
+ # Passing statistics instead of preconditioners as they are similarly
1541
+ # shaped tensors. Note statistics will be ignored as we are passing in
1542
+ # a large init value for error.
1543
+ preconditioners_init = padded_statistics
1544
+ errors_init = [inverse_failure_threshold] * len(padded_statistics)
1545
+ init_state = [preconditioners_init, errors_init]
1546
+ perform_step = step % preconditioning_compute_steps == 0
1547
+ preconditioners_flat, errors_flat = efficient_cond(
1548
+ perform_step, _internal_inverse_pth_root_all, init_state
1549
+ )
1550
+
1551
+ def _skip(error):
1552
+ condition = jnp.logical_or(
1553
+ jnp.isnan(error), error >= inverse_failure_threshold
1554
+ )
1555
+ return condition.astype(error.dtype)
1556
+
1557
+ def _select_preconditioner(error, new_p, old_p):
1558
+ return lax.cond(
1559
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None
1560
+ )
1561
+
1562
+ new_preconditioners_flat = []
1563
+ for p, shape, prev_p, error in zip(
1564
+ preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
1565
+ ):
1566
+ new_preconditioners_flat.append(
1567
+ _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
1568
+ )
1569
+
1570
+ assert len(states) == len(num_statistics_per_state)
1571
+ assert len(new_preconditioners_flat) == num_statistics
1572
+
1573
+ # Add back empty preconditioners so we that we can set the optimizer state.
1574
+ preconditioners_for_states = []
1575
+ idx = 0
1576
+ for num_statistics, state in zip(num_statistics_per_state, states):
1577
+ if num_statistics == 0:
1578
+ preconditioners_for_states.append([])
1579
+ else:
1580
+ preconditioners_for_state = new_preconditioners_flat[
1581
+ idx : idx + num_statistics
1582
+ ]
1583
+ assert len(state.statistics) == len(preconditioners_for_state)
1584
+ preconditioners_for_states.append(preconditioners_for_state)
1585
+ idx += num_statistics
1586
+ new_states = []
1587
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1588
+ new_states.append(
1589
+ ParameterStats(
1590
+ state.diagonal_statistics,
1591
+ state.statistics,
1592
+ new_preconditioners,
1593
+ state.diagonal_momentum,
1594
+ state.momentum,
1595
+ )
1596
+ )
1597
+
1598
+ return new_states
1599
+
1600
+ def _compute_preconditioners(states, params, step):
1601
+ """Computes preconditioners for given statistics in states.
1602
+
1603
+ Args:
1604
+ states: A list of optimizer states.
1605
+ params: A list of params.
1606
+ step: Current step number
1607
+
1608
+ Returns:
1609
+ New optimizer states after computing the preconditioner.
1610
+ """
1611
+ statistics = []
1612
+ num_statistics_per_state = []
1613
+ original_shapes = []
1614
+ exponents = []
1615
+ max_size = 0
1616
+ prev_preconditioners = []
1617
+
1618
+ for state, param in zip(states, params):
1619
+ num_statistics = len(state.statistics)
1620
+ num_statistics_per_state.append(num_statistics)
1621
+ original_shapes_for_state = []
1622
+ if num_statistics > 0:
1623
+ preconditioner = Preconditioner(
1624
+ param, block_size, best_effort_shape_interpretation
1625
+ )
1626
+ for statistic in state.statistics:
1627
+ exponents.append(
1628
+ preconditioner.exponent_for_preconditioner()
1629
+ if exponent_override == 0
1630
+ else exponent_override
1631
+ )
1632
+ original_shapes_for_state.append(statistic.shape)
1633
+ max_size = max(max_size, statistic.shape[0])
1634
+
1635
+ statistics.extend(state.statistics)
1636
+ prev_preconditioners.extend(state.preconditioners)
1637
+ original_shapes.extend(original_shapes_for_state)
1638
+
1639
+ if batch_axis_name:
1640
+ # Quantization is only enabled if batch_axis_name is not set.
1641
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
1642
+
1643
+ if quantized_dtype == jnp.float32:
1644
+ return _pmap_compute_preconditioners(
1645
+ states,
1646
+ step,
1647
+ statistics,
1648
+ num_statistics_per_state,
1649
+ original_shapes,
1650
+ exponents,
1651
+ max_size,
1652
+ prev_preconditioners,
1653
+ )
1654
+ else:
1655
+ return _pmap_quantized_compute_preconditioners(
1656
+ states,
1657
+ step,
1658
+ statistics,
1659
+ num_statistics_per_state,
1660
+ original_shapes,
1661
+ exponents,
1662
+ max_size,
1663
+ prev_preconditioners,
1664
+ )
1665
+
1666
+ else:
1667
+ return _pjit_compute_preconditioners(
1668
+ states,
1669
+ step,
1670
+ statistics,
1671
+ num_statistics_per_state,
1672
+ original_shapes,
1673
+ exponents,
1674
+ max_size,
1675
+ prev_preconditioners,
1676
+ )
1677
+
1678
+ def _transform_grad(grad, state, param, step):
1679
+ """Transform per-parameter gradients."""
1680
+ preconditioner = Preconditioner(
1681
+ param, block_size, best_effort_shape_interpretation
1682
+ )
1683
+ sgd_update = grad
1684
+ new_diagonal_statistics = state.diagonal_statistics.to_float()
1685
+ if graft_type == GraftingType.ADAGRAD:
1686
+ new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square(
1687
+ grad
1688
+ )
1689
+ adagrad_update = grad / (
1690
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
1691
+ )
1692
+ grafting_update = adagrad_update
1693
+ elif (
1694
+ graft_type == GraftingType.RMSPROP
1695
+ or graft_type == GraftingType.RMSPROP_NORMALIZED
1696
+ ):
1697
+
1698
+ scaled_grad = grad
1699
+ if graft_type == GraftingType.RMSPROP_NORMALIZED:
1700
+ scaled_grad = grad / jnp.linalg.norm(grad)
1701
+
1702
+ w1 = beta2
1703
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
1704
+
1705
+ new_diagonal_statistics = (
1706
+ w1 * state.diagonal_statistics.to_float() + w2 * jnp.square(scaled_grad)
1707
+ )
1708
+ rmsprop_update = scaled_grad / (
1709
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
1710
+ )
1711
+
1712
+ if clip_by_scaled_gradient_norm:
1713
+ scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / (
1714
+ jnp.sqrt(float(rmsprop_update.size))
1715
+ )
1716
+ clipping_denom = jnp.maximum(
1717
+ 1.0, scaled_grad_norm / clip_by_scaled_gradient_norm
1718
+ )
1719
+ rmsprop_update /= clipping_denom
1720
+
1721
+ grafting_update = rmsprop_update
1722
+ else:
1723
+ grafting_update = sgd_update
1724
+
1725
+ precond_grad = grad
1726
+ if not _skip_preconditioning(param):
1727
+ precond_grad = preconditioner.preconditioned_grad(
1728
+ precond_grad, _maybe_dequantize_preconditioners(state.preconditioners)
1729
+ )
1730
+ else:
1731
+ precond_grad = grafting_update
1732
+
1733
+ grafting_update_norm = jnp.linalg.norm(grafting_update)
1734
+ precond_grad_norm = jnp.linalg.norm(precond_grad)
1735
+
1736
+ multiplier = grafting_update_norm / (precond_grad_norm + 1e-16)
1737
+ shampoo_update = precond_grad * multiplier
1738
+
1739
+ shampoo_update_with_wd = shampoo_update
1740
+ grafting_update_with_wd = grafting_update
1741
+ if weight_decay != 0:
1742
+ shampoo_update_with_wd = shampoo_update + weight_decay * param
1743
+ grafting_update_with_wd = grafting_update + weight_decay * param
1744
+
1745
+ w = (1.0 - beta1) if moving_average_for_momentum else 1.0
1746
+ shampoo_update_with_wd_momentum = (
1747
+ state.momentum.to_float() * beta1 + w * shampoo_update_with_wd
1748
+ )
1749
+ grafting_update_with_wd_momentum = (
1750
+ state.diagonal_momentum.to_float() * beta1 + w * grafting_update_with_wd
1751
+ )
1752
+
1753
+ run_shampoo = (step >= start_preconditioning_step).astype(
1754
+ grafting_update_with_wd_momentum.dtype
1755
+ )
1756
+
1757
+ momentum_update = (
1758
+ run_shampoo * shampoo_update_with_wd_momentum
1759
+ + (1.0 - run_shampoo) * grafting_update_with_wd_momentum
1760
+ )
1761
+
1762
+ wd_update = (
1763
+ run_shampoo * shampoo_update_with_wd
1764
+ + (1.0 - run_shampoo) * grafting_update_with_wd
1765
+ )
1766
+
1767
+ if nesterov:
1768
+ momentum_update = w * wd_update + beta1 * momentum_update
1769
+
1770
+ lr = learning_rate
1771
+ if callable(learning_rate):
1772
+ lr = learning_rate(step)
1773
+ transformed_update = -1.0 * lr * momentum_update
1774
+
1775
+ param_stats = ParameterStats(
1776
+ _quantize_diagonal_statistics(new_diagonal_statistics),
1777
+ state.statistics,
1778
+ state.preconditioners,
1779
+ _quantize_momentum(grafting_update_with_wd_momentum),
1780
+ _quantize_momentum(shampoo_update_with_wd_momentum),
1781
+ )
1782
+ return transformed_update, param_stats
1783
+
1784
+ def update_fn(grads, state, params):
1785
+ """Transform the input gradient and update all statistics.
1786
+
1787
+ Args:
1788
+ grads: the gradient tensors for the parameters.
1789
+ state: a named tuple containing the state of the optimizer
1790
+ params: the parameters that should be updated.
1791
+
1792
+ Returns:
1793
+ A tuple containing the new parameters and the new optimizer state.
1794
+ """
1795
+ params_flat, treedef = jax.tree_flatten(params)
1796
+ stats_flat = treedef.flatten_up_to(state.stats)
1797
+ grads_flat = treedef.flatten_up_to(grads)
1798
+
1799
+ new_stats_flat = jax.tree_multimap(
1800
+ lambda g, s, p: _compute_stats(g, s, p, state.count),
1801
+ grads_flat,
1802
+ stats_flat,
1803
+ params_flat,
1804
+ )
1805
+ new_stats_flat = _compute_preconditioners(
1806
+ new_stats_flat, params_flat, state.count
1807
+ )
1808
+
1809
+ outputs = jax.tree_multimap(
1810
+ lambda g, s, p: _transform_grad(g, s, p, state.count),
1811
+ grads_flat,
1812
+ new_stats_flat,
1813
+ params_flat,
1814
+ )
1815
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
1816
+
1817
+ updates = jax.tree_unflatten(treedef, updates_flat)
1818
+ new_stats = jax.tree_unflatten(treedef, new_stats_flat)
1819
+
1820
+ new_state = ShampooState(count=state.count + 1, stats=new_stats)
1821
+ return updates, new_state
1822
+
1823
+ if shard_optimizer_states:
1824
+ return optax.GradientTransformation(sharded_init_fn, sharded_update_fn)
1825
+ else:
1826
+ return optax.GradientTransformation(init_fn, update_fn)
tools/train/sweep.yaml CHANGED
@@ -11,44 +11,39 @@ parameters:
11
  # from exp(min) to exp(max)
12
  min: -6.9
13
  max: -3.5
 
 
 
 
 
 
 
 
 
 
 
 
14
  gradient_accumulation_steps:
15
- value: 8
16
  warmup_steps:
17
  value: 4000
18
- #TODO: outdated command
 
 
 
 
 
 
 
 
19
  command:
20
  - python3
21
  - ${program}
22
- - "--tokenizer_name"
23
- - "boris/dalle-mini-tokenizer"
24
- - "--config_name"
25
- - "facebook/bart-large-cnn"
26
- - "--dataset_repo_or_path"
27
- - "boris/gis_vqgan_f16_16384"
28
  - "--streaming"
29
- - "--use_auth_token"
30
- - "--image_vocab_size"
31
- - 16384
32
- - "--image_length"
33
- - 256
34
- - "--normalize_text"
35
- - True
36
- - "--per_device_train_batch_size"
37
- - 56
38
- - "--per_device_eval_batch_size"
39
- - 56
40
- - "--adafactor"
41
- - "--do_train"
42
- - "--do_eval"
43
- - "--num_train_epochs"
44
- - 1
45
- - "--logging_steps"
46
- - 40
47
- - "--eval_steps"
48
- - 800
49
  - "--output_dir"
50
  - "./output"
51
  - "--overwrite_output_dir"
52
- - "--max_train_samples"
53
- - 10000000
 
54
  - ${args}
 
11
  # from exp(min) to exp(max)
12
  min: -6.9
13
  max: -3.5
14
+ tokenizer_name:
15
+ value: boris/dalle-mini-tokenizer
16
+ config_name:
17
+ value: ./config/mini
18
+ dtype:
19
+ value: bfloat16
20
+ dataset_repo_or_path:
21
+ value: ./data
22
+ per_device_train_batch_size:
23
+ value: 64
24
+ per_device_eval_batch_size:
25
+ value: 64
26
  gradient_accumulation_steps:
27
+ value: 1
28
  warmup_steps:
29
  value: 4000
30
+ num_train_epochs:
31
+ value: 1
32
+ logging_steps:
33
+ value: 32
34
+ eval_steps:
35
+ value: 800
36
+ max_train_samples:
37
+ value: 1000000
38
+
39
  command:
40
  - python3
41
  - ${program}
 
 
 
 
 
 
42
  - "--streaming"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  - "--output_dir"
44
  - "./output"
45
  - "--overwrite_output_dir"
46
+ - "--adafactor"
47
+ - "--do_train"
48
+ - "--do_eval"
49
  - ${args}
tools/train/train.py CHANGED
@@ -34,6 +34,7 @@ import optax
34
  import transformers
35
  import wandb
36
  from datasets import Dataset
 
37
  from flax import jax_utils, traverse_util
38
  from flax.jax_utils import unreplicate
39
  from flax.serialization import from_bytes, to_bytes
@@ -41,10 +42,9 @@ from flax.training import train_state
41
  from flax.training.common_utils import get_metrics, onehot, shard_prng_key
42
  from tqdm import tqdm
43
  from transformers import AutoTokenizer, HfArgumentParser
44
- from transformers.models.bart.modeling_flax_bart import BartConfig
45
 
46
  from dalle_mini.data import Dataset
47
- from dalle_mini.model import CustomFlaxBartForConditionalGeneration
48
 
49
  logger = logging.getLogger(__name__)
50
 
@@ -68,26 +68,12 @@ class ModelArguments:
68
  "help": "Pretrained config name or path if not the same as model_name"
69
  },
70
  )
71
- image_vocab_size: Optional[int] = field(
72
- default=None,
73
- metadata={"help": "Vocab size of image encoder"},
74
- )
75
- image_length: Optional[int] = field(
76
- default=None,
77
- metadata={"help": "Number of tokens per image"},
78
- )
79
  tokenizer_name: Optional[str] = field(
80
  default=None,
81
  metadata={
82
  "help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
83
  },
84
  )
85
- normalize_text: Optional[bool] = field(
86
- default=None,
87
- metadata={
88
- "help": "Whether to normalize text or not. By default, we refer to base model or don't normalize for new models."
89
- },
90
- )
91
  dtype: Optional[str] = field(
92
  default="float32",
93
  metadata={
@@ -126,26 +112,21 @@ class DataTrainingArguments:
126
  default=None,
127
  metadata={"help": "An optional input evaluation data file (glob acceptable)."},
128
  )
129
- dataset_type: str = field(
130
- default="datasets",
131
- metadata={"help": "Either 🤗 'dataset' (default) or 'webdataset'."},
132
- )
133
  # data loading should not be a bottleneck so we use "streaming" mode by default
134
- streaming: bool = field(
135
  default=True,
136
  metadata={"help": "Whether to stream the dataset."},
137
  )
138
- use_auth_token: bool = field(
139
  default=False,
140
  metadata={
141
  "help": "Whether to use the authentication token for private datasets."
142
  },
143
  )
144
- max_source_length: Optional[int] = field(
145
- default=128,
146
  metadata={
147
- "help": "The maximum total input sequence length after tokenization. Sequences longer "
148
- "than this will be truncated, sequences shorter will be padded."
149
  },
150
  )
151
  max_train_samples: Optional[int] = field(
@@ -232,7 +213,11 @@ class TrainingArguments:
232
  )
233
  adafactor: bool = field(
234
  default=False,
235
- metadata={"help": "Whether or not to replace AdamW by Adafactor."},
 
 
 
 
236
  )
237
  weight_decay: float = field(
238
  default=None, metadata={"help": "Weight decay if we apply some."}
@@ -351,14 +336,39 @@ def create_learning_rate_fn(
351
  return schedule_fn
352
 
353
 
354
- def wandb_log(metrics, step=None, prefix=None):
355
- if jax.process_index() == 0:
356
- log_metrics = {
357
- f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
 
 
 
 
 
 
 
 
358
  }
359
- if step is not None:
360
- log_metrics["train/step"] = step
361
- wandb.log(log_metrics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
 
364
  def main():
@@ -411,20 +421,29 @@ def main():
411
  do_eval=training_args.do_eval,
412
  )
413
 
 
 
 
414
  # Set up wandb run
415
- wandb.init(
416
- entity="dalle-mini",
417
- project="dalle-mini",
418
- job_type="Seq2Seq",
419
- config=parser.parse_args(),
420
- )
 
421
 
422
  if training_args.resume_from_checkpoint is not None:
423
- artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
 
 
 
424
  artifact_dir = artifact.download()
425
 
426
  # load model
427
- model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
 
 
428
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
429
  print(model.params)
430
 
@@ -436,56 +455,24 @@ def main():
436
 
437
  else:
438
  # Set up our new model config
439
- # TODO: simplify with custom config class
440
  if model_args.config_name:
441
- config = BartConfig.from_pretrained(model_args.config_name)
442
- else:
443
- config = BartConfig.from_pretrained(model_args.model_name_or_path)
444
- if model_args.image_vocab_size:
445
- config.image_vocab_size = model_args.image_vocab_size
446
- assert (
447
- getattr(config, "image_vocab_size") is not None
448
- ), "image_vocab_size must be specified when not present in base model/config"
449
- if model_args.image_length:
450
- config.image_length = model_args.image_length
451
- assert (
452
- getattr(config, "image_length") is not None
453
- ), "image_length must be specified when not present in base model/config"
454
- # we append decoder bos to image vocab
455
- config.decoder_start_token_id = config.image_vocab_size
456
- # ensure we don't generate bos (in addition to decoder start token)
457
- config.force_bos_token_to_be_generated = False
458
- config.forced_bos_token_id = None # we don't need this token
459
- config.forced_eos_token_id = None # we don't need this token
460
-
461
- config.tie_word_embeddings = False
462
- config.min_length = config.image_length + 1
463
- config.max_length = config.image_length + 1
464
-
465
- # below tokens need to be set to avoid error during generation (converted to jnp.array)
466
- # they are not expected to be used and are set to unreachable token id
467
- config.bos_token_id = config.image_vocab_size + 1
468
- config.pos_token_id = config.image_vocab_size + 1
469
- config.eos_token_id = config.image_vocab_size + 1
470
-
471
- # save whether we normalize the text
472
- if model_args.normalize_text is not None:
473
- config.normalize_text = model_args.normalize_text
474
  else:
475
- config.normalize_text = getattr(config, "normalize_text", False)
476
 
477
  # Load or create new model
478
  if model_args.model_name_or_path:
479
- model = CustomFlaxBartForConditionalGeneration.from_pretrained(
480
  model_args.model_name_or_path,
481
  config=config,
482
  seed=training_args.seed_model,
483
  dtype=getattr(jnp, model_args.dtype),
 
484
  )
485
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
486
  print(model.params)
487
  else:
488
- model = CustomFlaxBartForConditionalGeneration(
489
  config,
490
  seed=training_args.seed_model,
491
  dtype=getattr(jnp, model_args.dtype),
@@ -502,9 +489,6 @@ def main():
502
  use_fast=True,
503
  )
504
 
505
- logger.info(f"TPUs: {jax.device_count()}")
506
- assert jax.device_count() == 8, "TPUs in use, please check running processes"
507
-
508
  # Preprocessing the datasets.
509
  # We need to normalize and tokenize inputs and targets.
510
 
@@ -512,6 +496,7 @@ def main():
512
  tokenizer=tokenizer,
513
  decoder_start_token_id=model.config.decoder_start_token_id,
514
  normalize_text=model.config.normalize_text,
 
515
  )
516
 
517
  # Initialize our training
@@ -520,18 +505,28 @@ def main():
520
 
521
  # Store some constant
522
  num_epochs = int(training_args.num_train_epochs)
 
523
  train_batch_size = (
524
- int(training_args.per_device_train_batch_size) * jax.device_count()
 
 
 
 
 
 
 
 
525
  )
526
- batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
527
- eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
528
  len_train_dataset, len_eval_dataset = dataset.length
529
  steps_per_epoch = (
530
- len_train_dataset // train_batch_size if len_train_dataset is not None else None
 
 
531
  )
532
  num_train_steps = (
533
  steps_per_epoch * num_epochs if steps_per_epoch is not None else None
534
  )
 
535
 
536
  # Create learning rate schedule
537
  learning_rate_fn = create_learning_rate_fn(
@@ -572,13 +567,43 @@ def main():
572
  weight_decay_mask=decay_mask_fn,
573
  clipping_threshold=training_args.max_grad_norm,
574
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
  else:
576
  optimizer = optax.adamw(
577
  learning_rate=learning_rate_fn,
578
  b1=training_args.adam_beta1,
579
  b2=training_args.adam_beta2,
580
  eps=training_args.adam_epsilon,
581
- weight_decay=training_args.weight_decay,
 
 
582
  mask=decay_mask_fn,
583
  )
584
 
@@ -625,7 +650,7 @@ def main():
625
  grads=grads,
626
  dropout_rng=new_dropout_rng,
627
  train_time=state.train_time + delta_time,
628
- train_samples=state.train_samples + train_batch_size,
629
  )
630
 
631
  metrics = {
@@ -657,25 +682,30 @@ def main():
657
  logger.info(
658
  f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
659
  )
 
660
  logger.info(
661
  f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
662
  )
 
663
  epochs = tqdm(
664
  range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
665
  )
666
 
667
- # set default x-axis as 'train/step'
668
- wandb_log({}, step=state.step)
669
- wandb.define_metric("*", step_metric="train/step")
670
-
671
- # add interesting config parameters
672
- wandb.config.update(
673
- {
674
- "len_train_dataset": len_train_dataset,
675
- "len_eval_dataset": len_eval_dataset,
676
- "batch_size_per_update": batch_size_per_update,
677
- }
678
- )
 
 
 
679
 
680
  # replicate state on each device
681
  state = state.replicate()
@@ -706,7 +736,9 @@ def main():
706
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
707
 
708
  # log metrics
709
- wandb_log(eval_metrics, step=unreplicate(state.step), prefix="eval")
 
 
710
 
711
  # Print metrics and update progress bar
712
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
@@ -743,51 +775,61 @@ def main():
743
  f,
744
  )
745
 
746
- # save to W&B
747
- if training_args.log_model:
748
- # save some space
749
- c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
750
- c.cleanup(wandb.util.from_human_size("10GB"))
751
-
752
- metadata = dict(state_dict)
753
- if eval_metrics is not None:
754
- metadata["eval"] = eval_metrics
755
- artifact = wandb.Artifact(
756
- name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
757
- )
758
- artifact.add_file(
759
- str(Path(training_args.output_dir) / "flax_model.msgpack")
760
- )
761
- artifact.add_file(str(Path(training_args.output_dir) / "config.json"))
762
- artifact.add_file(
763
- str(Path(training_args.output_dir) / "tokenizer.json")
764
- )
765
- artifact.add_file(
766
- str(Path(training_args.output_dir) / "tokenizer_config.json")
767
- )
768
- artifact.add_file(str(Path(training_args.output_dir) / "vocab.json"))
769
- artifact.add_file(str(Path(training_args.output_dir) / "merges.txt"))
770
- artifact.add_file(
771
- str(Path(training_args.output_dir) / "special_tokens_map.json")
772
- )
773
- artifact.add_file(
774
- str(Path(training_args.output_dir) / "opt_state.msgpack")
775
- )
776
- artifact.add_file(
777
- str(Path(training_args.output_dir) / "training_state.json")
778
- )
779
-
780
- wandb.run.log_artifact(artifact)
781
-
782
- # save to the hub
783
- if training_args.push_to_hub:
784
- model.save_pretrained(
785
- training_args.output_dir,
786
- params=params,
787
- push_to_hub=training_args.push_to_hub,
788
- commit_message=f"Saving weights and logs at step {unreplicate(state.step)+1}",
789
- temp_dir=True, # avoid issues with being in a repository
790
- )
 
 
 
 
 
 
 
 
 
 
791
 
792
  # init variables
793
  last_time = time.perf_counter()
@@ -796,7 +838,7 @@ def main():
796
  for epoch in epochs:
797
  state.replace(epoch=jax_utils.replicate(epoch))
798
  # ======================== Training ================================
799
- wandb_log({"train/epoch": epoch}, step=unreplicate(state.step))
800
 
801
  # Generate an epoch by shuffling sampling indices from the train dataset
802
  train_loader = dataset.dataloader("train", train_batch_size)
@@ -821,14 +863,8 @@ def main():
821
  step = unreplicate(state.step)
822
 
823
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
824
- # log metrics
825
- metrics = unreplicate(train_metrics)
826
- # log state parameters
827
- state_dict = {
828
- k.split("_")[-1]: unreplicate(getattr(state, k))
829
- for k in ["epoch", "train_time", "train_samples"]
830
- }
831
- wandb_log({**metrics, **state_dict}, step=step, prefix="train")
832
 
833
  eval_metrics = None
834
  if training_args.eval_steps and step % training_args.eval_steps == 0:
@@ -839,8 +875,8 @@ def main():
839
 
840
  # log final train metrics
841
  if train_metrics is not None:
842
- train_metrics = unreplicate(train_metrics)
843
- wandb_log(train_metrics, step=step, prefix="train")
844
 
845
  epochs.write(
846
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
 
34
  import transformers
35
  import wandb
36
  from datasets import Dataset
37
+ from distributed_shampoo import GraftingType, distributed_shampoo
38
  from flax import jax_utils, traverse_util
39
  from flax.jax_utils import unreplicate
40
  from flax.serialization import from_bytes, to_bytes
 
42
  from flax.training.common_utils import get_metrics, onehot, shard_prng_key
43
  from tqdm import tqdm
44
  from transformers import AutoTokenizer, HfArgumentParser
 
45
 
46
  from dalle_mini.data import Dataset
47
+ from dalle_mini.model import DalleBart, DalleBartConfig
48
 
49
  logger = logging.getLogger(__name__)
50
 
 
68
  "help": "Pretrained config name or path if not the same as model_name"
69
  },
70
  )
 
 
 
 
 
 
 
 
71
  tokenizer_name: Optional[str] = field(
72
  default=None,
73
  metadata={
74
  "help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
75
  },
76
  )
 
 
 
 
 
 
77
  dtype: Optional[str] = field(
78
  default="float32",
79
  metadata={
 
112
  default=None,
113
  metadata={"help": "An optional input evaluation data file (glob acceptable)."},
114
  )
 
 
 
 
115
  # data loading should not be a bottleneck so we use "streaming" mode by default
116
+ streaming: Optional[bool] = field(
117
  default=True,
118
  metadata={"help": "Whether to stream the dataset."},
119
  )
120
+ use_auth_token: Optional[bool] = field(
121
  default=False,
122
  metadata={
123
  "help": "Whether to use the authentication token for private datasets."
124
  },
125
  )
126
+ shard_by_host: Optional[bool] = field(
127
+ default=False,
128
  metadata={
129
+ "help": "Whether to shard data files by host in multi-host environments."
 
130
  },
131
  )
132
  max_train_samples: Optional[int] = field(
 
213
  )
214
  adafactor: bool = field(
215
  default=False,
216
+ metadata={"help": "Use Adafactor instead of AdamW."},
217
+ )
218
+ distributed_shampoo: bool = field(
219
+ default=False,
220
+ metadata={"help": "Use Distributed Shampoo optimizer instead of AdamW."},
221
  )
222
  weight_decay: float = field(
223
  default=None, metadata={"help": "Weight decay if we apply some."}
 
336
  return schedule_fn
337
 
338
 
339
+ class MetricsLogger:
340
+ def __init__(self, state):
341
+ self.step = state.step
342
+ self.time = time.perf_counter()
343
+
344
+ def get_all_train_metrics(self, train_metrics, state):
345
+ """Make a dict of training metrics to be logged"""
346
+ metrics = unreplicate(train_metrics)
347
+ # get state parameters
348
+ state_dict = {
349
+ k.split("_")[-1]: unreplicate(getattr(state, k))
350
+ for k in ["epoch", "train_time", "train_samples"]
351
  }
352
+ # timing metrics
353
+ new_step = int(unreplicate(state.step))
354
+ new_time = time.perf_counter()
355
+ if new_step > self.step:
356
+ time_per_step = (new_time - self.time) / (new_step - self.step)
357
+ self.step = new_step
358
+ self.time = new_time
359
+ state_dict["time_per_step"] = time_per_step
360
+ return {**metrics, **state_dict}
361
+
362
+ @staticmethod
363
+ def log(metrics, step=None, prefix=None):
364
+ if jax.process_index() == 0:
365
+ log_metrics = {
366
+ f"{prefix}/{k}" if prefix is not None else k: v
367
+ for k, v in metrics.items()
368
+ }
369
+ if step is not None:
370
+ log_metrics["train/step"] = step
371
+ wandb.log(log_metrics)
372
 
373
 
374
  def main():
 
421
  do_eval=training_args.do_eval,
422
  )
423
 
424
+ logger.info(f"Local TPUs: {jax.local_device_count()}")
425
+ assert jax.local_device_count() == 8, "TPUs in use, please check running processes"
426
+
427
  # Set up wandb run
428
+ if jax.process_index() == 0:
429
+ wandb.init(
430
+ entity="dalle-mini",
431
+ project="dalle-mini",
432
+ job_type="Seq2Seq",
433
+ config=parser.parse_args(),
434
+ )
435
 
436
  if training_args.resume_from_checkpoint is not None:
437
+ if jax.process_index() == 0:
438
+ artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
439
+ else:
440
+ artifact = wandb.Api().artifact(training_args.resume_from_checkpoint)
441
  artifact_dir = artifact.download()
442
 
443
  # load model
444
+ model = DalleBart.from_pretrained(
445
+ artifact_dir, dtype=getattr(jnp, model_args.dtype), abstract_init=True
446
+ )
447
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
448
  print(model.params)
449
 
 
455
 
456
  else:
457
  # Set up our new model config
 
458
  if model_args.config_name:
459
+ config = DalleBartConfig.from_pretrained(model_args.config_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  else:
461
+ config = DalleBartConfig.from_pretrained(model_args.model_name_or_path)
462
 
463
  # Load or create new model
464
  if model_args.model_name_or_path:
465
+ model = DalleBart.from_pretrained(
466
  model_args.model_name_or_path,
467
  config=config,
468
  seed=training_args.seed_model,
469
  dtype=getattr(jnp, model_args.dtype),
470
+ abstract_init=True,
471
  )
472
  # avoid OOM on TPU: see https://github.com/google/flax/issues/1658
473
  print(model.params)
474
  else:
475
+ model = DalleBart(
476
  config,
477
  seed=training_args.seed_model,
478
  dtype=getattr(jnp, model_args.dtype),
 
489
  use_fast=True,
490
  )
491
 
 
 
 
492
  # Preprocessing the datasets.
493
  # We need to normalize and tokenize inputs and targets.
494
 
 
496
  tokenizer=tokenizer,
497
  decoder_start_token_id=model.config.decoder_start_token_id,
498
  normalize_text=model.config.normalize_text,
499
+ max_length=model.config.max_text_length,
500
  )
501
 
502
  # Initialize our training
 
505
 
506
  # Store some constant
507
  num_epochs = int(training_args.num_train_epochs)
508
+ # batch size per node
509
  train_batch_size = (
510
+ int(training_args.per_device_train_batch_size) * jax.local_device_count()
511
+ )
512
+ batch_size_per_update = (
513
+ train_batch_size
514
+ * training_args.gradient_accumulation_steps
515
+ * jax.process_count()
516
+ )
517
+ eval_batch_size = (
518
+ int(training_args.per_device_eval_batch_size) * jax.local_device_count()
519
  )
 
 
520
  len_train_dataset, len_eval_dataset = dataset.length
521
  steps_per_epoch = (
522
+ len_train_dataset // (train_batch_size * jax.process_count())
523
+ if len_train_dataset is not None
524
+ else None
525
  )
526
  num_train_steps = (
527
  steps_per_epoch * num_epochs if steps_per_epoch is not None else None
528
  )
529
+ num_params = model.num_params
530
 
531
  # Create learning rate schedule
532
  learning_rate_fn = create_learning_rate_fn(
 
567
  weight_decay_mask=decay_mask_fn,
568
  clipping_threshold=training_args.max_grad_norm,
569
  )
570
+ elif training_args.distributed_shampoo:
571
+ # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
572
+ # Notes:
573
+ # - mask for weight decay is not implemented but we don't use it anyway
574
+ optimizer = distributed_shampoo(
575
+ learning_rate_fn,
576
+ block_size=1024, # recommended default for large LM is 1536
577
+ beta1=0.9,
578
+ beta2=0.999,
579
+ diagonal_epsilon=1e-10,
580
+ matrix_epsilon=1e-8,
581
+ weight_decay=0.0,
582
+ start_preconditioning_step=1001,
583
+ preconditioning_compute_steps=10,
584
+ statistics_compute_steps=1,
585
+ best_effort_shape_interpretation=True,
586
+ graft_type=GraftingType.RMSPROP_NORMALIZED,
587
+ nesterov=False,
588
+ exponent_override=0,
589
+ batch_axis_name="batch",
590
+ inverse_failure_threshold=0.1,
591
+ moving_average_for_momentum=True,
592
+ skip_preconditioning_dim_size_gt=4096,
593
+ clip_by_scaled_gradient_norm=None,
594
+ precision=jax.lax.Precision.HIGHEST,
595
+ best_effort_memory_usage_reduction=False,
596
+ )
597
+
598
  else:
599
  optimizer = optax.adamw(
600
  learning_rate=learning_rate_fn,
601
  b1=training_args.adam_beta1,
602
  b2=training_args.adam_beta2,
603
  eps=training_args.adam_epsilon,
604
+ weight_decay=training_args.weight_decay
605
+ if training_args.weight_decay is not None
606
+ else 0.0,
607
  mask=decay_mask_fn,
608
  )
609
 
 
650
  grads=grads,
651
  dropout_rng=new_dropout_rng,
652
  train_time=state.train_time + delta_time,
653
+ train_samples=state.train_samples + train_batch_size * jax.process_count(),
654
  )
655
 
656
  metrics = {
 
682
  logger.info(
683
  f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
684
  )
685
+ logger.info(f" Number of devices = {jax.device_count()}")
686
  logger.info(
687
  f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
688
  )
689
+ logger.info(f" Model parameters = {num_params:,}")
690
  epochs = tqdm(
691
  range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
692
  )
693
 
694
+ metrics_logger = MetricsLogger(state)
695
+ if jax.process_index() == 0:
696
+ # set default x-axis as 'train/step'
697
+ metrics_logger.log({}, step=state.step)
698
+ wandb.define_metric("*", step_metric="train/step")
699
+
700
+ # add interesting config parameters
701
+ wandb.config.update(
702
+ {
703
+ "len_train_dataset": len_train_dataset,
704
+ "len_eval_dataset": len_eval_dataset,
705
+ "batch_size_per_update": batch_size_per_update,
706
+ "num_params": num_params,
707
+ }
708
+ )
709
 
710
  # replicate state on each device
711
  state = state.replicate()
 
736
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
737
 
738
  # log metrics
739
+ metrics_logger.log(
740
+ eval_metrics, step=unreplicate(state.step), prefix="eval"
741
+ )
742
 
743
  # Print metrics and update progress bar
744
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
 
775
  f,
776
  )
777
 
778
+ if jax.process_index() == 0:
779
+ # save to W&B
780
+ if training_args.log_model:
781
+ # save some space
782
+ c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
783
+ c.cleanup(wandb.util.from_human_size("10GB"))
784
+
785
+ metadata = dict(state_dict)
786
+ metadata["num_params"] = num_params
787
+ if eval_metrics is not None:
788
+ metadata["eval"] = eval_metrics
789
+ artifact = wandb.Artifact(
790
+ name=f"model-{wandb.run.id}",
791
+ type="bart_model",
792
+ metadata=metadata,
793
+ )
794
+ artifact.add_file(
795
+ str(Path(training_args.output_dir) / "flax_model.msgpack")
796
+ )
797
+ artifact.add_file(
798
+ str(Path(training_args.output_dir) / "config.json")
799
+ )
800
+ artifact.add_file(
801
+ str(Path(training_args.output_dir) / "tokenizer.json")
802
+ )
803
+ artifact.add_file(
804
+ str(Path(training_args.output_dir) / "tokenizer_config.json")
805
+ )
806
+ artifact.add_file(
807
+ str(Path(training_args.output_dir) / "vocab.json")
808
+ )
809
+ artifact.add_file(
810
+ str(Path(training_args.output_dir) / "merges.txt")
811
+ )
812
+ artifact.add_file(
813
+ str(Path(training_args.output_dir) / "special_tokens_map.json")
814
+ )
815
+ artifact.add_file(
816
+ str(Path(training_args.output_dir) / "opt_state.msgpack")
817
+ )
818
+ artifact.add_file(
819
+ str(Path(training_args.output_dir) / "training_state.json")
820
+ )
821
+
822
+ wandb.run.log_artifact(artifact)
823
+
824
+ # save to the hub
825
+ if training_args.push_to_hub:
826
+ model.save_pretrained(
827
+ training_args.output_dir,
828
+ params=params,
829
+ push_to_hub=training_args.push_to_hub,
830
+ commit_message=f"Saving weights and logs at step {unreplicate(state.step)+1}",
831
+ temp_dir=True, # avoid issues with being in a repository
832
+ )
833
 
834
  # init variables
835
  last_time = time.perf_counter()
 
838
  for epoch in epochs:
839
  state.replace(epoch=jax_utils.replicate(epoch))
840
  # ======================== Training ================================
841
+ metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
842
 
843
  # Generate an epoch by shuffling sampling indices from the train dataset
844
  train_loader = dataset.dataloader("train", train_batch_size)
 
863
  step = unreplicate(state.step)
864
 
865
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
866
+ all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
867
+ metrics_logger.log(all_metrics, step=step, prefix="train")
 
 
 
 
 
 
868
 
869
  eval_metrics = None
870
  if training_args.eval_steps and step % training_args.eval_steps == 0:
 
875
 
876
  # log final train metrics
877
  if train_metrics is not None:
878
+ all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
879
+ metrics_logger.log(all_metrics, step=step, prefix="train")
880
 
881
  epochs.write(
882
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"