boris commited on
Commit
bc4734f
1 Parent(s): 9f5e879

fix(train): consider schedule offset

Browse files
Files changed (1) hide show
  1. tools/train/train.py +2 -1
tools/train/train.py CHANGED
@@ -688,7 +688,8 @@ def main():
688
  staircase=training_args.lr_staircase,
689
  )
690
  schedule_fn = optax.join_schedules(
691
- schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
 
692
  )
693
  return schedule_fn
694
 
 
688
  staircase=training_args.lr_staircase,
689
  )
690
  schedule_fn = optax.join_schedules(
691
+ schedules=[warmup_fn, decay_fn],
692
+ boundaries=[model_metadata.get("step", 0) + training_args.warmup_steps],
693
  )
694
  return schedule_fn
695