boris commited on
Commit
8a9e367
1 Parent(s): cc34d07

feat: update distributed_shampoo + fix None spec

Browse files
Files changed (1) hide show
  1. tools/train/distributed_shampoo.py +427 -61
tools/train/distributed_shampoo.py CHANGED
@@ -1,7 +1,5 @@
1
- """File copied from https://github.com/google-research/google-research/edit/master/scalable_shampoo/optax/distributed_shampoo.py"""
2
-
3
  # coding=utf-8
4
- # Copyright 2021 The Google Research Authors.
5
  #
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
  # you may not use this file except in compliance with the License.
@@ -147,6 +145,12 @@ class QuantizedValue:
147
  return val
148
 
149
 
 
 
 
 
 
 
150
  # Per parameter optimizer state used in data-parallel training.
151
  class ParameterStats(NamedTuple):
152
  """State associated to each parameter of the model being trained."""
@@ -156,6 +160,7 @@ class ParameterStats(NamedTuple):
156
  preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
157
  diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
158
  momentum: QuantizedValue # Momentum for the shampoo preconditioner
 
159
 
160
 
161
  # For training extremely large model; We keep a global state with a concatenated
@@ -166,6 +171,7 @@ class ParameterStats(NamedTuple):
166
  class GlobalShardedParameterStats:
167
  statistics: chex.Array # Statistics
168
  preconditioners: chex.Array # Preconditioners
 
169
 
170
 
171
  # These are per-parameter local states; All statistics here mirror the parameter
@@ -177,12 +183,34 @@ class LocalShardedParameterStats:
177
  diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
178
  diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
179
  momentum: QuantizedValue # Momentum for the shampoo preconditioner
 
180
  index_start: np.int32 = struct.field(
181
  pytree_node=False
182
  ) # Index into global statistics array
183
  sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  class ShardedShampooStats(NamedTuple):
187
  """Shampoo state in sharded mode."""
188
 
@@ -195,6 +223,12 @@ class ShampooState(NamedTuple):
195
  stats: Any
196
 
197
 
 
 
 
 
 
 
198
  class GraftingType(enum.IntEnum):
199
  SGD = 1
200
  ADAGRAD = 2
