boris commited on
Commit
36cb737
1 Parent(s): 3cd6d41

fix: comment

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +2 -1
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -753,6 +753,7 @@ def main():
753
  # restore optimizer state and step
754
  state = state.restore_state(artifact_dir)
755
  # TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
 
756
 
757
  # label smoothed cross entropy
758
  def loss_fn(logits, labels):
@@ -937,7 +938,7 @@ def main():
937
  for epoch in epochs:
938
  # ======================== Training ================================
939
  step = unreplicate(state.step)
940
- # wandb_log({"train/epoch": epoch}, step=step)
941
 
942
  # Generate an epoch by shuffling sampling indices from the train dataset
943
  if data_args.streaming:
 
753
  # restore optimizer state and step
754
  state = state.restore_state(artifact_dir)
755
  # TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
756
+ # TODO: optimizer may use a different step for learning rate, we should serialize/restore entire state
757
 
758
  # label smoothed cross entropy
759
  def loss_fn(logits, labels):
 
938
  for epoch in epochs:
939
  # ======================== Training ================================
940
  step = unreplicate(state.step)
941
+ wandb_log({"train/epoch": epoch}, step=step)
942
 
943
  # Generate an epoch by shuffling sampling indices from the train dataset
944
  if data_args.streaming: