boris commited on
Commit
e4401dd
2 Parent(s): f5239e1 2b7f5f1

Merge pull request #127 from borisdayma/pjit-t5x

Browse files

feat(train): pjit optimization and distributed shampoo support

src/dalle_mini/data.py CHANGED
@@ -152,24 +152,15 @@ class Dataset:
152
  ),
153
  )
154
 
155
- def dataloader(
156
- self, split, per_device_batch_size, gradient_accumulation_steps=None, epoch=None
157
- ):
158
- num_devices = jax.local_device_count()
159
-
160
  def _dataloader_datasets_non_streaming(
161
  dataset: Dataset,
162
- per_device_batch_size: int,
163
- gradient_accumulation_steps: int,
164
  rng: jax.random.PRNGKey = None,
165
  ):
166
  """
167
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
168
  Shuffle batches if rng is set.
169
  """
170
- batch_size = (
171
- per_device_batch_size * num_devices * gradient_accumulation_steps
172
- )
173
  steps_per_epoch = len(dataset) // batch_size
174
 
175
  if rng is not None:
@@ -185,18 +176,10 @@ class Dataset:
185
  for idx in batch_idx:
186
  batch = dataset[idx]
187
  batch = {k: jnp.array(v) for k, v in batch.items()}
188
- if gradient_accumulation_steps is not None:
189
- batch = jax.tree_map(
190
- lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
191
- batch,
192
- )
193
  yield batch
194
 
195
  def _dataloader_datasets_streaming(
196
  dataset: Dataset,
197
- split: str,
198
- per_device_batch_size: int,
199
- gradient_accumulation_steps: int,
200
  epoch: int,
201
  ):
202
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
@@ -208,28 +191,15 @@ class Dataset:
208
  # For validation data we put the entire set on each host as we could lose
209
  # too many samples on pods
210
  if epoch is not None:
211
- # reshuffle training data at each epoch (not applicable with validation set)
 
212
  dataset.set_epoch(epoch)
213
  epoch += 1
214
  for item in dataset:
215
  for k, v in item.items():
216
  batch[k].append(v)
217
- # batch = 5, devices = 8, accumulation = 2 / batch_size = 5 x 8
218
- # (40, 3, 3) -> shard 8 x (5, 3, 3)
219
- # (16, 5, 3, 3) -> shard 8 x (2, 5, 3, 3)
220
- if len(batch[keys[0]]) == per_device_batch_size * num_devices * (
221
- gradient_accumulation_steps
222
- if gradient_accumulation_steps is not None
223
- else 1
224
- ):
225
  batch = {k: jnp.array(v) for k, v in batch.items()}
226
- if gradient_accumulation_steps is not None:
227
- batch = jax.tree_map(
228
- lambda x: x.reshape(
229
- (-1, per_device_batch_size) + x.shape[1:]
230
- ),
231
- batch,
232
- )
233
  yield batch
234
  batch = {k: [] for k in keys}
235
  first_loop = False
@@ -242,15 +212,11 @@ class Dataset:
242
  raise ValueError(f'split must be "train" or "eval", got {split}')
243
 
244
  if self.streaming:
245
- return _dataloader_datasets_streaming(
246
- ds, split, per_device_batch_size, gradient_accumulation_steps, epoch
247
- )
248
  else:
249
  if split == "train":
250
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
251
- return _dataloader_datasets_non_streaming(
252
- ds, per_device_batch_size, gradient_accumulation_steps, input_rng
253
- )
254
 
255
  @property
256
  def length(self):
 
152
  ),
153
  )
154
 
155
+ def dataloader(self, split, batch_size, epoch=None):
 
 
 
 
156
  def _dataloader_datasets_non_streaming(
157
  dataset: Dataset,
 
 
158
  rng: jax.random.PRNGKey = None,
159
  ):
160
  """
161
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
162
  Shuffle batches if rng is set.
163
  """
 
 
 
164
  steps_per_epoch = len(dataset) // batch_size
165
 
166
  if rng is not None:
 
176
  for idx in batch_idx:
177
  batch = dataset[idx]
178
  batch = {k: jnp.array(v) for k, v in batch.items()}
 
 
 
 
 
179
  yield batch
180
 
181
  def _dataloader_datasets_streaming(
182
  dataset: Dataset,
 
 
 
183
  epoch: int,
184
  ):
185
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
 
191
  # For validation data we put the entire set on each host as we could lose
192
  # too many samples on pods
193
  if epoch is not None:
194
+ assert split == "train"
195
+ # reshuffle training data at each epoch
196
  dataset.set_epoch(epoch)
197
  epoch += 1
198
  for item in dataset:
199
  for k, v in item.items():
200
  batch[k].append(v)
201
+ if len(batch[keys[0]]) == batch_size:
 
 
 
 
 
 
 
202
  batch = {k: jnp.array(v) for k, v in batch.items()}
 
 
 
 
 
 
 
203
  yield batch
204
  batch = {k: [] for k in keys}
205
  first_loop = False
 
212
  raise ValueError(f'split must be "train" or "eval", got {split}')
213
 
214
  if self.streaming:
215
+ return _dataloader_datasets_streaming(ds, epoch)
 
 
216
  else:
217
  if split == "train":
218
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
219
+ return _dataloader_datasets_non_streaming(ds, input_rng)
 
 
220
 
221
  @property
222
  def length(self):
src/dalle_mini/model/modeling.py CHANGED
@@ -312,7 +312,7 @@ class FlaxBartPreTrainedModel(FlaxBartPreTrainedModel):
312
  seed: int = 0,
313
  dtype: jnp.dtype = jnp.float32,
314
  abstract_init: bool = False,
315
- load_on_cpu: bool = True,
316
  **kwargs,
317
  ):
318
  module = self.module_class(config=config, dtype=dtype, **kwargs)
 
312
  seed: int = 0,
313
  dtype: jnp.dtype = jnp.float32,
314
  abstract_init: bool = False,
315
+ load_on_cpu: bool = False,
316
  **kwargs,
317
  ):
318
  module = self.module_class(config=config, dtype=dtype, **kwargs)
tools/train/distributed_shampoo.py CHANGED
@@ -1,7 +1,5 @@
1
- """File copied from https://github.com/google-research/google-research/edit/master/scalable_shampoo/optax/distributed_shampoo.py"""
2
-
3
  # coding=utf-8
4
- # Copyright 2021 The Google Research Authors.
5
  #
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
  # you may not use this file except in compliance with the License.
@@ -147,6 +145,12 @@ class QuantizedValue:
147
  return val
148
 
149
 
 
 
 
 
 
 
150
  # Per parameter optimizer state used in data-parallel training.
151
  class ParameterStats(NamedTuple):
152
  """State associated to each parameter of the model being trained."""
@@ -156,6 +160,7 @@ class ParameterStats(NamedTuple):
156
  preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
157
  diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
158
  momentum: QuantizedValue # Momentum for the shampoo preconditioner
 
159
 
160
 
161
  # For training extremely large model; We keep a global state with a concatenated
@@ -166,6 +171,7 @@ class ParameterStats(NamedTuple):
166
  class GlobalShardedParameterStats:
167
  statistics: chex.Array # Statistics
168
  preconditioners: chex.Array # Preconditioners
 
169
 
170
 
171
  # These are per-parameter local states; All statistics here mirror the parameter
@@ -177,12 +183,34 @@ class LocalShardedParameterStats:
177
  diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
178
  diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
179
  momentum: QuantizedValue # Momentum for the shampoo preconditioner
 
180
  index_start: np.int32 = struct.field(
181
  pytree_node=False
182
  ) # Index into global statistics array
183
  sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  class ShardedShampooStats(NamedTuple):
187
  """Shampoo state in sharded mode."""
188
 
@@ -195,6 +223,12 @@ class ShampooState(NamedTuple):
195
  stats: Any
196
 
197
 
 
 
 
 
 
 
198
  class GraftingType(enum.IntEnum):
199
  SGD = 1
200
  ADAGRAD = 2
