boris commited on
Commit
df01fa8
1 Parent(s): 4fa53a5

fix: style

Browse files
Files changed (1) hide show
  1. tools/train/train.py +6 -2
tools/train/train.py CHANGED
@@ -647,7 +647,9 @@ def main():
647
 
648
  # add gradient accumulation
649
  if training_args.gradient_accumulation_steps > 1:
650
- optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
 
 
651
 
652
  # Setup train state
653
  state = TrainState.create(
@@ -691,7 +693,9 @@ def main():
691
 
692
  metrics = {
693
  "loss": loss,
694
- "learning_rate": learning_rate_fn(state.step // training_args.gradient_accumulation_steps),
 
 
695
  }
696
  metrics = jax.lax.pmean(metrics, axis_name="batch")
697
 
 
647
 
648
  # add gradient accumulation
649
  if training_args.gradient_accumulation_steps > 1:
650
+ optimizer = optax.MultiSteps(
651
+ optimizer, training_args.gradient_accumulation_steps
652
+ )
653
 
654
  # Setup train state
655
  state = TrainState.create(
 
693
 
694
  metrics = {
695
  "loss": loss,
696
+ "learning_rate": learning_rate_fn(
697
+ state.step // training_args.gradient_accumulation_steps
698
+ ),
699
  }
700
  metrics = jax.lax.pmean(metrics, axis_name="batch")
701