dat
commited on
Commit
•
afb3179
1
Parent(s):
bc11ccf
with grad_accum
Browse files- run_mlm_flax.py +34 -20
run_mlm_flax.py
CHANGED
@@ -288,7 +288,8 @@ def rotate_checkpoints(ckpt_dir: str, save_total_limit: int):
|
|
288 |
|
289 |
|
290 |
|
291 |
-
|
|
|
292 |
|
293 |
|
294 |
|
@@ -396,10 +397,10 @@ if __name__ == "__main__":
|
|
396 |
return train, val
|
397 |
train, val = train_val_files()
|
398 |
datasets = load_dataset('json', data_files={'train': train, 'validation': val})
|
399 |
-
|
400 |
-
|
401 |
-
datasets["train"] = datasets["train"].select(range(10000))
|
402 |
-
datasets["validation"] = datasets["validation"].select(range(10000))
|
403 |
|
404 |
|
405 |
|
@@ -566,7 +567,7 @@ if __name__ == "__main__":
|
|
566 |
|
567 |
# Store some constant
|
568 |
num_epochs = int(training_args.num_train_epochs)
|
569 |
-
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
570 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
571 |
|
572 |
num_train_steps = len(train_dataset) // train_batch_size * num_epochs
|
@@ -613,14 +614,14 @@ if __name__ == "__main__":
|
|
613 |
mask=decay_mask_fn,
|
614 |
)
|
615 |
|
616 |
-
if training_args.gradient_accumulation_steps > 1:
|
617 |
-
|
618 |
-
grad_accum_steps = training_args.gradient_accumulation_steps
|
619 |
|
620 |
# Setup train state
|
621 |
|
622 |
|
623 |
-
state =
|
624 |
|
625 |
if training_args.resume_from_checkpoint:
|
626 |
state = restore_checkpoint(training_args.resume_from_checkpoint, state)
|
@@ -645,17 +646,30 @@ if __name__ == "__main__":
|
|
645 |
# take average
|
646 |
loss = loss.sum() / label_mask.sum()
|
647 |
|
648 |
-
return loss
|
649 |
|
650 |
grad_fn = jax.value_and_grad(loss_fn)
|
651 |
-
loss,
|
652 |
-
|
653 |
-
|
654 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
655 |
metrics = jax.lax.pmean(
|
656 |
-
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step
|
657 |
)
|
658 |
|
|
|
659 |
return new_state, metrics, new_dropout_rng
|
660 |
|
661 |
# Create parallel version of the train step
|
@@ -699,10 +713,10 @@ if __name__ == "__main__":
|
|
699 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
700 |
num_train_samples = len(train_dataset)
|
701 |
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
|
702 |
-
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size
|
703 |
|
704 |
# Gather the indexes for creating the batch and do a training step
|
705 |
-
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1,initial=resume_step
|
706 |
samples = [train_dataset[int(idx)] for idx in batch_idx]
|
707 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
708 |
|
@@ -716,7 +730,7 @@ if __name__ == "__main__":
|
|
716 |
if cur_step < resume_step:
|
717 |
continue
|
718 |
|
719 |
-
if (cur_step % training_args.logging_steps
|
720 |
# Save metrics
|
721 |
train_metric = jax_utils.unreplicate(train_metric)
|
722 |
train_time += time.time() - train_start
|
@@ -733,7 +747,7 @@ if __name__ == "__main__":
|
|
733 |
|
734 |
train_metrics = []
|
735 |
|
736 |
-
if cur_step % (training_args.eval_steps
|
737 |
# ======================== Evaluating ==============================
|
738 |
num_eval_samples = len(eval_dataset)
|
739 |
eval_samples_idx = jnp.arange(num_eval_samples)
|
288 |
|
289 |
|
290 |
|
291 |
+
class TrainState(train_state.TrainState):
|
292 |
+
grad_accum: jnp.ndarray
|
293 |
|
294 |
|
295 |
|
397 |
return train, val
|
398 |
train, val = train_val_files()
|
399 |
datasets = load_dataset('json', data_files={'train': train, 'validation': val})
|
400 |
+
datasets["train"] = datasets["train"].select(range(int(0.8*len(datasets["train"]))))
|
401 |
+
datasets["validation"] = datasets["validation"].select(range(int(0.8*len(datasets["validation"]))))
|
402 |
+
#datasets["train"] = datasets["train"].select(range(10000))
|
403 |
+
#datasets["validation"] = datasets["validation"].select(range(10000))
|
404 |
|
405 |
|
406 |
|
567 |
|
568 |
# Store some constant
|
569 |
num_epochs = int(training_args.num_train_epochs)
|
570 |
+
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
571 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
572 |
|
573 |
num_train_steps = len(train_dataset) // train_batch_size * num_epochs
|
614 |
mask=decay_mask_fn,
|
615 |
)
|
616 |
|
617 |
+
#if training_args.gradient_accumulation_steps > 1:
|
618 |
+
# optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
|
619 |
+
#grad_accum_steps = training_args.gradient_accumulation_steps
|
620 |
|
621 |
# Setup train state
|
622 |
|
623 |
|
624 |
+
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer,grad_accum=jax.tree_map(jnp.zeros_like, model.params))
|
625 |
|
626 |
if training_args.resume_from_checkpoint:
|
627 |
state = restore_checkpoint(training_args.resume_from_checkpoint, state)
|
646 |
# take average
|
647 |
loss = loss.sum() / label_mask.sum()
|
648 |
|
649 |
+
return loss / training_args.gradient_accumulation_steps
|
650 |
|
651 |
grad_fn = jax.value_and_grad(loss_fn)
|
652 |
+
loss, grads = grad_fn(state.params)
|
653 |
+
grad_accum = jax.tree_multimap(lambda x, y: x + y, grads, state.grad_accum)
|
654 |
+
|
655 |
+
def update_fn():
|
656 |
+
grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
|
657 |
+
grads = jax.lax.pmean(grad_accum, "batch")
|
658 |
+
new_state = state.apply_gradients(grads=grads,grad_accum=jax.tree_map(jnp.zeros_like, grads))
|
659 |
+
return new_state
|
660 |
+
|
661 |
+
new_state = jax.lax.cond(
|
662 |
+
state.step % training_args.gradient_accumulation_steps == 0,
|
663 |
+
lambda _: update_fn(),
|
664 |
+
lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
|
665 |
+
None,
|
666 |
+
)
|
667 |
+
|
668 |
metrics = jax.lax.pmean(
|
669 |
+
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch" #
|
670 |
)
|
671 |
|
672 |
+
#return new_state.replace(new_dropout_rng=new_dropout_rng), metrics
|
673 |
return new_state, metrics, new_dropout_rng
|
674 |
|
675 |
# Create parallel version of the train step
|
713 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
714 |
num_train_samples = len(train_dataset)
|
715 |
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
|
716 |
+
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) #// grad_accum_steps
|
717 |
|
718 |
# Gather the indexes for creating the batch and do a training step
|
719 |
+
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1,initial=resume_step)): #grad_accum
|
720 |
samples = [train_dataset[int(idx)] for idx in batch_idx]
|
721 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
722 |
|
730 |
if cur_step < resume_step:
|
731 |
continue
|
732 |
|
733 |
+
if (cur_step % training_args.logging_steps) == 0 and cur_step > 0: # * grad_accum_steps
|
734 |
# Save metrics
|
735 |
train_metric = jax_utils.unreplicate(train_metric)
|
736 |
train_time += time.time() - train_start
|
747 |
|
748 |
train_metrics = []
|
749 |
|
750 |
+
if cur_step % (training_args.eval_steps) == 0 and cur_step > 0: #* grad_accum_steps
|
751 |
# ======================== Evaluating ==============================
|
752 |
num_eval_samples = len(eval_dataset)
|
753 |
eval_samples_idx = jnp.arange(num_eval_samples)
|