Spaces:
Running
Running
fix: style
Browse files- tools/train/train.py +6 -2
tools/train/train.py
CHANGED
@@ -647,7 +647,9 @@ def main():
|
|
647 |
|
648 |
# add gradient accumulation
|
649 |
if training_args.gradient_accumulation_steps > 1:
|
650 |
-
optimizer = optax.MultiSteps(
|
|
|
|
|
651 |
|
652 |
# Setup train state
|
653 |
state = TrainState.create(
|
@@ -691,7 +693,9 @@ def main():
|
|
691 |
|
692 |
metrics = {
|
693 |
"loss": loss,
|
694 |
-
"learning_rate": learning_rate_fn(
|
|
|
|
|
695 |
}
|
696 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
697 |
|
|
|
647 |
|
648 |
# add gradient accumulation
|
649 |
if training_args.gradient_accumulation_steps > 1:
|
650 |
+
optimizer = optax.MultiSteps(
|
651 |
+
optimizer, training_args.gradient_accumulation_steps
|
652 |
+
)
|
653 |
|
654 |
# Setup train state
|
655 |
state = TrainState.create(
|
|
|
693 |
|
694 |
metrics = {
|
695 |
"loss": loss,
|
696 |
+
"learning_rate": learning_rate_fn(
|
697 |
+
state.step // training_args.gradient_accumulation_steps
|
698 |
+
),
|
699 |
}
|
700 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
701 |
|