@@ -292,6 +326,8 @@ def matrix_inverse_pth_root(
292
  matrix^(-1/p)
293
  """
294
 
 
 
295
  # We use float32 for the matrix inverse pth root.
296
  # Switch to f64 if you have hardware that supports it.
297
  matrix_size = matrix.shape[0]
@@ -615,6 +651,7 @@ def _convert_to_parameter_stats(global_stats, local_stat):
615
  new_preconditioners,
616
  local_stat.diagonal_momentum,
617
  local_stat.momentum,
 
618
  )
619
 
620
 
@@ -624,11 +661,40 @@ def _convert_from_parameter_stats(parameter_stats, local_stats):
624
  parameter_stats.diagonal_statistics,
625
  parameter_stats.diagonal_momentum,
626
  parameter_stats.momentum,
 
627
  local_stats.index_start,
628
  local_stats.sizes,
629
  )
630
 
631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
  def batch(x, num_devices):
633
  """Batch `x` so that so that leading axis is num_devices."""
634
  n = len(x)
@@ -670,7 +736,8 @@ def distributed_shampoo(
670
  batch_axis_name=None,
671
  ### Only set following 3 params in pjit/spmd mode.
672
  ### WARNING: Experimental
673
- mesh_axis_names=None,
 
674
  num_devices_for_pjit=None,
675
  shard_optimizer_states=False,
676
  ###
@@ -730,7 +797,8 @@ def distributed_shampoo(
730
  exponent_override: Override the exponent used in matrix inverse.
731
  batch_axis_name: labeled axis over pmap for data-parallel training the
732
  optimizer used for.
733
- mesh_axis_names: Axis names for the mesh (used in pjit).
 
734
  num_devices_for_pjit: Number of devices to parallelize over when using pjit.
735
  shard_optimizer_states: Shard optimizer states to save memory in model
736
  parallel training.
@@ -830,6 +898,11 @@ def distributed_shampoo(
830
  )
831
 
832
  def sharded_init_fn(params):
 
 
 
 
 
833
  params_flat, treedef = jax.tree_flatten(params)
834
  # Find max size to pad to.
835
  max_size = 0
@@ -845,6 +918,7 @@ def distributed_shampoo(
845
  padded_statistics = []
846
  padded_preconditioners = []
847
  local_stats_flat = []
 
848
  for param in params_flat:
849
  preconditioner = Preconditioner(
850
  param, block_size, best_effort_shape_interpretation
@@ -862,6 +936,12 @@ def distributed_shampoo(
862
  preconditioners = [jnp.eye(max_size) for s in shapes]
863
  padded_statistics.extend(statistics)
864
  padded_preconditioners.extend(preconditioners)
 
 
 
 
 
 
865
 
866
  diagonal_statistics = []
867
  if graft_type != GraftingType.SGD:
@@ -871,6 +951,7 @@ def distributed_shampoo(
871
  _quantize_diagonal_statistics(diagonal_statistics),
872
  _quantize_momentum(jnp.zeros_like(param)),
873
  _quantize_momentum(jnp.zeros_like(param)),
 
874
  index_start,
875
  sizes,
876
  )
@@ -888,14 +969,238 @@ def distributed_shampoo(
888
  padded_preconditioners.extend(
889
  [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
890
  )
 
891
  global_stats = GlobalShardedParameterStats(
892
- jnp.stack(padded_statistics), jnp.stack(padded_preconditioners)
 
 
893
  )
894
  return ShampooState(
895
  count=jnp.zeros([], jnp.int32),
896
  stats=ShardedShampooStats(global_stats, local_stats),
897
  )
898
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
899
  def sharded_update_fn(grads, state, params):
900
  """Transform the input gradient and update all statistics in sharded mode.
901
 
@@ -923,20 +1228,6 @@ def distributed_shampoo(
923
  params_flat,
924
  )
925
 
926
- exponents = []
927
- for stat, param in zip(new_stats_flat, params_flat):
928
- num_statistics = len(stat.statistics)
929
- if num_statistics > 0:
930
- preconditioner = Preconditioner(
931
- param, block_size, best_effort_shape_interpretation
932
- )
933
- exponent = (
934
- preconditioner.exponent_for_preconditioner()
935
- if exponent_override == 0
936
- else exponent_override
937
- )
938
- exponents.extend([exponent] * num_statistics)
939
-
940
  outputs = jax.tree_multimap(
941
  lambda g, s, p: _transform_grad(g, s, p, state.count),
942
  grads_flat,
@@ -951,7 +1242,6 @@ def distributed_shampoo(
951
  _convert_from_parameter_stats(new_stat, local_stat)
952
  for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
953
  ]
954
- new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
955
 
956
  max_size = global_stats.statistics.shape[1]
957
  new_padded_statistics = []
@@ -974,22 +1264,16 @@ def distributed_shampoo(
974
  for _ in range(to_pad)
975
  ]
976
  )
977
- exponents.extend([1 for _ in range(to_pad)])
978
  new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
979
- new_stacked_exponents = jnp.stack(exponents)
980
-
981
- def _matrix_inverse_pth_root_vmap(xs, ps):
982
- mi_pth_root = functools.partial(
983
- matrix_inverse_pth_root,
984
- ridge_epsilon=matrix_epsilon,
985
- precision=precision,
986
- )
987
- preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
988
- return preconditioners, errors
989
 
990
  def _internal_inverse_pth_root_all():
991
- preconditioners, errors = _matrix_inverse_pth_root_vmap(
992
- new_stacked_padded_statistics, new_stacked_exponents
 
 
993
  )
994
  return preconditioners, errors
995
 
@@ -1000,13 +1284,18 @@ def distributed_shampoo(
1000
  # shaped tensors. Note statistics will be ignored as we are passing in
1001
  # a large init value for error.
1002
  preconditioners_init = new_stacked_padded_statistics
1003
- errors_init = np.stack([inverse_failure_threshold] * len(exponents))
 
1004
  init_state = [preconditioners_init, errors_init]
1005
  perform_step = state.count % preconditioning_compute_steps == 0
1006
  new_preconditioners, errors = efficient_cond(
1007
  perform_step, _internal_inverse_pth_root_all, init_state
1008
  )
1009
 
 
 
 
 
1010
  errors = errors.reshape((-1, 1, 1))
1011
  predicate = jnp.logical_or(
1012
  jnp.isnan(errors), errors >= inverse_failure_threshold
@@ -1017,7 +1306,9 @@ def distributed_shampoo(
1017
  + (1.0 - predicate) * new_preconditioners
1018
  )
1019
  new_global_stats = GlobalShardedParameterStats(
1020
- new_stacked_padded_statistics, new_conditional_preconditioners
 
 
1021
  )
1022
  new_shampoo_state = ShampooState(
1023
  count=state.count + 1,
@@ -1048,6 +1339,7 @@ def distributed_shampoo(
1048
  _maybe_quantize_preconditioners(preconditioners),
1049
  _quantize_momentum(jnp.zeros_like(param)),
1050
  _quantize_momentum(jnp.zeros_like(param)),
 
1051
  )
1052
 
1053
  return ShampooState(
@@ -1092,6 +1384,7 @@ def distributed_shampoo(
1092
  state.preconditioners,
1093
  state.diagonal_momentum,
1094
  state.momentum,
 
1095
  )
1096
 
1097
  def _matrix_inverse_pth_root_vmap(xs, ps):
@@ -1115,33 +1408,27 @@ def distributed_shampoo(
1115
 
1116
  return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
1117
 
1118
- def _matrix_inverse_pth_root_pjit(xs, ps):
1119
- mesh_axis_names_tuple = tuple(mesh_axis_names)
1120
  # Partition the concatenated statistics matrix across all cores.
1121
- partitioned_xs, partitioned_ps = pjit.pjit(
1122
- lambda x, y: (x, y),
1123
- in_axis_resources=None,
1124
- out_axis_resources=pjit.PartitionSpec(
1125
- mesh_axis_names_tuple,
1126
- ),
1127
- )(xs, ps)
1128
  # Run matrix inverse pth root on each shard.
1129
  partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
1130
  partitioned_xs, partitioned_ps
1131
  )
 
 
 
 
 
1132
  # Recombine the outputs at each core.
1133
- preconditioners, errors = pjit.pjit(
1134
- lambda x, y: (x, y),
1135
- in_axis_resources=(
1136
- pjit.PartitionSpec(
1137
- mesh_axis_names_tuple,
1138
- ),
1139
- pjit.PartitionSpec(
1140
- mesh_axis_names_tuple,
1141
- ),
1142
- ),
1143
- out_axis_resources=(None, None),
1144
- )(partitioned_preconditioners, partitioned_errors)
1145
  return preconditioners, errors
1146
 
1147
  def _pmap_compute_preconditioners(
@@ -1223,31 +1510,54 @@ def distributed_shampoo(
1223
  )
1224
 
1225
  new_preconditioners_flat = []
 
1226
  for p, shape, prev_p, error in zip(
1227
  preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
1228
  ):
1229
  new_preconditioners_flat.append(
1230
  _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
1231
  )
 
1232
 
1233
  assert len(states) == len(num_statistics_per_state)
1234
  assert len(new_preconditioners_flat) == num_statistics
 
1235
 
1236
  # Add back empty preconditioners so we that we can set the optimizer state.
1237
  preconditioners_for_states = []
1238
  idx = 0
 
1239
  for num_statistics, state in zip(num_statistics_per_state, states):
1240
  if num_statistics == 0:
1241
  preconditioners_for_states.append([])
 
1242
  else:
1243
  preconditioners_for_state = new_preconditioners_flat[
1244
  idx : idx + num_statistics
1245
  ]
1246
  assert len(state.statistics) == len(preconditioners_for_state)
1247
  preconditioners_for_states.append(preconditioners_for_state)
 
 
 
 
 
 
 
1248
  idx += num_statistics
1249
  new_states = []
1250
- for state, new_preconditioners in zip(states, preconditioners_for_states):
 
 
 
 
 
 
 
 
 
 
 
1251
  new_states.append(
1252
  ParameterStats(
1253
  state.diagonal_statistics,
@@ -1255,6 +1565,7 @@ def distributed_shampoo(
1255
  new_preconditioners,
1256
  state.diagonal_momentum,
1257
  state.momentum,
 
1258
  )
1259
  )
1260
 
@@ -1413,6 +1724,7 @@ def distributed_shampoo(
1413
  new_quantized_preconditioners_flat = []
1414
  new_quantized_diagonals_flat = []
1415
  new_quantized_bucket_sizes_flat = []
 
1416
  for p, d, b, shape, prev_p, error in zip(
1417
  quantized_preconditioners_flat,
1418
  quantized_diagonals_flat,
@@ -1432,6 +1744,7 @@ def distributed_shampoo(
1432
  new_quantized_bucket_sizes_flat.append(
1433
  _select_preconditioner(error, b[: shape[0]], prev_p.bucket_size)
1434
  )
 
1435
 
1436
  assert len(states) == len(num_statistics_per_state)
1437
  assert len(new_quantized_preconditioners_flat) == num_statistics
@@ -1440,10 +1753,12 @@ def distributed_shampoo(
1440
 
1441
  # Add back empty preconditioners so we that we can set the optimizer state.
1442
  preconditioners_for_states = []
 
1443
  idx = 0
1444
  for num_statistics, state in zip(num_statistics_per_state, states):
1445
  if num_statistics == 0:
1446
  preconditioners_for_states.append([])
 
1447
  else:
1448
  quantized_preconditioners_for_state = (
1449
  new_quantized_preconditioners_flat[idx : idx + num_statistics]
@@ -1454,10 +1769,14 @@ def distributed_shampoo(
1454
  quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
1455
  idx : idx + num_statistics
1456
  ]
 
 
 
1457
 
1458
  assert len(state.statistics) == len(quantized_preconditioners_for_state)
1459
  assert len(state.statistics) == len(quantized_diagonals_for_state)
1460
  assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
 
1461
 
1462
  quantized_preconditioners = []
1463
  for qv, qd, qb in zip(
@@ -1469,9 +1788,21 @@ def distributed_shampoo(
1469
  QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape))
1470
  )
1471
  preconditioners_for_states.append(quantized_preconditioners)
 
1472
  idx += num_statistics
1473
  new_states = []
1474
- for state, new_preconditioners in zip(states, preconditioners_for_states):
 
 
 
 
 
 
 
 
 
 
 
1475
  new_states.append(
1476
  ParameterStats(
1477
  state.diagonal_statistics,
@@ -1479,6 +1810,7 @@ def distributed_shampoo(
1479
  new_preconditioners,
1480
  state.diagonal_momentum,
1481
  state.momentum,
 
1482
  )
1483
  )
1484
 
@@ -1560,31 +1892,53 @@ def distributed_shampoo(
1560
  )
1561
 
1562
  new_preconditioners_flat = []
 
1563
  for p, shape, prev_p, error in zip(
1564
  preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
1565
  ):
1566
  new_preconditioners_flat.append(
1567
  _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
1568
  )
 
1569
 
1570
  assert len(states) == len(num_statistics_per_state)
1571
  assert len(new_preconditioners_flat) == num_statistics
1572
 
1573
  # Add back empty preconditioners so we that we can set the optimizer state.
1574
  preconditioners_for_states = []
 
1575
  idx = 0
1576
  for num_statistics, state in zip(num_statistics_per_state, states):
1577
  if num_statistics == 0:
1578
  preconditioners_for_states.append([])
 
1579
  else:
1580
  preconditioners_for_state = new_preconditioners_flat[
1581
  idx : idx + num_statistics
1582
  ]
1583
  assert len(state.statistics) == len(preconditioners_for_state)
1584
  preconditioners_for_states.append(preconditioners_for_state)
 
 
 
 
 
 
1585
  idx += num_statistics
 
1586
  new_states = []
1587
- for state, new_preconditioners in zip(states, preconditioners_for_states):
 
 
 
 
 
 
 
 
 
 
 
1588
  new_states.append(
1589
  ParameterStats(
1590
  state.diagonal_statistics,
@@ -1592,6 +1946,7 @@ def distributed_shampoo(
1592
  new_preconditioners,
1593
  state.diagonal_momentum,
1594
  state.momentum,
 
1595
  )
1596
  )
1597
 
@@ -1778,7 +2133,9 @@ def distributed_shampoo(
1778
  state.preconditioners,
1779
  _quantize_momentum(grafting_update_with_wd_momentum),
1780
  _quantize_momentum(shampoo_update_with_wd_momentum),
 
1781
  )
 
1782
  return transformed_update, param_stats
1783
 
1784
  def update_fn(grads, state, params):
@@ -1821,6 +2178,15 @@ def distributed_shampoo(
1821
  return updates, new_state
1822
 
1823
  if shard_optimizer_states:
1824
- return optax.GradientTransformation(sharded_init_fn, sharded_update_fn)
 
 
 
 
 
 
 
 
 
1825
  else:
1826
  return optax.GradientTransformation(init_fn, update_fn)
 
 
 
1
  # coding=utf-8
2
+ # Copyright 2022 The Google Research Authors.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
145
  return val
146
 
147
 
148
+ @struct.dataclass
149
+ class TrainingMetrics:
150
+ inverse_pth_root_errors: chex.Array # Error for inverse-pth roots.
151
+ # TODO(rohananil): Add more important metrics to track during training.
152
+
153
+
154
  # Per parameter optimizer state used in data-parallel training.
155
  class ParameterStats(NamedTuple):
156
  """State associated to each parameter of the model being trained."""
 
160
  preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
161
  diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
162
  momentum: QuantizedValue # Momentum for the shampoo preconditioner
163
+ training_metrics: TrainingMetrics # Metrics (optional for training).
164
 
165
 
166
  # For training extremely large model; We keep a global state with a concatenated
 
171
  class GlobalShardedParameterStats:
172
  statistics: chex.Array # Statistics
173
  preconditioners: chex.Array # Preconditioners
174
+ exponents: chex.Array # exponents
175
 
176
 
177
  # These are per-parameter local states; All statistics here mirror the parameter
 
183
  diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
184
  diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
185
  momentum: QuantizedValue # Momentum for the shampoo preconditioner
186
+ training_metrics: TrainingMetrics # Metrics (optional for training).
187
  index_start: np.int32 = struct.field(
188
  pytree_node=False
189
  ) # Index into global statistics array
190
  sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
191
 
192
 
193
+ def init_training_metrics(num_statistics):
194
+ if num_statistics:
195
+ return TrainingMetrics(jnp.zeros([num_statistics], jnp.float32))
196
+ else:
197
+ return TrainingMetrics([])
198
+
199
+
200
+ def init_training_metrics_shapes(num_statistics):
201
+ if num_statistics:
202
+ return TrainingMetrics([[num_statistics], jnp.float32])
203
+ else:
204
+ return TrainingMetrics([None, jnp.float32])
205
+
206
+
207
+ def init_training_metrics_pspec(num_statistics):
208
+ if num_statistics:
209
+ return TrainingMetrics(pjit.PartitionSpec())
210
+ else:
211
+ return TrainingMetrics(None)
212
+
213
+
214
  class ShardedShampooStats(NamedTuple):
215
  """Shampoo state in sharded mode."""
216
 
 
223
  stats: Any
224
 
225
 
226
+ class InitFnState(NamedTuple):
227
+ init_fn: Any
228
+ pspec_fn: Any
229
+ shape_and_dtype_fn: Any
230
+
231
+
232
  class GraftingType(enum.IntEnum):
233
  SGD = 1
234
  ADAGRAD = 2
 
326
  matrix^(-1/p)
327
  """
328
 
329
+ assert matrix.shape[0] == matrix.shape[1]
330
+
331
  # We use float32 for the matrix inverse pth root.
332
  # Switch to f64 if you have hardware that supports it.
333
  matrix_size = matrix.shape[0]
 
651
  new_preconditioners,
652
  local_stat.diagonal_momentum,
653
  local_stat.momentum,
654
+ local_stat.training_metrics,
655
  )
656
 
657
 
 
661
  parameter_stats.diagonal_statistics,
662
  parameter_stats.diagonal_momentum,
663
  parameter_stats.momentum,
664
+ parameter_stats.training_metrics,
665
  local_stats.index_start,
666
  local_stats.sizes,
667
  )
668
 
669
 
670
+ def _add_error_into_local_stats(local_stats, errors, inverse_failure_threshold):
671
+ """Adds errors back into local statistics."""
672
+ new_local_stats = []
673
+ for local_stat in local_stats:
674
+ index_start = int(local_stat.index_start)
675
+ index_end = int(len(local_stat.sizes)) + index_start
676
+ per_stat_error = errors[index_start:index_end]
677
+ if local_stat.sizes:
678
+ per_stat_error = jnp.where(
679
+ jnp.logical_and(
680
+ per_stat_error > 0.0, per_stat_error != inverse_failure_threshold
681
+ ),
682
+ per_stat_error,
683
+ local_stat.training_metrics.inverse_pth_root_errors,
684
+ )
685
+ new_local_stats.append(
686
+ LocalShardedParameterStats(
687
+ local_stat.diagonal_statistics,
688
+ local_stat.diagonal_momentum,
689
+ local_stat.momentum,
690
+ TrainingMetrics(per_stat_error),
691
+ local_stat.index_start,
692
+ local_stat.sizes,
693
+ )
694
+ )
695
+ return new_local_stats
696
+
697
+
698
  def batch(x, num_devices):
699
  """Batch `x` so that so that leading axis is num_devices."""
700
  n = len(x)
 
736
  batch_axis_name=None,
737
  ### Only set following 3 params in pjit/spmd mode.
738
  ### WARNING: Experimental
739
+ statistics_partition_spec=None,
740
+ preconditioner_partition_spec=None,
741
  num_devices_for_pjit=None,
742
  shard_optimizer_states=False,
743
  ###
 
797
  exponent_override: Override the exponent used in matrix inverse.
798
  batch_axis_name: labeled axis over pmap for data-parallel training the
799
  optimizer used for.
800
+ statistics_partition_spec: PartitionSpec to be used in sharded mode.
801
+ preconditioner_partition_spec: PartitionSpec to be used in sharded mode.
802
  num_devices_for_pjit: Number of devices to parallelize over when using pjit.
803
  shard_optimizer_states: Shard optimizer states to save memory in model
804
  parallel training.
 
898
  )
899
 
900
  def sharded_init_fn(params):
901
+ """Returns optimizer state (for PJIT mode).
902
+
903
+ Args:
904
+ params: the parameters that should be updated.
905
+ """
906
  params_flat, treedef = jax.tree_flatten(params)
907
  # Find max size to pad to.
908
  max_size = 0
 
918
  padded_statistics = []
919
  padded_preconditioners = []
920
  local_stats_flat = []
921
+ exponents = []
922
  for param in params_flat:
923
  preconditioner = Preconditioner(
924
  param, block_size, best_effort_shape_interpretation
 
936
  preconditioners = [jnp.eye(max_size) for s in shapes]
937
  padded_statistics.extend(statistics)
938
  padded_preconditioners.extend(preconditioners)
939
+ exponent = (
940
+ preconditioner.exponent_for_preconditioner()
941
+ if exponent_override == 0
942
+ else exponent_override
943
+ )
944
+ exponents.extend([exponent] * len(shapes))
945
 
946
  diagonal_statistics = []
947
  if graft_type != GraftingType.SGD:
 
951
  _quantize_diagonal_statistics(diagonal_statistics),
952
  _quantize_momentum(jnp.zeros_like(param)),
953
  _quantize_momentum(jnp.zeros_like(param)),
954
+ init_training_metrics(len(sizes)),
955
  index_start,
956
  sizes,
957
  )
 
969
  padded_preconditioners.extend(
970
  [jnp.eye(max_size, dtype=padded_statistics[0].dtype) for _ in range(to_pad)]
971
  )
972
+ exponents.extend([1 for _ in range(to_pad)])
973
  global_stats = GlobalShardedParameterStats(
974
+ jnp.stack(padded_statistics),
975
+ jnp.stack(padded_preconditioners),
976
+ jnp.stack(exponents),
977
  )
978
  return ShampooState(
979
  count=jnp.zeros([], jnp.int32),
980
  stats=ShardedShampooStats(global_stats, local_stats),
981
  )
982
 
983
+ def _max_statistics_size_from_params(params):
984
+ max_size = 0
985
+ for param in params:
986
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
987
+ preconditioner = Preconditioner(
988
+ param_clone, block_size, best_effort_shape_interpretation
989
+ )
990
+ if not _skip_preconditioning(param):
991
+ shapes = preconditioner.shapes_for_preconditioners()
992
+ sizes = [s[0] for s in shapes]
993
+ max_size = max(max(sizes), max_size)
994
+ return max_size
995
+
996
+ def _remove_leading_sharding_annotation(pspec):
997
+ """Mapping from N-d to (N-1)-d, used for quantization, factoring etc."""
998
+ # None and PSpec(None) are valid PSpecs.
999
+ if pspec and len(pspec) > 1:
1000
+ return pjit.PartitionSpec(*pspec[1:])
1001
+ else:
1002
+ return None
1003
+
1004
+ def sharded_init_partition_spec_fn(
1005
+ params, params_partition_spec, partition_spec_for_statistics
1006
+ ):
1007
+ """Returns a parallel state tree with PartitionSpec associated with state.
1008
+
1009
+
1010
+ Args:
1011
+ params: A pytree with params.
1012
+ params_partition_spec: A pytree with PartitionSpec for params.
1013
+ partition_spec_for_statistics: PartitionSpec for the statistics.
1014
+ """
1015
+ # Parallel lists of spec, and params.
1016
+ param_pspec_flat, _ = jax.tree_flatten(
1017
+ params_partition_spec, is_leaf=lambda x: x is None
1018
+ )
1019
+ params_flat, treedef = jax.tree_flatten(params)
1020
+ assert param_pspec_flat
1021
+ assert params_flat
1022
+ # Step is replicated across cores.
1023
+ # None means cores.
1024
+ local_stats_flat = []
1025
+ num_statistics = 0
1026
+ for param, param_pspec in zip(params_flat, param_pspec_flat):
1027
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
1028
+ preconditioner = Preconditioner(
1029
+ param_clone, block_size, best_effort_shape_interpretation
1030
+ )
1031
+ shapes = preconditioner.shapes_for_preconditioners()
1032
+ sizes = []
1033
+
1034
+ index_start = num_statistics
1035
+ if not _skip_preconditioning(param):
1036
+ sizes = [s[0] for s in shapes]
1037
+ shapes = preconditioner.shapes_for_preconditioners()
1038
+ num_statistics += len(shapes)
1039
+
1040
+ diagonal_statistics_pspec = []
1041
+ diagonal_statistics_scale_pspec = []
1042
+ if graft_type != GraftingType.SGD:
1043
+ # Identically shaped param.
1044
+ diagonal_statistics_pspec = param_pspec
1045
+ if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
1046
+ diagonal_statistics_scale_pspec = (
1047
+ _remove_leading_sharding_annotation(param_pspec)
1048
+ )
1049
+
1050
+ m1_pspec = param_pspec
1051
+ m2_pspec = param_pspec
1052
+
1053
+ m1_scale_pspec = []
1054
+ m2_scale_pspec = []
1055
+
1056
+ if quantized_dtype_for_momentum_buffers() != jnp.float32:
1057
+ m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
1058
+ m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec)
1059
+
1060
+ local_stats_flat.append(
1061
+ LocalShardedParameterStats(
1062
+ QuantizedValue(
1063
+ diagonal_statistics_pspec,
1064
+ [],
1065
+ diagonal_statistics_scale_pspec,
1066
+ quantized_dtype_for_diagonal_statistics_buffers(),
1067
+ False,
1068
+ list(param.shape),
1069
+ ),
1070
+ QuantizedValue(
1071
+ m1_pspec,
1072
+ [],
1073
+ m1_scale_pspec,
1074
+ quantized_dtype_for_momentum_buffers(),
1075
+ False,
1076
+ list(param.shape),
1077
+ ),
1078
+ QuantizedValue(
1079
+ m2_pspec,
1080
+ [],
1081
+ m2_scale_pspec,
1082
+ quantized_dtype_for_momentum_buffers(),
1083
+ False,
1084
+ list(param.shape),
1085
+ ),
1086
+ init_training_metrics_pspec(len(sizes)),
1087
+ index_start,
1088
+ sizes,
1089
+ )
1090
+ )
1091
+
1092
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
1093
+ global_stats = GlobalShardedParameterStats(
1094
+ partition_spec_for_statistics,
1095
+ partition_spec_for_statistics,
1096
+ pjit.PartitionSpec(),
1097
+ )
1098
+ count_pspec = pjit.PartitionSpec()
1099
+ return ShampooState(
1100
+ count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats)
1101
+ )
1102
+
1103
+ def sharded_init_shape_and_dtype_fn(params):
1104
+ """Returns a parallel state tree with shape, dtype associated with state.
1105
+
1106
+
1107
+ Args:
1108
+ params: A pytree with params.
1109
+ """
1110
+ # Parallel lists of spec, and params.
1111
+ params_flat, treedef = jax.tree_flatten(params)
1112
+ assert params_flat
1113
+ # Step is replicated across cores.
1114
+ # None means cores.
1115
+ local_stats_flat = []
1116
+ num_statistics = 0
1117
+ for param in params_flat:
1118
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
1119
+ preconditioner = Preconditioner(
1120
+ param_clone, block_size, best_effort_shape_interpretation
1121
+ )
1122
+ shapes = preconditioner.shapes_for_preconditioners()
1123
+ sizes = []
1124
+
1125
+ index_start = num_statistics
1126
+ if not _skip_preconditioning(param):
1127
+ sizes = [s[0] for s in shapes]
1128
+ shapes = preconditioner.shapes_for_preconditioners()
1129
+ num_statistics += len(shapes)
1130
+
1131
+ diagonal_statistics_shape_and_dtype = []
1132
+ diagonal_statistics_scale_shape_and_dtype = []
1133
+ if graft_type != GraftingType.SGD:
1134
+ diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype]
1135
+ qdtype = quantized_dtype_for_diagonal_statistics_buffers()
1136
+ if qdtype != jnp.float32:
1137
+ diagonal_statistics_shape_and_dtype = [list(param.shape), qdtype]
1138
+ diagonal_statistics_scale_shape_and_dtype = [
1139
+ list(param.shape)[1:],
1140
+ param.dtype,
1141
+ ]
1142
+
1143
+ m1_shape_and_dtype = [list(param.shape), param.dtype]
1144
+ m2_shape_and_dtype = [list(param.shape), param.dtype]
1145
+
1146
+ m1_scale_shape_and_dtype = []
1147
+ m2_scale_shape_and_dtype = []
1148
+
1149
+ qdtype = quantized_dtype_for_momentum_buffers()
1150
+ if qdtype != jnp.float32:
1151
+ m1_shape_and_dtype = [list(param.shape), qdtype]
1152
+ m2_shape_and_dtype = [list(param.shape), qdtype]
1153
+
1154
+ m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1155
+ m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1156
+
1157
+ local_stats_flat.append(
1158
+ LocalShardedParameterStats(
1159
+ QuantizedValue(
1160
+ diagonal_statistics_shape_and_dtype,
1161
+ [],
1162
+ diagonal_statistics_scale_shape_and_dtype,
1163
+ quantized_dtype_for_diagonal_statistics_buffers(),
1164
+ False,
1165
+ list(param.shape),
1166
+ ),
1167
+ QuantizedValue(
1168
+ m1_shape_and_dtype,
1169
+ [],
1170
+ m1_scale_shape_and_dtype,
1171
+ quantized_dtype_for_momentum_buffers(),
1172
+ False,
1173
+ list(param.shape),
1174
+ ),
1175
+ QuantizedValue(
1176
+ m2_shape_and_dtype,
1177
+ [],
1178
+ m2_scale_shape_and_dtype,
1179
+ quantized_dtype_for_momentum_buffers(),
1180
+ False,
1181
+ list(param.shape),
1182
+ ),
1183
+ init_training_metrics_shapes(len(sizes)),
1184
+ index_start,
1185
+ sizes,
1186
+ )
1187
+ )
1188
+
1189
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
1190
+ max_statistics_size = _max_statistics_size_from_params(params_flat)
1191
+ to_pad = -num_statistics % num_devices_for_pjit
1192
+ num_statistics += to_pad
1193
+ statistics_shape = [num_statistics, max_statistics_size, max_statistics_size]
1194
+ global_stats = GlobalShardedParameterStats(
1195
+ [statistics_shape, jnp.float32],
1196
+ [statistics_shape, jnp.float32],
1197
+ [[num_statistics], jnp.int32],
1198
+ )
1199
+ return ShampooState(
1200
+ count=[[], jnp.float32],
1201
+ stats=ShardedShampooStats(global_stats, local_stats),
1202
+ )
1203
+
1204
  def sharded_update_fn(grads, state, params):
1205
  """Transform the input gradient and update all statistics in sharded mode.
1206
 
 
1228
  params_flat,
1229
  )
1230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1231
  outputs = jax.tree_multimap(
1232
  lambda g, s, p: _transform_grad(g, s, p, state.count),
1233
  grads_flat,
 
1242
  _convert_from_parameter_stats(new_stat, local_stat)
1243
  for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
1244
  ]
 
1245
 
1246
  max_size = global_stats.statistics.shape[1]
1247
  new_padded_statistics = []
 
1264
  for _ in range(to_pad)
1265
  ]
1266
  )
 
1267
  new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
1268
+ new_stacked_padded_statistics = pjit.with_sharding_constraint(
1269
+ new_stacked_padded_statistics, statistics_partition_spec
1270
+ )
 
 
 
 
 
 
 
1271
 
1272
  def _internal_inverse_pth_root_all():
1273
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
1274
+ new_stacked_padded_statistics,
1275
+ global_stats.exponents,
1276
+ statistics_partition_spec,
1277
  )
1278
  return preconditioners, errors
1279
 
 
1284
  # shaped tensors. Note statistics will be ignored as we are passing in
1285
  # a large init value for error.
1286
  preconditioners_init = new_stacked_padded_statistics
1287
+ n = new_stacked_padded_statistics.shape[0]
1288
+ errors_init = jnp.ones([n], jnp.float32) * inverse_failure_threshold
1289
  init_state = [preconditioners_init, errors_init]
1290
  perform_step = state.count % preconditioning_compute_steps == 0
1291
  new_preconditioners, errors = efficient_cond(
1292
  perform_step, _internal_inverse_pth_root_all, init_state
1293
  )
1294
 
1295
+ new_local_stats_flat = _add_error_into_local_stats(
1296
+ new_local_stats_flat, errors, inverse_failure_threshold
1297
+ )
1298
+ new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
1299
  errors = errors.reshape((-1, 1, 1))
1300
  predicate = jnp.logical_or(
1301
  jnp.isnan(errors), errors >= inverse_failure_threshold
 
1306
  + (1.0 - predicate) * new_preconditioners
1307
  )
1308
  new_global_stats = GlobalShardedParameterStats(
1309
+ new_stacked_padded_statistics,
1310
+ new_conditional_preconditioners,
1311
+ global_stats.exponents,
1312
  )
1313
  new_shampoo_state = ShampooState(
1314
  count=state.count + 1,
 
1339
  _maybe_quantize_preconditioners(preconditioners),
1340
  _quantize_momentum(jnp.zeros_like(param)),
1341
  _quantize_momentum(jnp.zeros_like(param)),
1342
+ init_training_metrics(len(statistics)),
1343
  )
1344
 
1345
  return ShampooState(
 
1384
  state.preconditioners,
1385
  state.diagonal_momentum,
1386
  state.momentum,
1387
+ state.training_metrics,
1388
  )
1389
 
1390
  def _matrix_inverse_pth_root_vmap(xs, ps):
 
1408
 
1409
  return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
1410
 
1411
+ def _matrix_inverse_pth_root_pjit(xs, ps, statistics_partition_spec=None):
 
1412
  # Partition the concatenated statistics matrix across all cores.
1413
+ pspec_for_partition = preconditioner_partition_spec
1414
+ partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition)
1415
+ partitioned_ps = pjit.with_sharding_constraint(
1416
+ ps, pjit.PartitionSpec(preconditioner_partition_spec[0])
1417
+ )
 
 
1418
  # Run matrix inverse pth root on each shard.
1419
  partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
1420
  partitioned_xs, partitioned_ps
1421
  )
1422
+ # Reshard output to have the same PSpec as input. This is required to avoid
1423
+ # vmap seeing the full set of statistics.
1424
+ partitioned_preconditioners = pjit.with_sharding_constraint(
1425
+ partitioned_preconditioners, pspec_for_partition
1426
+ )
1427
  # Recombine the outputs at each core.
1428
+ preconditioners = pjit.with_sharding_constraint(
1429
+ partitioned_preconditioners, statistics_partition_spec
1430
+ )
1431
+ errors = pjit.with_sharding_constraint(partitioned_errors, pjit.PartitionSpec())
 
 
 
 
 
 
 
 
1432
  return preconditioners, errors
1433
 
1434
  def _pmap_compute_preconditioners(
 
1510
  )
1511
 
1512
  new_preconditioners_flat = []
1513
+ new_errors_flat = []
1514
  for p, shape, prev_p, error in zip(
1515
  preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
1516
  ):
1517
  new_preconditioners_flat.append(
1518
  _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
1519
  )
1520
+ new_errors_flat.append(error)
1521
 
1522
  assert len(states) == len(num_statistics_per_state)
1523
  assert len(new_preconditioners_flat) == num_statistics
1524
+ assert len(new_errors_flat) == num_statistics
1525
 
1526
  # Add back empty preconditioners so we that we can set the optimizer state.
1527
  preconditioners_for_states = []
1528
  idx = 0
1529
+ errors_for_states = []
1530
  for num_statistics, state in zip(num_statistics_per_state, states):
1531
  if num_statistics == 0:
1532
  preconditioners_for_states.append([])
1533
+ errors_for_states.append([])
1534
  else:
1535
  preconditioners_for_state = new_preconditioners_flat[
1536
  idx : idx + num_statistics
1537
  ]
1538
  assert len(state.statistics) == len(preconditioners_for_state)
1539
  preconditioners_for_states.append(preconditioners_for_state)
1540
+
1541
+ errors_for_state = jnp.stack(
1542
+ new_errors_flat[idx : idx + num_statistics]
1543
+ )
1544
+ assert len(state.statistics) == len(errors_for_state)
1545
+ errors_for_states.append(errors_for_state)
1546
+
1547
  idx += num_statistics
1548
  new_states = []
1549
+ for state, new_preconditioners, new_errors in zip(
1550
+ states, preconditioners_for_states, errors_for_states
1551
+ ):
1552
+ if state.statistics:
1553
+ new_errors = jnp.where(
1554
+ jnp.logical_and(
1555
+ new_errors > 0.0, new_errors != inverse_failure_threshold
1556
+ ),
1557
+ new_errors,
1558
+ state.training_metrics.inverse_pth_root_errors,
1559
+ )
1560
+ new_training_metrics = TrainingMetrics(new_errors)
1561
  new_states.append(
1562
  ParameterStats(
1563
  state.diagonal_statistics,
 
1565
  new_preconditioners,
1566
  state.diagonal_momentum,
1567
  state.momentum,
1568
+ new_training_metrics,
1569
  )
1570
  )
1571
 
 
1724
  new_quantized_preconditioners_flat = []
1725
  new_quantized_diagonals_flat = []
1726
  new_quantized_bucket_sizes_flat = []
1727
+ new_errors_flat = []
1728
  for p, d, b, shape, prev_p, error in zip(
1729
  quantized_preconditioners_flat,
1730
  quantized_diagonals_flat,
 
1744
  new_quantized_bucket_sizes_flat.append(
1745
  _select_preconditioner(error, b[: shape[0]], prev_p.bucket_size)
1746
  )
1747
+ new_errors_flat.append(error)
1748
 
1749
  assert len(states) == len(num_statistics_per_state)
1750
  assert len(new_quantized_preconditioners_flat) == num_statistics
 
1753
 
1754
  # Add back empty preconditioners so we that we can set the optimizer state.
1755
  preconditioners_for_states = []
1756
+ errors_for_states = []
1757
  idx = 0
1758
  for num_statistics, state in zip(num_statistics_per_state, states):
1759
  if num_statistics == 0:
1760
  preconditioners_for_states.append([])
1761
+ errors_for_states.append([])
1762
  else:
1763
  quantized_preconditioners_for_state = (
1764
  new_quantized_preconditioners_flat[idx : idx + num_statistics]
 
1769
  quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
1770
  idx : idx + num_statistics
1771
  ]
1772
+ errors_for_state = jnp.stack(
1773
+ new_errors_flat[idx : idx + num_statistics]
1774
+ )
1775
 
1776
  assert len(state.statistics) == len(quantized_preconditioners_for_state)
1777
  assert len(state.statistics) == len(quantized_diagonals_for_state)
1778
  assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
1779
+ assert len(state.statistics) == len(errors_for_state)
1780
 
1781
  quantized_preconditioners = []
1782
  for qv, qd, qb in zip(
 
1788
  QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape))
1789
  )
1790
  preconditioners_for_states.append(quantized_preconditioners)
1791
+ errors_for_states.append(errors_for_state)
1792
  idx += num_statistics
1793
  new_states = []
1794
+ for state, new_preconditioners, new_errors in zip(
1795
+ states, preconditioners_for_states, errors_for_states
1796
+ ):
1797
+ if state.statistics:
1798
+ new_errors = jnp.where(
1799
+ jnp.logical_and(
1800
+ new_errors > 0.0, new_errors != inverse_failure_threshold
1801
+ ),
1802
+ new_errors,
1803
+ state.training_metrics.inverse_pth_root_errors,
1804
+ )
1805
+ new_training_metrics = TrainingMetrics(new_errors)
1806
  new_states.append(
1807
  ParameterStats(
1808
  state.diagonal_statistics,
 
1810
  new_preconditioners,
1811
  state.diagonal_momentum,
1812
  state.momentum,
1813
+ new_training_metrics,
1814
  )
1815
  )
1816
 
 
1892
  )
1893
 
1894
  new_preconditioners_flat = []
1895
+ new_errors_flat = []
1896
  for p, shape, prev_p, error in zip(
1897
  preconditioners_flat, original_shapes, prev_preconditioners, errors_flat
1898
  ):
1899
  new_preconditioners_flat.append(
1900
  _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
1901
  )
1902
+ new_errors_flat.append(error)
1903
 
1904
  assert len(states) == len(num_statistics_per_state)
1905
  assert len(new_preconditioners_flat) == num_statistics
1906
 
1907
  # Add back empty preconditioners so we that we can set the optimizer state.
1908
  preconditioners_for_states = []
1909
+ errors_for_states = []
1910
  idx = 0
1911
  for num_statistics, state in zip(num_statistics_per_state, states):
1912
  if num_statistics == 0:
1913
  preconditioners_for_states.append([])
1914
+ errors_for_states.append([])
1915
  else:
1916
  preconditioners_for_state = new_preconditioners_flat[
1917
  idx : idx + num_statistics
1918
  ]
1919
  assert len(state.statistics) == len(preconditioners_for_state)
1920
  preconditioners_for_states.append(preconditioners_for_state)
1921
+
1922
+ errors_for_state = jnp.stack(
1923
+ new_errors_flat[idx : idx + num_statistics]
1924
+ )
1925
+ assert len(state.statistics) == len(errors_for_state)
1926
+ errors_for_states.append(errors_for_state)
1927
  idx += num_statistics
1928
+
1929
  new_states = []
1930
+ for state, new_preconditioners, new_errors in zip(
1931
+ states, preconditioners_for_states, errors_for_states
1932
+ ):
1933
+ if state.statistics:
1934
+ new_errors = jnp.where(
1935
+ jnp.logical_and(
1936
+ new_errors > 0.0, new_errors != inverse_failure_threshold
1937
+ ),
1938
+ new_errors,
1939
+ state.training_metrics.inverse_pth_root_errors,
1940
+ )
1941
+ new_training_metrics = TrainingMetrics(new_errors)
1942
  new_states.append(
1943
  ParameterStats(
1944
  state.diagonal_statistics,
 
1946
  new_preconditioners,
1947
  state.diagonal_momentum,
1948
  state.momentum,
1949
+ new_training_metrics,
1950
  )
1951
  )
1952
 
 
2133
  state.preconditioners,
2134
  _quantize_momentum(grafting_update_with_wd_momentum),
2135
  _quantize_momentum(shampoo_update_with_wd_momentum),
2136
+ state.training_metrics,
2137
  )
2138
+
2139
  return transformed_update, param_stats
2140
 
2141
  def update_fn(grads, state, params):
 
2178
  return updates, new_state
2179
 
2180
  if shard_optimizer_states:
2181
+ # Hijacks the init_fn signature so we can return an OptState with
2182
+ # appropriate init_fns.
2183
+ def _init_fns(unused_params):
2184
+ return InitFnState(
2185
+ init_fn=sharded_init_fn,
2186
+ pspec_fn=sharded_init_partition_spec_fn,
2187
+ shape_and_dtype_fn=sharded_init_shape_and_dtype_fn,
2188
+ )
2189
+
2190
+ return optax.GradientTransformation(_init_fns, sharded_update_fn)
2191
  else:
2192
  return optax.GradientTransformation(init_fn, update_fn)
tools/train/train.py CHANGED
@@ -25,7 +25,7 @@ import sys
25
  import time
26
  from dataclasses import asdict, dataclass, field
27
  from pathlib import Path
28
- from typing import Callable, Optional
29
 
30
  import datasets
31
  import jax
@@ -36,12 +36,12 @@ import transformers
36
  import wandb
37
  from datasets import Dataset
38
  from distributed_shampoo import GraftingType, distributed_shampoo
39
- from flax.core.frozen_dict import freeze
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.training import train_state
42
  from flax.training.common_utils import onehot, stack_forest
43
  from jax.experimental import PartitionSpec, maps
44
- from jax.experimental.pjit import pjit
45
  from tqdm import tqdm
46
  from transformers import HfArgumentParser
47
 
@@ -248,6 +248,10 @@ class TrainingArguments:
248
  default=1024,
249
  metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
250
  )
 
 
 
 
251
  preconditioning_compute_steps: int = field(
252
  default=10, metadata={"help": "Number of steps to update preconditioner."}
253
  )
@@ -478,6 +482,7 @@ def main():
478
  artifact_dir,
479
  dtype=getattr(jnp, model_args.dtype),
480
  abstract_init=True,
 
481
  )
482
 
483
  # load tokenizer
@@ -501,12 +506,14 @@ def main():
501
  seed=training_args.seed_model,
502
  dtype=getattr(jnp, model_args.dtype),
503
  abstract_init=True,
 
504
  )
505
  else:
506
  model = DalleBart(
507
  config,
508
  seed=training_args.seed_model,
509
  dtype=getattr(jnp, model_args.dtype),
 
510
  )
511
 
512
  # Load tokenizer
@@ -520,6 +527,12 @@ def main():
520
  use_fast=True,
521
  )
522
 
 
 
 
 
 
 
523
  # Preprocessing the datasets.
524
  # We need to normalize and tokenize inputs and targets.
525
 
@@ -536,14 +549,14 @@ def main():
536
 
537
  # Store some constant
538
  num_epochs = training_args.num_train_epochs
539
- # batch size per node
540
- train_batch_size = (
541
- training_args.per_device_train_batch_size * jax.local_device_count()
542
  )
543
- batch_size_per_node = train_batch_size * training_args.gradient_accumulation_steps
544
  batch_size_per_step = batch_size_per_node * jax.process_count()
545
  eval_batch_size = (
546
- training_args.per_device_eval_batch_size * jax.local_device_count()
547
  )
548
  len_train_dataset, len_eval_dataset = dataset.length
549
  steps_per_epoch = (
@@ -599,14 +612,17 @@ def main():
599
  beta2=training_args.beta2,
600
  diagonal_epsilon=1e-10,
601
  matrix_epsilon=1e-8,
602
- start_preconditioning_step=training_args.warmup_steps,
603
  preconditioning_compute_steps=training_args.preconditioning_compute_steps,
604
  statistics_compute_steps=1,
605
  best_effort_shape_interpretation=True,
606
  graft_type=GraftingType.RMSPROP_NORMALIZED,
607
  nesterov=False,
608
  exponent_override=0,
609
- batch_axis_name="batch",
 
 
 
610
  inverse_failure_threshold=0.1,
611
  moving_average_for_momentum=True,
612
  skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
@@ -614,6 +630,13 @@ def main():
614
  precision=jax.lax.Precision.HIGHEST,
615
  best_effort_memory_usage_reduction=training_args.optim_quantized,
616
  )
 
 
 
 
 
 
 
617
 
618
  elif training_args.optim == "adam":
619
  optimizer = optax.adamw(
@@ -630,31 +653,45 @@ def main():
630
  clipping_threshold=training_args.max_grad_norm,
631
  )
632
 
633
- # get opt_state shape without actual init
634
- opt_state_shape = jax.eval_shape(lambda x: optimizer.init(x), model.params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
 
636
- # get PartitionSpec for model params
637
- param_spec = set_partitions(model.params)
 
638
 
639
- # create PartitionSpec for opt_state
640
- def opt_state_spec_per_leaf(x):
641
- if training_args.optim in ["adam", "adafactor"]:
642
- if isinstance(x, dict):
643
- # variables with same structure as params
644
- return param_spec
645
- else:
646
- # other variables such as count
647
- return None
648
  else:
649
- # TODO: create spec for Distributed Shampoo
650
  raise NotImplementedError
 
651
 
652
- opt_state_spec = jax.tree_map(
653
- opt_state_spec_per_leaf,
654
- opt_state_shape,
655
- # return None spec for empty elements
656
- is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
657
- )
658
 
659
  # create a mesh
660
  mesh_shape = (training_args.dp_devices, training_args.mp_devices)
@@ -674,51 +711,61 @@ def main():
674
  tx=optimizer,
675
  )
676
 
677
- opt_state, attr_state = None, None
678
- if training_args.resume_from_checkpoint is not None:
679
- # restore opt_state
680
- with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
681
- opt_state = from_bytes(opt_state_shape, f.read())
682
- # need to freeze dict for pjit
683
- opt_state = jax.tree_map(
684
- lambda x: freeze(x) if isinstance(x, dict) else x,
685
- opt_state,
686
- is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
687
- )
688
- # restore other attributes
689
- with (Path(artifact_dir) / "training_state.json").open("r") as f:
690
- attr_state = json.load(f)
691
-
692
  # create training state
693
- def init_state(params, opt_state):
694
  if training_args.resume_from_checkpoint is None:
695
- state = TrainState.create(
696
- apply_fn=model.__call__,
697
- tx=optimizer,
698
- params=freeze(params),
699
- dropout_rng=dropout_rng,
700
- )
 
 
 
 
 
 
 
 
 
 
701
  else:
702
- state = TrainState(
703
- apply_fn=model.__call__,
704
- tx=optimizer,
705
- params=freeze(params),
706
- opt_state=opt_state,
707
- dropout_rng=dropout_rng,
708
- **attr_state,
709
- )
710
- return state
 
 
 
 
 
 
 
 
711
 
712
- with maps.mesh(mesh.devices, mesh.axis_names):
713
- state = pjit(
714
- init_state,
715
- in_axis_resources=(param_spec, opt_state_spec),
716
- out_axis_resources=state_spec,
717
- donate_argnums=(0, 1),
718
- )(freeze(model.params), opt_state)
 
 
719
 
720
- # free memory from large parameters
721
- del model._params, opt_state
 
 
 
 
 
722
 
723
  # label smoothed cross entropy
724
  def loss_fn(logits, labels):
@@ -728,11 +775,24 @@ def main():
728
 
729
  # Define gradient update step fn
730
  def train_step(state, batch, delta_time):
731
- dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
732
- # use a different rng per node
733
- dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
 
 
 
 
 
 
 
 
 
 
 
734
 
735
- def compute_loss(params, minibatch):
 
 
736
  labels = minibatch.pop("labels")
737
  logits = state.apply_fn(
738
  **minibatch, params=params, dropout_rng=dropout_rng, train=True
@@ -741,36 +801,75 @@ def main():
741
 
742
  grad_fn = jax.value_and_grad(compute_loss)
743
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
744
  if training_args.gradient_accumulation_steps == 1:
745
- minibatch = jax.tree_map(lambda x: x[0], batch)
746
- loss, grads = grad_fn(state.params, minibatch)
747
- else:
748
 
749
- def _cumul_loss_grads(i, cumul_loss_grads):
750
- minibatch = jax.tree_map(lambda x: x[i], batch)
751
- return jax.tree_map(
752
- lambda x, y: x + y,
753
- cumul_loss_grads,
754
- grad_fn(state.params, minibatch),
755
- )
756
 
757
- init_loss_grads = (
 
 
 
758
  0.0,
759
  jax.tree_map(jnp.zeros_like, state.params),
760
  )
761
- loss, grads = jax.tree_map(
762
- lambda x: x / training_args.gradient_accumulation_steps,
763
- jax.lax.fori_loop(
764
- 0,
765
- training_args.gradient_accumulation_steps,
766
- _cumul_loss_grads,
767
- init_loss_grads,
768
- ),
 
 
 
 
 
 
 
 
 
 
 
 
769
  )
770
 
 
 
771
  state = state.apply_gradients(
772
  grads=grads,
773
- dropout_rng=new_dropout_rng,
774
  train_time=state.train_time + delta_time,
775
  train_samples=state.train_samples + batch_size_per_step,
776
  )
@@ -784,6 +883,7 @@ def main():
784
 
785
  # Define eval fn
786
  def eval_step(params, batch):
 
787
  labels = batch.pop("labels")
788
  logits = model(**batch, params=params, train=False)[0]
789
  loss = loss_fn(logits, labels)
@@ -795,13 +895,13 @@ def main():
795
  # Create parallel version of the train and eval step
796
  p_train_step = pjit(
797
  train_step,
798
- in_axis_resources=(state_spec, PartitionSpec("batch", None), None),
799
  out_axis_resources=(state_spec, None),
800
  donate_argnums=(0,),
801
  )
802
  p_eval_step = pjit(
803
  eval_step,
804
- in_axis_resources=(param_spec, PartitionSpec("batch", None)),
805
  out_axis_resources=None,
806
  )
807
 
@@ -842,9 +942,7 @@ def main():
842
  # ======================== Evaluating ==============================
843
  eval_metrics = []
844
  if training_args.do_eval:
845
- eval_loader = dataset.dataloader(
846
- "eval", training_args.per_device_eval_batch_size
847
- )
848
  eval_steps = (
849
  len_eval_dataset // eval_batch_size
850
  if len_eval_dataset is not None
@@ -857,8 +955,8 @@ def main():
857
  leave=False,
858
  total=eval_steps,
859
  ):
860
- # Model forward
861
- metrics = p_eval_step(state.params, batch)
862
  eval_metrics.append(metrics)
863
 
864
  # normalize eval metrics
@@ -962,8 +1060,7 @@ def main():
962
  # Generate an epoch by shuffling sampling indices from the train dataset
963
  train_loader = dataset.dataloader(
964
  "train",
965
- training_args.per_device_train_batch_size,
966
- training_args.gradient_accumulation_steps,
967
  epoch,
968
  )
969
  # train
@@ -974,15 +1071,27 @@ def main():
974
  leave=False,
975
  total=steps_per_epoch,
976
  ):
977
-
978
  # calculate delta time (we have a lag of one step but it's ok)
979
  new_time = time.perf_counter()
980
  delta_time = new_time - last_time
981
  last_time = new_time
982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
983
  # train step
984
- state, train_metrics = p_train_step(state, batch, delta_time)
985
- step = state.step
986
 
987
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
988
  all_metrics = metrics_logger.get_all_train_metrics(
 
25
  import time
26
  from dataclasses import asdict, dataclass, field
27
  from pathlib import Path
28
+ from typing import Any, Callable, NamedTuple, Optional
29
 
30
  import datasets
31
  import jax
 
36
  import wandb
37
  from datasets import Dataset
38
  from distributed_shampoo import GraftingType, distributed_shampoo
39
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.training import train_state
42
  from flax.training.common_utils import onehot, stack_forest
43
  from jax.experimental import PartitionSpec, maps
44
+ from jax.experimental.pjit import pjit, with_sharding_constraint
45
  from tqdm import tqdm
46
  from transformers import HfArgumentParser
47
 
 
248
  default=1024,
249
  metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
250
  )
251
+ start_preconditioning_step: int = field(
252
+ default=100,
253
+ metadata={"help": "Number of steps before starting to update preconditioner."},
254
+ )
255
  preconditioning_compute_steps: int = field(
256
  default=10, metadata={"help": "Number of steps to update preconditioner."}
257
  )
 
482
  artifact_dir,
483
  dtype=getattr(jnp, model_args.dtype),
484
  abstract_init=True,
485
+ load_on_cpu=True,
486
  )
