boris commited on
Commit
5996680
·
1 Parent(s): fa72aa7

feat: update distributed_shampoo

Browse files
Files changed (1) hide show
  1. tools/train/distributed_shampoo.py +86 -38
tools/train/distributed_shampoo.py CHANGED
@@ -34,13 +34,13 @@ import itertools
34
  from typing import Any, List, NamedTuple
35
 
36
  import chex
 
37
  import jax
 
38
  import jax.experimental.pjit as pjit
39
  import jax.numpy as jnp
40
  import numpy as np
41
  import optax
42
- from flax import struct
43
- from jax import lax
44
 
45
 
46
  # pylint:disable=no-value-for-parameter
@@ -234,6 +234,8 @@ class GraftingType(enum.IntEnum):
234
  ADAGRAD = 2
235
  RMSPROP = 3
236
  RMSPROP_NORMALIZED = 4
 
 
237
 
238
 
239
  def power_iteration(
@@ -336,7 +338,7 @@ def matrix_inverse_pth_root(
336
  _, max_ev = power_iteration(
337
  matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision
338
  )
339
- ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16)
340
 
341
  def _unrolled_mat_pow_1(mat_m):
342
  """Computes mat_m^1."""
@@ -791,8 +793,7 @@ def distributed_shampoo(
791
  block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
792
  graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
793
  optimizer. This allows us to plugin the Shampoo optimizer into settings
794
- where SGD/AdaGrad is already well tuned. Available options are:
795
- GraftingType.SGD and GraftingType.ADAGRAD.
796
  nesterov: Nesterov momentum.
797
  exponent_override: Override the exponent used in matrix inverse.
798
  batch_axis_name: labeled axis over pmap for data-parallel training the
@@ -823,12 +824,20 @@ def distributed_shampoo(
823
  a GradientTransformation.
824
  """
825
 
 
 
 
 
 
 
 
 
826
  def quantized_dtype_for_momentum_buffers():
827
  return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32
828
 
829
  # TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes.
830
  def quantized_dtype_for_diagonal_statistics_buffers():
831
- return jnp.bfloat16 if best_effort_memory_usage_reduction else jnp.float32
832
 
833
  # Preconditioner and statistics are both stores as int16 in this mode.
834
  # We take out the diagonal to make quantization easier.
@@ -944,13 +953,19 @@ def distributed_shampoo(
944
  exponents.extend([exponent] * len(shapes))
945
 
946
  diagonal_statistics = []
947
- if graft_type != GraftingType.SGD:
948
  diagonal_statistics = jnp.zeros_like(param)
 
 
 
 
 
 
949
  local_stats_flat.append(
950
  LocalShardedParameterStats(
951
  _quantize_diagonal_statistics(diagonal_statistics),
952
- _quantize_momentum(jnp.zeros_like(param)),
953
- _quantize_momentum(jnp.zeros_like(param)),
954
  init_training_metrics(len(sizes)),
955
  index_start,
956
  sizes,
@@ -1039,7 +1054,7 @@ def distributed_shampoo(
1039
 
1040
  diagonal_statistics_pspec = []
1041
  diagonal_statistics_scale_pspec = []
1042
- if graft_type != GraftingType.SGD:
1043
  # Identically shaped param.
1044
  diagonal_statistics_pspec = param_pspec
1045
  if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
@@ -1047,14 +1062,16 @@ def distributed_shampoo(
1047
  _remove_leading_sharding_annotation(param_pspec)
1048
  )
1049
 
1050
- m1_pspec = param_pspec
1051
- m2_pspec = param_pspec
1052
-
1053
  m1_scale_pspec = []
1054
- m2_scale_pspec = []
 
 
 
1055
 
 
 
1056
  if quantized_dtype_for_momentum_buffers() != jnp.float32:
1057
- m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
1058
  m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec)
1059
 
1060
  local_stats_flat.append(
@@ -1130,7 +1147,7 @@ def distributed_shampoo(
1130
 
1131
  diagonal_statistics_shape_and_dtype = []
1132
  diagonal_statistics_scale_shape_and_dtype = []
1133
- if graft_type != GraftingType.SGD:
1134
  diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype]
1135
  qdtype = quantized_dtype_for_diagonal_statistics_buffers()
1136
  if qdtype != jnp.float32:
@@ -1140,18 +1157,18 @@ def distributed_shampoo(
1140
  param.dtype,
1141
  ]
1142
 
1143
- m1_shape_and_dtype = [list(param.shape), param.dtype]
1144
- m2_shape_and_dtype = [list(param.shape), param.dtype]
1145
-
1146
  m1_scale_shape_and_dtype = []
1147
- m2_scale_shape_and_dtype = []
 
 
 
1148
 
1149
- qdtype = quantized_dtype_for_momentum_buffers()
 
1150
  if qdtype != jnp.float32:
1151
- m1_shape_and_dtype = [list(param.shape), qdtype]
1152
  m2_shape_and_dtype = [list(param.shape), qdtype]
1153
-
1154
- m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1155
  m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1156
 
1157
  local_stats_flat.append(
@@ -1331,14 +1348,20 @@ def distributed_shampoo(
1331
  preconditioners = [jnp.eye(s[0]) for s in shapes]
1332
 
1333
  diagonal_statistics = []
1334
- if graft_type != GraftingType.SGD:
1335
  diagonal_statistics = jnp.zeros_like(param)
 
 
 
 
 
 
1336
  return ParameterStats(
1337
  _quantize_diagonal_statistics(diagonal_statistics),
1338
  _maybe_quantize_statistics(statistics),
1339
  _maybe_quantize_preconditioners(preconditioners),
1340
- _quantize_momentum(jnp.zeros_like(param)),
1341
- _quantize_momentum(jnp.zeros_like(param)),
1342
  init_training_metrics(len(statistics)),
1343
  )
1344
 
@@ -2037,11 +2060,19 @@ def distributed_shampoo(
2037
  )
2038
  sgd_update = grad
2039
  new_diagonal_statistics = state.diagonal_statistics.to_float()
2040
- if graft_type == GraftingType.ADAGRAD:
 
 
 
 
 
 
 
 
2041
  new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square(
2042
- grad
2043
  )
2044
- adagrad_update = grad / (
2045
  jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
2046
  )
2047
  grafting_update = adagrad_update
@@ -2074,8 +2105,10 @@ def distributed_shampoo(
2074
  rmsprop_update /= clipping_denom
2075
 
2076
  grafting_update = rmsprop_update
2077
- else:
2078
  grafting_update = sgd_update
 
 
2079
 
2080
  precond_grad = grad
2081
  if not _skip_preconditioning(param):
@@ -2098,12 +2131,20 @@ def distributed_shampoo(
2098
  grafting_update_with_wd = grafting_update + weight_decay * param
2099
 
2100
  w = (1.0 - beta1) if moving_average_for_momentum else 1.0
 
2101
  shampoo_update_with_wd_momentum = (
2102
  state.momentum.to_float() * beta1 + w * shampoo_update_with_wd
2103
  )
2104
- grafting_update_with_wd_momentum = (
2105
- state.diagonal_momentum.to_float() * beta1 + w * grafting_update_with_wd
2106
- )
 
 
 
 
 
 
 
2107
 
2108
  run_shampoo = (step >= start_preconditioning_step).astype(
2109
  grafting_update_with_wd_momentum.dtype
@@ -2119,20 +2160,27 @@ def distributed_shampoo(
2119
  + (1.0 - run_shampoo) * grafting_update_with_wd
2120
  )
2121
 
 
2122
  if nesterov:
2123
- momentum_update = w * wd_update + beta1 * momentum_update
2124
 
2125
  lr = learning_rate
2126
  if callable(learning_rate):
2127
  lr = learning_rate(step)
2128
- transformed_update = -1.0 * lr * momentum_update
 
 
 
 
 
 
2129
 
2130
  param_stats = ParameterStats(
2131
  _quantize_diagonal_statistics(new_diagonal_statistics),
2132
  state.statistics,
2133
  state.preconditioners,
2134
- _quantize_momentum(grafting_update_with_wd_momentum),
2135
- _quantize_momentum(shampoo_update_with_wd_momentum),
2136
  state.training_metrics,
2137
  )
2138
 
 
34
  from typing import Any, List, NamedTuple
35
 
36
  import chex
37
+ from flax import struct
38
  import jax
39
+ from jax import lax
40
  import jax.experimental.pjit as pjit
41
  import jax.numpy as jnp
42
  import numpy as np
43
  import optax
 
 
44
 
45
 
46
  # pylint:disable=no-value-for-parameter
 
234
  ADAGRAD = 2
235
  RMSPROP = 3
236
  RMSPROP_NORMALIZED = 4
237
+ SQRT_N = 5
238
+ ADAGRAD_NORMALIZED = 5
239
 
240
 
241
  def power_iteration(
 
338
  _, max_ev = power_iteration(
339
  matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision
340
  )
341
+ ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-6)
342
 
343
  def _unrolled_mat_pow_1(mat_m):
344
  """Computes mat_m^1."""
 
793
  block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
794
  graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
795
  optimizer. This allows us to plugin the Shampoo optimizer into settings
796
+ where SGD/AdaGrad is already well tuned.
 
797
  nesterov: Nesterov momentum.
798
  exponent_override: Override the exponent used in matrix inverse.
799
  batch_axis_name: labeled axis over pmap for data-parallel training the
 
824
  a GradientTransformation.
825
  """
826
 
827
+ def _graft_type_has_diagonal_statistics():
828
+ """Returns True if using diagonal firt order method for grafting."""
829
+ return graft_type != GraftingType.SGD and graft_type != GraftingType.SQRT_N
830
+
831
+ def _graft_type_has_diagonal_momentum_states():
832
+ """Returns False if using SQRT_N for grafting."""
833
+ return graft_type != GraftingType.SQRT_N
834
+
835
  def quantized_dtype_for_momentum_buffers():
836
  return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32
837
 
838
  # TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes.
839
  def quantized_dtype_for_diagonal_statistics_buffers():
840
+ return jnp.float32
841
 
842
  # Preconditioner and statistics are both stores as int16 in this mode.
843
  # We take out the diagonal to make quantization easier.
 
953
  exponents.extend([exponent] * len(shapes))
954
 
955
  diagonal_statistics = []
956
+ if _graft_type_has_diagonal_statistics():
957
  diagonal_statistics = jnp.zeros_like(param)
958
+
959
+ diagonal_momentum = _quantize_momentum([])
960
+ momentum = _quantize_momentum(jnp.zeros_like(param))
961
+ if _graft_type_has_diagonal_momentum_states():
962
+ diagonal_momentum = _quantize_momentum((jnp.zeros_like(param)))
963
+
964
  local_stats_flat.append(
965
  LocalShardedParameterStats(
966
  _quantize_diagonal_statistics(diagonal_statistics),
967
+ diagonal_momentum,
968
+ momentum,
969
  init_training_metrics(len(sizes)),
970
  index_start,
971
  sizes,
 
1054
 
1055
  diagonal_statistics_pspec = []
1056
  diagonal_statistics_scale_pspec = []
1057
+ if _graft_type_has_diagonal_statistics():
1058
  # Identically shaped param.
1059
  diagonal_statistics_pspec = param_pspec
1060
  if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
 
1062
  _remove_leading_sharding_annotation(param_pspec)
1063
  )
1064
 
1065
+ m1_pspec = []
 
 
1066
  m1_scale_pspec = []
1067
+ if _graft_type_has_diagonal_momentum_states():
1068
+ m1_pspec = param_pspec
1069
+ if quantized_dtype_for_momentum_buffers() != jnp.float32:
1070
+ m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
1071
 
1072
+ m2_pspec = param_pspec
1073
+ m2_scale_pspec = []
1074
  if quantized_dtype_for_momentum_buffers() != jnp.float32:
 
1075
  m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec)
1076
 
1077
  local_stats_flat.append(
 
1147
 
1148
  diagonal_statistics_shape_and_dtype = []
1149
  diagonal_statistics_scale_shape_and_dtype = []
1150
+ if _graft_type_has_diagonal_statistics():
1151
  diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype]
1152
  qdtype = quantized_dtype_for_diagonal_statistics_buffers()
1153
  if qdtype != jnp.float32:
 
1157
  param.dtype,
1158
  ]
1159
 
1160
+ qdtype = quantized_dtype_for_momentum_buffers()
1161
+ m1_shape_and_dtype = []
 
1162
  m1_scale_shape_and_dtype = []
1163
+ if _graft_type_has_diagonal_momentum_states():
1164
+ m1_shape_and_dtype = [list(param.shape), qdtype]
1165
+ if quantized_dtype_for_momentum_buffers() != jnp.float32:
1166
+ m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1167
 
1168
+ m2_shape_and_dtype = [list(param.shape), param.dtype]
1169
+ m2_scale_shape_and_dtype = []
1170
  if qdtype != jnp.float32:
 
1171
  m2_shape_and_dtype = [list(param.shape), qdtype]
 
 
1172
  m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1173
 
1174
  local_stats_flat.append(
 
1348
  preconditioners = [jnp.eye(s[0]) for s in shapes]
1349
 
1350
  diagonal_statistics = []
1351
+ if _graft_type_has_diagonal_statistics():
1352
  diagonal_statistics = jnp.zeros_like(param)
1353
+
1354
+ diagonal_momentum = _quantize_momentum([])
1355
+ momentum = _quantize_momentum(jnp.zeros_like(param))
1356
+ if _graft_type_has_diagonal_momentum_states():
1357
+ diagonal_momentum = _quantize_momentum(jnp.zeros_like(param))
1358
+
1359
  return ParameterStats(
1360
  _quantize_diagonal_statistics(diagonal_statistics),
1361
  _maybe_quantize_statistics(statistics),
1362
  _maybe_quantize_preconditioners(preconditioners),
1363
+ diagonal_momentum,
1364
+ momentum,
1365
  init_training_metrics(len(statistics)),
1366
  )
1367
 
 
2060
  )
2061
  sgd_update = grad
2062
  new_diagonal_statistics = state.diagonal_statistics.to_float()
2063
+ if (
2064
+ graft_type == GraftingType.ADAGRAD
2065
+ or graft_type == GraftingType.ADAGRAD_NORMALIZED
2066
+ ):
2067
+
2068
+ scaled_grad = grad
2069
+ if graft_type == GraftingType.ADAGRAD_NORMALIZED:
2070
+ scaled_grad = grad / jnp.linalg.norm(grad)
2071
+
2072
  new_diagonal_statistics = state.diagonal_statistics.to_float() + jnp.square(
2073
+ scaled_grad
2074
  )
2075
+ adagrad_update = scaled_grad / (
2076
  jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
2077
  )
2078
  grafting_update = adagrad_update
 
2105
  rmsprop_update /= clipping_denom
2106
 
2107
  grafting_update = rmsprop_update
2108
+ elif graft_type == GraftingType.SGD:
2109
  grafting_update = sgd_update
2110
+ else:
2111
+ grafting_update = jnp.ones_like(sgd_update) * jnp.sign(sgd_update)
2112
 
2113
  precond_grad = grad
2114
  if not _skip_preconditioning(param):
 
2131
  grafting_update_with_wd = grafting_update + weight_decay * param
2132
 
2133
  w = (1.0 - beta1) if moving_average_for_momentum else 1.0
2134
+
2135
  shampoo_update_with_wd_momentum = (
2136
  state.momentum.to_float() * beta1 + w * shampoo_update_with_wd
2137
  )
2138
+
2139
+ if _graft_type_has_diagonal_momentum_states():
2140
+ grafting_update_with_wd_momentum = (
2141
+ state.diagonal_momentum.to_float() * beta1 + w * grafting_update_with_wd
2142
+ )
2143
+ else:
2144
+ # Share the momentum buffer
2145
+ grafting_update_with_wd_momentum = (
2146
+ state.momentum.to_float() * beta1 + w * grafting_update_with_wd
2147
+ )
2148
 
2149
  run_shampoo = (step >= start_preconditioning_step).astype(
2150
  grafting_update_with_wd_momentum.dtype
 
2160
  + (1.0 - run_shampoo) * grafting_update_with_wd
2161
  )
2162
 
2163
+ nesterov_momentum_update = momentum_update
2164
  if nesterov:
2165
+ nesterov_momentum_update = w * wd_update + beta1 * momentum_update
2166
 
2167
  lr = learning_rate
2168
  if callable(learning_rate):
2169
  lr = learning_rate(step)
2170
+ transformed_update = -1.0 * lr * nesterov_momentum_update
2171
+
2172
+ new_diagonal_momentum = grafting_update_with_wd_momentum
2173
+ new_momentum = shampoo_update_with_wd_momentum
2174
+ if not _graft_type_has_diagonal_momentum_states():
2175
+ new_diagonal_momentum = []
2176
+ new_momentum = momentum_update
2177
 
2178
  param_stats = ParameterStats(
2179
  _quantize_diagonal_statistics(new_diagonal_statistics),
2180
  state.statistics,
2181
  state.preconditioners,
2182
+ _quantize_momentum(new_diagonal_momentum),
2183
+ _quantize_momentum(new_momentum),
2184
  state.training_metrics,
2185
  )
2186