boris commited on
Commit
b90198c
1 Parent(s): 7143593

feat: update distributed_shampoo

Browse files
Files changed (1) hide show
  1. tools/train/distributed_shampoo.py +684 -207
tools/train/distributed_shampoo.py CHANGED
@@ -33,7 +33,7 @@
33
  import enum
34
  import functools
35
  import itertools
36
- from typing import Any, NamedTuple
37
 
38
  import chex
39
  from flax import struct
@@ -46,16 +46,105 @@ import optax
46
 
47
 
48
  # pylint:disable=no-value-for-parameter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  # Per parameter optimizer state used in data-parallel training.
52
  class ParameterStats(NamedTuple):
53
  """State associated to each parameter of the model being trained."""
54
- diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner
55
- statistics: chex.Array # Statistics
56
- preconditioners: chex.Array # Preconditioners
57
- diagonal_momentum: chex.Array # Momentum for the diagonal preconditioner
58
- momentum: chex.Array # Momentum for the shampoo preconditioner
59
 
60
 
61
  # For training extremely large model; We keep a global state with a concatenated
@@ -73,9 +162,9 @@ class GlobalShardedParameterStats:
73
  @struct.dataclass
74
  class LocalShardedParameterStats:
75
  """State associated to each parameter of the model being trained."""
76
- diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner
77
- diagonal_momentum: chex.Array # Momentum for the diagonal preconditioner
78
- momentum: chex.Array # Momentum for the shampoo preconditioner
79
  index_start: np.int32 = struct.field(
80
  pytree_node=False) # Index into global statistics array
81
  sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
@@ -141,7 +230,8 @@ def power_iteration(
141
  jnp.greater(jnp.abs(s_new - s), error_tolerance))
142
 
143
  # Figure out how to use step as seed for random.
144
- v_0 = np.random.uniform(-1.0, 1.0, matrix_size).astype(matrix.dtype)
 
145
 
146
  init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
147
  _, v_out, s_out, _, _ = lax.while_loop(
@@ -323,6 +413,25 @@ def pad_matrix(mat, max_size):
323
  return mat
324
 
325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
327
  """Avoids wasteful buffer allocation with XLA."""
328
 
@@ -492,33 +601,59 @@ def _convert_from_parameter_stats(parameter_stats, local_stats):
492
  local_stats.index_start, local_stats.sizes)
493
 
494
 
495
- def distributed_shampoo(learning_rate,
496
- block_size,
497
- beta1=0.9,
498
- beta2=0.999,
499
- diagonal_epsilon=1e-10,
500
- matrix_epsilon=1e-6,
501
- weight_decay=0.0,
502
- start_preconditioning_step=5,
503
- preconditioning_compute_steps=1,
504
- statistics_compute_steps=1,
505
- best_effort_shape_interpretation=True,
506
- graft_type=GraftingType.SGD,
507
- nesterov=True,
508
- exponent_override=0,
509
- # Pass pmap 'batch axis name' in pmap mode.
510
- batch_axis_name=None,
511
- ### Only set following 3 params in pjit/spmd mode.
512
- ### WARNING: Experimental
513
- mesh_axis_names=None,
514
- num_devices_for_pjit=None,
515
- shard_optimizer_states=False,
516
- ###
517
- inverse_failure_threshold=0.1,
518
- moving_average_for_momentum=False,
519
- skip_preconditioning_dim_size_gt=4096,
520
- clip_by_scaled_gradient_norm=None,
521
- precision=lax.Precision.HIGHEST):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
  """Distributed Shampoo optimizer.
523
 
524
  Distributed Shampoo is a second-order preconditioned method (concretely, a
@@ -570,6 +705,10 @@ def distributed_shampoo(learning_rate,
570
  num_devices_for_pjit: Number of devices to parallelize over when using pjit.
571
  shard_optimizer_states: Shard optimizer states to save memory in model
572
  parallel training.
 
 
 
 
573
  inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
574
  determine that using this threshold.
575
  moving_average_for_momentum: Whether to use moving average for momentum
@@ -587,6 +726,67 @@ def distributed_shampoo(learning_rate,
587
  a GradientTransformation.
588
  """
589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
  def sharded_init_fn(params):
591
  params_flat, treedef = jax.tree_flatten(params)
592
  # Find max size to pad to.
@@ -619,12 +819,14 @@ def distributed_shampoo(learning_rate,
619
  padded_statistics.extend(statistics)
620
  padded_preconditioners.extend(preconditioners)
621
 
622
- adagrad_statistics = []
623
  if graft_type != GraftingType.SGD:
624
- adagrad_statistics = jnp.zeros_like(param)
625
  local_stats_flat.append(
626
- LocalShardedParameterStats(adagrad_statistics, jnp.zeros_like(param),
627
- jnp.zeros_like(param), index_start, sizes))
 
 
628
 
629
  local_stats = jax.tree_unflatten(treedef, local_stats_flat)
630
  # Pad the statistics and preconditioner matrices to be a multiple of
@@ -769,12 +971,15 @@ def distributed_shampoo(learning_rate,
769
  statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes]
770
  preconditioners = [jnp.eye(s[0]) for s in shapes]
771
 
772
- adagrad_statistics = []
773
  if graft_type != GraftingType.SGD:
774
- adagrad_statistics = jnp.zeros_like(param)
775
- return ParameterStats(adagrad_statistics, statistics, preconditioners,
776
- jnp.zeros_like(param), jnp.zeros_like(param))
777
-
 
 
 
778
  return ShampooState(
779
  count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params))
780
 
