boris commited on
Commit
193c88c
2 Parent(s): f5dba1e 25862e8

Merge pull request #118 from borisdayma/feat-optim

Browse files
src/dalle_mini/data.py CHANGED
@@ -161,7 +161,7 @@ class Dataset:
161
  ):
162
  """
163
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
164
- Shuffle batches if `shuffle` is `True`.
165
  """
166
  steps_per_epoch = len(dataset) // batch_size
167
 
@@ -182,19 +182,20 @@ class Dataset:
182
  yield batch
183
 
184
  def _dataloader_datasets_streaming(
185
- dataset: Dataset, batch_size: int, epoch: int
186
  ):
187
- # epoch is only use for multi-host
188
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
189
  batch = {k: [] for k in keys}
190
- first_loop = True
191
- while self.multi_hosts or first_loop:
192
  # in multi-host, we run forever (no epoch) as hosts need to stop
193
- # at the same time and we don't know how much data is on each host
194
- if not first_loop:
195
- # multi-host setting, we reshuffle shards
196
- epoch += 1
 
197
  dataset.set_epoch(epoch)
 
198
  for item in dataset:
199
  for k, v in item.items():
200
  batch[k].append(v)
@@ -213,9 +214,7 @@ class Dataset:
213
  raise ValueError(f'split must be "train" or "eval", got {split}')
214
 
215
  if self.streaming:
216
- if split == "train":
217
- ds.set_epoch(epoch)
218
- return _dataloader_datasets_streaming(ds, batch_size, epoch)
219
  else:
220
  if split == "train":
221
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
 
161
  ):
162
  """
163
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
164
+ Shuffle batches if rng is set.
165
  """
166
  steps_per_epoch = len(dataset) // batch_size
167
 
 
182
  yield batch
183
 
184
  def _dataloader_datasets_streaming(
185
+ dataset: Dataset, split: str, batch_size: int, epoch: int
186
  ):
 
187
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
188
  batch = {k: [] for k in keys}
189
+ first_loop = True # stop after one loop in some cases
190
+ while (self.multi_hosts and split == "train") or first_loop:
191
  # in multi-host, we run forever (no epoch) as hosts need to stop
192
+ # at the same time and training data may not be split equally
193
+ # For validation data we put the entire set on each host as we could lose
194
+ # too many samples on pods
195
+ if epoch is not None:
196
+ # reshuffle training data at each epoch (not applicable with validation set)
197
  dataset.set_epoch(epoch)
198
+ epoch += 1
199
  for item in dataset:
200
  for k, v in item.items():
201
  batch[k].append(v)
 
214
  raise ValueError(f'split must be "train" or "eval", got {split}')
215
 
216
  if self.streaming:
217
+ return _dataloader_datasets_streaming(ds, split, batch_size, epoch)
 
 
218
  else:
219
  if split == "train":
220
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
tools/inference/inference_pipeline.ipynb CHANGED
@@ -70,15 +70,15 @@
70
  "# Model references\n",
71
  "\n",
72
  "# dalle-mini\n",
73
- "DALLE_MODEL = 'dalle-mini/dalle-mini/model-3bqwu04f:latest' # can be wandb artifact or 🤗 Hub or local folder\n",
74
  "DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
75
  "\n",
76
  "# VQGAN model\n",
77
- "VQGAN_REPO = 'dalle-mini/vqgan_imagenet_f16_16384'\n",
78
- "VQGAN_COMMIT_ID = 'e93a26e7707683d349bf5d5c41c5b0ef69b677a9'\n",
79
  "\n",
80
  "# CLIP model\n",
81
- "CLIP_REPO = 'openai/clip-vit-base-patch16'\n",
82
  "CLIP_COMMIT_ID = None"
83
  ]
84
  },
@@ -121,18 +121,28 @@
121
  "import wandb\n",
122
  "\n",
123
  "# Load dalle-mini\n",
124
- "if ':' in DALLE_MODEL:\n",
125
  " # wandb artifact\n",
126
  " artifact = wandb.Api().artifact(DALLE_MODEL)\n",
127
  " # we only download required files (no need for opt_state which is large)\n",
128
- " model_files = ['config.json', 'flax_model.msgpack', 'merges.txt', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.json']\n",
 
 
 
 
 
 
 
 
129
  " for f in model_files:\n",
