boris commited on
Commit
b993d27
1 Parent(s): 2f1e5d9

feat: vmap optimizer (#166)

Browse files
src/dalle_mini/model/modeling.py CHANGED
@@ -946,15 +946,6 @@ class FlaxBartEncoderLayerCollection(nn.Module):
946
  if output_hidden_states:
947
  all_hidden_states += (hidden_states,)
948
 
949
- # postln is already applied in every layer
950
- if self.config.use_final_ln_encoder and self.config.ln_positions != "postln":
951
- hidden_states = norm(
952
- self.config.ln_type,
953
- dtype=self.dtype,
954
- epsilon=1e-05,
955
- use_scale=self.config.force_ln_scale,
956
- )(hidden_states)
957
-
958
  outputs = [
959
  hidden_states,
960
  all_hidden_states,
@@ -1034,7 +1025,7 @@ class FlaxBartDecoderLayerCollection(nn.Module):
1034
  self.config,
1035
  dtype=self.dtype,
1036
  add_norm=self.config.ln_positions == "postln",
1037
- name="FlaxBartEncoderLayers",
1038
  )(
1039
  hidden_states,
1040
  attention_mask,
@@ -1086,15 +1077,6 @@ class FlaxBartDecoderLayerCollection(nn.Module):
1086
  if output_hidden_states:
1087
  all_hidden_states += (hidden_states,)
1088
 
1089
- # postln is already applied in every layer
1090
- if self.config.use_final_ln_decoder and self.config.ln_positions != "postln":
1091
- hidden_states = norm(
1092
- self.config.ln_type,
1093
- dtype=self.dtype,
1094
- epsilon=1e-05,
1095
- use_scale=self.config.force_ln_scale,
1096
- )(hidden_states)
1097
-
1098
  outputs = [
1099
  hidden_states,
1100
  all_hidden_states,
@@ -1146,6 +1128,17 @@ class FlaxBartEncoder(nn.Module):
1146
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05
1147
  )
1148
 
 
 
 
 
 
 
 
 
 
 
 
1149
  def __call__(
1150
  self,
1151
  input_ids,
@@ -1177,11 +1170,16 @@ class FlaxBartEncoder(nn.Module):
1177
  return_dict=return_dict,
1178
  )
1179
 
 
 
 
 
 
1180
  if not return_dict:
1181
- return outputs
1182
 
1183
  return FlaxBaseModelOutput(
1184
- last_hidden_state=outputs.last_hidden_state,
1185
  hidden_states=outputs.hidden_states,
1186
  attentions=outputs.attentions,
1187
  )
@@ -1223,6 +1221,15 @@ class FlaxBartDecoder(nn.Module):
1223
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05
1224
  )
1225
 
 
 
 
 
 
 
 
 
 
1226
  def __call__(
1227
  self,
1228
  input_ids,
@@ -1260,11 +1267,16 @@ class FlaxBartDecoder(nn.Module):
1260
  return_dict=return_dict,
1261
  )
1262
 
 
 
 
 
 
1263
  if not return_dict:
1264
- return outputs
1265
 
1266
  return FlaxBaseModelOutputWithPastAndCrossAttentions(
1267
- last_hidden_state=outputs.last_hidden_state,
1268
  hidden_states=outputs.hidden_states,
1269
  attentions=outputs.attentions,
1270
  cross_attentions=outputs.cross_attentions,
 
946
  if output_hidden_states:
947
  all_hidden_states += (hidden_states,)
948
 
 
 
 
 
 
 
 
 
 
949
  outputs = [
950
  hidden_states,
951
  all_hidden_states,
 
1025
  self.config,
1026
  dtype=self.dtype,
1027
  add_norm=self.config.ln_positions == "postln",
1028
+ name="FlaxBartDecoderLayers",
1029
  )(
1030
  hidden_states,
1031
  attention_mask,
 
1077
  if output_hidden_states:
1078
  all_hidden_states += (hidden_states,)
1079
 
 
 
 
 
 
 
 
 
 
1080
  outputs = [
1081
  hidden_states,
1082
  all_hidden_states,
 
1128
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05
1129
  )
1130
 
1131
+ # postln is already applied in every layer
1132
+ if self.config.use_final_ln_encoder and self.config.ln_positions != "postln":
1133
+ self.final_ln = norm(
1134
+ self.config.ln_type,
1135
+ dtype=self.dtype,
1136
+ epsilon=1e-05,
1137
+ use_scale=self.config.force_ln_scale,
1138
+ )
1139
+ else:
1140
+ self.final_ln = None
1141
+
1142
  def __call__(
1143
  self,
1144
  input_ids,
 
1170
  return_dict=return_dict,
1171
  )
1172
 
1173
+ if self.final_ln is None:
1174
+ final_output = outputs[0]
1175
+ else:
1176
+ final_output = self.final_ln(outputs[0])
1177
+
1178
  if not return_dict:
1179
+ return (final_output,) + outputs[1:]
1180
 
1181
  return FlaxBaseModelOutput(
1182
+ last_hidden_state=final_output,
1183
  hidden_states=outputs.hidden_states,
1184
  attentions=outputs.attentions,
1185
  )
 
1221
  self.config.ln_type, dtype=self.dtype, epsilon=1e-05
1222
  )
1223
 
1224
+ # postln is already applied in every layer
1225
+ if self.config.use_final_ln_decoder and self.config.ln_positions != "postln":
1226
+ self.final_ln = norm(
1227
+ self.config.ln_type,
1228
+ dtype=self.dtype,
1229
+ epsilon=1e-05,
1230
+ use_scale=self.config.force_ln_scale,
1231
+ )
1232
+
1233
  def __call__(
1234
  self,
1235
  input_ids,
 
1267
  return_dict=return_dict,
1268
  )
1269
 
1270
+ if self.final_ln is None:
1271
+ final_output = outputs[0]
1272
+ else:
1273
+ final_output = self.final_ln(outputs[0])
1274
+
1275
  if not return_dict:
1276
+ return (final_output,) + outputs[1:]
1277
 
1278
  return FlaxBaseModelOutputWithPastAndCrossAttentions(
1279
+ last_hidden_state=final_output,
1280
  hidden_states=outputs.hidden_states,
1281
  attentions=outputs.attentions,
1282
  cross_attentions=outputs.cross_attentions,
src/dalle_mini/model/partitions.py CHANGED
@@ -65,7 +65,7 @@ def set_partitions(in_dict, use_scan):
65
  print(f"Unmatched -> {k}")
66
  l = list(result.keys())
67
  if use_scan:
68
- # add None dimension to scanned layers
69
  result = {
70
  k: (P(*(None,) + v) if v is not None else None)
71
  if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"])
 
65
  print(f"Unmatched -> {k}")
66
  l = list(result.keys())
67
  if use_scan:
68
+ # add None dimension to layers
69
  result = {
70
  k: (P(*(None,) + v) if v is not None else None)
71
  if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"])
tools/train/config/mega/config.json CHANGED
@@ -7,14 +7,14 @@
7
  "decoder_attention_heads": 32,
8
  "decoder_ffn_dim": 4096,
9
  "decoder_layerdrop": 0.0,
10
- "decoder_layers": 26,
11
  "decoder_start_token_id": 16384,
12
  "do_sample": true,
13
  "dropout": 0.0,
14
  "encoder_attention_heads": 32,
15
  "encoder_ffn_dim": 4096,
16
  "encoder_layerdrop": 0.0,
17
- "encoder_layers": 26,
18
  "encoder_vocab_size": 50272,
19
  "eos_token_id": 16385,
20
  "force_ln_scale": false,
 
7
  "decoder_attention_heads": 32,
8
  "decoder_ffn_dim": 4096,
9
  "decoder_layerdrop": 0.0,
10
+ "decoder_layers": 24,
11
  "decoder_start_token_id": 16384,
12
  "do_sample": true,
13
  "dropout": 0.0,
14
  "encoder_attention_heads": 32,
15
  "encoder_ffn_dim": 4096,
16
  "encoder_layerdrop": 0.0,
17
+ "encoder_layers": 24,
18
  "encoder_vocab_size": 50272,
19
  "eos_token_id": 16385,
20
  "force_ln_scale": false,
tools/train/train.py CHANGED
@@ -38,11 +38,10 @@ import optax
38
  import transformers
39
  import wandb
40
  from datasets import Dataset
 
41
  from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
42
  from flax.serialization import from_bytes, to_bytes
43
- from flax.training import train_state
44
  from flax.training.common_utils import onehot
45
- from jax import ShapeDtypeStruct
46
  from jax.experimental import PartitionSpec, maps
47
  from jax.experimental.compilation_cache import compilation_cache as cc
48
  from jax.experimental.pjit import pjit, with_sharding_constraint
@@ -526,60 +525,78 @@ class TrainingArguments:
526
  self.dp_devices = jax.device_count() // self.mp_devices
527
 
528
 
529
- class TrainState(train_state.TrainState):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  dropout_rng: jnp.ndarray = None
531
  epoch: int = 0
532
  train_time: float = 0.0 # total time the model trained
533
  train_samples: int = 0 # number of samples seen
534
 
535
  def apply_gradients(self, *, grads, **kwargs):
536
- params = self.unscan(self.params)
537
- updates, new_opt_state = self.tx.update(
538
- self.unscan(grads), self.opt_state, params
539
- )
540
- params = optax.apply_updates(params, updates)
 
 
 
 
 
 
 
 
541
  return self.replace(
542
  step=self.step + 1,
543
- params=self.rescan(params),
544
- opt_state=new_opt_state,
545
  **kwargs,
546
  )
547
 
548
  @classmethod
549
  def create(cls, *, apply_fn, params, tx, **kwargs):
550
- opt_state = tx.init(cls.unscan(params))
 
 
 
 
 
551
  return cls(
552
  step=0,
553
  apply_fn=apply_fn,
554
  params=params,
555
  tx=tx,
556
- opt_state=opt_state,
557
  **kwargs,
558
  )
559
 
560
- @staticmethod
561
- def unscan(params):
562
- params = unfreeze(params)
563
- for l in ["encoder", "decoder"]:
564
- params["model"][l]["layers"] = jax.tree_map(
565
- lambda x: {f"{i}": x[i] for i in range(len(x))},
566
- params["model"][l]["layers"],
567
- )
568
- params = freeze(params)
569
- return params
570
-
571
- @staticmethod
572
- def rescan(params):
573
- params = unfreeze(params)
574
- for l in ["encoder", "decoder"]:
575
- params["model"][l]["layers"] = jax.tree_map(
576
- lambda x: jnp.stack([x[f"{i}"] for i in range(len(x))]),
577
- params["model"][l]["layers"],
578
- is_leaf=lambda x: "0" in x,
579
- )
580
- params = freeze(params)
581
- return params
582
-
583
 
584
  def main():
585
  # See all possible arguments by passing the --help flag to this script.
@@ -792,23 +809,6 @@ def main():
792
 
793
  learning_rate_fn = create_learning_rate_fn()
794
 
795
- # reshape params to split scanned layers for optimizers
796
- if model.config.use_scan:
797
- params_struct = unfreeze(model.params)
798
- for l in ["encoder", "decoder"]:
799
- params_struct["model"][l]["layers"] = jax.tree_map(
800
- lambda x: {
801
- f"{i}": ShapeDtypeStruct(shape=x.shape[1:], dtype=x.dtype)
802
- for i in range(len(x))
803
- },
804
- params_struct["model"][l]["layers"],
805
- )
806
- params_struct = freeze(params_struct)
807
-
808
- else:
809
- params_struct = model.params
810
- opt_param_spec = set_partitions(params_struct, False)
811
-
812
  # create adam optimizer
813
  if training_args.optim == "distributed_shampoo":
814
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
@@ -820,7 +820,12 @@ def main():
820
  "sqrt_n": GraftingType.SQRT_N,
821
  "adagrad_normalized": GraftingType.ADAGRAD_NORMALIZED,
822
  }[training_args.graft_type]
823
- optimizer = distributed_shampoo(
 
 
 
 
 
824
  learning_rate_fn,
825
  block_size=training_args.block_size,
826
  beta1=training_args.beta1,
@@ -836,11 +841,7 @@ def main():
836
  graft_type=graft_type,
837
  nesterov=False,
838
  exponent_override=0,
839
- statistics_partition_spec=PartitionSpec(
840
- None, training_args.shard_shampoo_across, None
841
- )
842
- if training_args.shard_shampoo_across != "2d"
843
- else PartitionSpec(None, "dp", "mp"),
844
  preconditioner_partition_spec=PartitionSpec(
845
  training_args.shard_shampoo_across, None, None
846
  )
@@ -860,14 +861,18 @@ def main():
860
  best_effort_memory_usage_reduction=training_args.optim_quantized,
861
  )
862
  # get the real optimizer and helper functions
863
- update_fn = optimizer.update
864
-
865
- optimizer = optimizer.init(params_struct)
866
- opt_fn = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
867
- optimizer.pspec_fn, optimizer.shape_and_dtype_fn
868
- )
869
-
870
- optimizer = optax.GradientTransformation(optimizer.init_fn, update_fn)
 
 
 
 
871
 
872
  elif training_args.optim == "adam":
873
  optimizer = optax.adamw(
@@ -876,6 +881,8 @@ def main():
876
  b2=training_args.beta2,
877
  eps=training_args.adam_epsilon,
878
  )
 
 
879
  elif training_args.optim == "adafactor":
880
  # We use the default parameters here to initialize adafactor,
881
  # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
@@ -883,44 +890,66 @@ def main():
883
  learning_rate=learning_rate_fn,
884
  clipping_threshold=training_args.max_grad_norm,
885
  )
 
886
 
887
  # get PartitionSpec for optimizer state
888
- def get_opt_state_spec_and_shape(param_spec):
889
  # get opt_state shape without actual init
890
- opt_state_shape = jax.eval_shape(optimizer.init, params_struct)
 
 
 
 
 
891
 
892
- if training_args.optim == "adam":
 
 
893
 
894
- def _opt_state_spec_per_leaf(x):
 
 
895
  if isinstance(x, FrozenDict):
896
  # variables with same structure as params
897
- return param_spec
898
  else:
899
  # other variables such as count
900
  return None
901
 
902
- opt_state_spec = jax.tree_map(
903
- _opt_state_spec_per_leaf,
904
- opt_state_shape,
905
- # return None spec for empty elements
906
- is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
907
- )
908
-
909
- elif training_args.optim == "adafactor":
910
- # factorized state must be replicated (rank different than params)
911
- opt_state_spec = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
912
 
913
- elif training_args.optim == "distributed_shampoo":
914
- opt_state_spec = opt_fn.pspec_fn(
915
- params=params_struct,
916
- params_partition_spec=param_spec,
917
- partition_spec_for_statistics=PartitionSpec(None, "dp", None),
918
- )
919
  else:
920
  raise NotImplementedError
921
- return opt_state_spec, opt_state_shape
922
 
923
- opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape(opt_param_spec)
924
 
925
  # create a mesh
926
  mesh_shape = (training_args.dp_devices, training_args.mp_devices)
 
38
  import transformers
39
  import wandb
40
  from datasets import Dataset
41
+ from flax import core, struct, traverse_util
42
  from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
43
  from flax.serialization import from_bytes, to_bytes
 
44
  from flax.training.common_utils import onehot
 
45
  from jax.experimental import PartitionSpec, maps
46
  from jax.experimental.compilation_cache import compilation_cache as cc
47
  from jax.experimental.pjit import pjit, with_sharding_constraint
 
525
  self.dp_devices = jax.device_count() // self.mp_devices
526
 
527
 
528
+ def split_params(data):
529
+ """Split params between scanned and non-scanned"""
530
+ flat = traverse_util.flatten_dict(unfreeze(data))
531
+ split = {"standard": {}, "scanned_encoder": {}, "scanned_decoder": {}}
532
+ for k, v in flat.items():
533
+ if "FlaxBartEncoderLayers" in k:
534
+ split["scanned_encoder"][k] = v
535
+ elif "FlaxBartDecoderLayers" in k:
536
+ split["scanned_decoder"][k] = v
537
+ else:
538
+ split["standard"][k] = v
539
+ for k, v in split.items():
540
+ split[k] = freeze(traverse_util.unflatten_dict(v))
541
+ return split
542
+
543
+
544
+ def unsplit_params(data):
545
+ flat = {}
546
+ for k in ["standard", "scanned_encoder", "scanned_decoder"]:
547
+ flat.update(traverse_util.flatten_dict(unfreeze(data[k])))
548
+ return freeze(traverse_util.unflatten_dict(flat))
549
+
550
+
551
+ class TrainState(struct.PyTreeNode):
552
+ step: int
553
+ params: core.FrozenDict[str, Any]
554
+ opt_state: optax.OptState
555
+ apply_fn: Callable = struct.field(pytree_node=False)
556
+ tx: optax.GradientTransformation = struct.field(pytree_node=False)
557
  dropout_rng: jnp.ndarray = None
558
  epoch: int = 0
559
  train_time: float = 0.0 # total time the model trained
560
  train_samples: int = 0 # number of samples seen
561
 
562
  def apply_gradients(self, *, grads, **kwargs):
563
+ grads = split_params(grads)
564
+ params = split_params(self.params)
565
+ opt_state = {}
566
+ # we loop over keys: "standard", "scanned_encoder", "scanned_decoder"
567
+ for k, param in params.items():
568
+ update_fn = self.tx[k].update
569
+ if "scanned" in k:
570
+ update_fn = jax.vmap(update_fn, in_axes=(0, 0, 0), out_axes=(0, 0))
571
+ updates, new_opt_state = update_fn(grads[k], self.opt_state[k], param)
572
+ params[k] = optax.apply_updates(param, updates)
573
+ opt_state[k] = new_opt_state
574
+ params = unsplit_params(params)
575
+
576
  return self.replace(
577
  step=self.step + 1,
578
+ params=params,
579
+ opt_state=freeze(opt_state),
580
  **kwargs,
581
  )
582
 
583
  @classmethod
584
  def create(cls, *, apply_fn, params, tx, **kwargs):
585
+ opt_state = {}
586
+ for k, p in split_params(params).items():
587
+ init_fn = tx[k].init
588
+ if "scanned" in k:
589
+ init_fn = jax.vmap(init_fn)
590
+ opt_state[k] = init_fn(p)
591
  return cls(
592
  step=0,
593
  apply_fn=apply_fn,
594
  params=params,
595
  tx=tx,
596
+ opt_state=freeze(opt_state),
597
  **kwargs,
598
  )
599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
 
601
  def main():
602
  # See all possible arguments by passing the --help flag to this script.
 
809
 
810
  learning_rate_fn = create_learning_rate_fn()
811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
812
  # create adam optimizer
813
  if training_args.optim == "distributed_shampoo":
814
  # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
 
820
  "sqrt_n": GraftingType.SQRT_N,
821
  "adagrad_normalized": GraftingType.ADAGRAD_NORMALIZED,
822
  }[training_args.graft_type]
823
+ statistics_partition_spec = (
824
+ PartitionSpec(None, training_args.shard_shampoo_across, None)
825
+ if training_args.shard_shampoo_across != "2d"
826
+ else PartitionSpec(None, "dp", "mp")
827
+ )
828
+ opt = distributed_shampoo(
829
  learning_rate_fn,
830
  block_size=training_args.block_size,
831
  beta1=training_args.beta1,
 
841
  graft_type=graft_type,
842
  nesterov=False,
843
  exponent_override=0,
844
+ statistics_partition_spec=statistics_partition_spec,
 
 
 
 
845
  preconditioner_partition_spec=PartitionSpec(
846
  training_args.shard_shampoo_across, None, None
847
  )
 
861
  best_effort_memory_usage_reduction=training_args.optim_quantized,
862
  )
863
  # get the real optimizer and helper functions
864
+ update_fn = opt.update
865
+
866
+ optimizer = {}
867
+ opt_fn = {}
868
+ for k, p in split_params(model.params).items():
869
+ if "scanned" in k:
870
+ p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
871
+ optimizer[k] = opt.init(p)
872
+ opt_fn[k] = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
873
+ optimizer[k].pspec_fn, optimizer[k].shape_and_dtype_fn
874
+ )
875
+ optimizer[k] = optax.GradientTransformation(optimizer[k].init_fn, update_fn)
876
 