@@ -795,8 +1000,9 @@ def distributed_shampoo(learning_rate,
795
  new_stats = preconditioner.statistics_from_grad(grad)
796
  new_stats_accumulators = []
797
  for stat, stat_accumulator in zip(new_stats, state.statistics):
798
- new_stats_accumulators.append(w1 * stat_accumulator + w2 * stat)
799
- return new_stats_accumulators
 
800
 
801
  if statistics_compute_steps > 1:
802
  perform_step = step % statistics_compute_steps == 0
@@ -810,164 +1016,375 @@ def distributed_shampoo(learning_rate,
810
  state.preconditioners, state.diagonal_momentum,
811
  state.momentum)
812
 
813
- def _compute_preconditioners(states, params, step):
814
- """Compute preconditioners for statistics."""
815
- statistics = []
816
- num_statistics_per_state = []
817
- original_shapes = []
818
- exponents = []
819
- max_size = 0
820
- prev_preconditioners = []
821
- for state, param in zip(states, params):
822
- num_statistics = len(state.statistics)
823
- num_statistics_per_state.append(num_statistics)
824
- original_shapes_for_state = []
825
- if num_statistics > 0:
826
- preconditioner = Preconditioner(param, block_size,
827
- best_effort_shape_interpretation)
828
- for statistic in state.statistics:
829
- exponents.append(preconditioner.exponent_for_preconditioner(
830
- ) if exponent_override == 0 else exponent_override)
831
- original_shapes_for_state.append(statistic.shape)
832
- max_size = max(max_size, statistic.shape[0])
833
- statistics.extend(state.statistics)
834
- prev_preconditioners.extend(state.preconditioners)
835
- original_shapes.extend(original_shapes_for_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
836
  num_statistics = len(statistics)
 
 
 
 
 
 
 
 
837
 
838
- if batch_axis_name:
839
- num_devices = lax.psum(1, batch_axis_name)
840
-
841
- # Pad statistics and exponents to next multiple of num_devices.
842
- packed_statistics = [pad_matrix(stat, max_size) for stat in statistics]
843
- to_pad = -num_statistics % num_devices
844
- packed_statistics.extend([
845
- jnp.eye(max_size, dtype=packed_statistics[0].dtype)
846
- for _ in range(to_pad)
847
- ])
848
- exponents.extend([1 for _ in range(to_pad)])
849
-
850
- if not packed_statistics:
851
- return states
852
- # Batch statistics and exponents so that so that leading axis is
853
- # num_devices.
854
- def _batch(statistics, exponents, num_devices):
855
- assert len(statistics) == len(exponents)
856
- n = len(statistics)
857
- b = int(n / num_devices)
858
- batched_statistics = [
859
- jnp.stack(statistics[idx:idx + b]) for idx in range(0, n, b)
860
- ]
861
- batched_exponents = [
862
- jnp.stack(exponents[idx:idx + b]) for idx in range(0, n, b)
863
- ]
864
- return jnp.stack(batched_statistics), jnp.stack(batched_exponents)
865
-
866
- # Unbatch values across leading axis and return a list of elements.
867
- def _unbatch(batched_values):
868
- b1, b2 = batched_values.shape[0], batched_values.shape[1]
869
- results = []
870
- for v_array in jnp.split(
871
- batched_values, indices_or_sections=b1, axis=0):
872
- v_array = jnp.squeeze(v_array)
873
- # b2 = batches (number of preconditioner computation) per core.
874
- if b2 > 1:
875
- for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
876
- results.append(jnp.squeeze(v))
877
- else:
878
- results.append(v_array)
879
- return results
880
-
881
- all_statistics, all_exponents = _batch(packed_statistics, exponents,
882
- num_devices)
883
  else:
884
- to_pad = -num_statistics % num_devices_for_pjit
885
- padded_statistics = [pad_matrix(stat, max_size) for stat in statistics]
886
- padded_statistics.extend([
887
- jnp.eye(max_size, dtype=padded_statistics[0].dtype)
888
- for _ in range(to_pad)
889
- ])
890
- exponents.extend([1 for _ in range(to_pad)])
891
- all_statistics = jnp.stack(padded_statistics)
892
- all_exponents = jnp.stack(exponents)
893
 
894
- def _matrix_inverse_pth_root_vmap(xs, ps):
895
- mi_pth_root = functools.partial(
896
- matrix_inverse_pth_root,
897
- ridge_epsilon=matrix_epsilon,
898
- precision=precision)
899
- preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
900
- return preconditioners, errors
901
 
902
- def _matrix_inverse_pth_root_pjit(xs, ps):
903
- mesh_axis_names_tuple = tuple(mesh_axis_names)
904
- # Partition the concatenated statistics matrix across all cores.
905
- partitioned_xs, partitioned_ps = pjit.pjit(
906
- lambda x, y: (x, y),
907
- in_axis_resources=None,
908
- out_axis_resources=pjit.PartitionSpec(mesh_axis_names_tuple,))(xs, ps)
909
- # Run matrix inverse pth root on each shard.
910
- partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
911
- partitioned_xs, partitioned_ps)
912
- # Recombine the outputs at each core.
913
- preconditioners, errors = pjit.pjit(
914
- lambda x, y: (x, y),
915
- in_axis_resources=(pjit.PartitionSpec(mesh_axis_names_tuple,),
916
- pjit.PartitionSpec(mesh_axis_names_tuple,)),
917
- out_axis_resources=(None, None))(partitioned_preconditioners,
918
- partitioned_errors)
919
- return preconditioners, errors
920
 
921
- if not batch_axis_name:
922
- def _internal_inverse_pth_root_all():
923
- preconditioners, errors = _matrix_inverse_pth_root_pjit(
924
- all_statistics, all_exponents)
925
- b1 = preconditioners.shape[0]
926
- def split(batched_values):
927
- return [
928
- jnp.squeeze(v) for v in jnp.split(
929
- batched_values, indices_or_sections=b1, axis=0)
930
- ]
931
-
932
- return split(preconditioners), split(errors)
933
-
934
- if preconditioning_compute_steps == 1:
935
- preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
936
  else:
937
- # Passing statistics instead of preconditioners as they are similarly
938
- # shaped tensors. Note statistics will be ignored as we are passing in
939
- # a large init value for error.
940
- preconditioners_init = padded_statistics
941
- errors_init = [inverse_failure_threshold] * len(padded_statistics)
942
- init_state = [preconditioners_init, errors_init]
943
- perform_step = step % preconditioning_compute_steps == 0
944
- preconditioners_flat, errors_flat = efficient_cond(
945
- perform_step, _internal_inverse_pth_root_all, init_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
946
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
947
 
948
- def _internal_inverse_pth_root_all():
949
- preconditioners = jnp.array(all_statistics)
950
- current_replica = lax.axis_index(batch_axis_name)
951
- preconditioners, errors = _matrix_inverse_pth_root_vmap(
952
- all_statistics[current_replica], all_exponents[current_replica])
953
- preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
954
- errors = jax.lax.all_gather(errors, batch_axis_name)
955
- preconditioners_flat = _unbatch(preconditioners)
956
- errors_flat = _unbatch(errors)
957
- return preconditioners_flat, errors_flat
958
-
959
- if preconditioning_compute_steps == 1:
960
- preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
961
  else:
962
- # Passing statistics instead of preconditioners as they are similarly
963
- # shaped tensors. Note statistics will be ignored as we are passing in
964
- # a large init value for error.
965
- preconditioners_init = packed_statistics
966
- errors_init = ([inverse_failure_threshold] * len(packed_statistics))
967
- init_state = [preconditioners_init, errors_init]
968
- perform_step = step % preconditioning_compute_steps == 0
969
- preconditioners_flat, errors_flat = efficient_cond(
970
- perform_step, _internal_inverse_pth_root_all, init_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
971
 
972
  def _skip(error):
973
  condition = jnp.logical_or(
@@ -1008,14 +1425,70 @@ def distributed_shampoo(learning_rate,
1008
 
1009
  return new_states
1010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1011
  def _transform_grad(grad, state, param, step):
1012
  """Transform per-parameter gradients."""
1013
  preconditioner = Preconditioner(param, block_size,
1014
  best_effort_shape_interpretation)
1015
  sgd_update = grad
1016
- new_diagonal_statistics = state.diagonal_statistics
1017
  if graft_type == GraftingType.ADAGRAD:
1018
- new_diagonal_statistics = state.diagonal_statistics + jnp.square(grad)
 
1019
  adagrad_update = grad / (
1020
  jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
1021
  grafting_update = adagrad_update
@@ -1030,7 +1503,8 @@ def distributed_shampoo(learning_rate,
1030
  w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
1031
 
1032
  new_diagonal_statistics = (
1033
- w1 * state.diagonal_statistics + w2 * jnp.square(scaled_grad))
 
1034
  rmsprop_update = scaled_grad / (
1035
  jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
1036
 
@@ -1047,8 +1521,9 @@ def distributed_shampoo(learning_rate,
1047
 
1048
  precond_grad = grad
1049
  if not _skip_preconditioning(param):
1050
- precond_grad = preconditioner.preconditioned_grad(precond_grad,
1051
- state.preconditioners)
 
1052
  else:
1053
  precond_grad = grafting_update
1054
 
@@ -1066,9 +1541,10 @@ def distributed_shampoo(learning_rate,
1066
 
1067
  w = (1.0 - beta1) if moving_average_for_momentum else 1.0
1068
  shampoo_update_with_wd_momentum = (
1069
- state.momentum * beta1 + w * shampoo_update_with_wd)
1070
  grafting_update_with_wd_momentum = (
1071
- state.diagonal_momentum * beta1 + w * grafting_update_with_wd)
 
1072
 
1073
  run_shampoo = (step >= start_preconditioning_step).astype(
1074
  grafting_update_with_wd_momentum.dtype)
@@ -1089,10 +1565,11 @@ def distributed_shampoo(learning_rate,
1089
  lr = learning_rate(step)
1090
  transformed_update = -1.0 * lr * momentum_update
1091
 
1092
- param_stats = ParameterStats(new_diagonal_statistics, state.statistics,
1093
- state.preconditioners,
1094
- grafting_update_with_wd_momentum,
1095
- shampoo_update_with_wd_momentum)
 
1096
  return transformed_update, param_stats
1097
 
1098
  def update_fn(grads, state, params):
 
33
  import enum
34
  import functools
35
  import itertools
36
+ from typing import Any, List, NamedTuple
37
 
38
  import chex
39
  from flax import struct
 
46
 
47
 
48
  # pylint:disable=no-value-for-parameter
49
+ @struct.dataclass
50
+ class QuantizedValue:
51
+ """State associated with quantized value."""
52
+ quantized: chex.Array
53
+ diagonal: chex.Array # Diagonal (if extract_diagonal is set)
54
+ bucket_size: chex.Array
55
+ quantized_dtype: jnp.dtype = struct.field(
56
+ pytree_node=False) # Dtype for the quantized value.
57
+ extract_diagonal: bool = struct.field(
58
+ pytree_node=False) # In case its centered.
59
+ shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
60
+
61
+ @classmethod
62
+ def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
63
+ if isinstance(fvalue, list) and not fvalue:
64
+ return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
65
+ quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
66
+ fvalue, quantized_dtype, extract_diagonal)
67
+ return QuantizedValue(quantized, diagonal_fvalue, bucket_size,
68
+ quantized_dtype, extract_diagonal,
69
+ list(quantized.shape))
70
+
71
+ # Quantization is from Lingvo JAX optimizers.
72
+ # We extend it for int16 quantization of PSD matrices.
73
+ @classmethod
74
+ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
75
+ """Returns quantized value and the bucket."""
76
+ if quantized_dtype == jnp.float32:
77
+ return fvalue, [], []
78
+ elif quantized_dtype == jnp.bfloat16:
79
+ return fvalue.astype(jnp.bfloat16), [], []
80
+
81
+ float_dtype = fvalue.dtype
82
+ if quantized_dtype == jnp.int8:
83
+ # value -128 is not used.
84
+ num_buckets = jnp.array(127.0, dtype=float_dtype)
85
+ elif quantized_dtype == jnp.int16:
86
+ # value -32768 is not used.
87
+ num_buckets = jnp.array(32767.0, dtype=float_dtype)
88
+ else:
89
+ raise ValueError(f'Quantized dtype {quantized_dtype} not supported.')
90
+ # max value is mapped to num_buckets
91
+
92
+ if extract_diagonal and fvalue.ndim != 2:
93
+ raise ValueError(
94
+ f'Input array {fvalue} must be 2D to work with extract_diagonal.')
95
+
96
+ diagonal_fvalue = []
97
+ if extract_diagonal:
98
+ diagonal_fvalue = jnp.diag(fvalue)
99
+ # Remove the diagonal entries.
100
+ fvalue = fvalue - jnp.diag(diagonal_fvalue)
101
+
102
+ # TODO(rohananil): Extend this by making use of information about the blocks
103
+ # SM3 style which will be useful for diagonal statistics
104
+ # We first decide the scale.
105
+ if fvalue.ndim < 1:
106
+ raise ValueError(
107
+ f'Input array {fvalue} must have a strictly positive number of '
108
+ 'dimensions.')
109
+
110
+ max_abs = jnp.max(jnp.abs(fvalue), axis=0)
111
+ bucket_size = max_abs / num_buckets
112
+ bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
113
+ # To avoid divide by 0.0
114
+ bs_nonzero = jnp.where(bs_expanded > 0.0, bs_expanded,
115
+ jnp.ones_like(bs_expanded))
116
+ ratio = fvalue / bs_nonzero
117
+ # We use rounding to remove bias.
118
+ quantized = jnp.round(ratio)
119
+ return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
120
+
121
+ def to_float(self):
122
+ """Returns the float value."""
123
+ if isinstance(self.quantized, list) and not self.quantized:
124
+ return self.quantized
125
+
126
+ if self.quantized_dtype == jnp.float32:
127
+ return self.quantized
128
+
129
+ if self.quantized_dtype == jnp.bfloat16:
130
+ return self.quantized.astype(jnp.float32)
131
+
132
+ float_dtype = self.bucket_size.dtype
133
+ bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
134
+ val = self.quantized.astype(float_dtype) * bucket_size
135
+ if self.extract_diagonal:
136
+ val += jnp.diag(self.diagonal)
137
+ return val
138
 
139
 
140
  # Per parameter optimizer state used in data-parallel training.
141
  class ParameterStats(NamedTuple):
142
  """State associated to each parameter of the model being trained."""
143
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
144
+ statistics: List[Any] # Statistics (QuantizedValue, chex.Array)
145
+ preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
146
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
147
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
148
 
149
 
150
  # For training extremely large model; We keep a global state with a concatenated
 
162
  @struct.dataclass
163
  class LocalShardedParameterStats:
164
  """State associated to each parameter of the model being trained."""
165
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
166
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
167
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
168
  index_start: np.int32 = struct.field(
169
  pytree_node=False) # Index into global statistics array
170
  sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
 
230
  jnp.greater(jnp.abs(s_new - s), error_tolerance))
231
 
232
  # Figure out how to use step as seed for random.
233
+ v_0 = np.random.RandomState(1729).uniform(-1.0, 1.0,
234
+ matrix_size).astype(matrix.dtype)
235
 
236
  init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
237
  _, v_out, s_out, _, _ = lax.while_loop(
 
413
  return mat
414
 
415
 
416
+ def pad_vector(vec, max_size):
417
+ """Pad a vector to a max_size.
418
+
419
+ Args:
420
+ vec: a vector to pad.
421
+ max_size: matrix size requested.
422
+
423
+ Returns:
424
+ Given V returns [V, 0]
425
+ """
426
+ size = vec.shape[0]
427
+ assert size <= max_size
428
+ if size == max_size:
429
+ return vec
430
+ pad_size = max_size - size
431
+ zs1 = jnp.zeros([pad_size], dtype=vec.dtype)
432
+ return jnp.concatenate([vec, zs1], 0)
433
+
434
+
435
  def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
436
  """Avoids wasteful buffer allocation with XLA."""
437
 
 
601
  local_stats.index_start, local_stats.sizes)
602
 
603
 
604
+ def batch(x, num_devices):
605
+ """Batch `x` so that so that leading axis is num_devices."""
606
+ n = len(x)
607
+ b = int(n / num_devices)
608
+ return jnp.stack([jnp.stack(x[idx:idx + b]) for idx in range(0, n, b)])
609
+
610
+
611
+ def unbatch(batched_values):
612
+ """Unbatch values across leading axis and return a list of elements."""
613
+ b1, b2 = batched_values.shape[0], batched_values.shape[1]
614
+ results = []
615
+ for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0):
616
+ v_array = jnp.squeeze(v_array)
617
+ # b2 = batches (number of preconditioner computation) per core.
618
+ if b2 > 1:
619
+ for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
620
+ results.append(jnp.squeeze(v))
621
+ else:
622
+ results.append(v_array)
623
+ return results
624
+
625
+
626
+ def distributed_shampoo(
627
+ learning_rate,
628
+ block_size,
629
+ beta1=0.9,
630
+ beta2=0.999,
631
+ diagonal_epsilon=1e-10,
632
+ matrix_epsilon=1e-6,
633
+ weight_decay=0.0,
634
+ start_preconditioning_step=5,
635
+ preconditioning_compute_steps=1,
636
+ statistics_compute_steps=1,
637
+ best_effort_shape_interpretation=True,
638
+ graft_type=GraftingType.SGD,
639
+ nesterov=True,
640
+ exponent_override=0,
641
+ # Pass pmap 'batch axis name' in pmap mode.
642
+ batch_axis_name=None,
643
+ ### Only set following 3 params in pjit/spmd mode.
644
+ ### WARNING: Experimental
645
+ mesh_axis_names=None,
646
+ num_devices_for_pjit=None,
647
+ shard_optimizer_states=False,
648
+ ###
649
+ ### Experimental memory reduction mode
650
+ best_effort_memory_usage_reduction=False,
651
+ ###
652
+ inverse_failure_threshold=0.1,
653
+ moving_average_for_momentum=False,
654
+ skip_preconditioning_dim_size_gt=4096,
655
+ clip_by_scaled_gradient_norm=None,
656
+ precision=lax.Precision.HIGHEST):
657
  """Distributed Shampoo optimizer.
658
 
659
  Distributed Shampoo is a second-order preconditioned method (concretely, a
 
705
  num_devices_for_pjit: Number of devices to parallelize over when using pjit.
706
  shard_optimizer_states: Shard optimizer states to save memory in model
707
  parallel training.
708
+ best_effort_memory_usage_reduction: Best effort memory usage reduction.
709
+ diagonal_statistics -> jnp.bfloat16
710
+ momentum buffers (2x) -> jnp.int8
711
+ statistics, preconditioners -> jnp.int16 + diagonals
712
  inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
713
  determine that using this threshold.
714
  moving_average_for_momentum: Whether to use moving average for momentum
 
726
  a GradientTransformation.
727
  """
728
 
729
+ def quantized_dtype_for_momentum_buffers():
730
+ return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32
731
+
732
+ # TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes.
733
+ def quantized_dtype_for_diagonal_statistics_buffers():
734
+ return jnp.bfloat16 if best_effort_memory_usage_reduction else jnp.float32
735
+
736
+ # Preconditioner and statistics are both stores as int16 in this mode.
737
+ # We take out the diagonal to make quantization easier.
738
+ def quantized_dtype_for_second_moment_statistics_buffers():
739
+ return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32
740
+
741
+ # Preconditioner and statistics are both stores as int16 in this mode.
742
+ # We take out the diagonal to make quantization easier.
743
+ def quantized_dtype_for_second_moment_preconditioner_buffers():
744
+ return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32
745
+
746
+ def _to_float(maybe_quantized):
747
+ if isinstance(maybe_quantized, QuantizedValue):
748
+ return maybe_quantized.to_float()
749
+ else:
750
+ return maybe_quantized
751
+
752
+ def _maybe_quantize_statistics(statistics_list):
753
+ return _maybe_quantize_matrices_with_dtype(
754
+ statistics_list, quantized_dtype_for_second_moment_statistics_buffers())
755
+
756
+ def _maybe_quantize_preconditioners(statistics_list):
757
+ return _maybe_quantize_matrices_with_dtype(
758
+ statistics_list,
759
+ quantized_dtype_for_second_moment_preconditioner_buffers())
760
+
761
+ def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype):
762
+ if quantized_dtype != jnp.float32:
763
+ return ([
764
+ QuantizedValue.from_float_value(
765
+ s, quantized_dtype, extract_diagonal=True)
766
+ for s in statistics_list
767
+ ])
768
+ else:
769
+ return statistics_list
770
+
771
+ def _maybe_dequantize_preconditioners(preconditioner_list):
772
+ return _maybe_dequantize_matrices_with_dtype(
773
+ preconditioner_list,
774
+ quantized_dtype_for_second_moment_preconditioner_buffers())
775
+
776
+ def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype):
777
+ if quantized_dtype != jnp.float32:
778
+ return [s.to_float() for s in statistics_list]
779
+ else:
780
+ return statistics_list
781
+
782
+ def _quantize_diagonal_statistics(diagonal_statistics):
783
+ return QuantizedValue.from_float_value(
784
+ diagonal_statistics, quantized_dtype_for_diagonal_statistics_buffers())
785
+
786
+ def _quantize_momentum(momentum_statistics):
787
+ return QuantizedValue.from_float_value(
788
+ momentum_statistics, quantized_dtype_for_momentum_buffers())
789
+
790
  def sharded_init_fn(params):
791
  params_flat, treedef = jax.tree_flatten(params)
792
  # Find max size to pad to.
 
819
  padded_statistics.extend(statistics)
820
  padded_preconditioners.extend(preconditioners)
821
 
822
+ diagonal_statistics = []
823
  if graft_type != GraftingType.SGD:
824
+ diagonal_statistics = jnp.zeros_like(param)
825
  local_stats_flat.append(
826
+ LocalShardedParameterStats(
827
+ _quantize_diagonal_statistics(diagonal_statistics),
828
+ _quantize_momentum(jnp.zeros_like(param)),
829
+ _quantize_momentum(jnp.zeros_like(param)), index_start, sizes))
830
 
831
  local_stats = jax.tree_unflatten(treedef, local_stats_flat)
832
  # Pad the statistics and preconditioner matrices to be a multiple of
 
971
  statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes]
972
  preconditioners = [jnp.eye(s[0]) for s in shapes]
973
 
974
+ diagonal_statistics = []
975
  if graft_type != GraftingType.SGD:
976
+ diagonal_statistics = jnp.zeros_like(param)
977
+ return ParameterStats(
978
+ _quantize_diagonal_statistics(diagonal_statistics),
979
+ _maybe_quantize_statistics(statistics),
980
+ _maybe_quantize_preconditioners(preconditioners),
981
+ _quantize_momentum(jnp.zeros_like(param)),
982
+ _quantize_momentum(jnp.zeros_like(param)))
983
  return ShampooState(
984
  count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params))
985
 
 
1000
  new_stats = preconditioner.statistics_from_grad(grad)
1001
  new_stats_accumulators = []
1002
  for stat, stat_accumulator in zip(new_stats, state.statistics):
1003
+ new_stats_accumulators.append(w1 * _to_float(stat_accumulator) +
1004
+ w2 * stat)
1005
+ return _maybe_quantize_statistics(new_stats_accumulators)
1006
 
1007
  if statistics_compute_steps > 1:
1008
  perform_step = step % statistics_compute_steps == 0
 
1016
  state.preconditioners, state.diagonal_momentum,
1017
  state.momentum)
1018
 
1019
+ def _matrix_inverse_pth_root_vmap(xs, ps):
1020
+ mi_pth_root = functools.partial(
1021
+ matrix_inverse_pth_root,
1022
+ ridge_epsilon=matrix_epsilon,
1023
+ precision=precision)
1024
+ return jax.vmap(mi_pth_root)(xs, ps)
1025
+
1026
+ def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps):
1027
+
1028
+ def _quantized_to_float(qx, qd, qb):
1029
+ qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape))
1030
+ return qv.to_float()
1031
+
1032
+ def matrix_inverse_pth_root_wrapper(qx, qd, qb, p):
1033
+ v = _quantized_to_float(qx, qd, qb)
1034
+ preconditioner, error = matrix_inverse_pth_root(
1035
+ v, p, ridge_epsilon=matrix_epsilon, precision=precision)
1036
+ qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True)
1037
+ return qp.quantized, qp.diagonal, qp.bucket_size, error
1038
+
1039
+ return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
1040
+
1041
+ def _matrix_inverse_pth_root_pjit(xs, ps):
1042
+ mesh_axis_names_tuple = tuple(mesh_axis_names)
1043
+ # Partition the concatenated statistics matrix across all cores.
1044
+ partitioned_xs, partitioned_ps = pjit.pjit(
1045
+ lambda x, y: (x, y),
1046
+ in_axis_resources=None,
1047
+ out_axis_resources=pjit.PartitionSpec(mesh_axis_names_tuple,))(xs, ps)
1048
+ # Run matrix inverse pth root on each shard.
1049
+ partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
1050
+ partitioned_xs, partitioned_ps)
1051
+ # Recombine the outputs at each core.
1052
+ preconditioners, errors = pjit.pjit(
1053
+ lambda x, y: (x, y),
1054
+ in_axis_resources=(pjit.PartitionSpec(mesh_axis_names_tuple,),
1055
+ pjit.PartitionSpec(mesh_axis_names_tuple,)),
1056
+ out_axis_resources=(None, None))(partitioned_preconditioners,
1057
+ partitioned_errors)
1058
+ return preconditioners, errors
1059
+
1060
+ def _pmap_compute_preconditioners(states, step, statistics,
1061
+ num_statistics_per_state, original_shapes,
1062
+ exponents, max_size, prev_preconditioners):
1063
+ """Computes preconditioners for given statistics in states in PMAP mode.
1064
+
1065
+ Args:
1066
+ states: A list of optimizer states.
1067
+ step: Current step number
1068
+ statistics: A list of statistics for all variables (for every dim)
1069
+ num_statistics_per_state: Number of statistis per state to reconstruct
1070
+ output states.
1071
+ original_shapes: A list of shapes of the statistics.
1072
+ exponents: Exponent power to use for inverse-pth roots.
1073
+ max_size: Maximum dim of the statistics to pad.
1074
+ prev_preconditioners: Previously available preconditioner.
1075
+
1076
+ Returns:
1077
+ New optimizer states after computing the preconditioner.
1078
+ """
1079
+ num_devices = lax.psum(1, batch_axis_name)
1080
  num_statistics = len(statistics)
