boris commited on
Commit
803ccbf
1 Parent(s): 2e02683

feat: support pod (#139)

Browse files
src/dalle_mini/data.py CHANGED
@@ -27,6 +27,7 @@ class Dataset:
27
  do_eval: bool = True
28
  seed_dataset: int = None
29
  shard_by_host: bool = False
 
30
  train_dataset: Dataset = field(init=False)
31
  eval_dataset: Dataset = field(init=False)
32
  rng_dataset: jnp.ndarray = field(init=False)
@@ -34,6 +35,11 @@ class Dataset:
34
 
35
  def __post_init__(self):
36
  self.multi_hosts = jax.process_count() > 1
 
 
 
 
 
37
  # define data_files
38
  if self.train_file is not None or self.validation_file is not None:
39
  # accept braceexpand notation
@@ -101,6 +107,25 @@ class Dataset:
101
  self.seed_dataset = np.random.get_state()[1][0]
102
  self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  # normalize text
105
  if normalize_text:
106
  text_normalizer = TextNormalizer()
@@ -144,6 +169,10 @@ class Dataset:
144
  getattr(self, ds).map(
145
  partial_preprocess_function,
146
  batched=True,
 
 
 
 
147
  )
148
  if self.streaming
149
  else getattr(self, ds).map(
@@ -193,8 +222,8 @@ class Dataset:
193
  while (self.multi_hosts and split == "train") or first_loop:
194
  # in multi-host, we run forever (no epoch) as hosts need to stop
195
  # at the same time and training data may not be split equally
196
- # For validation data we put the entire set on each host as we could lose
197
- # too many samples on pods
198
  if epoch is not None:
199
  assert split == "train"
200
  # reshuffle training data at each epoch
@@ -252,6 +281,12 @@ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
252
  return shifted_input_ids
253
 
254
 
 
 
 
 
 
 
255
  def normalize_function(example, text_column, text_normalizer):
256
  example[text_column] = text_normalizer(example[text_column])
257
  return example
 
27
  do_eval: bool = True
28
  seed_dataset: int = None
29
  shard_by_host: bool = False
30
+ blank_caption_prob: float = 0.0
31
  train_dataset: Dataset = field(init=False)
32
  eval_dataset: Dataset = field(init=False)
33
  rng_dataset: jnp.ndarray = field(init=False)
 
35
 
36
  def __post_init__(self):
37
  self.multi_hosts = jax.process_count() > 1
38
+ # feed blank captions only in streaming mode for now
39
+ if self.blank_caption_prob:
40
+ assert (
41
+ self.streaming is True
42
+ ), "blank_caption_prob can only be used in streaming mode"
43
  # define data_files
44
  if self.train_file is not None or self.validation_file is not None:
45
  # accept braceexpand notation
 
107
  self.seed_dataset = np.random.get_state()[1][0]
108
  self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
109
 
110
+ # blank captions
111
+ if self.blank_caption_prob:
112
+ partial_blank_caption_function = partial(
113
+ blank_caption_function,
114
+ text_column=self.text_column,
115
+ blank_caption_prob=self.blank_caption_prob,
116
+ )
117
+ if hasattr(self, "train_dataset"):
118
+ self.train_dataset = (
119
+ self.train_dataset.map(partial_blank_caption_function)
120
+ if self.streaming
121
+ else self.train_dataset.map(
122
+ partial_blank_caption_function,
123
+ num_proc=self.preprocessing_num_workers,
124
+ load_from_cache_file=False,
125
+ desc="Blanking some captions",
126
+ )
127
+ )
128
+
129
  # normalize text
130
  if normalize_text:
131
  text_normalizer = TextNormalizer()
 
169
  getattr(self, ds).map(
170
  partial_preprocess_function,
171
  batched=True,
172
+ remove_columns=[
173
+ self.text_column,
174
+ self.encoding_column,
175
+ ],
176
  )
177
  if self.streaming
178
  else getattr(self, ds).map(
 
222
  while (self.multi_hosts and split == "train") or first_loop:
223
  # in multi-host, we run forever (no epoch) as hosts need to stop
224
  # at the same time and training data may not be split equally
225
+ # For validation data we put the entire batch on each host and then
226
+ # keep only the one specific to each host (could be improved but not necessary)
227
  if epoch is not None:
228
  assert split == "train"
229
  # reshuffle training data at each epoch
 
281
  return shifted_input_ids
282
 
283
 
284
+ def blank_caption_function(example, text_column, blank_caption_prob):
285
+ if blank_caption_prob and np.random.rand() < blank_caption_prob:
286
+ example[text_column] = ""
287
+ return example
288
+
289
+
290
  def normalize_function(example, text_column, text_normalizer):
291
  example[text_column] = text_normalizer(example[text_column])
292
  return example
src/dalle_mini/model/modeling.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and the DalleBart team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
@@ -328,6 +328,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
328
  dtype: jnp.dtype = jnp.float32,
329
  abstract_init: bool = False,
330
  load_on_cpu: bool = False,
 
331
  **kwargs,
332
  ):
333
  module = self.module_class(config=config, dtype=dtype, **kwargs)
@@ -347,25 +348,34 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
347
  self.key = PRNGKey(seed)
348
  self.dtype = dtype
349
 
350
- # init weights on CPU
351
- if load_on_cpu:
352
- # init weights on CPU
353
- init_fn = jax.jit(self.init_weights, static_argnums=(1,), backend="cpu")
354
- else:
355
- init_fn = self.init_weights
 
 
 
 
 
 
356
 
357
- # randomly initialized parameters
358
- random_params = self.init_weights(self.key, input_shape)
 
 
 
 
 
 
359
  if abstract_init:
360
  # only set shape and dtype, load parameters separately
361
  init_fn = partial(init_fn, input_shape=input_shape)
362
- random_params = jax.eval_shape(init_fn, self.key)
363
  else:
364
- random_params = init_fn(self.key, input_shape)
365
-
366
- # save required_params as set
367
- self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
368
- self.params = random_params
369
 
370
  @property
371
  def num_params(self):
 
1
  # coding=utf-8
2
+ # Copyright 2021-2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and & DALL·E Mini team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
328
  dtype: jnp.dtype = jnp.float32,
329
  abstract_init: bool = False,
330
  load_on_cpu: bool = False,
331
+ init_weights: bool = True,
332
  **kwargs,
333
  ):
334
  module = self.module_class(config=config, dtype=dtype, **kwargs)
 
348
  self.key = PRNGKey(seed)
349
  self.dtype = dtype
350
 
351
+ if init_weights:
352
+ # get shape of params only
353
+ random_params = self.init_weights(
354
+ self.key,
355
+ input_shape,
356
+ abstract_init=abstract_init,
357
+ load_on_cpu=load_on_cpu,
358
+ )
359
+
360
+ # save required_params as set
361
+ self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
362
+ self.params = random_params
363
 
364
+ def init_weights(
365
+ self, rng=None, input_shape=(1, 1), abstract_init=False, load_on_cpu=False
366
+ ):
367
+ if rng is None:
368
+ rng = self.key
369
+ init_fn = super().init_weights
370
+ if load_on_cpu:
371
+ init_fn = jax.jit(init_fn, static_argnums=(1,), backend="cpu")
372
  if abstract_init:
373
  # only set shape and dtype, load parameters separately
374
  init_fn = partial(init_fn, input_shape=input_shape)
375
+ params = jax.eval_shape(init_fn, rng)
376
  else:
377
+ params = init_fn(rng, input_shape)
378
+ return params
 
 
 
379
 
380
  @property
381
  def num_params(self):
src/dalle_mini/model/utils.py CHANGED
@@ -23,12 +23,6 @@ class PretrainedFromWandbMixin:
23
  else:
24
  artifact = wandb.Api().artifact(pretrained_model_name_or_path)
25
  pretrained_model_name_or_path = artifact.download(tmp_dir)
26
- if artifact.metadata.get("bucket_path"):
27
- pretrained_model_name_or_path = artifact.metadata["bucket_path"]
28
-
29
- if pretrained_model_name_or_path.startswith("gs://"):
30
- copy_blobs(pretrained_model_name_or_path, tmp_dir)
31
- pretrained_model_name_or_path = tmp_dir
32
 
33
  return super(PretrainedFromWandbMixin, cls).from_pretrained(
34
  pretrained_model_name_or_path, *model_args, **kwargs
 
23
  else:
24
  artifact = wandb.Api().artifact(pretrained_model_name_or_path)
25
  pretrained_model_name_or_path = artifact.download(tmp_dir)
 
 
 
 
 
 
26
 
27
  return super(PretrainedFromWandbMixin, cls).from_pretrained(
28
  pretrained_model_name_or_path, *model_args, **kwargs
tools/inference/inference_pipeline.ipynb CHANGED
@@ -83,7 +83,7 @@
83
  "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
84
  "\n",
85
  "# CLIP model\n",
86
- "CLIP_REPO = \"openai/clip-vit-base-patch16\"\n",
87
  "CLIP_COMMIT_ID = None"
88
  ]
89
  },
@@ -129,7 +129,6 @@
129
  "from dalle_mini.model import DalleBart, DalleBartTokenizer\n",
130
  "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
131
  "from transformers import CLIPProcessor, FlaxCLIPModel\n",
132
- "import wandb\n",
133
  "\n",
134
  "# Load dalle-mini\n",
135
  "model = DalleBart.from_pretrained(\n",
@@ -168,9 +167,9 @@
168
  "if dtype == jnp.bfloat16:\n",
169
  " model.params = model.to_bf16(model.params)\n",
170
  "\n",
171
- "model_params = replicate(model.params)\n",
172
- "vqgan_params = replicate(vqgan.params)\n",
173
- "clip_params = replicate(clip.params)"
174
  ]
175
  },
176
  {
@@ -292,7 +291,7 @@
292
  },
293
  "outputs": [],
294
  "source": [
295
- "prompt = \"a blue table\""
296
  ]
297
  },
298
  {
@@ -414,12 +413,12 @@
414
  " key, subkey = jax.random.split(key)\n",
415
  " # generate images\n",
416
  " encoded_images = p_generate(\n",
417
- " tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p\n",
418
  " )\n",
419
  " # remove BOS\n",
420
  " encoded_images = encoded_images.sequences[..., 1:]\n",
421
  " # decode images\n",
422
- " decoded_images = p_decode(encoded_images, vqgan_params)\n",
423
  " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
424
  " for img in decoded_images:\n",
425
  " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
@@ -453,7 +452,7 @@
453
  " max_length=77,\n",
454
  " truncation=True,\n",
455
  ").data\n",
456
- "logits = p_clip(shard(clip_inputs), clip_params)\n",
457
  "logits = logits.squeeze().flatten()"
458
  ]
459
  },
@@ -479,6 +478,13 @@
479
  " display(images[idx])\n",
480
  " print(f\"Score: {logits[idx]:.2f}\\n\")"
481
  ]
 
 
 
 
 
 
 
482
  }
483
  ],
484
  "metadata": {
 
83
  "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
84
  "\n",
85
  "# CLIP model\n",
86
+ "CLIP_REPO = \"openai/clip-vit-large-patch14\"\n",
87
  "CLIP_COMMIT_ID = None"
88
  ]
