aapot commited on
Commit
cd57f41
1 Parent(s): 0590843

Saving weights and logs of step 10000

Browse files
added_tokens.json ADDED
@@ -0,0 +1 @@
 
1
+ {"<|endoftext|>": 50257}
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./",
3
+ "activation_function": "gelu_new",
4
+ "architectures": [
5
+ "GPT2LMHeadModel"
6
+ ],
7
+ "attn_pdrop": 0.1,
8
+ "bos_token_id": 50256,
9
+ "embd_pdrop": 0.1,
10
+ "eos_token_id": 50256,
11
+ "initializer_range": 0.02,
12
+ "layer_norm_epsilon": 1e-05,
13
+ "model_type": "gpt2",
14
+ "n_ctx": 1024,
15
+ "n_embd": 1024,
16
+ "n_head": 16,
17
+ "n_inner": null,
18
+ "n_layer": 24,
19
+ "n_positions": 1024,
20
+ "n_special": 0,
21
+ "predict_special_tokens": true,
22
+ "reorder_and_upcast_attn": false,
23
+ "resid_pdrop": 0.1,
24
+ "scale_attn_by_inverse_layer_idx": false,
25
+ "scale_attn_weights": true,
26
+ "summary_activation": null,
27
+ "summary_first_dropout": 0.1,
28
+ "summary_proj_to_labels": true,
29
+ "summary_type": "cls_index",
30
+ "summary_use_proj": true,
31
+ "task_specific_params": {
32
+ "text-generation": {
33
+ "do_sample": true,
34
+ "max_length": 50
35
+ }
36
+ },
37
+ "transformers_version": "4.16.0.dev0",
38
+ "use_cache": true,
39
+ "vocab_size": 50257
40
+ }
distributed_shampoo.py ADDED
@@ -0,0 +1,1801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from https://github.com/google-research/google-research/blob/master/scalable_shampoo/optax/distributed_shampoo.py
2
+
3
+ # coding=utf-8
4
+ # Copyright 2022 The Google Research Authors.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ # An implementation of distributed Shampoo optimizer from:
19
+ #
20
+ # Scalable Second Order Optimization for Deep Learning
21
+ # Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
22
+ # Preprint Paper: https://arxiv.org/abs/2002.09018
23
+ #
24
+ # This implementation moves computation of inverse pth root back to the
25
+ # accelerator (if higher precision is available).
26
+ #
27
+ # Authors: Rohan Anil (rohananil at google dot com)
28
+ # & Vineet Gupta (vineet at google dot com)
29
+ #
30
+
31
+ """Distributed Shampoo Implementation."""
32
+
33
+ import enum
34
+ import functools
35
+ import itertools
36
+ from typing import Any, List, NamedTuple
37
+
38
+ import chex
39
+ from flax import struct
40
+ import jax
41
+ from jax import lax
42
+ import jax.experimental.pjit as pjit
43
+ import jax.numpy as jnp
44
+ import numpy as np
45
+ import optax
46
+
47
+ PartitionSpec = pjit.PartitionSpec
48
+
49
+
50
+ # pylint:disable=no-value-for-parameter
51
+ @struct.dataclass
52
+ class QuantizedValue:
53
+ """State associated with quantized value."""
54
+ quantized: chex.Array
55
+ diagonal: chex.Array # Diagonal (if extract_diagonal is set)
56
+ bucket_size: chex.Array
57
+ quantized_dtype: jnp.dtype = struct.field(
58
+ pytree_node=False) # Dtype for the quantized value.
59
+ extract_diagonal: bool = struct.field(
60
+ pytree_node=False) # In case its centered.
61
+ shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
62
+
63
+ @classmethod
64
+ def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
65
+ if isinstance(fvalue, list) and not fvalue:
66
+ return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
67
+ quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
68
+ fvalue, quantized_dtype, extract_diagonal)
69
+ return QuantizedValue(quantized, diagonal_fvalue, bucket_size,
70
+ quantized_dtype, extract_diagonal,
71
+ list(quantized.shape))
72
+
73
+ # Quantization is from Lingvo JAX optimizers.
74
+ # We extend it for int16 quantization of PSD matrices.
75
+ @classmethod
76
+ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
77
+ """Returns quantized value and the bucket."""
78
+ if quantized_dtype == jnp.float32:
79
+ return fvalue, [], []
80
+ elif quantized_dtype == jnp.bfloat16:
81
+ return fvalue.astype(jnp.bfloat16), [], []
82
+
83
+ float_dtype = fvalue.dtype
84
+ if quantized_dtype == jnp.int8:
85
+ # value -128 is not used.
86
+ num_buckets = jnp.array(127.0, dtype=float_dtype)
87
+ elif quantized_dtype == jnp.int16:
88
+ # value -32768 is not used.
89
+ num_buckets = jnp.array(32767.0, dtype=float_dtype)
90
+ else:
91
+ raise ValueError(f'Quantized dtype {quantized_dtype} not supported.')
92
+ # max value is mapped to num_buckets
93
+
94
+ if extract_diagonal and fvalue.ndim != 2:
95
+ raise ValueError(
96
+ f'Input array {fvalue} must be 2D to work with extract_diagonal.')
97
+
98
+ diagonal_fvalue = []
99
+ if extract_diagonal:
100
+ diagonal_fvalue = jnp.diag(fvalue)
101
+ # Remove the diagonal entries.
102
+ fvalue = fvalue - jnp.diag(diagonal_fvalue)
103
+
104
+ # TODO(rohananil): Extend this by making use of information about the blocks
105
+ # SM3 style which will be useful for diagonal statistics
106
+ # We first decide the scale.
107
+ if fvalue.ndim < 1:
108
+ raise ValueError(
109
+ f'Input array {fvalue} must have a strictly positive number of '
110
+ 'dimensions.')
111
+
112
+ max_abs = jnp.max(jnp.abs(fvalue), axis=0)
113
+ bucket_size = max_abs / num_buckets
114
+ bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
115
+ # To avoid divide by 0.0
116
+ bs_nonzero = jnp.where(bs_expanded > 0.0, bs_expanded,
117
+ jnp.ones_like(bs_expanded))
118
+ ratio = fvalue / bs_nonzero
119
+ # We use rounding to remove bias.
120
+ quantized = jnp.round(ratio)
121
+ return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
122
+
123
+ def to_float(self):
124
+ """Returns the float value."""
125
+ if isinstance(self.quantized, list) and not self.quantized:
126
+ return self.quantized
127
+
128
+ if self.quantized_dtype == jnp.float32:
129
+ return self.quantized
130
+
131
+ if self.quantized_dtype == jnp.bfloat16:
132
+ return self.quantized.astype(jnp.float32)
133
+
134
+ float_dtype = self.bucket_size.dtype
135
+ bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
136
+ val = self.quantized.astype(float_dtype) * bucket_size
137
+ if self.extract_diagonal:
138
+ val += jnp.diag(self.diagonal)
139
+ return val
140
+
141
+
142
+ # Per parameter optimizer state used in data-parallel training.
143
+ class ParameterStats(NamedTuple):
144
+ """State associated to each parameter of the model being trained."""
145
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
146
+ statistics: List[Any] # Statistics (QuantizedValue, chex.Array)
147
+ preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
148
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
149
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
150
+
151
+
152
+ # For training extremely large model; We keep a global state with a concatenated
153
+ # statistics and preconditioner states for all vars. This is so that we can
154
+ # annotate the leading axis to be sharded to save memory at the cost of
155
+ # communication.
156
+ @struct.dataclass
157
+ class GlobalShardedParameterStats:
158
+ statistics: chex.Array # Statistics
159
+ preconditioners: chex.Array # Preconditioners
160
+ exponents: chex.Array # exponents
161
+
162
+
163
+ # These are per-parameter local states; All statistics here mirror the parameter
164
+ # Thus the sharding is copied over from the param specification.
165
+ @struct.dataclass
166
+ class LocalShardedParameterStats:
167
+ """State associated to each parameter of the model being trained."""
168
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
169
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
170
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
171
+ index_start: np.int32 = struct.field(
172
+ pytree_node=False) # Index into global statistics array
173
+ sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
174
+
175
+
176
+ class ShardedShampooStats(NamedTuple):
177
+ """Shampoo state in sharded mode."""
178
+ global_stats: Any
179
+ local_stats: Any
180
+
181
+
182
+ class ShampooState(NamedTuple):
183
+ count: chex.Array
184
+ stats: Any
185
+
186
+
187
+ class InitFnState(NamedTuple):
188
+ init_fn: Any
189
+ pspec_fn: Any
190
+ shape_and_dtype_fn: Any
191
+
192
+
193
+ class GraftingType(enum.IntEnum):
194
+ SGD = 1
195
+ ADAGRAD = 2
196
+ RMSPROP = 3
197
+ RMSPROP_NORMALIZED = 4
198
+
199
+
200
+ def power_iteration(
201
+ matrix,
202
+ num_iters=100,
203
+ error_tolerance=1e-6,
204
+ precision=lax.Precision.HIGHEST):
205
+ r"""Power iteration algorithm.
206
+
207
+ The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
208
+ a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
209
+ of `A`, and a vector v, which is the corresponding eigenvector of `A`.
210
+
211
+ References:
212
+ [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
213
+
214
+ Args:
215
+ matrix: the symmetric PSD matrix.
216
+ num_iters: Number of iterations.
217
+ error_tolerance: Iterative exit condition.
218
+ precision: precision XLA related flag, the available options are:
219
+ a) lax.Precision.DEFAULT (better step time, but not precise)
220
+ b) lax.Precision.HIGH (increased precision, slower)
221
+ c) lax.Precision.HIGHEST (best possible precision, slowest)
222
+
223
+ Returns:
224
+ eigen vector, eigen value
225
+ """
226
+ matrix_size = matrix.shape[-1]
227
+ def _iter_condition(state):
228
+ i, unused_v, unused_s, unused_s_v, run_step = state
229
+ return jnp.logical_and(i < num_iters, run_step)
230
+
231
+ def _iter_body(state):
232
+ """One step of power iteration."""
233
+ i, new_v, s, s_v, unused_run_step = state
234
+ new_v = new_v / jnp.linalg.norm(new_v)
235
+
236
+ s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision)
237
+ s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision)
238
+ return (i + 1, s_v, s_new, s_v,
239
+ jnp.greater(jnp.abs(s_new - s), error_tolerance))
240
+
241
+ # Figure out how to use step as seed for random.
242
+ v_0 = np.random.RandomState(1729).uniform(-1.0, 1.0,
243
+ matrix_size).astype(matrix.dtype)
244
+
245
+ init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
246
+ _, v_out, s_out, _, _ = lax.while_loop(
247
+ _iter_condition, _iter_body, init_state)
248
+ v_out = v_out / jnp.linalg.norm(v_out)
249
+ return v_out, s_out
250
+
251
+
252
+ def matrix_inverse_pth_root(
253
+ matrix,
254
+ p,
255
+ num_iters=100,
256
+ ridge_epsilon=1e-6,
257
+ error_tolerance=1e-6,
258
+ precision=lax.Precision.HIGHEST):
259
+ """Computes `matrix^(-1/p)`, where `p` is a positive integer.
260
+
261
+ This function uses the Coupled newton iterations algorithm for
262
+ the computation of a matrix's inverse pth root.
263
+
264
+
265
+ References:
266
+ [Functions of Matrices, Theory and Computation,
267
+ Nicholas J Higham, Pg 184, Eq 7.18](
268
+ https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
269
+
270
+ Args:
271
+ matrix: the symmetric PSD matrix whose power it to be computed
272
+ p: exponent, for p a positive integer.
273
+ num_iters: Maximum number of iterations.
274
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
275
+ error_tolerance: Error indicator, useful for early termination.
276
+ precision: precision XLA related flag, the available options are:
277
+ a) lax.Precision.DEFAULT (better step time, but not precise)
278
+ b) lax.Precision.HIGH (increased precision, slower)
279
+ c) lax.Precision.HIGHEST (best possible precision, slowest)
280
+
281
+ Returns:
282
+ matrix^(-1/p)
283
+ """
284
+
285
+ assert matrix.shape[0] == matrix.shape[1]
286
+
287
+ # We use float32 for the matrix inverse pth root.
288
+ # Switch to f64 if you have hardware that supports it.
289
+ matrix_size = matrix.shape[0]
290
+ alpha = jnp.asarray(-1.0 / p, jnp.float32)
291
+ identity = jnp.eye(matrix_size, dtype=jnp.float32)
292
+ _, max_ev = power_iteration(
293
+ matrix=matrix, num_iters=100,
294
+ error_tolerance=1e-6, precision=precision)
295
+ ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16)
296
+
297
+ def _unrolled_mat_pow_1(mat_m):
298
+ """Computes mat_m^1."""
299
+ return mat_m
300
+
301
+ def _unrolled_mat_pow_2(mat_m):
302
+ """Computes mat_m^2."""
303
+ return jnp.matmul(mat_m, mat_m, precision=precision)
304
+
305
+ def _unrolled_mat_pow_4(mat_m):
306
+ """Computes mat_m^4."""
307
+ mat_pow_2 = _unrolled_mat_pow_2(mat_m)
308
+ return jnp.matmul(
309
+ mat_pow_2, mat_pow_2, precision=precision)
310
+
311
+ def _unrolled_mat_pow_8(mat_m):
312
+ """Computes mat_m^4."""
313
+ mat_pow_4 = _unrolled_mat_pow_4(mat_m)
314
+ return jnp.matmul(
315
+ mat_pow_4, mat_pow_4, precision=precision)
316
+
317
+ def mat_power(mat_m, p):
318
+ """Computes mat_m^p, for p == 1, 2, 4 or 8.
319
+
320
+ Args:
321
+ mat_m: a square matrix
322
+ p: a positive integer
323
+
324
+ Returns:
325
+ mat_m^p
326
+ """
327
+ # We unrolled the loop for performance reasons.
328
+ exponent = jnp.round(jnp.log2(p))
329
+ return lax.switch(
330
+ jnp.asarray(exponent, jnp.int32), [
331
+ _unrolled_mat_pow_1,
332
+ _unrolled_mat_pow_2,
333
+ _unrolled_mat_pow_4,
334
+ _unrolled_mat_pow_8,
335
+ ], (mat_m))
336
+
337
+ def _iter_condition(state):
338
+ (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error,
339
+ run_step) = state
340
+ error_above_threshold = jnp.logical_and(
341
+ error > error_tolerance, run_step)
342
+ return jnp.logical_and(i < num_iters, error_above_threshold)
343
+
344
+ def _iter_body(state):
345
+ (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state
346
+ mat_m_i = (1 - alpha) * identity + alpha * mat_m
347
+ new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision)
348
+ new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision)
349
+ new_error = jnp.max(jnp.abs(new_mat_m - identity))
350
+ # sometimes error increases after an iteration before decreasing and
351
+ # converging. 1.2 factor is used to bound the maximal allowed increase.
352
+ return (i + 1, new_mat_m, new_mat_h, mat_h, new_error,
353
+ new_error < error * 1.2)
354
+
355
+ if matrix_size == 1:
356
+ resultant_mat_h = (matrix + ridge_epsilon)**alpha
357
+ error = 0
358
+ else:
359
+ damped_matrix = matrix + ridge_epsilon * identity
360
+
361
+ z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix))
362
+ new_mat_m_0 = damped_matrix * z
363
+ new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
364
+ new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
365
+ init_state = tuple(
366
+ [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True])
367
+ _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
368
+ _iter_condition, _iter_body, init_state)
369
+ error = jnp.max(jnp.abs(mat_m - identity))
370
+ is_converged = jnp.asarray(convergence, old_mat_h.dtype)
371
+ resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
372
+ resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype)
373
+ return resultant_mat_h, error
374
+
375
+
376
+ def merge_small_dims(shape_to_merge, max_dim):
377
+ """Merge small dimensions.
378
+
379
+ If there are some small dimensions, we collapse them:
380
+ e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
381
+ [1, 2, 768, 1, 2048] --> [2, 768, 2048]
382
+
383
+ Args:
384
+ shape_to_merge: Shape to merge small dimensions.
385
+ max_dim: Maximal dimension of output shape used in merging.
386
+
387
+ Returns:
388
+ Merged shape.
389
+ """
390
+ resulting_shape = []
391
+ product = 1
392
+ for d in shape_to_merge:
393
+ if product * d <= max_dim:
394
+ product *= d
395
+ else:
396
+ if product > 1:
397
+ resulting_shape.append(product)
398
+ product = d
399
+ if product > 1:
400
+ resulting_shape.append(product)
401
+ return resulting_shape
402
+
403
+
404
+ def pad_matrix(mat, max_size):
405
+ """Pad a matrix to a max_size.
406
+
407
+ Args:
408
+ mat: a matrix to pad.
409
+ max_size: matrix size requested.
410
+
411
+ Returns:
412
+ Given M returns [[M, 0], [0, I]]
413
+ """
414
+ size = mat.shape[0]
415
+ assert size <= max_size
416
+ if size == max_size:
417
+ return mat
418
+ pad_size = max_size - size
419
+ zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype)
420
+ zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype)
421
+ eye = jnp.eye(pad_size, dtype=mat.dtype)
422
+ mat = jnp.concatenate([mat, zs1], 1)
423
+ mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
424
+ return mat
425
+
426
+
427
+ def pad_vector(vec, max_size):
428
+ """Pad a vector to a max_size.
429
+
430
+ Args:
431
+ vec: a vector to pad.
432
+ max_size: matrix size requested.
433
+
434
+ Returns:
435
+ Given V returns [V, 0]
436
+ """
437
+ size = vec.shape[0]
438
+ assert size <= max_size
439
+ if size == max_size:
440
+ return vec
441
+ pad_size = max_size - size
442
+ zs1 = jnp.zeros([pad_size], dtype=vec.dtype)
443
+ return jnp.concatenate([vec, zs1], 0)
444
+
445
+
446
+ def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
447
+ """Avoids wasteful buffer allocation with XLA."""
448
+
449
+ def _iter_body(unused_state):
450
+ results = compute_fn(*args, **kwargs)
451
+ return tuple([False] + list(results))
452
+
453
+ def _iter_condition(state):
454
+ return state[0]
455
+
456
+ results = jax.lax.while_loop(_iter_condition, _iter_body,
457
+ tuple([predicate] + init_state))
458
+ return tuple(results[1:])
459
+
460
+
461
+ class BlockPartitioner:
462
+ """Partitions a tensor into smaller tensors."""
463
+
464
+ def __init__(self, param, block_size):
465
+ self._shape = param.shape
466
+ self._splits = []
467
+ split_sizes = []
468
+ # We split params into smaller blocks. Here we store the metadata to make
469
+ # that split.
470
+ for i, d in enumerate(param.shape):
471
+ if 0 < block_size < d:
472
+ # d-1, otherwise split appends a 0-size array.
473
+ nsplit = (d - 1) // block_size
474
+ indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size
475
+ sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size
476
+ sizes[-1] = d - indices[-1]
477
+ self._splits.append((i, indices))
478
+ split_sizes.append(sizes)
479
+ else:
480
+ split_sizes.append(np.array([d], dtype=np.int32))
481
+ self._num_splits = len(split_sizes)
482
+ self._preconditioner_shapes = []
483
+ for t in itertools.product(*split_sizes):
484
+ self._preconditioner_shapes.extend([[d, d] for d in t])
485
+
486
+ def shapes_for_preconditioners(self):
487
+ return self._preconditioner_shapes
488
+
489
+ def num_splits(self):
490
+ return self._num_splits
491
+
492
+ def partition(self, tensor):
493
+ """Partition tensor into blocks."""
494
+
495
+ assert tensor.shape == self._shape
496
+ tensors = [tensor]
497
+ for (i, indices) in self._splits:
498
+ tensors_local = []
499
+ for t in tensors:
500
+ tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
501
+ tensors = tensors_local
502
+ return tensors
503
+
504
+ def merge_partitions(self, partitions):
505
+ """Merge partitions back to original shape."""
506
+
507
+ for (i, indices) in reversed(self._splits):
508
+ n = len(indices) + 1
509
+ partial_merged_tensors = []
510
+ ind = 0
511
+ while ind < len(partitions):
512
+ partial_merged_tensors.append(
513
+ jnp.concatenate(partitions[ind:ind + n], axis=i))
514
+ ind += n
515
+ partitions = partial_merged_tensors
516
+ assert len(partitions) == 1
517
+ return partitions[0]
518
+
519
+
520
+ class Preconditioner:
521
+ """Compute statistics/shape from gradients for preconditioning."""
522
+
523
+ def __init__(self, param, block_size, best_effort_shape_interpretation):
524
+ self._original_shape = param.shape
525
+ self._transformed_shape = param.shape
526
+ if best_effort_shape_interpretation:
527
+ self._transformed_shape = merge_small_dims(self._original_shape,
528
+ block_size)
529
+ reshaped_param = jnp.reshape(param, self._transformed_shape)
530
+ self._partitioner = BlockPartitioner(reshaped_param, block_size)
531
+
532
+ def statistics_from_grad(self, grad):
533
+ """Compute statistics from gradients.
534
+
535
+ Args:
536
+ grad: Gradient to compute statistics from.
537
+
538
+ Returns:
539
+ A list of gradient statistics for each partition.
540
+ """
541
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
542
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
543
+ stats = []
544
+ for g in partitioned_grads:
545
+ g_stats = []
546
+ rank = len(g.shape)
547
+ for i in range(rank):
548
+ axes = list(range(i)) + list(range(i + 1, rank))
549
+ stat = jnp.tensordot(g, g, axes=(axes, axes))
550
+ g_stats.append(stat)
551
+ stats.extend(g_stats)
552
+ return stats
553
+
554
+ def shapes_for_preconditioners(self):
555
+ """Returns shape from statistics."""
556
+ return self._partitioner.shapes_for_preconditioners()
557
+
558
+ def exponent_for_preconditioner(self):
559
+ """Returns exponent to use for inverse-pth root M^{-1/p}."""
560
+ return 2 * len(self._transformed_shape)
561
+
562
+ def preconditioned_grad(self, grad, preconditioners):
563
+ """Precondition the gradient.
564
+
565
+ Args:
566
+ grad: A gradient tensor to precondition.
567
+ preconditioners: A list of preconditioners to apply.
568
+
569
+ Returns:
570
+ A preconditioned gradient.
571
+ """
572
+
573
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
574
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
575
+ preconditioned_partitioned_grads = []
576
+ num_splits = self._partitioner.num_splits()
577
+ for i, g in enumerate(partitioned_grads):
578
+ preconditioners_for_grad = preconditioners[i * num_splits:(i + 1) *
579
+ num_splits]
580
+ rank = len(g.shape)
581
+ precond_g = g
582
+ for j in range(rank):
583
+ precond_g = jnp.tensordot(
584
+ precond_g, preconditioners_for_grad[j], axes=[[0], [0]])
585
+ preconditioned_partitioned_grads.append(precond_g)
586
+ merged_grad = self._partitioner.merge_partitions(
587
+ preconditioned_partitioned_grads)
588
+ return jnp.reshape(merged_grad, self._original_shape)
589
+
590
+
591
+ def _convert_to_parameter_stats(global_stats, local_stat):
592
+ """Creates parameter stats from sharded stats."""
593
+ index_start = int(local_stat.index_start)
594
+ index_end = int(len(local_stat.sizes)) + index_start
595
+ statistics = global_stats.statistics[index_start:index_end, :, :]
596
+ preconditioners = global_stats.preconditioners[index_start:index_end, :, :]
597
+ new_statistics = []
598
+ new_preconditioners = []
599
+ for i, size in enumerate(local_stat.sizes):
600
+ new_statistics.append(statistics[i][:size, :size])
601
+ new_preconditioners.append(preconditioners[i][:size, :size])
602
+ return ParameterStats(local_stat.diagonal_statistics, new_statistics,
603
+ new_preconditioners, local_stat.diagonal_momentum,
604
+ local_stat.momentum)
605
+
606
+
607
+ def _convert_from_parameter_stats(parameter_stats, local_stats):
608
+ """Creates sharded stats from paramter stats."""
609
+ return LocalShardedParameterStats(parameter_stats.diagonal_statistics,
610
+ parameter_stats.diagonal_momentum,
611
+ parameter_stats.momentum,
612
+ local_stats.index_start, local_stats.sizes)
613
+
614
+
615
+ def batch(x, num_devices):
616
+ """Batch `x` so that so that leading axis is num_devices."""
617
+ n = len(x)
618
+ b = int(n / num_devices)
619
+ return jnp.stack([jnp.stack(x[idx:idx + b]) for idx in range(0, n, b)])
620
+
621
+
622
+ def unbatch(batched_values):
623
+ """Unbatch values across leading axis and return a list of elements."""
624
+ b1, b2 = batched_values.shape[0], batched_values.shape[1]
625
+ results = []
626
+ for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0):
627
+ v_array = jnp.squeeze(v_array)
628
+ # b2 = batches (number of preconditioner computation) per core.
629
+ if b2 > 1:
630
+ for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
631
+ results.append(jnp.squeeze(v))
632
+ else:
633
+ results.append(v_array)
634
+ return results
635
+
636
+
637
+ def distributed_shampoo(
638
+ learning_rate,
639
+ block_size,
640
+ beta1=0.9,
641
+ beta2=0.999,
642
+ diagonal_epsilon=1e-10,
643
+ matrix_epsilon=1e-6,
644
+ weight_decay=0.0,
645
+ start_preconditioning_step=5,
646
+ preconditioning_compute_steps=1,
647
+ statistics_compute_steps=1,
648
+ best_effort_shape_interpretation=True,
649
+ graft_type=GraftingType.SGD,
650
+ nesterov=True,
651
+ exponent_override=0,
652
+ # Pass pmap 'batch axis name' in pmap mode.
653
+ batch_axis_name=None,
654
+ ### Only set following 3 params in pjit/spmd mode.
655
+ ### WARNING: Experimental
656
+ statistics_partition_spec=None,
657
+ preconditioner_partition_spec=None,
658
+ num_devices_for_pjit=None,
659
+ shard_optimizer_states=False,
660
+ ###
661
+ ### Experimental memory reduction mode
662
+ best_effort_memory_usage_reduction=False,
663
+ ###
664
+ inverse_failure_threshold=0.1,
665
+ moving_average_for_momentum=False,
666
+ skip_preconditioning_dim_size_gt=4096,
667
+ clip_by_scaled_gradient_norm=None,
668
+ precision=lax.Precision.HIGHEST):
669
+ """Distributed Shampoo optimizer.
670
+
671
+ Distributed Shampoo is a second-order preconditioned method (concretely, a
672
+ variant of full-matrix Adagrad), that provides significant convergence and
673
+ wall-clock time improvements compared to conventional first-order methods,
674
+ and that has been shown to scale to large state-of-the-art deep learning
675
+ models.
676
+
677
+ References:
678
+ Scalable Second Order Optimization for Deep Learning,
679
+ Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
680
+
681
+ Preprint: https://arxiv.org/abs/2002.09018
682
+
683
+ Args:
684
+ learning_rate: the step size used to update the parameters.
685
+ block_size: Block size for large layers (if > 0). Preconditioning compute
686
+ operation is cubic in the dimension of the tensor. Block size allows us to
687
+ chunk the layers into sub-layers of maximal dimension dictated by this
688
+ value. Use 128 as default (increase if you have compute budget).
689
+ beta1: momentum parameter.
690
+ beta2: second moment averaging parameter.
691
+ diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
692
+ to AdaGrad is enabled).
693
+ matrix_epsilon: epsilon to add to statistics before computing inverse pth
694
+ root. If you are running in f32 precision for inverse pth root
695
+ (recommended today) this can go upto 1e-6. If you have latest hardware
696
+ with native f64 precision, set this upto 1e-12.
697
+ weight_decay: Weight decay for regularization.
698
+ start_preconditioning_step: When to start Shampoo update before which
699
+ diagonal update is used. This is because we dont have enough information
700
+ to do stable inverse.
701
+ preconditioning_compute_steps: How often to compute preconditioner.
702
+ Performance tuning params for controlling memory and compute requirements.
703
+ Ideally set this and statistics_compute_steps params to 1.
704
+ statistics_compute_steps: How often to compute statistics.
705
+ best_effort_shape_interpretation: If there are some small dimensions,
706
+ collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
707
+ block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
708
+ graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
709
+ optimizer. This allows us to plugin the Shampoo optimizer into settings
710
+ where SGD/AdaGrad is already well tuned. Available options are:
711
+ GraftingType.SGD and GraftingType.ADAGRAD.
712
+ nesterov: Nesterov momentum.
713
+ exponent_override: Override the exponent used in matrix inverse.
714
+ batch_axis_name: labeled axis over pmap for data-parallel training the
715
+ optimizer used for.
716
+ statistics_partition_spec: PartitionSpec to be used in sharded mode.
717
+ preconditioner_partition_spec: PartitionSpec to be used in sharded mode.
718
+ num_devices_for_pjit: Number of devices to parallelize over when using pjit.
719
+ shard_optimizer_states: Shard optimizer states to save memory in model
720
+ parallel training.
721
+ best_effort_memory_usage_reduction: Best effort memory usage reduction.
722
+ diagonal_statistics -> jnp.bfloat16
723
+ momentum buffers (2x) -> jnp.int8
724
+ statistics, preconditioners -> jnp.int16 + diagonals
725
+ inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
726
+ determine that using this threshold.
727
+ moving_average_for_momentum: Whether to use moving average for momentum
728
+ instead of exponential moving average.
729
+ skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
730
+ greater than this value.
731
+ clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful
732
+ when using RMSProp Grafting).
733
+ precision: precision XLA related flag, the available options are: a)
734
+ lax.Precision.DEFAULT (better step time, but not precise) b)
735
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
736
+ (best possible precision, slowest)
737
+
738
+ Returns:
739
+ a GradientTransformation.
740
+ """
741
+
742
+ def quantized_dtype_for_momentum_buffers():
743
+ return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32
744
+
745
+ # TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes.
746
+ def quantized_dtype_for_diagonal_statistics_buffers():
747
+ return jnp.bfloat16 if best_effort_memory_usage_reduction else jnp.float32
748
+
749
+ # Preconditioner and statistics are both stores as int16 in this mode.
750
+ # We take out the diagonal to make quantization easier.
751
+ def quantized_dtype_for_second_moment_statistics_buffers():
752
+ return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32
753
+
754
+ # Preconditioner and statistics are both stores as int16 in this mode.
755
+ # We take out the diagonal to make quantization easier.
756
+ def quantized_dtype_for_second_moment_preconditioner_buffers():
757
+ return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32
758
+
759
+ def _to_float(maybe_quantized):
760
+ if isinstance(maybe_quantized, QuantizedValue):
761
+ return maybe_quantized.to_float()
762
+ else:
763
+ return maybe_quantized
764
+
765
+ def _maybe_quantize_statistics(statistics_list):
766
+ return _maybe_quantize_matrices_with_dtype(
767
+ statistics_list, quantized_dtype_for_second_moment_statistics_buffers())
768
+
769
+ def _maybe_quantize_preconditioners(statistics_list):
770
+ return _maybe_quantize_matrices_with_dtype(
771
+ statistics_list,
772
+ quantized_dtype_for_second_moment_preconditioner_buffers())
773
+
774
+ def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype):
775
+ if quantized_dtype != jnp.float32:
776
+ return ([
777
+ QuantizedValue.from_float_value(
778
+ s, quantized_dtype, extract_diagonal=True)
779
+ for s in statistics_list
780
+ ])
781
+ else:
782
+ return statistics_list
783
+
784
+ def _maybe_dequantize_preconditioners(preconditioner_list):
785
+ return _maybe_dequantize_matrices_with_dtype(
786
+ preconditioner_list,
787
+ quantized_dtype_for_second_moment_preconditioner_buffers())
788
+
789
+ def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype):
790
+ if quantized_dtype != jnp.float32:
791
+ return [s.to_float() for s in statistics_list]
792
+ else:
793
+ return statistics_list
794
+
795
+ def _quantize_diagonal_statistics(diagonal_statistics):
796
+ return QuantizedValue.from_float_value(
797
+ diagonal_statistics, quantized_dtype_for_diagonal_statistics_buffers())
798
+
799
+ def _quantize_momentum(momentum_statistics):
800
+ return QuantizedValue.from_float_value(
801
+ momentum_statistics, quantized_dtype_for_momentum_buffers())
802
+
803
+ def sharded_init_fn(params):
804
+ """Returns optimizer state (for PJIT mode).
805
+
806
+ Args:
807
+ params: the parameters that should be updated.
808
+ """
809
+ params_flat, treedef = jax.tree_flatten(params)
810
+ # Find max size to pad to.
811
+ max_size = 0
812
+ for param in params_flat:
813
+ preconditioner = Preconditioner(param, block_size,
814
+ best_effort_shape_interpretation)
815
+ if not _skip_preconditioning(param):
816
+ shapes = preconditioner.shapes_for_preconditioners()
817
+ sizes = [s[0] for s in shapes]
818
+ max_size = max(max(sizes), max_size)
819
+
820
+ padded_statistics = []
821
+ padded_preconditioners = []
822
+ local_stats_flat = []
823
+ exponents = []
824
+ for param in params_flat:
825
+ preconditioner = Preconditioner(param, block_size,
826
+ best_effort_shape_interpretation)
827
+ shapes = preconditioner.shapes_for_preconditioners()
828
+ sizes = []
829
+
830
+ statistics = []
831
+ preconditioners = []
832
+ index_start = len(padded_statistics)
833
+ if not _skip_preconditioning(param):
834
+ sizes = [s[0] for s in shapes]
835
+ shapes = preconditioner.shapes_for_preconditioners()
836
+ statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes]
837
+ preconditioners = [jnp.eye(max_size) for s in shapes]
838
+ padded_statistics.extend(statistics)
839
+ padded_preconditioners.extend(preconditioners)
840
+ exponent = (
841
+ preconditioner.exponent_for_preconditioner()
842
+ if exponent_override == 0 else exponent_override)
843
+ exponents.extend([exponent] * len(shapes))
844
+
845
+ diagonal_statistics = []
846
+ if graft_type != GraftingType.SGD:
847
+ diagonal_statistics = jnp.zeros_like(param)
848
+ local_stats_flat.append(
849
+ LocalShardedParameterStats(
850
+ _quantize_diagonal_statistics(diagonal_statistics),
851
+ _quantize_momentum(jnp.zeros_like(param)),
852
+ _quantize_momentum(jnp.zeros_like(param)), index_start, sizes))
853
+
854
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
855
+ # Pad the statistics and preconditioner matrices to be a multiple of
856
+ # num devices.
857
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
858
+ # is split on.
859
+ to_pad = -len(padded_statistics) % num_devices_for_pjit
860
+ padded_statistics.extend([
861
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
862
+ for _ in range(to_pad)
863
+ ])
864
+ padded_preconditioners.extend([
865
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
866
+ for _ in range(to_pad)
867
+ ])
868
+ exponents.extend([1 for _ in range(to_pad)])
869
+ global_stats = GlobalShardedParameterStats(
870
+ jnp.stack(padded_statistics), jnp.stack(padded_preconditioners),
871
+ jnp.stack(exponents))
872
+ return ShampooState(
873
+ count=jnp.zeros([], jnp.int32),
874
+ stats=ShardedShampooStats(global_stats, local_stats))
875
+
876
+ def _max_statistics_size_from_params(params):
877
+ max_size = 0
878
+ for param in params:
879
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
880
+ preconditioner = Preconditioner(param_clone, block_size,
881
+ best_effort_shape_interpretation)
882
+ if not _skip_preconditioning(param):
883
+ shapes = preconditioner.shapes_for_preconditioners()
884
+ sizes = [s[0] for s in shapes]
885
+ max_size = max(max(sizes), max_size)
886
+ return max_size
887
+
888
+ def _remove_leading_sharding_annotation(pspec):
889
+ """Mapping from N-d to (N-1)-d, used for quantization, factoring etc."""
890
+ # None and PSpec(None) are valid PSpecs.
891
+ if pspec and len(pspec) > 1:
892
+ return PartitionSpec(*pspec[1:])
893
+ else:
894
+ return pspec
895
+
896
+ def sharded_init_partition_spec_fn(params, params_partition_spec,
897
+ partition_spec_for_statistics):
898
+ """Returns a parallel state tree with PartitionSpec associated with state.
899
+
900
+
901
+ Args:
902
+ params: A pytree with params.
903
+ params_partition_spec: A pytree with PartitionSpec for params.
904
+ partition_spec_for_statistics: PartitionSpec for the statistics.
905
+ """
906
+ # Parallel lists of spec, and params.
907
+ param_pspec_flat, _ = jax.tree_flatten(params_partition_spec)
908
+ params_flat, treedef = jax.tree_flatten(params)
909
+ assert param_pspec_flat
910
+ assert params_flat
911
+ # Step is replicated across cores.
912
+ # None means cores.
913
+ local_stats_flat = []
914
+ num_statistics = 0
915
+ for param, param_pspec in zip(params_flat, param_pspec_flat):
916
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
917
+ preconditioner = Preconditioner(param_clone, block_size,
918
+ best_effort_shape_interpretation)
919
+ shapes = preconditioner.shapes_for_preconditioners()
920
+ sizes = []
921
+
922
+ index_start = num_statistics
923
+ if not _skip_preconditioning(param):
924
+ sizes = [s[0] for s in shapes]
925
+ shapes = preconditioner.shapes_for_preconditioners()
926
+ num_statistics += len(shapes)
927
+
928
+ diagonal_statistics_pspec = []
929
+ diagonal_statistics_scale_pspec = []
930
+ if graft_type != GraftingType.SGD:
931
+ # Identically shaped param.
932
+ diagonal_statistics_pspec = param_pspec
933
+ if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
934
+ diagonal_statistics_scale_pspec = _remove_leading_sharding_annotation(
935
+ param_pspec)
936
+
937
+ m1_pspec = param_pspec
938
+ m2_pspec = param_pspec
939
+
940
+ m1_scale_pspec = []
941
+ m2_scale_pspec = []
942
+
943
+ if quantized_dtype_for_momentum_buffers() != jnp.float32:
944
+ m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
945
+ m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec)
946
+
947
+ local_stats_flat.append(
948
+ LocalShardedParameterStats(
949
+ QuantizedValue(diagonal_statistics_pspec, [],
950
+ diagonal_statistics_scale_pspec,
951
+ quantized_dtype_for_diagonal_statistics_buffers(),
952
+ False, list(param.shape)),
953
+ QuantizedValue(m1_pspec, [], m1_scale_pspec,
954
+ quantized_dtype_for_momentum_buffers(), False,
955
+ list(param.shape)),
956
+ QuantizedValue(m2_pspec, [], m2_scale_pspec,
957
+ quantized_dtype_for_momentum_buffers(), False,
958
+ list(param.shape)), index_start, sizes))
959
+
960
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
961
+ global_stats = GlobalShardedParameterStats(partition_spec_for_statistics,
962
+ partition_spec_for_statistics,
963
+ PartitionSpec())
964
+ count_pspec = PartitionSpec()
965
+ return ShampooState(
966
+ count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats))
967
+
968
+ def sharded_init_shape_and_dtype_fn(params):
969
+ """Returns a parallel state tree with shape, dtype associated with state.
970
+
971
+
972
+ Args:
973
+ params: A pytree with params.
974
+ """
975
+ # Parallel lists of spec, and params.
976
+ params_flat, treedef = jax.tree_flatten(params)
977
+ assert params_flat
978
+ # Step is replicated across cores.
979
+ # None means cores.
980
+ local_stats_flat = []
981
+ num_statistics = 0
982
+ for param in params_flat:
983
+ param_clone = jnp.zeros(param.shape, dtype=param.dtype)
984
+ preconditioner = Preconditioner(param_clone, block_size,
985
+ best_effort_shape_interpretation)
986
+ shapes = preconditioner.shapes_for_preconditioners()
987
+ sizes = []
988
+
989
+ index_start = num_statistics
990
+ if not _skip_preconditioning(param):
991
+ sizes = [s[0] for s in shapes]
992
+ shapes = preconditioner.shapes_for_preconditioners()
993
+ num_statistics += len(shapes)
994
+
995
+ diagonal_statistics_shape_and_dtype = []
996
+ diagonal_statistics_scale_shape_and_dtype = []
997
+ if graft_type != GraftingType.SGD:
998
+ diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype]
999
+ qdtype = quantized_dtype_for_diagonal_statistics_buffers()
1000
+ if qdtype != jnp.float32:
1001
+ diagonal_statistics_shape_and_dtype = [list(param.shape), qdtype]
1002
+ diagonal_statistics_scale_shape_and_dtype = [
1003
+ list(param.shape)[1:], param.dtype
1004
+ ]
1005
+
1006
+ m1_shape_and_dtype = [list(param.shape), param.dtype]
1007
+ m2_shape_and_dtype = [list(param.shape), param.dtype]
1008
+
1009
+ m1_scale_shape_and_dtype = []
1010
+ m2_scale_shape_and_dtype = []
1011
+
1012
+ qdtype = quantized_dtype_for_momentum_buffers()
1013
+ if qdtype != jnp.float32:
1014
+ m1_shape_and_dtype = [list(param.shape), qdtype]
1015
+ m2_shape_and_dtype = [list(param.shape), qdtype]
1016
+
1017
+ m1_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1018
+ m2_scale_shape_and_dtype = [list(param.shape)[1:], qdtype]
1019
+
1020
+ local_stats_flat.append(
1021
+ LocalShardedParameterStats(
1022
+ QuantizedValue(diagonal_statistics_shape_and_dtype, [],
1023
+ diagonal_statistics_scale_shape_and_dtype,
1024
+ quantized_dtype_for_diagonal_statistics_buffers(),
1025
+ False, list(param.shape)),
1026
+ QuantizedValue(m1_shape_and_dtype, [], m1_scale_shape_and_dtype,
1027
+ quantized_dtype_for_momentum_buffers(), False,
1028
+ list(param.shape)),
1029
+ QuantizedValue(m2_shape_and_dtype, [], m2_scale_shape_and_dtype,
1030
+ quantized_dtype_for_momentum_buffers(), False,
1031
+ list(param.shape)), index_start, sizes))
1032
+
1033
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
1034
+ max_statistics_size = _max_statistics_size_from_params(params_flat)
1035
+ to_pad = -num_statistics % num_devices_for_pjit
1036
+ num_statistics += to_pad
1037
+ statistics_shape = [
1038
+ num_statistics, max_statistics_size, max_statistics_size
1039
+ ]
1040
+ global_stats = GlobalShardedParameterStats([statistics_shape, jnp.float32],
1041
+ [statistics_shape, jnp.float32],
1042
+ [[num_statistics], jnp.int32])
1043
+ return ShampooState(
1044
+ count=[[], jnp.float32],
1045
+ stats=ShardedShampooStats(global_stats, local_stats))
1046
+
1047
+ def sharded_update_fn(grads, state, params):
1048
+ """Transform the input gradient and update all statistics in sharded mode.
1049
+
1050
+ Args:
1051
+ grads: the gradient tensors for the parameters.
1052
+ state: a named tuple containing the state of the optimizer
1053
+ params: the parameters that should be updated.
1054
+
1055
+ Returns:
1056
+ A tuple containing the new parameters and the new optimizer state.
1057
+ """
1058
+ params_flat, treedef = jax.tree_flatten(params)
1059
+ grads_flat = treedef.flatten_up_to(grads)
1060
+
1061
+ global_stats = state.stats.global_stats
1062
+ local_stats_flat = treedef.flatten_up_to(state.stats.local_stats)
1063
+ stats_flat = [
1064
+ _convert_to_parameter_stats(global_stats, local_stat)
1065
+ for local_stat in local_stats_flat
1066
+ ]
1067
+ new_stats_flat = jax.tree_multimap(
1068
+ lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
1069
+ stats_flat, params_flat)
1070
+
1071
+ outputs = jax.tree_multimap(
1072
+ lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
1073
+ new_stats_flat, params_flat)
1074
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
1075
+
1076
+ updates = jax.tree_unflatten(treedef, updates_flat)
1077
+ # Create new local_stats
1078
+ new_local_stats_flat = [
1079
+ _convert_from_parameter_stats(new_stat, local_stat)
1080
+ for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
1081
+ ]
1082
+ new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
1083
+
1084
+ max_size = global_stats.statistics.shape[1]
1085
+ new_padded_statistics = []
1086
+ for stat in new_stats_flat:
1087
+ new_padded_statistics.extend(
1088
+ [pad_matrix(stat, max_size) for stat in stat.statistics])
1089
+
1090
+ # Create global stats
1091
+ # TODO(rohananil): Preconditioner is not updated every step, so cost of
1092
+ # stack/pad can be obviated away.
1093
+ # Pad the statistics and preconditioner matrices to be a multiple of
1094
+ # num devices.
1095
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
1096
+ # is split on.
1097
+ to_pad = -len(new_padded_statistics) % num_devices_for_pjit
1098
+ new_padded_statistics.extend([
1099
+ jnp.eye(max_size, dtype=new_padded_statistics[0].dtype)
1100
+ for _ in range(to_pad)
1101
+ ])
1102
+ new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
1103
+ new_stacked_padded_statistics = pjit.with_sharding_constraint(
1104
+ new_stacked_padded_statistics, statistics_partition_spec)
1105
+ def _internal_inverse_pth_root_all():
1106
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
1107
+ new_stacked_padded_statistics, global_stats.exponents,
1108
+ statistics_partition_spec)
1109
+ return preconditioners, errors
1110
+
1111
+ if preconditioning_compute_steps == 1:
1112
+ new_preconditioners, errors = _internal_inverse_pth_root_all()
1113
+ else:
1114
+ # Passing statistics instead of preconditioners as they are similarly
1115
+ # shaped tensors. Note statistics will be ignored as we are passing in
1116
+ # a large init value for error.
1117
+ preconditioners_init = new_stacked_padded_statistics
1118
+ n = new_stacked_padded_statistics.shape[0]
1119
+ errors_init = jnp.ones([n], jnp.float32) * inverse_failure_threshold
1120
+ init_state = [preconditioners_init, errors_init]
1121
+ perform_step = state.count % preconditioning_compute_steps == 0
1122
+ new_preconditioners, errors = efficient_cond(
1123
+ perform_step, _internal_inverse_pth_root_all, init_state)
1124
+
1125
+ errors = errors.reshape((-1, 1, 1))
1126
+ predicate = jnp.logical_or(
1127
+ jnp.isnan(errors),
1128
+ errors >= inverse_failure_threshold).astype(new_preconditioners.dtype)
1129
+ # TODO(rohananil): Check for numerical instabilities.
1130
+ new_conditional_preconditioners = (
1131
+ predicate * global_stats.preconditioners +
1132
+ (1.0 - predicate) * new_preconditioners)
1133
+ new_global_stats = GlobalShardedParameterStats(
1134
+ new_stacked_padded_statistics, new_conditional_preconditioners,
1135
+ global_stats.exponents)
1136
+ new_shampoo_state = ShampooState(
1137
+ count=state.count + 1,
1138
+ stats=ShardedShampooStats(new_global_stats, new_local_stats))
1139
+ return updates, new_shampoo_state
1140
+
1141
+ def init_fn(params):
1142
+ """Initialise the optimiser's state."""
1143
+
1144
+ def _init(param):
1145
+ preconditioner = Preconditioner(param, block_size,
1146
+ best_effort_shape_interpretation)
1147
+ statistics = []
1148
+ preconditioners = []
1149
+ if not _skip_preconditioning(param):
1150
+ shapes = preconditioner.shapes_for_preconditioners()
1151
+ statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes]
1152
+ preconditioners = [jnp.eye(s[0]) for s in shapes]
1153
+
1154
+ diagonal_statistics = []
1155
+ if graft_type != GraftingType.SGD:
1156
+ diagonal_statistics = jnp.zeros_like(param)
1157
+ return ParameterStats(
1158
+ _quantize_diagonal_statistics(diagonal_statistics),
1159
+ _maybe_quantize_statistics(statistics),
1160
+ _maybe_quantize_preconditioners(preconditioners),
1161
+ _quantize_momentum(jnp.zeros_like(param)),
1162
+ _quantize_momentum(jnp.zeros_like(param)))
1163
+ return ShampooState(
1164
+ count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params))
1165
+
1166
+ def _skip_preconditioning(param):
1167
+ return len(param.shape) < 1 or any(
1168
+ [s > skip_preconditioning_dim_size_gt for s in param.shape])
1169
+
1170
+ def _compute_stats(grad, state, param, step):
1171
+ """Compute per-parameter statistics."""
1172
+ preconditioner = Preconditioner(param, block_size,
1173
+ best_effort_shape_interpretation)
1174
+ new_statistics = [[]] * len(state.statistics)
1175
+ w1 = beta2
1176
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
1177
+ if not _skip_preconditioning(param):
1178
+
1179
+ def compute_updated_statistics():
1180
+ new_stats = preconditioner.statistics_from_grad(grad)
1181
+ new_stats_accumulators = []
1182
+ for stat, stat_accumulator in zip(new_stats, state.statistics):
1183
+ new_stats_accumulators.append(w1 * _to_float(stat_accumulator) +
1184
+ w2 * stat)
1185
+ return _maybe_quantize_statistics(new_stats_accumulators)
1186
+
1187
+ if statistics_compute_steps > 1:
1188
+ perform_step = step % statistics_compute_steps == 0
1189
+ init_state = state.statistics
1190
+ new_statistics = list(
1191
+ efficient_cond(perform_step, compute_updated_statistics,
1192
+ init_state))
1193
+ else:
1194
+ new_statistics = compute_updated_statistics()
1195
+ return ParameterStats(state.diagonal_statistics, new_statistics,
1196
+ state.preconditioners, state.diagonal_momentum,
1197
+ state.momentum)
1198
+
1199
+ def _matrix_inverse_pth_root_vmap(xs, ps):
1200
+ mi_pth_root = functools.partial(
1201
+ matrix_inverse_pth_root,
1202
+ ridge_epsilon=matrix_epsilon,
1203
+ precision=precision)
1204
+ return jax.vmap(mi_pth_root)(xs, ps)
1205
+
1206
+ def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps):
1207
+
1208
+ def _quantized_to_float(qx, qd, qb):
1209
+ qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape))
1210
+ return qv.to_float()
1211
+
1212
+ def matrix_inverse_pth_root_wrapper(qx, qd, qb, p):
1213
+ v = _quantized_to_float(qx, qd, qb)
1214
+ preconditioner, error = matrix_inverse_pth_root(
1215
+ v, p, ridge_epsilon=matrix_epsilon, precision=precision)
1216
+ qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True)
1217
+ return qp.quantized, qp.diagonal, qp.bucket_size, error
1218
+
1219
+ return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
1220
+
1221
+ def _matrix_inverse_pth_root_pjit(xs, ps, statistics_partition_spec=None):
1222
+ # Partition the concatenated statistics matrix across all cores.
1223
+ pspec_for_partition = preconditioner_partition_spec
1224
+ partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition)
1225
+ partitioned_ps = pjit.with_sharding_constraint(
1226
+ ps, pjit.PartitionSpec(preconditioner_partition_spec[0]))
1227
+ # Run matrix inverse pth root on each shard.
1228
+ partitioned_preconditioners, partitioned_errors = (
1229
+ _matrix_inverse_pth_root_vmap(partitioned_xs, partitioned_ps))
1230
+ # Reshard output to have the same PSpec as input. This is required to avoid
1231
+ # vmap seeing the full set of statistics.
1232
+ partitioned_preconditioners = pjit.with_sharding_constraint(
1233
+ partitioned_preconditioners, pspec_for_partition)
1234
+ # Recombine the outputs at each core.
1235
+ preconditioners = pjit.with_sharding_constraint(partitioned_preconditioners,
1236
+ statistics_partition_spec)
1237
+ errors = pjit.with_sharding_constraint(partitioned_errors,
1238
+ pjit.PartitionSpec())
1239
+ return preconditioners, errors
1240
+
1241
+ def _pmap_compute_preconditioners(states, step, statistics,
1242
+ num_statistics_per_state, original_shapes,
1243
+ exponents, max_size, prev_preconditioners):
1244
+ """Computes preconditioners for given statistics in states in PMAP mode.
1245
+
1246
+ Args:
1247
+ states: A list of optimizer states.
1248
+ step: Current step number
1249
+ statistics: A list of statistics for all variables (for every dim)
1250
+ num_statistics_per_state: Number of statistis per state to reconstruct
1251
+ output states.
1252
+ original_shapes: A list of shapes of the statistics.
1253
+ exponents: Exponent power to use for inverse-pth roots.
1254
+ max_size: Maximum dim of the statistics to pad.
1255
+ prev_preconditioners: Previously available preconditioner.
1256
+
1257
+ Returns:
1258
+ New optimizer states after computing the preconditioner.
1259
+ """
1260
+ num_devices = lax.psum(1, batch_axis_name)
1261
+ num_statistics = len(statistics)
1262
+ # Pad statistics and exponents to next multiple of num_devices.
1263
+ packed_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1264
+ to_pad = -num_statistics % num_devices
1265
+ packed_statistics.extend([
1266
+ jnp.eye(max_size, dtype=packed_statistics[0].dtype)
1267
+ for _ in range(to_pad)
1268
+ ])
1269
+ exponents.extend([1 for _ in range(to_pad)])
1270
+
1271
+ if not packed_statistics:
1272
+ return states
1273
+
1274
+ all_statistics = batch(packed_statistics, num_devices)
1275
+ all_exponents = batch(exponents, num_devices)
1276
+
1277
+ def _internal_inverse_pth_root_all():
1278
+ current_replica = lax.axis_index(batch_axis_name)
1279
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
1280
+ all_statistics[current_replica], all_exponents[current_replica])
1281
+ preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
1282
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1283
+ preconditioners_flat = unbatch(preconditioners)
1284
+ errors_flat = unbatch(errors)
1285
+ return preconditioners_flat, errors_flat
1286
+
1287
+ if preconditioning_compute_steps == 1:
1288
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1289
+ else:
1290
+ # Passing statistics instead of preconditioners as they are similarly
1291
+ # shaped tensors. Note statistics will be ignored as we are passing in
1292
+ # a large init value for error.
1293
+ preconditioners_init = packed_statistics
1294
+ errors_init = ([inverse_failure_threshold] * len(packed_statistics))
1295
+ init_state = [preconditioners_init, errors_init]
1296
+ perform_step = step % preconditioning_compute_steps == 0
1297
+ preconditioners_flat, errors_flat = efficient_cond(
1298
+ perform_step, _internal_inverse_pth_root_all, init_state)
1299
+
1300
+ def _skip(error):
1301
+ condition = jnp.logical_or(
1302
+ jnp.isnan(error), error >= inverse_failure_threshold)
1303
+ return condition.astype(error.dtype)
1304
+
1305
+ def _select_preconditioner(error, new_p, old_p):
1306
+ return lax.cond(
1307
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
1308
+
1309
+ new_preconditioners_flat = []
1310
+ for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
1311
+ prev_preconditioners, errors_flat):
1312
+ new_preconditioners_flat.append(
1313
+ _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
1314
+
1315
+ assert len(states) == len(num_statistics_per_state)
1316
+ assert len(new_preconditioners_flat) == num_statistics
1317
+
1318
+ # Add back empty preconditioners so we that we can set the optimizer state.
1319
+ preconditioners_for_states = []
1320
+ idx = 0
1321
+ for num_statistics, state in zip(num_statistics_per_state, states):
1322
+ if num_statistics == 0:
1323
+ preconditioners_for_states.append([])
1324
+ else:
1325
+ preconditioners_for_state = new_preconditioners_flat[idx:idx +
1326
+ num_statistics]
1327
+ assert len(state.statistics) == len(preconditioners_for_state)
1328
+ preconditioners_for_states.append(preconditioners_for_state)
1329
+ idx += num_statistics
1330
+ new_states = []
1331
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1332
+ new_states.append(
1333
+ ParameterStats(state.diagonal_statistics, state.statistics,
1334
+ new_preconditioners, state.diagonal_momentum,
1335
+ state.momentum))
1336
+
1337
+ return new_states
1338
+
1339
+ def _pmap_quantized_compute_preconditioners(states, step, statistics,
1340
+ num_statistics_per_state,
1341
+ original_shapes, exponents,
1342
+ max_size, prev_preconditioners):
1343
+ """Computes preconditioners for given statistics in states in PMAP mode.
1344
+
1345
+ For quantization, each statistic is represented by three values:
1346
+ quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots
1347
+ without ever recreating the original matrix in f32.
1348
+
1349
+ Args:
1350
+ states: A list of optimizer states.
1351
+ step: Current step number
1352
+ statistics: A list of statistics for all variables (for every dim)
1353
+ num_statistics_per_state: Number of statistis per state to reconstruct
1354
+ output states.
1355
+ original_shapes: A list of shapes of the statistics.
1356
+ exponents: Exponent power to use for inverse-pth roots.
1357
+ max_size: Maximum dim of the statistics to pad.
1358
+ prev_preconditioners: Previously available preconditioner.
1359
+
1360
+ Returns:
1361
+ New optimizer states after computing the preconditioner.
1362
+ """
1363
+ num_devices = lax.psum(1, batch_axis_name)
1364
+ num_statistics = len(statistics)
1365
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
1366
+ # Complexity here is around: shapes needing be statically shaped,
1367
+ # our custom quantization type requires a different type of packing.
1368
+
1369
+ # Parallel tensors:
1370
+ # quantized [dxd]
1371
+ # diagonals [d] f32
1372
+ # bucket_sizes [d] f32
1373
+ packed_quantized_statistics = [
1374
+ pad_matrix(stat.quantized, max_size) for stat in statistics
1375
+ ]
1376
+ packed_quantized_diagonals = [
1377
+ pad_vector(stat.diagonal, max_size) for stat in statistics
1378
+ ]
1379
+ packed_quantized_bucket_sizes = [
1380
+ pad_vector(stat.bucket_size, max_size) for stat in statistics
1381
+ ]
1382
+
1383
+ to_pad = -num_statistics % num_devices
1384
+ padded_eye = jnp.eye(max_size, dtype=jnp.float32)
1385
+ quantized_eye = QuantizedValue.from_float_value(padded_eye, quantized_dtype,
1386
+ True)
1387
+ packed_quantized_statistics.extend(
1388
+ [quantized_eye.quantized for _ in range(to_pad)])
1389
+ packed_quantized_diagonals.extend(
1390
+ [quantized_eye.diagonal for _ in range(to_pad)])
1391
+ packed_quantized_bucket_sizes.extend(
1392
+ [quantized_eye.bucket_size for _ in range(to_pad)])
1393
+ exponents.extend([1 for _ in range(to_pad)])
1394
+
1395
+ if not packed_quantized_statistics:
1396
+ return states
1397
+
1398
+ all_quantized_statistics = batch(packed_quantized_statistics, num_devices)
1399
+ all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices)
1400
+ all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes,
1401
+ num_devices)
1402
+ all_exponents = batch(exponents, num_devices)
1403
+
1404
+ def _internal_inverse_pth_root_all():
1405
+ current_replica = lax.axis_index(batch_axis_name)
1406
+ (quantized_preconditioners, quantized_diagonals, quantized_bucket_sizes,
1407
+ errors) = (
1408
+ _quantized_matrix_inverse_pth_root_vmap(
1409
+ all_quantized_statistics[current_replica],
1410
+ all_quantized_diagonals[current_replica],
1411
+ all_quantized_bucket_sizes[current_replica],
1412
+ all_exponents[current_replica]))
1413
+ quantized_preconditioners = jax.lax.all_gather(quantized_preconditioners,
1414
+ batch_axis_name)
1415
+ quantized_diagonals = jax.lax.all_gather(quantized_diagonals,
1416
+ batch_axis_name)
1417
+ quantized_bucket_sizes = jax.lax.all_gather(quantized_bucket_sizes,
1418
+ batch_axis_name)
1419
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1420
+ quantized_preconditioners_flat = unbatch(quantized_preconditioners)
1421
+ quantized_diagonals_flat = unbatch(quantized_diagonals)
1422
+ quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes)
1423
+ errors_flat = unbatch(errors)
1424
+ return (quantized_preconditioners_flat, quantized_diagonals_flat,
1425
+ quantized_bucket_sizes_flat, errors_flat)
1426
+
1427
+ if preconditioning_compute_steps == 1:
1428
+ (quantized_preconditioners_flat, quantized_diagonals_flat,
1429
+ quantized_bucket_sizes_flat, errors_flat) = (
1430
+ _internal_inverse_pth_root_all())
1431
+ else:
1432
+ # Passing statistics instead of preconditioners as they are similarly
1433
+ # shaped tensors. Note statistics will be ignored as we are passing in
1434
+ # a large init value for error.
1435
+ quantized_preconditioners_init = packed_quantized_statistics
1436
+ quantized_diagonals_init = packed_quantized_diagonals
1437
+ quantized_bucket_sizes_init = packed_quantized_bucket_sizes
1438
+ errors_init = ([inverse_failure_threshold] *
1439
+ len(quantized_preconditioners_init))
1440
+ init_state = [
1441
+ quantized_preconditioners_init, quantized_diagonals_init,
1442
+ quantized_bucket_sizes_init, errors_init
1443
+ ]
1444
+ perform_step = step % preconditioning_compute_steps == 0
1445
+ (quantized_preconditioners_flat, quantized_diagonals_flat,
1446
+ quantized_bucket_sizes_flat, errors_flat) = (
1447
+ efficient_cond(perform_step, _internal_inverse_pth_root_all,
1448
+ init_state))
1449
+
1450
+ def _skip(error):
1451
+ condition = jnp.logical_or(
1452
+ jnp.isnan(error), error >= inverse_failure_threshold)
1453
+ return condition.astype(error.dtype)
1454
+
1455
+ def _select_preconditioner(error, new_p, old_p):
1456
+ return lax.cond(
1457
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
1458
+
1459
+ new_quantized_preconditioners_flat = []
1460
+ new_quantized_diagonals_flat = []
1461
+ new_quantized_bucket_sizes_flat = []
1462
+ for p, d, b, shape, prev_p, error in zip(quantized_preconditioners_flat,
1463
+ quantized_diagonals_flat,
1464
+ quantized_bucket_sizes_flat,
1465
+ original_shapes,
1466
+ prev_preconditioners, errors_flat):
1467
+ new_quantized_preconditioners_flat.append(
1468
+ _select_preconditioner(error, p[:shape[0], :shape[1]],
1469
+ prev_p.quantized))
1470
+ new_quantized_diagonals_flat.append(
1471
+ _select_preconditioner(error, d[:shape[0]], prev_p.diagonal))
1472
+ new_quantized_bucket_sizes_flat.append(
1473
+ _select_preconditioner(error, b[:shape[0]], prev_p.bucket_size))
1474
+
1475
+ assert len(states) == len(num_statistics_per_state)
1476
+ assert len(new_quantized_preconditioners_flat) == num_statistics
1477
+ assert len(new_quantized_diagonals_flat) == num_statistics
1478
+ assert len(new_quantized_bucket_sizes_flat) == num_statistics
1479
+
1480
+ # Add back empty preconditioners so we that we can set the optimizer state.
1481
+ preconditioners_for_states = []
1482
+ idx = 0
1483
+ for num_statistics, state in zip(num_statistics_per_state, states):
1484
+ if num_statistics == 0:
1485
+ preconditioners_for_states.append([])
1486
+ else:
1487
+ quantized_preconditioners_for_state = new_quantized_preconditioners_flat[
1488
+ idx:idx + num_statistics]
1489
+ quantized_diagonals_for_state = new_quantized_diagonals_flat[
1490
+ idx:idx + num_statistics]
1491
+ quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
1492
+ idx:idx + num_statistics]
1493
+
1494
+ assert len(state.statistics) == len(quantized_preconditioners_for_state)
1495
+ assert len(state.statistics) == len(quantized_diagonals_for_state)
1496
+ assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
1497
+
1498
+ quantized_preconditioners = []
1499
+ for qv, qd, qb in zip(quantized_preconditioners_for_state,
1500
+ quantized_diagonals_for_state,
1501
+ quantized_bucket_sizes_for_state):
1502
+ quantized_preconditioners.append(
1503
+ QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape)))
1504
+ preconditioners_for_states.append(quantized_preconditioners)
1505
+ idx += num_statistics
1506
+ new_states = []
1507
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1508
+ new_states.append(
1509
+ ParameterStats(state.diagonal_statistics, state.statistics,
1510
+ new_preconditioners, state.diagonal_momentum,
1511
+ state.momentum))
1512
+
1513
+ return new_states
1514
+
1515
+ def _pjit_compute_preconditioners(states, step, statistics,
1516
+ num_statistics_per_state, original_shapes,
1517
+ exponents, max_size, prev_preconditioners):
1518
+ """Computes preconditioners for given statistics in states in PJIT mode.
1519
+
1520
+ Args:
1521
+ states: A list of optimizer states.
1522
+ step: Current step number
1523
+ statistics: A list of statistics for all variables (for every dim)
1524
+ num_statistics_per_state: Number of statistis per state to reconstruct
1525
+ output states.
1526
+ original_shapes: A list of shapes of the statistics.
1527
+ exponents: Exponent power to use for inverse-pth roots.
1528
+ max_size: Maximum dim of the statistics to pad.
1529
+ prev_preconditioners: Previously available preconditioner.
1530
+
1531
+ Returns:
1532
+ New optimizer states after computing the preconditioner.
1533
+ """
1534
+ num_statistics = len(statistics)
1535
+ to_pad = -num_statistics % num_devices_for_pjit
1536
+ padded_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1537
+ padded_statistics.extend([
1538
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
1539
+ for _ in range(to_pad)
1540
+ ])
1541
+ exponents.extend([1 for _ in range(to_pad)])
1542
+ all_statistics = jnp.stack(padded_statistics)
1543
+ all_exponents = jnp.stack(exponents)
1544
+
1545
+ def _internal_inverse_pth_root_all():
1546
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
1547
+ all_statistics, all_exponents)
1548
+ b1 = preconditioners.shape[0]
1549
+
1550
+ def split(batched_values):
1551
+ return [
1552
+ jnp.squeeze(v)
1553
+ for v in jnp.split(batched_values, indices_or_sections=b1, axis=0)
1554
+ ]
1555
+
1556
+ return split(preconditioners), split(errors)
1557
+
1558
+ if preconditioning_compute_steps == 1:
1559
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1560
+ else:
1561
+ # Passing statistics instead of preconditioners as they are similarly
1562
+ # shaped tensors. Note statistics will be ignored as we are passing in
1563
+ # a large init value for error.
1564
+ preconditioners_init = padded_statistics
1565
+ errors_init = [inverse_failure_threshold] * len(padded_statistics)
1566
+ init_state = [preconditioners_init, errors_init]
1567
+ perform_step = step % preconditioning_compute_steps == 0
1568
+ preconditioners_flat, errors_flat = efficient_cond(
1569
+ perform_step, _internal_inverse_pth_root_all, init_state)
1570
+
1571
+ def _skip(error):
1572
+ condition = jnp.logical_or(
1573
+ jnp.isnan(error), error >= inverse_failure_threshold)
1574
+ return condition.astype(error.dtype)
1575
+
1576
+ def _select_preconditioner(error, new_p, old_p):
1577
+ return lax.cond(
1578
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
1579
+
1580
+ new_preconditioners_flat = []
1581
+ for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
1582
+ prev_preconditioners, errors_flat):
1583
+ new_preconditioners_flat.append(
1584
+ _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
1585
+
1586
+ assert len(states) == len(num_statistics_per_state)
1587
+ assert len(new_preconditioners_flat) == num_statistics
1588
+
1589
+ # Add back empty preconditioners so we that we can set the optimizer state.
1590
+ preconditioners_for_states = []
1591
+ idx = 0
1592
+ for num_statistics, state in zip(num_statistics_per_state, states):
1593
+ if num_statistics == 0:
1594
+ preconditioners_for_states.append([])
1595
+ else:
1596
+ preconditioners_for_state = new_preconditioners_flat[idx:idx +
1597
+ num_statistics]
1598
+ assert len(state.statistics) == len(preconditioners_for_state)
1599
+ preconditioners_for_states.append(preconditioners_for_state)
1600
+ idx += num_statistics
1601
+ new_states = []
1602
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1603
+ new_states.append(
1604
+ ParameterStats(state.diagonal_statistics, state.statistics,
1605
+ new_preconditioners, state.diagonal_momentum,
1606
+ state.momentum))
1607
+
1608
+ return new_states
1609
+
1610
+ def _compute_preconditioners(states, params, step):
1611
+ """Computes preconditioners for given statistics in states.
1612
+
1613
+ Args:
1614
+ states: A list of optimizer states.
1615
+ params: A list of params.
1616
+ step: Current step number
1617
+
1618
+ Returns:
1619
+ New optimizer states after computing the preconditioner.
1620
+ """
1621
+ statistics = []
1622
+ num_statistics_per_state = []
1623
+ original_shapes = []
1624
+ exponents = []
1625
+ max_size = 0
1626
+ prev_preconditioners = []
1627
+
1628
+ for state, param in zip(states, params):
1629
+ num_statistics = len(state.statistics)
1630
+ num_statistics_per_state.append(num_statistics)
1631
+ original_shapes_for_state = []
1632
+ if num_statistics > 0:
1633
+ preconditioner = Preconditioner(param, block_size,
1634
+ best_effort_shape_interpretation)
1635
+ for statistic in state.statistics:
1636
+ exponents.append(preconditioner.exponent_for_preconditioner(
1637
+ ) if exponent_override == 0 else exponent_override)
1638
+ original_shapes_for_state.append(statistic.shape)
1639
+ max_size = max(max_size, statistic.shape[0])
1640
+
1641
+ statistics.extend(state.statistics)
1642
+ prev_preconditioners.extend(state.preconditioners)
1643
+ original_shapes.extend(original_shapes_for_state)
1644
+
1645
+ if batch_axis_name:
1646
+ # Quantization is only enabled if batch_axis_name is not set.
1647
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
1648
+
1649
+ if quantized_dtype == jnp.float32:
1650
+ return _pmap_compute_preconditioners(states, step, statistics,
1651
+ num_statistics_per_state,
1652
+ original_shapes, exponents,
1653
+ max_size, prev_preconditioners)
1654
+ else:
1655
+ return _pmap_quantized_compute_preconditioners(
1656
+ states, step, statistics, num_statistics_per_state, original_shapes,
1657
+ exponents, max_size, prev_preconditioners)
1658
+
1659
+ else:
1660
+ return _pjit_compute_preconditioners(states, step, statistics,
1661
+ num_statistics_per_state,
1662
+ original_shapes, exponents, max_size,
1663
+ prev_preconditioners)
1664
+
1665
+ def _transform_grad(grad, state, param, step):
1666
+ """Transform per-parameter gradients."""
1667
+ preconditioner = Preconditioner(param, block_size,
1668
+ best_effort_shape_interpretation)
1669
+ sgd_update = grad
1670
+ new_diagonal_statistics = state.diagonal_statistics.to_float()
1671
+ if graft_type == GraftingType.ADAGRAD:
1672
+ new_diagonal_statistics = state.diagonal_statistics.to_float(
1673
+ ) + jnp.square(grad)
1674
+ adagrad_update = grad / (
1675
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
1676
+ grafting_update = adagrad_update
1677
+ elif (graft_type == GraftingType.RMSPROP or
1678
+ graft_type == GraftingType.RMSPROP_NORMALIZED):
1679
+
1680
+ scaled_grad = grad
1681
+ if graft_type == GraftingType.RMSPROP_NORMALIZED:
1682
+ scaled_grad = grad / jnp.linalg.norm(grad)
1683
+
1684
+ w1 = beta2
1685
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
1686
+
1687
+ new_diagonal_statistics = (
1688
+ w1 * state.diagonal_statistics.to_float() +
1689
+ w2 * jnp.square(scaled_grad))
1690
+ rmsprop_update = scaled_grad / (
1691
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
1692
+
1693
+ if clip_by_scaled_gradient_norm:
1694
+ scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / (
1695
+ jnp.sqrt(float(rmsprop_update.size)))
1696
+ clipping_denom = jnp.maximum(
1697
+ 1., scaled_grad_norm / clip_by_scaled_gradient_norm)
1698
+ rmsprop_update /= clipping_denom
1699
+
1700
+ grafting_update = rmsprop_update
1701
+ else:
1702
+ grafting_update = sgd_update
1703
+
1704
+ precond_grad = grad
1705
+ if not _skip_preconditioning(param):
1706
+ precond_grad = preconditioner.preconditioned_grad(
1707
+ precond_grad,
1708
+ _maybe_dequantize_preconditioners(state.preconditioners))
1709
+ else:
1710
+ precond_grad = grafting_update
1711
+
1712
+ grafting_update_norm = jnp.linalg.norm(grafting_update)
1713
+ precond_grad_norm = jnp.linalg.norm(precond_grad)
1714
+
1715
+ multiplier = (grafting_update_norm / (precond_grad_norm + 1e-16))
1716
+ shampoo_update = precond_grad * multiplier
1717
+
1718
+ shampoo_update_with_wd = shampoo_update
1719
+ grafting_update_with_wd = grafting_update
1720
+ if weight_decay != 0:
1721
+ shampoo_update_with_wd = shampoo_update + weight_decay * param
1722
+ grafting_update_with_wd = grafting_update + weight_decay * param
1723
+
1724
+ w = (1.0 - beta1) if moving_average_for_momentum else 1.0
1725
+ shampoo_update_with_wd_momentum = (
1726
+ state.momentum.to_float() * beta1 + w * shampoo_update_with_wd)
1727
+ grafting_update_with_wd_momentum = (
1728
+ state.diagonal_momentum.to_float() * beta1 +
1729
+ w * grafting_update_with_wd)
1730
+
1731
+ run_shampoo = (step >= start_preconditioning_step).astype(
1732
+ grafting_update_with_wd_momentum.dtype)
1733
+
1734
+ momentum_update = (
1735
+ run_shampoo * shampoo_update_with_wd_momentum +
1736
+ (1.0 - run_shampoo) * grafting_update_with_wd_momentum)
1737
+
1738
+ wd_update = (
1739
+ run_shampoo * shampoo_update_with_wd +
1740
+ (1.0 - run_shampoo) * grafting_update_with_wd)
1741
+
1742
+ if nesterov:
1743
+ momentum_update = w * wd_update + beta1 * momentum_update
1744
+
1745
+ lr = learning_rate
1746
+ if callable(learning_rate):
1747
+ lr = learning_rate(step)
1748
+ transformed_update = -1.0 * lr * momentum_update
1749
+
1750
+ param_stats = ParameterStats(
1751
+ _quantize_diagonal_statistics(new_diagonal_statistics),
1752
+ state.statistics, state.preconditioners,
1753
+ _quantize_momentum(grafting_update_with_wd_momentum),
1754
+ _quantize_momentum(shampoo_update_with_wd_momentum))
1755
+ return transformed_update, param_stats
1756
+
1757
+ def update_fn(grads, state, params):
1758
+ """Transform the input gradient and update all statistics.
1759
+
1760
+ Args:
1761
+ grads: the gradient tensors for the parameters.
1762
+ state: a named tuple containing the state of the optimizer
1763
+ params: the parameters that should be updated.
1764
+
1765
+ Returns:
1766
+ A tuple containing the new parameters and the new optimizer state.
1767
+ """
1768
+ params_flat, treedef = jax.tree_flatten(params)
1769
+ stats_flat = treedef.flatten_up_to(state.stats)
1770
+ grads_flat = treedef.flatten_up_to(grads)
1771
+
1772
+ new_stats_flat = jax.tree_multimap(
1773
+ lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
1774
+ stats_flat, params_flat)
1775
+ new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat,
1776
+ state.count)
1777
+
1778
+ outputs = jax.tree_multimap(
1779
+ lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
1780
+ new_stats_flat, params_flat)
1781
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
1782
+
1783
+ updates = jax.tree_unflatten(treedef, updates_flat)
1784
+ new_stats = jax.tree_unflatten(treedef, new_stats_flat)
1785
+
1786
+ new_state = ShampooState(
1787
+ count=state.count+1, stats=new_stats)
1788
+ return updates, new_state
1789
+
1790
+ if shard_optimizer_states:
1791
+ # Hijacks the init_fn signature so we can return an OptState with
1792
+ # appropriate init_fns.
1793
+ def _init_fns(unused_params):
1794
+ return InitFnState(
1795
+ init_fn=sharded_init_fn,
1796
+ pspec_fn=sharded_init_partition_spec_fn,
1797
+ shape_and_dtype_fn=sharded_init_shape_and_dtype_fn)
1798
+
1799
+ return optax.GradientTransformation(_init_fns, sharded_update_fn)
1800
+ else:
1801
+ return optax.GradientTransformation(init_fn, update_fn)
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8fc602f9bddd0cad0b464f75463b0329adea24e06851d5740ce11781d6cc4ba
3
+ size 1419302302
flax_model_to_pytorch.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, FlaxAutoModelForCausalLM, AutoTokenizer
2
+ import torch
3
+ import numpy as np
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ def to_f32(t):
8
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
9
+
10
+ jax.config.update('jax_platform_name', 'cpu')
11
+ MODEL_PATH = "./"
12
+ model = FlaxAutoModelForCausalLM.from_pretrained(MODEL_PATH)
13
+ model.params = to_f32(model.params)
14
+ model.save_pretrained(MODEL_PATH)
15
+
16
+ pt_model = AutoModelForCausalLM.from_pretrained(
17
+ MODEL_PATH, from_flax=True).to('cpu')
18
+
19
+ input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
20
+ input_ids_pt = torch.tensor(input_ids)
21
+
22
+ logits_pt = pt_model(input_ids_pt).logits
23
+ print(logits_pt)
24
+ logits_fx = model(input_ids).logits
25
+ print(logits_fx)
26
+
27
+ pt_model.save_pretrained(MODEL_PATH)
merges.txt ADDED
The diff for this file is too large to render. See raw diff
replace_token_script.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''''This script was used to replace the final index of tokenizer.json and vocab.json
2
+ with "<|endoftext|>" token. Also reassociate the corresponding merges'''
3
+
4
+ import json
5
+
6
+ tokenizer_path = 'tokenizer.json'
7
+ model_config_path = 'config.json'
8
+ vocab_path = 'vocab.json'
9
+
10
+ with open(vocab_path, "r") as f:
11
+ vocab_data = json.load(f)
12
+
13
+ with open(tokenizer_path, "r") as f:
14
+ tokenizer_data = json.load(f)
15
+
16
+ with open(model_config_path, "r") as f:
17
+ model_config = json.load(f)
18
+
19
+ model_vocab_size = model_config['vocab_size']
20
+ tokenizer_vocab = tokenizer_data['model']['vocab']
21
+
22
+ mergeslength = len(tokenizer_data['model']['merges'])
23
+
24
+ #readjust added_tokens 'id' to model_vocab_size - 1
25
+ tokenizer_data['added_tokens'][-1]['id'] = model_vocab_size - 1
26
+
27
+ final_index = model_vocab_size - 1
28
+ eos = '<|endoftext|>'
29
+
30
+ #retrieve the key of final index
31
+ old_key_final_index_tokenizer = list(tokenizer_data['model']['vocab'].keys())[final_index]
32
+ old_key_final_index_vocab = list(vocab_data.keys())[final_index]
33
+ old_key_final_index_vocab_min2 = list(vocab_data.keys())[final_index - 1]
34
+ old_key_final_index_tokenizer_merges = tokenizer_data['model']['merges'][mergeslength - 1]
35
+
36
+ print(f"old_key_final_index_tokenizer = {old_key_final_index_tokenizer}")
37
+ print(f"old_key_final_index_vocab = {old_key_final_index_vocab}")
38
+ print(f"old_key_final_index_vocab_min2 = {old_key_final_index_vocab_min2}")
39
+ print(f"old_key_final_index_tokenizer_merges = {old_key_final_index_tokenizer_merges}")
40
+
41
+ #replace old key with new key
42
+ tokenizer_data['model']['vocab']['<|endoftext|>'] = tokenizer_data['model']['vocab'][old_key_final_index_tokenizer]
43
+ vocab_data[eos] = vocab_data[old_key_final_index_vocab]
44
+
45
+ #replace the final merges idx with vocab_data - 1
46
+ tokenizer_data['model']['merges'] = tokenizer_data['model']['merges'][: mergeslength - 1]
47
+
48
+
49
+ #delete old key
50
+ del tokenizer_data['model']['vocab'][old_key_final_index_tokenizer]
51
+ del vocab_data[old_key_final_index_vocab]
52
+
53
+ #check updated key
54
+ old_key_final_index_tokenizer = list(tokenizer_data['model']['vocab'].keys())[final_index]
55
+ old_key_final_index_vocab = list(vocab_data.keys())[final_index]
56
+ old_key_final_index_tokenizer_merges = tokenizer_data['model']['merges'][mergeslength - 2]
57
+
58
+ print(len(tokenizer_data['model']['merges']))
59
+ print()
60
+ print(f"updated old_key_final_index_tokenizer = {old_key_final_index_tokenizer}")
61
+ print(f"updated old_key_final_index_vocab = {old_key_final_index_vocab}")
62
+ print(f"updated old_key_final_index_tokenizer_merges = {old_key_final_index_tokenizer_merges}")
63
+
64
+ with open(tokenizer_path, "w")as f:
65
+ json.dump(tokenizer_data, f)
66
+
67
+ with open(vocab_path, "w")as f:
68
+ json.dump(vocab_data, f)
69
+
70
+ with open('merges.txt') as f:
71
+ lines = f.readlines()
72
+
73
+ with open("merges.txt", "w") as f:
74
+ for i in range(len(lines) - 1):
75
+ f.write(lines[i])
76
+
77
+ with open('merges.txt') as f:
78
+ newlines = f.readlines()
79
+
80
+ print(f"newlines[len(newlines) - 1] = {newlines[len(newlines) - 1]}")
run_clm_flax.py ADDED
@@ -0,0 +1,892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
+ https://huggingface.co/models?filter=text-generation
21
+ """
22
+ # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
23
+
24
+ import json
25
+ import logging
26
+ import math
27
+ import os
28
+ import sys
29
+ import time
30
+ import gc
31
+ from dataclasses import asdict, dataclass, field
32
+ from enum import Enum
33
+ from itertools import chain
34
+ from pathlib import Path
35
+ from typing import Callable, Optional
36
+
37
+ import datasets
38
+ import numpy as np
39
+ from datasets import Dataset, load_dataset, load_from_disk
40
+ from tqdm import tqdm
41
+
42
+ import jax
43
+ import jax.numpy as jnp
44
+ import optax
45
+ import transformers
46
+ from flax import jax_utils, traverse_util
47
+ from flax.jax_utils import unreplicate
48
+ from flax.training import train_state
49
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
50
+ from huggingface_hub import Repository
51
+ from transformers import (
52
+ CONFIG_MAPPING,
53
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
54
+ AutoConfig,
55
+ AutoTokenizer,
56
+ FlaxAutoModelForCausalLM,
57
+ HfArgumentParser,
58
+ is_tensorboard_available,
59
+ set_seed,
60
+ )
61
+ from transformers.file_utils import get_full_repo_name
62
+ from transformers.testing_utils import CaptureLogger
63
+
64
+ from distributed_shampoo import distributed_shampoo, GraftingType
65
+
66
+
67
+ logger = logging.getLogger(__name__)
68
+
69
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
70
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
71
+
72
+
73
+ @dataclass
74
+ class TrainingArguments:
75
+ output_dir: str = field(
76
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
77
+ )
78
+ overwrite_output_dir: bool = field(
79
+ default=False,
80
+ metadata={
81
+ "help": (
82
+ "Overwrite the content of the output directory. "
83
+ "Use this to continue training if output_dir points to a checkpoint directory."
84
+ )
85
+ },
86
+ )
87
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
88
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
89
+ per_device_train_batch_size: int = field(
90
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
91
+ )
92
+ per_device_eval_batch_size: int = field(
93
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
94
+ )
95
+ gradient_accumulation_steps: int = field(
96
+ default=1,
97
+ metadata={
98
+ "help": "Number of updates steps to accumulate before performing a backward/update pass."
99
+ },
100
+ )
101
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
102
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
103
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
104
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
105
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
106
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
107
+ distributed_shampoo: bool = field(
108
+ default=False, metadata={"help": "Use Distributed Shampoo optimizer instead of AdamW."},
109
+ )
110
+ quantize_shampoo: bool = field(
111
+ default=False, metadata={"help": "Quantize Distributed Shampoo optimizer."},
112
+ )
113
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
114
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
115
+ warmup_ratio: float = field(default=0.0, metadata={"help": "Linear warmup ratio of total train steps."})
116
+ cosine_decay: bool = field(
117
+ default=False, metadata={"help": "Whether or not to use cosine decay instead of the basic linear decay schedule."}
118
+ )
119
+ gradient_clipping: bool = field(
120
+ default=False, metadata={"help": "Whether or not to use gradient clipping."}
121
+ )
122
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
123
+ save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
124
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
125
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
126
+ push_to_hub: bool = field(
127
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
128
+ )
129
+ hub_model_id: str = field(
130
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
131
+ )
132
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
133
+
134
+ def __post_init__(self):
135
+ if self.output_dir is not None:
136
+ self.output_dir = os.path.expanduser(self.output_dir)
137
+
138
+ def to_dict(self):
139
+ """
140
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
141
+ the token values by removing their value.
142
+ """
143
+ d = asdict(self)
144
+ for k, v in d.items():
145
+ if isinstance(v, Enum):
146
+ d[k] = v.value
147
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
148
+ d[k] = [x.value for x in v]
149
+ if k.endswith("_token"):
150
+ d[k] = f"<{k.upper()}>"
151
+ return d
152
+
153
+
154
+ @dataclass
155
+ class ModelArguments:
156
+ """
157
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
158
+ """
159
+
160
+ model_name_or_path: Optional[str] = field(
161
+ default=None,
162
+ metadata={
163
+ "help": "The model checkpoint for weights initialization."
164
+ "Don't set if you want to train a model from scratch."
165
+ },
166
+ )
167
+ model_type: Optional[str] = field(
168
+ default=None,
169
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
170
+ )
171
+ config_name: Optional[str] = field(
172
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
173
+ )
174
+ tokenizer_name: Optional[str] = field(
175
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
176
+ )
177
+ cache_dir: Optional[str] = field(
178
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
179
+ )
180
+ use_fast_tokenizer: bool = field(
181
+ default=True,
182
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
183
+ )
184
+ dtype: Optional[str] = field(
185
+ default="float32",
186
+ metadata={
187
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
188
+ },
189
+ )
190
+
191
+
192
+ @dataclass
193
+ class DataTrainingArguments:
194
+ """
195
+ Arguments pertaining to what data we are going to input our model for training and eval.
196
+ """
197
+
198
+ dataset_name: Optional[str] = field(
199
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
200
+ )
201
+ dataset_config_name: Optional[str] = field(
202
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
203
+ )
204
+ dataset_filepath: Optional[str] = field(
205
+ default=None, metadata={"help": "Filepath to locally saved HF Dataset (with 'dataset.save_to_disk' method) to use for training"}
206
+ )
207
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
208
+ validation_file: Optional[str] = field(
209
+ default=None,
210
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
211
+ )
212
+ max_train_samples: Optional[int] = field(
213
+ default=None,
214
+ metadata={
215
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
216
+ "value if set."
217
+ },
218
+ )
219
+ max_eval_samples: Optional[int] = field(
220
+ default=None,
221
+ metadata={
222
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
223
+ "value if set."
224
+ },
225
+ )
226
+ overwrite_cache: bool = field(
227
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
228
+ )
229
+ validation_split_percentage: Optional[int] = field(
230
+ default=5,
231
+ metadata={
232
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
233
+ },
234
+ )
235
+ block_size: Optional[int] = field(
236
+ default=None,
237
+ metadata={
238
+ "help": "Optional input sequence length after tokenization. "
239
+ "The training dataset will be truncated in block of this size for training. "
240
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
241
+ },
242
+ )
243
+ overwrite_cache: bool = field(
244
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
245
+ )
246
+ preprocessing_num_workers: Optional[int] = field(
247
+ default=None,
248
+ metadata={"help": "The number of processes to use for the preprocessing."},
249
+ )
250
+ keep_linebreaks: bool = field(
251
+ default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
252
+ )
253
+
254
+ def __post_init__(self):
255
+ if self.dataset_name is None and self.train_file is None and self.dataset_filepath is None and self.validation_file is None:
256
+ raise ValueError("Need either a dataset name or a training/validation file.")
257
+ else:
258
+ if self.train_file is not None:
259
+ extension = self.train_file.split(".")[-1]
260
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
261
+ if self.validation_file is not None:
262
+ extension = self.validation_file.split(".")[-1]
263
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
264
+
265
+
266
+ class TrainState(train_state.TrainState):
267
+ dropout_rng: jnp.ndarray
268
+
269
+ def replicate(self):
270
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
271
+
272
+
273
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
274
+ """
275
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
276
+ Shuffle batches if `shuffle` is `True`.
277
+ """
278
+ steps_per_epoch = len(dataset) // batch_size
279
+
280
+ if shuffle:
281
+ batch_idx = jax.random.permutation(rng, len(dataset))
282
+ else:
283
+ batch_idx = jnp.arange(len(dataset))
284
+
285
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
286
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
287
+
288
+ for idx in batch_idx:
289
+ batch = dataset[idx]
290
+ batch = {k: np.array(v) for k, v in batch.items()}
291
+
292
+ yield batch
293
+
294
+
295
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
296
+ summary_writer.scalar("train_time", train_time, step)
297
+
298
+ train_metrics = get_metrics(train_metrics)
299
+ for key, vals in train_metrics.items():
300
+ tag = f"train_{key}"
301
+ for i, val in enumerate(vals):
302
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
303
+
304
+
305
+ def write_eval_metric(summary_writer, eval_metrics, step):
306
+ for metric_name, value in eval_metrics.items():
307
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
308
+
309
+
310
+ def create_linear_learning_rate_fn(
311
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
312
+ ) -> Callable[[int], jnp.array]:
313
+ """Returns a linear warmup, linear decay learning rate function."""
314
+ steps_per_epoch = train_ds_size // train_batch_size
315
+ num_train_steps = steps_per_epoch * num_train_epochs
316
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
317
+ decay_fn = optax.linear_schedule(
318
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
319
+ )
320
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
321
+ return schedule_fn
322
+
323
+ def create_cosine_learning_rate_fn(
324
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
325
+ ) -> Callable[[int], jnp.array]:
326
+ """Returns a linear warmup, cosine decay learning rate function."""
327
+ steps_per_epoch = train_ds_size // train_batch_size
328
+ num_train_steps = steps_per_epoch * num_train_epochs
329
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
330
+ decay_fn = optax.cosine_decay_schedule(
331
+ init_value=learning_rate, decay_steps=num_train_steps - num_warmup_steps, alpha=0.1
332
+ )
333
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
334
+ return schedule_fn
335
+
336
+
337
+ def main():
338
+ # See all possible arguments in src/transformers/training_args.py
339
+ # or by passing the --help flag to this script.
340
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
341
+
342
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
343
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
344
+ # If we pass only one argument to the script and it's the path to a json file,
345
+ # let's parse it to get our arguments.
346
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
347
+ else:
348
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
349
+
350
+ if (
351
+ os.path.exists(training_args.output_dir)
352
+ and os.listdir(training_args.output_dir)
353
+ and training_args.do_train
354
+ and not training_args.overwrite_output_dir
355
+ ):
356
+ raise ValueError(
357
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
358
+ "Use --overwrite_output_dir to overcome."
359
+ )
360
+
361
+ # Make one log on every process with the configuration for debugging.
362
+ logging.basicConfig(
363
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
364
+ datefmt="%m/%d/%Y %H:%M:%S",
365
+ level=logging.INFO,
366
+ )
367
+ # Setup logging, we only want one process per machine to log things on the screen.
368
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
369
+ if jax.process_index() == 0:
370
+ datasets.utils.logging.set_verbosity_warning()
371
+ transformers.utils.logging.set_verbosity_info()
372
+ else:
373
+ datasets.utils.logging.set_verbosity_error()
374
+ transformers.utils.logging.set_verbosity_error()
375
+
376
+ # Set the verbosity to info of the Transformers logger (on main process only):
377
+ logger.info(f"Training/evaluation parameters {training_args}")
378
+
379
+ # Set seed before initializing model.
380
+ set_seed(training_args.seed)
381
+
382
+ # Handle the repository creation
383
+ if training_args.push_to_hub:
384
+ if training_args.hub_model_id is None:
385
+ repo_name = get_full_repo_name(
386
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
387
+ )
388
+ else:
389
+ repo_name = training_args.hub_model_id
390
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
391
+
392
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
393
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
394
+ # (the dataset will be downloaded automatically from the datasets Hub).
395
+ #
396
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
397
+ # 'text' is found. You can easily tweak this behavior (see below).
398
+ #
399
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
400
+ # download the dataset.
401
+ if data_args.dataset_name is not None:
402
+ # Downloading and loading a dataset from the hub.
403
+ dataset = load_dataset(
404
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
405
+ )
406
+
407
+ if "validation" not in dataset.keys():
408
+ dataset["validation"] = load_dataset(
409
+ data_args.dataset_name,
410
+ data_args.dataset_config_name,
411
+ split=f"train[:{data_args.validation_split_percentage}%]",
412
+ cache_dir=model_args.cache_dir,
413
+ )
414
+ dataset["train"] = load_dataset(
415
+ data_args.dataset_name,
416
+ data_args.dataset_config_name,
417
+ split=f"train[{data_args.validation_split_percentage}%:]",
418
+ cache_dir=model_args.cache_dir,
419
+ )
420
+
421
+ elif data_args.dataset_filepath is not None:
422
+ # Loading a dataset from local file.
423
+ dataset = load_from_disk(data_args.dataset_filepath)
424
+ if "validation" not in dataset.keys():
425
+ dataset = datasets.train_test_split(test_size=data_args.validation_split_percentage/100)
426
+ dataset["validation"] = dataset["test"]
427
+ del dataset["test"]
428
+
429
+ else:
430
+ data_files = {}
431
+ dataset_args = {}
432
+ if data_args.train_file is not None:
433
+ data_files["train"] = data_args.train_file
434
+ if data_args.validation_file is not None:
435
+ data_files["validation"] = data_args.validation_file
436
+ extension = data_args.train_file.split(".")[-1]
437
+ if extension == "txt":
438
+ extension = "text"
439
+ dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
440
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args)
441
+
442
+ if "validation" not in dataset.keys():
443
+ dataset["validation"] = load_dataset(
444
+ extension,
445
+ data_files=data_files,
446
+ split=f"train[:{data_args.validation_split_percentage}%]",
447
+ cache_dir=model_args.cache_dir,
448
+ **dataset_args,
449
+ )
450
+ dataset["train"] = load_dataset(
451
+ extension,
452
+ data_files=data_files,
453
+ split=f"train[{data_args.validation_split_percentage}%:]",
454
+ cache_dir=model_args.cache_dir,
455
+ **dataset_args,
456
+ )
457
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
458
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
459
+
460
+ # Load pretrained model and tokenizer
461
+
462
+ # Distributed training:
463
+ # The .from_pretrained methods guarantee that only one local process can concurrently
464
+ # download model & vocab.
465
+ if model_args.config_name:
466
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
467
+ elif model_args.model_name_or_path:
468
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
469
+ else:
470
+ config = CONFIG_MAPPING[model_args.model_type]()
471
+ logger.warning("You are instantiating a new config instance from scratch.")
472
+
473
+ if model_args.tokenizer_name:
474
+ tokenizer = AutoTokenizer.from_pretrained(
475
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
476
+ )
477
+ elif model_args.model_name_or_path:
478
+ tokenizer = AutoTokenizer.from_pretrained(
479
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
480
+ )
481
+ else:
482
+ raise ValueError(
483
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
484
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
485
+ )
486
+
487
+ if tokenizer.pad_token is None:
488
+ tokenizer.pad_token = tokenizer.eos_token
489
+
490
+ if model_args.model_name_or_path:
491
+ model = FlaxAutoModelForCausalLM.from_pretrained(
492
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
493
+ )
494
+ else:
495
+ model = FlaxAutoModelForCausalLM.from_config(
496
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
497
+ )
498
+
499
+ # Preprocessing the datasets.
500
+ # First we tokenize all the texts.
501
+ if training_args.do_train:
502
+ column_names = dataset["train"].column_names
503
+ else:
504
+ column_names = dataset["validation"].column_names
505
+ text_column_name = "text" if "text" in column_names else column_names[0]
506
+
507
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
508
+ tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
509
+
510
+ def tokenize_function(examples):
511
+ with CaptureLogger(tok_logger) as cl:
512
+ output = tokenizer(examples[text_column_name])
513
+ # clm input could be much much longer than block_size
514
+ if "Token indices sequence length is longer than the" in cl.out:
515
+ tok_logger.warning(
516
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
517
+ )
518
+ return output
519
+
520
+ tokenized_datasets = dataset.map(
521
+ tokenize_function,
522
+ batched=True,
523
+ num_proc=data_args.preprocessing_num_workers,
524
+ remove_columns=column_names,
525
+ load_from_cache_file=not data_args.overwrite_cache,
526
+ )
527
+
528
+ if data_args.block_size is None:
529
+ block_size = tokenizer.model_max_length
530
+ if block_size > config.max_position_embeddings:
531
+ logger.warning(
532
+ f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
533
+ "Picking 1024 instead. You can change that default value by passing --block_size xxx."
534
+ )
535
+ block_size = 1024
536
+ else:
537
+ if data_args.block_size > tokenizer.model_max_length:
538
+ logger.warning(
539
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
540
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
541
+ )
542
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
543
+
544
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
545
+ def group_texts(examples):
546
+ # Concatenate all texts.
547
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
548
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
549
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
550
+ # customize this part to your needs.
551
+ if total_length >= block_size:
552
+ total_length = (total_length // block_size) * block_size
553
+ # Split by chunks of max_len.
554
+ result = {
555
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
556
+ for k, t in concatenated_examples.items()
557
+ }
558
+ result["labels"] = result["input_ids"].copy()
559
+ return result
560
+
561
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
562
+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
563
+ # to preprocess.
564
+ #
565
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
566
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
567
+
568
+ lm_datasets = tokenized_datasets.map(
569
+ group_texts,
570
+ batched=True,
571
+ num_proc=data_args.preprocessing_num_workers,
572
+ load_from_cache_file=not data_args.overwrite_cache,
573
+ )
574
+
575
+ if training_args.do_train:
576
+ if "train" not in tokenized_datasets:
577
+ raise ValueError("--do_train requires a train dataset")
578
+ train_dataset = lm_datasets["train"]
579
+ if data_args.max_train_samples is not None:
580
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
581
+
582
+ # test to see that tokenization worked
583
+ detokenized_example = tokenizer.decode(train_dataset[0]["input_ids"])
584
+ logger.info(f"Detokenized example: {detokenized_example}")
585
+ detokenized_example = tokenizer.decode(train_dataset[-1]["input_ids"])
586
+ logger.info(f"Detokenized example 2: {detokenized_example}")
587
+
588
+ if training_args.do_eval:
589
+ if "validation" not in tokenized_datasets:
590
+ raise ValueError("--do_eval requires a validation dataset")
591
+ eval_dataset = lm_datasets["validation"]
592
+ if data_args.max_eval_samples is not None:
593
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
594
+
595
+ # Enable tensorboard only on the master node
596
+ has_tensorboard = is_tensorboard_available()
597
+ if has_tensorboard and jax.process_index() == 0:
598
+ try:
599
+ from flax.metrics.tensorboard import SummaryWriter
600
+
601
+ summary_writer = SummaryWriter(log_dir=Path(os.path.join(training_args.output_dir, "runs")))
602
+ except ImportError as ie:
603
+ has_tensorboard = False
604
+ logger.warning(
605
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
606
+ )
607
+ else:
608
+ logger.warning(
609
+ "Unable to display metrics through TensorBoard because the package is not installed: "
610
+ "Please run pip install tensorboard to enable."
611
+ )
612
+
613
+ # Initialize our training
614
+ rng = jax.random.PRNGKey(training_args.seed)
615
+ rng, dropout_rng = jax.random.split(rng)
616
+
617
+ # Store some constant
618
+ num_epochs = int(training_args.num_train_epochs)
619
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * training_args.gradient_accumulation_steps
620
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
621
+ steps_per_epoch = len(train_dataset) // train_batch_size
622
+ total_train_steps = steps_per_epoch * num_epochs
623
+
624
+ if training_args.warmup_ratio > 0:
625
+ warmup_steps = int(total_train_steps * training_args.warmup_ratio)
626
+ else:
627
+ warmup_steps = training_args.warmup_steps
628
+
629
+ # Create learning rate schedule
630
+ if training_args.cosine_decay:
631
+ lr_schedule_fn = create_cosine_learning_rate_fn(
632
+ len(train_dataset),
633
+ train_batch_size,
634
+ training_args.num_train_epochs,
635
+ warmup_steps,
636
+ training_args.learning_rate,
637
+ )
638
+ else:
639
+ lr_schedule_fn = create_linear_learning_rate_fn(
640
+ len(train_dataset),
641
+ train_batch_size,
642
+ training_args.num_train_epochs,
643
+ warmup_steps,
644
+ training_args.learning_rate,
645
+ )
646
+
647
+ # We use Optax's "masking" functionality to not apply weight decay
648
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
649
+ # mask boolean with the same structure as the parameters.
650
+ # The mask is True for parameters that should be decayed.
651
+ # Note that this mask is specifically adapted for FlaxGPT2.
652
+ # For other models, one should correct the layer norm parameter naming
653
+ # accordingly.
654
+ def decay_mask_fn(params):
655
+ flat_params = traverse_util.flatten_dict(params)
656
+ flat_mask = {
657
+ path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
658
+ for path in flat_params
659
+ }
660
+ return traverse_util.unflatten_dict(flat_mask)
661
+
662
+ # create adam optimizer
663
+ if training_args.adafactor:
664
+ # We use the default parameters here to initialize adafactor,
665
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
666
+ optimizer = optax.adafactor(
667
+ learning_rate=lr_schedule_fn,
668
+ )
669
+
670
+ elif training_args.distributed_shampoo:
671
+ # parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
672
+ # Notes:
673
+ # - mask for weight decay is not implemented but we don't use it anyway
674
+ optimizer = distributed_shampoo(
675
+ lr_schedule_fn,
676
+ block_size=1536, # recommended default for large LM is 1536
677
+ beta1=0.9,
678
+ beta2=0.999,
679
+ diagonal_epsilon=1e-10,
680
+ matrix_epsilon=1e-8,
681
+ weight_decay=0.0,
682
+ start_preconditioning_step=1001,
683
+ preconditioning_compute_steps=10,
684
+ statistics_compute_steps=1,
685
+ best_effort_shape_interpretation=True,
686
+ graft_type=GraftingType.RMSPROP_NORMALIZED,
687
+ nesterov=False,
688
+ exponent_override=0,
689
+ batch_axis_name="batch",
690
+ inverse_failure_threshold=0.1,
691
+ moving_average_for_momentum=True,
692
+ skip_preconditioning_dim_size_gt=4096,
693
+ clip_by_scaled_gradient_norm=None,
694
+ precision=jax.lax.Precision.HIGHEST,
695
+ best_effort_memory_usage_reduction=training_args.quantize_shampoo,
696
+ )
697
+ else:
698
+ optimizer = optax.adamw(
699
+ learning_rate=lr_schedule_fn,
700
+ b1=training_args.adam_beta1,
701
+ b2=training_args.adam_beta2,
702
+ eps=training_args.adam_epsilon,
703
+ weight_decay=training_args.weight_decay,
704
+ mask=decay_mask_fn,
705
+ )
706
+ if training_args.gradient_clipping:
707
+ optimizer = optax.chain(
708
+ optax.clip_by_global_norm(1.),
709
+ optimizer
710
+ )
711
+
712
+ # add gradient accumulation
713
+ if training_args.gradient_accumulation_steps > 1:
714
+ optimizer = optax.chain(
715
+ optax.apply_every(training_args.gradient_accumulation_steps), optimizer
716
+ )
717
+
718
+ # Setup train state
719
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
720
+
721
+ def loss_fn(logits, labels):
722
+ shift_logits = logits[..., :-1, :]
723
+ shift_labels = labels[..., 1:]
724
+ loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
725
+ return loss.mean()
726
+
727
+ # Define gradient update step fn
728
+ def train_step(state, batch):
729
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
730
+
731
+ def compute_loss(params):
732
+ labels = batch.pop("labels")
733
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
734
+ loss = loss_fn(logits, labels)
735
+ return loss
736
+
737
+ grad_fn = jax.value_and_grad(compute_loss)
738
+ loss, grad = grad_fn(state.params)
739
+ grad = jax.lax.pmean(grad, "batch")
740
+
741
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
742
+
743
+ metrics = {"loss": loss, "learning_rate": lr_schedule_fn(state.step // training_args.gradient_accumulation_steps)}
744
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
745
+
746
+ return new_state, metrics
747
+
748
+ # Define eval fn
749
+ def eval_step(params, batch):
750
+ labels = batch.pop("labels")
751
+ logits = model(**batch, params=params, train=False)[0]
752
+ loss = loss_fn(logits, labels)
753
+
754
+ # summarize metrics
755
+ metrics = {"loss": loss}
756
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
757
+ return metrics
758
+
759
+ # Create parallel version of the train and eval step
760
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
761
+ p_eval_step = jax.pmap(eval_step, "batch")
762
+
763
+ # Replicate the train state on each device
764
+ state = state.replicate()
765
+
766
+ logger.info("***** Running training *****")
767
+ logger.info(f" Num examples = {len(train_dataset)}")
768
+ logger.info(f" Num Epochs = {num_epochs}")
769
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
770
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
771
+ logger.info(f" Total optimization steps = {total_train_steps}")
772
+
773
+ train_time = 0
774
+ train_metrics = []
775
+ epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
776
+ for epoch in epochs:
777
+ # ======================== Training ================================
778
+ train_start = time.time()
779
+
780
+ # Create sampling rng
781
+ rng, input_rng = jax.random.split(rng)
782
+
783
+ # Generate an epoch by shuffling sampling indices from the train dataset
784
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size // training_args.gradient_accumulation_steps, shuffle=True)
785
+ steps_per_epoch = len(train_dataset) // train_batch_size
786
+ # train
787
+ steps_trained_progress_bar = tqdm(range(steps_per_epoch), desc="Training...", position=1,
788
+ leave=False)
789
+ for step in range(steps_per_epoch * training_args.gradient_accumulation_steps):
790
+ batch = next(train_loader)
791
+ batch = shard(batch)
792
+ state, train_metric = p_train_step(state, batch)
793
+ train_metrics.append(train_metric)
794
+
795
+ cur_step = epoch * (steps_per_epoch*training_args.gradient_accumulation_steps) + step
796
+
797
+ if step % training_args.gradient_accumulation_steps == 0:
798
+ steps_trained_progress_bar.update(1)
799
+
800
+ if cur_step % (training_args.logging_steps * training_args.gradient_accumulation_steps) == 0 and cur_step > 0:
801
+ # Save metrics
802
+ train_metric = unreplicate(train_metric)
803
+ train_time += time.time() - train_start
804
+ if has_tensorboard and jax.process_index() == 0:
805
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
806
+
807
+ epochs.write(
808
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
809
+ )
810
+
811
+ train_metrics = []
812
+
813
+ if cur_step % (training_args.eval_steps * training_args.gradient_accumulation_steps) == 0 and cur_step > 0:
814
+ # ======================== Evaluating ==============================
815
+ eval_metrics = []
816
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
817
+ eval_steps = len(eval_dataset) // eval_batch_size
818
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
819
+ # Model forward
820
+ batch = next(eval_loader)
821
+ batch = shard(batch)
822
+ metrics = p_eval_step(state.params, batch)
823
+ eval_metrics.append(metrics)
824
+
825
+ # normalize eval metrics
826
+ eval_metrics = get_metrics(eval_metrics)
827
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
828
+
829
+ try:
830
+ eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
831
+ except OverflowError:
832
+ eval_metrics["perplexity"] = float("inf")
833
+
834
+ # Print metrics and update progress bar
835
+ desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
836
+ epochs.write(desc)
837
+ epochs.desc = desc
838
+
839
+ # Save metrics
840
+ if has_tensorboard and jax.process_index() == 0:
841
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
842
+
843
+ if cur_step % (training_args.save_steps * training_args.gradient_accumulation_steps) == 0 and cur_step > 0:
844
+ # save checkpoint after each epoch and push checkpoint to the hub
845
+ if jax.process_index() == 0:
846
+ params = jax.device_get(unreplicate(state.params))
847
+ model.save_pretrained(training_args.output_dir, params=params)
848
+ tokenizer.save_pretrained(training_args.output_dir)
849
+ if training_args.push_to_hub:
850
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
851
+
852
+ # save also at the end of epoch
853
+ try:
854
+ if jax.process_index() == 0:
855
+ params = jax.device_get(unreplicate(state.params))
856
+ model.save_pretrained(training_args.output_dir, params=params)
857
+ tokenizer.save_pretrained(training_args.output_dir)
858
+ if training_args.push_to_hub:
859
+ repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
860
+ except:
861
+ # push to hub fails the whole script if nothing new to commit
862
+ pass
863
+
864
+ # Eval after training
865
+ if training_args.do_eval:
866
+ eval_metrics = []
867
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
868
+ eval_steps = len(eval_dataset) // eval_batch_size
869
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
870
+ # Model forward
871
+ batch = shard(next(eval_loader))
872
+ metrics = p_eval_step(state.params, batch)
873
+ eval_metrics.append(metrics)
874
+
875
+ # normalize eval metrics
876
+ eval_metrics = get_metrics(eval_metrics)
877
+ eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)
878
+
879
+ try:
880
+ eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
881
+ except OverflowError:
882
+ eval_metrics["perplexity"] = float("inf")
883
+
884
+ if jax.process_index() == 0:
885
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
886
+ path = os.path.join(training_args.output_dir, "eval_results.json")
887
+ with open(path, "w") as f:
888
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
889
+
890
+
891
+ if __name__ == "__main__":
892
+ main()
runs/events.out.tfevents.1642710569.t1v-n-42145f73-w-0.1403347.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b72568b9aa0ed78b38fc232c126bf59229b152bbcacaae8358612edc7507a6ed
3
+ size 1471449
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>", "pad_token": "<|endoftext|>"}
start_train.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set train hyperparams
2
+ unset LD_PRELOAD
3
+ export HF_DATASETS_CACHE="/researchdisk/datasets_cache"
4
+ export USE_TORCH=0
5
+ python3 run_clm_flax.py \
6
+ --output_dir="./" \
7
+ --model_type="gpt2" \
8
+ --config_name="./" \
9
+ --tokenizer_name="./" \
10
+ --dataset_filepath="/researchdisk/training_dataset_full_deduplicated" \
11
+ --do_train --do_eval \
12
+ --block_size="512" \
13
+ --per_device_train_batch_size="16" \
14
+ --per_device_eval_batch_size="16" \
15
+ --preprocessing_num_workers="1" \
16
+ --adam_beta1="0.9" \
17
+ --adam_beta2="0.98" \
18
+ --learning_rate="1e-4" \
19
+ --weight_decay="0.01" \
20
+ --warmup_steps="4000" \
21
+ --cosine_decay \
22
+ --overwrite_output_dir \
23
+ --logging_steps="500" \
24
+ --eval_steps="10000" \
25
+ --save_steps="10000" \
26
+ --num_train_epochs="10" \
27
+ --dtype="bfloat16" \
28
+ --push_to_hub \
29
+ --hub_model_id="Finnish-NLP/gpt2-medium-finnish"
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "special_tokens_map_file": null, "name_or_path": "./", "tokenizer_class": "GPT2Tokenizer"}
train_tokenizer.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_from_disk
2
+ from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
3
+ from transformers import AutoTokenizer
4
+
5
+
6
+ model_dir = "./"
7
+
8
+ # load dataset
9
+ dataset = load_from_disk("/researchdisk/training_dataset_full_deduplicated")
10
+ dataset = dataset["train"]
11
+
12
+ # Instantiate tokenizer
13
+ tokenizer = ByteLevelBPETokenizer()
14
+ def batch_iterator(batch_size=1000):
15
+ for i in range(0, len(dataset), batch_size):
16
+ yield dataset[i: i + batch_size]["text"]
17
+
18
+ # Customized training
19
+ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50257, min_frequency=2, special_tokens=[
20
+ "<s>",
21
+ "<pad>",
22
+ "</s>",
23
+ "<unk>",
24
+ "<mask>",
25
+ ])
26
+
27
+ # Save files to disk
28
+ tokenizer.save(f"{model_dir}/tokenizer.json")
29
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
30
+ tokenizer.save_pretrained(model_dir)
vocab.json ADDED
The diff for this file is too large to render. See raw diff