487
 
488
  # load tokenizer
 
506
  seed=training_args.seed_model,
507
  dtype=getattr(jnp, model_args.dtype),
508
  abstract_init=True,
509
+ load_on_cpu=True,
510
  )
511
  else:
512
  model = DalleBart(
513
  config,
514
  seed=training_args.seed_model,
515
  dtype=getattr(jnp, model_args.dtype),
516
+ load_on_cpu=True,
517
  )
518
 
519
  # Load tokenizer
 
527
  use_fast=True,
528
  )
529
 
530
+ # get PartitionSpec for model params (required to be a dict)
531
+ param_spec = set_partitions(model.params)
532
+
533
+ # convert params to frozen dict
534
+ model._params = freeze(model.params)
535
+
536
  # Preprocessing the datasets.
537
  # We need to normalize and tokenize inputs and targets.
538
 
 
549
 
550
  # Store some constant
551
  num_epochs = training_args.num_train_epochs
552
+ # batch size
553
+ minibatch_size = (
554
+ training_args.per_device_train_batch_size * training_args.dp_devices
555
  )
556
+ batch_size_per_node = minibatch_size * training_args.gradient_accumulation_steps
557
  batch_size_per_step = batch_size_per_node * jax.process_count()
558
  eval_batch_size = (
559
+ training_args.per_device_eval_batch_size * training_args.dp_devices
560
  )
