boris commited on
Commit
699e1d9
2 Parent(s): a11eff5 09362db

Merge pull request #32 from borisdayma/feat-model

Browse files

feat: save and restore checkpoints
Former-commit-id: 6254697762481523764fcb4c8856e63203d2f117

Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +55 -12
seq2seq/run_seq2seq_flax.py CHANGED
@@ -271,6 +271,10 @@ class TrainState(train_state.TrainState):
271
 
272
  class CustomFlaxBartModule(FlaxBartModule):
273
  def setup(self):
 
 
 
 
274
  # we keep shared to easily load pre-trained weights
275
  self.shared = nn.Embed(
276
  self.config.vocab_size,
@@ -280,7 +284,7 @@ class CustomFlaxBartModule(FlaxBartModule):
280
  )
281
  # a separate embedding is used for the decoder
282
  self.decoder_embed = nn.Embed(
283
- OUTPUT_VOCAB_SIZE,
284
  self.config.d_model,
285
  embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
286
  dtype=self.dtype,
@@ -289,20 +293,23 @@ class CustomFlaxBartModule(FlaxBartModule):
289
 
290
  # the decoder has a different config
291
  decoder_config = BartConfig(self.config.to_dict())
292
- decoder_config.max_position_embeddings = OUTPUT_LENGTH
293
- decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
294
  self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
295
 
296
  class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
297
  def setup(self):
 
 
 
298
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
299
  self.lm_head = nn.Dense(
300
- OUTPUT_VOCAB_SIZE,
301
  use_bias=False,
302
  dtype=self.dtype,
303
  kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
304
  )
305
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
306
 
307
  class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
308
  module_class = CustomFlaxBartForConditionalGenerationModule
@@ -429,11 +436,24 @@ def main():
429
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
430
  # https://huggingface.co/docs/datasets/loading_datasets.html.
431
 
432
- # Load pretrained model and tokenizer
433
- tokenizer = AutoTokenizer.from_pretrained(
434
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
435
- )
436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  if model_args.from_checkpoint is not None:
438
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
439
  artifact_dir = artifact.download()
@@ -448,6 +468,12 @@ def main():
448
  # used in the preprocessing function
449
  config = model.config
450
 
 
 
 
 
 
 
451
  else:
452
  base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
453
  model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
@@ -473,6 +499,12 @@ def main():
473
  model.params['model']['shared'] = base_model.params['model']['shared']
474
  del base_model
475
 
 
 
 
 
 
 
476
  print(f"TPUs: {jax.device_count()}")
477
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
478
 
@@ -669,6 +701,9 @@ def main():
669
  grad_accum=jax.tree_map(jnp.zeros_like, model.params),
670
  optimizer_step=0,
671
  )
 
 
 
672
 
673
  # label smoothed cross entropy
674
  def loss_fn(logits, labels):
@@ -811,13 +846,16 @@ def main():
811
  params=params,
812
  )
813
 
 
 
 
814
  # save state
815
  state = unreplicate(state)
816
  with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
817
  f.write(to_bytes(state.opt_state))
818
  with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
819
  json.dump({'step': state.step.item()}, f)
820
-
821
  # save to W&B
822
  if data_args.log_model:
823
  metadata = {'step': step, 'epoch': epoch}
@@ -826,8 +864,13 @@ def main():
826
  artifact = wandb.Artifact(
827
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
828
  )
829
- artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
830
- artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
 
 
 
 
 
831
  artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
832
  artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
833
  wandb.run.log_artifact(artifact)
 
271
 
272
  class CustomFlaxBartModule(FlaxBartModule):
273
  def setup(self):
274
+ # check config is valid, otherwise set default values
275
+ self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
276
+ self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH)
277
+
278
  # we keep shared to easily load pre-trained weights
279
  self.shared = nn.Embed(
280
  self.config.vocab_size,
 
284
  )
285
  # a separate embedding is used for the decoder
