boris commited on
Commit
fdbe19f
2 Parent(s): 4a4820f 5f6b691

Merge pull request #90 from borisdayma/feat-new

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +18 -31
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -100,12 +100,6 @@ class ModelArguments:
100
  "help": "Pretrained config name or path if not the same as model_name"
101
  },
102
  )
103
- tokenizer_name: Optional[str] = field(
104
- default=None,
105
- metadata={
106
- "help": "Pretrained tokenizer name or path if not the same as model_name"
107
- },
108
- )
109
  cache_dir: Optional[str] = field(
110
  default=None,
111
  metadata={
@@ -422,7 +416,7 @@ def wandb_log(metrics, step=None, prefix=None):
422
  f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
423
  }
424
  if step is not None:
425
- log_metrics["train/step"] = unreplicate(step)
426
  wandb.log(log_metrics)
427
 
428
 
@@ -534,11 +528,6 @@ def main():
534
  )
535
 
536
  else:
537
- base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
538
- model_args.model_name_or_path,
539
- seed=training_args.seed,
540
- dtype=getattr(jnp, model_args.dtype),
541
- )
542
  # Set up our new model config
543
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
544
  config.tie_word_embeddings = False
@@ -563,11 +552,6 @@ def main():
563
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
564
  )
565
 
566
- # Use pre-trained weights for encoder
567
- model.params["model"]["encoder"] = base_model.params["model"]["encoder"]
568
- model.params["model"]["shared"] = base_model.params["model"]["shared"]
569
- del base_model
570
-
571
  # Load tokenizer if it has not been set
572
  if tokenizer is None:
573
  tokenizer = AutoTokenizer.from_pretrained(
@@ -862,7 +846,7 @@ def main():
862
  f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
863
  )
864
  logger.info(
865
- f" Total train batch size (w. parallel & distributed) = {batch_size_per_update}"
866
  )
867
  logger.info(f" Total global steps = {total_steps}")
868
  logger.info(f" Total optimization steps = {total_optimization_steps}")
@@ -870,7 +854,7 @@ def main():
870
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
871
 
872
  # set default x-axis as 'train/step'
873
- wandb_log({}, step=state.step)
874
  wandb.define_metric("*", step_metric="train/step")
875
 
876
  # add interesting config parameters
@@ -909,7 +893,7 @@ def main():
909
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
910
 
911
  # log metrics
912
- wandb_log(eval_metrics, step=state.step, prefix="eval")
913
 
914
  # Print metrics and update progress bar
915
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
@@ -943,6 +927,10 @@ def main():
943
 
944
  # save to W&B
945
  if data_args.log_model:
 
 
 
 
946
  metadata = {"step": step, "epoch": epoch}
947
  if eval_metrics is not None:
948
  metadata["eval/loss"] = eval_metrics["loss"]
@@ -970,11 +958,8 @@ def main():
970
  artifact.add_file(
971
  str(Path(training_args.output_dir) / "training_state.json")
972
  )
973
- wandb.run.log_artifact(artifact)
974
 
975
- # save some space
976
- c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
977
- c.cleanup(wandb.util.from_human_size("5GB"))
978
 
979
  # save to the hub
980
  if training_args.push_to_hub:
@@ -988,7 +973,8 @@ def main():
988
 
989
  for epoch in epochs:
990
  # ======================== Training ================================
991
- wandb_log({"train/epoch": epoch}, step=state.step)
 
992
 
993
  # Create sampling rng
994
  rng, input_rng = jax.random.split(rng)
@@ -1010,19 +996,20 @@ def main():
1010
  total=steps_per_epoch,
1011
  ):
1012
  state, train_metric = p_train_step(state, batch)
 
1013
 
1014
- if state.step % data_args.log_interval == 0 and jax.process_index() == 0:
1015
  # log metrics
1016
- wandb_log(unreplicate(train_metric), step=state.step, prefix="train")
1017
 
1018
- if training_args.eval_steps and state.step % training_args.eval_steps == 0:
1019
  run_evaluation()
1020
 
1021
- if state.step % data_args.save_model_steps == 0:
1022
- run_save_model(state, state.step, epoch)
1023
 
1024
  # log final train metrics
1025
- wandb_log(unreplicate(train_metric), step=state.step, prefix="train")
1026
 
1027
  train_metric = unreplicate(train_metric)
1028
  epochs.write(
 
100
  "help": "Pretrained config name or path if not the same as model_name"
101
  },
102
  )
 
 
 
 
 
 
103
  cache_dir: Optional[str] = field(
104
  default=None,
105
  metadata={
 
416
  f"{prefix}/{k}" if prefix is not None else k: v for k, v in metrics.items()
417
  }
418
  if step is not None:
419
+ log_metrics["train/step"] = step
420
  wandb.log(log_metrics)
421
 
422
 
 
528
  )
529
 
530
  else:
 
 
 
 
 
531
  # Set up our new model config
532
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
533
  config.tie_word_embeddings = False
 
552
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
553
  )
554
 
 
 
 
 
 
555
  # Load tokenizer if it has not been set
556
  if tokenizer is None:
557
  tokenizer = AutoTokenizer.from_pretrained(
 
846
  f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
847
  )
848
  logger.info(
849
+ f" Total train batch size (w. parallel, distributed & gradient accumulation) = {batch_size_per_update}"
850
  )
851
  logger.info(f" Total global steps = {total_steps}")
852
  logger.info(f" Total optimization steps = {total_optimization_steps}")
 
854
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
855
 
856
  # set default x-axis as 'train/step'
857
+ wandb_log({}, step=unreplicate(state.step))
858
  wandb.define_metric("*", step_metric="train/step")
859
 
860
  # add interesting config parameters
 
893
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
894
 
895
  # log metrics
896
+ wandb_log(eval_metrics, step=unreplicate(state.step), prefix="eval")
897
 
898
  # Print metrics and update progress bar
899
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
 
927
 
928
  # save to W&B
929
  if data_args.log_model:
930
+ # save some space
931
+ c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
932
+ c.cleanup(wandb.util.from_human_size("5GB"))
933
+
934
  metadata = {"step": step, "epoch": epoch}
935
  if eval_metrics is not None:
936
  metadata["eval/loss"] = eval_metrics["loss"]
 
958
  artifact.add_file(
959
  str(Path(training_args.output_dir) / "training_state.json")
960
  )
 
961
 
962
+ wandb.run.log_artifact(artifact)
 
 
963
 
964
  # save to the hub
965
  if training_args.push_to_hub:
 
973
 
974
  for epoch in epochs:
975
  # ======================== Training ================================
976
+ step = unreplicate(state.step)
977
+ wandb_log({"train/epoch": epoch}, step=step)
978
 
979
  # Create sampling rng
980
  rng, input_rng = jax.random.split(rng)
 
996
  total=steps_per_epoch,
997
  ):
998
  state, train_metric = p_train_step(state, batch)
999
+ step = unreplicate(state.step)
1000
 
1001
+ if step % data_args.log_interval == 0 and jax.process_index() == 0:
1002
  # log metrics
1003
+ wandb_log(unreplicate(train_metric), step=step, prefix="train")
1004
 
1005
+ if training_args.eval_steps and step % training_args.eval_steps == 0:
1006
  run_evaluation()
1007
 
1008
+ if step % data_args.save_model_steps == 0:
1009
+ run_save_model(state, step, epoch)
1010
 
1011
  # log final train metrics
1012
+ wandb_log(unreplicate(train_metric), step=step, prefix="train")
1013
 
1014
  train_metric = unreplicate(train_metric)
1015
  epochs.write(