130
- " artifact.get_path(f).download('model')\n",
131
- " model = DalleBart.from_pretrained('model', dtype=dtype, abstract_init=True)\n",
132
- " tokenizer = AutoTokenizer.from_pretrained('model')\n",
133
  "else:\n",
134
  " # local folder or 🤗 Hub\n",
135
- " model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True)\n",
 
 
136
  " tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
137
  "\n",
138
  "# Load VQGAN\n",
@@ -191,7 +201,7 @@
191
  "from functools import partial\n",
192
  "\n",
193
  "# model inference\n",
194
- "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3,4))\n",
195
  "def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
196
  " return model.generate(\n",
197
  " **tokenized_prompt,\n",
@@ -203,11 +213,13 @@
203
  " top_p=top_p\n",
204
  " )\n",
205
  "\n",
 
206
  "# decode images\n",
207
  "@partial(jax.pmap, axis_name=\"batch\")\n",
208
  "def p_decode(indices, params):\n",
209
  " return vqgan.decode_code(indices, params=params)\n",
210
  "\n",
 
211
  "# score images\n",
212
  "@partial(jax.pmap, axis_name=\"batch\")\n",
213
  "def p_clip(inputs, params):\n",
@@ -235,7 +247,7 @@
235
  "import random\n",
236
  "\n",
237
  "# create a random key\n",
238
- "seed = random.randint(0, 2**32-1)\n",
239
  "key = jax.random.PRNGKey(seed)"
240
  ]
241
  },
@@ -287,7 +299,7 @@
287
  },
288
  "outputs": [],
289
  "source": [
290
- "prompt = 'a red T-shirt'"
291
  ]
292
  },
293
  {
@@ -323,7 +335,13 @@
323
  "repeated_prompts = [processed_prompt] * jax.device_count()\n",
324
  "\n",
325
  "# tokenize\n",
326
- "tokenized_prompt = tokenizer(repeated_prompts, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
 
 
 
 
 
 
327
  "tokenized_prompt"
328
  ]
329
  },
@@ -408,12 +426,14 @@
408
  " # get a new key\n",
409
  " key, subkey = jax.random.split(key)\n",
410
  " # generate images\n",
411
- " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p)\n",
 
 
412
  " # remove BOS\n",
413
  " encoded_images = encoded_images.sequences[..., 1:]\n",
414
  " # decode images\n",
415
  " decoded_images = p_decode(encoded_images, vqgan_params)\n",
416
- " decoded_images = decoded_images.clip(0., 1.).reshape((-1, 256, 256, 3))\n",
417
  " for img in decoded_images:\n",
418
  " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
419
  ]
@@ -436,7 +456,14 @@
436
  "outputs": [],
437
  "source": [
438
  "# get clip scores\n",
439
- "clip_inputs = processor(text=[prompt] * jax.device_count(), images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
 
 
 
 
 
 
 
440
  "logits = p_clip(shard(clip_inputs), clip_params)\n",
441
  "logits = logits.squeeze().flatten()"
442
  ]
@@ -458,10 +485,10 @@
458
  },
459
  "outputs": [],
460
  "source": [
461
- "print(f'Prompt: {prompt}\\n')\n",
462
  "for idx in logits.argsort()[::-1]:\n",
463
  " display(images[idx])\n",
464
- " print(f'Score: {logits[idx]:.2f}\\n')"
465
  ]
466
  }
467
  ],
 
70
  "# Model references\n",
71
  "\n",
72
  "# dalle-mini\n",
73
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/model-3bqwu04f:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
74
  "DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
75
  "\n",
76
  "# VQGAN model\n",
77
+ "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
78
+ "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
79
  "\n",
80
  "# CLIP model\n",
81
+ "CLIP_REPO = \"openai/clip-vit-base-patch16\"\n",
82
  "CLIP_COMMIT_ID = None"
83
  ]
84
  },
 
121
  "import wandb\n",
122
  "\n",
123
  "# Load dalle-mini\n",
124
+ "if \":\" in DALLE_MODEL:\n",
125
  " # wandb artifact\n",
126
  " artifact = wandb.Api().artifact(DALLE_MODEL)\n",
127
  " # we only download required files (no need for opt_state which is large)\n",
