boris commited on
Commit
6523a6d
1 Parent(s): 87fac28

feat: add metrics + cleanup

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +83 -81
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -21,6 +21,7 @@ Script adapted from run_summarization_flax.py
21
  import os
22
  import logging
23
  import sys
 
24
  from dataclasses import dataclass, field
25
  from pathlib import Path
26
  from typing import Callable, Optional
@@ -37,7 +38,6 @@ import optax
37
  import transformers
38
  from flax import jax_utils, traverse_util
39
  from flax.serialization import from_bytes, to_bytes
40
- import flax.linen as nn
41
  from flax.jax_utils import unreplicate
42
  from flax.training import train_state
43
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
@@ -136,14 +136,6 @@ class DataTrainingArguments:
136
  default=False,
137
  metadata={"help": "Whether to stream the dataset."},
138
  )
139
- len_train: Optional[int] = field(
140
- default=None,
141
- metadata={"help": "Length of training dataset, required for streaming"},
142
- )
143
- len_eval: Optional[int] = field(
144
- default=None,
145
- metadata={"help": "Length of validation dataset, required for streaming"},
146
- )
147
  max_source_length: Optional[int] = field(
148
  default=128,
149
  metadata={
@@ -189,10 +181,6 @@ class DataTrainingArguments:
189
  default=False,
190
  metadata={"help": "Log frequency for model"},
191
  )
192
- save_model_steps: Optional[int] = field(
193
- default=5000,
194
- metadata={"help": "For saving/logging the model more frequently"},
195
- )
196
 
197
  def __post_init__(self):
198
  if self.dataset_repo_or_path is None:
@@ -201,6 +189,9 @@ class DataTrainingArguments:
201
 
202
  class TrainState(train_state.TrainState):
203
  dropout_rng: jnp.ndarray = None
 
 
 
204
 
205
  def replicate(self):
206
  return jax_utils.replicate(self).replace(
@@ -212,13 +203,17 @@ class TrainState(train_state.TrainState):
212
  with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
213
  new_opt_state = from_bytes(self.opt_state, f.read())
214
 
215
- # restore steps
216
  with (Path(artifact_dir) / "training_state.json").open("r") as f:
217
  training_state = json.load(f)
218
- new_step = training_state["step"]
219
 
220
  # replace state
221
- return self.replace(step=new_step, opt_state=new_opt_state)
 
 
 
 
 
222
 
223
 
224
  def data_loader(
@@ -259,16 +254,16 @@ def data_loader_streaming(dataset: Dataset, batch_size: int):
259
 
260
 
261
  def create_learning_rate_fn(
262
- train_ds_size: int,
263
- train_batch_size: int,
264
- num_train_epochs: int,
265
  num_warmup_steps: int,
266
  learning_rate: float,
267
  use_decay: bool,
 
268
  ) -> Callable[[int], jnp.array]:
269
  """Returns a linear warmup, linear_decay learning rate function."""
270
- steps_per_epoch = train_ds_size // train_batch_size
271
- num_train_steps = steps_per_epoch * num_train_epochs
 
 
272
  warmup_fn = optax.linear_schedule(
273
  init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
274
  )
@@ -364,7 +359,6 @@ def main():
364
  project="dalle-mini",
365
  job_type="Seq2Seq",
366
  config=parser.parse_args(),
367
- save_code=True,
368
  )
369
 
370
  if model_args.from_checkpoint is not None:
@@ -562,35 +556,26 @@ def main():
562
  )
563
  batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
564
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
 
565
  if data_args.streaming:
566
- len_train_dataset = data_args.len_train
567
- if (
568
- data_args.max_train_samples is not None
569
- and data_args.max_train_samples < len_train_dataset
570
- ):
571
  len_train_dataset = data_args.max_train_samples
572
-
573
- len_eval_dataset = data_args.len_eval
574
- if (
575
- data_args.max_eval_samples is not None
576
- and data_args.max_eval_samples < len_eval_dataset
577
- ):
578
  len_eval_dataset = data_args.max_eval_samples
579
  else:
580
  len_train_dataset = len(train_dataset)
581
  len_eval_dataset = len(eval_dataset)
582
- steps_per_epoch = len_train_dataset // train_batch_size
583
- total_steps = steps_per_epoch * num_epochs
584
- total_optimization_steps = (len_train_dataset // batch_size_per_update) * num_epochs
585
 
586
  # Create learning rate schedule
587
  learning_rate_fn = create_learning_rate_fn(
588
- len_train_dataset,
589
- train_batch_size,
590
- training_args.num_train_epochs,
591
  training_args.warmup_steps,
592
  training_args.learning_rate,
593
  data_args.use_decay,
 
594
  )
595
 
596
  # We use Optax's "masking" functionality to not apply weight decay
@@ -621,7 +606,7 @@ def main():
621
  optimizer = optax.adafactor(
622
  learning_rate=learning_rate_fn,
623
  weight_decay_rate=training_args.weight_decay,
624
- weight_decay_mask=decay_mask_fn
625
  )
626
  else:
627
  optimizer = optax.adamw(
@@ -647,10 +632,9 @@ def main():
647
  dropout_rng=dropout_rng,
648
  )
649
  if model_args.from_checkpoint is not None:
650
- # restore optimizer state and step
 
651
  state = state.restore_state(artifact_dir)
652
- # TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
653
- # TODO: optimizer may use a different step for learning rate, we should serialize/restore entire state
654
 
655
  # label smoothed cross entropy
656
  def loss_fn(logits, labels):
@@ -659,7 +643,7 @@ def main():
659
  return loss
660
 
661
  # Define gradient update step fn
662
- def train_step(state, batch):
663
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
664
 
665
  def compute_loss(params, batch):
@@ -673,14 +657,20 @@ def main():
673
  grad_fn = jax.value_and_grad(compute_loss)
674
  loss, grads = grad_fn(state.params, batch)
675
  grads = jax.lax.pmean(grads, "batch")
676
- state = state.apply_gradients(grads=grads)
 
 
 
 
 
677
 
678
  metrics = {
679
  "loss": loss,
680
  "learning_rate": learning_rate_fn(state.step),
681
  }
682
  metrics = jax.lax.pmean(metrics, axis_name="batch")
683
- return state.replace(dropout_rng=new_dropout_rng), metrics
 
684
 
685
  # Define eval fn
686
  def eval_step(params, batch):
@@ -697,10 +687,6 @@ def main():
697
  p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
698
  p_eval_step = jax.pmap(eval_step, "batch")
699
 
700
- # Replicate the train state on each device
701
- del model._params
702
- state = state.replicate()
703
-
704
  logger.info("***** Running training *****")
705
  logger.info(f" Num examples = {len_train_dataset}")
706
  logger.info(f" Num Epochs = {num_epochs}")
@@ -710,13 +696,12 @@ def main():
710
  logger.info(
711
  f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
712
  )
713
- logger.info(f" Total global steps = {total_steps}")
714
- logger.info(f" Total optimization steps = {total_optimization_steps}")
715
-
716
- epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
717
 
718
  # set default x-axis as 'train/step'
719
- wandb_log({}, step=unreplicate(state.step))
720
  wandb.define_metric("*", step_metric="train/step")
721
 
722
  # add interesting config parameters
@@ -725,11 +710,12 @@ def main():
725
  "len_train": len_train_dataset,
726
  "len_eval": len_eval_dataset,
727
  "batch_size_per_update": batch_size_per_update,
728
- "total_steps": total_steps,
729
- "total_optimization_steps": total_optimization_steps,
730
  }
731
  )