89
  },
 
129
  "from dalle_mini.model import DalleBart, DalleBartTokenizer\n",
130
  "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
131
  "from transformers import CLIPProcessor, FlaxCLIPModel\n",
 
132
  "\n",
133
  "# Load dalle-mini\n",
134
  "model = DalleBart.from_pretrained(\n",
 
167
  "if dtype == jnp.bfloat16:\n",
168
  " model.params = model.to_bf16(model.params)\n",
169
  "\n",
170
+ "model._params = replicate(model.params)\n",
171
+ "vqgan._params = replicate(vqgan.params)\n",
172
+ "clip._params = replicate(clip.params)"
173
  ]
174
  },
175
  {
 
291
  },
292
  "outputs": [],
293
  "source": [
294
+ "prompt = \"view of the beach during sunset\""
295
  ]
296
  },
297
  {
 
413
  " key, subkey = jax.random.split(key)\n",
414
  " # generate images\n",
415
  " encoded_images = p_generate(\n",
416
+ " tokenized_prompt, shard_prng_key(subkey), model.params, gen_top_k, gen_top_p\n",
417
  " )\n",
418
  " # remove BOS\n",
419
  " encoded_images = encoded_images.sequences[..., 1:]\n",
420
  " # decode images\n",
421
+ " decoded_images = p_decode(encoded_images, vqgan.params)\n",
422
  " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
423
  " for img in decoded_images:\n",
424
  " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
 
452
  " max_length=77,\n",
453
  " truncation=True,\n",
454
  ").data\n",
455
+ "logits = p_clip(shard(clip_inputs), clip.params)\n",
456
  "logits = logits.squeeze().flatten()"
457
  ]
458
  },
 
478
  " display(images[idx])\n",
479
  " print(f\"Score: {logits[idx]:.2f}\\n\")"
480
  ]
481
+ },
482
+ {
483
+ "cell_type": "code",
484
+ "execution_count": null,
485
+ "metadata": {},
486
+ "outputs": [],
487
+ "source": []
488
  }
489
  ],
490
  "metadata": {
tools/train/config/medium/config.json CHANGED
@@ -28,6 +28,5 @@
28
  "pad_token_id": 16385,
29
  "scale_embedding": false,
30
  "tie_word_embeddings": false,
31
- "transformers_version": "4.13.0.dev0",
32
  "use_cache": true
33
  }
 
28
  "pad_token_id": 16385,
29
  "scale_embedding": false,
30
  "tie_word_embeddings": false,
 
31
  "use_cache": true
32
  }
tools/train/config/mega/config.json CHANGED
@@ -5,21 +5,20 @@
5
  "bos_token_id": 16385,
6
  "classifier_dropout": 0.0,
7
  "d_model": 2048,
8
- "decoder_attention_heads": 16,
9
- "decoder_ffn_dim": 4096,
10
  "decoder_layerdrop": 0.0,
11
- "decoder_layers": 31,
12
  "decoder_start_token_id": 16384,
13
- "dropout": 0.1,
14
- "encoder_attention_heads": 16,
15
- "encoder_ffn_dim": 4096,
16
  "encoder_layerdrop": 0.0,
17
- "encoder_layers": 31,
18
  "encoder_vocab_size": 50264,
19
  "eos_token_id": 16385,
20
- "gradient_checkpointing": false,
21
  "image_length": 256,
22
- "image_vocab_size": 16384,
23
  "init_std": 0.01,
24
  "is_encoder_decoder": true,
25
  "max_text_length": 64,
@@ -28,6 +27,5 @@
28
  "pad_token_id": 16385,
29
  "scale_embedding": false,
30
  "tie_word_embeddings": false,
31
- "transformers_version": "4.13.0.dev0",
32
  "use_cache": true
33
  }
 
5
  "bos_token_id": 16385,
6
  "classifier_dropout": 0.0,
7
  "d_model": 2048,
8
+ "decoder_attention_heads": 32,
9
+ "decoder_ffn_dim": 8192,
10
  "decoder_layerdrop": 0.0,
11
+ "decoder_layers": 24,
12
  "decoder_start_token_id": 16384,
13
+ "dropout": 0.0,
14
+ "encoder_attention_heads": 32,
15
+ "encoder_ffn_dim": 8192,
16
  "encoder_layerdrop": 0.0,
17
+ "encoder_layers": 24,
18
  "encoder_vocab_size": 50264,
19
  "eos_token_id": 16385,
 
20
  "image_length": 256,
21
+ "image_vocab_size": 16391,
22
  "init_std": 0.01,
23
  "is_encoder_decoder": true,
24
  "max_text_length": 64,
 
27
  "pad_token_id": 16385,
28
  "scale_embedding": false,
29
  "tie_word_embeddings": false,
 
30
  "use_cache": true
31
  }
tools/train/config/micro/config.json CHANGED
@@ -4,22 +4,21 @@
4
  "attention_dropout": 0.0,
5
  "bos_token_id": 16385,
6
  "classifier_dropout": 0.0,
7
- "d_model": 1024,
8
- "decoder_attention_heads": 16,
9
- "decoder_ffn_dim": 2048,
10
  "decoder_layerdrop": 0.0,
11
  "decoder_layers": 2,
12
  "decoder_start_token_id": 16384,
13
  "dropout": 0.0,
14
- "encoder_attention_heads": 16,
15
- "encoder_ffn_dim": 2048,
16
  "encoder_layerdrop": 0.0,
17
  "encoder_layers": 2,
18
  "encoder_vocab_size": 50264,
19
  "eos_token_id": 16385,
20
- "gradient_checkpointing": false,
21
  "image_length": 256,
22
- "image_vocab_size": 16384,
23
  "init_std": 0.02,
24
  "is_encoder_decoder": true,
25
  "max_text_length": 64,
@@ -28,6 +27,5 @@
28
  "pad_token_id": 16385,
29
  "scale_embedding": false,
30
  "tie_word_embeddings": false,
31
- "transformers_version": "4.13.0.dev0",
32
  "use_cache": true
33
  }
 
4
  "attention_dropout": 0.0,
5
  "bos_token_id": 16385,
6
  "classifier_dropout": 0.0,
7
+ "d_model": 256,
8
+ "decoder_attention_heads": 2,
9
+ "decoder_ffn_dim": 256,
10
  "decoder_layerdrop": 0.0,
11
  "decoder_layers": 2,
12
  "decoder_start_token_id": 16384,
13
  "dropout": 0.0,
14
+ "encoder_attention_heads": 2,
15
+ "encoder_ffn_dim": 256,
16
  "encoder_layerdrop": 0.0,
17
  "encoder_layers": 2,
18
  "encoder_vocab_size": 50264,
19
  "eos_token_id": 16385,
 
20
  "image_length": 256,
21
+ "image_vocab_size": 16391,
22
  "init_std": 0.02,
23
  "is_encoder_decoder": true,
24
  "max_text_length": 64,
 
27
  "pad_token_id": 16385,
28
  "scale_embedding": false,
29
  "tie_word_embeddings": false,
 
30
  "use_cache": true
31
  }
tools/train/config/mini/config.json CHANGED
@@ -28,6 +28,5 @@
28
  "pad_token_id": 16385,
29
  "scale_embedding": false,
30
  "tie_word_embeddings": false,
31
- "transformers_version": "4.13.0.dev0",
32
  "use_cache": true
33
  }
 
28
  "pad_token_id": 16385,
29
  "scale_embedding": false,
30
  "tie_word_embeddings": false,
 
31
  "use_cache": true
32
  }
tools/train/scalable_shampoo/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Notes
2
+
3
+ Files copied from [google-research/scalable_shampoo/optax](https://github.com/google-research/google-research/tree/master/scalable_shampoo/optax).
4
+
5
+ Imports have been modified to be relative.
6
+
7
+ This will be replaced with `optax-shampoo` package eventually.
tools/train/{distributed_shampoo.py → scalable_shampoo/distributed_shampoo.py} RENAMED
@@ -1,5 +1,3 @@
1
- # file from: https://github.com/google-research/google-research/blob/master/scalable_shampoo/optax/distributed_shampoo.py
2
-
3
  # coding=utf-8
4
  # Copyright 2022 The Google Research Authors.
5
  #
@@ -44,107 +42,12 @@ import optax
44
  from flax import struct
45
  from jax import lax
46
 
 
47
 
48
- # pylint:disable=no-value-for-parameter
49
- @struct.dataclass
50
- class QuantizedValue:
51
- """State associated with quantized value."""
52
-
53
- quantized: chex.Array
54
- diagonal: chex.Array # Diagonal (if extract_diagonal is set)
55
- bucket_size: chex.Array
56
- quantized_dtype: jnp.dtype = struct.field(
57
- pytree_node=False
58
- ) # Dtype for the quantized value.
59
- extract_diagonal: bool = struct.field(pytree_node=False) # In case its centered.
60
- shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
61
-
62
- @classmethod
63
- def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
64
- if isinstance(fvalue, list) and not fvalue:
65
- return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
66
- quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
67
- fvalue, quantized_dtype, extract_diagonal
68
- )
69
- return QuantizedValue(
70
- quantized,
71
- diagonal_fvalue,
72
- bucket_size,
73
- quantized_dtype,
74
- extract_diagonal,
75
- list(quantized.shape),
76
- )
77
-
78
- # Quantization is from Lingvo JAX optimizers.
79
- # We extend it for int16 quantization of PSD matrices.
80
- @classmethod
81
- def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
82
- """Returns quantized value and the bucket."""
83
- if quantized_dtype == jnp.float32:
84
- return fvalue, [], []
85
- elif quantized_dtype == jnp.bfloat16:
86
- return fvalue.astype(jnp.bfloat16), [], []
87
-
88
- float_dtype = fvalue.dtype
89
- if quantized_dtype == jnp.int8:
90
- # value -128 is not used.
91
- num_buckets = jnp.array(127.0, dtype=float_dtype)
92
- elif quantized_dtype == jnp.int16:
93
- # value -32768 is not used.
94
- num_buckets = jnp.array(32767.0, dtype=float_dtype)
95
- else:
96
- raise ValueError(f"Quantized dtype {quantized_dtype} not supported.")
97
- # max value is mapped to num_buckets
98
-
99
- if extract_diagonal and fvalue.ndim != 2:
100
- raise ValueError(
101
- f"Input array {fvalue} must be 2D to work with extract_diagonal."
102
- )
103
-
104
- diagonal_fvalue = []
105
- if extract_diagonal:
106
- diagonal_fvalue = jnp.diag(fvalue)
107
- # Remove the diagonal entries.
108
- fvalue = fvalue - jnp.diag(diagonal_fvalue)
109
-
110
- # TODO(rohananil): Extend this by making use of information about the blocks
111
- # SM3 style which will be useful for diagonal statistics
112
- # We first decide the scale.
113
- if fvalue.ndim < 1:
114
- raise ValueError(
115
- f"Input array {fvalue} must have a strictly positive number of "
116
- "dimensions."
117
- )
118
-
119
- max_abs = jnp.max(jnp.abs(fvalue), axis=0)
120
- bucket_size = max_abs / num_buckets
121
- bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
122
- # To avoid divide by 0.0
123
- bs_nonzero = jnp.where(
124
- bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)
125
- )
126
- ratio = fvalue / bs_nonzero
127
- # We use rounding to remove bias.
128
- quantized = jnp.round(ratio)
129
- return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
130
-
131
- def to_float(self):
132
- """Returns the float value."""
133
- if isinstance(self.quantized, list) and not self.quantized:
134
- return self.quantized
135
-
136
- if self.quantized_dtype == jnp.float32:
137
- return self.quantized
138
-
139
- if self.quantized_dtype == jnp.bfloat16:
140
- return self.quantized.astype(jnp.float32)
141
-
142
- float_dtype = self.bucket_size.dtype
143
- bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
144
- val = self.quantized.astype(float_dtype) * bucket_size
145
- if self.extract_diagonal:
146
- val += jnp.diag(self.diagonal)
147
- return val
148
 