561
  len_train_dataset, len_eval_dataset = dataset.length
562
  steps_per_epoch = (
 
612
  beta2=training_args.beta2,
613
  diagonal_epsilon=1e-10,
614
  matrix_epsilon=1e-8,
615
+ start_preconditioning_step=training_args.start_preconditioning_step,
616
  preconditioning_compute_steps=training_args.preconditioning_compute_steps,
617
  statistics_compute_steps=1,
618
  best_effort_shape_interpretation=True,
619
  graft_type=GraftingType.RMSPROP_NORMALIZED,
620
  nesterov=False,
621
  exponent_override=0,
622
+ statistics_partition_spec=PartitionSpec(None, "batch", None),
623
+ preconditioner_partition_spec=PartitionSpec("batch", None, None),
624
+ num_devices_for_pjit=training_args.dp_devices,
625
+ shard_optimizer_states=True,
626
  inverse_failure_threshold=0.1,
627
  moving_average_for_momentum=True,
628
  skip_preconditioning_dim_size_gt=training_args.skip_preconditioning_dim_size_gt,
 
630
  precision=jax.lax.Precision.HIGHEST,
631
  best_effort_memory_usage_reduction=training_args.optim_quantized,
632
  )
633
+ # get the real optimizer and helper functions
634
+ update_fn = optimizer.update
635
+ optimizer = optimizer.init(model.params)
636
+ opt_fn = NamedTuple("opt_fn", pspec_fn=Any, shape_and_dtype_fn=Any)(
637
+ optimizer.pspec_fn, optimizer.shape_and_dtype_fn
638
+ )
639
+ optimizer = optax.GradientTransformation(optimizer.init_fn, update_fn)
640
 
