versae commited on
Commit
30cb273
1 Parent(s): e50fbb8

Shampoo test 2

Browse files
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "roberta-base",
3
+ "architectures": [
4
+ "RobertaForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 768,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 514,
17
+ "model_type": "roberta",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "position_embedding_type": "absolute",
22
+ "transformers_version": "4.16.0.dev0",
23
+ "type_vocab_size": 1,
24
+ "use_cache": true,
25
+ "vocab_size": 50265
26
+ }
distributed_shampoo.py ADDED
@@ -0,0 +1,1609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # An implementation of distributed Shampoo optimizer from:
17
+ #
18
+ # Scalable Second Order Optimization for Deep Learning
19
+ # Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
20
+ # Preprint Paper: https://arxiv.org/abs/2002.09018
21
+ #
22
+ # This implementation moves computation of inverse pth root back to the
23
+ # accelerator (if higher precision is available).
24
+ #
25
+ # Authors: Rohan Anil (rohananil at google dot com)
26
+ # & Vineet Gupta (vineet at google dot com)
27
+ #
28
+
29
+ """Distributed Shampoo Implementation."""
30
+
31
+ import enum
32
+ import functools
33
+ import itertools
34
+ from typing import Any, List, NamedTuple
35
+
36
+ import chex
37
+ from flax import struct
38
+ import jax
39
+ from jax import lax
40
+ import jax.experimental.pjit as pjit
41
+ import jax.numpy as jnp
42
+ import numpy as np
43
+ import optax
44
+
45
+
46
+ # pylint:disable=no-value-for-parameter
47
+ @struct.dataclass
48
+ class QuantizedValue:
49
+ """State associated with quantized value."""
50
+ quantized: chex.Array
51
+ diagonal: chex.Array # Diagonal (if extract_diagonal is set)
52
+ bucket_size: chex.Array
53
+ quantized_dtype: jnp.dtype = struct.field(
54
+ pytree_node=False) # Dtype for the quantized value.
55
+ extract_diagonal: bool = struct.field(
56
+ pytree_node=False) # In case its centered.
57
+ shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
58
+
59
+ @classmethod
60
+ def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
61
+ if isinstance(fvalue, list) and not fvalue:
62
+ return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
63
+ quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
64
+ fvalue, quantized_dtype, extract_diagonal)
65
+ return QuantizedValue(quantized, diagonal_fvalue, bucket_size,
66
+ quantized_dtype, extract_diagonal,
67
+ list(quantized.shape))
68
+
69
+ # Quantization is from Lingvo JAX optimizers.
70
+ # We extend it for int16 quantization of PSD matrices.
71
+ @classmethod
72
+ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
73
+ """Returns quantized value and the bucket."""
74
+ if quantized_dtype == jnp.float32:
75
+ return fvalue, [], []
76
+ elif quantized_dtype == jnp.bfloat16:
77
+ return fvalue.astype(jnp.bfloat16), [], []
78
+
79
+ float_dtype = fvalue.dtype
80
+ if quantized_dtype == jnp.int8:
81
+ # value -128 is not used.
82
+ num_buckets = jnp.array(127.0, dtype=float_dtype)
83
+ elif quantized_dtype == jnp.int16:
84
+ # value -32768 is not used.
85
+ num_buckets = jnp.array(32767.0, dtype=float_dtype)
86
+ else:
87
+ raise ValueError(f'Quantized dtype {quantized_dtype} not supported.')
88
+ # max value is mapped to num_buckets
89
+
90
+ if extract_diagonal and fvalue.ndim != 2:
91
+ raise ValueError(
92
+ f'Input array {fvalue} must be 2D to work with extract_diagonal.')
93
+
94
+ diagonal_fvalue = []
95
+ if extract_diagonal:
96
+ diagonal_fvalue = jnp.diag(fvalue)
97
+ # Remove the diagonal entries.
98
+ fvalue = fvalue - jnp.diag(diagonal_fvalue)
99
+
100
+ # TODO(rohananil): Extend this by making use of information about the blocks
101
+ # SM3 style which will be useful for diagonal statistics
102
+ # We first decide the scale.
103
+ if fvalue.ndim < 1:
104
+ raise ValueError(
105
+ f'Input array {fvalue} must have a strictly positive number of '
106
+ 'dimensions.')
107
+
108
+ max_abs = jnp.max(jnp.abs(fvalue), axis=0)
109
+ bucket_size = max_abs / num_buckets
110
+ bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
111
+ # To avoid divide by 0.0
112
+ bs_nonzero = jnp.where(bs_expanded > 0.0, bs_expanded,
113
+ jnp.ones_like(bs_expanded))
114
+ ratio = fvalue / bs_nonzero
115
+ # We use rounding to remove bias.
116
+ quantized = jnp.round(ratio)
117
+ return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
118
+
119
+ def to_float(self):
120
+ """Returns the float value."""
121
+ if isinstance(self.quantized, list) and not self.quantized:
122
+ return self.quantized
123
+
124
+ if self.quantized_dtype == jnp.float32:
125
+ return self.quantized
126
+
127
+ if self.quantized_dtype == jnp.bfloat16:
128
+ return self.quantized.astype(jnp.float32)
129
+
130
+ float_dtype = self.bucket_size.dtype
131
+ bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
132
+ val = self.quantized.astype(float_dtype) * bucket_size
133
+ if self.extract_diagonal:
134
+ val += jnp.diag(self.diagonal)
135
+ return val
136
+
137
+
138
+ # Per parameter optimizer state used in data-parallel training.
139
+ class ParameterStats(NamedTuple):
140
+ """State associated to each parameter of the model being trained."""
141
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
142
+ statistics: List[Any] # Statistics (QuantizedValue, chex.Array)
143
+ preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
144
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
145
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
146
+
147
+
148
+ # For training extremely large model; We keep a global state with a concatenated
149
+ # statistics and preconditioner states for all vars. This is so that we can
150
+ # annotate the leading axis to be sharded to save memory at the cost of
151
+ # communication.
152
+ @struct.dataclass
153
+ class GlobalShardedParameterStats:
154
+ statistics: chex.Array # Statistics
155
+ preconditioners: chex.Array # Preconditioners
156
+
157
+
158
+ # These are per-parameter local states; All statistics here mirror the parameter
159
+ # Thus the sharding is copied over from the param specification.
160
+ @struct.dataclass
161
+ class LocalShardedParameterStats:
162
+ """State associated to each parameter of the model being trained."""
163
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
164
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
165
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
166
+ index_start: np.int32 = struct.field(
167
+ pytree_node=False) # Index into global statistics array
168
+ sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
169
+
170
+
171
+ class ShardedShampooStats(NamedTuple):
172
+ """Shampoo state in sharded mode."""
173
+ global_stats: Any
174
+ local_stats: Any
175
+
176
+
177
+ class ShampooState(NamedTuple):
178
+ count: chex.Array
179
+ stats: Any
180
+
181
+
182
+ class GraftingType(enum.IntEnum):
183
+ SGD = 1
184
+ ADAGRAD = 2
185
+ RMSPROP = 3
186
+ RMSPROP_NORMALIZED = 4
187
+
188
+
189
+ def power_iteration(
190
+ matrix,
191
+ num_iters=100,
192
+ error_tolerance=1e-6,
193
+ precision=lax.Precision.HIGHEST):
194
+ r"""Power iteration algorithm.
195
+
196
+ The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
197
+ a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
198
+ of `A`, and a vector v, which is the corresponding eigenvector of `A`.
199
+
200
+ References:
201
+ [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
202
+
203
+ Args:
204
+ matrix: the symmetric PSD matrix.
205
+ num_iters: Number of iterations.
206
+ error_tolerance: Iterative exit condition.
207
+ precision: precision XLA related flag, the available options are:
208
+ a) lax.Precision.DEFAULT (better step time, but not precise)
209
+ b) lax.Precision.HIGH (increased precision, slower)
210
+ c) lax.Precision.HIGHEST (best possible precision, slowest)
211
+
212
+ Returns:
213
+ eigen vector, eigen value
214
+ """
215
+ matrix_size = matrix.shape[-1]
216
+ def _iter_condition(state):
217
+ i, unused_v, unused_s, unused_s_v, run_step = state
218
+ return jnp.logical_and(i < num_iters, run_step)
219
+
220
+ def _iter_body(state):
221
+ """One step of power iteration."""
222
+ i, new_v, s, s_v, unused_run_step = state
223
+ new_v = new_v / jnp.linalg.norm(new_v)
224
+
225
+ s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision)
226
+ s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision)
227
+ return (i + 1, s_v, s_new, s_v,
228
+ jnp.greater(jnp.abs(s_new - s), error_tolerance))
229
+
230
+ # Figure out how to use step as seed for random.
231
+ v_0 = np.random.RandomState(1729).uniform(-1.0, 1.0,
232
+ matrix_size).astype(matrix.dtype)
233
+
234
+ init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
235
+ _, v_out, s_out, _, _ = lax.while_loop(
236
+ _iter_condition, _iter_body, init_state)
237
+ v_out = v_out / jnp.linalg.norm(v_out)
238
+ return v_out, s_out
239
+
240
+
241
+ def matrix_inverse_pth_root(
242
+ matrix,
243
+ p,
244
+ num_iters=100,
245
+ ridge_epsilon=1e-6,
246
+ error_tolerance=1e-6,
247
+ precision=lax.Precision.HIGHEST):
248
+ """Computes `matrix^(-1/p)`, where `p` is a positive integer.
249
+
250
+ This function uses the Coupled newton iterations algorithm for
251
+ the computation of a matrix's inverse pth root.
252
+
253
+
254
+ References:
255
+ [Functions of Matrices, Theory and Computation,
256
+ Nicholas J Higham, Pg 184, Eq 7.18](
257
+ https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
258
+
259
+ Args:
260
+ matrix: the symmetric PSD matrix whose power it to be computed
261
+ p: exponent, for p a positive integer.
262
+ num_iters: Maximum number of iterations.
263
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
264
+ error_tolerance: Error indicator, useful for early termination.
265
+ precision: precision XLA related flag, the available options are:
266
+ a) lax.Precision.DEFAULT (better step time, but not precise)
267
+ b) lax.Precision.HIGH (increased precision, slower)
268
+ c) lax.Precision.HIGHEST (best possible precision, slowest)
269
+
270
+ Returns:
271
+ matrix^(-1/p)
272
+ """
273
+
274
+ # We use float32 for the matrix inverse pth root.
275
+ # Switch to f64 if you have hardware that supports it.
276
+ matrix_size = matrix.shape[0]
277
+ alpha = jnp.asarray(-1.0 / p, jnp.float32)
278
+ identity = jnp.eye(matrix_size, dtype=jnp.float32)
279
+ _, max_ev = power_iteration(
280
+ matrix=matrix, num_iters=100,
281
+ error_tolerance=1e-6, precision=precision)
282
+ ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16)
283
+
284
+ def _unrolled_mat_pow_1(mat_m):
285
+ """Computes mat_m^1."""
286
+ return mat_m
287
+
288
+ def _unrolled_mat_pow_2(mat_m):
289
+ """Computes mat_m^2."""
290
+ return jnp.matmul(mat_m, mat_m, precision=precision)
291
+
292
+ def _unrolled_mat_pow_4(mat_m):
293
+ """Computes mat_m^4."""
294
+ mat_pow_2 = _unrolled_mat_pow_2(mat_m)
295
+ return jnp.matmul(
296
+ mat_pow_2, mat_pow_2, precision=precision)
297
+
298
+ def _unrolled_mat_pow_8(mat_m):
299
+ """Computes mat_m^4."""
300
+ mat_pow_4 = _unrolled_mat_pow_4(mat_m)
301
+ return jnp.matmul(
302
+ mat_pow_4, mat_pow_4, precision=precision)
303
+
304
+ def mat_power(mat_m, p):
305
+ """Computes mat_m^p, for p == 1, 2, 4 or 8.
306
+
307
+ Args:
308
+ mat_m: a square matrix
309
+ p: a positive integer
310
+
311
+ Returns:
312
+ mat_m^p
313
+ """
314
+ # We unrolled the loop for performance reasons.
315
+ exponent = jnp.round(jnp.log2(p))
316
+ return lax.switch(
317
+ jnp.asarray(exponent, jnp.int32), [
318
+ _unrolled_mat_pow_1,
319
+ _unrolled_mat_pow_2,
320
+ _unrolled_mat_pow_4,
321
+ _unrolled_mat_pow_8,
322
+ ], (mat_m))
323
+
324
+ def _iter_condition(state):
325
+ (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error,
326
+ run_step) = state
327
+ error_above_threshold = jnp.logical_and(
328
+ error > error_tolerance, run_step)
329
+ return jnp.logical_and(i < num_iters, error_above_threshold)
330
+
331
+ def _iter_body(state):
332
+ (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state
333
+ mat_m_i = (1 - alpha) * identity + alpha * mat_m
334
+ new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision)
335
+ new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision)
336
+ new_error = jnp.max(jnp.abs(new_mat_m - identity))
337
+ # sometimes error increases after an iteration before decreasing and
338
+ # converging. 1.2 factor is used to bound the maximal allowed increase.
339
+ return (i + 1, new_mat_m, new_mat_h, mat_h, new_error,
340
+ new_error < error * 1.2)
341
+
342
+ if matrix_size == 1:
343
+ resultant_mat_h = (matrix + ridge_epsilon)**alpha
344
+ error = 0
345
+ else:
346
+ damped_matrix = matrix + ridge_epsilon * identity
347
+
348
+ z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix))
349
+ new_mat_m_0 = damped_matrix * z
350
+ new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
351
+ new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
352
+ init_state = tuple(
353
+ [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True])
354
+ _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
355
+ _iter_condition, _iter_body, init_state)
356
+ error = jnp.max(jnp.abs(mat_m - identity))
357
+ is_converged = jnp.asarray(convergence, old_mat_h.dtype)
358
+ resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
359
+ resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype)
360
+ return resultant_mat_h, error
361
+
362
+
363
+ def merge_small_dims(shape_to_merge, max_dim):
364
+ """Merge small dimensions.
365
+
366
+ If there are some small dimensions, we collapse them:
367
+ e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
368
+ [1, 2, 768, 1, 2048] --> [2, 768, 2048]
369
+
370
+ Args:
371
+ shape_to_merge: Shape to merge small dimensions.
372
+ max_dim: Maximal dimension of output shape used in merging.
373
+
374
+ Returns:
375
+ Merged shape.
376
+ """
377
+ resulting_shape = []
378
+ product = 1
379
+ for d in shape_to_merge:
380
+ if product * d <= max_dim:
381
+ product *= d
382
+ else:
383
+ if product > 1:
384
+ resulting_shape.append(product)
385
+ product = d
386
+ if product > 1:
387
+ resulting_shape.append(product)
388
+ return resulting_shape
389
+
390
+
391
+ def pad_matrix(mat, max_size):
392
+ """Pad a matrix to a max_size.
393
+
394
+ Args:
395
+ mat: a matrix to pad.
396
+ max_size: matrix size requested.
397
+
398
+ Returns:
399
+ Given M returns [[M, 0], [0, I]]
400
+ """
401
+ size = mat.shape[0]
402
+ assert size <= max_size
403
+ if size == max_size:
404
+ return mat
405
+ pad_size = max_size - size
406
+ zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype)
407
+ zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype)
408
+ eye = jnp.eye(pad_size, dtype=mat.dtype)
409
+ mat = jnp.concatenate([mat, zs1], 1)
410
+ mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
411
+ return mat
412
+
413
+
414
+ def pad_vector(vec, max_size):
415
+ """Pad a vector to a max_size.
416
+
417
+ Args:
418
+ vec: a vector to pad.
419
+ max_size: matrix size requested.
420
+
421
+ Returns:
422
+ Given V returns [V, 0]
423
+ """
424
+ size = vec.shape[0]
425
+ assert size <= max_size
426
+ if size == max_size:
427
+ return vec
428
+ pad_size = max_size - size
429
+ zs1 = jnp.zeros([pad_size], dtype=vec.dtype)
430
+ return jnp.concatenate([vec, zs1], 0)
431
+
432
+
433
+ def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
434
+ """Avoids wasteful buffer allocation with XLA."""
435
+
436
+ def _iter_body(unused_state):
437
+ results = compute_fn(*args, **kwargs)
438
+ return tuple([False] + list(results))
439
+
440
+ def _iter_condition(state):
441
+ return state[0]
442
+
443
+ results = jax.lax.while_loop(_iter_condition, _iter_body,
444
+ tuple([predicate] + init_state))
445
+ return tuple(results[1:])
446
+
447
+
448
+ class BlockPartitioner:
449
+ """Partitions a tensor into smaller tensors."""
450
+
451
+ def __init__(self, param, block_size):
452
+ self._shape = param.shape
453
+ self._splits = []
454
+ split_sizes = []
455
+ # We split params into smaller blocks. Here we store the metadata to make
456
+ # that split.
457
+ for i, d in enumerate(param.shape):
458
+ if 0 < block_size < d:
459
+ # d-1, otherwise split appends a 0-size array.
460
+ nsplit = (d - 1) // block_size
461
+ indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size
462
+ sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size
463
+ sizes[-1] = d - indices[-1]
464
+ self._splits.append((i, indices))
465
+ split_sizes.append(sizes)
466
+ else:
467
+ split_sizes.append(np.array([d], dtype=np.int32))
468
+ self._num_splits = len(split_sizes)
469
+ self._preconditioner_shapes = []
470
+ for t in itertools.product(*split_sizes):
471
+ self._preconditioner_shapes.extend([[d, d] for d in t])
472
+
473
+ def shapes_for_preconditioners(self):
474
+ return self._preconditioner_shapes
475
+
476
+ def num_splits(self):
477
+ return self._num_splits
478
+
479
+ def partition(self, tensor):
480
+ """Partition tensor into blocks."""
481
+
482
+ assert tensor.shape == self._shape
483
+ tensors = [tensor]
484
+ for (i, indices) in self._splits:
485
+ tensors_local = []
486
+ for t in tensors:
487
+ tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
488
+ tensors = tensors_local
489
+ return tensors
490
+
491
+ def merge_partitions(self, partitions):
492
+ """Merge partitions back to original shape."""
493
+
494
+ for (i, indices) in reversed(self._splits):
495
+ n = len(indices) + 1
496
+ partial_merged_tensors = []
497
+ ind = 0
498
+ while ind < len(partitions):
499
+ partial_merged_tensors.append(
500
+ jnp.concatenate(partitions[ind:ind + n], axis=i))
501
+ ind += n
502
+ partitions = partial_merged_tensors
503
+ assert len(partitions) == 1
504
+ return partitions[0]
505
+
506
+
507
+ class Preconditioner:
508
+ """Compute statistics/shape from gradients for preconditioning."""
509
+
510
+ def __init__(self, param, block_size, best_effort_shape_interpretation):
511
+ self._original_shape = param.shape
512
+ self._transformed_shape = param.shape
513
+ if best_effort_shape_interpretation:
514
+ self._transformed_shape = merge_small_dims(self._original_shape,
515
+ block_size)
516
+ reshaped_param = jnp.reshape(param, self._transformed_shape)
517
+ self._partitioner = BlockPartitioner(reshaped_param, block_size)
518
+
519
+ def statistics_from_grad(self, grad):
520
+ """Compute statistics from gradients.
521
+
522
+ Args:
523
+ grad: Gradient to compute statistics from.
524
+
525
+ Returns:
526
+ A list of gradient statistics for each partition.
527
+ """
528
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
529
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
530
+ stats = []
531
+ for g in partitioned_grads:
532
+ g_stats = []
533
+ rank = len(g.shape)
534
+ for i in range(rank):
535
+ axes = list(range(i)) + list(range(i + 1, rank))
536
+ stat = jnp.tensordot(g, g, axes=(axes, axes))
537
+ g_stats.append(stat)
538
+ stats.extend(g_stats)
539
+ return stats
540
+
541
+ def shapes_for_preconditioners(self):
542
+ """Returns shape from statistics."""
543
+ return self._partitioner.shapes_for_preconditioners()
544
+
545
+ def exponent_for_preconditioner(self):
546
+ """Returns exponent to use for inverse-pth root M^{-1/p}."""
547
+ return 2 * len(self._transformed_shape)
548
+
549
+ def preconditioned_grad(self, grad, preconditioners):
550
+ """Precondition the gradient.
551
+
552
+ Args:
553
+ grad: A gradient tensor to precondition.
554
+ preconditioners: A list of preconditioners to apply.
555
+
556
+ Returns:
557
+ A preconditioned gradient.
558
+ """
559
+
560
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
561
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
562
+ preconditioned_partitioned_grads = []
563
+ num_splits = self._partitioner.num_splits()
564
+ for i, g in enumerate(partitioned_grads):
565
+ preconditioners_for_grad = preconditioners[i * num_splits:(i + 1) *
566
+ num_splits]
567
+ rank = len(g.shape)
568
+ precond_g = g
569
+ for j in range(rank):
570
+ precond_g = jnp.tensordot(
571
+ precond_g, preconditioners_for_grad[j], axes=[[0], [0]])
572
+ preconditioned_partitioned_grads.append(precond_g)
573
+ merged_grad = self._partitioner.merge_partitions(
574
+ preconditioned_partitioned_grads)
575
+ return jnp.reshape(merged_grad, self._original_shape)
576
+
577
+
578
+ def _convert_to_parameter_stats(global_stats, local_stat):
579
+ """Creates parameter stats from sharded stats."""
580
+ index_start = int(local_stat.index_start)
581
+ index_end = int(len(local_stat.sizes)) + index_start
582
+ statistics = global_stats.statistics[index_start:index_end, :, :]
583
+ preconditioners = global_stats.preconditioners[index_start:index_end, :, :]
584
+ new_statistics = []
585
+ new_preconditioners = []
586
+ for i, size in enumerate(local_stat.sizes):
587
+ new_statistics.append(statistics[i][:size, :size])
588
+ new_preconditioners.append(preconditioners[i][:size, :size])
589
+ return ParameterStats(local_stat.diagonal_statistics, new_statistics,
590
+ new_preconditioners, local_stat.diagonal_momentum,
591
+ local_stat.momentum)
592
+
593
+
594
+ def _convert_from_parameter_stats(parameter_stats, local_stats):
595
+ """Creates sharded stats from paramter stats."""
596
+ return LocalShardedParameterStats(parameter_stats.diagonal_statistics,
597
+ parameter_stats.diagonal_momentum,
598
+ parameter_stats.momentum,
599
+ local_stats.index_start, local_stats.sizes)
600
+
601
+
602
+ def batch(x, num_devices):
603
+ """Batch `x` so that so that leading axis is num_devices."""
604
+ n = len(x)
605
+ b = int(n / num_devices)
606
+ return jnp.stack([jnp.stack(x[idx:idx + b]) for idx in range(0, n, b)])
607
+
608
+
609
+ def unbatch(batched_values):
610
+ """Unbatch values across leading axis and return a list of elements."""
611
+ b1, b2 = batched_values.shape[0], batched_values.shape[1]
612
+ results = []
613
+ for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0):
614
+ v_array = jnp.squeeze(v_array)
615
+ # b2 = batches (number of preconditioner computation) per core.
616
+ if b2 > 1:
617
+ for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
618
+ results.append(jnp.squeeze(v))
619
+ else:
620
+ results.append(v_array)
621
+ return results
622
+
623
+
624
+ def distributed_shampoo(
625
+ learning_rate,
626
+ block_size,
627
+ beta1=0.9,
628
+ beta2=0.999,
629
+ diagonal_epsilon=1e-10,
630
+ matrix_epsilon=1e-6,
631
+ weight_decay=0.0,
632
+ start_preconditioning_step=5,
633
+ preconditioning_compute_steps=1,
634
+ statistics_compute_steps=1,
635
+ best_effort_shape_interpretation=True,
636
+ graft_type=GraftingType.SGD,
637
+ nesterov=True,
638
+ exponent_override=0,
639
+ # Pass pmap 'batch axis name' in pmap mode.
640
+ batch_axis_name=None,
641
+ ### Only set following 3 params in pjit/spmd mode.
642
+ ### WARNING: Experimental
643
+ mesh_axis_names=None,
644
+ num_devices_for_pjit=None,
645
+ shard_optimizer_states=False,
646
+ ###
647
+ ### Experimental memory reduction mode
648
+ best_effort_memory_usage_reduction=False,
649
+ ###
650
+ inverse_failure_threshold=0.1,
651
+ moving_average_for_momentum=False,
652
+ skip_preconditioning_dim_size_gt=4096,
653
+ clip_by_scaled_gradient_norm=None,
654
+ precision=lax.Precision.HIGHEST):
655
+ """Distributed Shampoo optimizer.
656
+
657
+ Distributed Shampoo is a second-order preconditioned method (concretely, a
658
+ variant of full-matrix Adagrad), that provides significant convergence and
659
+ wall-clock time improvements compared to conventional first-order methods,
660
+ and that has been shown to scale to large state-of-the-art deep learning
661
+ models.
662
+
663
+ References:
664
+ Scalable Second Order Optimization for Deep Learning,
665
+ Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
666
+
667
+ Preprint: https://arxiv.org/abs/2002.09018
668
+
669
+ Args:
670
+ learning_rate: the step size used to update the parameters.
671
+ block_size: Block size for large layers (if > 0). Preconditioning compute
672
+ operation is cubic in the dimension of the tensor. Block size allows us to
673
+ chunk the layers into sub-layers of maximal dimension dictated by this
674
+ value. Use 128 as default (increase if you have compute budget).
675
+ beta1: momentum parameter.
676
+ beta2: second moment averaging parameter.
677
+ diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
678
+ to AdaGrad is enabled).
679
+ matrix_epsilon: epsilon to add to statistics before computing inverse pth
680
+ root. If you are running in f32 precision for inverse pth root
681
+ (recommended today) this can go upto 1e-6. If you have latest hardware
682
+ with native f64 precision, set this upto 1e-12.
683
+ weight_decay: Weight decay for regularization.
684
+ start_preconditioning_step: When to start Shampoo update before which
685
+ diagonal update is used. This is because we dont have enough information
686
+ to do stable inverse.
687
+ preconditioning_compute_steps: How often to compute preconditioner.
688
+ Performance tuning params for controlling memory and compute requirements.
689
+ Ideally set this and statistics_compute_steps params to 1.
690
+ statistics_compute_steps: How often to compute statistics.
691
+ best_effort_shape_interpretation: If there are some small dimensions,
692
+ collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
693
+ block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
694
+ graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
695
+ optimizer. This allows us to plugin the Shampoo optimizer into settings
696
+ where SGD/AdaGrad is already well tuned. Available options are:
697
+ GraftingType.SGD and GraftingType.ADAGRAD.
698
+ nesterov: Nesterov momentum.
699
+ exponent_override: Override the exponent used in matrix inverse.
700
+ batch_axis_name: labeled axis over pmap for data-parallel training the
701
+ optimizer used for.
702
+ mesh_axis_names: Axis names for the mesh (used in pjit).
703
+ num_devices_for_pjit: Number of devices to parallelize over when using pjit.
704
+ shard_optimizer_states: Shard optimizer states to save memory in model
705
+ parallel training.
706
+ best_effort_memory_usage_reduction: Best effort memory usage reduction.
707
+ diagonal_statistics -> jnp.bfloat16
708
+ momentum buffers (2x) -> jnp.int8
709
+ statistics, preconditioners -> jnp.int16 + diagonals
710
+ inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
711
+ determine that using this threshold.
712
+ moving_average_for_momentum: Whether to use moving average for momentum
713
+ instead of exponential moving average.
714
+ skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
715
+ greater than this value.
716
+ clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful
717
+ when using RMSProp Grafting).
718
+ precision: precision XLA related flag, the available options are: a)
719
+ lax.Precision.DEFAULT (better step time, but not precise) b)
720
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
721
+ (best possible precision, slowest)
722
+
723
+ Returns:
724
+ a GradientTransformation.
725
+ """
726
+
727
+ def quantized_dtype_for_momentum_buffers():
728
+ return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32
729
+
730
+ # TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes.
731
+ def quantized_dtype_for_diagonal_statistics_buffers():
732
+ return jnp.bfloat16 if best_effort_memory_usage_reduction else jnp.float32
733
+
734
+ # Preconditioner and statistics are both stores as int16 in this mode.
735
+ # We take out the diagonal to make quantization easier.
736
+ def quantized_dtype_for_second_moment_statistics_buffers():
737
+ return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32
738
+
739
+ # Preconditioner and statistics are both stores as int16 in this mode.
740
+ # We take out the diagonal to make quantization easier.
741
+ def quantized_dtype_for_second_moment_preconditioner_buffers():
742
+ return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32
743
+
744
+ def _to_float(maybe_quantized):
745
+ if isinstance(maybe_quantized, QuantizedValue):
746
+ return maybe_quantized.to_float()
747
+ else:
748
+ return maybe_quantized
749
+
750
+ def _maybe_quantize_statistics(statistics_list):
751
+ return _maybe_quantize_matrices_with_dtype(
752
+ statistics_list, quantized_dtype_for_second_moment_statistics_buffers())
753
+
754
+ def _maybe_quantize_preconditioners(statistics_list):
755
+ return _maybe_quantize_matrices_with_dtype(
756
+ statistics_list,
757
+ quantized_dtype_for_second_moment_preconditioner_buffers())
758
+
759
+ def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype):
760
+ if quantized_dtype != jnp.float32:
761
+ return ([
762
+ QuantizedValue.from_float_value(
763
+ s, quantized_dtype, extract_diagonal=True)
764
+ for s in statistics_list
765
+ ])
766
+ else:
767
+ return statistics_list
768
+
769
+ def _maybe_dequantize_preconditioners(preconditioner_list):
770
+ return _maybe_dequantize_matrices_with_dtype(
771
+ preconditioner_list,
772
+ quantized_dtype_for_second_moment_preconditioner_buffers())
773
+
774
+ def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype):
775
+ if quantized_dtype != jnp.float32:
776
+ return [s.to_float() for s in statistics_list]
777
+ else:
778
+ return statistics_list
779
+
780
+ def _quantize_diagonal_statistics(diagonal_statistics):
781
+ return QuantizedValue.from_float_value(
782
+ diagonal_statistics, quantized_dtype_for_diagonal_statistics_buffers())
783
+
784
+ def _quantize_momentum(momentum_statistics):
785
+ return QuantizedValue.from_float_value(
786
+ momentum_statistics, quantized_dtype_for_momentum_buffers())
787
+
788
+ def sharded_init_fn(params):
789
+ params_flat, treedef = jax.tree_flatten(params)
790
+ # Find max size to pad to.
791
+ max_size = 0
792
+ for param in params_flat:
793
+ preconditioner = Preconditioner(param, block_size,
794
+ best_effort_shape_interpretation)
795
+ if not _skip_preconditioning(param):
796
+ shapes = preconditioner.shapes_for_preconditioners()
797
+ sizes = [s[0] for s in shapes]
798
+ max_size = max(max(sizes), max_size)
799
+
800
+ padded_statistics = []
801
+ padded_preconditioners = []
802
+ local_stats_flat = []
803
+ for param in params_flat:
804
+ preconditioner = Preconditioner(param, block_size,
805
+ best_effort_shape_interpretation)
806
+ shapes = preconditioner.shapes_for_preconditioners()
807
+ sizes = []
808
+
809
+ statistics = []
810
+ preconditioners = []
811
+ index_start = len(padded_statistics)
812
+ if not _skip_preconditioning(param):
813
+ sizes = [s[0] for s in shapes]
814
+ shapes = preconditioner.shapes_for_preconditioners()
815
+ statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes]
816
+ preconditioners = [jnp.eye(max_size) for s in shapes]
817
+ padded_statistics.extend(statistics)
818
+ padded_preconditioners.extend(preconditioners)
819
+
820
+ diagonal_statistics = []
821
+ if graft_type != GraftingType.SGD:
822
+ diagonal_statistics = jnp.zeros_like(param)
823
+ local_stats_flat.append(
824
+ LocalShardedParameterStats(
825
+ _quantize_diagonal_statistics(diagonal_statistics),
826
+ _quantize_momentum(jnp.zeros_like(param)),
827
+ _quantize_momentum(jnp.zeros_like(param)), index_start, sizes))
828
+
829
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
830
+ # Pad the statistics and preconditioner matrices to be a multiple of
831
+ # num devices.
832
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
833
+ # is split on.
834
+ to_pad = -len(padded_statistics) % num_devices_for_pjit
835
+ padded_statistics.extend([
836
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
837
+ for _ in range(to_pad)
838
+ ])
839
+ padded_preconditioners.extend([
840
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
841
+ for _ in range(to_pad)
842
+ ])
843
+ global_stats = GlobalShardedParameterStats(
844
+ jnp.stack(padded_statistics), jnp.stack(padded_preconditioners))
845
+ return ShampooState(
846
+ count=jnp.zeros([], jnp.int32),
847
+ stats=ShardedShampooStats(global_stats, local_stats))
848
+
849
+ def sharded_update_fn(grads, state, params):
850
+ """Transform the input gradient and update all statistics in sharded mode.
851
+
852
+ Args:
853
+ grads: the gradient tensors for the parameters.
854
+ state: a named tuple containing the state of the optimizer
855
+ params: the parameters that should be updated.
856
+
857
+ Returns:
858
+ A tuple containing the new parameters and the new optimizer state.
859
+ """
860
+ params_flat, treedef = jax.tree_flatten(params)
861
+ grads_flat = treedef.flatten_up_to(grads)
862
+
863
+ global_stats = state.stats.global_stats
864
+ local_stats_flat = treedef.flatten_up_to(state.stats.local_stats)
865
+ stats_flat = [
866
+ _convert_to_parameter_stats(global_stats, local_stat)
867
+ for local_stat in local_stats_flat
868
+ ]
869
+ new_stats_flat = jax.tree_multimap(
870
+ lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
871
+ stats_flat, params_flat)
872
+
873
+ exponents = []
874
+ for stat, param in zip(new_stats_flat, params_flat):
875
+ num_statistics = len(stat.statistics)
876
+ if num_statistics > 0:
877
+ preconditioner = Preconditioner(param, block_size,
878
+ best_effort_shape_interpretation)
879
+ exponent = (
880
+ preconditioner.exponent_for_preconditioner()
881
+ if exponent_override == 0 else exponent_override)
882
+ exponents.extend([exponent] * num_statistics)
883
+
884
+ outputs = jax.tree_multimap(
885
+ lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
886
+ new_stats_flat, params_flat)
887
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
888
+
889
+ updates = jax.tree_unflatten(treedef, updates_flat)
890
+ # Create new local_stats
891
+ new_local_stats_flat = [
892
+ _convert_from_parameter_stats(new_stat, local_stat)
893
+ for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
894
+ ]
895
+ new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
896
+
897
+ max_size = global_stats.statistics.shape[1]
898
+ new_padded_statistics = []
899
+ for stat in new_stats_flat:
900
+ new_padded_statistics.extend(
901
+ [pad_matrix(stat, max_size) for stat in stat.statistics])
902
+
903
+ # Create global stats
904
+ # TODO(rohananil): Preconditioner is not updated every step, so cost of
905
+ # stack/pad can be obviated away.
906
+ # Pad the statistics and preconditioner matrices to be a multiple of
907
+ # num devices.
908
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
909
+ # is split on.
910
+ to_pad = -len(new_padded_statistics) % num_devices_for_pjit
911
+ new_padded_statistics.extend([
912
+ jnp.eye(max_size, dtype=new_padded_statistics[0].dtype)
913
+ for _ in range(to_pad)
914
+ ])
915
+ exponents.extend([1 for _ in range(to_pad)])
916
+ new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
917
+ new_stacked_exponents = jnp.stack(exponents)
918
+ def _matrix_inverse_pth_root_vmap(xs, ps):
919
+ mi_pth_root = functools.partial(
920
+ matrix_inverse_pth_root,
921
+ ridge_epsilon=matrix_epsilon,
922
+ precision=precision)
923
+ preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
924
+ return preconditioners, errors
925
+
926
+ def _internal_inverse_pth_root_all():
927
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
928
+ new_stacked_padded_statistics, new_stacked_exponents)
929
+ return preconditioners, errors
930
+
931
+ if preconditioning_compute_steps == 1:
932
+ new_preconditioners, errors = _internal_inverse_pth_root_all()
933
+ else:
934
+ # Passing statistics instead of preconditioners as they are similarly
935
+ # shaped tensors. Note statistics will be ignored as we are passing in
936
+ # a large init value for error.
937
+ preconditioners_init = new_stacked_padded_statistics
938
+ errors_init = np.stack([inverse_failure_threshold] * len(exponents))
939
+ init_state = [preconditioners_init, errors_init]
940
+ perform_step = state.count % preconditioning_compute_steps == 0
941
+ new_preconditioners, errors = efficient_cond(
942
+ perform_step, _internal_inverse_pth_root_all, init_state)
943
+
944
+ errors = errors.reshape((-1, 1, 1))
945
+ predicate = jnp.logical_or(
946
+ jnp.isnan(errors),
947
+ errors >= inverse_failure_threshold).astype(new_preconditioners.dtype)
948
+ # TODO(rohananil): Check for numerical instabilities.
949
+ new_conditional_preconditioners = (
950
+ predicate * global_stats.preconditioners +
951
+ (1.0 - predicate) * new_preconditioners)
952
+ new_global_stats = GlobalShardedParameterStats(
953
+ new_stacked_padded_statistics, new_conditional_preconditioners)
954
+ new_shampoo_state = ShampooState(
955
+ count=state.count + 1,
956
+ stats=ShardedShampooStats(new_global_stats, new_local_stats))
957
+ return updates, new_shampoo_state
958
+
959
+ def init_fn(params):
960
+ """Initialise the optimiser's state."""
961
+
962
+ def _init(param):
963
+ preconditioner = Preconditioner(param, block_size,
964
+ best_effort_shape_interpretation)
965
+ statistics = []
966
+ preconditioners = []
967
+ if not _skip_preconditioning(param):
968
+ shapes = preconditioner.shapes_for_preconditioners()
969
+ statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes]
970
+ preconditioners = [jnp.eye(s[0]) for s in shapes]
971
+
972
+ diagonal_statistics = []
973
+ if graft_type != GraftingType.SGD:
974
+ diagonal_statistics = jnp.zeros_like(param)
975
+ return ParameterStats(
976
+ _quantize_diagonal_statistics(diagonal_statistics),
977
+ _maybe_quantize_statistics(statistics),
978
+ _maybe_quantize_preconditioners(preconditioners),
979
+ _quantize_momentum(jnp.zeros_like(param)),
980
+ _quantize_momentum(jnp.zeros_like(param)))
981
+ return ShampooState(
982
+ count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params))
983
+
984
+ def _skip_preconditioning(param):
985
+ return len(param.shape) < 1 or any(
986
+ [s > skip_preconditioning_dim_size_gt for s in param.shape])
987
+
988
+ def _compute_stats(grad, state, param, step):
989
+ """Compute per-parameter statistics."""
990
+ preconditioner = Preconditioner(param, block_size,
991
+ best_effort_shape_interpretation)
992
+ new_statistics = [[]] * len(state.statistics)
993
+ w1 = beta2
994
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
995
+ if not _skip_preconditioning(param):
996
+
997
+ def compute_updated_statistics():
998
+ new_stats = preconditioner.statistics_from_grad(grad)
999
+ new_stats_accumulators = []
1000
+ for stat, stat_accumulator in zip(new_stats, state.statistics):
1001
+ new_stats_accumulators.append(w1 * _to_float(stat_accumulator) +
1002
+ w2 * stat)
1003
+ return _maybe_quantize_statistics(new_stats_accumulators)
1004
+
1005
+ if statistics_compute_steps > 1:
1006
+ perform_step = step % statistics_compute_steps == 0
1007
+ init_state = state.statistics
1008
+ new_statistics = list(
1009
+ efficient_cond(perform_step, compute_updated_statistics,
1010
+ init_state))
1011
+ else:
1012
+ new_statistics = compute_updated_statistics()
1013
+ return ParameterStats(state.diagonal_statistics, new_statistics,
1014
+ state.preconditioners, state.diagonal_momentum,
1015
+ state.momentum)
1016
+
1017
+ def _matrix_inverse_pth_root_vmap(xs, ps):
1018
+ mi_pth_root = functools.partial(
1019
+ matrix_inverse_pth_root,
1020
+ ridge_epsilon=matrix_epsilon,
1021
+ precision=precision)
1022
+ return jax.vmap(mi_pth_root)(xs, ps)
1023
+
1024
+ def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps):
1025
+
1026
+ def _quantized_to_float(qx, qd, qb):
1027
+ qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape))
1028
+ return qv.to_float()
1029
+
1030
+ def matrix_inverse_pth_root_wrapper(qx, qd, qb, p):
1031
+ v = _quantized_to_float(qx, qd, qb)
1032
+ preconditioner, error = matrix_inverse_pth_root(
1033
+ v, p, ridge_epsilon=matrix_epsilon, precision=precision)
1034
+ qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True)
1035
+ return qp.quantized, qp.diagonal, qp.bucket_size, error
1036
+
1037
+ return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
1038
+
1039
+ def _matrix_inverse_pth_root_pjit(xs, ps):
1040
+ mesh_axis_names_tuple = tuple(mesh_axis_names)
1041
+ # Partition the concatenated statistics matrix across all cores.
1042
+ partitioned_xs, partitioned_ps = pjit.pjit(
1043
+ lambda x, y: (x, y),
1044
+ in_axis_resources=None,
1045
+ out_axis_resources=pjit.PartitionSpec(mesh_axis_names_tuple,))(xs, ps)
1046
+ # Run matrix inverse pth root on each shard.
1047
+ partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
1048
+ partitioned_xs, partitioned_ps)
1049
+ # Recombine the outputs at each core.
1050
+ preconditioners, errors = pjit.pjit(
1051
+ lambda x, y: (x, y),
1052
+ in_axis_resources=(pjit.PartitionSpec(mesh_axis_names_tuple,),
1053
+ pjit.PartitionSpec(mesh_axis_names_tuple,)),
1054
+ out_axis_resources=(None, None))(partitioned_preconditioners,
1055
+ partitioned_errors)
1056
+ return preconditioners, errors
1057
+
1058
+ def _pmap_compute_preconditioners(states, step, statistics,
1059
+ num_statistics_per_state, original_shapes,
1060
+ exponents, max_size, prev_preconditioners):
1061
+ """Computes preconditioners for given statistics in states in PMAP mode.
1062
+
1063
+ Args:
1064
+ states: A list of optimizer states.
1065
+ step: Current step number
1066
+ statistics: A list of statistics for all variables (for every dim)
1067
+ num_statistics_per_state: Number of statistis per state to reconstruct
1068
+ output states.
1069
+ original_shapes: A list of shapes of the statistics.
1070
+ exponents: Exponent power to use for inverse-pth roots.
1071
+ max_size: Maximum dim of the statistics to pad.
1072
+ prev_preconditioners: Previously available preconditioner.
1073
+
1074
+ Returns:
1075
+ New optimizer states after computing the preconditioner.
1076
+ """
1077
+ num_devices = lax.psum(1, batch_axis_name)
1078
+ num_statistics = len(statistics)
1079
+ # Pad statistics and exponents to next multiple of num_devices.
1080
+ packed_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1081
+ to_pad = -num_statistics % num_devices
1082
+ packed_statistics.extend([
1083
+ jnp.eye(max_size, dtype=packed_statistics[0].dtype)
1084
+ for _ in range(to_pad)
1085
+ ])
1086
+ exponents.extend([1 for _ in range(to_pad)])
1087
+
1088
+ if not packed_statistics:
1089
+ return states
1090
+
1091
+ all_statistics = batch(packed_statistics, num_devices)
1092
+ all_exponents = batch(exponents, num_devices)
1093
+
1094
+ def _internal_inverse_pth_root_all():
1095
+ current_replica = lax.axis_index(batch_axis_name)
1096
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
1097
+ all_statistics[current_replica], all_exponents[current_replica])
1098
+ preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
1099
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1100
+ preconditioners_flat = unbatch(preconditioners)
1101
+ errors_flat = unbatch(errors)
1102
+ return preconditioners_flat, errors_flat
1103
+
1104
+ if preconditioning_compute_steps == 1:
1105
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1106
+ else:
1107
+ # Passing statistics instead of preconditioners as they are similarly
1108
+ # shaped tensors. Note statistics will be ignored as we are passing in
1109
+ # a large init value for error.
1110
+ preconditioners_init = packed_statistics
1111
+ errors_init = ([inverse_failure_threshold] * len(packed_statistics))
1112
+ init_state = [preconditioners_init, errors_init]
1113
+ perform_step = step % preconditioning_compute_steps == 0
1114
+ preconditioners_flat, errors_flat = efficient_cond(
1115
+ perform_step, _internal_inverse_pth_root_all, init_state)
1116
+
1117
+ def _skip(error):
1118
+ condition = jnp.logical_or(
1119
+ jnp.isnan(error), error >= inverse_failure_threshold)
1120
+ return condition.astype(error.dtype)
1121
+
1122
+ def _select_preconditioner(error, new_p, old_p):
1123
+ return lax.cond(
1124
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
1125
+
1126
+ new_preconditioners_flat = []
1127
+ for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
1128
+ prev_preconditioners, errors_flat):
1129
+ new_preconditioners_flat.append(
1130
+ _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
1131
+
1132
+ assert len(states) == len(num_statistics_per_state)
1133
+ assert len(new_preconditioners_flat) == num_statistics
1134
+
1135
+ # Add back empty preconditioners so we that we can set the optimizer state.
1136
+ preconditioners_for_states = []
1137
+ idx = 0
1138
+ for num_statistics, state in zip(num_statistics_per_state, states):
1139
+ if num_statistics == 0:
1140
+ preconditioners_for_states.append([])
1141
+ else:
1142
+ preconditioners_for_state = new_preconditioners_flat[idx:idx +
1143
+ num_statistics]
1144
+ assert len(state.statistics) == len(preconditioners_for_state)
1145
+ preconditioners_for_states.append(preconditioners_for_state)
1146
+ idx += num_statistics
1147
+ new_states = []
1148
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1149
+ new_states.append(
1150
+ ParameterStats(state.diagonal_statistics, state.statistics,
1151
+ new_preconditioners, state.diagonal_momentum,
1152
+ state.momentum))
1153
+
1154
+ return new_states
1155
+
1156
+ def _pmap_quantized_compute_preconditioners(states, step, statistics,
1157
+ num_statistics_per_state,
1158
+ original_shapes, exponents,
1159
+ max_size, prev_preconditioners):
1160
+ """Computes preconditioners for given statistics in states in PMAP mode.
1161
+
1162
+ For quantization, each statistic is represented by three values:
1163
+ quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots
1164
+ without ever recreating the original matrix in f32.
1165
+
1166
+ Args:
1167
+ states: A list of optimizer states.
1168
+ step: Current step number
1169
+ statistics: A list of statistics for all variables (for every dim)
1170
+ num_statistics_per_state: Number of statistis per state to reconstruct
1171
+ output states.
1172
+ original_shapes: A list of shapes of the statistics.
1173
+ exponents: Exponent power to use for inverse-pth roots.
1174
+ max_size: Maximum dim of the statistics to pad.
1175
+ prev_preconditioners: Previously available preconditioner.
1176
+
1177
+ Returns:
1178
+ New optimizer states after computing the preconditioner.
1179
+ """
1180
+ num_devices = lax.psum(1, batch_axis_name)
1181
+ num_statistics = len(statistics)
1182
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
1183
+ # Complexity here is around: shapes needing be statically shaped,
1184
+ # our custom quantization type requires a different type of packing.
1185
+
1186
+ # Parallel tensors:
1187
+ # quantized [dxd]
1188
+ # diagonals [d] f32
1189
+ # bucket_sizes [d] f32
1190
+ packed_quantized_statistics = [
1191
+ pad_matrix(stat.quantized, max_size) for stat in statistics
1192
+ ]
1193
+ packed_quantized_diagonals = [
1194
+ pad_vector(stat.diagonal, max_size) for stat in statistics
1195
+ ]
1196
+ packed_quantized_bucket_sizes = [
1197
+ pad_vector(stat.bucket_size, max_size) for stat in statistics
1198
+ ]
1199
+
1200
+ to_pad = -num_statistics % num_devices
1201
+ padded_eye = jnp.eye(max_size, dtype=jnp.float32)
1202
+ quantized_eye = QuantizedValue.from_float_value(padded_eye, quantized_dtype,
1203
+ True)
1204
+ packed_quantized_statistics.extend(
1205
+ [quantized_eye.quantized for _ in range(to_pad)])
1206
+ packed_quantized_diagonals.extend(
1207
+ [quantized_eye.diagonal for _ in range(to_pad)])
1208
+ packed_quantized_bucket_sizes.extend(
1209
+ [quantized_eye.bucket_size for _ in range(to_pad)])
1210
+ exponents.extend([1 for _ in range(to_pad)])
1211
+
1212
+ if not packed_quantized_statistics:
1213
+ return states
1214
+
1215
+ all_quantized_statistics = batch(packed_quantized_statistics, num_devices)
1216
+ all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices)
1217
+ all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes,
1218
+ num_devices)
1219
+ all_exponents = batch(exponents, num_devices)
1220
+
1221
+ def _internal_inverse_pth_root_all():
1222
+ current_replica = lax.axis_index(batch_axis_name)
1223
+ quantized_preconditioners, quantized_diagonals, quantized_bucket_sizes, errors = (
1224
+ _quantized_matrix_inverse_pth_root_vmap(
1225
+ all_quantized_statistics[current_replica],
1226
+ all_quantized_diagonals[current_replica],
1227
+ all_quantized_bucket_sizes[current_replica],
1228
+ all_exponents[current_replica]))
1229
+ quantized_preconditioners = jax.lax.all_gather(quantized_preconditioners,
1230
+ batch_axis_name)
1231
+ quantized_diagonals = jax.lax.all_gather(quantized_diagonals,
1232
+ batch_axis_name)
1233
+ quantized_bucket_sizes = jax.lax.all_gather(quantized_bucket_sizes,
1234
+ batch_axis_name)
1235
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1236
+ quantized_preconditioners_flat = unbatch(quantized_preconditioners)
1237
+ quantized_diagonals_flat = unbatch(quantized_diagonals)
1238
+ quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes)
1239
+ errors_flat = unbatch(errors)
1240
+ return (quantized_preconditioners_flat, quantized_diagonals_flat,
1241
+ quantized_bucket_sizes_flat, errors_flat)
1242
+
1243
+ if preconditioning_compute_steps == 1:
1244
+ (quantized_preconditioners_flat, quantized_diagonals_flat,
1245
+ quantized_bucket_sizes_flat, errors_flat) = (
1246
+ _internal_inverse_pth_root_all())
1247
+ else:
1248
+ # Passing statistics instead of preconditioners as they are similarly
1249
+ # shaped tensors. Note statistics will be ignored as we are passing in
1250
+ # a large init value for error.
1251
+ quantized_preconditioners_init = packed_quantized_statistics
1252
+ quantized_diagonals_init = packed_quantized_diagonals
1253
+ quantized_bucket_sizes_init = packed_quantized_bucket_sizes
1254
+ errors_init = ([inverse_failure_threshold] *
1255
+ len(quantized_preconditioners_init))
1256
+ init_state = [
1257
+ quantized_preconditioners_init, quantized_diagonals_init,
1258
+ quantized_bucket_sizes_init, errors_init
1259
+ ]
1260
+ perform_step = step % preconditioning_compute_steps == 0
1261
+ (quantized_preconditioners_flat, quantized_diagonals_flat,
1262
+ quantized_bucket_sizes_flat, errors_flat) = (
1263
+ efficient_cond(perform_step, _internal_inverse_pth_root_all,
1264
+ init_state))
1265
+
1266
+ def _skip(error):
1267
+ condition = jnp.logical_or(
1268
+ jnp.isnan(error), error >= inverse_failure_threshold)
1269
+ return condition.astype(error.dtype)
1270
+
1271
+ def _select_preconditioner(error, new_p, old_p):
1272
+ return lax.cond(
1273
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
1274
+
1275
+ new_quantized_preconditioners_flat = []
1276
+ new_quantized_diagonals_flat = []
1277
+ new_quantized_bucket_sizes_flat = []
1278
+ for p, d, b, shape, prev_p, error in zip(quantized_preconditioners_flat,
1279
+ quantized_diagonals_flat,
1280
+ quantized_bucket_sizes_flat,
1281
+ original_shapes,
1282
+ prev_preconditioners, errors_flat):
1283
+ new_quantized_preconditioners_flat.append(
1284
+ _select_preconditioner(error, p[:shape[0], :shape[1]],
1285
+ prev_p.quantized))
1286
+ new_quantized_diagonals_flat.append(
1287
+ _select_preconditioner(error, d[:shape[0]], prev_p.diagonal))
1288
+ new_quantized_bucket_sizes_flat.append(
1289
+ _select_preconditioner(error, b[:shape[0]], prev_p.bucket_size))
1290
+
1291
+ assert len(states) == len(num_statistics_per_state)
1292
+ assert len(new_quantized_preconditioners_flat) == num_statistics
1293
+ assert len(new_quantized_diagonals_flat) == num_statistics
1294
+ assert len(new_quantized_bucket_sizes_flat) == num_statistics
1295
+
1296
+ # Add back empty preconditioners so we that we can set the optimizer state.
1297
+ preconditioners_for_states = []
1298
+ idx = 0
1299
+ for num_statistics, state in zip(num_statistics_per_state, states):
1300
+ if num_statistics == 0:
1301
+ preconditioners_for_states.append([])
1302
+ else:
1303
+ quantized_preconditioners_for_state = new_quantized_preconditioners_flat[
1304
+ idx:idx + num_statistics]
1305
+ quantized_diagonals_for_state = new_quantized_diagonals_flat[
1306
+ idx:idx + num_statistics]
1307
+ quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
1308
+ idx:idx + num_statistics]
1309
+
1310
+ assert len(state.statistics) == len(quantized_preconditioners_for_state)
1311
+ assert len(state.statistics) == len(quantized_diagonals_for_state)
1312
+ assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
1313
+
1314
+ quantized_preconditioners = []
1315
+ for qv, qd, qb in zip(quantized_preconditioners_for_state,
1316
+ quantized_diagonals_for_state,
1317
+ quantized_bucket_sizes_for_state):
1318
+ quantized_preconditioners.append(
1319
+ QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape)))
1320
+ preconditioners_for_states.append(quantized_preconditioners)
1321
+ idx += num_statistics
1322
+ new_states = []
1323
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1324
+ new_states.append(
1325
+ ParameterStats(state.diagonal_statistics, state.statistics,
1326
+ new_preconditioners, state.diagonal_momentum,
1327
+ state.momentum))
1328
+
1329
+ return new_states
1330
+
1331
+ def _pjit_compute_preconditioners(states, step, statistics,
1332
+ num_statistics_per_state, original_shapes,
1333
+ exponents, max_size, prev_preconditioners):
1334
+ """Computes preconditioners for given statistics in states in PJIT mode.
1335
+
1336
+ Args:
1337
+ states: A list of optimizer states.
1338
+ step: Current step number
1339
+ statistics: A list of statistics for all variables (for every dim)
1340
+ num_statistics_per_state: Number of statistis per state to reconstruct
1341
+ output states.
1342
+ original_shapes: A list of shapes of the statistics.
1343
+ exponents: Exponent power to use for inverse-pth roots.
1344
+ max_size: Maximum dim of the statistics to pad.
1345
+ prev_preconditioners: Previously available preconditioner.
1346
+
1347
+ Returns:
1348
+ New optimizer states after computing the preconditioner.
1349
+ """
1350
+ num_statistics = len(statistics)
1351
+ to_pad = -num_statistics % num_devices_for_pjit
1352
+ padded_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1353
+ padded_statistics.extend([
1354
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
1355
+ for _ in range(to_pad)
1356
+ ])
1357
+ exponents.extend([1 for _ in range(to_pad)])
1358
+ all_statistics = jnp.stack(padded_statistics)
1359
+ all_exponents = jnp.stack(exponents)
1360
+
1361
+ def _internal_inverse_pth_root_all():
1362
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
1363
+ all_statistics, all_exponents)
1364
+ b1 = preconditioners.shape[0]
1365
+
1366
+ def split(batched_values):
1367
+ return [
1368
+ jnp.squeeze(v)
1369
+ for v in jnp.split(batched_values, indices_or_sections=b1, axis=0)
1370
+ ]
1371
+
1372
+ return split(preconditioners), split(errors)
1373
+
1374
+ if preconditioning_compute_steps == 1:
1375
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1376
+ else:
1377
+ # Passing statistics instead of preconditioners as they are similarly
1378
+ # shaped tensors. Note statistics will be ignored as we are passing in
1379
+ # a large init value for error.
1380
+ preconditioners_init = padded_statistics
1381
+ errors_init = [inverse_failure_threshold] * len(padded_statistics)
1382
+ init_state = [preconditioners_init, errors_init]
1383
+ perform_step = step % preconditioning_compute_steps == 0
1384
+ preconditioners_flat, errors_flat = efficient_cond(
1385
+ perform_step, _internal_inverse_pth_root_all, init_state)
1386
+
1387
+ def _skip(error):
1388
+ condition = jnp.logical_or(
1389
+ jnp.isnan(error), error >= inverse_failure_threshold)
1390
+ return condition.astype(error.dtype)
1391
+
1392
+ def _select_preconditioner(error, new_p, old_p):
1393
+ return lax.cond(
1394
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
1395
+
1396
+ new_preconditioners_flat = []
1397
+ for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
1398
+ prev_preconditioners, errors_flat):
1399
+ new_preconditioners_flat.append(
1400
+ _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
1401
+
1402
+ assert len(states) == len(num_statistics_per_state)
1403
+ assert len(new_preconditioners_flat) == num_statistics
1404
+
1405
+ # Add back empty preconditioners so we that we can set the optimizer state.
1406
+ preconditioners_for_states = []
1407
+ idx = 0
1408
+ for num_statistics, state in zip(num_statistics_per_state, states):
1409
+ if num_statistics == 0:
1410
+ preconditioners_for_states.append([])
1411
+ else:
1412
+ preconditioners_for_state = new_preconditioners_flat[idx:idx +
1413
+ num_statistics]
1414
+ assert len(state.statistics) == len(preconditioners_for_state)
1415
+ preconditioners_for_states.append(preconditioners_for_state)
1416
+ idx += num_statistics
1417
+ new_states = []
1418
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1419
+ new_states.append(
1420
+ ParameterStats(state.diagonal_statistics, state.statistics,
1421
+ new_preconditioners, state.diagonal_momentum,
1422
+ state.momentum))
1423
+
1424
+ return new_states
1425
+
1426
+ def _compute_preconditioners(states, params, step):
1427
+ """Computes preconditioners for given statistics in states.
1428
+
1429
+ Args:
1430
+ states: A list of optimizer states.
1431
+ params: A list of params.
1432
+ step: Current step number
1433
+
1434
+ Returns:
1435
+ New optimizer states after computing the preconditioner.
1436
+ """
1437
+ statistics = []
1438
+ num_statistics_per_state = []
1439
+ original_shapes = []
1440
+ exponents = []
1441
+ max_size = 0
1442
+ prev_preconditioners = []
1443
+
1444
+ for state, param in zip(states, params):
1445
+ num_statistics = len(state.statistics)
1446
+ num_statistics_per_state.append(num_statistics)
1447
+ original_shapes_for_state = []
1448
+ if num_statistics > 0:
1449
+ preconditioner = Preconditioner(param, block_size,
1450
+ best_effort_shape_interpretation)
1451
+ for statistic in state.statistics:
1452
+ exponents.append(preconditioner.exponent_for_preconditioner(
1453
+ ) if exponent_override == 0 else exponent_override)
1454
+ original_shapes_for_state.append(statistic.shape)
1455
+ max_size = max(max_size, statistic.shape[0])
1456
+
1457
+ statistics.extend(state.statistics)
1458
+ prev_preconditioners.extend(state.preconditioners)
1459
+ original_shapes.extend(original_shapes_for_state)
1460
+
1461
+ if batch_axis_name:
1462
+ # Quantization is only enabled if batch_axis_name is not set.
1463
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
1464
+
1465
+ if quantized_dtype == jnp.float32:
1466
+ return _pmap_compute_preconditioners(states, step, statistics,
1467
+ num_statistics_per_state,
1468
+ original_shapes, exponents,
1469
+ max_size, prev_preconditioners)
1470
+ else:
1471
+ return _pmap_quantized_compute_preconditioners(
1472
+ states, step, statistics, num_statistics_per_state, original_shapes,
1473
+ exponents, max_size, prev_preconditioners)
1474
+
1475
+ else:
1476
+ return _pjit_compute_preconditioners(states, step, statistics,
1477
+ num_statistics_per_state,
1478
+ original_shapes, exponents, max_size,
1479
+ prev_preconditioners)
1480
+
1481
+ def _transform_grad(grad, state, param, step):
1482
+ """Transform per-parameter gradients."""
1483
+ preconditioner = Preconditioner(param, block_size,
1484
+ best_effort_shape_interpretation)
1485
+ sgd_update = grad
1486
+ new_diagonal_statistics = state.diagonal_statistics.to_float()
1487
+ if graft_type == GraftingType.ADAGRAD:
1488
+ new_diagonal_statistics = state.diagonal_statistics.to_float(
1489
+ ) + jnp.square(grad)
1490
+ adagrad_update = grad / (
1491
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
1492
+ grafting_update = adagrad_update
1493
+ elif (graft_type == GraftingType.RMSPROP or
1494
+ graft_type == GraftingType.RMSPROP_NORMALIZED):
1495
+
1496
+ scaled_grad = grad
1497
+ if graft_type == GraftingType.RMSPROP_NORMALIZED:
1498
+ scaled_grad = grad / jnp.linalg.norm(grad)
1499
+
1500
+ w1 = beta2
1501
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
1502
+
1503
+ new_diagonal_statistics = (
1504
+ w1 * state.diagonal_statistics.to_float() +
1505
+ w2 * jnp.square(scaled_grad))
1506
+ rmsprop_update = scaled_grad / (
1507
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
1508
+
1509
+ if clip_by_scaled_gradient_norm:
1510
+ scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / (
1511
+ jnp.sqrt(float(rmsprop_update.size)))
1512
+ clipping_denom = jnp.maximum(
1513
+ 1., scaled_grad_norm / clip_by_scaled_gradient_norm)
1514
+ rmsprop_update /= clipping_denom
1515
+
1516
+ grafting_update = rmsprop_update
1517
+ else:
1518
+ grafting_update = sgd_update
1519
+
1520
+ precond_grad = grad
1521
+ if not _skip_preconditioning(param):
1522
+ precond_grad = preconditioner.preconditioned_grad(
1523
+ precond_grad,
1524
+ _maybe_dequantize_preconditioners(state.preconditioners))
1525
+ else:
1526
+ precond_grad = grafting_update
1527
+
1528
+ grafting_update_norm = jnp.linalg.norm(grafting_update)
1529
+ precond_grad_norm = jnp.linalg.norm(precond_grad)
1530
+
1531
+ multiplier = (grafting_update_norm / (precond_grad_norm + 1e-16))
1532
+ shampoo_update = precond_grad * multiplier
1533
+
1534
+ shampoo_update_with_wd = shampoo_update
1535
+ grafting_update_with_wd = grafting_update
1536
+ if weight_decay != 0:
1537
+ shampoo_update_with_wd = shampoo_update + weight_decay * param
1538
+ grafting_update_with_wd = grafting_update + weight_decay * param
1539
+
1540
+ w = (1.0 - beta1) if moving_average_for_momentum else 1.0
1541
+ shampoo_update_with_wd_momentum = (
1542
+ state.momentum.to_float() * beta1 + w * shampoo_update_with_wd)
1543
+ grafting_update_with_wd_momentum = (
1544
+ state.diagonal_momentum.to_float() * beta1 +
1545
+ w * grafting_update_with_wd)
1546
+
1547
+ run_shampoo = (step >= start_preconditioning_step).astype(
1548
+ grafting_update_with_wd_momentum.dtype)
1549
+
1550
+ momentum_update = (
1551
+ run_shampoo * shampoo_update_with_wd_momentum +
1552
+ (1.0 - run_shampoo) * grafting_update_with_wd_momentum)
1553
+
1554
+ wd_update = (
1555
+ run_shampoo * shampoo_update_with_wd +
1556
+ (1.0 - run_shampoo) * grafting_update_with_wd)
1557
+
1558
+ if nesterov:
1559
+ momentum_update = w * wd_update + beta1 * momentum_update
1560
+
1561
+ lr = learning_rate
1562
+ if callable(learning_rate):
1563
+ lr = learning_rate(step)
1564
+ transformed_update = -1.0 * lr * momentum_update
1565
+
1566
+ param_stats = ParameterStats(
1567
+ _quantize_diagonal_statistics(new_diagonal_statistics),
1568
+ state.statistics, state.preconditioners,
1569
+ _quantize_momentum(grafting_update_with_wd_momentum),
1570
+ _quantize_momentum(shampoo_update_with_wd_momentum))
1571
+ return transformed_update, param_stats
1572
+
1573
+ def update_fn(grads, state, params):
1574
+ """Transform the input gradient and update all statistics.
1575
+
1576
+ Args:
1577
+ grads: the gradient tensors for the parameters.
1578
+ state: a named tuple containing the state of the optimizer
1579
+ params: the parameters that should be updated.
1580
+
1581
+ Returns:
1582
+ A tuple containing the new parameters and the new optimizer state.
1583
+ """
1584
+ params_flat, treedef = jax.tree_flatten(params)
1585
+ stats_flat = treedef.flatten_up_to(state.stats)
1586
+ grads_flat = treedef.flatten_up_to(grads)
1587
+
1588
+ new_stats_flat = jax.tree_multimap(
1589
+ lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
1590
+ stats_flat, params_flat)
1591
+ new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat,
1592
+ state.count)
1593
+
1594
+ outputs = jax.tree_multimap(
1595
+ lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
1596
+ new_stats_flat, params_flat)
1597
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
1598
+
1599
+ updates = jax.tree_unflatten(treedef, updates_flat)
1600
+ new_stats = jax.tree_unflatten(treedef, new_stats_flat)
1601
+
1602
+ new_state = ShampooState(
1603
+ count=state.count+1, stats=new_stats)
1604
+ return updates, new_state
1605
+
1606
+ if shard_optimizer_states:
1607
+ return optax.GradientTransformation(sharded_init_fn, sharded_update_fn)
1608
+ else:
1609
+ return optax.GradientTransformation(init_fn, update_fn)
merges.txt ADDED
The diff for this file is too large to render. See raw diff
run_mlm_flax.py ADDED
@@ -0,0 +1,846 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=fill-mask
22
+ """
23
+ import json
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ from dataclasses import asdict, dataclass, field
30
+ from enum import Enum
31
+ from itertools import chain
32
+
33
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
34
+ from pathlib import Path
35
+ from typing import Dict, List, Optional, Tuple
36
+
37
+ import numpy as np
38
+ from datasets import load_dataset
39
+ from tqdm import tqdm
40
+
41
+ import flax
42
+ import jax
43
+ import jax.numpy as jnp
44
+ import optax
45
+ from flax import jax_utils, traverse_util
46
+ from flax.training import train_state
47
+ from flax.training.common_utils import get_metrics, onehot, shard
48
+ from huggingface_hub import Repository
49
+ from transformers import (
50
+ CONFIG_MAPPING,
51
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
52
+ AutoConfig,
53
+ AutoTokenizer,
54
+ FlaxAutoModelForMaskedLM,
55
+ HfArgumentParser,
56
+ PreTrainedTokenizerBase,
57
+ TensorType,
58
+ is_tensorboard_available,
59
+ set_seed,
60
+ )
61
+ from transformers.file_utils import get_full_repo_name
62
+
63
+ from distributed_shampoo import distributed_shampoo, GraftingType
64
+
65
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
66
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
67
+
68
+
69
+ @dataclass
70
+ class TrainingArguments:
71
+ output_dir: str = field(
72
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
73
+ )
74
+ overwrite_output_dir: bool = field(
75
+ default=False,
76
+ metadata={
77
+ "help": (
78
+ "Overwrite the content of the output directory. "
79
+ "Use this to continue training if output_dir points to a checkpoint directory."
80
+ )
81
+ },
82
+ )
83
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
84
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
85
+ per_device_train_batch_size: int = field(
86
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
87
+ )
88
+ per_device_eval_batch_size: int = field(
89
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
90
+ )
91
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
92
+ distributed_shampoo: bool = field(
93
+ default=False,
94
+ metadata={"help": "Use Distributed Shampoo optimizer instead of AdamW."},
95
+ )
96
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
97
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
98
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
99
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
100
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
101
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
102
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
103
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
104
+ save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
105
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
106
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
107
+ push_to_hub: bool = field(
108
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
109
+ )
110
+ hub_model_id: str = field(
111
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
112
+ )
113
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
114
+
115
+ def __post_init__(self):
116
+ if self.output_dir is not None:
117
+ self.output_dir = os.path.expanduser(self.output_dir)
118
+
119
+ def to_dict(self):
120
+ """
121
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
122
+ the token values by removing their value.
123
+ """
124
+ d = asdict(self)
125
+ for k, v in d.items():
126
+ if isinstance(v, Enum):
127
+ d[k] = v.value
128
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
129
+ d[k] = [x.value for x in v]
130
+ if k.endswith("_token"):
131
+ d[k] = f"<{k.upper()}>"
132
+ return d
133
+
134
+
135
+ @dataclass
136
+ class ModelArguments:
137
+ """
138
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
139
+ """
140
+
141
+ model_name_or_path: Optional[str] = field(
142
+ default=None,
143
+ metadata={
144
+ "help": "The model checkpoint for weights initialization."
145
+ "Don't set if you want to train a model from scratch."
146
+ },
147
+ )
148
+ model_type: Optional[str] = field(
149
+ default=None,
150
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
151
+ )
152
+ config_name: Optional[str] = field(
153
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
154
+ )
155
+ tokenizer_name: Optional[str] = field(
156
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
157
+ )
158
+ cache_dir: Optional[str] = field(
159
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
160
+ )
161
+ use_fast_tokenizer: bool = field(
162
+ default=True,
163
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
164
+ )
165
+ dtype: Optional[str] = field(
166
+ default="float32",
167
+ metadata={
168
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
169
+ },
170
+ )
171
+
172
+
173
+ @dataclass
174
+ class DataTrainingArguments:
175
+ """
176
+ Arguments pertaining to what data we are going to input our model for training and eval.
177
+ """
178
+
179
+ dataset_name: Optional[str] = field(
180
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
181
+ )
182
+ dataset_config_name: Optional[str] = field(
183
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
184
+ )
185
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
186
+ validation_file: Optional[str] = field(
187
+ default=None,
188
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
189
+ )
190
+ train_ref_file: Optional[str] = field(
191
+ default=None,
192
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
193
+ )
194
+ validation_ref_file: Optional[str] = field(
195
+ default=None,
196
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
197
+ )
198
+ overwrite_cache: bool = field(
199
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
200
+ )
201
+ validation_split_percentage: Optional[int] = field(
202
+ default=5,
203
+ metadata={
204
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
205
+ },
206
+ )
207
+ max_seq_length: Optional[int] = field(
208
+ default=None,
209
+ metadata={
210
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
211
+ "than this will be truncated. Default to the max input length of the model."
212
+ },
213
+ )
214
+ preprocessing_num_workers: Optional[int] = field(
215
+ default=None,
216
+ metadata={"help": "The number of processes to use for the preprocessing."},
217
+ )
218
+ mlm_probability: float = field(
219
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
220
+ )
221
+ pad_to_max_length: bool = field(
222
+ default=False,
223
+ metadata={
224
+ "help": "Whether to pad all samples to `max_seq_length`. "
225
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
226
+ },
227
+ )
228
+ line_by_line: bool = field(
229
+ default=False,
230
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
231
+ )
232
+
233
+ def __post_init__(self):
234
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
235
+ raise ValueError("Need either a dataset name or a training/validation file.")
236
+ else:
237
+ if self.train_file is not None:
238
+ extension = self.train_file.split(".")[-1]
239
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
240
+ if self.validation_file is not None:
241
+ extension = self.validation_file.split(".")[-1]
242
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
243
+
244
+
245
+ @flax.struct.dataclass
246
+ class FlaxDataCollatorForLanguageModeling:
247
+ """
248
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
249
+ are not all of the same length.
250
+
251
+ Args:
252
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
253
+ The tokenizer used for encoding the data.
254
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
255
+ The probability with which to (randomly) mask tokens in the input.
256
+
257
+ .. note::
258
+
259
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
260
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
261
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
262
+ argument :obj:`return_special_tokens_mask=True`.
263
+ """
264
+
265
+ tokenizer: PreTrainedTokenizerBase
266
+ mlm_probability: float = 0.15
267
+
268
+ def __post_init__(self):
269
+ if self.tokenizer.mask_token is None:
270
+ raise ValueError(
271
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
272
+ "You should pass `mlm=False` to train on causal language modeling instead."
273
+ )
274
+
275
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
276
+ # Handle dict or lists with proper padding and conversion to tensor.
277
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
278
+
279
+ # If special token mask has been preprocessed, pop it from the dict.
280
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
281
+
282
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
283
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
284
+ )
285
+ return batch
286
+
287
+ def mask_tokens(
288
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
289
+ ) -> Tuple[np.ndarray, np.ndarray]:
290
+ """
291
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
292
+ """
293
+ labels = inputs.copy()
294
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
295
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
296
+ special_tokens_mask = special_tokens_mask.astype("bool")
297
+
298
+ probability_matrix[special_tokens_mask] = 0.0
299
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
300
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
301
+
302
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
303
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
304
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
305
+
306
+ # 10% of the time, we replace masked input tokens with random word
307
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
308
+ indices_random &= masked_indices & ~indices_replaced
309
+
310
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
311
+ inputs[indices_random] = random_words[indices_random]
312
+
313
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
314
+ return inputs, labels
315
+
316
+
317
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
318
+ num_samples = len(samples_idx)
319
+ samples_to_remove = num_samples % batch_size
320
+
321
+ if samples_to_remove != 0:
322
+ samples_idx = samples_idx[:-samples_to_remove]
323
+ sections_split = num_samples // batch_size
324
+ batch_idx = np.split(samples_idx, sections_split)
325
+ return batch_idx
326
+
327
+
328
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
329
+ summary_writer.scalar("train_time", train_time, step)
330
+
331
+ train_metrics = get_metrics(train_metrics)
332
+ for key, vals in train_metrics.items():
333
+ tag = f"train_{key}"
334
+ for i, val in enumerate(vals):
335
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
336
+
337
+
338
+ def write_eval_metric(summary_writer, eval_metrics, step):
339
+ for metric_name, value in eval_metrics.items():
340
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
341
+
342
+
343
+ def main():
344
+ # See all possible arguments in src/transformers/training_args.py
345
+ # or by passing the --help flag to this script.
346
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
347
+
348
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
349
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
350
+ # If we pass only one argument to the script and it's the path to a json file,
351
+ # let's parse it to get our arguments.
352
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
353
+ else:
354
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
355
+
356
+ if (
357
+ os.path.exists(training_args.output_dir)
358
+ and os.listdir(training_args.output_dir)
359
+ and training_args.do_train
360
+ and not training_args.overwrite_output_dir
361
+ ):
362
+ raise ValueError(
363
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
364
+ "Use --overwrite_output_dir to overcome."
365
+ )
366
+
367
+ # Setup logging
368
+ logging.basicConfig(
369
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
370
+ level=logging.INFO,
371
+ datefmt="[%X]",
372
+ )
373
+
374
+ # Log on each process the small summary:
375
+ logger = logging.getLogger(__name__)
376
+
377
+ # Set the verbosity to info of the Transformers logger (on main process only):
378
+ logger.info(f"Training/evaluation parameters {training_args}")
379
+
380
+ # Set seed before initializing model.
381
+ set_seed(training_args.seed)
382
+
383
+ # Handle the repository creation
384
+ if training_args.push_to_hub:
385
+ if training_args.hub_model_id is None:
386
+ repo_name = get_full_repo_name(
387
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
388
+ )
389
+ else:
390
+ repo_name = training_args.hub_model_id
391
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
392
+
393
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
394
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
395
+ # (the dataset will be downloaded automatically from the datasets Hub).
396
+ #
397
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
398
+ # 'text' is found. You can easily tweak this behavior (see below).
399
+ #
400
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
401
+ # download the dataset.
402
+ if data_args.dataset_name is not None:
403
+ # Downloading and loading a dataset from the hub.
404
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
405
+
406
+ if "validation" not in datasets.keys():
407
+ datasets["validation"] = load_dataset(
408
+ data_args.dataset_name,
409
+ data_args.dataset_config_name,
410
+ split=f"train[:{data_args.validation_split_percentage}%]",
411
+ cache_dir=model_args.cache_dir,
412
+ )
413
+ datasets["train"] = load_dataset(
414
+ data_args.dataset_name,
415
+ data_args.dataset_config_name,
416
+ split=f"train[{data_args.validation_split_percentage}%:]",
417
+ cache_dir=model_args.cache_dir,
418
+ )
419
+ else:
420
+ data_files = {}
421
+ if data_args.train_file is not None:
422
+ data_files["train"] = data_args.train_file
423
+ if data_args.validation_file is not None:
424
+ data_files["validation"] = data_args.validation_file
425
+ extension = data_args.train_file.split(".")[-1]
426
+ if extension == "txt":
427
+ extension = "text"
428
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
429
+
430
+ if "validation" not in datasets.keys():
431
+ datasets["validation"] = load_dataset(
432
+ extension,
433
+ data_files=data_files,
434
+ split=f"train[:{data_args.validation_split_percentage}%]",
435
+ cache_dir=model_args.cache_dir,
436
+ )
437
+ datasets["train"] = load_dataset(
438
+ extension,
439
+ data_files=data_files,
440
+ split=f"train[{data_args.validation_split_percentage}%:]",
441
+ cache_dir=model_args.cache_dir,
442
+ )
443
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
444
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
445
+
446
+ # Load pretrained model and tokenizer
447
+
448
+ # Distributed training:
449
+ # The .from_pretrained methods guarantee that only one local process can concurrently
450
+ # download model & vocab.
451
+ if model_args.config_name:
452
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
453
+ elif model_args.model_name_or_path:
454
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
455
+ else:
456
+ config = CONFIG_MAPPING[model_args.model_type]()
457
+ logger.warning("You are instantiating a new config instance from scratch.")
458
+
459
+ if model_args.tokenizer_name:
460
+ tokenizer = AutoTokenizer.from_pretrained(
461
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
462
+ )
463
+ elif model_args.model_name_or_path:
464
+ tokenizer = AutoTokenizer.from_pretrained(
465
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
466
+ )
467
+ else:
468
+ raise ValueError(
469
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
470
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
471
+ )
472
+
473
+ # Preprocessing the datasets.
474
+ # First we tokenize all the texts.
475
+ if training_args.do_train:
476
+ column_names = datasets["train"].column_names
477
+ else:
478
+ column_names = datasets["validation"].column_names
479
+ text_column_name = "text" if "text" in column_names else column_names[0]
480
+
481
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
482
+
483
+ if data_args.line_by_line:
484
+ # When using line_by_line, we just tokenize each nonempty line.
485
+ padding = "max_length" if data_args.pad_to_max_length else False
486
+
487
+ def tokenize_function(examples):
488
+ # Remove empty lines
489
+ examples = [line for line in examples if len(line) > 0 and not line.isspace()]
490
+ return tokenizer(
491
+ examples,
492
+ return_special_tokens_mask=True,
493
+ padding=padding,
494
+ truncation=True,
495
+ max_length=max_seq_length,
496
+ )
497
+
498
+ tokenized_datasets = datasets.map(
499
+ tokenize_function,
500
+ input_columns=[text_column_name],
501
+ batched=True,
502
+ num_proc=data_args.preprocessing_num_workers,
503
+ remove_columns=column_names,
504
+ load_from_cache_file=not data_args.overwrite_cache,
505
+ )
506
+
507
+ else:
508
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
509
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
510
+ # efficient when it receives the `special_tokens_mask`.
511
+ def tokenize_function(examples):
512
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
513
+
514
+ tokenized_datasets = datasets.map(
515
+ tokenize_function,
516
+ batched=True,
517
+ num_proc=data_args.preprocessing_num_workers,
518
+ remove_columns=column_names,
519
+ load_from_cache_file=not data_args.overwrite_cache,
520
+ )
521
+
522
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
523
+ # max_seq_length.
524
+ def group_texts(examples):
525
+ # Concatenate all texts.
526
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
527
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
528
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
529
+ # customize this part to your needs.
530
+ if total_length >= max_seq_length:
531
+ total_length = (total_length // max_seq_length) * max_seq_length
532
+ # Split by chunks of max_len.
533
+ result = {
534
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
535
+ for k, t in concatenated_examples.items()
536
+ }
537
+ return result
538
+
539
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
540
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
541
+ # might be slower to preprocess.
542
+ #
543
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
544
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
545
+ tokenized_datasets = tokenized_datasets.map(
546
+ group_texts,
547
+ batched=True,
548
+ num_proc=data_args.preprocessing_num_workers,
549
+ load_from_cache_file=not data_args.overwrite_cache,
550
+ )
551
+
552
+ # Enable tensorboard only on the master node
553
+ has_tensorboard = is_tensorboard_available()
554
+ if has_tensorboard and jax.process_index() == 0:
555
+ try:
556
+ # Enable Weight&Biases
557
+ import wandb
558
+ wandb.init(
559
+ entity='versae',
560
+ project='roberta-base-ncc',
561
+ sync_tensorboard=True,
562
+ )
563
+ wandb.config.update(training_args)
564
+ wandb.config.update(model_args)
565
+ wandb.config.update(data_args)
566
+
567
+ from flax.metrics.tensorboard import SummaryWriter
568
+
569
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
570
+ except ImportError as ie:
571
+ has_tensorboard = False
572
+ logger.warning(
573
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
574
+ )
575
+ else:
576
+ logger.warning(
577
+ "Unable to display metrics through TensorBoard because the package is not installed: "
578
+ "Please run pip install tensorboard to enable."
579
+ )
580
+
581
+ # Data collator
582
+ # This one will take care of randomly masking the tokens.
583
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
584
+
585
+ # Initialize our training
586
+ rng = jax.random.PRNGKey(training_args.seed)
587
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
588
+
589
+ if model_args.model_name_or_path:
590
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
591
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
592
+ )
593
+ else:
594
+ model = FlaxAutoModelForMaskedLM.from_config(
595
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
596
+ )
597
+
598
+ # Store some constant
599
+ num_epochs = int(training_args.num_train_epochs)
600
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
601
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
602
+
603
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
604
+
605
+ # Create learning rate schedule
606
+ warmup_fn = optax.linear_schedule(
607
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
608
+ )
609
+ decay_fn = optax.linear_schedule(
610
+ init_value=training_args.learning_rate,
611
+ end_value=0,
612
+ transition_steps=num_train_steps - training_args.warmup_steps,
613
+ )
614
+ linear_decay_lr_schedule_fn = optax.join_schedules(
615
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
616
+ )
617
+
618
+ # We use Optax's "masking" functionality to not apply weight decay
619
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
620
+ # mask boolean with the same structure as the parameters.
621
+ # The mask is True for parameters that should be decayed.
622
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
623
+ # For other models, one should correct the layer norm parameter naming
624
+ # accordingly.
625
+ def decay_mask_fn(params):
626
+ flat_params = traverse_util.flatten_dict(params)
627
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
628
+ return traverse_util.unflatten_dict(flat_mask)
629
+
630
+ # create adam optimizer
631
+ if training_args.adafactor:
632
+ # We use the default parameters here to initialize adafactor,
633
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
634
+ optimizer = optax.adafactor(
635
+ learning_rate=linear_decay_lr_schedule_fn,
636
+ )
637
+ elif training_args.distributed_shampoo:
638
+ # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
639
+ # Notes:
640
+ # - mask for weight decay is not implemented but we don't use it anyway
641
+ optimizer = distributed_shampoo(
642
+ linear_decay_lr_schedule_fn,
643
+ block_size=1024, # recommended default for large LM is 1536
644
+ beta1=training_args.adam_beta1, # 0.9,
645
+ beta2=training_args.adam_beta2, # 0.999,
646
+ diagonal_epsilon=training_args.adam_epsilon, # 1e-10,
647
+ matrix_epsilon=1e-8,
648
+ weight_decay=training_args.weight_decay, # 0.0,
649
+ start_preconditioning_step=1001,
650
+ preconditioning_compute_steps=10,
651
+ statistics_compute_steps=1,
652
+ best_effort_shape_interpretation=True,
653
+ graft_type=GraftingType.RMSPROP_NORMALIZED,
654
+ nesterov=False,
655
+ exponent_override=0,
656
+ batch_axis_name="batch",
657
+ inverse_failure_threshold=0.1,
658
+ moving_average_for_momentum=True,
659
+ skip_preconditioning_dim_size_gt=4096,
660
+ clip_by_scaled_gradient_norm=None,
661
+ precision=jax.lax.Precision.HIGHEST,
662
+ )
663
+ else:
664
+ optimizer = optax.adamw(
665
+ learning_rate=linear_decay_lr_schedule_fn,
666
+ b1=training_args.adam_beta1,
667
+ b2=training_args.adam_beta2,
668
+ eps=training_args.adam_epsilon,
669
+ weight_decay=training_args.weight_decay,
670
+ mask=decay_mask_fn,
671
+ )
672
+
673
+ # Setup train state
674
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
675
+
676
+ # Define gradient update step fn
677
+ def train_step(state, batch, dropout_rng):
678
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
679
+
680
+ def loss_fn(params):
681
+ labels = batch.pop("labels")
682
+
683
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
684
+
685
+ # compute loss, ignore padded input tokens
686
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
687
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
688
+
689
+ # take average
690
+ loss = loss.sum() / label_mask.sum()
691
+
692
+ return loss
693
+
694
+ grad_fn = jax.value_and_grad(loss_fn)
695
+ loss, grad = grad_fn(state.params)
696
+ grad = jax.lax.pmean(grad, "batch")
697
+ new_state = state.apply_gradients(grads=grad)
698
+
699
+ metrics = jax.lax.pmean(
700
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
701
+ )
702
+
703
+ return new_state, metrics, new_dropout_rng
704
+
705
+ # Create parallel version of the train step
706
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
707
+
708
+ # Define eval fn
709
+ def eval_step(params, batch):
710
+ labels = batch.pop("labels")
711
+
712
+ logits = model(**batch, params=params, train=False)[0]
713
+
714
+ # compute loss, ignore padded input tokens
715
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
716
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
717
+
718
+ # compute accuracy
719
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
720
+
721
+ # summarize metrics
722
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
723
+ metrics = jax.lax.psum(metrics, axis_name="batch")
724
+
725
+ return metrics
726
+
727
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
728
+
729
+ # Replicate the train state on each device
730
+ state = jax_utils.replicate(state)
731
+
732
+ train_time = 0
733
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
734
+ for epoch in epochs:
735
+ # ======================== Training ================================
736
+ train_start = time.time()
737
+ train_metrics = []
738
+
739
+ # Create sampling rng
740
+ rng, input_rng = jax.random.split(rng)
741
+
742
+ # Generate an epoch by shuffling sampling indices from the train dataset
743
+ num_train_samples = len(tokenized_datasets["train"])
744
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
745
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
746
+
747
+ # Gather the indexes for creating the batch and do a training step
748
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
749
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
750
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
751
+
752
+ # Model forward
753
+ model_inputs = shard(model_inputs.data)
754
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
755
+ train_metrics.append(train_metric)
756
+
757
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
758
+
759
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
760
+ # Save metrics
761
+ train_metric = jax_utils.unreplicate(train_metric)
762
+ train_time += time.time() - train_start
763
+ if has_tensorboard and jax.process_index() == 0:
764
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
765
+
766
+ epochs.write(
767
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
768
+ )
769
+
770
+ train_metrics = []
771
+
772
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
773
+ # ======================== Evaluating ==============================
774
+ num_eval_samples = len(tokenized_datasets["validation"])
775
+ eval_samples_idx = jnp.arange(num_eval_samples)
776
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
777
+
778
+ eval_metrics = []
779
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
780
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
781
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
782
+
783
+ # Model forward
784
+ model_inputs = shard(model_inputs.data)
785
+ metrics = p_eval_step(state.params, model_inputs)
786
+ eval_metrics.append(metrics)
787
+
788
+ # normalize eval metrics
789
+ eval_metrics = get_metrics(eval_metrics)
790
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
791
+ eval_normalizer = eval_metrics.pop("normalizer")
792
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
793
+
794
+ # Update progress bar
795
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
796
+
797
+ # Save metrics
798
+ if has_tensorboard and jax.process_index() == 0:
799
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
800
+
801
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
802
+ # save checkpoint after each epoch and push checkpoint to the hub
803
+ if jax.process_index() == 0:
804
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
805
+ model.save_pretrained(training_args.output_dir, params=params)
806
+ tokenizer.save_pretrained(training_args.output_dir)
807
+ if training_args.push_to_hub:
808
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
809
+
810
+ # Eval after training
811
+ if training_args.do_eval:
812
+ num_eval_samples = len(tokenized_datasets["validation"])
813
+ eval_samples_idx = jnp.arange(num_eval_samples)
814
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
815
+
816
+ eval_metrics = []
817
+ for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
818
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
819
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
820
+
821
+ # Model forward
822
+ model_inputs = shard(model_inputs.data)
823
+ metrics = p_eval_step(state.params, model_inputs)
824
+ eval_metrics.append(metrics)
825
+
826
+ # normalize eval metrics
827
+ eval_metrics = get_metrics(eval_metrics)
828
+ eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
829
+ eval_normalizer = eval_metrics.pop("normalizer")
830
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
831
+
832
+ try:
833
+ perplexity = math.exp(eval_metrics["loss"])
834
+ except OverflowError:
835
+ perplexity = float("inf")
836
+ eval_metrics["perplexity"] = perplexity
837
+
838
+ if jax.process_index() == 0:
839
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
840
+ path = os.path.join(training_args.output_dir, "eval_results.json")
841
+ with open(path, "w") as f:
842
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
843
+
844
+
845
+ if __name__ == "__main__":
846
+ main()
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "add_prefix_space": false, "errors": "replace", "sep_token": "</s>", "cls_token": "<s>", "pad_token": "<pad>", "mask_token": "<mask>", "trim_offsets": true, "special_tokens_map_file": null, "name_or_path": "NbAiLab/nb-roberta-base", "tokenizer_class": "RobertaTokenizer"}
train.128.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_mlm_flax.py \
2
+ --output_dir="./" \
3
+ --model_type="roberta" \
4
+ --config_name="roberta-base" \
5
+ --tokenizer_name="NbAiLab/nb-roberta-base" \
6
+ --dataset_name="NbAiLab/NCC" \
7
+ --max_seq_length="128" \
8
+ --weight_decay="0.0" \
9
+ --per_device_train_batch_size="232" \
10
+ --per_device_eval_batch_size="232" \
11
+ --pad_to_max_length \
12
+ --learning_rate="0.0003" \
13
+ --warmup_steps="10000" \
14
+ --overwrite_output_dir \
15
+ --num_train_epochs="3" \
16
+ --distributed_shampoo \
17
+ --adam_beta1="0.9" \
18
+ --adam_beta2="0.99" \
19
+ --adam_epsilon="1e-10" \
20
+ --logging_steps="1000" \
21
+ --save_steps="1000" \
22
+ --eval_steps="1000" \
23
+ --do_train \
24
+ --do_eval \
25
+ --dtype="bfloat16" \
26
+ --push_to_hub
vocab.json ADDED
The diff for this file is too large to render. See raw diff