Spaces:
Running
Running
feat(train): cleanup args
Browse files- tools/train/train.py +21 -17
tools/train/train.py
CHANGED
@@ -199,8 +199,11 @@ class TrainingArguments:
|
|
199 |
per_device_train_batch_size: int = field(
|
200 |
default=8, metadata={"help": "Batch size per GPU/TPU/CPU for training."}
|
201 |
)
|
202 |
-
per_device_eval_batch_size: int = field(
|
203 |
-
default=
|
|
|
|
|
|
|
204 |
)
|
205 |
|
206 |
gradient_accumulation_steps: int = field(
|
@@ -252,6 +255,13 @@ class TrainingArguments:
|
|
252 |
},
|
253 |
)
|
254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
lr_decay: str = field(
|
256 |
default=None,
|
257 |
metadata={
|
@@ -277,13 +287,6 @@ class TrainingArguments:
|
|
277 |
},
|
278 |
)
|
279 |
|
280 |
-
num_train_epochs: int = field(
|
281 |
-
default=3, metadata={"help": "Total number of training epochs to perform."}
|
282 |
-
)
|
283 |
-
warmup_steps: int = field(
|
284 |
-
default=0, metadata={"help": "Linear warmup over warmup_steps."}
|
285 |
-
)
|
286 |
-
|
287 |
logging_steps: int = field(
|
288 |
default=40, metadata={"help": "Log every X updates steps."}
|
289 |
)
|
@@ -334,6 +337,11 @@ class TrainingArguments:
|
|
334 |
"adam",
|
335 |
"adafactor",
|
336 |
], f"Selected optimizer not supported: {self.optim}"
|
|
|
|
|
|
|
|
|
|
|
337 |
if (
|
338 |
os.path.exists(self.output_dir)
|
339 |
and os.listdir(self.output_dir)
|
@@ -623,9 +631,7 @@ def main():
|
|
623 |
beta2=training_args.beta2,
|
624 |
diagonal_epsilon=1e-10,
|
625 |
matrix_epsilon=1e-8,
|
626 |
-
weight_decay=training_args.weight_decay
|
627 |
-
if training_args.weight_decay is not None
|
628 |
-
else 0.0,
|
629 |
start_preconditioning_step=training_args.warmup_steps,
|
630 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
631 |
statistics_compute_steps=1,
|
@@ -648,9 +654,7 @@ def main():
|
|
648 |
b1=training_args.beta1,
|
649 |
b2=training_args.beta2,
|
650 |
eps=training_args.adam_epsilon,
|
651 |
-
weight_decay=training_args.weight_decay
|
652 |
-
if training_args.weight_decay is not None
|
653 |
-
else 0.0,
|
654 |
mask=decay_mask_fn,
|
655 |
)
|
656 |
elif training_args.optim == "adafactor":
|
@@ -749,8 +753,8 @@ def main():
|
|
749 |
return metrics
|
750 |
|
751 |
# Create parallel version of the train and eval step
|
752 |
-
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,
|
753 |
-
p_eval_step = jax.pmap(eval_step, "batch"
|
754 |
|
755 |
logger.info("***** Running training *****")
|
756 |
logger.info(f" Num examples = {len_train_dataset}")
|
|
|
199 |
per_device_train_batch_size: int = field(
|
200 |
default=8, metadata={"help": "Batch size per GPU/TPU/CPU for training."}
|
201 |
)
|
202 |
+
per_device_eval_batch_size: Optional[int] = field(
|
203 |
+
default=None,
|
204 |
+
metadata={
|
205 |
+
"help": "Batch size per GPU/TPU/CPU for evaluation. Same as training batch size if not set."
|
206 |
+
},
|
207 |
)
|
208 |
|
209 |
gradient_accumulation_steps: int = field(
|
|
|
255 |
},
|
256 |
)
|
257 |
|
258 |
+
num_train_epochs: int = field(
|
259 |
+
default=3, metadata={"help": "Total number of training epochs to perform."}
|
260 |
+
)
|
261 |
+
|
262 |
+
warmup_steps: int = field(
|
263 |
+
default=0, metadata={"help": "Linear warmup over warmup_steps."}
|
264 |
+
)
|
265 |
lr_decay: str = field(
|
266 |
default=None,
|
267 |
metadata={
|
|
|
287 |
},
|
288 |
)
|
289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
logging_steps: int = field(
|
291 |
default=40, metadata={"help": "Log every X updates steps."}
|
292 |
)
|
|
|
337 |
"adam",
|
338 |
"adafactor",
|
339 |
], f"Selected optimizer not supported: {self.optim}"
|
340 |
+
if self.per_device_eval_batch_size is None:
|
341 |
+
self.per_device_eval_batch_size = self.per_device_train_batch_size
|
342 |
+
if self.weight_decay is None:
|
343 |
+
if self.optim in ["distributed_shampoo", "adam"]:
|
344 |
+
self.weight_decay = 0.0
|
345 |
if (
|
346 |
os.path.exists(self.output_dir)
|
347 |
and os.listdir(self.output_dir)
|
|
|
631 |
beta2=training_args.beta2,
|
632 |
diagonal_epsilon=1e-10,
|
633 |
matrix_epsilon=1e-8,
|
634 |
+
weight_decay=training_args.weight_decay,
|
|
|
|
|
635 |
start_preconditioning_step=training_args.warmup_steps,
|
636 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
637 |
statistics_compute_steps=1,
|
|
|
654 |
b1=training_args.beta1,
|
655 |
b2=training_args.beta2,
|
656 |
eps=training_args.adam_epsilon,
|
657 |
+
weight_decay=training_args.weight_decay,
|
|
|
|
|
658 |
mask=decay_mask_fn,
|
659 |
)
|
660 |
elif training_args.optim == "adafactor":
|
|
|
753 |
return metrics
|
754 |
|
755 |
# Create parallel version of the train and eval step
|
756 |
+
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
757 |
+
p_eval_step = jax.pmap(eval_step, "batch")
|
758 |
|
759 |
logger.info("***** Running training *****")
|
760 |
logger.info(f" Num examples = {len_train_dataset}")
|