641
  elif training_args.optim == "adam":
642
  optimizer = optax.adamw(
 
653
  clipping_threshold=training_args.max_grad_norm,
654
  )
655
 
656
+ # get PartitionSpec for optimizer state
657
+ def get_opt_state_spec_and_shape(param_spec):
658
+ if training_args.optim in ["adam", "adafactor"]:
659
+ # get opt_state shape without actual init
660
+ opt_state_shape = jax.eval_shape(optimizer.init, model.params)
661
+
662
+ if training_args.optim == "adam":
663
+
664
+ def _opt_state_spec_per_leaf(x):
665
+ if isinstance(x, FrozenDict):
666
+ # variables with same structure as params
667
+ return param_spec
668
+ else:
669
+ # other variables such as count
670
+ return None
671
+
672
+ opt_state_spec = jax.tree_map(
673
+ _opt_state_spec_per_leaf,
674
+ opt_state_shape,
675
+ # return None spec for empty elements
676
+ is_leaf=lambda x: isinstance(x, (FrozenDict, optax.EmptyState)),
677
+ )
678
 
679
+ elif training_args.optim == "adafactor":
680
+ # factorized state must be replicated (rank different than params)
681
+ opt_state_spec = None
682
 
683
+ elif training_args.optim == "distributed_shampoo":
684
+ opt_state_spec = opt_fn.pspec_fn(
685
+ params=model.params,
686
+ params_partition_spec=param_spec,
687
+ partition_spec_for_statistics=PartitionSpec(None, "batch", None),
688
+ )
689
+ opt_state_shape = opt_fn.shape_and_dtype_fn(model.params)
 
 
690
  else:
 
