boris commited on
Commit
7c4c287
2 Parent(s): 605df32 386f839

feat(train): split artifact into model/state (#128)

Browse files
Files changed (2) hide show
  1. src/dalle_mini/text.py +3 -3
  2. tools/train/train.py +114 -122
src/dalle_mini/text.py CHANGED
@@ -116,7 +116,7 @@ def remove_comma_numbers(t):
116
 
117
 
118
  def pre_process_dot_numbers(t):
119
- return re.sub("(\w)\.(\w)", fr"\1{temp_token}dot{temp_token}\2", t)
120
 
121
 
122
  def post_process_dot_numbers(t):
@@ -126,7 +126,7 @@ def post_process_dot_numbers(t):
126
  def pre_process_quotes(t):
127
  # allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
128
  return re.sub(
129
- r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", fr"{temp_token}quote{temp_token}", t
130
  )
131
 
132
 
@@ -135,7 +135,7 @@ def post_process_quotes(t):
135
 
136
 
137
  def pre_process_dates(t):
138
- return re.sub("(\d)/(\d)", fr"\1{temp_token}slash{temp_token}\2", t)
139
 
140
 
141
  def post_process_dates(t):
 
116
 
117
 
118
  def pre_process_dot_numbers(t):
119
+ return re.sub("(\w)\.(\w)", rf"\1{temp_token}dot{temp_token}\2", t)
120
 
121
 
122
  def post_process_dot_numbers(t):
 
126
  def pre_process_quotes(t):
127
  # allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
128
  return re.sub(
129
+ r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", rf"{temp_token}quote{temp_token}", t
130
  )
131
 
132
 
 
135
 
136
 
137
  def pre_process_dates(t):
138
+ return re.sub("(\d)/(\d)", rf"\1{temp_token}slash{temp_token}\2", t)
139
 
140
 
141
  def post_process_dates(t):
tools/train/train.py CHANGED
@@ -88,6 +88,23 @@ class ModelArguments:
88
  "help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
89
  },
90
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
 
93
  @dataclass
@@ -319,11 +336,6 @@ class TrainingArguments:
319
  },
320
  )
321
 
322
- resume_from_checkpoint: Optional[str] = field(
323
- default=None,
324
- metadata={"help": "Reference to a wandb artifact for resuming training."},
325
- )
326
-
327
  wandb_entity: Optional[str] = field(
328
  default=None,
329
  metadata={"help": "The wandb entity to use (for teams)."},
@@ -349,6 +361,8 @@ class TrainingArguments:
349
  },
350
  )
351
 
 
 
352
  def __post_init__(self):
353
  assert self.optim in [
354
  "distributed_shampoo",
@@ -470,62 +484,40 @@ def main():
470
  config=parser.parse_args(),
471
  )
472
 
473
- if training_args.resume_from_checkpoint is not None:
474
- if jax.process_index() == 0:
475
- artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
476
- else:
477
- artifact = wandb.Api().artifact(training_args.resume_from_checkpoint)
478
- artifact_dir = artifact.download()
479
 
480
- # load model
 
481
  model = DalleBart.from_pretrained(
482
- artifact_dir,
 
 
483
  dtype=getattr(jnp, model_args.dtype),
484
  abstract_init=True,
485
  load_on_cpu=True,
486
  )
 
 
 
 
 
 
 
487
 
488
- # load tokenizer
 
489
  tokenizer = DalleBartTokenizer.from_pretrained(
490
- artifact_dir,
491
- use_fast=True,
492
  )
493
-
494
  else:
495
- # Set up our new model config
496
- if model_args.config_name:
497
- config = DalleBartConfig.from_pretrained(model_args.config_name)
498
- else:
499
- config = None
500
-
501
- # Load or create new model
502
- if model_args.model_name_or_path:
503
- model = DalleBart.from_pretrained(
504
- model_args.model_name_or_path,
505
- config=config,
506
- seed=training_args.seed_model,
507
- dtype=getattr(jnp, model_args.dtype),
508
- abstract_init=True,
509
- load_on_cpu=True,
510
- )
511
- else:
512
- model = DalleBart(
513
- config,
514
- seed=training_args.seed_model,
515
- dtype=getattr(jnp, model_args.dtype),
516
- load_on_cpu=True,
517
- )
518
-
519
- # Load tokenizer
520
- if model_args.tokenizer_name is not None:
521
- tokenizer = DalleBartTokenizer.from_pretrained(
522
- model_args.tokenizer_name, use_fast=True
523
- )
524
- else:
525
- tokenizer = DalleBartTokenizer.from_pretrained(
526
- model_args.model_name_or_path,
527
- use_fast=True,
528
- )
529
 
530
  # get PartitionSpec for model params (required to be a dict)
531
  param_spec = set_partitions(model.params)
@@ -655,30 +647,29 @@ def main():
655
 
656
  # get PartitionSpec for optimizer state
657
  def get_opt_state_spec_and_shape(param_spec):
658
- if training_args.optim in ["adam", "adafactor"]:
659
- # get opt_state shape without actual init
660
- opt_state_shape = jax.eval_shape(optimizer.init, model.params)
661
-
662
- if training_args.optim == "adam":
663
-
664
- def _opt_state_spec_per_leaf(x):
665
- if isinstance(x, FrozenDict):
666
- # variables with same structure as params
667
- return param_spec
668
- else:
669
- # other variables such as count
670
- return None
671
-
672
- opt_state_spec = jax.tree_map(
673
- _opt_state_spec_per_leaf,
674
- opt_state_shape,
675
- # return None spec for empty elements
676
- is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
677
- )
678
 
679
- elif training_args.optim == "adafactor":
680
- # factorized state must be replicated (rank different than params)
681
- opt_state_spec = None
682
 
683
  elif training_args.optim == "distributed_shampoo":
684
  opt_state_spec = opt_fn.pspec_fn(
@@ -686,7 +677,6 @@ def main():
686
  params_partition_spec=param_spec,
687
  partition_spec_for_statistics=PartitionSpec(None, "batch", None),
688
  )
689
- opt_state_shape = opt_fn.shape_and_dtype_fn(model.params)
690
  else:
691
  raise NotImplementedError
692
  return opt_state_spec, opt_state_shape
@@ -698,7 +688,7 @@ def main():
698
  devices = np.asarray(jax.devices()).reshape(*mesh_shape)
699
  mesh = maps.Mesh(devices, ("batch", "mp"))
700
 
701
- # Create state spec
702
  state_spec = TrainState(
703
  params=param_spec,
704
  opt_state=opt_state_spec,
@@ -713,7 +703,7 @@ def main():
713
 
714
  # create training state
715
  with maps.mesh(mesh.devices, mesh.axis_names):
716
- if training_args.resume_from_checkpoint is None:
717
 
718
  def init_state(params):
719
  return TrainState.create(
@@ -731,6 +721,13 @@ def main():
731
  )(model.params)
732
 
733
  else:
 
 
 
 
 
 
 
734
  # restore opt_state
735
  with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
736
  opt_state = from_bytes(opt_state_shape, f.read())
@@ -760,7 +757,7 @@ def main():
760
  del opt_state
761
 
762
  # free memory
763
- del model._params
764
 
765
  # define batch specs
766
  keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
@@ -998,51 +995,46 @@ def main():
998
  f,
999
  )
1000
 