286
  self.decoder_embed = nn.Embed(
287
+ self.config.vocab_size_output,
288
  self.config.d_model,
289
  embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
290
  dtype=self.dtype,
 
293
 
294
  # the decoder has a different config
295
  decoder_config = BartConfig(self.config.to_dict())
296
+ decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder
297
+ decoder_config.vocab_size = self.config.vocab_size_output
298
  self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
299
 
300
  class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
301
  def setup(self):
302
+ # check config is valid, otherwise set default values
303
+ self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
304
+
305
  self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
306
  self.lm_head = nn.Dense(
307
+ self.config.vocab_size_output,
308
  use_bias=False,
309
  dtype=self.dtype,
310
  kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
311
  )
312
+ self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output))
313
 
314
  class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
315
  module_class = CustomFlaxBartForConditionalGenerationModule
 
436
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
437
  # https://huggingface.co/docs/datasets/loading_datasets.html.
438
 
439
+ # Set up items to load or create
440
+ tokenizer = None
441
+ artifact_dir = None
 
442
 
443
+ def restore_state(state, artifact_dir):
444
+ # restore optimizer state
445
+ if (Path(artifact_dir) / 'opt_state.msgpack').exists():
446
+ with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
447
+ opt_state = from_bytes(state.opt_state, f.read())
448
+
449
+ # restore steps
450
+ if (Path(artifact_dir) / 'training_state.json').exists():
451
+ with (Path(artifact_dir) / 'training_state.json').open('r') as f:
452
+ training_state = json.load(f)
453
+ step = training_state['step']
454
+ optimizer_step = step // training_args.gradient_accumulation_steps
455
+ state.replace(step=step, optimizer_step=optimizer_step)
456
+
457
  if model_args.from_checkpoint is not None:
458
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
459
  artifact_dir = artifact.download()
 
468
  # used in the preprocessing function
469
  config = model.config
470
 
471
+ # load tokenizer if present
472
+ if (Path(artifact_dir) / 'tokenizer_config.json').exists():
473
+ tokenizer = AutoTokenizer.from_pretrained(
474
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
475
+ )
476
+
477
  else:
478
  base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
479
  model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
 
499
  model.params['model']['shared'] = base_model.params['model']['shared']
500
  del base_model
501
 
502
+ # Load tokenizer if it has not been set
503
+ if tokenizer is None:
504
+ tokenizer = AutoTokenizer.from_pretrained(
505
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
506
+ )
507
+
508
  print(f"TPUs: {jax.device_count()}")
509
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
510
 
 
701
  grad_accum=jax.tree_map(jnp.zeros_like, model.params),
702
  optimizer_step=0,
703
  )
704
+ if model_args.from_checkpoint is not None:
705
+ # restore optimizer state, step and optimizer_step
706
+ restore_state(state, artifact_dir)
707
 
708
  # label smoothed cross entropy
709
  def loss_fn(logits, labels):
 
846
  params=params,
847
  )
848
 
849
+ # save tokenizer
850
+ tokenizer.save_pretrained(training_args.output_dir)
851
+
852
  # save state
853
  state = unreplicate(state)
854
  with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
855
  f.write(to_bytes(state.opt_state))
856
  with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
857
  json.dump({'step': state.step.item()}, f)
858
+
859
  # save to W&B
860
  if data_args.log_model:
861
  metadata = {'step': step, 'epoch': epoch}
 
864
  artifact = wandb.Artifact(
865
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
866
  )
867
+ artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
868
+ artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
869
+ artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer.json'))
870
+ artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer_config.json'))
871
+ artifact.add_file(str(Path(training_args.output_dir) / 'vocab.json'))
872
+ artifact.add_file(str(Path(training_args.output_dir) / 'merges.txt'))
873
+ artifact.add_file(str(Path(training_args.output_dir) / 'special_tokens_map.json'))
874
  artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
875
  artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
876
  wandb.run.log_artifact(artifact)