versae commited on
Commit
6a189a4
1 Parent(s): 4cad39c

Fixing restore checkpoint step

Browse files
Files changed (2) hide show
  1. mc4/mc4.py +10 -7
  2. run_mlm_flax_stream.py +63 -6
mc4/mc4.py CHANGED
@@ -1,11 +1,11 @@
1
- """mC4 dataset based on Common Crawl."""
2
 
3
 
4
  import gzip
5
  import json
6
 
7
  import datasets
8
- import kenlm
9
  import numpy as np
10
  from numpy.random import default_rng
11
 
@@ -289,6 +289,7 @@ class Mc4(datasets.GeneratorBasedBuilder):
289
  self.sampling_factor = kwargs.pop("sampling_factor", None)
290
  self.boundaries = kwargs.pop("boundaries", None)
291
  self.seed = kwargs.pop("seed", None)
 
292
  if self.sampling_method:
293
  if self.seed is not None:
294
  self.rng = default_rng(self.seed)
@@ -316,7 +317,7 @@ class Mc4(datasets.GeneratorBasedBuilder):
316
  doc_length += length
317
  return 10.0 ** (-doc_log_score / doc_length)
318
 
319
- def _should_keep_doc_step(self, doc, factor=1.5e5, boundaries=None):
320
  perplexity = self.get_perplexity(doc)
321
  if boundaries is None:
322
  boundaries = [536394.99320948, 662247.50212365, 919250.87225178]
@@ -331,17 +332,18 @@ class Mc4(datasets.GeneratorBasedBuilder):
331
  probability = factor / quartile_range
332
  return self.rng.uniform() < probability
333
 
334
- def _should_keep_doc_gaussian(self, doc, factor=0.78, boundaries=None):
 
335
  perplexity = self.get_perplexity(doc)
336
  if boundaries is not None:
337
  m = boundaries[1]
338
  else:
339
  m = 662247.50212365
340
- exponential = np.exp(-9/2 * ((perplexity - m) / m) ** 2)
341
  weighted_perplexity = factor * exponential
342
  return self.rng.uniform() < weighted_perplexity
343
 
344
- def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
345
  if factor is None:
346
  factor = 0.5
347
  return self.rng.uniform() <= factor
@@ -415,7 +417,8 @@ class Mc4(datasets.GeneratorBasedBuilder):
415
  if self.should_keep_doc(
416
  example["text"],
417
  factor=self.sampling_factor,
418
- boundaries=self.boundaries):
 
419
  yield id_, example
420
  id_ += 1
421
  else:
1
+ """Perplexity Sampled mC4 dataset based on Common Crawl."""
2
 
3
 
4
  import gzip
5
  import json
6
 
7
  import datasets
8
+ import kenlm # pip install https://github.com/kpu/kenlm/archive/master.zip
9
  import numpy as np
10
  from numpy.random import default_rng
11
 
289
  self.sampling_factor = kwargs.pop("sampling_factor", None)
290
  self.boundaries = kwargs.pop("boundaries", None)
291
  self.seed = kwargs.pop("seed", None)
292
+ self.kwargs = kwargs
293
  if self.sampling_method:
294
  if self.seed is not None:
295
  self.rng = default_rng(self.seed)
317
  doc_length += length
318
  return 10.0 ** (-doc_log_score / doc_length)
319
 
320
+ def _should_keep_doc_step(self, doc, factor=1.5e5, boundaries=None, **kwargs):
321
  perplexity = self.get_perplexity(doc)
322
  if boundaries is None:
323
  boundaries = [536394.99320948, 662247.50212365, 919250.87225178]
332
  probability = factor / quartile_range
333
  return self.rng.uniform() < probability
334
 
335
+ def _should_keep_doc_gaussian(self, doc, factor=0.78, boundaries=None, **kwargs):
336
+ width = kwargs.get("width", 9 / 2) # width (spread) of the exponential curve
337
  perplexity = self.get_perplexity(doc)
338
  if boundaries is not None:
339
  m = boundaries[1]
340
  else:
341
  m = 662247.50212365
342
+ exponential = np.exp((-1 / width) * ((perplexity - m) / m) ** 2)
343
  weighted_perplexity = factor * exponential
344
  return self.rng.uniform() < weighted_perplexity
345
 
346
+ def _should_keep_doc_random(self, doc, factor=None, boundaries=None, **kwargs):
347
  if factor is None:
348
  factor = 0.5
349
  return self.rng.uniform() <= factor
417
  if self.should_keep_doc(
418
  example["text"],
419
  factor=self.sampling_factor,
420
+ boundaries=self.boundaries
421
+ **self.kwargs):
422
  yield id_, example
423
  id_ += 1
424
  else:
run_mlm_flax_stream.py CHANGED
@@ -348,6 +348,24 @@ def save_checkpoint_files(state, data_collator, training_args, save_dir):
348
  json.dump({"step": unreplicated_state.step.item()}, f)
349
 
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  def rotate_checkpoints(path, max_checkpoints=5):
352
  paths = sorted(Path(path).iterdir(), key=os.path.getmtime)[::-1]
353
  if len(paths) > max_checkpoints:
@@ -484,8 +502,6 @@ if __name__ == "__main__":
484
  has_tensorboard = is_tensorboard_available()
485
  if has_tensorboard and jax.process_index() == 0:
486
  try:
487
- from flax.metrics.tensorboard import SummaryWriter
488
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
489
  # Enable Weight&Biases
490
  import wandb