149
 
150
  @struct.dataclass
@@ -193,24 +96,21 @@ class LocalShardedParameterStats:
193
 
194
 
195
  def init_training_metrics(num_statistics):
196
- if num_statistics:
197
- return TrainingMetrics(jnp.zeros([num_statistics], jnp.float32))
198
- else:
199
- return TrainingMetrics([])
200
 
201
 
202
  def init_training_metrics_shapes(num_statistics):
203
- if num_statistics:
204
- return TrainingMetrics([[num_statistics], jnp.float32])
205
- else:
206
- return TrainingMetrics([None, jnp.float32])
207
 
208
 
209
- def init_training_metrics_pspec(num_statistics):
210
- if num_statistics:
211
- return TrainingMetrics(pjit.PartitionSpec())
212
- else:
213
- return TrainingMetrics(None)
214
 
215
 
216
  class ShardedShampooStats(NamedTuple):
@@ -296,6 +196,30 @@ def power_iteration(
296
  return v_out, s_out
297
 
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  def matrix_inverse_pth_root(
300
  matrix,
301
  p,
@@ -332,57 +256,19 @@ def matrix_inverse_pth_root(
332
 
333
  assert matrix.shape[0] == matrix.shape[1]
334
 
335
- # We use float32 for the matrix inverse pth root.
336
- # Switch to f64 if you have hardware that supports it.
 
337
  matrix_size = matrix.shape[0]
338
- alpha = jnp.asarray(-1.0 / p, jnp.float32)
339
- identity = jnp.eye(matrix_size, dtype=jnp.float32)
 
 
340
  _, max_ev = power_iteration(
341
  matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision
342
  )
343
  ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-6)
344
 
345
- def _unrolled_mat_pow_1(mat_m):
346
- """Computes mat_m^1."""
347
- return mat_m
348
-
349
- def _unrolled_mat_pow_2(mat_m):
350
- """Computes mat_m^2."""
351
- return jnp.matmul(mat_m, mat_m, precision=precision)
352
-
353
- def _unrolled_mat_pow_4(mat_m):
354
- """Computes mat_m^4."""
355
- mat_pow_2 = _unrolled_mat_pow_2(mat_m)
356
- return jnp.matmul(mat_pow_2, mat_pow_2, precision=precision)
357
-
358
- def _unrolled_mat_pow_8(mat_m):
359
- """Computes mat_m^4."""
360
- mat_pow_4 = _unrolled_mat_pow_4(mat_m)
361
- return jnp.matmul(mat_pow_4, mat_pow_4, precision=precision)
362
-
363
- def mat_power(mat_m, p):
364
- """Computes mat_m^p, for p == 1, 2, 4 or 8.
365
-
366
- Args:
367
- mat_m: a square matrix
368
- p: a positive integer
369
-
370
- Returns:
371
- mat_m^p
372
- """
373
- # We unrolled the loop for performance reasons.
374
- exponent = jnp.round(jnp.log2(p))
375
- return lax.switch(
376
- jnp.asarray(exponent, jnp.int32),
377
- [
378
- _unrolled_mat_pow_1,
379
- _unrolled_mat_pow_2,
380
- _unrolled_mat_pow_4,
381
- _unrolled_mat_pow_8,
382
- ],
383
- (mat_m),
384
- )
385
-
386
  def _iter_condition(state):
387
  (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state
388
  error_above_threshold = jnp.logical_and(error > error_tolerance, run_step)
@@ -412,10 +298,10 @@ def matrix_inverse_pth_root(
412
  _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
413
  _iter_condition, _iter_body, init_state
414
  )
415
- error = jnp.max(jnp.abs(mat_m - identity))
416
  is_converged = jnp.asarray(convergence, old_mat_h.dtype)
417
  resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
418
- resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype)
419
  return resultant_mat_h, error
420
 
421
 
@@ -433,6 +319,9 @@ def merge_small_dims(shape_to_merge, max_dim):
433
  Returns:
434
  Merged shape.
435
  """
 
 
 
436
  resulting_shape = []
437
  product = 1
438
  for d in shape_to_merge:
@@ -975,16 +864,22 @@ def distributed_shampoo(
975
  )
976
 
977
  local_stats = jax.tree_unflatten(treedef, local_stats_flat)
 
 
 
 
 
 
 
978
  # Pad the statistics and preconditioner matrices to be a multiple of
979
  # num devices.
980
  # TODO(rohananil): Relax to only the size of the mesh axis where the dim
981
  # is split on.
982
- to_pad = -len(padded_statistics) % num_devices_for_pjit
983
  padded_statistics.extend(
984
- [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
985
  )
986
  padded_preconditioners.extend(
987
- [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
988
  )
989
  exponents.extend([1 for _ in range(to_pad)])
990
  global_stats = GlobalShardedParameterStats(
@@ -1016,7 +911,7 @@ def distributed_shampoo(
1016
  if pspec and len(pspec) > 1:
1017
  return pjit.PartitionSpec(*pspec[1:])
1018
  else:
1019
- return None
1020
 
1021
  def sharded_init_partition_spec_fn(
1022
  params, params_partition_spec, partition_spec_for_statistics
@@ -1102,7 +997,7 @@ def distributed_shampoo(
1102
  False,
1103
  list(param.shape),
1104
  ),
1105
- init_training_metrics_pspec(len(sizes)),
1106
  index_start,
1107
  sizes,
1108
  )
@@ -1209,6 +1104,9 @@ def distributed_shampoo(
1209
  max_statistics_size = _max_statistics_size_from_params(params_flat)
1210
  to_pad = -num_statistics % num_devices_for_pjit
1211
  num_statistics += to_pad
 
 
 
1212
  statistics_shape = [num_statistics, max_statistics_size, max_statistics_size]
1213
  global_stats = GlobalShardedParameterStats(
1214
  [statistics_shape, jnp.float32],
@@ -2069,7 +1967,7 @@ def distributed_shampoo(
2069
 
2070
  scaled_grad = grad
2071
  if graft_type == GraftingType.ADAGRAD_NORMALIZED:
2072
- scaled_grad = grad / jnp.linalg.norm(grad)
2073
 
2074
  new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square(
2075
  scaled_grad
@@ -2085,7 +1983,7 @@ def distributed_shampoo(
2085
 
2086
  scaled_grad = grad
2087
  if graft_type == GraftingType.RMSPROP_NORMALIZED:
2088
- scaled_grad = grad / jnp.linalg.norm(grad)
2089
 
2090
  w1 = beta2
2091
  w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
@@ -2212,7 +2110,6 @@ def distributed_shampoo(
2212
  new_stats_flat = _compute_preconditioners(
2213
  new_stats_flat, params_flat, state.count
2214
  )
2215
-
2216
  outputs = jax.tree_multimap(
2217
  lambda g, s, p: _transform_grad(g, s, p, state.count),
2218
  grads_flat,
 
 
 
1
  # coding=utf-8
2
  # Copyright 2022 The Google Research Authors.
3
  #
 
42
  from flax import struct
43
  from jax import lax
44
 
45
+ from .quantization_utils import QuantizedValue
46
 
47
+ # Dtype for inverse-pth root routine
48
+ # Switch to f64 if you have hardware that supports it. Enable the jax flag
49
+ # jax_enable_x64 for this to work, otherwise it will default to float32.
50
+ _MAT_INV_PTH_ROOT_DTYPE = jnp.float64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
  @struct.dataclass
 
96
 
97
 
98
  def init_training_metrics(num_statistics):
99
+ # Since the downstream apis expect a jnp.array - we create a dummy one if
100
+ # num_statistics=0.
101
+ n = 1 if not num_statistics else num_statistics
102
+ return TrainingMetrics(jnp.zeros([n], jnp.float32))
103
 
104
 
105
  def init_training_metrics_shapes(num_statistics):
106
+ # Since the downstream apis expect a jnp.array - we create a dummy one if
107
+ # num_statistics=0.
108
+ n = 1 if not num_statistics else num_statistics
109
+ return TrainingMetrics([[n], jnp.float32])
110
 
111
 
112
+ def init_training_metrics_pspec():
113
+ return TrainingMetrics(pjit.PartitionSpec())
 
 
 
114
 
115
 
116
  class ShardedShampooStats(NamedTuple):
 
196
  return v_out, s_out
197
 
198
 
199
+ def mat_power(mat_m, p, precision=lax.Precision.HIGHEST):
200
+ """A simple matrix power method. M^p where p can be TracedValue."""
201
+ power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
202
+
203
+ def _iter_condition(state):
204
+ i, _, _ = state
205
+ return i > 0
206
+
207
+ def _iter_body(state):
208
+ i, power, mat = state
209
+
210
+ power = jax.lax.cond(
211
+ i % 2 == 1,
212
+ lambda: jnp.matmul(mat, power, precision=precision),
213
+ lambda: power,
214
+ )
215
+ i //= 2
216
+ mat = jnp.matmul(mat, mat, precision=precision)
217
+ return i, power, mat
218
+
219
+ _, result, _ = lax.while_loop(_iter_condition, _iter_body, (p, power, mat_m))
220
+ return result
221
+
222
+
223
  def matrix_inverse_pth_root(
224
  matrix,
225
  p,
 
256
 
257
  assert matrix.shape[0] == matrix.shape[1]
258
 
259
+ # We use _MAT_INV_PTH_ROOT_DTYPE for the matrix inverse pth root.
260
+ # Switch to f64 if you have hardware that supports it. Enable the jax flag
261
+ # jax_enable_x64 for this to work.
262
  matrix_size = matrix.shape[0]
263
+ orig_dtype = matrix.dtype
264
+ matrix = matrix.astype(_MAT_INV_PTH_ROOT_DTYPE)
265
+ alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE)
266
+ identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
267
  _, max_ev = power_iteration(
268
  matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision
269
  )
270
  ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-6)
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  def _iter_condition(state):
273
  (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, run_step) = state
274
  error_above_threshold = jnp.logical_and(error > error_tolerance, run_step)
 
298
  _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
299
  _iter_condition, _iter_body, init_state
300
  )
301
+ error = jnp.max(jnp.abs(mat_m - identity)).astype(jnp.float32)
302
  is_converged = jnp.asarray(convergence, old_mat_h.dtype)
303
  resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
304
+ resultant_mat_h = jnp.asarray(resultant_mat_h, orig_dtype)
305
  return resultant_mat_h, error
306
 
307
 
 
319
  Returns:
320
  Merged shape.
321
  """
322
+ if shape_to_merge and np.all(np.array(shape_to_merge) == 1):
323
+ return [1]
324
+
325
  resulting_shape = []
326
  product = 1
327
  for d in shape_to_merge:
 
864
  )
865
 
866
  local_stats = jax.tree_unflatten(treedef, local_stats_flat)
867
+ to_pad = -len(padded_statistics) % num_devices_for_pjit
868
+ if max_size == 0:
869
+ to_pad = num_devices_for_pjit
870
+ max_size = block_size
871
+ stat_dtype = jnp.float32
872
+ else:
873
+ stat_dtype = padded_statistics[0].dtype
874
  # Pad the statistics and preconditioner matrices to be a multiple of
875
  # num devices.
876
  # TODO(rohananil): Relax to only the size of the mesh axis where the dim
877
  # is split on.
 
878
  padded_statistics.extend(
879
+ [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
880
  )
881
  padded_preconditioners.extend(
882
+ [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
883
  )
884
  exponents.extend([1 for _ in range(to_pad)])
885
  global_stats = GlobalShardedParameterStats(
 
911
  if pspec and len(pspec) > 1:
912
  return pjit.PartitionSpec(*pspec[1:])
913
  else:
914
+ return []
915
 
916
  def sharded_init_partition_spec_fn(
917
  params, params_partition_spec, partition_spec_for_statistics
 
997
  False,
998
  list(param.shape),
999
  ),
1000
+ init_training_metrics_pspec(),
1001
  index_start,
1002
  sizes,
1003
  )
 
1104
  max_statistics_size = _max_statistics_size_from_params(params_flat)
1105
  to_pad = -num_statistics % num_devices_for_pjit
1106
  num_statistics += to_pad
1107
+ if num_statistics == 0:
1108
+ num_statistics = num_devices_for_pjit
1109
+ max_statistics_size = block_size
1110
  statistics_shape = [num_statistics, max_statistics_size, max_statistics_size]
1111
  global_stats = GlobalShardedParameterStats(
1112
  [statistics_shape, jnp.float32],
 
1967
 
1968
  scaled_grad = grad
1969
  if graft_type == GraftingType.ADAGRAD_NORMALIZED:
1970
+ scaled_grad = grad / (jnp.linalg.norm(grad) + 1e-16)
1971
 
1972
  new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square(
1973
  scaled_grad
 
1983
 
1984
  scaled_grad = grad
1985
  if graft_type == GraftingType.RMSPROP_NORMALIZED:
1986
+ scaled_grad = grad / (jnp.linalg.norm(grad) + 1e-16)
1987
 
1988
  w1 = beta2
1989
  w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
 
2110
  new_stats_flat = _compute_preconditioners(
2111
  new_stats_flat, params_flat, state.count
2112
  )
 
2113
  outputs = jax.tree_multimap(
2114
  lambda g, s, p: _transform_grad(g, s, p, state.count),
2115
  grads_flat,
tools/train/scalable_shampoo/quantization_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Helper routines for quantization."""
17
+
18
+ from typing import Any
19
+
20
+ import chex
21
+ import jax.numpy as jnp
22
+ from flax import struct
23
+
24
+
25
+ # pylint:disable=no-value-for-parameter
26
+ @struct.dataclass
27
+ class QuantizedValue:
28
+ """State associated with quantized value."""
29
+
30
+ quantized: chex.Array
31
+ diagonal: chex.Array # Diagonal (if extract_diagonal is set)
32
+ bucket_size: chex.Array
33
+ quantized_dtype: jnp.dtype = struct.field(
34
+ pytree_node=False
35
+ ) # Dtype for the quantized value.
36
+ extract_diagonal: bool = struct.field(pytree_node=False) # In case its centered.
37
+ shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
38
+
39
+ @classmethod
40
+ def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
41
+ if isinstance(fvalue, list) and not fvalue:
42
+ return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
43
+ quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
44
+ fvalue, quantized_dtype, extract_diagonal
45
+ )
46
+ return QuantizedValue(
47
+ quantized,
48
+ diagonal_fvalue,
49
+ bucket_size,
50
+ quantized_dtype,
51
+ extract_diagonal,
52
+ list(quantized.shape),
53
+ )
54
+
55
+ # Quantization is from Lingvo JAX optimizers.
56
+ # We extend it for int16 quantization of PSD matrices.
57
+ @classmethod
58
+ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
59
+ """Returns quantized value and the bucket."""
60
+ if quantized_dtype == jnp.float32:
61
+ return fvalue, [], []
62
+ elif quantized_dtype == jnp.bfloat16:
63
+ return fvalue.astype(jnp.bfloat16), [], []
64
+
65
+ float_dtype = fvalue.dtype
66
+ if quantized_dtype == jnp.int8:
67
+ # value -128 is not used.
68
+ num_buckets = jnp.array(127.0, dtype=float_dtype)
69
+ elif quantized_dtype == jnp.int16:
70
+ # value -32768 is not used.
71
+ num_buckets = jnp.array(32767.0, dtype=float_dtype)
72
+ else:
73
+ raise ValueError(f"Quantized dtype {quantized_dtype} not supported.")
74
+ # max value is mapped to num_buckets
75
+
76
+ if extract_diagonal and fvalue.ndim != 2:
77
+ raise ValueError(
78
+ f"Input array {fvalue} must be 2D to work with extract_diagonal."
79
+ )
80
+
81
+ diagonal_fvalue = []
82
+ if extract_diagonal:
83
+ diagonal_fvalue = jnp.diag(fvalue)
84
+ # Remove the diagonal entries.
85
+ fvalue = fvalue - jnp.diag(diagonal_fvalue)
86
+
87
+ # TODO(rohananil): Extend this by making use of information about the blocks
88
+ # SM3 style which will be useful for diagonal statistics
89
+ # We first decide the scale.
90
+ if fvalue.ndim < 1:
91
+ raise ValueError(
92
+ f"Input array {fvalue} must have a strictly positive number of "
93
+ "dimensions."
94
+ )
95
+
96
+ max_abs = jnp.max(jnp.abs(fvalue), axis=0)
97
+ bucket_size = max_abs / num_buckets
98
+ bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
99
+ # To avoid divide by 0.0
100
+ bs_nonzero = jnp.where(
101
+ bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)
102
+ )
103
+ ratio = fvalue / bs_nonzero
104
+ # We use rounding to remove bias.
105
+ quantized = jnp.round(ratio)
106
+ return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
107
+
108
+ def to_float(self):
109
+ """Returns the float value."""
110
+ if isinstance(self.quantized, list) and not self.quantized:
111
+ return self.quantized
112
+
113
+ if self.quantized_dtype == jnp.float32:
114
+ return self.quantized
115
+
116
+ if self.quantized_dtype == jnp.bfloat16:
117
+ return self.quantized.astype(jnp.float32)
118
+
119
+ float_dtype = self.bucket_size.dtype
120
+ bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
121
+ val = self.quantized.astype(float_dtype) * bucket_size
122
+ if self.extract_diagonal:
123
+ val += jnp.diag(self.diagonal)
124
+ return val
tools/train/scalable_shampoo/sm3.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # An implementation of SM3 from:
17
+ #
18
+ # Memory-Efficient Adaptive Optimization, https://arxiv.org/pdf/1901.11150.pdf
19
+ # Rohan Anil, Vineet Gupta, Tomer Koren, Yoram Singer
20
+ #
21
+ # Author: Rohan Anil (rohananil at google dot com)
22
+ #
23
+
24
+ """SM3 Implementation."""
25
+
26
+ import functools
27
+ from typing import Any, NamedTuple
28
+
29
+ import chex
30
+ import jax
31
+ import jax.numpy as jnp
32
+ import optax
33
+
34
+ from .quantization_utils import QuantizedValue
35
+
36
+
37
+ class SM3State(NamedTuple):
38
+ count: chex.Array
39
+ stats: Any
40
+
41
+
42
+ # Per parameter optimizer state used in data-parallel training.
43
+ class ParameterStats(NamedTuple):
44
+ """State associated to each parameter of the model being trained."""
45
+
46
+ diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner
47
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
48
+
49
+
50
+ def sm3(
51
+ learning_rate, beta1=0.9, beta2=0.999, diagonal_epsilon=1e-10, normalize_grads=False
52
+ ):
53
+ """SM3 optimizer.
54
+
55
+ Memory-Efficient Adaptive Optimization, Rohan Anil, Vineet Gupta, Tomer Koren,
56
+ Yoram Singer
57
+
58
+ https://arxiv.org/abs/1901.11150
59
+
60
+ Args:
61
+ learning_rate: the step size used to update the parameters.
62
+ beta1: momentum parameter.
63
+ beta2: second moment averaging parameter.
64
+ diagonal_epsilon: epsilon for sm3
65
+ normalize_grads: Whether to normalize grads. Author finds it useful when
66
+ grads are high variance.
67
+
68
+ Returns:
69
+ a GradientTransformation.
70
+ """
71
+
72
+ def _quantize_momentum(momentum_statistics):
73
+ return QuantizedValue.from_float_value(momentum_statistics, jnp.int8)
74
+
75
+ def init_fn(params):
76
+ """Initialise the optimiser's state."""
77
+
78
+ def _init(param):
79
+ accumulators = [jnp.zeros([s]) for s in param.shape]
80
+ momentum = _quantize_momentum(jnp.zeros_like(param))
81
+ return ParameterStats(accumulators, momentum)
82
+
83
+ return SM3State(
84
+ count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)
85
+ )
86
+
87
+ def _get_expanded_shape(shape, i):
88
+ rank = len(shape)
89
+ # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i.
90
+ # For eg: i = 1 returns [1, N, 1].
91
+ return [1] * i + [shape[i]] + [1] * (rank - i - 1)
92
+
93
+ def _moving_averages(grad, accumulators):
94
+ w = (1.0 - beta2) if beta2 != 1.0 else 1.0
95
+ if grad.ndim < 2:
96
+ return beta2 * accumulators[0] + w * grad**2
97
+ else:
98
+ min_accumulator = functools.reduce(jnp.minimum, accumulators)
99
+ return beta2 * min_accumulator + w * grad**2
100
+
101
+ def _moving_averages_momentum(grad, momentum):
102
+ w = (1.0 - beta1) if beta1 != 1.0 else 1.0
103
+ return beta1 * momentum.to_float() + w * grad
104
+
105
+ def _sketch_diagonal_statistics(grad, updated_diagonal_statistics):
106
+ all_diagonal_statistics = []
107
+ for i in range(grad.ndim):
108
+ axes = list(range(i)) + list(range(i + 1, grad.ndim))
109
+ dim_diagonal_statistics = jnp.max(updated_diagonal_statistics, axis=axes)
110
+ all_diagonal_statistics.append(dim_diagonal_statistics)
111
+ if grad.ndim == 1:
112
+ all_diagonal_statistics[0] = updated_diagonal_statistics
113
+ return all_diagonal_statistics
114
+
115
+ def update_fn(updates, state, params=None):
116
+ del params
117
+ stats = state.stats
118
+ if normalize_grads:
119
+ updates = jax.tree_map(lambda g: g / (jnp.linalg.norm(g) + 1e-16), updates)
120
+ # Reshape all vectors into N-d tensors to compute min over them.
121
+ # [n], [m] -> [n, 1], [1, m]
122
+ expanded_diagonal_statistics = jax.tree_multimap(
123
+ lambda grad, state: [ # pylint:disable=g-long-lambda
124
+ jnp.reshape(
125
+ state.diagonal_statistics[i], _get_expanded_shape(grad.shape, i)
126
+ )
127
+ for i in range(grad.ndim)
128
+ ],
129
+ updates,
130
+ stats,
131
+ )
132
+
133
+ # Compute new diagonal statistics
134
+ new_diagonal_statistics = jax.tree_multimap(
135
+ _moving_averages, updates, expanded_diagonal_statistics
136
+ )
137
+
138
+ # Compute preconditioners (1/sqrt(s)) where s is the statistics.
139
+ new_preconditioners = jax.tree_map(
140
+ lambda t: 1.0 / jnp.sqrt(t + diagonal_epsilon), new_diagonal_statistics
141
+ )
142
+ preconditioned_grads = jax.tree_multimap(
143
+ lambda g, p: g * p, updates, new_preconditioners
144
+ )
145
+
146
+ # Compute updated momentum (also handle quantization)
147
+ updated_momentum = jax.tree_multimap(
148
+ lambda preconditioned_grad, state: _moving_averages_momentum( # pylint:disable=g-long-lambda
149
+ preconditioned_grad, state.diagonal_momentum
150
+ ),
151
+ preconditioned_grads,
152
+ stats,
153
+ )
154
+
155
+ # Update diagonal statistics.
156
+ updated_diagonal_statistics = jax.tree_multimap(
157
+ _sketch_diagonal_statistics, updates, new_diagonal_statistics
158
+ )
159
+
160
+ # Update momentum.
161
+ new_sm3_stats = jax.tree_multimap(
162
+ lambda momentum, diagonal_stats: ParameterStats( # pylint:disable=g-long-lambda
163
+ diagonal_stats, _quantize_momentum(momentum)
164
+ ),
165
+ updated_momentum,
166
+ updated_diagonal_statistics,
167
+ )
168
+
169
+ lr = learning_rate
170
+ if callable(learning_rate):
171
+ lr = learning_rate(state.count)
172
+
173
+ new_updates = jax.tree_map(lambda pg: -lr * pg, updated_momentum)
174
+ return new_updates, SM3State(count=state.count + 1, stats=new_sm3_stats)
175
+
176
+ return optax.GradientTransformation(init_fn, update_fn)
tools/train/scalable_shampoo/symmetric_matrices/symmetric_matrices.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """JAX Ops for symmetric matrices used by the Shampoo optimizer."""
17
+
18
+ import functools
19
+ from typing import List, Union
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+ from flax import struct
24
+ from jax import lax
25
+
26
+
27
+ @struct.dataclass
28
+ class SlicedSymmetricMatrix:
29
+ """A symmetric matrix represented by lower-triangular block row slices.
30
+
31
+ For example, the symmetric matrix M = [[a, b^T], [b, c]] would be represented
32
+ by the block rows a and [b, c].
33
+
34
+ The matrix may be batched, in which case each entry of block_rows may have
35
+ dimension greater than 2. The last two dimensions represent the rows and cols.
36
+ """
37
+
38
+ block_rows: List[jnp.ndarray]
39
+
40
+
41
+ def product_with_transpose(
42
+ mat1,
43
+ mat2,
44
+ precision=lax.Precision.DEFAULT,
45
+ ):
46
+ """Returns mat1 * mat2^T for two matrices (possibly batched).
47
+
48
+ The rows and columns are the last two dimensions for each matrix.
49
+
50
+ Args:
51
+ mat1: First matrix.
52
+ mat2: Second matrix.
53
+ precision: JAX precision to use for the multiplication.
54
+ """
55
+ return jnp.einsum("...ij,...kj->...ik", mat1, mat2, precision=precision)
56
+
57
+
58
+ @functools.partial(jax.jit, static_argnames=("block_size", "precision"))
59
+ def sliced_transposed_product(
60
+ mat,
61
+ block_size,
62
+ precision=lax.Precision.DEFAULT,
63
+ ):
64
+ """Returns the blocked slices representing a symmetric matrix mat*mat^T.
65
+
66
+ Args:
67
+ mat: The matrix for which we will compute mat*mat^T. It does not need to be
68
+ square, and may be batched.
69
+ block_size: The size of row blocks to compute.
70
+ precision: The precision to use in each computation.
71
+
72
+ Raises:
73
+ ValueError: Raised when the specified block size does not evenly divide
74
+ the number of rows of the input mat.
75
+ """
76
+ num_rows = mat.shape[-2]
77
+ if num_rows % block_size != 0:
78
+ raise ValueError(
79
+ "The row dimension must be divisible by block_size. "
80
+ f"Instead got row dimension={num_rows} and block_size={block_size}."
81
+ )
82
+ block_rows = [
83
+ product_with_transpose(
84
+ mat[Ellipsis, i * block_size : (i + 1) * block_size, :],
85
+ mat[Ellipsis, 0 : (i + 1) * block_size, :],
86
+ precision,
87
+ )
88
+ for i in range(num_rows // block_size)
89
+ ]
90
+ return SlicedSymmetricMatrix(block_rows=block_rows)
91
+
92
+
93
+ @functools.partial(jax.jit, static_argnames=("block_size", "precision"))
94
+ def sliced_transposed_product_concat(
95
+ mat,
96
+ block_size,
97
+ precision=lax.Precision.DEFAULT,
98
+ ):
99
+ """Returns the concatenated slices representing mat*mat^T.
100
+
101
+ Args:
102
+ mat: The matrix for which we will compute mat*mat^T. It does not need to be
103
+ square, and may be batched.
104
+ block_size: The size of row blocks to compute.
105
+ precision: The precision to use in each computation.
106
+
107
+ Raises:
108
+ ValueError: Raised when the specified block size does not evenly divide
109
+ the number of rows of the input mat.
110
+ """
111
+ sliced_symmetric_matrix = sliced_transposed_product(
112
+ mat=mat, block_size=block_size, precision=precision
113
+ )
114
+ return jnp.concatenate(sliced_symmetric_matrix.block_rows, axis=-1)
115
+
116
+
117
+ @jax.jit
118
+ def materialize_matrix(symmetric_matrix):
119
+ """Returns a materialized symmetric matrix.
120
+
121
+ Args:
122
+ symmetric_matrix: the matrix represented by lower-triangular block slices.
123
+ """
124
+ block_rows = symmetric_matrix.block_rows
125
+ block_size = block_rows[0].shape[-2]
126
+ num_blocks = len(block_rows)
127
+
128
+ # Slice the lower-triangular and diagonal blocks into blocks.
129
+ blocks = [
130
+ [
131
+ block_row[Ellipsis, i * block_size : (i + 1) * block_size]
132
+ for i in range(k + 1)
133
+ ]
134
+ for k, block_row in enumerate(block_rows)
135
+ ]
136
+
137
+ # Generate the (off-diagonal) upper-triangular blocks.
138
+ off_diags = [[] for _ in range(num_blocks - 1)]
139
+ for k, block_row in enumerate(block_rows[1:]):
140
+ for i in range(k + 1):
141
+ off_diags[i].append(
142
+ jnp.swapaxes(
143
+ a=block_row[Ellipsis, i * block_size : (i + 1) * block_size],
144
+ axis1=-1,
145
+ axis2=-2,
146
+ )
147
+ )
148
+
149
+ return jnp.block(
150
+ [row + row_t for row, row_t in zip(blocks[:-1], off_diags)] + [blocks[-1]]
151
+ )
152
+
153
+
154
+ @functools.partial(jax.jit, static_argnames=("num_blocks"))
155
+ def materialize_matrix_from_concat(
156
+ block_rows_concat,
157
+ num_blocks,
158
+ ):
159
+ """Returns a materialized symmetric matrix from concatenated slices.
160
+
161
+ Args:
162
+ block_rows_concat: The matrix represented as the concatenated
163
+ lower-triangular blocks.
164
+ num_blocks: The number of block-rows used to represent the symmetric matrix.
165
+ """
166
+ block_size = block_rows_concat.shape[-2]
167
+
168
+ block_rows = [
169
+ block_rows_concat[
170
+ Ellipsis,
171
+ (k * (k + 1))
172
+ // 2
173
+ * block_size : (((k + 1) * (k + 2)) // 2 + 1)
174
+ * block_size,
175
+ ]
176
+ for k in range(num_blocks)
177
+ ]
178
+
179
+ return materialize_matrix(SlicedSymmetricMatrix(block_rows=block_rows))
180
+
181
+
182
+ @functools.partial(jax.jit, static_argnames=("alpha", "beta"))
183
+ def update_sliced_rows(
184
+ symmetric_matrix,
185
+ mat,
186
+ alpha,
187
+ beta,
188
+ ):
189
+ """Implements the blocked equivalent of SYRK.
190
+
191
+ Specifically, the symmetric matrix (represented using lower-triangular block
192
+ rows) is updated using the sliced product of mat.
193
+
194
+ Args:
195
+ symmetric_matrix: The symmetric matrix to update.
196
+ mat: The matrix to use for the update = mat * mat^T. The number of rows
197
+ should match that of symmetric_matrix.
198
+ alpha: The weight for the update.
199
+ beta: The weight for the original symmetric matrix.
200
+
201
+ Returns:
202
+ The updated rows of alpha * mat * mat^T + beta * symmetric_matrix.
203
+ """
204
+ block_size = symmetric_matrix.block_rows[0].shape[-2]
205
+ sym_prod = sliced_transposed_product(mat=mat, block_size=block_size)
206
+ return SlicedSymmetricMatrix(
207
+ block_rows=[
208
+ update * alpha + row * beta
209
+ for update, row in zip(sym_prod.block_rows, symmetric_matrix.block_rows)
210
+ ]
211
+ )
tools/train/train.py CHANGED
@@ -1,6 +1,6 @@
1
  #!/usr/bin/env python
2
  # coding=utf-8
3
- # Copyright 2021-2022 The HuggingFace & DALL·E Mini Team All rights reserved.
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
  # you may not use this file except in compliance with the License.
@@ -37,7 +37,6 @@ import optax
37
  import transformers
38
  import wandb
39
  from datasets import Dataset
40
- from distributed_shampoo import GraftingType, distributed_shampoo
41
  from flax.core.frozen_dict import FrozenDict, freeze
42
  from flax.serialization import from_bytes, to_bytes
43
  from flax.training import train_state
@@ -46,6 +45,7 @@ from google.cloud import storage
46
  from jax.experimental import PartitionSpec, maps
47
  from jax.experimental.compilation_cache import compilation_cache as cc
48
  from jax.experimental.pjit import pjit, with_sharding_constraint
 
49
  from tqdm import tqdm
50
  from transformers import HfArgumentParser
51
 
@@ -57,7 +57,7 @@ from dalle_mini.model import (
57
  set_partitions,
58
  )
59
 
60
- cc.initialize_cache("./jax_cache", max_cache_size_bytes=5 * 2**30)
61
 
62
  logger = logging.getLogger(__name__)
63
 
@@ -203,6 +203,12 @@ class DataTrainingArguments:
203
  "help": "Whether to shard data files by host in multi-host environments."
204
  },
205
  )
 
 
 
 
 
 
206
  max_train_samples: Optional[int] = field(
207
  default=None,
208
  metadata={
@@ -314,10 +320,6 @@ class TrainingArguments:
314
  default=1024,
315
  metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
316
  )
317
- start_preconditioning_step: int = field(
318
- default=100,
319
- metadata={"help": "Number of steps before starting to update preconditioner."},
320
- )
321
  preconditioning_compute_steps: int = field(
322
  default=10, metadata={"help": "Number of steps to update preconditioner."}
323
  )
@@ -325,6 +327,12 @@ class TrainingArguments:
325
  default=4096,
326
  metadata={"help": "Max size for preconditioning with Distributed Shampoo."},
327
  )
 
 
 
 
 
 
328
  optim_quantized: bool = field(
329
  default=False,
330
  metadata={
@@ -413,11 +421,28 @@ class TrainingArguments:
413
  dp_devices: int = field(init=False)
414
 
415
  def __post_init__(self):
 
 
 
 
416
  assert self.optim in [
417
  "distributed_shampoo",
418
  "adam",
419
  "adafactor",
420
  ], f"Selected optimizer not supported: {self.optim}"
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  if self.per_device_eval_batch_size is None:
422
  self.per_device_eval_batch_size = self.per_device_train_batch_size
423
  if (
@@ -430,6 +455,9 @@ class TrainingArguments:
430
  f"Output directory ({self.output_dir}) already exists and is not empty."
431
  "Use --overwrite_output_dir to overcome."
432
  )
 
 
 
433
  assert (
434
  jax.device_count() % self.mp_devices == 0
435
  ), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})."
@@ -514,10 +542,6 @@ def main():
514
 
515
  logger.info(f"Local TPUs: {jax.local_device_count()}")
516
  logger.info(f"Global TPUs: {jax.device_count()}")
517
- if training_args.assert_TPU_available:
518
- assert (
519
- jax.local_device_count() == 8
520
- ), "TPUs in use, please check running processes"
521
 
522
  # Set up wandb run
523
  if jax.process_index() == 0:
@@ -544,8 +568,7 @@ def main():
544
  config=config,
545
  seed=training_args.seed_model,
546
  dtype=getattr(jnp, model_args.dtype),
547
- abstract_init=True,
548
- load_on_cpu=True,
549
  # initializing params with gradient checkpointing creates issues
550
  # we correctly set it later per training_args
551
  gradient_checkpointing=False,
@@ -555,29 +578,23 @@ def main():
555
  config,
556
  seed=training_args.seed_model,
557
  dtype=getattr(jnp, model_args.dtype),
558
- load_on_cpu=True,
559
  )
560
 
561
- # update model config per training args
562
- # Done after initialization of weights to avoid issues with remat
563
- # This is still considered correctly during training as function is pjitted
564
- model.config.gradient_checkpointing = training_args.gradient_checkpointing
565
-
566
  if training_args.gradient_checkpointing:
567
- # eval model cannot use remat
568
- eval_config = copy.deepcopy(model.config)
569
- eval_config.gradient_checkpointing = False
570
- eval_model = DalleBart(
571
- eval_config,
572
  seed=training_args.seed_model,
573
  dtype=getattr(jnp, model_args.dtype),
574
- abstract_init=True,
575
- load_on_cpu=True,
576
  )
577
- del eval_model._params
578
- eval_fn = eval_model.__call__
579
  else:
580
- eval_fn = model.__call__
581
 
582
  # get model metadata
583
  model_metadata = model_args.get_metadata()
@@ -620,7 +637,7 @@ def main():
620
  eval_batch_size_per_step = eval_batch_size_per_node * jax.process_count()
621
  len_train_dataset, len_eval_dataset = dataset.length
622
  steps_per_epoch = (
623
- len_train_dataset // batch_size_per_step
624
  if len_train_dataset is not None
625
  else None
626
  )
@@ -633,7 +650,7 @@ def main():
633
  logger.info(f" Num examples = {len_train_dataset}")
634
  logger.info(f" Num Epochs = {num_epochs}")
635
  logger.info(
636
- f" Batch size per device = {training_args.per_device_train_batch_size}"
637
  )
638
  logger.info(f" Number of devices = {jax.device_count()}")
639
  logger.info(
@@ -701,22 +718,32 @@ def main():
701
  # create adam optimizer
702
  if training_args.optim == "distributed_shampoo":
703
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
 
 
 
 
 
 
 
 
704
  optimizer = distributed_shampoo(
705
  learning_rate_fn,
706
  block_size=training_args.block_size,
707
  beta1=training_args.beta1,
708
  beta2=training_args.beta2,
709
  diagonal_epsilon=1e-10,
710
- matrix_epsilon=1e-8,
711
- start_preconditioning_step=training_args.start_preconditioning_step,
 
 
712
  preconditioning_compute_steps=training_args.preconditioning_compute_steps,
713
  statistics_compute_steps=1,
714
  best_effort_shape_interpretation=True,
715
- graft_type=GraftingType.RMSPROP_NORMALIZED,
716
  nesterov=False,
717
  exponent_override=0,
718
- statistics_partition_spec=PartitionSpec(None, "batch", None),
719
- preconditioner_partition_spec=PartitionSpec("batch", None, None),
720
  num_devices_for_pjit=training_args.dp_devices,
721
  shard_optimizer_states=True,
722
  inverse_failure_threshold=0.1,
@@ -779,7 +806,7 @@ def main():
779
  opt_state_spec = opt_fn.pspec_fn(
780
  params=model.params,
781
  params_partition_spec=param_spec,
782
- partition_spec_for_statistics=PartitionSpec(None, "batch", None),
783
  )
784
  else:
785
  raise NotImplementedError
@@ -790,7 +817,8 @@ def main():
790
  # create a mesh
791
  mesh_shape = (training_args.dp_devices, training_args.mp_devices)
792
  devices = np.asarray(jax.devices()).reshape(*mesh_shape)
793
- mesh = maps.Mesh(devices, ("batch", "mp"))
 
794
 
795
  # define state spec
796
  state_spec = TrainState(
@@ -801,28 +829,39 @@ def main():
801
  epoch=None,
802
  train_time=None,
803
  train_samples=None,
804
- apply_fn=model.__call__,
805
  tx=optimizer,
806
  )
807
 
808
- # create training state
 
 
 
 
 
 
 
 
809
  with maps.mesh(mesh.devices, mesh.axis_names):
 
810
  if not model_args.restore_state:
811
 
812
  def init_state(params):
813
  return TrainState.create(
814
- apply_fn=model.__call__,
815
  tx=optimizer,
816
- params=params,
817
  dropout_rng=dropout_rng,
818
  )
819
 
820
  state = pjit(
821
  init_state,
822
- in_axis_resources=(param_spec,),
 
 
823
  out_axis_resources=state_spec,
824
  donate_argnums=(0,),
825
- )(model.params)
826
 
827
  else:
828
  # load opt_state
@@ -836,7 +875,7 @@ def main():
836
 
837
  def restore_state(params, opt_state):
838
  return TrainState(
839
- apply_fn=model.__call__,
840
  tx=optimizer,
841
  params=params,
842
  opt_state=opt_state,
@@ -846,7 +885,10 @@ def main():
846
 
847
  state = pjit(
848
  restore_state,
849
- in_axis_resources=(param_spec, opt_state_spec),
 
 
 
850
  out_axis_resources=state_spec,
851
  donate_argnums=(0, 1),
852
  )(model.params, opt_state)
@@ -854,37 +896,32 @@ def main():
854
  # remove opt_state from CPU
855
  del opt_state
856
 
857
- # free memory
858
  del model._params, opt_state_spec, opt_state_shape
859
 
860
  # define batch specs
861
- keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
862
- batch_spec = freeze({k: PartitionSpec("batch") for k in keys})
863
- grad_batch_spec = freeze({k: PartitionSpec(None, "batch") for k in keys})
864
 
865
- # label smoothed cross entropy
866
  def loss_fn(logits, labels):
867
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
868
  loss = loss.mean()
869
  return loss
870
 
 
 
 
 
 
 
 
 
 
 
 
871
  # Define gradient update step fn
872
  def train_step(state, batch, delta_time):
873
- # we reshape to (gradient_accumulation_steps, dp_devices, ...)
874
- # allows feeding partial batch size per node for full model parallel
875
- batch = jax.tree_map(
876
- lambda x: x.reshape(
877
- (
878
- training_args.gradient_accumulation_steps,
879
- training_args.dp_devices,
880
- training_args.per_device_train_batch_size,
881
- )
882
- + x.shape[2:]
883
- ),
884
- batch,
885
- )
886
- # ensure data is sharded correctly per dp device
887
- batch = with_sharding_constraint(batch, grad_batch_spec)
888
 
889
  # get a minibatch (one gradient accumulation slice)
890
  def get_minibatch(batch, grad_idx):
@@ -904,62 +941,71 @@ def main():
904
  grad_fn = jax.value_and_grad(compute_loss)
905
 
906
  def loss_and_grad(grad_idx, dropout_rng):
907
- # minibatch at grad_idx, shape (dp_devices, per_device_train_batch_size, ...)
908
- minibatch = get_minibatch(batch, grad_idx)
909
- # calculate loss and grads independently per dp_device
910
- dropout_rng, _ = jax.random.split(dropout_rng)
911
- # ensure inputs are sharded per device
912
- minibatch = jax.tree_map(
913
- lambda x: with_sharding_constraint(x, PartitionSpec("batch")),
914
- minibatch,
915
- )
916
- # only 1 single rng per grad step, let us handle larger batch size
917
- loss_grads = jax.vmap(grad_fn, in_axes=(None, 0, None), out_axes=(0, 0))(
918
- state.params, minibatch, dropout_rng
919
  )
920
- # ensure outputs are sharded per device
921
- loss_grads = jax.tree_map(
922
- lambda x: with_sharding_constraint(x, PartitionSpec("batch")),
923
- loss_grads,
924
- )
925
- # average across all devices
926
- loss_grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), loss_grads)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
927
  # return loss and grads
928
- return loss_grads, dropout_rng
929
 
930
  if training_args.gradient_accumulation_steps == 1:
931
- loss_grad, dropout_rng = loss_and_grad(0, state.dropout_rng)
932
  else:
933
  # create initial state for cumul_minibatch_step loop
934
  init_minibatch_step = (
935
- (
936
- 0.0,
937
- jax.tree_map(jnp.zeros_like, state.params),
938
  ),
939
  state.dropout_rng,
940
  )
941
 
942
  # accumulate gradients
943
  def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
944
- cumul_loss_grad, dropout_rng = cumul_loss_grad_dropout
945
- loss_grad, dropout_rng = loss_and_grad(grad_idx, dropout_rng)
946
- cumul_loss_grad = jax.tree_map(jnp.add, cumul_loss_grad, loss_grad)
947
- return cumul_loss_grad, dropout_rng
 
 
 
948
 
949
  # loop over gradients
950
- loss_grad, dropout_rng = jax.lax.fori_loop(
951
  0,
952
  training_args.gradient_accumulation_steps,
953
  cumul_minibatch_step,
954
  init_minibatch_step,
955
  )
 
956
  # sum -> mean
957
- loss_grad = jax.tree_map(
958
- lambda x: x / training_args.gradient_accumulation_steps, loss_grad
959
  )
960
 
961
  # update state
962
- loss, grads = loss_grad
963
  state = state.apply_gradients(
964
  grads=grads,
965
  dropout_rng=dropout_rng,
@@ -976,37 +1022,32 @@ def main():
976
 
977
  # Define eval fn
978
  def eval_step(state, batch):
979
- # we reshape to (dp_devices, ...)
980
- batch = jax.tree_map(
981
- lambda x: x.reshape(
982
- (
983
- training_args.dp_devices,
984
- training_args.per_device_eval_batch_size,
985
- )
986
- + x.shape[1:]
987
- ),
988
- batch,
989
- )
990
- # ensure data is sharded correctly per dp device
991
- batch = with_sharding_constraint(batch, batch_spec)
992
-
993
  def compute_eval_loss(batch):
994
  batch, labels = batch.pop("labels")
995
  logits = eval_fn(**batch, params=state.params, train=False)[0]
996
  return loss_fn(logits, labels)
997
 
998
- # calculate loss independently per dp_device
999
- loss = jax.vmap(compute_eval_loss, in_axes=(0,), out_axes=0)(batch)
1000
- # ensure they are sharded over dp devices
1001
- loss = with_sharding_constraint(loss, PartitionSpec("batch"))
1002
- # average across all devices
1003
- loss = jnp.mean(loss)
 
 
 
1004
  return loss
1005
 
1006
  # Create parallel version of the train and eval step
1007
  p_train_step = pjit(
1008
  train_step,
1009
- in_axis_resources=(state_spec, grad_batch_spec, None),
 
 
 
 
 
 
1010
  out_axis_resources=(state_spec, None),
1011
  donate_argnums=(0,),
1012
  )
@@ -1022,7 +1063,10 @@ def main():
1022
  step = int(state.step)
1023
  metrics_logger = MetricsLogger(step)
1024
  epochs = tqdm(
1025
- range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
 
 
 
1026
  )
1027
 
1028
  def run_evaluation():
@@ -1041,6 +1085,7 @@ def main():
1041
  position=2,
1042
  leave=False,
1043
  total=eval_steps,
 
1044
  ):
1045
  # need to keep only eval_batch_size_per_node items relevant to the node
1046
  batch = jax.tree_map(
@@ -1050,6 +1095,17 @@ def main():
1050
  batch,
1051
  )
1052
  batch = jax.tree_map(lambda x: x[jax.process_index()], batch)
 
 
 
 
 
 
 
 
 
 
 
1053
  # freeze batch to pass safely to jax transforms
1054
  batch = freeze(batch)
1055
  # accumulate losses async
@@ -1166,6 +1222,7 @@ def main():
1166
  )
1167
  wandb.run.log_artifact(artifact_state)
1168
 
 
1169
  with maps.mesh(mesh.devices, mesh.axis_names):
1170
  for epoch in epochs:
1171
  state.replace(epoch=epoch)
@@ -1186,21 +1243,33 @@ def main():
1186
  position=1,
1187
  leave=False,
1188
  total=steps_per_epoch,
 
1189
  ):
1190
  # calculate delta time (we have a lag of one step but it's ok)
1191
  new_time = time.perf_counter()
1192
  delta_time = new_time - last_time
1193
  last_time = new_time
1194
 
1195
- # reshape data into (gradient_accumulation_steps, dp_devices, batch_per_dp, ...)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1196
  batch = jax.tree_map(
1197
- lambda x: x.reshape(
1198
- (
1199
- training_args.gradient_accumulation_steps,
1200
- batch_size_per_node_per_grad_step,
1201
- )
1202
- + x.shape[1:]
1203
- ),
1204
  batch,
1205
  )
1206
  # freeze batch to pass safely to jax transforms
 
1
  #!/usr/bin/env python
2
  # coding=utf-8
3
+ # Copyright 2021-2022 The HuggingFace & DALL·E Mini team. All rights reserved.
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
  # you may not use this file except in compliance with the License.
 
37
  import transformers
38
  import wandb
39
  from datasets import Dataset
 
40
  from flax.core.frozen_dict import FrozenDict, freeze
41
  from flax.serialization import from_bytes, to_bytes
42
  from flax.training import train_state
 
45
  from jax.experimental import PartitionSpec, maps
46
  from jax.experimental.compilation_cache import compilation_cache as cc
47
  from jax.experimental.pjit import pjit, with_sharding_constraint
48
+ from scalable_shampoo.distributed_shampoo import GraftingType, distributed_shampoo
49
  from tqdm import tqdm
50
  from transformers import HfArgumentParser
51
 
 
57
  set_partitions,
58
  )
59
 
60
+ cc.initialize_cache("./jax_cache", max_cache_size_bytes=10 * 2**30)
61
 
62
  logger = logging.getLogger(__name__)
63
 
 
203
  "help": "Whether to shard data files by host in multi-host environments."
204
  },
205
  )