691
  raise NotImplementedError
692
+ return opt_state_spec, opt_state_shape
693
 
694
+ opt_state_spec, opt_state_shape = get_opt_state_spec_and_shape(param_spec)
 
 
 
 
 
695
 
696
  # create a mesh
697
  mesh_shape = (training_args.dp_devices, training_args.mp_devices)
 
711
  tx=optimizer,
712
  )
713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
  # create training state
715
+ with maps.mesh(mesh.devices, mesh.axis_names):
716
  if training_args.resume_from_checkpoint is None:
717
+
718
+ def init_state(params):
719
+ return TrainState.create(
720
+ apply_fn=model.__call__,
721
+ tx=optimizer,
722
+ params=params,
723
+ dropout_rng=dropout_rng,
724
+ )
725
+
726
+ state = pjit(
727
+ init_state,
728
+ in_axis_resources=(param_spec,),
729
+ out_axis_resources=state_spec,
730
+ donate_argnums=(0,),
731
+ )(model.params)
732
+
733
  else:
734
+ # restore opt_state
735
+ with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
736
+ opt_state = from_bytes(opt_state_shape, f.read())
737
+
738
+ # restore other attributes
739
+ with (Path(artifact_dir) / "training_state.json").open("r") as f:
740
+ attr_state = json.load(f)
741
+
742
+ def restore_state(params, opt_state):
743
+ return TrainState(
744
+ apply_fn=model.__call__,
745
+ tx=optimizer,
746
+ params=params,
747
+ opt_state=opt_state,
748
+ dropout_rng=dropout_rng,
749
+ **attr_state,
750
+ )
751
 
