Fixing restore checkpoint step
Browse files- mc4/mc4.py +10 -7
- run_mlm_flax_stream.py +63 -6
mc4/mc4.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
-
"""mC4 dataset based on Common Crawl."""
|
2 |
|
3 |
|
4 |
import gzip
|
5 |
import json
|
6 |
|
7 |
import datasets
|
8 |
-
import kenlm
|
9 |
import numpy as np
|
10 |
from numpy.random import default_rng
|
11 |
|
@@ -289,6 +289,7 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
289 |
self.sampling_factor = kwargs.pop("sampling_factor", None)
|
290 |
self.boundaries = kwargs.pop("boundaries", None)
|
291 |
self.seed = kwargs.pop("seed", None)
|
|
|
292 |
if self.sampling_method:
|
293 |
if self.seed is not None:
|
294 |
self.rng = default_rng(self.seed)
|
@@ -316,7 +317,7 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
316 |
doc_length += length
|
317 |
return 10.0 ** (-doc_log_score / doc_length)
|
318 |
|
319 |
-
def _should_keep_doc_step(self, doc, factor=1.5e5, boundaries=None):
|
320 |
perplexity = self.get_perplexity(doc)
|
321 |
if boundaries is None:
|
322 |
boundaries = [536394.99320948, 662247.50212365, 919250.87225178]
|
@@ -331,17 +332,18 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
331 |
probability = factor / quartile_range
|
332 |
return self.rng.uniform() < probability
|
333 |
|
334 |
-
def _should_keep_doc_gaussian(self, doc, factor=0.78, boundaries=None):
|
|
|
335 |
perplexity = self.get_perplexity(doc)
|
336 |
if boundaries is not None:
|
337 |
m = boundaries[1]
|
338 |
else:
|
339 |
m = 662247.50212365
|
340 |
-
exponential = np.exp(-
|
341 |
weighted_perplexity = factor * exponential
|
342 |
return self.rng.uniform() < weighted_perplexity
|
343 |
|
344 |
-
def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
|
345 |
if factor is None:
|
346 |
factor = 0.5
|
347 |
return self.rng.uniform() <= factor
|
@@ -415,7 +417,8 @@ class Mc4(datasets.GeneratorBasedBuilder):
|
|
415 |
if self.should_keep_doc(
|
416 |
example["text"],
|
417 |
factor=self.sampling_factor,
|
418 |
-
boundaries=self.boundaries
|
|
|
419 |
yield id_, example
|
420 |
id_ += 1
|
421 |
else:
|
|
|
1 |
+
"""Perplexity Sampled mC4 dataset based on Common Crawl."""
|
2 |
|
3 |
|
4 |
import gzip
|
5 |
import json
|
6 |
|
7 |
import datasets
|
8 |
+
import kenlm # pip install https://github.com/kpu/kenlm/archive/master.zip
|
9 |
import numpy as np
|
10 |
from numpy.random import default_rng
|
11 |
|
|
|
289 |
self.sampling_factor = kwargs.pop("sampling_factor", None)
|
290 |
self.boundaries = kwargs.pop("boundaries", None)
|
291 |
self.seed = kwargs.pop("seed", None)
|
292 |
+
self.kwargs = kwargs
|
293 |
if self.sampling_method:
|
294 |
if self.seed is not None:
|
295 |
self.rng = default_rng(self.seed)
|
|
|
317 |
doc_length += length
|
318 |
return 10.0 ** (-doc_log_score / doc_length)
|
319 |
|
320 |
+
def _should_keep_doc_step(self, doc, factor=1.5e5, boundaries=None, **kwargs):
|
321 |
perplexity = self.get_perplexity(doc)
|
322 |
if boundaries is None:
|
323 |
boundaries = [536394.99320948, 662247.50212365, 919250.87225178]
|
|
|
332 |
probability = factor / quartile_range
|
333 |
return self.rng.uniform() < probability
|
334 |
|
335 |
+
def _should_keep_doc_gaussian(self, doc, factor=0.78, boundaries=None, **kwargs):
|
336 |
+
width = kwargs.get("width", 9 / 2) # width (spread) of the exponential curve
|
337 |
perplexity = self.get_perplexity(doc)
|
338 |
if boundaries is not None:
|
339 |
m = boundaries[1]
|
340 |
else:
|
341 |
m = 662247.50212365
|
342 |
+
exponential = np.exp((-1 / width) * ((perplexity - m) / m) ** 2)
|
343 |
weighted_perplexity = factor * exponential
|
344 |
return self.rng.uniform() < weighted_perplexity
|
345 |
|
346 |
+
def _should_keep_doc_random(self, doc, factor=None, boundaries=None, **kwargs):
|
347 |
if factor is None:
|
348 |
factor = 0.5
|
349 |
return self.rng.uniform() <= factor
|
|
|
417 |
if self.should_keep_doc(
|
418 |
example["text"],
|
419 |
factor=self.sampling_factor,
|
420 |
+
boundaries=self.boundaries
|
421 |
+
**self.kwargs):
|
422 |
yield id_, example
|
423 |
id_ += 1
|
424 |
else:
|
run_mlm_flax_stream.py
CHANGED
@@ -348,6 +348,24 @@ def save_checkpoint_files(state, data_collator, training_args, save_dir):
|
|
348 |
json.dump({"step": unreplicated_state.step.item()}, f)
|
349 |
|
350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
def rotate_checkpoints(path, max_checkpoints=5):
|
352 |
paths = sorted(Path(path).iterdir(), key=os.path.getmtime)[::-1]
|
353 |
if len(paths) > max_checkpoints:
|
@@ -484,8 +502,6 @@ if __name__ == "__main__":
|
|
484 |
has_tensorboard = is_tensorboard_available()
|
485 |
if has_tensorboard and jax.process_index() == 0:
|
486 |
try:
|
487 |
-
from flax.metrics.tensorboard import SummaryWriter
|
488 |
-
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
489 |
# Enable Weight&Biases
|
490 |
import wandb
|
491 |
wandb.init(
|
@@ -496,6 +512,8 @@ if __name__ == "__main__":
|
|
496 |
wandb.config.update(training_args)
|
497 |
wandb.config.update(model_args)
|
498 |
wandb.config.update(data_args)
|
|
|
|
|
499 |
except ImportError as ie:
|
500 |
has_tensorboard = False
|
501 |
logger.warning(
|
@@ -569,6 +587,42 @@ if __name__ == "__main__":
|
|
569 |
|
570 |
# Setup train state
|
571 |
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
572 |
|
573 |
# Define gradient update step fn
|
574 |
def train_step(state, batch, dropout_rng):
|
@@ -637,7 +691,10 @@ if __name__ == "__main__":
|
|
637 |
eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
|
638 |
|
639 |
steps = tqdm(range(num_train_steps), desc="Training...", position=0)
|
640 |
-
for step in range(num_train_steps):
|
|
|
|
|
|
|
641 |
# ======================== Training ================================
|
642 |
try:
|
643 |
samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
|
@@ -700,7 +757,7 @@ if __name__ == "__main__":
|
|
700 |
|
701 |
# save checkpoint after eval_steps
|
702 |
if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0:
|
703 |
-
logger.info(f"Saving checkpoint at {step
|
704 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
705 |
model.save_pretrained(
|
706 |
training_args.output_dir,
|
@@ -709,9 +766,9 @@ if __name__ == "__main__":
|
|
709 |
commit_message=f"Saving weights and logs of step {step + 1}",
|
710 |
)
|
711 |
save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
|
712 |
-
checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step
|
713 |
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
714 |
-
model.save_pretrained(checkpoints_dir, params=params
|
715 |
save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
|
716 |
rotate_checkpoints(
|
717 |
Path(training_args.output_dir) / "checkpoints",
|
|
|
348 |
json.dump({"step": unreplicated_state.step.item()}, f)
|
349 |
|
350 |
|
351 |
+
def restore_checkpoint(save_dir, state):
|
352 |
+
logger.info(f"Restoring checkpoint from {save_dir}")
|
353 |
+
with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
|
354 |
+
params = from_bytes(state.params, f.read())
|
355 |
+
|
356 |
+
with open(os.path.join(save_dir, "optimizer_state.msgpack"), "rb") as f:
|
357 |
+
opt_state = from_bytes(state.opt_state, f.read())
|
358 |
+
|
359 |
+
args = joblib.load(os.path.join(save_dir, "training_args.joblib"))
|
360 |
+
data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))
|
361 |
+
|
362 |
+
with open(os.path.join(save_dir, "training_state.json"), "r") as f:
|
363 |
+
training_state = json.load(f)
|
364 |
+
step = training_state["step"]
|
365 |
+
|
366 |
+
return params, opt_state, step, args, data_collator
|
367 |
+
|
368 |
+
|
369 |
def rotate_checkpoints(path, max_checkpoints=5):
|
370 |
paths = sorted(Path(path).iterdir(), key=os.path.getmtime)[::-1]
|
371 |
if len(paths) > max_checkpoints:
|
|
|
502 |
has_tensorboard = is_tensorboard_available()
|
503 |
if has_tensorboard and jax.process_index() == 0:
|
504 |
try:
|
|
|
|
|
505 |
# Enable Weight&Biases
|
506 |
import wandb
|
507 |
wandb.init(
|
|
|
512 |
wandb.config.update(training_args)
|
513 |
wandb.config.update(model_args)
|
514 |
wandb.config.update(data_args)
|
515 |
+
from flax.metrics.tensorboard import SummaryWriter
|
516 |
+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
517 |
except ImportError as ie:
|
518 |
has_tensorboard = False
|
519 |
logger.warning(
|
|
|
587 |
|
588 |
# Setup train state
|
589 |
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
|
590 |
+
saved_step = 0
|
591 |
+
if "checkpoint" in model_args.model_name_or_path:
|
592 |
+
params, opt_state, saved_step, args, data_collator = restore_checkpoint(model_args.model_name_or_path, state)
|
593 |
+
# Create learning rate schedule
|
594 |
+
warmup_fn = optax.linear_schedule(
|
595 |
+
init_value=0.0, end_value=args.learning_rate, transition_steps=args.warmup_steps
|
596 |
+
)
|
597 |
+
decay_fn = optax.linear_schedule(
|
598 |
+
init_value=args.learning_rate,
|
599 |
+
end_value=0,
|
600 |
+
transition_steps=data_args.num_train_steps - args.warmup_steps,
|
601 |
+
)
|
602 |
+
linear_decay_lr_schedule_fn = optax.join_schedules(
|
603 |
+
schedules=[warmup_fn, decay_fn], boundaries=[args.warmup_steps]
|
604 |
+
)
|
605 |
+
# create adam optimizer
|
606 |
+
adamw = optax.adamw(
|
607 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
608 |
+
b1=training_args.adam_beta1,
|
609 |
+
b2=training_args.adam_beta2,
|
610 |
+
eps=training_args.adam_epsilon,
|
611 |
+
weight_decay=args.weight_decay,
|
612 |
+
mask=decay_mask_fn,
|
613 |
+
)
|
614 |
+
state = train_state.TrainState(
|
615 |
+
step=saved_step,
|
616 |
+
apply_fn=model.__call__,
|
617 |
+
params=params,
|
618 |
+
tx=adamw,
|
619 |
+
opt_state=opt_state,
|
620 |
+
)
|
621 |
+
# self.args = args
|
622 |
+
# data_collator = data_collator
|
623 |
+
# scheduler_fn = args.learning_rate
|
624 |
+
model.params = params
|
625 |
+
|
626 |
|
627 |
# Define gradient update step fn
|
628 |
def train_step(state, batch, dropout_rng):
|
|
|
691 |
eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
|
692 |
|
693 |
steps = tqdm(range(num_train_steps), desc="Training...", position=0)
|
694 |
+
for step in range(saved_step, num_train_steps):
|
695 |
+
if step < saved_step:
|
696 |
+
steps.update(1)
|
697 |
+
continue
|
698 |
# ======================== Training ================================
|
699 |
try:
|
700 |
samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
|
|
|
757 |
|
758 |
# save checkpoint after eval_steps
|
759 |
if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0:
|
760 |
+
logger.info(f"Saving checkpoint at {step} steps")
|
761 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
762 |
model.save_pretrained(
|
763 |
training_args.output_dir,
|
|
|
766 |
commit_message=f"Saving weights and logs of step {step + 1}",
|
767 |
)
|
768 |
save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
|
769 |
+
checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
|
770 |
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
771 |
+
model.save_pretrained(checkpoints_dir, params=params)
|
772 |
save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
|
773 |
rotate_checkpoints(
|
774 |
Path(training_args.output_dir) / "checkpoints",
|