206
+ blank_caption_prob: Optional[float] = field(
207
+ default=0.0,
208
+ metadata={
209
+ "help": "Probability of removing some captions for classifier-free guidance."
210
+ },
211
+ )
212
  max_train_samples: Optional[int] = field(
213
  default=None,
214
  metadata={
 
320
  default=1024,
321
  metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
322
  )
 
 
 
 
323
  preconditioning_compute_steps: int = field(
324
  default=10, metadata={"help": "Number of steps to update preconditioner."}
325
  )
 
327
  default=4096,
328
  metadata={"help": "Max size for preconditioning with Distributed Shampoo."},
329
  )
330
+ graft_type: str = field(
331
+ default="rmsprop_normalized",
332
+ metadata={
333
+ "help": "The type of grafting to use. Can be 'rmsprop_normalized' (default), 'rmsprop', 'adagrad', 'adagrad_normalized', 'sgd' or 'sqrt_n'"
334
+ },
335
+ )
336
  optim_quantized: bool = field(
337
  default=False,
338
  metadata={
 
421
  dp_devices: int = field(init=False)
422
 
423
  def __post_init__(self):
424
+ if self.assert_TPU_available:
425
+ assert (
426
+ jax.local_device_count() == 8
427
+ ), "TPUs in use, please check running processes"
428
  assert self.optim in [
429
  "distributed_shampoo",
430
  "adam",
431
  "adafactor",
432
  ], f"Selected optimizer not supported: {self.optim}"
433
+ assert self.graft_type in [
434
+ "rmsprop_normalized",
435
+ "rmsprop",
436
+ "adagrad",
437
+ "adagrad_normalized",
438
+ "sgd",
439
+ "sqrt_n",
440
+ ], f"Selected graft type not supported: {self.graft_type}"
441
+ assert self.lr_decay in [
442
+ None,
443
+ "linear",
444
+ "exponential",
445
+ ], f"Selected learning rate decay not supported: {self.lr_decay}"
446
  if self.per_device_eval_batch_size is None:
447
  self.per_device_eval_batch_size = self.per_device_train_batch_size
448
  if (
 
455
  f"Output directory ({self.output_dir}) already exists and is not empty."
456
  "Use --overwrite_output_dir to overcome."
457
  )
458
+ assert (
459
+ self.mp_devices > 0
460
+ ), f"Number of devices for model parallelism must be > 0"
461
  assert (
462
  jax.device_count() % self.mp_devices == 0
463
  ), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})."
 