732
 
 
 
 
733
  def run_evaluation():
734
  # ======================== Evaluating ==============================
735
  eval_metrics = []
@@ -755,7 +741,7 @@ def main():
755
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
756
 
757
  # log metrics
758
- wandb_log(eval_metrics, step=unreplicate(state.step), prefix="eval")
759
 
760
  # Print metrics and update progress bar
761
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
@@ -764,10 +750,9 @@ def main():
764
 
765
  return eval_metrics
766
 
767
- def run_save_model(state, step, epoch, eval_metrics=None):
768
  if jax.process_index() == 0:
769
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
770
-
771
  # save model locally
772
  model.save_pretrained(
773
  training_args.output_dir,
@@ -778,24 +763,32 @@ def main():
778
  tokenizer.save_pretrained(training_args.output_dir)
779
 
780
  # save state
781
- # TODO: maybe we should just save the full state object without params
782
- state = unreplicate(state)
783
  with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
784
- f.write(to_bytes(state.opt_state))
785
  with (Path(training_args.output_dir) / "training_state.json").open(
786
  "w"
787
  ) as f:
788
- json.dump({"step": state.step.item()}, f)
 
 
 
 
 
 
789
 
790
  # save to W&B
791
  if data_args.log_model:
792
  # save some space
793
  c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
794
- c.cleanup(wandb.util.from_human_size("5GB"))
795
 
796
- metadata = {"step": step, "epoch": epoch}
 
 
 
797
  if eval_metrics is not None:
798
- metadata["eval/loss"] = eval_metrics["loss"]
799
  artifact = wandb.Artifact(
800
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
801
  )
@@ -829,18 +822,19 @@ def main():
829
  training_args.output_dir,
830
  params=params,
831
  push_to_hub=training_args.push_to_hub,
832
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
833
  temp_dir=True, # avoid issues with being in a repository
834
  )
