boris commited on
Commit
a1c047b
2 Parent(s): b29bab7 b20769d

Merge pull request #22 from borisdayma/feat-axis

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +25 -57
seq2seq/run_seq2seq_flax.py CHANGED
@@ -57,7 +57,6 @@ from transformers import (
57
  FlaxBartForConditionalGeneration,
58
  HfArgumentParser,
59
  TrainingArguments,
60
- is_tensorboard_available,
61
  )
62
  from transformers.models.bart.modeling_flax_bart import *
63
  from transformers.file_utils import is_offline_mode
@@ -229,12 +228,6 @@ class DataTrainingArguments:
229
  "value if set."
230
  },
231
  )
232
- eval_interval: Optional[int] = field(
233
- default=400,
234
- metadata={
235
- "help": "Evaluation will be performed every eval_interval steps"
236
- },
237
- )
238
  log_model: bool = field(
239
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
240
  )
@@ -327,19 +320,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
327
  yield batch
328
 
329
 
330
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
331
- summary_writer.scalar("train_time", train_time, step)
332
-
333
- train_metrics = get_metrics(train_metrics)
334
- for key, vals in train_metrics.items():
335
- tag = f"train_epoch/{key}"
336
- for i, val in enumerate(vals):
337
- summary_writer.scalar(tag, val, step - len(vals) + i + 1)
338
-
339
- for metric_name, value in eval_metrics.items():
340
- summary_writer.scalar(f"eval/{metric_name}", value, step)
341
-
342
-
343
  def create_learning_rate_fn(
344
  train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
345
  ) -> Callable[[int], jnp.array]:
@@ -356,6 +336,14 @@ def create_learning_rate_fn(
356
  return schedule_fn
357
 
358
 
 
 
 
 
 
 
 
 
359
  def main():
360
  # See all possible arguments in src/transformers/training_args.py
361
  # or by passing the --help flag to this script.
@@ -369,6 +357,9 @@ def main():
369
  else:
370
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
371
 
 
 
 
372
  if (
373
  os.path.exists(training_args.output_dir)
374
  and os.listdir(training_args.output_dir)
@@ -382,13 +373,16 @@ def main():
382
 
383
  # Set up wandb run
384
  wandb.init(
385
- sync_tensorboard=True,
386
  entity='wandb',
387
  project='hf-flax-dalle-mini',
388
  job_type='Seq2SeqVQGAN',
389
  config=parser.parse_args()
390
  )
391
 
 
 
 
 
392
  # Make one log on every process with the configuration for debugging.
393
  pylogging.basicConfig(
394
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -583,24 +577,6 @@ def main():
583
  result = {k: round(v, 4) for k, v in result.items()}
584
  return result
585
 
586
- # Enable tensorboard only on the master node
587
- has_tensorboard = is_tensorboard_available()
588
- if has_tensorboard and jax.process_index() == 0:
589
- try:
590
- from flax.metrics.tensorboard import SummaryWriter
591
-
592
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
593
- except ImportError as ie:
594
- has_tensorboard = False
595
- logger.warning(
596
- f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
597
- )
598
- else:
599
- logger.warning(
600
- "Unable to display metrics through TensorBoard because the package is not installed: "
601
- "Please run pip install tensorboard to enable."
602
- )
603
-
604
  # Initialize our training
605
  rng = jax.random.PRNGKey(training_args.seed)
606
  rng, dropout_rng = jax.random.split(rng)
@@ -780,10 +756,8 @@ def main():
780
  eval_metrics = get_metrics(eval_metrics)
781
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
782
 
783
- if jax.process_index() == 0:
784
- for k, v in eval_metrics.items():
785
- wandb.log({"eval/step": global_step})
786
- wandb.log({f"eval/{k}": jax.device_get(v)})
787
 
788
  # compute ROUGE metrics
789
  rouge_desc = ""
@@ -796,6 +770,7 @@ def main():
796
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
797
  epochs.write(desc)
798
  epochs.desc = desc
 
799
  return eval_metrics
800
 
801
  for epoch in epochs:
@@ -804,7 +779,6 @@ def main():
804
 
805
  # Create sampling rng
806
  rng, input_rng = jax.random.split(rng)
807
- train_metrics = []
808
 
809
  # Generate an epoch by shuffling sampling indices from the train dataset
810
  train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
@@ -814,32 +788,26 @@ def main():
814
  global_step +=1
815
  batch = next(train_loader)
816
  state, train_metric = p_train_step(state, batch)
817
- train_metrics.append(train_metric)
818
 
819
  if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
820
- print("logging train loss")
821
- for k, v in unreplicate(train_metric).items():
822
- wandb.log({"train/step": global_step})
823
- wandb.log({f"train/{k}": jax.device_get(v)})
824
 
825
- if global_step % data_args.eval_interval == 0 and jax.process_index() == 0:
826
  run_evaluation()
 
 
 
827
 
828
  train_time += time.time() - train_start
829
-
830
  train_metric = unreplicate(train_metric)
831
-
832
  epochs.write(
833
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
834
  )
835
 
 
836
  eval_metrics = run_evaluation()
837
 
838
- # Save metrics
839
- if has_tensorboard and jax.process_index() == 0:
840
- cur_step = epoch * (len(train_dataset) // train_batch_size)
841
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
842
-
843
  # save checkpoint after each epoch and push checkpoint to the hub
844
  if jax.process_index() == 0:
845
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
 
57
  FlaxBartForConditionalGeneration,
58
  HfArgumentParser,
59
  TrainingArguments,
 
60
  )
61
  from transformers.models.bart.modeling_flax_bart import *
62
  from transformers.file_utils import is_offline_mode
 
228
  "value if set."
229
  },
230
  )
 
 
 
 
 
 
231
  log_model: bool = field(
232
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
233
  )
 
320
  yield batch
321
 
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  def create_learning_rate_fn(
324
  train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
325
  ) -> Callable[[int], jnp.array]:
 
336
  return schedule_fn
337
 
338
 
339
+ def wandb_log(metrics, step=None, prefix=None):
340
+ if jax.process_index() == 0:
341
+ log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
342
+ if step is not None:
343
+ log_metrics = {**log_metrics, 'train/step': step}
344
+ wandb.log(log_metrics)
345
+
346
+
347
  def main():
348
  # See all possible arguments in src/transformers/training_args.py
349
  # or by passing the --help flag to this script.
 
357
  else:
358
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
359
 
360
+ logger.warning(f"eval_steps has been manually hardcoded") # TODO: remove it later, convenient for now
361
+ training_args.eval_steps = 400
362
+
363
  if (
364
  os.path.exists(training_args.output_dir)
365
  and os.listdir(training_args.output_dir)
 
373
 
374
  # Set up wandb run
375
  wandb.init(
 
376
  entity='wandb',
377
  project='hf-flax-dalle-mini',
378
  job_type='Seq2SeqVQGAN',
379
  config=parser.parse_args()
380
  )
381
 
382
+ # set default x-axis as 'train/step'
383
+ wandb.define_metric('train/step')
384
+ wandb.define_metric('*', step_metric='train/step')
385
+
386
  # Make one log on every process with the configuration for debugging.
387
  pylogging.basicConfig(
388
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 
577
  result = {k: round(v, 4) for k, v in result.items()}
578
  return result
579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
  # Initialize our training
581
  rng = jax.random.PRNGKey(training_args.seed)
582
  rng, dropout_rng = jax.random.split(rng)
 
756
  eval_metrics = get_metrics(eval_metrics)
757
  eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
758
 
759
+ # log metrics
760
+ wandb_log(eval_metrics, step=global_step, prefix='eval')
 
 
761
 
762
  # compute ROUGE metrics
763
  rouge_desc = ""
 
770
  desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
771
  epochs.write(desc)
772
  epochs.desc = desc
773
+
774
  return eval_metrics
775
 
776
  for epoch in epochs:
 
779
 
780
  # Create sampling rng
781
  rng, input_rng = jax.random.split(rng)
 
782
 
783
  # Generate an epoch by shuffling sampling indices from the train dataset
784
  train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
 
788
  global_step +=1
789
  batch = next(train_loader)
790
  state, train_metric = p_train_step(state, batch)
 
791
 
792
  if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
793
+ # log metrics
794
+ wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
 
 
795
 
796
+ if global_step % training_args.eval_steps == 0:
797
  run_evaluation()
798
+
799
+ # log final train metrics
800
+ wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
801
 
802
  train_time += time.time() - train_start
 
803
  train_metric = unreplicate(train_metric)
 
804
  epochs.write(
805
  f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
806
  )
807
 
808
+ # Final evaluation
809
  eval_metrics = run_evaluation()
810
 
 
 
 
 
 
811
  # save checkpoint after each epoch and push checkpoint to the hub
812
  if jax.process_index() == 0:
813
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))