542
 
543
  logger.info(f"Local TPUs: {jax.local_device_count()}")
544
  logger.info(f"Global TPUs: {jax.device_count()}")
 
 
 
 
545
 
546
  # Set up wandb run
547
  if jax.process_index() == 0:
 
568
  config=config,
569
  seed=training_args.seed_model,
570
  dtype=getattr(jnp, model_args.dtype),
571
+ abstract_init=True, # we overwrite them with loaded checkpoint
 
572
  # initializing params with gradient checkpointing creates issues
573
  # we correctly set it later per training_args
574
  gradient_checkpointing=False,
 
578
  config,
579
  seed=training_args.seed_model,
580
  dtype=getattr(jnp, model_args.dtype),
581
+ abstract_init=True,
582
  )
583
 
584
+ # define model eval and train functions
585
+ eval_fn = model.__call__
 
 
 
586
  if training_args.gradient_checkpointing:
587
+ remat_config = copy.deepcopy(model.config)
588
+ remat_config.gradient_checkpointing = True
589
+ remat_model = DalleBart(
590
+ remat_config,
 
591
  seed=training_args.seed_model,
592
  dtype=getattr(jnp, model_args.dtype),
593
+ init_weights=False,
 
594
  )
595
+ train_fn = remat_model.__call__
 
