Spaces:
Running
Running
fix: weight decay Adam + speed logging
Browse files- tools/train/train.py +9 -5
tools/train/train.py
CHANGED
@@ -353,10 +353,12 @@ class MetricsLogger:
|
|
353 |
# timing metrics
|
354 |
new_step = int(unreplicate(state.step))
|
355 |
new_time = time.perf_counter()
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
|
|
|
|
360 |
|
361 |
@staticmethod
|
362 |
def log(metrics, step=None, prefix=None):
|
@@ -599,7 +601,9 @@ def main():
|
|
599 |
b1=training_args.adam_beta1,
|
600 |
b2=training_args.adam_beta2,
|
601 |
eps=training_args.adam_epsilon,
|
602 |
-
weight_decay=training_args.weight_decay
|
|
|
|
|
603 |
mask=decay_mask_fn,
|
604 |
)
|
605 |
|
|
|
353 |
# timing metrics
|
354 |
new_step = int(unreplicate(state.step))
|
355 |
new_time = time.perf_counter()
|
356 |
+
if new_step > self.step:
|
357 |
+
time_per_step = (new_time - self.time) / (new_step - self.step)
|
358 |
+
self.step = new_step
|
359 |
+
self.time = new_time
|
360 |
+
state_dict["time_per_step"] = time_per_step
|
361 |
+
return {**metrics, **state_dict}
|
362 |
|
363 |
@staticmethod
|
364 |
def log(metrics, step=None, prefix=None):
|
|
|
601 |
b1=training_args.adam_beta1,
|
602 |
b2=training_args.adam_beta2,
|
603 |
eps=training_args.adam_epsilon,
|
604 |
+
weight_decay=training_args.weight_decay
|
605 |
+
if training_args.weight_decay is not None
|
606 |
+
else 0.0,
|
607 |
mask=decay_mask_fn,
|
608 |
)
|
609 |
|