cahya commited on
Commit
db85c97
1 Parent(s): c3ae139

add wandb integration

Browse files
Files changed (1) hide show
  1. run_clm_flax.py +18 -0
run_clm_flax.py CHANGED
@@ -53,6 +53,7 @@ from transformers import (
53
  is_tensorboard_available,
54
  )
55
  from transformers.testing_utils import CaptureLogger
 
56
 
57
 
58
  logger = logging.getLogger(__name__)
@@ -232,6 +233,13 @@ def main():
232
  # or by passing the --help flag to this script.
233
  # We now keep distinct sets of args, for a cleaner separation of concerns.
234
 
 
 
 
 
 
 
 
235
  parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
236
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
237
  # If we pass only one argument to the script and it's the path to a json file,
@@ -250,6 +258,13 @@ def main():
250
  f"Output directory ({training_args.output_dir}) already exists and is not empty."
251
  "Use --overwrite_output_dir to overcome."
252
  )
 
 
 
 
 
 
 
253
 
254
  # Make one log on every process with the configuration for debugging.
255
  logging.basicConfig(
@@ -591,6 +606,8 @@ def main():
591
  epochs.write(
592
  f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
593
  )
 
 
594
 
595
  train_metrics = []
596
 
@@ -623,6 +640,7 @@ def main():
623
  if has_tensorboard and jax.process_index() == 0:
624
  cur_step = epoch * (len(train_dataset) // train_batch_size)
625
  write_eval_metric(summary_writer, eval_metrics, cur_step)
 
626
 
627
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
628
  # save checkpoint after each epoch and push checkpoint to the hub
 
53
  is_tensorboard_available,
54
  )
55
  from transformers.testing_utils import CaptureLogger
56
+ import wandb
57
 
58
 
59
  logger = logging.getLogger(__name__)
 
233
  # or by passing the --help flag to this script.
234
  # We now keep distinct sets of args, for a cleaner separation of concerns.
235
 
236
+ if jax.process_index() == 0:
237
+ wandb.init(
238
+ entity = os.getenv("WANDB_ENTITY", "indonesian-nlp"),
239
+ project = os.getenv("WANDB_PROJECT", "huggingface"),
240
+ sync_tensorboard =True
241
+ )
242
+
243
  parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
244
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
245
  # If we pass only one argument to the script and it's the path to a json file,
 
258
  f"Output directory ({training_args.output_dir}) already exists and is not empty."
259
  "Use --overwrite_output_dir to overcome."
260
  )
261
+ # log your configs with wandb.config, accepts a dict
262
+ if jax.process_index() == 0:
263
+ wandb.config.update(training_args) # optional, log your configs
264
+ wandb.config.update(model_args) # optional, log your configs
265
+ wandb.config.update(data_args) # optional, log your configs
266
+
267
+ wandb.config['test_log'] = 12345 # log additional things
268
 
269
  # Make one log on every process with the configuration for debugging.
270
  logging.basicConfig(
 
606
  epochs.write(
607
  f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
608
  )
609
+ if jax.process_index() == 0:
610
+ wandb.log({'my_metric': train_metrics})
611
 
612
  train_metrics = []
613
 
 
640
  if has_tensorboard and jax.process_index() == 0:
641
  cur_step = epoch * (len(train_dataset) // train_batch_size)
642
  write_eval_metric(summary_writer, eval_metrics, cur_step)
643
+ wandb.log({'my_metric': eval_metrics})
644
 
645
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
646
  # save checkpoint after each epoch and push checkpoint to the hub