1081
+ # Pad statistics and exponents to next multiple of num_devices.
1082
+ packed_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1083
+ to_pad = -num_statistics % num_devices
1084
+ packed_statistics.extend([
1085
+ jnp.eye(max_size, dtype=packed_statistics[0].dtype)
1086
+ for _ in range(to_pad)
1087
+ ])
1088
+ exponents.extend([1 for _ in range(to_pad)])
1089
 
1090
+ if not packed_statistics:
1091
+ return states
1092
+
1093
+ all_statistics = batch(packed_statistics, num_devices)
1094
+ all_exponents = batch(exponents, num_devices)
1095
+
1096
+ def _internal_inverse_pth_root_all():
1097
+ current_replica = lax.axis_index(batch_axis_name)
1098
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
1099
+ all_statistics[current_replica], all_exponents[current_replica])
1100
+ preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
1101
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1102
+ preconditioners_flat = unbatch(preconditioners)
1103
+ errors_flat = unbatch(errors)
1104
+ return preconditioners_flat, errors_flat
1105
+
1106
+ if preconditioning_compute_steps == 1:
1107
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1108
  else:
1109
+ # Passing statistics instead of preconditioners as they are similarly
1110
+ # shaped tensors. Note statistics will be ignored as we are passing in
1111
+ # a large init value for error.
1112
+ preconditioners_init = packed_statistics
1113
+ errors_init = ([inverse_failure_threshold] * len(packed_statistics))
1114
+ init_state = [preconditioners_init, errors_init]
1115
+ perform_step = step % preconditioning_compute_steps == 0
1116
+ preconditioners_flat, errors_flat = efficient_cond(
1117
+ perform_step, _internal_inverse_pth_root_all, init_state)
1118
 