596
  else:
597
+ train_fn = model.__call__
598
 
599
  # get model metadata
600
  model_metadata = model_args.get_metadata()
 
637
  eval_batch_size_per_step = eval_batch_size_per_node * jax.process_count()
638
  len_train_dataset, len_eval_dataset = dataset.length
639
  steps_per_epoch = (
640
+ len_train_dataset // batch_size_per_node
641
  if len_train_dataset is not None
642
  else None
643
  )
 
650
  logger.info(f" Num examples = {len_train_dataset}")
651
  logger.info(f" Num Epochs = {num_epochs}")
652
  logger.info(
653
+ f" Batch size per dp device = {training_args.per_device_train_batch_size}"
654
  )
655
  logger.info(f" Number of devices = {jax.device_count()}")
656
  logger.info(
 
718
  # create adam optimizer
719
  if training_args.optim == "distributed_shampoo":
720
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
721
+ graft_type = {
722
+ "sgd": GraftingType.SGD,
723
+ "adagrad": GraftingType.ADAGRAD,
724
+ "rmsprop": GraftingType.RMSPROP,
725
+ "rmsprop_normalized": GraftingType.RMSPROP_NORMALIZED,
726
+ "sqrt_n": GraftingType.SQRT_N,
727
+ "adagrad_normalized": GraftingType.ADAGRAD_NORMALIZED,
728
+ }[training_args.graft_type]
729
  optimizer = distributed_shampoo(
730
  learning_rate_fn,
731
  block_size=training_args.block_size,
732
  beta1=training_args.beta1,
733
  beta2=training_args.beta2,
734
  diagonal_epsilon=1e-10,
735
+ matrix_epsilon=1e-6,
736
+ start_preconditioning_step=max(
737
+ training_args.preconditioning_compute_steps + 1, 101
738
+ ),
739
  preconditioning_compute_steps=training_args.preconditioning_compute_steps,
740
  statistics_compute_steps=1,
741
  best_effort_shape_interpretation=True,
742
+ graft_type=graft_type,
743
  nesterov=False,
744
  exponent_override=0,
745
+ statistics_partition_spec=PartitionSpec(None, "dp", None),
746
+ preconditioner_partition_spec=PartitionSpec("dp", None, None),
747
  num_devices_for_pjit=training_args.dp_devices,
748
  shard_optimizer_states=True,
749
  inverse_failure_threshold=0.1,
 
806
  opt_state_spec = opt_fn.pspec_fn(
807
  params=model.params,
808
  params_partition_spec=param_spec,
809
+ partition_spec_for_statistics=PartitionSpec(None, "dp", None),
810
  )
811
  else:
812
  raise NotImplementedError
 
817
  # create a mesh
818
  mesh_shape = (training_args.dp_devices, training_args.mp_devices)
819
  devices = np.asarray(jax.devices()).reshape(*mesh_shape)
820
+ mesh = maps.Mesh(devices, ("dp", "mp"))
821
+ logger.info(f" Mesh shape: {mesh_shape}")
822
 
823
  # define state spec
824
  state_spec = TrainState(
 
829
  epoch=None,
830
  train_time=None,
831
  train_samples=None,
832
+ apply_fn=train_fn,
833
  tx=optimizer,
834
  )
835
 
836
+ # init params if not available yet
837
+ def maybe_init_params(params):
838
+ if model_args.model_name_or_path:
839
+ # model params are correctly loaded
840
+ return params
841
+ else:
842
+ # params have not been initialized yet
843
+ return model.init_weights()
844
+
845
  with maps.mesh(mesh.devices, mesh.axis_names):
846
+ logger.info(" Creating state")
847
  if not model_args.restore_state:
848
 
849
  def init_state(params):
850
  return TrainState.create(
851
+ apply_fn=train_fn,
852
  tx=optimizer,
853
+ params=maybe_init_params(params),
854
  dropout_rng=dropout_rng,
855
  )
856
 
857
  state = pjit(
858
  init_state,
859
+ in_axis_resources=(param_spec,)
860
+ if model_args.model_name_or_path
861
+ else None,
862
  out_axis_resources=state_spec,
863
  donate_argnums=(0,),
864
+ )(model.params if model_args.model_name_or_path else None)
865
 
866
  else:
867
  # load opt_state
 
875
 
876
  def restore_state(params, opt_state):
877
  return TrainState(
878
+ apply_fn=train_fn,
879
  tx=optimizer,
880
  params=params,
881
  opt_state=opt_state,
 
885
 
886
  state = pjit(
887
  restore_state,
888
+ in_axis_resources=(
889
+ param_spec,
890
+ opt_state_spec,
891
+ ),
892
  out_axis_resources=state_spec,
893
  donate_argnums=(0, 1),
894
  )(model.params, opt_state)
 
896
  # remove opt_state from CPU
897
  del opt_state
898
 
899
+ # free CPU memory
900
  del model._params, opt_state_spec, opt_state_shape
901
 
902
  # define batch specs
903
+ batch_spec = PartitionSpec("dp")
904
+ grad_batch_spec = PartitionSpec(None, "dp")
 
905
 
906
+ # define loss
907
  def loss_fn(logits, labels):
908
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
909
  loss = loss.mean()
910
  return loss
911
 
912
+ # "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens)
913
+ # lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2
914
+ use_vmap_trick = True
915
+
916
+ # make grad_param_spec for vmap
917
+ if use_vmap_trick:
918
+ grad_param_spec = jax.tree_map(
919
+ lambda x: PartitionSpec(*("dp",) + (x if x is not None else (None,))),
920
+ param_spec,
921
+ )
922
+
923
  # Define gradient update step fn
924
  def train_step(state, batch, delta_time):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
925
 
926
  # get a minibatch (one gradient accumulation slice)
927
  def get_minibatch(batch, grad_idx):
 
941
  grad_fn = jax.value_and_grad(compute_loss)
942
 
943
  def loss_and_grad(grad_idx, dropout_rng):
944
+ # minibatch at grad_idx for gradient accumulation (None otherwise)
945
+ minibatch = (
946
+ get_minibatch(batch, grad_idx) if grad_idx is not None else batch
 
 
 
 
 
 
 
 
 
947
  )
948
+ # ensure it is sharded properly
949
+ minibatch = with_sharding_constraint(minibatch, batch_spec)
950
+ # only 1 single rng per grad step, let us handle larger batch size (not sure why)
951
+ dropout_rng, _ = jax.random.split(dropout_rng)
952
+
953
+ if use_vmap_trick:
954
+ # "vmap trick", calculate loss and grads independently per dp_device
955
+ loss, grads = jax.vmap(
956
+ grad_fn, in_axes=(None, 0, None), out_axes=(0, 0)
957
+ )(state.params, minibatch, dropout_rng)
958
+ # ensure they are sharded correctly
959
+ loss = with_sharding_constraint(loss, batch_spec)
960
+ grads = with_sharding_constraint(grads, grad_param_spec)
961
+ # average across all devices
962
+ # Note: we could average per device only after gradient accumulation, right before params update
963
+ loss, grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), (loss, grads))
964
+ else:
965
+ # "vmap trick" does not work in multi-hosts and requires too much hbm
966
+ loss, grads = grad_fn(state.params, minibatch, dropout_rng)
967
+ # ensure grads are sharded
968
+ grads = with_sharding_constraint(grads, param_spec)
969
  # return loss and grads
970
+ return loss, grads, dropout_rng
971
 
972
  if training_args.gradient_accumulation_steps == 1:
973
+ loss, grads, dropout_rng = loss_and_grad(None, state.dropout_rng)
974
  else:
975
  # create initial state for cumul_minibatch_step loop
976
  init_minibatch_step = (
977
+ 0.0,
978
+ with_sharding_constraint(
979
+ jax.tree_map(jnp.zeros_like, state.params), param_spec
980
  ),
981
  state.dropout_rng,
982
  )
983
 
984
  # accumulate gradients
985
  def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
986
+ cumul_loss, cumul_grads, dropout_rng = cumul_loss_grad_dropout
987
+ loss, grads, dropout_rng = loss_and_grad(grad_idx, dropout_rng)
988
+ cumul_loss, cumul_grads = jax.tree_map(
989
+ jnp.add, (cumul_loss, cumul_grads), (loss, grads)
990
+ )
991
+ cumul_grads = with_sharding_constraint(cumul_grads, param_spec)
992
+ return cumul_loss, cumul_grads, dropout_rng
993
 
994
  # loop over gradients
995
+ loss, grads, dropout_rng = jax.lax.fori_loop(
996
  0,
997
  training_args.gradient_accumulation_steps,
998
  cumul_minibatch_step,
999
  init_minibatch_step,
1000
  )
1001
+ grads = with_sharding_constraint(grads, param_spec)
1002
  # sum -> mean
1003
+ loss, grads = jax.tree_map(
1004
+ lambda x: x / training_args.gradient_accumulation_steps, (loss, grads)
1005
  )
1006
 
1007
  # update state
1008
+ grads = with_sharding_constraint(grads, param_spec)
1009
  state = state.apply_gradients(
1010
  grads=grads,
1011
  dropout_rng=dropout_rng,
 
1022
 
1023
  # Define eval fn
1024
  def eval_step(state, batch):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1025
  def compute_eval_loss(batch):
1026
  batch, labels = batch.pop("labels")
1027
  logits = eval_fn(**batch, params=state.params, train=False)[0]
1028
  return loss_fn(logits, labels)
1029
 
1030
+ if use_vmap_trick:
1031
+ loss = jax.vmap(compute_eval_loss)(batch)
1032
+ # ensure they are sharded correctly
1033
+ loss = with_sharding_constraint(loss, batch_spec)
1034
+ # average across all devices
1035
+ loss = jnp.mean(loss)
1036
+ else:
1037
+ loss = compute_eval_loss(batch)
1038
+
1039
  return loss
1040
 
1041
  # Create parallel version of the train and eval step
1042
  p_train_step = pjit(
1043
  train_step,
1044
+ in_axis_resources=(
1045
+ state_spec,
1046
+ grad_batch_spec
1047
+ if training_args.gradient_accumulation_steps > 1
1048
+ else batch_spec,
1049
+ None,
1050
+ ),
1051
  out_axis_resources=(state_spec, None),
1052
  donate_argnums=(0,),
1053
  )
 
1063
  step = int(state.step)
1064
  metrics_logger = MetricsLogger(step)
1065
  epochs = tqdm(
1066
+ range(state.epoch, num_epochs),
1067
+ desc=f"Epoch ... (1/{num_epochs})",
1068
+ position=0,
1069
+ disable=jax.process_index() > 0,
1070
  )
1071
 
1072
  def run_evaluation():
 
1085
  position=2,
1086
  leave=False,
1087
  total=eval_steps,
1088
+ disable=jax.process_index() > 0,
1089
  ):
1090
  # need to keep only eval_batch_size_per_node items relevant to the node
1091
  batch = jax.tree_map(
 
1095
  batch,
1096
  )
1097
  batch = jax.tree_map(lambda x: x[jax.process_index()], batch)
1098
+
1099
+ # add dp dimension when using "vmap trick"
1100
+ if use_vmap_trick:
1101
+ bs_shape = (
1102
+ jax.local_device_count() // training_args.mp_devices,
1103
+ training_args.per_device_eval_batch_size,
1104
+ )
1105
+ batch = jax.tree_map(
1106
+ lambda x: x.reshape(bs_shape + x.shape[1:]), batch
1107
+ )
1108
+
1109
  # freeze batch to pass safely to jax transforms
1110
  batch = freeze(batch)
1111
  # accumulate losses async
 
1222
  )