877
  elif training_args.optim == "adam":
878
  optimizer = optax.adamw(
 
881
  b2=training_args.beta2,
882
  eps=training_args.adam_epsilon,
883
  )
884
+ optimizer = {k: optimizer for k in split_params(model.params)}
885
+
886
  elif training_args.optim == "adafactor":
887
  # We use the default parameters here to initialize adafactor,
888
  # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
 
890
  learning_rate=learning_rate_fn,
891
  clipping_threshold=training_args.max_grad_norm,
892
  )
893
+ optimizer = {k: optimizer for k in split_params(model.params)}
894
 
895
  # get PartitionSpec for optimizer state
896
+ def get_opt_state_spec_and_shape():
897
  # get opt_state shape without actual init
898
+ opt_state_shape = {}
899
+ for k, p in split_params(model.params).items():
900
+ if "scanned" not in k:
901
+ opt_state_shape[k] = jax.eval_shape(optimizer[k].init, p)
902
+ else:
903
+ opt_state_shape[k] = jax.eval_shape(jax.vmap(optimizer[k].init), p)
904
 
905
+ if training_args.optim == "adafactor":
906
+ # factorized state must be replicated (rank different than params)
907
+ opt_state_spec = {k: None for k in split_params(model.params)}
908
 
909
+ elif training_args.optim in ["adam", "distributed_shampoo"]:
910
+
911
+ def _opt_state_spec_per_leaf(x, spec):
912
  if isinstance(x, FrozenDict):