@@ -292,6 +326,8 @@ def matrix_inverse_pth_root(
292
  matrix^(-1/p)
293
  """
294
 
 
 
295
  # We use float32 for the matrix inverse pth root.
296
  # Switch to f64 if you have hardware that supports it.
297
  matrix_size = matrix.shape[0]
@@ -615,6 +651,7 @@ def _convert_to_parameter_stats(global_stats, local_stat):
615
  new_preconditioners,
616
  local_stat.diagonal_momentum,
617
  local_stat.momentum,
 
618
  )
619
 
620
 
@@ -624,11 +661,40 @@ def _convert_from_parameter_stats(parameter_stats, local_stats):
624
  parameter_stats.diagonal_statistics,
625
  parameter_stats.diagonal_momentum,
626
  parameter_stats.momentum,
 
627
  local_stats.index_start,
628
  local_stats.sizes,
629
  )
630
 
631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
  def batch(x, num_devices):
633
  """Batch `x` so that so that leading axis is num_devices."""
634
  n = len(x)
@@ -670,7 +736,8 @@ def distributed_shampoo(
670
  batch_axis_name=None,
671
  ### Only set following 3 params in pjit/spmd mode.
672
  ### WARNING: Experimental
673
- mesh_axis_names=None,
 
674
  num_devices_for_pjit=None,
675
  shard_optimizer_states=False,
676
  ###
@@ -730,7 +797,8 @@ def distributed_shampoo(
730
  exponent_override: Override the exponent used in matrix inverse.
731
  batch_axis_name: labeled axis over pmap for data-parallel training the
732
  optimizer used for.
733
- mesh_axis_names: Axis names for the mesh (used in pjit).
 
734
  num_devices_for_pjit: Number of devices to parallelize over when using pjit.
735
  shard_optimizer_states: Shard optimizer states to save memory in model
736
  parallel training.
@@ -830,6 +898,11 @@ def distributed_shampoo(
830
  )
831
 
832
  def sharded_init_fn(params):
 
 
 
 
 
833
  params_flat, treedef = jax.tree_flatten(params)
834
  # Find max size to pad to.
835
  max_size = 0
@@ -845,6 +918,7 @@ def distributed_shampoo(
845
  padded_statistics = []
846
  padded_preconditioners = []
847
  local_stats_flat = []
 
848
  for param in params_flat:
849
  preconditioner = Preconditioner(
850
  param, block_size, best_effort_shape_interpretation
@@ -862,6 +936,12 @@ def distributed_shampoo(
862
  preconditioners = [jnp.eye(max_size) for s in shapes]
863
  padded_statistics.extend(statistics)
864
  padded_preconditioners.extend(preconditioners)
 
 
 
 
 
 
865
 
866
  diagonal_statistics = []
867
  if graft_type != GraftingType.SGD:
@@ -871,6 +951,7 @@ def distributed_shampoo(
871
  _quantize_diagonal_statistics(diagonal_statistics),
872
  _quantize_momentum(jnp.zeros_like(param)),
873
  _quantize_momentum(jnp.zeros_like(param)),
 
874
  index_start,
875
  sizes,
876
  )
@@ -888,14 +969,238 @@ def distributed_shampoo(
888
  padded_preconditioners.extend(
889
  [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
890
  )
 
891
  global_stats = GlobalShardedParameterStats(
892
- jnp.stack(padded_statistics), jnp.stack(padded_preconditioners)
 
 
893
  )
894
  return ShampooState(
895
  count=jnp.zeros([], jnp.int32),
896
  stats=ShardedShampooStats(global_stats, local_stats),
897
  )
898
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
899
  def sharded_update_fn(grads, state, params):
900
  """Transform the input gradient and update all statistics in sharded mode.
901
 
@@ -923,20 +1228,6 @@ def distributed_shampoo(
923
  params_flat,
924
  )
925
 
926
- exponents = []
927
- for stat, param in zip(new_stats_flat, params_flat):
928
- num_statistics = len(stat.statistics)
929
- if num_statistics > 0:
930
- preconditioner = Preconditioner(
931
- param, block_size, best_effort_shape_interpretation
932
- )
933
- exponent = (
934
- preconditioner.exponent_for_preconditioner()
935
- if exponent_override == 0
936
- else exponent_override
937
- )
938
- exponents.extend([exponent] * num_statistics)
939
-
940
  outputs = jax.tree_multimap(
941
  lambda g, s, p: _transform_grad(g, s, p, state.count),
942
  grads_flat,
@@ -951,7 +1242,6 @@ def distributed_shampoo(
951
  _convert_from_parameter_stats(new_stat, local_stat)
952
  for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
953
  ]
954
- new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
955
 
956
  max_size = global_stats.statistics.shape[1]
957
  new_padded_statistics = []
@@ -974,22 +1264,16 @@ def distributed_shampoo(
974
  for _ in range(to_pad)
975
  ]
976
  )
977
- exponents.extend([1 for _ in range(to_pad)])
978
  new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
979
- new_stacked_exponents = jnp.stack(exponents)
980
-
981
- def _matrix_inverse_pth_root_vmap(xs, ps):
982
- mi_pth_root = functools.partial(
983
- matrix_inverse_pth_root,
984
- ridge_epsilon=matrix_epsilon,
985
- precision=precision,
986
- )
987
- preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
988
- return preconditioners, errors
989
 
990
  def _internal_inverse_pth_root_all():
991
- preconditioners, errors = _matrix_inverse_pth_root_vmap(
992
- new_stacked_padded_statistics, new_stacked_exponents
 
 
993
  )
994
  return preconditioners, errors
995
 
@@ -1000,13 +1284,18 @@ def distributed_shampoo(
1000
  # shaped tensors. Note statistics will be ignored as we are passing in
1001
  # a large init value for error.
1002
  preconditioners_init = new_stacked_padded_statistics
1003
- errors_init = np.stack([inverse_failure_threshold] * len(exponents))
 
1004
  init_state = [preconditioners_init, errors_init]
1005
  perform_step = state.count % preconditioning_compute_steps == 0
1006
  new_preconditioners, errors = efficient_cond(
1007
  perform_step, _internal_inverse_pth_root_all, init_state
1008
  )
1009
 
 
 
 
 
1010
  errors = errors.reshape((-1, 1, 1))
1011
  predicate = jnp.logical_or(
1012
  jnp.isnan(errors), errors >= inverse_failure_threshold
@@ -1017,7 +1306,9 @@ def distributed_shampoo(
1017
  + (1.0 - predicate) * new_preconditioners
1018
  )
1019
  new_global_stats = GlobalShardedParameterStats(
1020
- new_stacked_padded_statistics, new_conditional_preconditioners
 
 
1021
  )
1022
  new_shampoo_state = ShampooState(
1023
  count=state.count + 1,
@@ -1048,6 +1339,7 @@ def distributed_shampoo(
1048
  _maybe_quantize_preconditioners(preconditioners),
1049
  _quantize_momentum(jnp.zeros_like(param)),
1050
  _quantize_momentum(jnp.zeros_like(param)),
 
1051
  )
1052
 
1053
  return ShampooState(
@@ -1092,6 +1384,7 @@ def distributed_shampoo(
1092
  state.preconditioners,
1093
  state.diagonal_momentum,
1094
  state.momentum,
 
1095
  )
1096
 
1097
  def _matrix_inverse_pth_root_vmap(xs, ps):
@@ -1115,33 +1408,27 @@ def distributed_shampoo(
1115
 
1116
  return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
1117
 
1118
- def _matrix_inverse_pth_root_pjit(xs, ps):
1119
- mesh_axis_names_tuple = tuple(mesh_axis_names)
1120
  # Partition the concatenated statistics matrix across all cores.
1121
- partitioned_xs, partitioned_ps = pjit.pjit(
1122
- lambda x, y: (x, y),
1123
- in_axis_resources=None,
1124
- out_axis_resources=pjit.PartitionSpec(
1125
- mesh_axis_names_tuple,
1126
- ),
1127
- )(xs, ps)
1128
  # Run matrix inverse pth root on each shard.
1129
  partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
1130
  partitioned_xs, partitioned_ps
1131
  )
 
 
 
 
 
1132
  # Recombine the outputs at each core.
1133
- preconditioners, errors = pjit.pjit(
1134
- lambda x, y: (x, y),
1135
- in_axis_resources=(
1136
- pjit.PartitionSpec(
1137
- mesh_axis_names_tuple,
1138
- ),
1139
- pjit.PartitionSpec(
1140
- mesh_axis_names_tuple,
1141
- ),
1142
- ),
1143
- out_axis_resources=(None, None),
1144
- )(partitioned_preconditioners, partitioned_errors)
1145
  return preconditioners, errors
1146
 
1147
  def _pmap_compute_preconditioners(
@@ -1223,31 +1510,54 @@ def distributed_shampoo(
1223
  )
1224
 
1225
  new_preconditioners_flat = []
 
1226
  for p, shape, prev_p, error in zip(
1227
  preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
1228
  ):
1229
  new_preconditioners_flat.append(
1230
  _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
1231
  )
 
1232
 
1233
  assert len(states) == len(num_statistics_per_state)
1234
  assert len(new_preconditioners_flat) == num_statistics
 
1235
 
1236
  # Add back empty preconditioners so we that we can set the optimizer state.
1237
  preconditioners_for_states = []
1238
  idx = 0
 
1239
  for num_statistics, state in zip(num_statistics_per_state, states):
1240
  if num_statistics == 0:
1241
  preconditioners_for_states.append([])
 
1242
  else:
1243
  preconditioners_for_state = new_preconditioners_flat[
1244
  idx : idx + num_statistics
1245
  ]
1246
  assert len(state.statistics) == len(preconditioners_for_state)
1247
  preconditioners_for_states.append(preconditioners_for_state)
 
 
 
 
 
 
 
1248
  idx += num_statistics
1249
  new_states = []
1250
- for state, new_preconditioners in zip(states, preconditioners_for_states):
 
 
 
 
 
 
 
 
 
 
 
1251
  new_states.append(
1252
  ParameterStats(
1253
  state.diagonal_statistics,
@@ -1255,6 +1565,7 @@ def distributed_shampoo(
1255
  new_preconditioners,
1256
  state.diagonal_momentum,
1257
  state.momentum,
 
1258
  )
1259
  )
1260
 
@@ -1413,6 +1724,7 @@ def distributed_shampoo(
1413
  new_quantized_preconditioners_flat = []
1414
  new_quantized_diagonals_flat = []
1415
  new_quantized_bucket_sizes_flat = []
 
1416
  for p, d, b, shape, prev_p, error in zip(
1417
  quantized_preconditioners_flat,
1418
  quantized_diagonals_flat,
@@ -1432,6 +1744,7 @@ def distributed_shampoo(
1432
  new_quantized_bucket_sizes_flat.append(
1433
  _select_preconditioner(error, b[: shape[0]], prev_p.bucket_size)
1434
  )
 
1435
 
1436
  assert len(states) == len(num_statistics_per_state)
1437
  assert len(new_quantized_preconditioners_flat) == num_statistics
@@ -1440,10 +1753,12 @@ def distributed_shampoo(
1440
 
1441
  # Add back empty preconditioners so we that we can set the optimizer state.
1442
  preconditioners_for_states = []
 
1443
  idx = 0
1444
  for num_statistics, state in zip(num_statistics_per_state, states):
1445
  if num_statistics == 0:
1446
  preconditioners_for_states.append([])
 
1447
  else:
1448
  quantized_preconditioners_for_state = (
1449
  new_quantized_preconditioners_flat[idx : idx + num_statistics]
@@ -1454,10 +1769,14 @@ def distributed_shampoo(
1454
  quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
1455
  idx : idx + num_statistics
1456
  ]
 
 
 
1457
 
1458
  assert len(state.statistics) == len(quantized_preconditioners_for_state)
1459
  assert len(state.statistics) == len(quantized_diagonals_for_state)
1460
  assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
 
1461
 
1462
  quantized_preconditioners = []
1463
  for qv, qd, qb in zip(
@@ -1469,9 +1788,21 @@ def distributed_shampoo(
1469
  QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape))
1470
  )
1471
  preconditioners_for_states.append(quantized_preconditioners)
 
1472
  idx += num_statistics
1473
  new_states = []
1474
- for state, new_preconditioners in zip(states, preconditioners_for_states):
 
 
 
 
 
 
 
 
 
 
 
1475
  new_states.append(
1476
  ParameterStats(
1477
  state.diagonal_statistics,
@@ -1479,6 +1810,7 @@ def distributed_shampoo(
1479
  new_preconditioners,
1480
  state.diagonal_momentum,
1481
  state.momentum,
 
1482
  )
1483
  )
1484
 
@@ -1560,31 +1892,53 @@ def distributed_shampoo(
1560
  )
1561
 
1562
  new_preconditioners_flat = []
 
1563
  for p, shape, prev_p, error in zip(
1564
  preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
1565
  ):
1566
  new_preconditioners_flat.append(
1567
  _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
1568
  )
 
1569
 
1570
  assert len(states) == len(num_statistics_per_state)
1571
  assert len(new_preconditioners_flat) == num_statistics
1572
 
1573
  # Add back empty preconditioners so we that we can set the optimizer state.
1574
  preconditioners_for_states = []
 
1575
  idx = 0
1576
  for num_statistics, state in zip(num_statistics_per_state, states):
1577
  if num_statistics == 0:
1578
  preconditioners_for_states.append([])
 
1579
  else:
1580
  preconditioners_for_state = new_preconditioners_flat[
1581
  idx : idx + num_statistics
1582
  ]
1583
  assert len(state.statistics) == len(preconditioners_for_state)
1584
  preconditioners_for_states.append(preconditioners_for_state)
 
 
 
 
 
 
1585
  idx += num_statistics
 
1586
  new_states = []
1587
- for state, new_preconditioners in zip(states, preconditioners_for_states):
 
 
 
 
 
 
 
 
 
 
 
1588
  new_states.append(
1589
  ParameterStats(
1590
  state.diagonal_statistics,
@@ -1592,6 +1946,7 @@ def distributed_shampoo(
1592
  new_preconditioners,
1593
  state.diagonal_momentum,
1594
  state.momentum,
 
1595
  )
1596
  )
1597
 
@@ -1778,7 +2133,9 @@ def distributed_shampoo(
1778
  state.preconditioners,
1779
  _quantize_momentum(grafting_update_with_wd_momentum),
1780
  _quantize_momentum(shampoo_update_with_wd_momentum),
 
1781
  )
 
1782
  return transformed_update, param_stats
1783
 
1784
  def update_fn(grads, state, params):
@@ -1821,6 +2178,15 @@ def distributed_shampoo(
1821
  return updates, new_state
1822
 
1823
  if shard_optimizer_states:
1824
- return optax.GradientTransformation(sharded_init_fn, sharded_update_fn)
 
 
 
 
 
 
 
 
 
1825
  else:
1826
  return optax.GradientTransformation(init_fn, update_fn)
 
 
 
1
  # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
145
  return val
146
 
147
 
148
+ @struct.dataclass
149
+ class TrainingMetrics:
150
+ inverse_pth_root_errors: chex.Array # Error for inverse-pth roots.
151
+ # TODO(rohananil): Add more important metrics to track during training.
152
+
153
+
154
  # Per parameter optimizer state used in data-parallel training.
155
  class ParameterStats(NamedTuple):
156
  """State associated to each parameter of the model being trained."""
 
160
  preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
161
  diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
162
  momentum: QuantizedValue # Momentum for the shampoo preconditioner
163
+ training_metrics: TrainingMetrics # Metrics (optional for training).
164
 
165
 
166
  # For training extremely large model; We keep a global state with a concatenated
 
171
  class GlobalShardedParameterStats:
172
  statistics: chex.Array # Statistics
173
  preconditioners: chex.Array # Preconditioners
174
+ exponents: chex.Array # exponents
175
 
176
 
177
  # These are per-parameter local states; All statistics here mirror the parameter
 
183
  diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
184
  diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
185
  momentum: QuantizedValue # Momentum for the shampoo preconditioner
186
+ training_metrics: TrainingMetrics # Metrics (optional for training).
187
  index_start: np.int32 = struct.field(
188
  pytree_node=False
189
  ) # Index into global statistics array
190
  sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
191
 
192
 
193
+ def init_training_metrics(num_statistics):
194
+ if num_statistics:
195
+ return TrainingMetrics(jnp.zeros([num_statistics], jnp.float32))
196
+ else:
197
+ return TrainingMetrics([])
198
+
199
+
200
+ def init_training_metrics_shapes(num_statistics):
201
+ if num_statistics:
202
+ return TrainingMetrics([[num_statistics], jnp.float32])
203
+ else:
204
+ return TrainingMetrics([None, jnp.float32])
205
+
206
+
207
+ def init_training_metrics_pspec(num_statistics):
208
+ if num_statistics:
209
+ return TrainingMetrics(pjit.PartitionSpec())
210
+ else:
211
+ return TrainingMetrics(None)
212
+
213
+
214
  class ShardedShampooStats(NamedTuple):
215
  """Shampoo state in sharded mode."""
216
 
 
223
  stats: Any
224
 
225
 
226
+ class InitFnState(NamedTuple):
227
+ init_fn: Any
228
+ pspec_fn: Any
229
+ shape_and_dtype_fn: Any
230
+
231
+
232
  class GraftingType(enum.IntEnum):
233
  SGD = 1
234
  ADAGRAD = 2
 
326
  matrix^(-1/p)
327
  """
328
 
329
+ assert matrix.shape[0] == matrix.shape[1]
330
+
331
  # We use float32 for the matrix inverse pth root.
332
  # Switch to f64 if you have hardware that supports it.
333
  matrix_size = matrix.shape[0]
 
651
  new_preconditioners,
652
  local_stat.diagonal_momentum,
653
  local_stat.momentum,
654
+ local_stat.training_metrics,
655
  )
656
 
657
 
 
661
  parameter_stats.diagonal_statistics,
662
  parameter_stats.diagonal_momentum,
663
  parameter_stats.momentum,
664
+ parameter_stats.training_metrics,
665
  local_stats.index_start,
666
  local_stats.sizes,
667
  )
668
 
669
 
670
+ def _add_error_into_local_stats(local_stats, errors, inverse_failure_threshold):
671
+ """Adds errors back into local statistics."""
672
+ new_local_stats = []
673
+ for local_stat in local_stats:
674
+ index_start = int(local_stat.index_start)
675
+ index_end = int(len(local_stat.sizes)) + index_start
676
+ per_stat_error = errors[index_start:index_end]
677
+ if local_stat.sizes:
678
+ per_stat_error = jnp.where(
679
+ jnp.logical_and(
680
+ per_stat_error > 0.0, per_stat_error != inverse_failure_threshold
681
+ ),
682
+ per_stat_error,
683
+ local_stat.training_metrics.inverse_pth_root_errors,
684
+ )
685
+ new_local_stats.append(
686
+ LocalShardedParameterStats(
687
+ local_stat.diagonal_statistics,
688
+ local_stat.diagonal_momentum,
689
+ local_stat.momentum,
690
+ TrainingMetrics(per_stat_error),
691
+ local_stat.index_start,
692
+ local_stat.sizes,
693
+ )
694
+ )
695
+ return new_local_stats
696
+
697
+
698
  def batch(x, num_devices):
699
  """Batch `x` so that so that leading axis is num_devices."""
700
  n = len(x)
 
736
  batch_axis_name=None,
737
  ### Only set following 3 params in pjit/spmd mode.
738
  ### WARNING: Experimental
739
+ statistics_partition_spec=None,
740
+ preconditioner_partition_spec=None,
741
  num_devices_for_pjit=None,
742
  shard_optimizer_states=False,
743
  ###
 
797
  exponent_override: Override the exponent used in matrix inverse.
798
  batch_axis_name: labeled axis over pmap for data-parallel training the
799
  optimizer used for.
800
+ statistics_partition_spec: PartitionSpec to be used in sharded mode.
801
+ preconditioner_partition_spec: PartitionSpec to be used in sharded mode.
802
  num_devices_for_pjit: Number of devices to parallelize over when using pjit.
803
  shard_optimizer_states: Shard optimizer states to save memory in model
804
  parallel training.
 
898
  )
899
 
900
  def sharded_init_fn(params):
901
+ """Returns optimizer state (for PJIT mode).
902
+
903
+ Args:
904
+ params: the parameters that should be updated.
905
+ """
906
  params_flat, treedef = jax.tree_flatten(params)
907
  # Find max size to pad to.
908
  max_size = 0
 
918
  padded_statistics = []
919
  padded_preconditioners = []
920
  local_stats_flat = []
921
+ exponents = []
922
  for param in params_flat:
923
  preconditioner = Preconditioner(
924
  param, block_size, best_effort_shape_interpretation
 
936
  preconditioners = [jnp.eye(max_size) for s in shapes]
937
  padded_statistics.extend(statistics)
938
  padded_preconditioners.extend(preconditioners)
939
+ exponent = (
940
+ preconditioner.exponent_for_preconditioner()
941
+ if exponent_override == 0
942
+ else exponent_override
943
+ )
944
+ exponents.extend([exponent] * len(shapes))
945
 
946
  diagonal_statistics = []
947
  if graft_type != GraftingType.SGD:
 
951
  _quantize_diagonal_statistics(diagonal_statistics),
952
  _quantize_momentum(jnp.zeros_like(param)),
953
  _quantize_momentum(jnp.zeros_like(param)),
954
+ init_training_metrics(len(sizes)),
955
  index_start,
956
  sizes,
957
  )
 
969
  padded_preconditioners.extend(
970
  [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
971
  )
972
+ exponents.extend([1 for _ in range(to_pad)])
973
  global_stats = GlobalShardedParameterStats(
974
+ jnp.stack(padded_statistics),
975
+ jnp.stack(padded_preconditioners),
976
+ jnp.stack(exponents),
977
  )
978
  return ShampooState(
979
  count=jnp.zeros([], jnp.int32),
980
  stats=ShardedShampooStats(global_stats, local_stats),
981
  )
982
 
983
+ def _max_statistics_size_from_params(params):
984
+ max_size = 0
985
+ for param in params:
986
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
987
+ preconditioner = Preconditioner(
988
+ param_clone, block_size, best_effort_shape_interpretation
989
+ )
990
+ if not _skip_preconditioning(param):
991
+ shapes = preconditioner.shapes_for_preconditioners()
992
+ sizes = [s[0] for s in shapes]
993
+ max_size = max(max(sizes), max_size)
994
+ return max_size
995
+
996
+ def _remove_leading_sharding_annotation(pspec):
997
+ """Mapping from N-d to (N-1)-d, used for quantization, factoring etc."""
998
+ # None and PSpec(None) are valid PSpecs.
999
+ if pspec and len(pspec) > 1:
1000
+ return pjit.PartitionSpec(*pspec[1:])
1001
+ else:
1002
+ return None
1003
+
1004
+ def sharded_init_partition_spec_fn(
1005
+ params, params_partition_spec, partition_spec_for_statistics
1006
+ ):
1007
+ """Returns a parallel state tree with PartitionSpec associated with state.
1008
+
1009
+
1010
+ Args:
1011
+ params: A pytree with params.
1012
+ params_partition_spec: A pytree with PartitionSpec for params.
1013
+ partition_spec_for_statistics: PartitionSpec for the statistics.
1014
+ """
1015
+ # Parallel lists of spec, and params.
1016
+ param_pspec_flat, _ = jax.tree_flatten(
1017
+ params_partition_spec, is_leaf=lambda x: x is None
1018
+ )
1019
+ params_flat, treedef = jax.tree_flatten(params)
1020
+ assert param_pspec_flat
1021
+ assert params_flat
1022
+ # Step is replicated across cores.
1023
+ # None means cores.
1024
+ local_stats_flat = []
1025
+ num_statistics = 0
1026
+ for param, param_pspec in zip(params_flat, param_pspec_flat):
1027
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
1028
+ preconditioner = Preconditioner(
1029
+ param_clone, block_size, best_effort_shape_interpretation
1030
+ )
1031
+ shapes = preconditioner.shapes_for_preconditioners()
1032
+ sizes = []
1033
+
1034
+ index_start = num_statistics
1035
+ if not _skip_preconditioning(param):
1036
+ sizes = [s[0] for s in shapes]
1037
+ shapes = preconditioner.shapes_for_preconditioners()
1038
+ num_statistics += len(shapes)
1039
+
1040
+ diagonal_statistics_pspec = []
1041
+ diagonal_statistics_scale_pspec = []
1042
+ if graft_type != GraftingType.SGD:
1043
+ # Identically shaped param.
1044
+ diagonal_statistics_pspec = param_pspec
1045
+ if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
1046
+ diagonal_statistics_scale_pspec = (
1047
+ _remove_leading_sharding_annotation(param_pspec)
1048
+ )
1049
+
1050
+ m1_pspec = param_pspec
1051
+ m2_pspec = param_pspec
1052
+
1053
+ m1_scale_pspec = []
1054
+ m2_scale_pspec = []
1055
+
1056
+ if quantized_dtype_for_momentum_buffers() != jnp.float32:
1057
+ m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
1058
+ m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec)
1059
+
1060
+ local_stats_flat.append(
1061
+ LocalShardedParameterStats(
1062
+ QuantizedValue(
1063
+ diagonal_statistics_pspec,
1064
+ [],
1065
+ diagonal_statistics_scale_pspec,
1066
+ quantized_dtype_for_diagonal_statistics_buffers(),
1067
+ False,
1068
+ list(param.shape),
1069
+ ),
1070
+ QuantizedValue(
1071
+ m1_pspec,
1072
+ [],
1073
+ m1_scale_pspec,
1074
+ quantized_dtype_for_momentum_buffers(),
1075
+ False,
1076
+ list(param.shape),
1077
+ ),
1078
+ QuantizedValue(
1079
+ m2_pspec,
1080
+ [],
1081
+ m2_scale_pspec,
1082
+ quantized_dtype_for_momentum_buffers(),
1083
+ False,
1084
+ list(param.shape),
1085
+ ),
1086
+ init_training_metrics_pspec(len(sizes)),
1087
+ index_start,
1088
+ sizes,
1089
+ )
1090
+ )
1091
+
1092
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
1093
+ global_stats = GlobalShardedParameterStats(
1094
+ partition_spec_for_statistics,
1095
+ partition_spec_for_statistics,
1096
+ pjit.PartitionSpec(),
1097
+ )
1098
+ count_pspec = pjit.PartitionSpec()
1099
+ return ShampooState(
1100
+ count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats)
1101
+ )
1102
+
1103
+ def sharded_init_shape_and_dtype_fn(params):
1104
+ """Returns a parallel state tree with shape, dtype associated with state.
1105
+
1106
+
1107
+ Args:
1108
+ params: A pytree with params.
1109
+ """
1110
+ # Parallel lists of spec, and params.
1111
+ params_flat, treedef = jax.tree_flatten(params)
1112
+ assert params_flat
1113
+ # Step is replicated across cores.
1114
+ # None means cores.
1115
+ local_stats_flat = []
1116
+ num_statistics = 0
1117
+ for param in params_flat:
1118
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
1119
+ preconditioner = Preconditioner(
1120
+ param_clone, block_size, best_effort_shape_interpretation
1121
+ )
1122
+ shapes = preconditioner.shapes_for_preconditioners()
1123
+ sizes = []
1124
+
1125
+ index_start = num_statistics
1126
+ if not _skip_preconditioning(param):
1127
+ sizes = [s[0] for s in shapes]
1128
+ shapes = preconditioner.shapes_for_preconditioners()
1129
+ num_statistics += len(shapes)
1130
+
1131
+ diagonal_statistics_shape_and_dtype = []
1132
+ diagonal_statistics_scale_shape_and_dtype = []
1133
+ if graft_type != GraftingType.SGD:
1134
+ diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype]
1135
+ qdtype = quantized_dtype_for_diagonal_statistics_buffers()
1136
+ if qdtype != jnp.float32:
1137
+ diagonal_statistics_shape_and_dtype = [list(param.shape), qdtype]
1138
+ diagonal_statistics_scale_shape_and_dtype = [
1139
+ list(param.shape)[1:],
1140
+ param.dtype,
1141
+ ]
1142
+
1143
+ m1_shape_and_dtype = [list(param.shape), param.dtype]
1144
+ m2_shape_and_dtype = [list(param.shape), param.dtype]
1145
+
1146
+ m1_scale_shape_and_dtype = []
1147
+ m2_scale_shape_and_dtype = []
1148
+
1149
+ qdtype = quantized_dtype_for_momentum_buffers()
1150
+ if qdtype != jnp.float32:
1151
+ m1_shape_and_dtype = [list(param.shape), qdtype]
1152
+ m2_shape_and_dtype = [list(param.shape), qdtype]
1153
+
1154
+ m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1155
+ m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1156
+
1157
+ local_stats_flat.append(
1158
+ LocalShardedParameterStats(
1159
+ QuantizedValue(
1160
+ diagonal_statistics_shape_and_dtype,
1161
+ [],
1162
+ diagonal_statistics_scale_shape_and_dtype,
1163
+ quantized_dtype_for_diagonal_statistics_buffers(),
1164
+ False,
1165
+ list(param.shape),
1166
+ ),
1167
+ QuantizedValue(
1168
+ m1_shape_and_dtype,
1169
+ [],
1170
+ m1_scale_shape_and_dtype,
1171
+ quantized_dtype_for_momentum_buffers(),
1172
+ False,
1173
+ list(param.shape),
1174
+ ),
1175
+ QuantizedValue(
1176
+ m2_shape_and_dtype,
1177
+ [],
1178
+ m2_scale_shape_and_dtype,
1179
+ quantized_dtype_for_momentum_buffers(),
1180
+ False,
1181
+ list(param.shape),
1182
+ ),
1183
+ init_training_metrics_shapes(len(sizes)),
1184
+ index_start,
1185
+ sizes,
1186
+ )
1187
+ )
1188
+
1189
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
1190
+ max_statistics_size = _max_statistics_size_from_params(params_flat)
1191
+ to_pad = -num_statistics % num_devices_for_pjit
1192
+ num_statistics += to_pad
1193
+ statistics_shape = [num_statistics, max_statistics_size, max_statistics_size]
1194
+ global_stats = GlobalShardedParameterStats(
1195
+ [statistics_shape, jnp.float32],
1196
+ [statistics_shape, jnp.float32],
1197
+ [[num_statistics], jnp.int32],
1198
+ )
1199
+ return ShampooState(
1200
+ count=[[], jnp.float32],
1201
+ stats=ShardedShampooStats(global_stats, local_stats),
1202
+ )
1203
+
1204
  def sharded_update_fn(grads, state, params):
1205
  """Transform the input gradient and update all statistics in sharded mode.
1206
 
 
1228
  params_flat,
1229
  )
1230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1231
  outputs = jax.tree_multimap(
1232
  lambda g, s, p: _transform_grad(g, s, p, state.count),
1233
  grads_flat,
 
1242
  _convert_from_parameter_stats(new_stat, local_stat)
1243
  for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
1244
  ]
 
1245
 
1246
  max_size = global_stats.statistics.shape[1]
1247
  new_padded_statistics = []
 
1264
  for _ in range(to_pad)
1265
  ]
1266
  )
 
1267
  new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
1268
+ new_stacked_padded_statistics = pjit.with_sharding_constraint(
1269
+ new_stacked_padded_statistics, statistics_partition_spec
1270
+ )
 
 
 
 
 
 
 
1271
 
1272
  def _internal_inverse_pth_root_all():
1273
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
1274
+ new_stacked_padded_statistics,
1275
+ global_stats.exponents,
1276
+ statistics_partition_spec,
1277
  )
1278
  return preconditioners, errors
1279
 
 
1284
  # shaped tensors. Note statistics will be ignored as we are passing in
1285
  # a large init value for error.
1286
  preconditioners_init = new_stacked_padded_statistics
1287
+ n = new_stacked_padded_statistics.shape[0]
1288
+ errors_init = jnp.ones([n], jnp.float32) * inverse_failure_threshold
1289
  init_state = [preconditioners_init, errors_init]
1290
  perform_step = state.count % preconditioning_compute_steps == 0
1291
  new_preconditioners, errors = efficient_cond(
1292
  perform_step, _internal_inverse_pth_root_all, init_state
1293
  )
1294
 
1295
+ new_local_stats_flat = _add_error_into_local_stats(
1296
+ new_local_stats_flat, errors, inverse_failure_threshold
1297
+ )
1298
+ new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
1299
  errors = errors.reshape((-1, 1, 1))
1300
  predicate = jnp.logical_or(
1301
  jnp.isnan(errors), errors >= inverse_failure_threshold
 
1306
  + (1.0 - predicate) * new_preconditioners
1307
  )
1308
  new_global_stats = GlobalShardedParameterStats(
1309
+ new_stacked_padded_statistics,
1310
+ new_conditional_preconditioners,
1311
+ global_stats.exponents,
1312
  )
1313
  new_shampoo_state = ShampooState(
1314
  count=state.count + 1,
 
1339
  _maybe_quantize_preconditioners(preconditioners),
1340
  _quantize_momentum(jnp.zeros_like(param)),
1341
  _quantize_momentum(jnp.zeros_like(param)),
1342
+ init_training_metrics(len(statistics)),
1343
  )
1344
 
1345
  return ShampooState(
 
1384
  state.preconditioners,
1385
  state.diagonal_momentum,
1386
  state.momentum,
1387
+ state.training_metrics,
1388
  )
1389
 
1390
  def _matrix_inverse_pth_root_vmap(xs, ps):
 
1408
 
1409
  return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
1410
 
1411
+ def _matrix_inverse_pth_root_pjit(xs, ps, statistics_partition_spec=None):
 
1412
  # Partition the concatenated statistics matrix across all cores.
1413
+ pspec_for_partition = preconditioner_partition_spec
1414
+ partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition)
1415
+ partitioned_ps = pjit.with_sharding_constraint(
1416
+ ps, pjit.PartitionSpec(preconditioner_partition_spec[0])
1417
+ )
 
 
1418
  # Run matrix inverse pth root on each shard.
1419
  partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
1420
  partitioned_xs, partitioned_ps
1421
  )
1422
+ # Reshard output to have the same PSpec as input. This is required to avoid
1423
+ # vmap seeing the full set of statistics.
1424
+ partitioned_preconditioners = pjit.with_sharding_constraint(
1425
+ partitioned_preconditioners, pspec_for_partition
1426
+ )
1427
  # Recombine the outputs at each core.
1428
+ preconditioners = pjit.with_sharding_constraint(
1429
+ partitioned_preconditioners, statistics_partition_spec
1430
+ )
1431
+ errors = pjit.with_sharding_constraint(partitioned_errors, pjit.PartitionSpec())
 
 
 
 
 
 
 
 
1432
  return preconditioners, errors
1433
 
1434
  def _pmap_compute_preconditioners(
 
1510
  )
1511
 
1512
  new_preconditioners_flat = []
1513
+ new_errors_flat = []
1514
  for p, shape, prev_p, error in zip(
1515
  preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
1516
  ):
1517
  new_preconditioners_flat.append(
1518
  _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
1519
  )
1520
+ new_errors_flat.append(error)
1521
 
1522
  assert len(states) == len(num_statistics_per_state)
1523
  assert len(new_preconditioners_flat) == num_statistics
1524
+ assert len(new_errors_flat) == num_statistics
1525
 
1526
  # Add back empty preconditioners so we that we can set the optimizer state.
1527
  preconditioners_for_states = []
1528
  idx = 0
1529
+ errors_for_states = []
1530
  for num_statistics, state in zip(num_statistics_per_state, states):
1531
  if num_statistics == 0:
1532
  preconditioners_for_states.append([])
1533
+ errors_for_states.append([])
1534
  else:
1535
  preconditioners_for_state = new_preconditioners_flat[
1536
  idx : idx + num_statistics
1537
  ]
1538
  assert len(state.statistics) == len(preconditioners_for_state)
1539
  preconditioners_for_states.append(preconditioners_for_state)
1540
+
1541
+ errors_for_state = jnp.stack(
1542
+ new_errors_flat[idx : idx + num_statistics]
1543
+ )
1544
+ assert len(state.statistics) == len(errors_for_state)
1545
+ errors_for_states.append(errors_for_state)
1546
+
1547
  idx += num_statistics
1548
  new_states = []
1549
+ for state, new_preconditioners, new_errors in zip(
1550
+ states, preconditioners_for_states, errors_for_states
1551
+ ):
1552
+ if state.statistics:
1553
+ new_errors = jnp.where(
1554
+ jnp.logical_and(
1555
+ new_errors > 0.0, new_errors != inverse_failure_threshold
1556
+ ),
1557
+ new_errors,
1558
+ state.training_metrics.inverse_pth_root_errors,
1559
+ )
1560
+ new_training_metrics = TrainingMetrics(new_errors)
1561
  new_states.append(
1562
  ParameterStats(
1563
  state.diagonal_statistics,
 
1565
  new_preconditioners,
1566
  state.diagonal_momentum,
1567
  state.momentum,
1568
+ new_training_metrics,
1569
  )
1570
  )
1571
 
 
1724
  new_quantized_preconditioners_flat = []
1725
  new_quantized_diagonals_flat = []
1726
  new_quantized_bucket_sizes_flat = []
1727
+ new_errors_flat = []
1728
  for p, d, b, shape, prev_p, error in zip(
1729
  quantized_preconditioners_flat,
1730
  quantized_diagonals_flat,
 
1744
  new_quantized_bucket_sizes_flat.append(
1745
  _select_preconditioner(error, b[: shape[0]], prev_p.bucket_size)
1746
  )
1747
+ new_errors_flat.append(error)
1748
 
1749
  assert len(states) == len(num_statistics_per_state)
1750
  assert len(new_quantized_preconditioners_flat) == num_statistics
 
1753
 
1754
  # Add back empty preconditioners so we that we can set the optimizer state.
1755
  preconditioners_for_states = []
1756
+ errors_for_states = []
1757
  idx = 0
1758
  for num_statistics, state in zip(num_statistics_per_state, states):
1759
  if num_statistics == 0:
1760
  preconditioners_for_states.append([])
1761
+ errors_for_states.append([])
1762
  else:
1763
  quantized_preconditioners_for_state = (
1764
  new_quantized_preconditioners_flat[idx : idx + num_statistics]
 
1769
  quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
1770
  idx : idx + num_statistics
1771
  ]
1772
+ errors_for_state = jnp.stack(
1773
+ new_errors_flat[idx : idx + num_statistics]
1774
+ )
1775
 
1776
  assert len(state.statistics) == len(quantized_preconditioners_for_state)
1777
  assert len(state.statistics) == len(quantized_diagonals_for_state)
1778
  assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
1779
+ assert len(state.statistics) == len(errors_for_state)
1780
 
1781
  quantized_preconditioners = []
1782
  for qv, qd, qb in zip(
 
1788
  QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape))
1789
  )
1790
  preconditioners_for_states.append(quantized_preconditioners)
1791
+ errors_for_states.append(errors_for_state)
1792
  idx += num_statistics
1793
  new_states = []
1794
+ for state, new_preconditioners, new_errors in zip(
1795
+ states, preconditioners_for_states, errors_for_states
1796
+ ):
1797
+ if state.statistics:
1798
+ new_errors = jnp.where(
1799
+ jnp.logical_and(
1800
+ new_errors > 0.0, new_errors != inverse_failure_threshold
1801
+ ),
1802
+ new_errors,
1803
+ state.training_metrics.inverse_pth_root_errors,
1804
+ )
1805
+ new_training_metrics = TrainingMetrics(new_errors)
1806
  new_states.append(
1807
  ParameterStats(
1808
  state.diagonal_statistics,
 
1810
  new_preconditioners,
1811
  state.diagonal_momentum,
1812
  state.momentum,
1813
+ new_training_metrics,
1814
  )
1815
  )
1816
 
 
1892
  )
1893
 
1894
  new_preconditioners_flat = []
1895
+ new_errors_flat = []
1896
  for p, shape, prev_p, error in zip(
1897
  preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
1898
  ):
1899
  new_preconditioners_flat.append(
1900
  _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
1901
  )
1902
+ new_errors_flat.append(error)
1903
 
1904
  assert len(states) == len(num_statistics_per_state)
1905
  assert len(new_preconditioners_flat) == num_statistics
1906
 
1907
  # Add back empty preconditioners so we that we can set the optimizer state.
1908
  preconditioners_for_states = []
1909
+ errors_for_states = []
1910
  idx = 0
1911
  for num_statistics, state in zip(num_statistics_per_state, states):
1912
  if num_statistics == 0:
1913
  preconditioners_for_states.append([])
1914
+ errors_for_states.append([])
1915
  else:
1916
  preconditioners_for_state = new_preconditioners_flat[
1917
  idx : idx + num_statistics
1918
  ]
1919
  assert len(state.statistics) == len(preconditioners_for_state)
1920
  preconditioners_for_states.append(preconditioners_for_state)
1921
+
1922
+ errors_for_state = jnp.stack(
1923
+ new_errors_flat[idx : idx + num_statistics]
1924
+ )
1925
+ assert len(state.statistics) == len(errors_for_state)
1926
+ errors_for_states.append(errors_for_state)
1927
  idx += num_statistics
1928
+
1929
  new_states = []
1930
+ for state, new_preconditioners, new_errors in zip(
1931
+ states, preconditioners_for_states, errors_for_states
1932
+ ):
1933
+ if state.statistics:
1934
+ new_errors = jnp.where(
1935
+ jnp.logical_and(
1936
+ new_errors > 0.0, new_errors != inverse_failure_threshold
1937
+ ),
1938
+ new_errors,
1939
+ state.training_metrics.inverse_pth_root_errors,
1940
+ )
1941
+ new_training_metrics = TrainingMetrics(new_errors)
1942
  new_states.append(
1943
  ParameterStats(
1944
  state.diagonal_statistics,
 
1946
  new_preconditioners,
1947
  state.diagonal_momentum,
1948
  state.momentum,
1949
+ new_training_metrics,
1950
  )
1951
  )
1952
 
 
2133
  state.preconditioners,
2134
  _quantize_momentum(grafting_update_with_wd_momentum),
2135
  _quantize_momentum(shampoo_update_with_wd_momentum),
2136
+ state.training_metrics,
2137
  )
2138
+
2139
  return transformed_update, param_stats
2140
 
2141
  def update_fn(grads, state, params):
 
2178
  return updates, new_state
2179
 
2180
  if shard_optimizer_states:
2181
+ # Hijacks the init_fn signature so we can return an OptState with
2182
+ # appropriate init_fns.
2183
+ def _init_fns(unused_params):
2184
+ return InitFnState(
2185
+ init_fn=sharded_init_fn,
2186
+ pspec_fn=sharded_init_partition_spec_fn,
2187
+ shape_and_dtype_fn=sharded_init_shape_and_dtype_fn,
2188
+ )
2189
+
2190
+ return optax.GradientTransformation(_init_fns, sharded_update_fn)
2191
  else:
2192
  return optax.GradientTransformation(init_fn, update_fn)