1119
+ def _skip(error):
1120
+ condition = jnp.logical_or(
1121
+ jnp.isnan(error), error >= inverse_failure_threshold)
1122
+ return condition.astype(error.dtype)
 
 
 
1123
 
1124
+ def _select_preconditioner(error, new_p, old_p):
1125
+ return lax.cond(
1126
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
1127
+
1128
+ new_preconditioners_flat = []
1129
+ for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
1130
+ prev_preconditioners, errors_flat):
1131
+ new_preconditioners_flat.append(
1132
+ _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
1133
+
1134
+ assert len(states) == len(num_statistics_per_state)
1135
+ assert len(new_preconditioners_flat) == num_statistics
 
 
 
 
 
 
1136
 
1137
+ # Add back empty preconditioners so we that we can set the optimizer state.
1138
+ preconditioners_for_states = []
1139
+ idx = 0
1140
+ for num_statistics, state in zip(num_statistics_per_state, states):
1141
+ if num_statistics == 0:
1142
+ preconditioners_for_states.append([])
 
 
 
 
 
 
 
 
 
1143
  else:
1144
+ preconditioners_for_state = new_preconditioners_flat[idx:idx +
1145
+ num_statistics]
1146
+ assert len(state.statistics) == len(preconditioners_for_state)
1147
+ preconditioners_for_states.append(preconditioners_for_state)
1148
+ idx += num_statistics
1149
+ new_states = []
1150
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1151
+ new_states.append(
1152
+ ParameterStats(state.diagonal_statistics, state.statistics,
1153
+ new_preconditioners, state.diagonal_momentum,
1154
+ state.momentum))
1155
+
1156
+ return new_states
1157
+
1158
+ def _pmap_quantized_compute_preconditioners(states, step, statistics,
1159
+ num_statistics_per_state,
1160
+ original_shapes, exponents,
1161
+ max_size, prev_preconditioners):
1162
+ """Computes preconditioners for given statistics in states in PMAP mode.
1163
+
1164
+ For quantization, each statistic is represented by three values:
1165
+ quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots
1166
+ without ever recreating the original matrix in f32.
1167
+
1168
+ Args:
1169
+ states: A list of optimizer states.
1170
+ step: Current step number
1171
+ statistics: A list of statistics for all variables (for every dim)
1172
+ num_statistics_per_state: Number of statistis per state to reconstruct
1173
+ output states.
1174
+ original_shapes: A list of shapes of the statistics.
1175
+ exponents: Exponent power to use for inverse-pth roots.
1176
+ max_size: Maximum dim of the statistics to pad.
1177
+ prev_preconditioners: Previously available preconditioner.
1178
+
1179
+ Returns:
1180
+ New optimizer states after computing the preconditioner.
1181
+ """
1182
+ num_devices = lax.psum(1, batch_axis_name)
1183
+ num_statistics = len(statistics)
1184
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
1185
+ # Complexity here is around: shapes needing be statically shaped,
1186
+ # our custom quantization type requires a different type of packing.
1187
+
1188
+ # Parallel tensors:
1189
+ # quantized [dxd]
1190
+ # diagonals [d] f32
1191
+ # bucket_sizes [d] f32
1192
+ packed_quantized_statistics = [
1193
+ pad_matrix(stat.quantized, max_size) for stat in statistics
1194
+ ]
1195
+ packed_quantized_diagonals = [
1196
+ pad_vector(stat.diagonal, max_size) for stat in statistics
1197
+ ]
1198
+ packed_quantized_bucket_sizes = [
1199
+ pad_vector(stat.bucket_size, max_size) for stat in statistics
1200
+ ]
1201
+
1202
+ to_pad = -num_statistics % num_devices
1203
+ padded_eye = jnp.eye(max_size, dtype=jnp.float32)
1204
+ quantized_eye = QuantizedValue.from_float_value(padded_eye, quantized_dtype,
1205
+ True)
1206
+ packed_quantized_statistics.extend(
1207
+ [quantized_eye.quantized for _ in range(to_pad)])
1208
+ packed_quantized_diagonals.extend(
1209
+ [quantized_eye.diagonal for _ in range(to_pad)])
1210
+ packed_quantized_bucket_sizes.extend(
1211
+ [quantized_eye.bucket_size for _ in range(to_pad)])
1212
+ exponents.extend([1 for _ in range(to_pad)])
1213
+
1214
+ if not packed_quantized_statistics:
1215
+ return states
1216
+
1217
+ all_quantized_statistics = batch(packed_quantized_statistics, num_devices)
1218
+ all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices)
1219
+ all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes,
1220
+ num_devices)
1221
+ all_exponents = batch(exponents, num_devices)
1222
+
1223
+ def _internal_inverse_pth_root_all():
1224
+ current_replica = lax.axis_index(batch_axis_name)
1225
+ quantized_preconditioners, quantized_diagonals, quantized_bucket_sizes, errors = (
1226
+ _quantized_matrix_inverse_pth_root_vmap(
1227
+ all_quantized_statistics[current_replica],
1228
+ all_quantized_diagonals[current_replica],
1229
+ all_quantized_bucket_sizes[current_replica],
1230
+ all_exponents[current_replica]))
1231
+ quantized_preconditioners = jax.lax.all_gather(quantized_preconditioners,
1232
+ batch_axis_name)
1233
+ quantized_diagonals = jax.lax.all_gather(quantized_diagonals,
1234
+ batch_axis_name)
1235
+ quantized_bucket_sizes = jax.lax.all_gather(quantized_bucket_sizes,
1236
+ batch_axis_name)
1237
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1238
+ quantized_preconditioners_flat = unbatch(quantized_preconditioners)
1239
+ quantized_diagonals_flat = unbatch(quantized_diagonals)
1240
+ quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes)
1241
+ errors_flat = unbatch(errors)
1242
+ return (quantized_preconditioners_flat, quantized_diagonals_flat,
1243
+ quantized_bucket_sizes_flat, errors_flat)
1244
+
1245
+ if preconditioning_compute_steps == 1:
1246
+ (quantized_preconditioners_flat, quantized_diagonals_flat,
1247
+ quantized_bucket_sizes_flat, errors_flat) = (
1248
+ _internal_inverse_pth_root_all())
1249
  else:
1250
+ # Passing statistics instead of preconditioners as they are similarly
1251
+ # shaped tensors. Note statistics will be ignored as we are passing in
1252
+ # a large init value for error.
1253
+ quantized_preconditioners_init = packed_quantized_statistics
1254
+ quantized_diagonals_init = packed_quantized_diagonals
1255
+ quantized_bucket_sizes_init = packed_quantized_bucket_sizes
1256
+ errors_init = ([inverse_failure_threshold] *
1257
+ len(quantized_preconditioners_init))
1258
+ init_state = [
1259
+ quantized_preconditioners_init, quantized_diagonals_init,
1260
+ quantized_bucket_sizes_init, errors_init
1261
+ ]
1262
+ perform_step = step % preconditioning_compute_steps == 0
1263
+ (quantized_preconditioners_flat, quantized_diagonals_flat,
1264
+ quantized_bucket_sizes_flat, errors_flat) = (
1265
+ efficient_cond(perform_step, _internal_inverse_pth_root_all,
1266
+ init_state))
1267
 
1268
+ def _skip(error):
1269
+ condition = jnp.logical_or(
1270
+ jnp.isnan(error), error >= inverse_failure_threshold)
1271
+ return condition.astype(error.dtype)
1272
+
1273
+ def _select_preconditioner(error, new_p, old_p):
1274
+ return lax.cond(
1275
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
1276
+
1277
+ new_quantized_preconditioners_flat = []
1278
+ new_quantized_diagonals_flat = []
1279
+ new_quantized_bucket_sizes_flat = []
1280
+ for p, d, b, shape, prev_p, error in zip(quantized_preconditioners_flat,
1281
+ quantized_diagonals_flat,
1282
+ quantized_bucket_sizes_flat,
1283
+ original_shapes,
1284
+ prev_preconditioners, errors_flat):
1285
+ new_quantized_preconditioners_flat.append(
1286
+ _select_preconditioner(error, p[:shape[0], :shape[1]],
1287
+ prev_p.quantized))
1288
+ new_quantized_diagonals_flat.append(
1289
+ _select_preconditioner(error, d[:shape[0]], prev_p.diagonal))
1290
+ new_quantized_bucket_sizes_flat.append(
1291
+ _select_preconditioner(error, b[:shape[0]], prev_p.bucket_size))
1292
+
1293
+ assert len(states) == len(num_statistics_per_state)
1294
+ assert len(new_quantized_preconditioners_flat) == num_statistics
1295
+ assert len(new_quantized_diagonals_flat) == num_statistics
1296
+ assert len(new_quantized_bucket_sizes_flat) == num_statistics
1297
+
1298
+ # Add back empty preconditioners so we that we can set the optimizer state.
1299
+ preconditioners_for_states = []
1300
+ idx = 0
1301
+ for num_statistics, state in zip(num_statistics_per_state, states):
1302
+ if num_statistics == 0:
1303
+ preconditioners_for_states.append([])
1304
  else:
1305
+ quantized_preconditioners_for_state = new_quantized_preconditioners_flat[
1306
+ idx:idx + num_statistics]
1307
+ quantized_diagonals_for_state = new_quantized_diagonals_flat[
1308
+ idx:idx + num_statistics]
1309
+ quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
1310
+ idx:idx + num_statistics]
1311
+
1312
+ assert len(state.statistics) == len(quantized_preconditioners_for_state)
1313
+ assert len(state.statistics) == len(quantized_diagonals_for_state)
1314
+ assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
1315
+
1316
+ quantized_preconditioners = []
1317
+ for qv, qd, qb in zip(quantized_preconditioners_for_state,
1318
+ quantized_diagonals_for_state,
1319
+ quantized_bucket_sizes_for_state):
1320
+ quantized_preconditioners.append(
1321
+ QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape)))
1322
+ preconditioners_for_states.append(quantized_preconditioners)
1323
+ idx += num_statistics
1324
+ new_states = []
1325
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1326
+ new_states.append(
1327
+ ParameterStats(state.diagonal_statistics, state.statistics,
1328
+ new_preconditioners, state.diagonal_momentum,
1329
+ state.momentum))
1330
+
1331
+ return new_states
1332
+
1333
+ def _pjit_compute_preconditioners(states, step, statistics,
1334
+ num_statistics_per_state, original_shapes,
1335
+ exponents, max_size, prev_preconditioners):
1336
+ """Computes preconditioners for given statistics in states in PJIT mode.
1337
+
1338
+ Args:
1339
+ states: A list of optimizer states.
1340
+ step: Current step number
1341
+ statistics: A list of statistics for all variables (for every dim)
1342
+ num_statistics_per_state: Number of statistis per state to reconstruct
1343
+ output states.
1344
+ original_shapes: A list of shapes of the statistics.
1345
+ exponents: Exponent power to use for inverse-pth roots.
1346
+ max_size: Maximum dim of the statistics to pad.
1347
+ prev_preconditioners: Previously available preconditioner.
1348
+
1349
+ Returns:
1350
+ New optimizer states after computing the preconditioner.
1351
+ """
1352
+ num_statistics = len(statistics)
1353
+ to_pad = -num_statistics % num_devices_for_pjit
1354
+ padded_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1355
+ padded_statistics.extend([
1356
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
1357
+ for _ in range(to_pad)
1358
+ ])
1359
+ exponents.extend([1 for _ in range(to_pad)])
1360
+ all_statistics = jnp.stack(padded_statistics)
1361
+ all_exponents = jnp.stack(exponents)
1362
+
1363
+ def _internal_inverse_pth_root_all():
1364
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
1365
+ all_statistics, all_exponents)
1366
+ b1 = preconditioners.shape[0]
1367
+
1368
+ def split(batched_values):
1369
+ return [
1370
+ jnp.squeeze(v)
1371
+ for v in jnp.split(batched_values, indices_or_sections=b1, axis=0)
1372
+ ]
1373
+
1374
+ return split(preconditioners), split(errors)
1375
+
1376
+ if preconditioning_compute_steps == 1:
1377
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1378
+ else:
1379
+ # Passing statistics instead of preconditioners as they are similarly
1380
+ # shaped tensors. Note statistics will be ignored as we are passing in
1381
+ # a large init value for error.
1382
+ preconditioners_init = padded_statistics
1383
+ errors_init = [inverse_failure_threshold] * len(padded_statistics)
1384
+ init_state = [preconditioners_init, errors_init]
1385
+ perform_step = step % preconditioning_compute_steps == 0
1386
+ preconditioners_flat, errors_flat = efficient_cond(
1387
+ perform_step, _internal_inverse_pth_root_all, init_state)
1388
 