835
 
 
836
  for epoch in epochs:
 
837
  # ======================== Training ================================
838
- step = unreplicate(state.step)
839
- wandb_log({"train/epoch": epoch}, step=step)
840
 
841
  # Generate an epoch by shuffling sampling indices from the train dataset
842
  if data_args.streaming:
843
- train_dataset.set_epoch(epoch)
844
  train_loader = data_loader_streaming(train_dataset, train_batch_size)
845
  else:
846
  rng, input_rng = jax.random.split(rng)
@@ -855,23 +849,31 @@ def main():
855
  leave=False,
856
  total=steps_per_epoch,
857
  ):
858
- state, train_metric = p_train_step(state, batch)
859
- step = unreplicate(state.step)
 
 
 
 
 
 
 
860
 
861
  if step % data_args.log_interval == 0 and jax.process_index() == 0:
862
  # log metrics
863
- wandb_log(unreplicate(train_metric), step=step, prefix="train")
864
 
 
865
  if training_args.eval_steps and step % training_args.eval_steps == 0:
866
- run_evaluation()
867
 
868
- if step % data_args.save_model_steps == 0:
869
- run_save_model(state, step, epoch)
870
 
871
  # log final train metrics
872
- wandb_log(unreplicate(train_metric), step=step, prefix="train")
 
873
 
874
- train_metric = unreplicate(train_metric)
875
  epochs.write(
876
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
877
  )
@@ -880,7 +882,7 @@ def main():
880
  eval_metrics = run_evaluation()
881
 
882
  # save checkpoint after each epoch
883
- run_save_model(state, state.step, epoch, eval_metrics)
884
 
885
 
886
  if __name__ == "__main__":
 
21
  import os
22
  import logging
23
  import sys
24
+ import time
25
  from dataclasses import dataclass, field
26
  from pathlib import Path
27
  from typing import Callable, Optional
 
38
  import transformers
39
  from flax import jax_utils, traverse_util
40
  from flax.serialization import from_bytes, to_bytes
 
41
  from flax.jax_utils import unreplicate
42
  from flax.training import train_state
43
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
 
136
  default=False,
137
  metadata={"help": "Whether to stream the dataset."},
138
  )
 
 
 
 
 
 
 
 