128
+ " model_files = [\n",
129
+ " \"config.json\",\n",
130
+ " \"flax_model.msgpack\",\n",
131
+ " \"merges.txt\",\n",
132
+ " \"special_tokens_map.json\",\n",
133
+ " \"tokenizer.json\",\n",
134
+ " \"tokenizer_config.json\",\n",
135
+ " \"vocab.json\",\n",
136
+ " ]\n",
137
  " for f in model_files:\n",
138
+ " artifact.get_path(f).download(\"model\")\n",
139
+ " model = DalleBart.from_pretrained(\"model\", dtype=dtype, abstract_init=True)\n",
140
+ " tokenizer = AutoTokenizer.from_pretrained(\"model\")\n",
141
  "else:\n",
142
  " # local folder or 🤗 Hub\n",
143
+ " model = DalleBart.from_pretrained(\n",
144
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
145
+ " )\n",
146
  " tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
147
  "\n",
148
  "# Load VQGAN\n",
 
201
  "from functools import partial\n",
202
  "\n",
203
  "# model inference\n",
204
+ "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4))\n",
205
  "def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
206
  " return model.generate(\n",
207
  " **tokenized_prompt,\n",
 
213
  " top_p=top_p\n",
214
  " )\n",
215
  "\n",
216
+ "\n",
217
  "# decode images\n",
218
  "@partial(jax.pmap, axis_name=\"batch\")\n",
219
  "def p_decode(indices, params):\n",
220
  " return vqgan.decode_code(indices, params=params)\n",
221
  "\n",
222
+ "\n",
223
  "# score images\n",
224
  "@partial(jax.pmap, axis_name=\"batch\")\n",
225
  "def p_clip(inputs, params):\n",
 
247
  "import random\n",
248
  "\n",
249
  "# create a random key\n",
250
+ "seed = random.randint(0, 2 ** 32 - 1)\n",
251
  "key = jax.random.PRNGKey(seed)"
252
  ]
253
  },
 
299
  },
300
  "outputs": [],
301
  "source": [
302
+ "prompt = \"a red T-shirt\""
303
  ]
304
  },
305
  {
 
335
  "repeated_prompts = [processed_prompt] * jax.device_count()\n",
336
  "\n",
337
  "# tokenize\n",
338
+ "tokenized_prompt = tokenizer(\n",
339
+ " repeated_prompts,\n",
340
+ " return_tensors=\"jax\",\n",
341
+ " padding=\"max_length\",\n",
342
+ " truncation=True,\n",
343
+ " max_length=128,\n",
344
+ ").data\n",
345
  "tokenized_prompt"
346
  ]
347
  },
 
426
  " # get a new key\n",
427
  " key, subkey = jax.random.split(key)\n",
428
  " # generate images\n",
429
+ " encoded_images = p_generate(\n",
430
+ " tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p\n",
431
+ " )\n",
432
  " # remove BOS\n",
433
  " encoded_images = encoded_images.sequences[..., 1:]\n",
434
  " # decode images\n",
435
  " decoded_images = p_decode(encoded_images, vqgan_params)\n",
436
+ " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
437
  " for img in decoded_images:\n",
438
  " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
439
  ]
 
456
  "outputs": [],
457
  "source": [
458
  "# get clip scores\n",
459
+ "clip_inputs = processor(\n",
460
+ " text=[prompt] * jax.device_count(),\n",
461
+ " images=images,\n",
462
+ " return_tensors=\"np\",\n",
463
+ " padding=\"max_length\",\n",
464
+ " max_length=77,\n",
465
+ " truncation=True,\n",
466
+ ").data\n",
467
  "logits = p_clip(shard(clip_inputs), clip_params)\n",
468
  "logits = logits.squeeze().flatten()"
469
  ]
 
485
  },
486
  "outputs": [],
487
  "source": [
488
+ "print(f\"Prompt: {prompt}\\n\")\n",
489
  "for idx in logits.argsort()[::-1]:\n",
490
  " display(images[idx])\n",
491
+ " print(f\"Score: {logits[idx]:.2f}\\n\")"
492
  ]
493
  }
494
  ],
tools/train/train.py CHANGED
@@ -65,7 +65,7 @@ class ModelArguments:
65
  config_name: Optional[str] = field(
66
  default=None,
67
  metadata={
68
- "help": "Pretrained config name or path if not the same as model_name"
69
  },
70
  )