1389
  def _skip(error):
1390
  condition = jnp.logical_or(
 
1425
 
1426
  return new_states
1427
 
1428
+ def _compute_preconditioners(states, params, step):
1429
+ """Computes preconditioners for given statistics in states.
1430
+
1431
+ Args:
1432
+ states: A list of optimizer states.
1433
+ params: A list of params.
1434
+ step: Current step number
1435
+
1436
+ Returns:
1437
+ New optimizer states after computing the preconditioner.
1438
+ """
1439
+ statistics = []
1440
+ num_statistics_per_state = []
1441
+ original_shapes = []
1442
+ exponents = []
1443
+ max_size = 0
1444
+ prev_preconditioners = []
1445
+
1446
+ for state, param in zip(states, params):
1447
+ num_statistics = len(state.statistics)
1448
+ num_statistics_per_state.append(num_statistics)
1449
+ original_shapes_for_state = []
1450
+ if num_statistics > 0:
1451
+ preconditioner = Preconditioner(param, block_size,
1452
+ best_effort_shape_interpretation)
1453
+ for statistic in state.statistics:
1454
+ exponents.append(preconditioner.exponent_for_preconditioner(
1455
+ ) if exponent_override == 0 else exponent_override)
1456
+ original_shapes_for_state.append(statistic.shape)
1457
+ max_size = max(max_size, statistic.shape[0])
1458
+
1459
+ statistics.extend(state.statistics)
1460
+ prev_preconditioners.extend(state.preconditioners)
1461
+ original_shapes.extend(original_shapes_for_state)
1462
+
1463
+ if batch_axis_name:
1464
+ # Quantization is only enabled if batch_axis_name is not set.
1465
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
1466
+
1467
+ if quantized_dtype == jnp.float32:
1468
+ return _pmap_compute_preconditioners(states, step, statistics,
1469
+ num_statistics_per_state,
1470
+ original_shapes, exponents,
1471
+ max_size, prev_preconditioners)
1472
+ else:
1473
+ return _pmap_quantized_compute_preconditioners(
1474
+ states, step, statistics, num_statistics_per_state, original_shapes,
1475
+ exponents, max_size, prev_preconditioners)
1476
+
1477
+ else:
1478
+ return _pjit_compute_preconditioners(states, step, statistics,
1479
+ num_statistics_per_state,
1480
+ original_shapes, exponents, max_size,
1481
+ prev_preconditioners)
1482
+
1483
  def _transform_grad(grad, state, param, step):
1484
  """Transform per-parameter gradients."""
1485
  preconditioner = Preconditioner(param, block_size,
1486
  best_effort_shape_interpretation)
1487
  sgd_update = grad
1488
+ new_diagonal_statistics = state.diagonal_statistics.to_float()
1489
  if graft_type == GraftingType.ADAGRAD:
1490
+ new_diagonal_statistics = state.diagonal_statistics.to_float(
1491
+ ) + jnp.square(grad)
1492
  adagrad_update = grad / (
1493
  jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
1494
  grafting_update = adagrad_update
 
1503
  w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
1504
 
1505
  new_diagonal_statistics = (
1506
+ w1 * state.diagonal_statistics.to_float() +
1507
+ w2 * jnp.square(scaled_grad))
1508
  rmsprop_update = scaled_grad / (
1509
  jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
1510
 
 
1521
 
1522
  precond_grad = grad
1523
  if not _skip_preconditioning(param):
1524
+ precond_grad = preconditioner.preconditioned_grad(
1525
+ precond_grad,
1526
+ _maybe_dequantize_preconditioners(state.preconditioners))
1527
  else:
1528
  precond_grad = grafting_update
1529
 
 
1541
 
1542
  w = (1.0 - beta1) if moving_average_for_momentum else 1.0
1543
  shampoo_update_with_wd_momentum = (
1544
+ state.momentum.to_float() * beta1 + w * shampoo_update_with_wd)
1545
  grafting_update_with_wd_momentum = (
1546
+ state.diagonal_momentum.to_float() * beta1 +
1547
+ w * grafting_update_with_wd)
1548
 
1549
  run_shampoo = (step >= start_preconditioning_step).astype(
1550
  grafting_update_with_wd_momentum.dtype)
 
1565
  lr = learning_rate(step)
1566
  transformed_update = -1.0 * lr * momentum_update
1567
 
1568
+ param_stats = ParameterStats(
1569
+ _quantize_diagonal_statistics(new_diagonal_statistics),
1570
+ state.statistics, state.preconditioners,
1571
+ _quantize_momentum(grafting_update_with_wd_momentum),
1572
+ _quantize_momentum(shampoo_update_with_wd_momentum))
1573
  return transformed_update, param_stats
1574
 
1575
  def update_fn(grads, state, params):