491
  wandb.init(
@@ -496,6 +512,8 @@ if __name__ == "__main__":
496
  wandb.config.update(training_args)
497
  wandb.config.update(model_args)
498
  wandb.config.update(data_args)
 
 
499
  except ImportError as ie:
500
  has_tensorboard = False
501
  logger.warning(
@@ -569,6 +587,42 @@ if __name__ == "__main__":
569
 
570
  # Setup train state
571
  state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
 
573
  # Define gradient update step fn
574
  def train_step(state, batch, dropout_rng):
@@ -637,7 +691,10 @@ if __name__ == "__main__":
637
  eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
638
 
639
  steps = tqdm(range(num_train_steps), desc="Training...", position=0)
640
- for step in range(num_train_steps):
 
 
 
641
  # ======================== Training ================================
642
  try:
643
  samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
@@ -700,7 +757,7 @@ if __name__ == "__main__":
700
 
701
  # save checkpoint after eval_steps
702
  if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0:
703
- logger.info(f"Saving checkpoint at {step + 1} steps")
704
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
705
  model.save_pretrained(
706
  training_args.output_dir,
@@ -709,9 +766,9 @@ if __name__ == "__main__":
709
  commit_message=f"Saving weights and logs of step {step + 1}",
710
  )
711
  save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
712
- checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step + 1}"
713
  checkpoints_dir.mkdir(parents=True, exist_ok=True)
714
- model.save_pretrained(checkpoints_dir, params=params,)
715
  save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
716
  rotate_checkpoints(
717
  Path(training_args.output_dir) / "checkpoints",
348
  json.dump({"step": unreplicated_state.step.item()}, f)
349
 
350
 
351
+ def restore_checkpoint(save_dir, state):
352
+ logger.info(f"Restoring checkpoint from {save_dir}")
353
+ with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
354
+ params = from_bytes(state.params, f.read())
355
+
356
+ with open(os.path.join(save_dir, "optimizer_state.msgpack"), "rb") as f:
357
+ opt_state = from_bytes(state.opt_state, f.read())
358
+
359
+ args = joblib.load(os.path.join(save_dir, "training_args.joblib"))
360
+ data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))
361
+
362
+ with open(os.path.join(save_dir, "training_state.json"), "r") as f:
363
+ training_state = json.load(f)
364
+ step = training_state["step"]
365
+
366
+ return params, opt_state, step, args, data_collator
367
+
368
+
369
  def rotate_checkpoints(path, max_checkpoints=5):
370
  paths = sorted(Path(path).iterdir(), key=os.path.getmtime)[::-1]
371
  if len(paths) > max_checkpoints:
502
  has_tensorboard = is_tensorboard_available()
503
  if has_tensorboard and jax.process_index() == 0:
504
  try:
 
 
505
  # Enable Weight&Biases
506
  import wandb
507
  wandb.init(
512
  wandb.config.update(training_args)
513
  wandb.config.update(model_args)
514
  wandb.config.update(data_args)
515
+ from flax.metrics.tensorboard import SummaryWriter
516
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
517
  except ImportError as ie:
518
  has_tensorboard = False
519
  logger.warning(
587
 
588
  # Setup train state
589
  state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
590
+ saved_step = 0
591
+ if "checkpoint" in model_args.model_name_or_path:
592
+ params, opt_state, saved_step, args, data_collator = restore_checkpoint(model_args.model_name_or_path, state)
593
+ # Create learning rate schedule
594
+ warmup_fn = optax.linear_schedule(
595
+ init_value=0.0, end_value=args.learning_rate, transition_steps=args.warmup_steps
596
+ )
597
+ decay_fn = optax.linear_schedule(
598
+ init_value=args.learning_rate,
599
+ end_value=0,
600
+ transition_steps=data_args.num_train_steps - args.warmup_steps,
601
+ )
602
+ linear_decay_lr_schedule_fn = optax.join_schedules(
603
+ schedules=[warmup_fn, decay_fn], boundaries=[args.warmup_steps]
604
+ )
605
+ # create adam optimizer
606
+ adamw = optax.adamw(
607
+ learning_rate=linear_decay_lr_schedule_fn,
608
+ b1=training_args.adam_beta1,
609
+ b2=training_args.adam_beta2,
610
+ eps=training_args.adam_epsilon,
611
+ weight_decay=args.weight_decay,
612
+ mask=decay_mask_fn,
613
+ )
614
+ state = train_state.TrainState(
615
+ step=saved_step,
616
+ apply_fn=model.__call__,
617
+ params=params,
618
+ tx=adamw,
619
+ opt_state=opt_state,
620
+ )
621
+ # self.args = args
622
+ # data_collator = data_collator
623
+ # scheduler_fn = args.learning_rate
624
+ model.params = params
625
+
626
 
627
  # Define gradient update step fn
628
  def train_step(state, batch, dropout_rng):
691
  eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
692
 
693
  steps = tqdm(range(num_train_steps), desc="Training...", position=0)
694
+ for step in range(saved_step, num_train_steps):
695
+ if step < saved_step:
696
+ steps.update(1)
697
+ continue
698
  # ======================== Training ================================
699
  try:
700
  samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
757
 
758
  # save checkpoint after eval_steps
759
  if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0:
760
+ logger.info(f"Saving checkpoint at {step} steps")
761
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
762
  model.save_pretrained(
763
  training_args.output_dir,
766
  commit_message=f"Saving weights and logs of step {step + 1}",
767
  )
768
  save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
769
+ checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
770
  checkpoints_dir.mkdir(parents=True, exist_ok=True)
771
+ model.save_pretrained(checkpoints_dir, params=params)
772
  save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
773
  rotate_checkpoints(
774
  Path(training_args.output_dir) / "checkpoints",