752
+ state = pjit(
753
+ restore_state,
754
+ in_axis_resources=(param_spec, opt_state_spec),
755
+ out_axis_resources=state_spec,
756
+ donate_argnums=(0, 1),
757
+ )(model.params, opt_state)
758
+
759
+ # remove opt_state from CPU
760
+ del opt_state
761
 
762
+ # free memory
763
+ del model._params
764
+
765
+ # define batch specs
766
+ keys = ["attention_mask", "decoder_input_ids", "input_ids", "labels"]
767
+ batch_spec = freeze({k: PartitionSpec("batch") for k in keys})
768
+ grad_batch_spec = freeze({k: PartitionSpec(None, "batch") for k in keys})
769
 
770
  # label smoothed cross entropy
771
  def loss_fn(logits, labels):
 
775
 
776
  # Define gradient update step fn
777
  def train_step(state, batch, delta_time):
778
+ # batch is (gradient_accumulation_steps, minibatch_size, ...)
779
+ # check correct batch shape during compilation
780
+ assert batch["labels"].shape[0:3] == (
781
+ training_args.gradient_accumulation_steps,
782
+ training_args.dp_devices,
783
+ training_args.per_device_train_batch_size,
784
+ ), f"Expected label batch of shape dp_devices x gradient_acculumation x batch_per_device and got {batch['labels'].shape}"
785
+
786
+ # get a minibatch (one gradient accumulation slice)
787
+ def get_minibatch(batch, grad_idx):
788
+ return jax.tree_map(
789
+ lambda x: jax.lax.dynamic_index_in_dim(x, grad_idx, keepdims=False),
790
+ batch,
791
+ )
792
 
