boris commited on
Commit
0b87452
1 Parent(s): e1555d4

feat: add shampoo optimizer

Browse files
tools/train/distributed_shampoo.py ADDED
@@ -0,0 +1,1134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """File copied from https://github.com/google-research/google-research/edit/master/scalable_shampoo/optax/distributed_shampoo.py"""
2
+
3
+ # coding=utf-8
4
+ # Copyright 2021 The Google Research Authors.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ # An implementation of distributed Shampoo optimizer from:
19
+ #
20
+ # Scalable Second Order Optimization for Deep Learning
21
+ # Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
22
+ # Preprint Paper: https://arxiv.org/abs/2002.09018
23
+ #
24
+ # This implementation moves computation of inverse pth root back to the
25
+ # accelerator (if higher precision is available).
26
+ #
27
+ # Authors: Rohan Anil (rohananil at google dot com)
28
+ # & Vineet Gupta (vineet at google dot com)
29
+ #
30
+
31
+ """Distributed Shampoo Implementation."""
32
+
33
+ import enum
34
+ import functools
35
+ import itertools
36
+ from typing import Any, NamedTuple
37
+
38
+ import chex
39
+ from flax import struct
40
+ import jax
41
+ from jax import lax
42
+ import jax.experimental.pjit as pjit
43
+ import jax.numpy as jnp
44
+ import numpy as np
45
+ import optax
46
+
47
+
48
+ # pylint:disable=no-value-for-parameter
49
+
50
+
51
+ # Per parameter optimizer state used in data-parallel training.
52
+ class ParameterStats(NamedTuple):
53
+ """State associated to each parameter of the model being trained."""
54
+ diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner
55
+ statistics: chex.Array # Statistics
56
+ preconditioners: chex.Array # Preconditioners
57
+ diagonal_momentum: chex.Array # Momentum for the diagonal preconditioner
58
+ momentum: chex.Array # Momentum for the shampoo preconditioner
59
+
60
+
61
+ # For training extremely large model; We keep a global state with a concatenated
62
+ # statistics and preconditioner states for all vars. This is so that we can
63
+ # annotate the leading axis to be sharded to save memory at the cost of
64
+ # communication.
65
+ @struct.dataclass
66
+ class GlobalShardedParameterStats:
67
+ statistics: chex.Array # Statistics
68
+ preconditioners: chex.Array # Preconditioners
69
+
70
+
71
+ # These are per-parameter local states; All statistics here mirror the parameter
72
+ # Thus the sharding is copied over from the param specification.
73
+ @struct.dataclass
74
+ class LocalShardedParameterStats:
75
+ """State associated to each parameter of the model being trained."""
76
+ diagonal_statistics: chex.Array # Accumulator for diagonal preconditioner
77
+ diagonal_momentum: chex.Array # Momentum for the diagonal preconditioner
78
+ momentum: chex.Array # Momentum for the shampoo preconditioner
79
+ index_start: np.int32 = struct.field(
80
+ pytree_node=False) # Index into global statistics array
81
+ sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
82
+
83
+
84
+ class ShardedShampooStats(NamedTuple):
85
+ """Shampoo state in sharded mode."""
86
+ global_stats: Any
87
+ local_stats: Any
88
+
89
+
90
+ class ShampooState(NamedTuple):
91
+ count: chex.Array
92
+ stats: Any
93
+
94
+
95
+ class GraftingType(enum.IntEnum):
96
+ SGD = 1
97
+ ADAGRAD = 2
98
+ RMSPROP = 3
99
+ RMSPROP_NORMALIZED = 4
100
+
101
+
102
+ def power_iteration(
103
+ matrix,
104
+ num_iters=100,
105
+ error_tolerance=1e-6,
106
+ precision=lax.Precision.HIGHEST):
107
+ r"""Power iteration algorithm.
108
+
109
+ The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
110
+ a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
111
+ of `A`, and a vector v, which is the corresponding eigenvector of `A`.
112
+
113
+ References:
114
+ [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
115
+
116
+ Args:
117
+ matrix: the symmetric PSD matrix.
118
+ num_iters: Number of iterations.
119
+ error_tolerance: Iterative exit condition.
120
+ precision: precision XLA related flag, the available options are:
121
+ a) lax.Precision.DEFAULT (better step time, but not precise)
122
+ b) lax.Precision.HIGH (increased precision, slower)
123
+ c) lax.Precision.HIGHEST (best possible precision, slowest)
124
+
125
+ Returns:
126
+ eigen vector, eigen value
127
+ """
128
+ matrix_size = matrix.shape[-1]
129
+ def _iter_condition(state):
130
+ i, unused_v, unused_s, unused_s_v, run_step = state
131
+ return jnp.logical_and(i < num_iters, run_step)
132
+
133
+ def _iter_body(state):
134
+ """One step of power iteration."""
135
+ i, new_v, s, s_v, unused_run_step = state
136
+ new_v = new_v / jnp.linalg.norm(new_v)
137
+
138
+ s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision)
139
+ s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision)
140
+ return (i + 1, s_v, s_new, s_v,
141
+ jnp.greater(jnp.abs(s_new - s), error_tolerance))
142
+
143
+ # Figure out how to use step as seed for random.
144
+ v_0 = np.random.uniform(-1.0, 1.0, matrix_size).astype(matrix.dtype)
145
+
146
+ init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
147
+ _, v_out, s_out, _, _ = lax.while_loop(
148
+ _iter_condition, _iter_body, init_state)
149
+ v_out = v_out / jnp.linalg.norm(v_out)
150
+ return v_out, s_out
151
+
152
+
153
+ def matrix_inverse_pth_root(
154
+ matrix,
155
+ p,
156
+ num_iters=100,
157
+ ridge_epsilon=1e-6,
158
+ error_tolerance=1e-6,
159
+ precision=lax.Precision.HIGHEST):
160
+ """Computes `matrix^(-1/p)`, where `p` is a positive integer.
161
+
162
+ This function uses the Coupled newton iterations algorithm for
163
+ the computation of a matrix's inverse pth root.
164
+
165
+
166
+ References:
167
+ [Functions of Matrices, Theory and Computation,
168
+ Nicholas J Higham, Pg 184, Eq 7.18](
169
+ https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
170
+
171
+ Args:
172
+ matrix: the symmetric PSD matrix whose power it to be computed
173
+ p: exponent, for p a positive integer.
174
+ num_iters: Maximum number of iterations.
175
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
176
+ error_tolerance: Error indicator, useful for early termination.
177
+ precision: precision XLA related flag, the available options are:
178
+ a) lax.Precision.DEFAULT (better step time, but not precise)
179
+ b) lax.Precision.HIGH (increased precision, slower)
180
+ c) lax.Precision.HIGHEST (best possible precision, slowest)
181
+
182
+ Returns:
183
+ matrix^(-1/p)
184
+ """
185
+
186
+ # We use float32 for the matrix inverse pth root.
187
+ # Switch to f64 if you have hardware that supports it.
188
+ matrix_size = matrix.shape[0]
189
+ alpha = jnp.asarray(-1.0 / p, jnp.float32)
190
+ identity = jnp.eye(matrix_size, dtype=jnp.float32)
191
+ _, max_ev = power_iteration(
192
+ matrix=matrix, num_iters=100,
193
+ error_tolerance=1e-6, precision=precision)
194
+ ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16)
195
+
196
+ def _unrolled_mat_pow_1(mat_m):
197
+ """Computes mat_m^1."""
198
+ return mat_m
199
+
200
+ def _unrolled_mat_pow_2(mat_m):
201
+ """Computes mat_m^2."""
202
+ return jnp.matmul(mat_m, mat_m, precision=precision)
203
+
204
+ def _unrolled_mat_pow_4(mat_m):
205
+ """Computes mat_m^4."""
206
+ mat_pow_2 = _unrolled_mat_pow_2(mat_m)
207
+ return jnp.matmul(
208
+ mat_pow_2, mat_pow_2, precision=precision)
209
+
210
+ def _unrolled_mat_pow_8(mat_m):
211
+ """Computes mat_m^4."""
212
+ mat_pow_4 = _unrolled_mat_pow_4(mat_m)
213
+ return jnp.matmul(
214
+ mat_pow_4, mat_pow_4, precision=precision)
215
+
216
+ def mat_power(mat_m, p):
217
+ """Computes mat_m^p, for p == 1, 2, 4 or 8.
218
+
219
+ Args:
220
+ mat_m: a square matrix
221
+ p: a positive integer
222
+
223
+ Returns:
224
+ mat_m^p
225
+ """
226
+ # We unrolled the loop for performance reasons.
227
+ exponent = jnp.round(jnp.log2(p))
228
+ return lax.switch(
229
+ jnp.asarray(exponent, jnp.int32), [
230
+ _unrolled_mat_pow_1,
231
+ _unrolled_mat_pow_2,
232
+ _unrolled_mat_pow_4,
233
+ _unrolled_mat_pow_8,
234
+ ], (mat_m))
235
+
236
+ def _iter_condition(state):
237
+ (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error,
238
+ run_step) = state
239
+ error_above_threshold = jnp.logical_and(
240
+ error > error_tolerance, run_step)
241
+ return jnp.logical_and(i < num_iters, error_above_threshold)
242
+
243
+ def _iter_body(state):
244
+ (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state
245
+ mat_m_i = (1 - alpha) * identity + alpha * mat_m
246
+ new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision)
247
+ new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision)
248
+ new_error = jnp.max(jnp.abs(new_mat_m - identity))
249
+ # sometimes error increases after an iteration before decreasing and
250
+ # converging. 1.2 factor is used to bound the maximal allowed increase.
251
+ return (i + 1, new_mat_m, new_mat_h, mat_h, new_error,
252
+ new_error < error * 1.2)
253
+
254
+ if matrix_size == 1:
255
+ resultant_mat_h = (matrix + ridge_epsilon)**alpha
256
+ error = 0
257
+ else:
258
+ damped_matrix = matrix + ridge_epsilon * identity
259
+
260
+ z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix))
261
+ new_mat_m_0 = damped_matrix * z
262
+ new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
263
+ new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
264
+ init_state = tuple(
265
+ [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True])
266
+ _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
267
+ _iter_condition, _iter_body, init_state)
268
+ error = jnp.max(jnp.abs(mat_m - identity))
269
+ is_converged = jnp.asarray(convergence, old_mat_h.dtype)
270
+ resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
271
+ resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype)
272
+ return resultant_mat_h, error
273
+
274
+
275
+ def merge_small_dims(shape_to_merge, max_dim):
276
+ """Merge small dimensions.
277
+
278
+ If there are some small dimensions, we collapse them:
279
+ e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
280
+ [1, 2, 768, 1, 2048] --> [2, 768, 2048]
281
+
282
+ Args:
283
+ shape_to_merge: Shape to merge small dimensions.
284
+ max_dim: Maximal dimension of output shape used in merging.
285
+
286
+ Returns:
287
+ Merged shape.
288
+ """
289
+ resulting_shape = []
290
+ product = 1
291
+ for d in shape_to_merge:
292
+ if product * d <= max_dim:
293
+ product *= d
294
+ else:
295
+ if product > 1:
296
+ resulting_shape.append(product)
297
+ product = d
298
+ if product > 1:
299
+ resulting_shape.append(product)
300
+ return resulting_shape
301
+
302
+
303
+ def pad_matrix(mat, max_size):
304
+ """Pad a matrix to a max_size.
305
+
306
+ Args:
307
+ mat: a matrix to pad.
308
+ max_size: matrix size requested.
309
+
310
+ Returns:
311
+ Given M returns [[M, 0], [0, I]]
312
+ """
313
+ size = mat.shape[0]
314
+ assert size <= max_size
315
+ if size == max_size:
316
+ return mat
317
+ pad_size = max_size - size
318
+ zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype)
319
+ zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype)
320
+ eye = jnp.eye(pad_size, dtype=mat.dtype)
321
+ mat = jnp.concatenate([mat, zs1], 1)
322
+ mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
323
+ return mat
324
+
325
+
326
+ def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
327
+ """Avoids wasteful buffer allocation with XLA."""
328
+
329
+ def _iter_body(unused_state):
330
+ results = compute_fn(*args, **kwargs)
331
+ return tuple([False] + list(results))
332
+
333
+ def _iter_condition(state):
334
+ return state[0]
335
+
336
+ results = jax.lax.while_loop(_iter_condition, _iter_body,
337
+ tuple([predicate] + init_state))
338
+ return tuple(results[1:])
339
+
340
+
341
+ class BlockPartitioner:
342
+ """Partitions a tensor into smaller tensors."""
343
+
344
+ def __init__(self, param, block_size):
345
+ self._shape = param.shape
346
+ self._splits = []
347
+ split_sizes = []
348
+ # We split params into smaller blocks. Here we store the metadata to make
349
+ # that split.
350
+ for i, d in enumerate(param.shape):
351
+ if 0 < block_size < d:
352
+ # d-1, otherwise split appends a 0-size array.
353
+ nsplit = (d - 1) // block_size
354
+ indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size
355
+ sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size
356
+ sizes[-1] = d - indices[-1]
357
+ self._splits.append((i, indices))
358
+ split_sizes.append(sizes)
359
+ else:
360
+ split_sizes.append(np.array([d], dtype=np.int32))
361
+ self._num_splits = len(split_sizes)
362
+ self._preconditioner_shapes = []
363
+ for t in itertools.product(*split_sizes):
364
+ self._preconditioner_shapes.extend([[d, d] for d in t])
365
+
366
+ def shapes_for_preconditioners(self):
367
+ return self._preconditioner_shapes
368
+
369
+ def num_splits(self):
370
+ return self._num_splits
371
+
372
+ def partition(self, tensor):
373
+ """Partition tensor into blocks."""
374
+
375
+ assert tensor.shape == self._shape
376
+ tensors = [tensor]
377
+ for (i, indices) in self._splits:
378
+ tensors_local = []
379
+ for t in tensors:
380
+ tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
381
+ tensors = tensors_local
382
+ return tensors
383
+
384
+ def merge_partitions(self, partitions):
385
+ """Merge partitions back to original shape."""
386
+
387
+ for (i, indices) in reversed(self._splits):
388
+ n = len(indices) + 1
389
+ partial_merged_tensors = []
390
+ ind = 0
391
+ while ind < len(partitions):
392
+ partial_merged_tensors.append(
393
+ jnp.concatenate(partitions[ind:ind + n], axis=i))
394
+ ind += n
395
+ partitions = partial_merged_tensors
396
+ assert len(partitions) == 1
397
+ return partitions[0]
398
+
399
+
400
+ class Preconditioner:
401
+ """Compute statistics/shape from gradients for preconditioning."""
402
+
403
+ def __init__(self, param, block_size, best_effort_shape_interpretation):
404
+ self._original_shape = param.shape
405
+ self._transformed_shape = param.shape
406
+ if best_effort_shape_interpretation:
407
+ self._transformed_shape = merge_small_dims(self._original_shape,
408
+ block_size)
409
+ reshaped_param = jnp.reshape(param, self._transformed_shape)
410
+ self._partitioner = BlockPartitioner(reshaped_param, block_size)
411
+
412
+ def statistics_from_grad(self, grad):
413
+ """Compute statistics from gradients.
414
+
415
+ Args:
416
+ grad: Gradient to compute statistics from.
417
+
418
+ Returns:
419
+ A list of gradient statistics for each partition.
420
+ """
421
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
422
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
423
+ stats = []
424
+ for g in partitioned_grads:
425
+ g_stats = []
426
+ rank = len(g.shape)
427
+ for i in range(rank):
428
+ axes = list(range(i)) + list(range(i + 1, rank))
429
+ stat = jnp.tensordot(g, g, axes=(axes, axes))
430
+ g_stats.append(stat)
431
+ stats.extend(g_stats)
432
+ return stats
433
+
434
+ def shapes_for_preconditioners(self):
435
+ """Returns shape from statistics."""
436
+ return self._partitioner.shapes_for_preconditioners()
437
+
438
+ def exponent_for_preconditioner(self):
439
+ """Returns exponent to use for inverse-pth root M^{-1/p}."""
440
+ return 2 * len(self._transformed_shape)
441
+
442
+ def preconditioned_grad(self, grad, preconditioners):
443
+ """Precondition the gradient.
444
+
445
+ Args:
446
+ grad: A gradient tensor to precondition.
447
+ preconditioners: A list of preconditioners to apply.
448
+
449
+ Returns:
450
+ A preconditioned gradient.
451
+ """
452
+
453
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
454
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
455
+ preconditioned_partitioned_grads = []
456
+ num_splits = self._partitioner.num_splits()
457
+ for i, g in enumerate(partitioned_grads):
458
+ preconditioners_for_grad = preconditioners[i * num_splits:(i + 1) *
459
+ num_splits]
460
+ rank = len(g.shape)
461
+ precond_g = g
462
+ for j in range(rank):
463
+ precond_g = jnp.tensordot(
464
+ precond_g, preconditioners_for_grad[j], axes=[[0], [0]])
465
+ preconditioned_partitioned_grads.append(precond_g)
466
+ merged_grad = self._partitioner.merge_partitions(
467
+ preconditioned_partitioned_grads)
468
+ return jnp.reshape(merged_grad, self._original_shape)
469
+
470
+
471
+ def _convert_to_parameter_stats(global_stats, local_stat):
472
+ """Creates parameter stats from sharded stats."""
473
+ index_start = int(local_stat.index_start)
474
+ index_end = int(len(local_stat.sizes)) + index_start
475
+ statistics = global_stats.statistics[index_start:index_end, :, :]
476
+ preconditioners = global_stats.preconditioners[index_start:index_end, :, :]
477
+ new_statistics = []
478
+ new_preconditioners = []
479
+ for i, size in enumerate(local_stat.sizes):
480
+ new_statistics.append(statistics[i][:size, :size])
481
+ new_preconditioners.append(preconditioners[i][:size, :size])
482
+ return ParameterStats(local_stat.diagonal_statistics, new_statistics,
483
+ new_preconditioners, local_stat.diagonal_momentum,
484
+ local_stat.momentum)
485
+
486
+
487
+ def _convert_from_parameter_stats(parameter_stats, local_stats):
488
+ """Creates sharded stats from paramter stats."""
489
+ return LocalShardedParameterStats(parameter_stats.diagonal_statistics,
490
+ parameter_stats.diagonal_momentum,
491
+ parameter_stats.momentum,
492
+ local_stats.index_start, local_stats.sizes)
493
+
494
+
495
+ def distributed_shampoo(learning_rate,
496
+ block_size,
497
+ beta1=0.9,
498
+ beta2=0.999,
499
+ diagonal_epsilon=1e-10,
500
+ matrix_epsilon=1e-6,
501
+ weight_decay=0.0,
502
+ start_preconditioning_step=5,
503
+ preconditioning_compute_steps=1,
504
+ statistics_compute_steps=1,
505
+ best_effort_shape_interpretation=True,
506
+ graft_type=GraftingType.SGD,
507
+ nesterov=True,
508
+ exponent_override=0,
509
+ # Pass pmap 'batch axis name' in pmap mode.
510
+ batch_axis_name=None,
511
+ ### Only set following 3 params in pjit/spmd mode.
512
+ ### WARNING: Experimental
513
+ mesh_axis_names=None,
514
+ num_devices_for_pjit=None,
515
+ shard_optimizer_states=False,
516
+ ###
517
+ inverse_failure_threshold=0.1,
518
+ moving_average_for_momentum=False,
519
+ skip_preconditioning_dim_size_gt=4096,
520
+ clip_by_scaled_gradient_norm=None,
521
+ precision=lax.Precision.HIGHEST):
522
+ """Distributed Shampoo optimizer.
523
+
524
+ Distributed Shampoo is a second-order preconditioned method (concretely, a
525
+ variant of full-matrix Adagrad), that provides significant convergence and
526
+ wall-clock time improvements compared to conventional first-order methods,
527
+ and that has been shown to scale to large state-of-the-art deep learning
528
+ models.
529
+
530
+ References:
531
+ Scalable Second Order Optimization for Deep Learning,
532
+ Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
533
+
534
+ Preprint: https://arxiv.org/abs/2002.09018
535
+
536
+ Args:
537
+ learning_rate: the step size used to update the parameters.
538
+ block_size: Block size for large layers (if > 0). Preconditioning compute
539
+ operation is cubic in the dimension of the tensor. Block size allows us to
540
+ chunk the layers into sub-layers of maximal dimension dictated by this
541
+ value. Use 128 as default (increase if you have compute budget).
542
+ beta1: momentum parameter.
543
+ beta2: second moment averaging parameter.
544
+ diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
545
+ to AdaGrad is enabled).
546
+ matrix_epsilon: epsilon to add to statistics before computing inverse pth
547
+ root. If you are running in f32 precision for inverse pth root
548
+ (recommended today) this can go upto 1e-6. If you have latest hardware
549
+ with native f64 precision, set this upto 1e-12.
550
+ weight_decay: Weight decay for regularization.
551
+ start_preconditioning_step: When to start Shampoo update before which
552
+ diagonal update is used. This is because we dont have enough information
553
+ to do stable inverse.
554
+ preconditioning_compute_steps: How often to compute preconditioner.
555
+ Performance tuning params for controlling memory and compute requirements.
556
+ Ideally set this and statistics_compute_steps params to 1.
557
+ statistics_compute_steps: How often to compute statistics.
558
+ best_effort_shape_interpretation: If there are some small dimensions,
559
+ collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
560
+ block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
561
+ graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
562
+ optimizer. This allows us to plugin the Shampoo optimizer into settings
563
+ where SGD/AdaGrad is already well tuned. Available options are:
564
+ GraftingType.SGD and GraftingType.ADAGRAD.
565
+ nesterov: Nesterov momentum.
566
+ exponent_override: Override the exponent used in matrix inverse.
567
+ batch_axis_name: labeled axis over pmap for data-parallel training the
568
+ optimizer used for.
569
+ mesh_axis_names: Axis names for the mesh (used in pjit).
570
+ num_devices_for_pjit: Number of devices to parallelize over when using pjit.
571
+ shard_optimizer_states: Shard optimizer states to save memory in model
572
+ parallel training.
573
+ inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
574
+ determine that using this threshold.
575
+ moving_average_for_momentum: Whether to use moving average for momentum
576
+ instead of exponential moving average.
577
+ skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
578
+ greater than this value.
579
+ clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful
580
+ when using RMSProp Grafting).
581
+ precision: precision XLA related flag, the available options are: a)
582
+ lax.Precision.DEFAULT (better step time, but not precise) b)
583
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
584
+ (best possible precision, slowest)
585
+
586
+ Returns:
587
+ a GradientTransformation.
588
+ """
589
+
590
+ def sharded_init_fn(params):
591
+ params_flat, treedef = jax.tree_flatten(params)
592
+ # Find max size to pad to.
593
+ max_size = 0
594
+ for param in params_flat:
595
+ preconditioner = Preconditioner(param, block_size,
596
+ best_effort_shape_interpretation)
597
+ if not _skip_preconditioning(param):
598
+ shapes = preconditioner.shapes_for_preconditioners()
599
+ sizes = [s[0] for s in shapes]
600
+ max_size = max(max(sizes), max_size)
601
+
602
+ padded_statistics = []
603
+ padded_preconditioners = []
604
+ local_stats_flat = []
605
+ for param in params_flat:
606
+ preconditioner = Preconditioner(param, block_size,
607
+ best_effort_shape_interpretation)
608
+ shapes = preconditioner.shapes_for_preconditioners()
609
+ sizes = []
610
+
611
+ statistics = []
612
+ preconditioners = []
613
+ index_start = len(padded_statistics)
614
+ if not _skip_preconditioning(param):
615
+ sizes = [s[0] for s in shapes]
616
+ shapes = preconditioner.shapes_for_preconditioners()
617
+ statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes]
618
+ preconditioners = [jnp.eye(max_size) for s in shapes]
619
+ padded_statistics.extend(statistics)
620
+ padded_preconditioners.extend(preconditioners)
621
+
622
+ adagrad_statistics = []
623
+ if graft_type != GraftingType.SGD:
624
+ adagrad_statistics = jnp.zeros_like(param)
625
+ local_stats_flat.append(
626
+ LocalShardedParameterStats(adagrad_statistics, jnp.zeros_like(param),
627
+ jnp.zeros_like(param), index_start, sizes))
628
+
629
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
630
+ # Pad the statistics and preconditioner matrices to be a multiple of
631
+ # num devices.
632
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
633
+ # is split on.
634
+ to_pad = -len(padded_statistics) % num_devices_for_pjit
635
+ padded_statistics.extend([
636
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
637
+ for _ in range(to_pad)
638
+ ])
639
+ padded_preconditioners.extend([
640
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
641
+ for _ in range(to_pad)
642
+ ])
643
+ global_stats = GlobalShardedParameterStats(
644
+ jnp.stack(padded_statistics), jnp.stack(padded_preconditioners))
645
+ return ShampooState(
646
+ count=jnp.zeros([], jnp.int32),
647
+ stats=ShardedShampooStats(global_stats, local_stats))
648
+
649
+ def sharded_update_fn(grads, state, params):
650
+ """Transform the input gradient and update all statistics in sharded mode.
651
+
652
+ Args:
653
+ grads: the gradient tensors for the parameters.
654
+ state: a named tuple containing the state of the optimizer
655
+ params: the parameters that should be updated.
656
+
657
+ Returns:
658
+ A tuple containing the new parameters and the new optimizer state.
659
+ """
660
+ params_flat, treedef = jax.tree_flatten(params)
661
+ grads_flat = treedef.flatten_up_to(grads)
662
+
663
+ global_stats = state.stats.global_stats
664
+ local_stats_flat = treedef.flatten_up_to(state.stats.local_stats)
665
+ stats_flat = [
666
+ _convert_to_parameter_stats(global_stats, local_stat)
667
+ for local_stat in local_stats_flat
668
+ ]
669
+ new_stats_flat = jax.tree_multimap(
670
+ lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
671
+ stats_flat, params_flat)
672
+
673
+ exponents = []
674
+ for stat, param in zip(new_stats_flat, params_flat):
675
+ num_statistics = len(stat.statistics)
676
+ if num_statistics > 0:
677
+ preconditioner = Preconditioner(param, block_size,
678
+ best_effort_shape_interpretation)
679
+ exponent = (
680
+ preconditioner.exponent_for_preconditioner()
681
+ if exponent_override == 0 else exponent_override)
682
+ exponents.extend([exponent] * num_statistics)
683
+
684
+ outputs = jax.tree_multimap(
685
+ lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
686
+ new_stats_flat, params_flat)
687
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
688
+
689
+ updates = jax.tree_unflatten(treedef, updates_flat)
690
+ # Create new local_stats
691
+ new_local_stats_flat = [
692
+ _convert_from_parameter_stats(new_stat, local_stat)
693
+ for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
694
+ ]
695
+ new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
696
+
697
+ max_size = global_stats.statistics.shape[1]
698
+ new_padded_statistics = []
699
+ for stat in new_stats_flat:
700
+ new_padded_statistics.extend(
701
+ [pad_matrix(stat, max_size) for stat in stat.statistics])
702
+
703
+ # Create global stats
704
+ # TODO(rohananil): Preconditioner is not updated every step, so cost of
705
+ # stack/pad can be obviated away.
706
+ # Pad the statistics and preconditioner matrices to be a multiple of
707
+ # num devices.
708
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
709
+ # is split on.
710
+ to_pad = -len(new_padded_statistics) % num_devices_for_pjit
711
+ new_padded_statistics.extend([
712
+ jnp.eye(max_size, dtype=new_padded_statistics[0].dtype)
713
+ for _ in range(to_pad)
714
+ ])
715
+ exponents.extend([1 for _ in range(to_pad)])
716
+ new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
717
+ new_stacked_exponents = jnp.stack(exponents)
718
+ def _matrix_inverse_pth_root_vmap(xs, ps):
719
+ mi_pth_root = functools.partial(
720
+ matrix_inverse_pth_root,
721
+ ridge_epsilon=matrix_epsilon,
722
+ precision=precision)
723
+ preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
724
+ return preconditioners, errors
725
+
726
+ def _internal_inverse_pth_root_all():
727
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
728
+ new_stacked_padded_statistics, new_stacked_exponents)
729
+ return preconditioners, errors
730
+
731
+ if preconditioning_compute_steps == 1:
732
+ new_preconditioners, errors = _internal_inverse_pth_root_all()
733
+ else:
734
+ # Passing statistics instead of preconditioners as they are similarly
735
+ # shaped tensors. Note statistics will be ignored as we are passing in
736
+ # a large init value for error.
737
+ preconditioners_init = new_stacked_padded_statistics
738
+ errors_init = np.stack([inverse_failure_threshold] * len(exponents))
739
+ init_state = [preconditioners_init, errors_init]
740
+ perform_step = state.count % preconditioning_compute_steps == 0
741
+ new_preconditioners, errors = efficient_cond(
742
+ perform_step, _internal_inverse_pth_root_all, init_state)
743
+
744
+ errors = errors.reshape((-1, 1, 1))
745
+ predicate = jnp.logical_or(
746
+ jnp.isnan(errors),
747
+ errors >= inverse_failure_threshold).astype(new_preconditioners.dtype)
748
+ # TODO(rohananil): Check for numerical instabilities.
749
+ new_conditional_preconditioners = (
750
+ predicate * global_stats.preconditioners +
751
+ (1.0 - predicate) * new_preconditioners)
752
+ new_global_stats = GlobalShardedParameterStats(
753
+ new_stacked_padded_statistics, new_conditional_preconditioners)
754
+ new_shampoo_state = ShampooState(
755
+ count=state.count + 1,
756
+ stats=ShardedShampooStats(new_global_stats, new_local_stats))
757
+ return updates, new_shampoo_state
758
+
759
+ def init_fn(params):
760
+ """Initialise the optimiser's state."""
761
+
762
+ def _init(param):
763
+ preconditioner = Preconditioner(param, block_size,
764
+ best_effort_shape_interpretation)
765
+ statistics = []
766
+ preconditioners = []
767
+ if not _skip_preconditioning(param):
768
+ shapes = preconditioner.shapes_for_preconditioners()
769
+ statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes]
770
+ preconditioners = [jnp.eye(s[0]) for s in shapes]
771
+
772
+ adagrad_statistics = []
773
+ if graft_type != GraftingType.SGD:
774
+ adagrad_statistics = jnp.zeros_like(param)
775
+ return ParameterStats(adagrad_statistics, statistics, preconditioners,
776
+ jnp.zeros_like(param), jnp.zeros_like(param))
777
+
778
+ return ShampooState(
779
+ count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params))
780
+
781
+ def _skip_preconditioning(param):
782
+ return len(param.shape) < 1 or any(
783
+ [s > skip_preconditioning_dim_size_gt for s in param.shape])
784
+
785
+ def _compute_stats(grad, state, param, step):
786
+ """Compute per-parameter statistics."""
787
+ preconditioner = Preconditioner(param, block_size,
788
+ best_effort_shape_interpretation)
789
+ new_statistics = [[]] * len(state.statistics)
790
+ w1 = beta2
791
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
792
+ if not _skip_preconditioning(param):
793
+
794
+ def compute_updated_statistics():
795
+ new_stats = preconditioner.statistics_from_grad(grad)
796
+ new_stats_accumulators = []
797
+ for stat, stat_accumulator in zip(new_stats, state.statistics):
798
+ new_stats_accumulators.append(w1 * stat_accumulator + w2 * stat)
799
+ return new_stats_accumulators
800
+
801
+ if statistics_compute_steps > 1:
802
+ perform_step = step % statistics_compute_steps == 0
803
+ init_state = state.statistics
804
+ new_statistics = list(
805
+ efficient_cond(perform_step, compute_updated_statistics,
806
+ init_state))
807
+ else:
808
+ new_statistics = compute_updated_statistics()
809
+ return ParameterStats(state.diagonal_statistics, new_statistics,
810
+ state.preconditioners, state.diagonal_momentum,
811
+ state.momentum)
812
+
813
+ def _compute_preconditioners(states, params, step):
814
+ """Compute preconditioners for statistics."""
815
+ statistics = []
816
+ num_statistics_per_state = []
817
+ original_shapes = []
818
+ exponents = []
819
+ max_size = 0
820
+ prev_preconditioners = []
821
+ for state, param in zip(states, params):
822
+ num_statistics = len(state.statistics)
823
+ num_statistics_per_state.append(num_statistics)
824
+ original_shapes_for_state = []
825
+ if num_statistics > 0:
826
+ preconditioner = Preconditioner(param, block_size,
827
+ best_effort_shape_interpretation)
828
+ for statistic in state.statistics:
829
+ exponents.append(preconditioner.exponent_for_preconditioner(
830
+ ) if exponent_override == 0 else exponent_override)
831
+ original_shapes_for_state.append(statistic.shape)
832
+ max_size = max(max_size, statistic.shape[0])
833
+ statistics.extend(state.statistics)
834
+ prev_preconditioners.extend(state.preconditioners)
835
+ original_shapes.extend(original_shapes_for_state)
836
+ num_statistics = len(statistics)
837
+
838
+ if batch_axis_name:
839
+ num_devices = lax.psum(1, batch_axis_name)
840
+
841
+ # Pad statistics and exponents to next multiple of num_devices.
842
+ packed_statistics = [pad_matrix(stat, max_size) for stat in statistics]
843
+ to_pad = -num_statistics % num_devices
844
+ packed_statistics.extend([
845
+ jnp.eye(max_size, dtype=packed_statistics[0].dtype)
846
+ for _ in range(to_pad)
847
+ ])
848
+ exponents.extend([1 for _ in range(to_pad)])
849
+
850
+ if not packed_statistics:
851
+ return states
852
+ # Batch statistics and exponents so that so that leading axis is
853
+ # num_devices.
854
+ def _batch(statistics, exponents, num_devices):
855
+ assert len(statistics) == len(exponents)
856
+ n = len(statistics)
857
+ b = int(n / num_devices)
858
+ batched_statistics = [
859
+ jnp.stack(statistics[idx:idx + b]) for idx in range(0, n, b)
860
+ ]
861
+ batched_exponents = [
862
+ jnp.stack(exponents[idx:idx + b]) for idx in range(0, n, b)
863
+ ]
864
+ return jnp.stack(batched_statistics), jnp.stack(batched_exponents)
865
+
866
+ # Unbatch values across leading axis and return a list of elements.
867
+ def _unbatch(batched_values):
868
+ b1, b2 = batched_values.shape[0], batched_values.shape[1]
869
+ results = []
870
+ for v_array in jnp.split(
871
+ batched_values, indices_or_sections=b1, axis=0):
872
+ v_array = jnp.squeeze(v_array)
873
+ # b2 = batches (number of preconditioner computation) per core.
874
+ if b2 > 1:
875
+ for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
876
+ results.append(jnp.squeeze(v))
877
+ else:
878
+ results.append(v_array)
879
+ return results
880
+
881
+ all_statistics, all_exponents = _batch(packed_statistics, exponents,
882
+ num_devices)
883
+ else:
884
+ to_pad = -num_statistics % num_devices_for_pjit
885
+ padded_statistics = [pad_matrix(stat, max_size) for stat in statistics]
886
+ padded_statistics.extend([
887
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
888
+ for _ in range(to_pad)
889
+ ])
890
+ exponents.extend([1 for _ in range(to_pad)])
891
+ all_statistics = jnp.stack(padded_statistics)
892
+ all_exponents = jnp.stack(exponents)
893
+
894
+ def _matrix_inverse_pth_root_vmap(xs, ps):
895
+ mi_pth_root = functools.partial(
896
+ matrix_inverse_pth_root,
897
+ ridge_epsilon=matrix_epsilon,
898
+ precision=precision)
899
+ preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
900
+ return preconditioners, errors
901
+
902
+ def _matrix_inverse_pth_root_pjit(xs, ps):
903
+ mesh_axis_names_tuple = tuple(mesh_axis_names)
904
+ # Partition the concatenated statistics matrix across all cores.
905
+ partitioned_xs, partitioned_ps = pjit.pjit(
906
+ lambda x, y: (x, y),
907
+ in_axis_resources=None,
908
+ out_axis_resources=pjit.PartitionSpec(mesh_axis_names_tuple,))(xs, ps)
909
+ # Run matrix inverse pth root on each shard.
910
+ partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
911
+ partitioned_xs, partitioned_ps)
912
+ # Recombine the outputs at each core.
913
+ preconditioners, errors = pjit.pjit(
914
+ lambda x, y: (x, y),
915
+ in_axis_resources=(pjit.PartitionSpec(mesh_axis_names_tuple,),
916
+ pjit.PartitionSpec(mesh_axis_names_tuple,)),
917
+ out_axis_resources=(None, None))(partitioned_preconditioners,
918
+ partitioned_errors)
919
+ return preconditioners, errors
920
+
921
+ if not batch_axis_name:
922
+ def _internal_inverse_pth_root_all():
923
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
924
+ all_statistics, all_exponents)
925
+ b1 = preconditioners.shape[0]
926
+ def split(batched_values):
927
+ return [
928
+ jnp.squeeze(v) for v in jnp.split(
929
+ batched_values, indices_or_sections=b1, axis=0)
930
+ ]
931
+
932
+ return split(preconditioners), split(errors)
933
+
934
+ if preconditioning_compute_steps == 1:
935
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
936
+ else:
937
+ # Passing statistics instead of preconditioners as they are similarly
938
+ # shaped tensors. Note statistics will be ignored as we are passing in
939
+ # a large init value for error.
940
+ preconditioners_init = padded_statistics
941
+ errors_init = [inverse_failure_threshold] * len(padded_statistics)
942
+ init_state = [preconditioners_init, errors_init]
943
+ perform_step = step % preconditioning_compute_steps == 0
944
+ preconditioners_flat, errors_flat = efficient_cond(
945
+ perform_step, _internal_inverse_pth_root_all, init_state)
946
+ else:
947
+
948
+ def _internal_inverse_pth_root_all():
949
+ preconditioners = jnp.array(all_statistics)
950
+ current_replica = lax.axis_index(batch_axis_name)
951
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
952
+ all_statistics[current_replica], all_exponents[current_replica])
953
+ preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
954
+ errors = jax.lax.all_gather(errors, batch_axis_name)
955
+ preconditioners_flat = _unbatch(preconditioners)
956
+ errors_flat = _unbatch(errors)
957
+ return preconditioners_flat, errors_flat
958
+
959
+ if preconditioning_compute_steps == 1:
960
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
961
+ else:
962
+ # Passing statistics instead of preconditioners as they are similarly
963
+ # shaped tensors. Note statistics will be ignored as we are passing in
964
+ # a large init value for error.
965
+ preconditioners_init = packed_statistics
966
+ errors_init = ([inverse_failure_threshold] * len(packed_statistics))
967
+ init_state = [preconditioners_init, errors_init]
968
+ perform_step = step % preconditioning_compute_steps == 0
969
+ preconditioners_flat, errors_flat = efficient_cond(
970
+ perform_step, _internal_inverse_pth_root_all, init_state)
971
+
972
+ def _skip(error):
973
+ condition = jnp.logical_or(
974
+ jnp.isnan(error), error >= inverse_failure_threshold)
975
+ return condition.astype(error.dtype)
976
+
977
+ def _select_preconditioner(error, new_p, old_p):
978
+ return lax.cond(
979
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
980
+
981
+ new_preconditioners_flat = []
982
+ for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
983
+ prev_preconditioners, errors_flat):
984
+ new_preconditioners_flat.append(
985
+ _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
986
+
987
+ assert len(states) == len(num_statistics_per_state)
988
+ assert len(new_preconditioners_flat) == num_statistics
989
+
990
+ # Add back empty preconditioners so we that we can set the optimizer state.
991
+ preconditioners_for_states = []
992
+ idx = 0
993
+ for num_statistics, state in zip(num_statistics_per_state, states):
994
+ if num_statistics == 0:
995
+ preconditioners_for_states.append([])
996
+ else:
997
+ preconditioners_for_state = new_preconditioners_flat[idx:idx +
998
+ num_statistics]
999
+ assert len(state.statistics) == len(preconditioners_for_state)
1000
+ preconditioners_for_states.append(preconditioners_for_state)
1001
+ idx += num_statistics
1002
+ new_states = []
1003
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1004
+ new_states.append(
1005
+ ParameterStats(state.diagonal_statistics, state.statistics,
1006
+ new_preconditioners, state.diagonal_momentum,
1007
+ state.momentum))
1008
+
1009
+ return new_states
1010
+
1011
+ def _transform_grad(grad, state, param, step):
1012
+ """Transform per-parameter gradients."""
1013
+ preconditioner = Preconditioner(param, block_size,
1014
+ best_effort_shape_interpretation)
1015
+ sgd_update = grad
1016
+ new_diagonal_statistics = state.diagonal_statistics
1017
+ if graft_type == GraftingType.ADAGRAD:
1018
+ new_diagonal_statistics = state.diagonal_statistics + jnp.square(grad)
1019
+ adagrad_update = grad / (
1020
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
1021
+ grafting_update = adagrad_update
1022
+ elif (graft_type == GraftingType.RMSPROP or
1023
+ graft_type == GraftingType.RMSPROP_NORMALIZED):
1024
+
1025
+ scaled_grad = grad
1026
+ if graft_type == GraftingType.RMSPROP_NORMALIZED:
1027
+ scaled_grad = grad / jnp.linalg.norm(grad)
1028
+
1029
+ w1 = beta2
1030
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
1031
+
1032
+ new_diagonal_statistics = (
1033
+ w1 * state.diagonal_statistics + w2 * jnp.square(scaled_grad))
1034
+ rmsprop_update = scaled_grad / (
1035
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
1036
+
1037
+ if clip_by_scaled_gradient_norm:
1038
+ scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / (
1039
+ jnp.sqrt(float(rmsprop_update.size)))
1040
+ clipping_denom = jnp.maximum(
1041
+ 1., scaled_grad_norm / clip_by_scaled_gradient_norm)
1042
+ rmsprop_update /= clipping_denom
1043
+
1044
+ grafting_update = rmsprop_update
1045
+ else:
1046
+ grafting_update = sgd_update
1047
+
1048
+ precond_grad = grad
1049
+ if not _skip_preconditioning(param):
1050
+ precond_grad = preconditioner.preconditioned_grad(precond_grad,
1051
+ state.preconditioners)
1052
+ else:
1053
+ precond_grad = grafting_update
1054
+
1055
+ grafting_update_norm = jnp.linalg.norm(grafting_update)
1056
+ precond_grad_norm = jnp.linalg.norm(precond_grad)
1057
+
1058
+ multiplier = (grafting_update_norm / (precond_grad_norm + 1e-16))
1059
+ shampoo_update = precond_grad * multiplier
1060
+
1061
+ shampoo_update_with_wd = shampoo_update
1062
+ grafting_update_with_wd = grafting_update
1063
+ if weight_decay != 0:
1064
+ shampoo_update_with_wd = shampoo_update + weight_decay * param
1065
+ grafting_update_with_wd = grafting_update + weight_decay * param
1066
+
1067
+ w = (1.0 - beta1) if moving_average_for_momentum else 1.0
1068
+ shampoo_update_with_wd_momentum = (
1069
+ state.momentum * beta1 + w * shampoo_update_with_wd)
1070
+ grafting_update_with_wd_momentum = (
1071
+ state.diagonal_momentum * beta1 + w * grafting_update_with_wd)
1072
+
1073
+ run_shampoo = (step >= start_preconditioning_step).astype(
1074
+ grafting_update_with_wd_momentum.dtype)
1075
+
1076
+ momentum_update = (
1077
+ run_shampoo * shampoo_update_with_wd_momentum +
1078
+ (1.0 - run_shampoo) * grafting_update_with_wd_momentum)
1079
+
1080
+ wd_update = (
1081
+ run_shampoo * shampoo_update_with_wd +
1082
+ (1.0 - run_shampoo) * grafting_update_with_wd)
1083
+
1084
+ if nesterov:
1085
+ momentum_update = w * wd_update + beta1 * momentum_update
1086
+
1087
+ lr = learning_rate
1088
+ if callable(learning_rate):
1089
+ lr = learning_rate(step)
1090
+ transformed_update = -1.0 * lr * momentum_update
1091
+
1092
+ param_stats = ParameterStats(new_diagonal_statistics, state.statistics,
1093
+ state.preconditioners,
1094
+ grafting_update_with_wd_momentum,
1095
+ shampoo_update_with_wd_momentum)
1096
+ return transformed_update, param_stats
1097
+
1098
+ def update_fn(grads, state, params):
1099
+ """Transform the input gradient and update all statistics.
1100
+
1101
+ Args:
1102
+ grads: the gradient tensors for the parameters.
1103
+ state: a named tuple containing the state of the optimizer
1104
+ params: the parameters that should be updated.
1105
+
1106
+ Returns:
1107
+ A tuple containing the new parameters and the new optimizer state.
1108
+ """
1109
+ params_flat, treedef = jax.tree_flatten(params)
1110
+ stats_flat = treedef.flatten_up_to(state.stats)
1111
+ grads_flat = treedef.flatten_up_to(grads)
1112
+
1113
+ new_stats_flat = jax.tree_multimap(
1114
+ lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
1115
+ stats_flat, params_flat)
1116
+ new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat,
1117
+ state.count)
1118
+
1119
+ outputs = jax.tree_multimap(
1120
+ lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
1121
+ new_stats_flat, params_flat)
1122
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
1123
+
1124
+ updates = jax.tree_unflatten(treedef, updates_flat)
1125
+ new_stats = jax.tree_unflatten(treedef, new_stats_flat)
1126
+
1127
+ new_state = ShampooState(
1128
+ count=state.count+1, stats=new_stats)
1129
+ return updates, new_state
1130
+
1131
+ if shard_optimizer_states:
1132
+ return optax.GradientTransformation(sharded_init_fn, sharded_update_fn)
1133
+ else:
1134
+ return optax.GradientTransformation(init_fn, update_fn)
tools/train/train.py CHANGED
@@ -45,6 +45,8 @@ from transformers import AutoTokenizer, HfArgumentParser
45
  from dalle_mini.data import Dataset
46
  from dalle_mini.model import DalleBart, DalleBartConfig
47
 
 
 
48
  logger = logging.getLogger(__name__)
49
 
50
 
@@ -214,6 +216,10 @@ class TrainingArguments:
214
  default=False,
215
  metadata={"help": "Whether or not to replace AdamW by Adafactor."},
216
  )
 
 
 
 
217
  weight_decay: float = field(
218
  default=None, metadata={"help": "Weight decay if we apply some."}
219
  )
@@ -560,6 +566,33 @@ def main():
560
  weight_decay_mask=decay_mask_fn,
561
  clipping_threshold=training_args.max_grad_norm,
562
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
  else:
564
  optimizer = optax.adamw(
565
  learning_rate=learning_rate_fn,
 
45
  from dalle_mini.data import Dataset
46
  from dalle_mini.model import DalleBart, DalleBartConfig
47
 
48
+ from distributed_shampoo import distributed_shampoo, GraftingType
49
+
50
  logger = logging.getLogger(__name__)
51
 
52
 
 
216
  default=False,
217
  metadata={"help": "Whether or not to replace AdamW by Adafactor."},
218
  )
219
+ shampoo: bool = field(
220
+ default=False,
221
+ metadata={"help": "Whether or not to replace AdamW by Adafactor."},
222
+ )
223
  weight_decay: float = field(
224
  default=None, metadata={"help": "Weight decay if we apply some."}
225
  )
 
566
  weight_decay_mask=decay_mask_fn,
567
  clipping_threshold=training_args.max_grad_norm,
568
  )
569
+ elif training_args.shampoo:
570
+ # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
571
+ # Notes:
572
+ # - mask for weight decay is not implemented so we don't use it
573
+ optimizer = distributed_shampoo(
574
+ learning_rate_fn,
575
+ block_size=1024, # recommended default for large LM is 1536
576
+ beta1=0.9,
577
+ beta2=0.999,
578
+ diagonal_epsilon=1e-10,
579
+ matrix_epsilon=1e-8,
580
+ weight_decay=0.0,
581
+ start_preconditioning_step=51,
582
+ preconditioning_compute_steps=50,
583
+ statistics_compute_steps=1,
584
+ best_effort_shape_interpretation=True,
585
+ graft_type=GraftingType.RMSPROP_NORMALIZED,
586
+ nesterov=False,
587
+ exponent_override=0,
588
+ batch_axis_name="batch",
589
+ inverse_failure_threshold=0.1,
590
+ moving_average_for_momentum=True,
591
+ skip_preconditioning_dim_size_gt=4096,
592
+ clip_by_scaled_gradient_norm=None,
593
+ precision=jax.lax.Precision.HIGHEST,
594
+ )
595
+
596
  else:
597
  optimizer = optax.adamw(
598
  learning_rate=learning_rate_fn,