71
  tokenizer_name: Optional[str] = field(
@@ -77,7 +77,7 @@ class ModelArguments:
77
  dtype: Optional[str] = field(
78
  default="float32",
79
  metadata={
80
- "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
81
  },
82
  )
83
 
@@ -106,11 +106,15 @@ class DataTrainingArguments:
106
  )
107
  train_file: Optional[str] = field(
108
  default=None,
109
- metadata={"help": "The input training data file (glob acceptable)."},
 
 
110
  )
111
  validation_file: Optional[str] = field(
112
  default=None,
113
- metadata={"help": "An optional input evaluation data file (glob acceptable)."},
 
 
114
  )
115
  # data loading should not be a bottleneck so we use "streaming" mode by default
116
  streaming: Optional[bool] = field(
@@ -132,15 +136,13 @@ class DataTrainingArguments:
132
  max_train_samples: Optional[int] = field(
133
  default=None,
134
  metadata={
135
- "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
136
- "value if set."
137
  },
138
  )
139
  max_eval_samples: Optional[int] = field(
140
  default=None,
141
  metadata={
142
- "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
143
- "value if set."
144
  },
145
  )
146
  preprocessing_num_workers: Optional[int] = field(
@@ -191,42 +193,40 @@ class TrainingArguments:
191
 
192
  do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
193
  do_eval: bool = field(
194
- default=False, metadata={"help": "Whether to run eval on the dev set."}
195
  )
196
 
197
  per_device_train_batch_size: int = field(
198
- default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
199
  )
200
  per_device_eval_batch_size: int = field(
201
- default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
202
  )
203
 
204
  gradient_accumulation_steps: int = field(
205
  default=1,
206
  metadata={
207
- "help": "Number of updates steps to accumulate before performing a backward/update pass."
208
  },
209
  )
210
 
211
  learning_rate: float = field(
212
  default=5e-5, metadata={"help": "The initial learning rate."}
213
  )
214
- adafactor: bool = field(
215
- default=False,
216
- metadata={"help": "Use Adafactor instead of AdamW."},
217
- )
218
- distributed_shampoo: bool = field(
219
- default=False,
220
- metadata={"help": "Use Distributed Shampoo optimizer instead of AdamW."},
221
- )
222
- weight_decay: float = field(
223
- default=None, metadata={"help": "Weight decay if we apply some."}
224
  )
225
- adam_beta1: float = field(
226
- default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}
 
 
227
  )
228
- adam_beta2: float = field(
229
- default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}
 
230
  )
231
  adam_epsilon: float = field(
232
  default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
@@ -234,9 +234,47 @@ class TrainingArguments:
234
  max_grad_norm: float = field(
235
  default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
236
  )
237
- use_decay: bool = field(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  default=False,
239
- metadata={"help": "Whether to use decay in the learning rate scheduler."},
 
 
240
  )
241
 
242
  num_train_epochs: float = field(
@@ -267,18 +305,18 @@ class TrainingArguments:
267
  },
268
  )
269
 
270
- push_to_hub: bool = field(
271
- default=False,
272
- metadata={
273
- "help": "Whether or not to upload the trained model to the model hub after training."
274
- },
275
- )
276
-
277
  resume_from_checkpoint: Optional[str] = field(
278
  default=None,
279
  metadata={"help": "Reference to a wandb artifact for resuming training."},
280
  )
281
 
 
 
 
 
 
 
 
282
 
283
  class TrainState(train_state.TrainState):
284
  dropout_rng: jnp.ndarray = None
@@ -309,33 +347,6 @@ class TrainState(train_state.TrainState):
309
  )
310
 
311
 
312
- def create_learning_rate_fn(
313
- num_warmup_steps: int,
314
- learning_rate: float,
315
- use_decay: bool,
316
- num_train_steps: int = None, # used only with `use_decay`, typically train_size // batch_size * num_epochs
317
- ) -> Callable[[int], jnp.array]:
318
- """Returns a linear warmup, linear_decay learning rate function."""
319
- if use_decay:
320
- assert (
321
- num_train_steps is not None
322
- ), "Learning rate with decay requires number of training steps"
323
- warmup_fn = optax.linear_schedule(
324
- init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
325
- )
326
- if not use_decay:
327
- return warmup_fn
328
- decay_fn = optax.linear_schedule(
329
- init_value=learning_rate,
330
- end_value=0,
331
- transition_steps=num_train_steps - num_warmup_steps,
332
- )
333
- schedule_fn = optax.join_schedules(
334
- schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
335
- )
336
- return schedule_fn
337
-
338
-
339
  class MetricsLogger:
340
  def __init__(self, state):
341
  self.step = state.step
@@ -529,12 +540,37 @@ def main():
529
  num_params = model.num_params
530
 
531
  # Create learning rate schedule
532
- learning_rate_fn = create_learning_rate_fn(
533
- training_args.warmup_steps,
534
- training_args.learning_rate,
535
- training_args.use_decay,
536
- num_train_steps,
537
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
538
 
539
  # We use Optax's "masking" functionality to not apply weight decay
540
  # to bias and LayerNorm scale parameters. decay_mask_fn returns a
@@ -558,29 +594,22 @@ def main():
558
  return traverse_util.unflatten_dict(flat_mask)
559
 
560
  # create adam optimizer
561
- if training_args.adafactor:
562
- # We use the default parameters here to initialize adafactor,
563
- # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
564
- optimizer = optax.adafactor(
565
- learning_rate=learning_rate_fn,
566
- weight_decay_rate=training_args.weight_decay,
567
- weight_decay_mask=decay_mask_fn,
568
- clipping_threshold=training_args.max_grad_norm,
569
- )
570
- elif training_args.distributed_shampoo:
571
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
572
  # Notes:
573
- # - mask for weight decay is not implemented but we don't use it anyway
574
  optimizer = distributed_shampoo(
575
  learning_rate_fn,
576
- block_size=1024, # recommended default for large LM is 1536
577
- beta1=0.9,
578
- beta2=0.999,
579
  diagonal_epsilon=1e-10,
580
  matrix_epsilon=1e-8,
581
- weight_decay=0.0,
582
- start_preconditioning_step=1001,
583
- preconditioning_compute_steps=10,
 
 
584
  statistics_compute_steps=1,
585
  best_effort_shape_interpretation=True,
586
  graft_type=GraftingType.RMSPROP_NORMALIZED,
@@ -589,23 +618,32 @@ def main():
589
  batch_axis_name="batch",
590
  inverse_failure_threshold=0.1,
591
  moving_average_for_momentum=True,
592
- skip_preconditioning_dim_size_gt=4096,
593
  clip_by_scaled_gradient_norm=None,
594
  precision=jax.lax.Precision.HIGHEST,
595
- best_effort_memory_usage_reduction=False,
596
  )
597
 
598
- else:
599
  optimizer = optax.adamw(
600
  learning_rate=learning_rate_fn,
601
- b1=training_args.adam_beta1,
602
- b2=training_args.adam_beta2,
603
  eps=training_args.adam_epsilon,
604
  weight_decay=training_args.weight_decay
605
  if training_args.weight_decay is not None
606
  else 0.0,
607
  mask=decay_mask_fn,
608
  )
 
 
 
 
 
 
 
 
 
609
 
610
  # add gradient accumulation
611
  if training_args.gradient_accumulation_steps > 1:
@@ -821,16 +859,6 @@ def main():
821
 
822
  wandb.run.log_artifact(artifact)
823
 
824
- # save to the hub
825
- if training_args.push_to_hub:
826
- model.save_pretrained(
827
- training_args.output_dir,
828
- params=params,
829
- push_to_hub=training_args.push_to_hub,
830
- commit_message=f"Saving weights and logs at step {unreplicate(state.step)+1}",
831
- temp_dir=True, # avoid issues with being in a repository
832
- )
833
-
834
  # init variables
835
  last_time = time.perf_counter()
836
  train_metrics = None
@@ -841,7 +869,7 @@ def main():
841
  metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
842
 
843
  # Generate an epoch by shuffling sampling indices from the train dataset
844
- train_loader = dataset.dataloader("train", train_batch_size)
845
  # train
846
  for batch in tqdm(
847
  train_loader,
 
65
  config_name: Optional[str] = field(
66
  default=None,
67
  metadata={
68
+ "help": "Pretrained config name or path if not the same as model_name_or_path"
69
  },
70
  )
71
  tokenizer_name: Optional[str] = field(
 
77
  dtype: Optional[str] = field(
78
  default="float32",
79
  metadata={
80
+ "help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
81
  },
82
  )
83
 
 
106
  )
107
  train_file: Optional[str] = field(
108
  default=None,
109
+ metadata={
110
+ "help": "The input training data file (glob & braceexpand acceptable)."
111
+ },
112
  )
113
  validation_file: Optional[str] = field(
114
  default=None,
115
+ metadata={
116
+ "help": "An optional input evaluation data file (glob & braceexpand acceptable)."
117
+ },
118
  )
119
  # data loading should not be a bottleneck so we use "streaming" mode by default
120
  streaming: Optional[bool] = field(
 
136
  max_train_samples: Optional[int] = field(
137
  default=None,
138
  metadata={
139
+ "help": "For debugging purposes or quicker training, truncate the number of training examples."
 
140
  },
141
  )
142
  max_eval_samples: Optional[int] = field(
143
  default=None,
144
  metadata={
145
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples."
 
146
  },
147
  )
148
  preprocessing_num_workers: Optional[int] = field(
 
193
 
194
  do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
195
  do_eval: bool = field(
196
+ default=False, metadata={"help": "Whether to run eval on the validation set."}
197
  )
198
 
199
  per_device_train_batch_size: int = field(
200
+ default=8, metadata={"help": "Batch size per GPU/TPU/CPU for training."}
201
  )
202
  per_device_eval_batch_size: int = field(
203
+ default=8, metadata={"help": "Batch size per GPU/TPU/CPU for evaluation."}
204
  )
205
 
206
  gradient_accumulation_steps: int = field(
207
  default=1,
208
  metadata={
209
+ "help": "Number of updates steps to accumulate before performing an update pass."
210
  },
211
  )
212
 
213
  learning_rate: float = field(
214
  default=5e-5, metadata={"help": "The initial learning rate."}
215
  )
216
+ optim: str = field(
217
+ default="distributed_shampoo",
218
+ metadata={
219
+ "help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
220
+ },
 
 
 
 
 
221
  )
222
+ weight_decay: float = field(default=None, metadata={"help": "Weight decay."})
223
+ beta1: float = field(
224
+ default=0.9,
225
+ metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
226
  )
227
+ beta2: float = field(
228
+ default=0.999,
229
+ metadata={"help": "Beta2 for for Adam & Distributed Shampoo."},
230
  )
231
  adam_epsilon: float = field(
232
  default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
 
234
  max_grad_norm: float = field(
235
  default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
236
  )
237
+ block_size: int = field(
238
+ default=1024,
239
+ metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
240
+ )
241
+ preconditioning_compute_steps: int = field(
242
+ default=10, metadata={"help": "Number of steps to update preconditioner."}
243
+ )
244
+ skip_preconditioning_dim_size_gt: int = field(
245
+ default=4096,
246
+ metadata={"help": "Max size for preconditioning with Distributed Shampoo."},
247
+ )
248
+ optim_quantized: bool = field(
249
+ default=False,
250
+ metadata={
251
+ "help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
252
+ },
253
+ )
254
+
255
+ lr_decay: str = field(
256
+ default=None,
257
+ metadata={
258
+ "help": "Decay to be used in the learning rate scheduler. Can be None (default), linear or exponential."
259
+ },
260
+ )
261
+ lr_transition_steps: int = field(
262
+ default=None,
263
+ metadata={
264
+ "help": "Number of transition steps associated with learning rate decay when using exponential decay."
265
+ },
266
+ )
267
+ lr_decay_rate: float = field(
268
+ default=None,
269
+ metadata={
270
+ "help": "Decay rate associated with learning rate when using exponential decay."
271
+ },
272
+ )
273
+ lr_staircase: bool = field(
274
  default=False,
275
+ metadata={
276
+ "help": "Whether to use staircase or continuous learning rate when using exponential decay."
277
+ },
278
  )
279
 
280
  num_train_epochs: float = field(
 
305
  },
306
  )
307
 
 
 
 
 
 
 
 
308
  resume_from_checkpoint: Optional[str] = field(
309
  default=None,
310
  metadata={"help": "Reference to a wandb artifact for resuming training."},
311
  )
312
 
313
+ def __post_init__(self):
314
+ assert self.optim in [
315
+ "distributed_shampoo",
316
+ "adam",
317
+ "adafactor",
318
+ ], f"Selected optimizer not supported: {self.optim}"
319
+
320
 
321
  class TrainState(train_state.TrainState):
322
  dropout_rng: jnp.ndarray = None
 
347
  )
348
 
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  class MetricsLogger:
351
  def __init__(self, state):
352
  self.step = state.step
 
540
  num_params = model.num_params
541
 
542
  # Create learning rate schedule
543
+ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
544
+ """Create the learning rate function."""
545
+ warmup_fn = optax.linear_schedule(
546
+ init_value=0.0,
547
+ end_value=training_args.learning_rate,
548
+ transition_steps=training_args.warmup_steps,
549
+ )
550
+ if training_args.lr_decay is None:
551
+ return warmup_fn
552
+ elif training_args.lr_decay == "linear":
553
+ assert (
554
+ num_train_steps is not None
555
+ ), "linear decay requires knowing the dataset length"
556
+ decay_fn = optax.linear_schedule(
557
+ init_value=training_args.learning_rate,
558
+ end_value=0,
559
+ transition_steps=num_train_steps - training_args.warmup_steps,
560
+ )
561
+ elif training_args.lr_decay == "exponential":
562
+ decay_fn = optax.exponential_decay(
563
+ init_value=training_args.learning_rate,
564
+ transition_steps=training_args.lr_transition_steps,
565
+ decay_rate=training_args.lr_decay_rate,
566
+ staircase=training_args.lr_staircase,
567
+ )
568
+ schedule_fn = optax.join_schedules(
569
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
570
+ )
571
+ return schedule_fn
572
+
573
+ learning_rate_fn = create_learning_rate_fn()
574
 
575
  # We use Optax's "masking" functionality to not apply weight decay
576
  # to bias and LayerNorm scale parameters. decay_mask_fn returns a
 
594
  return traverse_util.unflatten_dict(flat_mask)
595
 
596
  # create adam optimizer
597
+ if training_args.optim == "distributed_shampoo":
 
 
 
 
 
 
 
 
 
598
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
599
  # Notes:
600
+ # - mask for weight decay is not implemented
601
  optimizer = distributed_shampoo(
602
  learning_rate_fn,
603
+ block_size=training_args.block_size,
604
+ beta1=training_args.beta1,
605
+ beta2=training_args.beta2,
606
  diagonal_epsilon=1e-10,
607
  matrix_epsilon=1e-8,
608
+ weight_decay=training_args.weight_decay
609
+ if training_args.weight_decay is not None
610
+ else 0.0,
611
+ start_preconditioning_step=training_args.warmup_steps,
612
+ preconditioning_compute_steps=training_args.preconditioning_compute_steps,
613
  statistics_compute_steps=1,
614
  best_effort_shape_interpretation=True,
615
  graft_type=GraftingType.RMSPROP_NORMALIZED,
 
618
  batch_axis_name="batch",
619
  inverse_failure_threshold=0.1,
620
  moving_average_for_momentum=True,
621
+ skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
622
  clip_by_scaled_gradient_norm=None,
623
  precision=jax.lax.Precision.HIGHEST,
624
+ best_effort_memory_usage_reduction=training_args.optim_quantized,
625
  )
626
 
627
+ elif training_args.optim == "adam":
628
  optimizer = optax.adamw(
629
  learning_rate=learning_rate_fn,
630
+ b1=training_args.beta1,
631
+ b2=training_args.beta2,
632
  eps=training_args.adam_epsilon,
633
  weight_decay=training_args.weight_decay
634
  if training_args.weight_decay is not None
635
  else 0.0,
636
  mask=decay_mask_fn,
637
  )
638
+ elif training_args.optim == "adafactor":
639
+ # We use the default parameters here to initialize adafactor,
640
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
641
+ optimizer = optax.adafactor(
642
+ learning_rate=learning_rate_fn,
643
+ weight_decay_rate=training_args.weight_decay,
644
+ weight_decay_mask=decay_mask_fn,
645
+ clipping_threshold=training_args.max_grad_norm,
646
+ )
647
 
648
  # add gradient accumulation
649
  if training_args.gradient_accumulation_steps > 1:
 
859
 
860
  wandb.run.log_artifact(artifact)
861
 
 
 
 
 
 
 
 
 
 
 
862
  # init variables
863
  last_time = time.perf_counter()
864
  train_metrics = None
 
869
  metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
870
 
871
  # Generate an epoch by shuffling sampling indices from the train dataset
872
+ train_loader = dataset.dataloader("train", train_batch_size, epoch)
873
  # train
874
  for batch in tqdm(
875
  train_loader,