793
+ def compute_loss(params, minibatch, dropout_rng):
794
+ # minibatch has dim (batch_size, ...)
795
+ minibatch = unfreeze(minibatch)
796
  labels = minibatch.pop("labels")
797
  logits = state.apply_fn(
798
  **minibatch, params=params, dropout_rng=dropout_rng, train=True
 
801
 
802
  grad_fn = jax.value_and_grad(compute_loss)
803
 
804
+ def loss_and_grad(grad_idx, dropout_rng):
805
+ # minibatch at grad_idx, shape (dp_devices, per_device_train_batch_size, ...)
806
+ minibatch = get_minibatch(batch, grad_idx)
807
+ # ensure batch is sharded over devices
808
+ minibatch = jax.tree_map(
809
+ lambda x: with_sharding_constraint(x, PartitionSpec("batch")), minibatch
810
+ )
811
+ # calculate loss and grads independently per dp_device
812
+ loss_grads = jax.vmap(grad_fn, in_axes=(None, 0, None), out_axes=(0, 0))(
813
+ state.params, minibatch, dropout_rng
814
+ )
815
+ # ensure they are sharded over devices
816
+ loss_grads = jax.tree_map(
817
+ lambda x: with_sharding_constraint(x, PartitionSpec("batch")),
818
+ loss_grads,
819
+ )
820
+
821
+ # average across all devices
822
+ loss_grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), loss_grads)
823
+
824
+ # return loss and grads
825
+ return loss_grads
826
+
827
+ # create a new rng
828
+ dropout_rng, _ = jax.random.split(state.dropout_rng)
829
+ # use a different rng per node
830
+ dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
831
+
832
  if training_args.gradient_accumulation_steps == 1:
 
 
 
833
 
834
+ def batch_step(dropout_rng):
835
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
836
+ loss_grad = loss_and_grad(0, dropout_rng)
837
+ return loss_grad, new_dropout_rng
 
 
 
838
 
839
+ loss_grad, dropout_rng = batch_step(dropout_rng)
840
+ else:
841
+ # create initial state for per_minibatch_step loop
842
+ init_cumul_loss_grad = (
843
  0.0,
844
  jax.tree_map(jnp.zeros_like, state.params),
845
  )
846
+ init_minibatch_step = (init_cumul_loss_grad, dropout_rng)
847
+
848
+ # accumulate gradients
849
+ def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
850
+ cumul_loss_grad, dropout_rng = cumul_loss_grad_dropout
851
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
852
+ loss_grad = loss_and_grad(grad_idx, dropout_rng)
853
+ cumul_loss_grad = jax.tree_map(jnp.add, cumul_loss_grad, loss_grad)
854
+ return cumul_loss_grad, new_dropout_rng
855
+
856
+ # loop over gradients
857
+ loss_grad, dropout_rng = jax.lax.fori_loop(
858
+ 0,
859
+ training_args.gradient_accumulation_steps,
860
+ cumul_minibatch_step,
861
+ init_minibatch_step,
862
+ )
863
+ # sum -> mean
864
+ loss_grad = jax.tree_map(
865
+ lambda x: x / training_args.gradient_accumulation_steps, loss_grad
866
  )
867
 
868
+ # update state
869
+ loss, grads = loss_grad
870
  state = state.apply_gradients(
871
  grads=grads,
872
+ dropout_rng=dropout_rng,
873
  train_time=state.train_time + delta_time,
874
  train_samples=state.train_samples + batch_size_per_step,
875
  )
 
883
 
884
  # Define eval fn
885
  def eval_step(params, batch):
886
+ batch = unfreeze(batch)
887
  labels = batch.pop("labels")
888
  logits = model(**batch, params=params, train=False)[0]
889
  loss = loss_fn(logits, labels)
 
895
  # Create parallel version of the train and eval step
896
  p_train_step = pjit(
897
  train_step,
898
+ in_axis_resources=(state_spec, grad_batch_spec, None),
899
  out_axis_resources=(state_spec, None),
900
  donate_argnums=(0,),
901
  )
902
  p_eval_step = pjit(
903
  eval_step,
904
+ in_axis_resources=(param_spec, batch_spec),
905
  out_axis_resources=None,
906
  )
907
 
 
942
  # ======================== Evaluating ==============================
943
  eval_metrics = []
944
  if training_args.do_eval:
945
+ eval_loader = dataset.dataloader("eval", eval_batch_size)
 
 
946
  eval_steps = (
947
  len_eval_dataset // eval_batch_size
948
  if len_eval_dataset is not None
 
955
  leave=False,
956
  total=eval_steps,
957
  ):
958
+ # TODO: make this more efficient once training loop is fast
959
+ metrics = p_eval_step(state.params, freeze(batch))
960
  eval_metrics.append(metrics)
961
 
962
  # normalize eval metrics
 
1060
  # Generate an epoch by shuffling sampling indices from the train dataset
1061
  train_loader = dataset.dataloader(
1062
  "train",
1063
+ batch_size_per_node,
 
1064
  epoch,
1065
  )
1066
  # train
 
1071
  leave=False,
1072
  total=steps_per_epoch,
1073
  ):
 
1074
  # calculate delta time (we have a lag of one step but it's ok)
1075
  new_time = time.perf_counter()
1076
  delta_time = new_time - last_time
1077
  last_time = new_time
1078
 
1079
+ # reshape data into (gradient_accumulation_steps, dp_devices, batch_per_dp, ...)
1080
+ batch = jax.tree_map(
1081
+ lambda x: x.reshape(
1082
+ (
1083
+ training_args.gradient_accumulation_steps,
1084
+ training_args.dp_devices,
1085
+ training_args.per_device_train_batch_size,
1086
+ )
1087
+ + x.shape[1:]
1088
+ ),
1089
+ batch,
1090
+ )
1091
+
1092
  # train step
1093
+ state, train_metrics = p_train_step(state, freeze(batch), delta_time)
1094
+ step = int(state.step)
1095
 
1096
  if step % training_args.logging_steps == 0 and jax.process_index() == 0:
1097
  all_metrics = metrics_logger.get_all_train_metrics(