1223
  wandb.run.log_artifact(artifact_state)
1224
 
1225
+ logger.info(" Ready to start training")
1226
  with maps.mesh(mesh.devices, mesh.axis_names):
1227
  for epoch in epochs:
1228
  state.replace(epoch=epoch)
 
1243
  position=1,
1244
  leave=False,
1245
  total=steps_per_epoch,
1246
+ disable=jax.process_index() > 0,
1247
  ):
1248
  # calculate delta time (we have a lag of one step but it's ok)
1249
  new_time = time.perf_counter()
1250
  delta_time = new_time - last_time
1251
  last_time = new_time
1252
 
1253
+ # set correct shape to batch
1254
+ # - add grad_step dim if gradient_accumulation_steps > 1
1255
+ # - split per dp device if not multi-host for vmap trick (does not work in multi-host)
1256
+ bs_shape = (
1257
+ (batch_size_per_node_per_grad_step,)
1258
+ if not use_vmap_trick
1259
+ else (
1260
+ jax.local_device_count()
1261
+ // training_args.mp_devices, # local dp devices
1262
+ training_args.per_device_train_batch_size,
1263
+ )
1264
+ )
1265
+ if training_args.gradient_accumulation_steps > 1:
1266
+ # reshape data into (gradient_accumulation_steps, batch_per_node, ...)
1267
+ # to avoid any data redistribution when sharding
1268
+ bs_shape = (training_args.gradient_accumulation_steps,) + bs_shape
1269
+
1270
+ # reshape batch
1271
  batch = jax.tree_map(
1272
+ lambda x: x.reshape(bs_shape + x.shape[1:]),
 
 
 
 
 
 
1273
  batch,
1274
  )
1275
  # freeze batch to pass safely to jax transforms