1001
- if jax.process_index() == 0:
1002
- # save to W&B
1003
- if training_args.log_model:
1004
- # save some space
1005
- c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
1006
- c.cleanup(wandb.util.from_human_size("10GB"))
1007
-
1008
- metadata = dict(state_dict)
1009
- metadata["num_params"] = num_params
1010
- if eval_metrics is not None:
1011
- metadata["eval"] = eval_metrics
1012
- artifact = wandb.Artifact(
1013
- name=f"model-{wandb.run.id}",
1014
- type="bart_model",
1015
- metadata=metadata,
1016
- )
1017
- artifact.add_file(
1018
- str(Path(training_args.output_dir) / "flax_model.msgpack")
1019
- )
1020
- artifact.add_file(
1021
- str(Path(training_args.output_dir) / "config.json")
1022
- )
1023
- artifact.add_file(
1024
- str(Path(training_args.output_dir) / "tokenizer.json")
1025
- )
1026
- artifact.add_file(
1027
- str(Path(training_args.output_dir) / "tokenizer_config.json")
1028
- )
1029
- artifact.add_file(
1030
- str(Path(training_args.output_dir) / "vocab.json")
1031
- )
1032
- artifact.add_file(
1033
- str(Path(training_args.output_dir) / "merges.txt")
1034
- )
1035
- artifact.add_file(
1036
- str(Path(training_args.output_dir) / "special_tokens_map.json")
1037
- )
1038
- artifact.add_file(
1039
- str(Path(training_args.output_dir) / "opt_state.msgpack")
1040
- )
1041
- artifact.add_file(
1042
- str(Path(training_args.output_dir) / "training_state.json")
1043
  )
1044
-
1045
- wandb.run.log_artifact(artifact)
1046
 
1047
  # init variables
1048
  last_time = time.perf_counter()
 
88
  "help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
89
  },
90
  )
91
+ restore_state: Optional[bool] = field(
92
+ default=False,
93
+ metadata={
94
+ "help": "Restore optimizer and training state associated with a wandb checkpoint."
95
+ },
96
+ )
97
+
98
+ state_artifact: str = field(init=False)
99
+
100
+ def __post_init__(self):
101
+ if self.restore_state:
102
+ assert (
103
+ "/model-" in self.model_name_or_path
104
+ ), "Restoring state only available with W&B artifact reference"
105
+ self.state_artifact = self.model_name_or_path.replace(
106
+ "/model-", "/state-", 1
107
+ )
108
 
109
 
110
  @dataclass
 
336
  },
337
  )
338
 
 
 
 
 
 
339
  wandb_entity: Optional[str] = field(
340
  default=None,
341
  metadata={"help": "The wandb entity to use (for teams)."},
 
361
  },
362
  )
363
 
364
+ dp_devices: int = field(init=False)
365
+
366
  def __post_init__(self):
367
  assert self.optim in [
368
  "distributed_shampoo",
 
484
  config=parser.parse_args(),
485
  )
486
 
487
+ # Set up our new model config
488
+ if model_args.config_name:
489
+ config = DalleBartConfig.from_pretrained(model_args.config_name)
490
+ else:
491
+ config = None
 
492
 
493
+ # Load or create new model
494
+ if model_args.model_name_or_path:
495
  model = DalleBart.from_pretrained(
496
+ model_args.model_name_or_path,
497
+ config=config,
498
+ seed=training_args.seed_model,
499
  dtype=getattr(jnp, model_args.dtype),
500
  abstract_init=True,
501
  load_on_cpu=True,
502
  )
503
+ else:
504
+ model = DalleBart(
505
+ config,
506
+ seed=training_args.seed_model,
507
+ dtype=getattr(jnp, model_args.dtype),
508
+ load_on_cpu=True,
509
+ )
510
 
511
+ # Load tokenizer
512
+ if model_args.tokenizer_name is not None:
513
  tokenizer = DalleBartTokenizer.from_pretrained(
514
+ model_args.tokenizer_name, use_fast=True
 
515
  )
 
516
  else:
517
+ tokenizer = DalleBartTokenizer.from_pretrained(
518
+ model_args.model_name_or_path,
519
+ use_fast=True,
520
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
 
522
  # get PartitionSpec for model params (required to be a dict)
523
  param_spec = set_partitions(model.params)
 
647
 
648
  # get PartitionSpec for optimizer state
649
  def get_opt_state_spec_and_shape(param_spec):
650
+ # get opt_state shape without actual init
651
+ opt_state_shape = jax.eval_shape(optimizer.init, model.params)
652
+
653
+ if training_args.optim == "adam":
654
+
655
+ def _opt_state_spec_per_leaf(x):
656
+ if isinstance(x, FrozenDict):
657
+ # variables with same structure as params
658
+ return param_spec
659
+ else:
660
+ # other variables such as count
661
+ return None
662
+
663
+ opt_state_spec = jax.tree_map(
664
+ _opt_state_spec_per_leaf,
665
+ opt_state_shape,
666
+ # return None spec for empty elements
667
+ is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
668
+ )
 
669
 
670
+ elif training_args.optim == "adafactor":
671
+ # factorized state must be replicated (rank different than params)
672
+ opt_state_spec = None
673
 
674
  elif training_args.optim == "distributed_shampoo":
675
  opt_state_spec = opt_fn.pspec_fn(
 
677
  params_partition_spec=param_spec,
678
  partition_spec_for_statistics=PartitionSpec(None, "batch", None),
679
  )
 
680
  else:
681
  raise NotImplementedError
682
  return opt_state_spec, opt_state_shape
 
688
  devices = np.asarray(jax.devices()).reshape(*mesh_shape)
689
  mesh = maps.Mesh(devices, ("batch", "mp"))
690
 
691
+ # define state spec
692
  state_spec = TrainState(
693
  params=param_spec,
694
  opt_state=opt_state_spec,
 
703
 
704
  # create training state
705
  with maps.mesh(mesh.devices, mesh.axis_names):
706
+ if not model_args.restore_state:
707
 
708
  def init_state(params):
709
  return TrainState.create(
 
721
  )(model.params)
722
 
723
  else:
724
+ # get state files from artifact
725
+ if jax.process_index() == 0:
726
+ artifact = wandb.run.use_artifact(model_args.state_artifact)
727
+ else:
728
+ artifact = wandb.Api().artifact(model_args.state_artifact)
729
+ artifact_dir = artifact.download()
730
+
731
  # restore opt_state
732
  with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
733
  opt_state = from_bytes(opt_state_shape, f.read())
 
757
  del opt_state
758
 
759
  # free memory
760
+ del model._params, opt_state_spec, opt_state_shape
761
 
762
  # define batch specs
763
  keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
 
995
  f,
996
  )
997
 
998
+ # save to W&B
999
+ if training_args.log_model:
1000
+ # save some space
1001
+ c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
1002
+ c.cleanup(wandb.util.from_human_size("10GB"))
1003
+
1004
+ metadata = dict(state_dict)
1005
+ metadata["num_params"] = num_params
1006
+ if eval_metrics is not None:
1007
+ metadata["eval"] = eval_metrics
1008
+
1009
+ # create model artifact
1010
+ artifact = wandb.Artifact(
1011
+ name=f"model-{wandb.run.id}",
1012
+ type="DalleBart_model",
1013
+ metadata=metadata,
1014
+ )
1015
+ for filename in [
1016
+ "config.json",
1017
+ "flax_model.msgpack",
1018
+ "merges.txt",
1019
+ "special_tokens_map.json",
1020
+ "tokenizer.json",
1021
+ "tokenizer_config.json",
1022
+ "vocab.json",
1023
+ ]:
1024
+ artifact.add_file(f"{Path(training_args.output_dir) / filename}")
1025
+ wandb.run.log_artifact(artifact)
1026
+
1027
+ # create state artifact
1028
+ artifact_state = wandb.Artifact(
1029
+ name=f"state-{wandb.run.id}",
1030
+ type="DalleBart_state",
1031
+ metadata=metadata,
1032
+ )
1033
+ for filename in ["opt_state.msgpack", "training_state.json"]:
1034
+ artifact_state.add_file(
1035
+ f"{Path(training_args.output_dir) / filename}"
 
 
 
 
1036
  )
1037
+ wandb.run.log_artifact(artifact_state)
 
1038
 
1039
  # init variables
1040
  last_time = time.perf_counter()