913
  # variables with same structure as params
914
+ return spec
915
  else:
916
  # other variables such as count
917
  return None
918
 
919
+ split_spec = split_params(set_partitions(model.params, False))
920
+ opt_state_spec = {}
921
+ for k, p in split_params(model.params).items():
922
+ if "scanned" in k:
923
+ p = jax.eval_shape(lambda x: jax.tree_map(lambda y: y[0], x), p)
924
+ if training_args.optim == "adam":
925
+ opt_state_spec[k] = jax.tree_map(
926
+ _opt_state_spec_per_leaf,
927
+ opt_state_shape[k],
928
+ split_spec[k],
929
+ # return None spec for empty elements
930
+ is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
931
+ )
932
+ elif training_args.optim == "distributed_shampoo":
933
+ opt_state_spec[k] = opt_fn[k].pspec_fn(
934
+ p,
935
+ split_spec[k],
936
+ statistics_partition_spec,
937
+ )
938
+ # add dimension for scanned params
939
+ if "scanned" in k:
940
+ opt_state_spec[k] = jax.tree_map(
941
+ lambda x: PartitionSpec(*(None,) + x)
942
+ if x is not None
943
+ else None,
944
+ opt_state_spec[k],
945
+ is_leaf=lambda x: isinstance(x, PartitionSpec),
946
+ )
947
 
 
 
 
 
 
 
948
  else:
949
  raise NotImplementedError
950
+ return freeze(opt_state_spec), freeze(opt_state_shape)
951
 
952
+ opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape()
953
 
954
  # create a mesh
955
  mesh_shape = (training_args.dp_devices, training_args.mp_devices)