boris commited on
Commit
7143593
1 Parent(s): e501f71

fix: weight decay Adam + speed logging

Browse files
Files changed (1) hide show
  1. tools/train/train.py +9 -5
tools/train/train.py CHANGED
@@ -353,10 +353,12 @@ class MetricsLogger:
353
  # timing metrics
354
  new_step = int(unreplicate(state.step))
355
  new_time = time.perf_counter()
356
- time_per_step = (new_time - self.time) / (new_step - self.step)
357
- self.step = new_step
358
- self.time = new_time
359
- return {**metrics, **state_dict, "time_per_step": time_per_step}
 
 
360
 
361
  @staticmethod
362
  def log(metrics, step=None, prefix=None):
@@ -599,7 +601,9 @@ def main():
599
  b1=training_args.adam_beta1,
600
  b2=training_args.adam_beta2,
601
  eps=training_args.adam_epsilon,
602
- weight_decay=training_args.weight_decay,
 
 
603
  mask=decay_mask_fn,
604
  )
605
 
 
353
  # timing metrics
354
  new_step = int(unreplicate(state.step))
355
  new_time = time.perf_counter()
356
+ if new_step > self.step:
357
+ time_per_step = (new_time - self.time) / (new_step - self.step)
358
+ self.step = new_step
359
+ self.time = new_time
360
+ state_dict["time_per_step"] = time_per_step
361
+ return {**metrics, **state_dict}
362
 
363
  @staticmethod
364
  def log(metrics, step=None, prefix=None):
 
601
  b1=training_args.adam_beta1,
602
  b2=training_args.adam_beta2,
603
  eps=training_args.adam_epsilon,
604
+ weight_decay=training_args.weight_decay
605
+ if training_args.weight_decay is not None
606
+ else 0.0,
607
  mask=decay_mask_fn,
608
  )
609