139
  max_source_length: Optional[int] = field(
140
  default=128,
141
  metadata={
 
181
  default=False,
182
  metadata={"help": "Log frequency for model"},
183
  )
 
 
 
 
184
 
185
  def __post_init__(self):
186
  if self.dataset_repo_or_path is None:
 
189
 
190
  class TrainState(train_state.TrainState):
191
  dropout_rng: jnp.ndarray = None
192
+ epoch: int = 0
193
+ train_time: float = 0.0 # total time the model trained
194
+ train_samples: int = 0 # number of samples seen
195
 
196
  def replicate(self):
197
  return jax_utils.replicate(self).replace(
 
203
  with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
204
  new_opt_state = from_bytes(self.opt_state, f.read())
205
 
206
+ # restore other parameters
207
  with (Path(artifact_dir) / "training_state.json").open("r") as f:
208
  training_state = json.load(f)
 
209
 
210
  # replace state
211
+ return self.replace(
212
+ opt_state=new_opt_state,
213
+ step=training_state["step"],
214
+ train_time=training_state["train_time"],
215
+ train_samples=training_state["train_samples"],
216
+ )
217
 
218
 
219
  def data_loader(
 
254
 
255
 
256
  def create_learning_rate_fn(
 
 
 
257
  num_warmup_steps: int,
258
  learning_rate: float,
259
  use_decay: bool,
260
+ num_train_steps: int = None, # used only with `use_decay`, typically train_size // batch_size * num_epochs
261
  ) -> Callable[[int], jnp.array]:
262
  """Returns a linear warmup, linear_decay learning rate function."""
263
+ if use_decay:
264
+ assert (
265
+ num_train_steps is not None
266
+ ), "Learning rate with decay requires number of training steps"
267
  warmup_fn = optax.linear_schedule(
268
  init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
269
  )
 
359
  project="dalle-mini",
360
  job_type="Seq2Seq",
361
  config=parser.parse_args(),
 
362
  )
363
 
364
  if model_args.from_checkpoint is not None:
 
556
  )
557
  batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
558
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
559
+ len_train_dataset, len_eval_dataset = None, None
560
  if data_args.streaming:
561
+ # we don't know the length, let's just assume max_samples if defined
562
+ if data_args.max_train_samples is not None:
 
 
 
563
  len_train_dataset = data_args.max_train_samples
564
+ if data_args.max_eval_samples is not None:
 
 
 
 
 
565
  len_eval_dataset = data_args.max_eval_samples
566
  else:
567
  len_train_dataset = len(train_dataset)
568
  len_eval_dataset = len(eval_dataset)
569
+ steps_per_epoch = (
570
+ len_train_dataset // train_batch_size if len_train_dataset is not None else None
571
+ )
572
 
573
  # Create learning rate schedule
574
  learning_rate_fn = create_learning_rate_fn(
 
 
 
575
  training_args.warmup_steps,
576
  training_args.learning_rate,
577
  data_args.use_decay,
578
+ steps_per_epoch * num_epochs,
579
  )
580
 
581
  # We use Optax's "masking" functionality to not apply weight decay
 
606
  optimizer = optax.adafactor(
607
  learning_rate=learning_rate_fn,
608
  weight_decay_rate=training_args.weight_decay,
609
+ weight_decay_mask=decay_mask_fn,
610
  )
611
  else:
612
  optimizer = optax.adamw(
 
632
  dropout_rng=dropout_rng,
633
  )
634
  if model_args.from_checkpoint is not None:
635
+ # restore optimizer state and other parameters
636
+ # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
637
  state = state.restore_state(artifact_dir)
 
 
638
 
639
  # label smoothed cross entropy
640
  def loss_fn(logits, labels):
 
643
  return loss
644
 
645
  # Define gradient update step fn
646
+ def train_step(state, batch, delta_time):
647
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
648
 
649
  def compute_loss(params, batch):
 
657
  grad_fn = jax.value_and_grad(compute_loss)
658
  loss, grads = grad_fn(state.params, batch)
659
  grads = jax.lax.pmean(grads, "batch")
660
+ state = state.apply_gradients(
661
+ grads=grads,
662
+ dropout_rng=new_dropout_rng,
663
+ train_time=state.train_time + delta_time,
664
+ train_samples=state.train_samples + train_batch_size,
665
+ )
666
 
667
  metrics = {
668
  "loss": loss,
669
  "learning_rate": learning_rate_fn(state.step),
670
  }
671
  metrics = jax.lax.pmean(metrics, axis_name="batch")
672
+
673
+ return state, metrics
674
 
675
  # Define eval fn
676
  def eval_step(params, batch):
 
687
  p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
688
  p_eval_step = jax.pmap(eval_step, "batch")
689
 
 
 
 
 
690
  logger.info("***** Running training *****")
691
  logger.info(f" Num examples = {len_train_dataset}")
692
  logger.info(f" Num Epochs = {num_epochs}")
 
696
  logger.info(
697
  f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
698
  )
699
+ epochs = tqdm(
700
+ range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
701
+ )
 
702
 
703
  # set default x-axis as 'train/step'
704
+ wandb_log({}, step=state.step)
705
  wandb.define_metric("*", step_metric="train/step")
706
 
707
  # add interesting config parameters
 
710
  "len_train": len_train_dataset,
711
  "len_eval": len_eval_dataset,
712
  "batch_size_per_update": batch_size_per_update,
 
 
713
  }
714
  )
715
 
716
+ # replicate state on each device
717
+ state = state.replicate()
718
+
719
  def run_evaluation():
720
  # ======================== Evaluating ==============================
721
  eval_metrics = []
 
741
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
742
 
743
  # log metrics
744
+ wandb_log(eval_metrics, step=get_metrics(state.step), prefix="eval")
745
 
746
  # Print metrics and update progress bar
747
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
 
750
 
751
  return eval_metrics
752
 
753
+ def run_save_model(state, eval_metrics=None):
754
  if jax.process_index() == 0:
755
+ params = jax.device_get(unreplicate(state.params))
 
756
  # save model locally
757
  model.save_pretrained(
758
  training_args.output_dir,
 
763
  tokenizer.save_pretrained(training_args.output_dir)
764
 
765
  # save state
766
+ opt_state = unreplicate(state.opt_state)
 
767
  with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
768
+ f.write(to_bytes(opt_state))
769
  with (Path(training_args.output_dir) / "training_state.json").open(
770
  "w"
771
  ) as f:
772
+ json.dump(
773
+ {
774
+ k: get_metrics(state[k])
775
+ for k in ["step", "epoch", "train_time", "train_samples"]
776
+ },
777
+ f,
778
+ )
779
 
780
  # save to W&B
781
  if data_args.log_model:
782
  # save some space
783
  c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
784
+ c.cleanup(wandb.util.from_human_size("10GB"))
785
 
786
+ metadata = {
787
+ k: get_metrics(state[k])
788
+ for k in ["step", "epoch", "train_time", "train_samples"]
789
+ }
790
  if eval_metrics is not None:
791
+ metadata["eval"] = eval_metrics
792
  artifact = wandb.Artifact(
793
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
794
  )
 
822
  training_args.output_dir,
823
  params=params,
824
  push_to_hub=training_args.push_to_hub,
825
+ commit_message=f"Saving weights and logs at step {get_metrics(state.step)+1}",
826
  temp_dir=True, # avoid issues with being in a repository
827
  )
828
 
829
+ last_time = time.perf_counter()
830
  for epoch in epochs:
831
+ state.replace(epoch=jax_utils.replicate(epoch))
832
  # ======================== Training ================================
833
+ wandb_log({"train/epoch": epoch}, step=get_metrics(state.step))
 
834
 
835
  # Generate an epoch by shuffling sampling indices from the train dataset
836
  if data_args.streaming:
837
+ train_dataset.set_epoch(epoch) # shuffle dataset
838
  train_loader = data_loader_streaming(train_dataset, train_batch_size)
839
  else:
840
  rng, input_rng = jax.random.split(rng)
 
849
  leave=False,
850
  total=steps_per_epoch,
851
  ):
852
+
853
+ # calculate delta time (we have a lag of one step but it's ok)
854
+ new_time = time.perf_counter()
855
+ delta_time = new_time - last_time
856
+ last_time = new_time
857
+
858
+ # train step
859
+ state, train_metric = p_train_step(state, batch, delta_time)
860
+ step = get_metrics(state.step)
861
 
862
  if step % data_args.log_interval == 0 and jax.process_index() == 0:
863
  # log metrics
864
+ wandb_log(get_metrics(train_metric), step=step, prefix="train")
865
 
866
+ eval_metrics = None
867
  if training_args.eval_steps and step % training_args.eval_steps == 0:
868
+ eval_metrics = run_evaluation()
869
 
870
+ if step % training_args.save_steps == 0:
871
+ run_save_model(state, eval_metrics)
872
 
873
  # log final train metrics
874
+ train_metric = get_metrics(train_metric)
875
+ wandb_log(train_metric, step=step, prefix="train")
876
 
 
877
  epochs.write(
878
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
879
  )
 
882
  eval_metrics = run_evaluation()
883
 
884
  # save checkpoint after each epoch
885
+ run_save_model(state, eval_metrics)
886
 
887
 
888
  if __name__ == "__main__":