versae commited on
Commit
f072d39
1 Parent(s): 30c6599

Testing shampoo

Browse files
.gitignore ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+ /model/
131
+
132
+ # VS Code
133
+ .vscode
134
+
135
+ wandb
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "roberta-base",
3
+ "architectures": [
4
+ "RobertaForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 768,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 514,
17
+ "model_type": "roberta",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "position_embedding_type": "absolute",
22
+ "transformers_version": "4.16.0.dev0",
23
+ "type_vocab_size": 1,
24
+ "use_cache": true,
25
+ "vocab_size": 50265
26
+ }
distributed_shampoo.py ADDED
@@ -0,0 +1,1609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # An implementation of distributed Shampoo optimizer from:
17
+ #
18
+ # Scalable Second Order Optimization for Deep Learning
19
+ # Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
20
+ # Preprint Paper: https://arxiv.org/abs/2002.09018
21
+ #
22
+ # This implementation moves computation of inverse pth root back to the
23
+ # accelerator (if higher precision is available).
24
+ #
25
+ # Authors: Rohan Anil (rohananil at google dot com)
26
+ # & Vineet Gupta (vineet at google dot com)
27
+ #
28
+
29
+ """Distributed Shampoo Implementation."""
30
+
31
+ import enum
32
+ import functools
33
+ import itertools
34
+ from typing import Any, List, NamedTuple
35
+
36
+ import chex
37
+ from flax import struct
38
+ import jax
39
+ from jax import lax
40
+ import jax.experimental.pjit as pjit
41
+ import jax.numpy as jnp
42
+ import numpy as np
43
+ import optax
44
+
45
+
46
+ # pylint:disable=no-value-for-parameter
47
+ @struct.dataclass
48
+ class QuantizedValue:
49
+ """State associated with quantized value."""
50
+ quantized: chex.Array
51
+ diagonal: chex.Array # Diagonal (if extract_diagonal is set)
52
+ bucket_size: chex.Array
53
+ quantized_dtype: jnp.dtype = struct.field(
54
+ pytree_node=False) # Dtype for the quantized value.
55
+ extract_diagonal: bool = struct.field(
56
+ pytree_node=False) # In case its centered.
57
+ shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
58
+
59
+ @classmethod
60
+ def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
61
+ if isinstance(fvalue, list) and not fvalue:
62
+ return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
63
+ quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
64
+ fvalue, quantized_dtype, extract_diagonal)
65
+ return QuantizedValue(quantized, diagonal_fvalue, bucket_size,
66
+ quantized_dtype, extract_diagonal,
67
+ list(quantized.shape))
68
+
69
+ # Quantization is from Lingvo JAX optimizers.
70
+ # We extend it for int16 quantization of PSD matrices.
71
+ @classmethod
72
+ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
73
+ """Returns quantized value and the bucket."""
74
+ if quantized_dtype == jnp.float32:
75
+ return fvalue, [], []
76
+ elif quantized_dtype == jnp.bfloat16:
77
+ return fvalue.astype(jnp.bfloat16), [], []
78
+
79
+ float_dtype = fvalue.dtype
80
+ if quantized_dtype == jnp.int8:
81
+ # value -128 is not used.
82
+ num_buckets = jnp.array(127.0, dtype=float_dtype)
83
+ elif quantized_dtype == jnp.int16:
84
+ # value -32768 is not used.
85
+ num_buckets = jnp.array(32767.0, dtype=float_dtype)
86
+ else:
87
+ raise ValueError(f'Quantized dtype {quantized_dtype} not supported.')
88
+ # max value is mapped to num_buckets
89
+
90
+ if extract_diagonal and fvalue.ndim != 2:
91
+ raise ValueError(
92
+ f'Input array {fvalue} must be 2D to work with extract_diagonal.')
93
+
94
+ diagonal_fvalue = []
95
+ if extract_diagonal:
96
+ diagonal_fvalue = jnp.diag(fvalue)
97
+ # Remove the diagonal entries.
98
+ fvalue = fvalue - jnp.diag(diagonal_fvalue)
99
+
100
+ # TODO(rohananil): Extend this by making use of information about the blocks
101
+ # SM3 style which will be useful for diagonal statistics
102
+ # We first decide the scale.
103
+ if fvalue.ndim < 1:
104
+ raise ValueError(
105
+ f'Input array {fvalue} must have a strictly positive number of '
106
+ 'dimensions.')
107
+
108
+ max_abs = jnp.max(jnp.abs(fvalue), axis=0)
109
+ bucket_size = max_abs / num_buckets
110
+ bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
111
+ # To avoid divide by 0.0
112
+ bs_nonzero = jnp.where(bs_expanded > 0.0, bs_expanded,
113
+ jnp.ones_like(bs_expanded))
114
+ ratio = fvalue / bs_nonzero
115
+ # We use rounding to remove bias.
116
+ quantized = jnp.round(ratio)
117
+ return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size
118
+
119
+ def to_float(self):
120
+ """Returns the float value."""
121
+ if isinstance(self.quantized, list) and not self.quantized:
122
+ return self.quantized
123
+
124
+ if self.quantized_dtype == jnp.float32:
125
+ return self.quantized
126
+
127
+ if self.quantized_dtype == jnp.bfloat16:
128
+ return self.quantized.astype(jnp.float32)
129
+
130
+ float_dtype = self.bucket_size.dtype
131
+ bucket_size = self.bucket_size[jnp.newaxis, Ellipsis]
132
+ val = self.quantized.astype(float_dtype) * bucket_size
133
+ if self.extract_diagonal:
134
+ val += jnp.diag(self.diagonal)
135
+ return val
136
+
137
+
138
+ # Per parameter optimizer state used in data-parallel training.
139
+ class ParameterStats(NamedTuple):
140
+ """State associated to each parameter of the model being trained."""
141
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
142
+ statistics: List[Any] # Statistics (QuantizedValue, chex.Array)
143
+ preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
144
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
145
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
146
+
147
+
148
+ # For training extremely large model; We keep a global state with a concatenated
149
+ # statistics and preconditioner states for all vars. This is so that we can
150
+ # annotate the leading axis to be sharded to save memory at the cost of
151
+ # communication.
152
+ @struct.dataclass
153
+ class GlobalShardedParameterStats:
154
+ statistics: chex.Array # Statistics
155
+ preconditioners: chex.Array # Preconditioners
156
+
157
+
158
+ # These are per-parameter local states; All statistics here mirror the parameter
159
+ # Thus the sharding is copied over from the param specification.
160
+ @struct.dataclass
161
+ class LocalShardedParameterStats:
162
+ """State associated to each parameter of the model being trained."""
163
+ diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
164
+ diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
165
+ momentum: QuantizedValue # Momentum for the shampoo preconditioner
166
+ index_start: np.int32 = struct.field(
167
+ pytree_node=False) # Index into global statistics array
168
+ sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
169
+
170
+
171
+ class ShardedShampooStats(NamedTuple):
172
+ """Shampoo state in sharded mode."""
173
+ global_stats: Any
174
+ local_stats: Any
175
+
176
+
177
+ class ShampooState(NamedTuple):
178
+ count: chex.Array
179
+ stats: Any
180
+
181
+
182
+ class GraftingType(enum.IntEnum):
183
+ SGD = 1
184
+ ADAGRAD = 2
185
+ RMSPROP = 3
186
+ RMSPROP_NORMALIZED = 4
187
+
188
+
189
+ def power_iteration(
190
+ matrix,
191
+ num_iters=100,
192
+ error_tolerance=1e-6,
193
+ precision=lax.Precision.HIGHEST):
194
+ r"""Power iteration algorithm.
195
+
196
+ The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
197
+ a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
198
+ of `A`, and a vector v, which is the corresponding eigenvector of `A`.
199
+
200
+ References:
201
+ [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
202
+
203
+ Args:
204
+ matrix: the symmetric PSD matrix.
205
+ num_iters: Number of iterations.
206
+ error_tolerance: Iterative exit condition.
207
+ precision: precision XLA related flag, the available options are:
208
+ a) lax.Precision.DEFAULT (better step time, but not precise)
209
+ b) lax.Precision.HIGH (increased precision, slower)
210
+ c) lax.Precision.HIGHEST (best possible precision, slowest)
211
+
212
+ Returns:
213
+ eigen vector, eigen value
214
+ """
215
+ matrix_size = matrix.shape[-1]
216
+ def _iter_condition(state):
217
+ i, unused_v, unused_s, unused_s_v, run_step = state
218
+ return jnp.logical_and(i < num_iters, run_step)
219
+
220
+ def _iter_body(state):
221
+ """One step of power iteration."""
222
+ i, new_v, s, s_v, unused_run_step = state
223
+ new_v = new_v / jnp.linalg.norm(new_v)
224
+
225
+ s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision)
226
+ s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision)
227
+ return (i + 1, s_v, s_new, s_v,
228
+ jnp.greater(jnp.abs(s_new - s), error_tolerance))
229
+
230
+ # Figure out how to use step as seed for random.
231
+ v_0 = np.random.RandomState(1729).uniform(-1.0, 1.0,
232
+ matrix_size).astype(matrix.dtype)
233
+
234
+ init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
235
+ _, v_out, s_out, _, _ = lax.while_loop(
236
+ _iter_condition, _iter_body, init_state)
237
+ v_out = v_out / jnp.linalg.norm(v_out)
238
+ return v_out, s_out
239
+
240
+
241
+ def matrix_inverse_pth_root(
242
+ matrix,
243
+ p,
244
+ num_iters=100,
245
+ ridge_epsilon=1e-6,
246
+ error_tolerance=1e-6,
247
+ precision=lax.Precision.HIGHEST):
248
+ """Computes `matrix^(-1/p)`, where `p` is a positive integer.
249
+
250
+ This function uses the Coupled newton iterations algorithm for
251
+ the computation of a matrix's inverse pth root.
252
+
253
+
254
+ References:
255
+ [Functions of Matrices, Theory and Computation,
256
+ Nicholas J Higham, Pg 184, Eq 7.18](
257
+ https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
258
+
259
+ Args:
260
+ matrix: the symmetric PSD matrix whose power it to be computed
261
+ p: exponent, for p a positive integer.
262
+ num_iters: Maximum number of iterations.
263
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
264
+ error_tolerance: Error indicator, useful for early termination.
265
+ precision: precision XLA related flag, the available options are:
266
+ a) lax.Precision.DEFAULT (better step time, but not precise)
267
+ b) lax.Precision.HIGH (increased precision, slower)
268
+ c) lax.Precision.HIGHEST (best possible precision, slowest)
269
+
270
+ Returns:
271
+ matrix^(-1/p)
272
+ """
273
+
274
+ # We use float32 for the matrix inverse pth root.
275
+ # Switch to f64 if you have hardware that supports it.
276
+ matrix_size = matrix.shape[0]
277
+ alpha = jnp.asarray(-1.0 / p, jnp.float32)
278
+ identity = jnp.eye(matrix_size, dtype=jnp.float32)
279
+ _, max_ev = power_iteration(
280
+ matrix=matrix, num_iters=100,
281
+ error_tolerance=1e-6, precision=precision)
282
+ ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16)
283
+
284
+ def _unrolled_mat_pow_1(mat_m):
285
+ """Computes mat_m^1."""
286
+ return mat_m
287
+
288
+ def _unrolled_mat_pow_2(mat_m):
289
+ """Computes mat_m^2."""
290
+ return jnp.matmul(mat_m, mat_m, precision=precision)
291
+
292
+ def _unrolled_mat_pow_4(mat_m):
293
+ """Computes mat_m^4."""
294
+ mat_pow_2 = _unrolled_mat_pow_2(mat_m)
295
+ return jnp.matmul(
296
+ mat_pow_2, mat_pow_2, precision=precision)
297
+
298
+ def _unrolled_mat_pow_8(mat_m):
299
+ """Computes mat_m^4."""
300
+ mat_pow_4 = _unrolled_mat_pow_4(mat_m)
301
+ return jnp.matmul(
302
+ mat_pow_4, mat_pow_4, precision=precision)
303
+
304
+ def mat_power(mat_m, p):
305
+ """Computes mat_m^p, for p == 1, 2, 4 or 8.
306
+
307
+ Args:
308
+ mat_m: a square matrix
309
+ p: a positive integer
310
+
311
+ Returns:
312
+ mat_m^p
313
+ """
314
+ # We unrolled the loop for performance reasons.
315
+ exponent = jnp.round(jnp.log2(p))
316
+ return lax.switch(
317
+ jnp.asarray(exponent, jnp.int32), [
318
+ _unrolled_mat_pow_1,
319
+ _unrolled_mat_pow_2,
320
+ _unrolled_mat_pow_4,
321
+ _unrolled_mat_pow_8,
322
+ ], (mat_m))
323
+
324
+ def _iter_condition(state):
325
+ (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error,
326
+ run_step) = state
327
+ error_above_threshold = jnp.logical_and(
328
+ error > error_tolerance, run_step)
329
+ return jnp.logical_and(i < num_iters, error_above_threshold)
330
+
331
+ def _iter_body(state):
332
+ (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state
333
+ mat_m_i = (1 - alpha) * identity + alpha * mat_m
334
+ new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision)
335
+ new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision)
336
+ new_error = jnp.max(jnp.abs(new_mat_m - identity))
337
+ # sometimes error increases after an iteration before decreasing and
338
+ # converging. 1.2 factor is used to bound the maximal allowed increase.
339
+ return (i + 1, new_mat_m, new_mat_h, mat_h, new_error,
340
+ new_error < error * 1.2)
341
+
342
+ if matrix_size == 1:
343
+ resultant_mat_h = (matrix + ridge_epsilon)**alpha
344
+ error = 0
345
+ else:
346
+ damped_matrix = matrix + ridge_epsilon * identity
347
+
348
+ z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix))
349
+ new_mat_m_0 = damped_matrix * z
350
+ new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
351
+ new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
352
+ init_state = tuple(
353
+ [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True])
354
+ _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop(
355
+ _iter_condition, _iter_body, init_state)
356
+ error = jnp.max(jnp.abs(mat_m - identity))
357
+ is_converged = jnp.asarray(convergence, old_mat_h.dtype)
358
+ resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
359
+ resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype)
360
+ return resultant_mat_h, error
361
+
362
+
363
+ def merge_small_dims(shape_to_merge, max_dim):
364
+ """Merge small dimensions.
365
+
366
+ If there are some small dimensions, we collapse them:
367
+ e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
368
+ [1, 2, 768, 1, 2048] --> [2, 768, 2048]
369
+
370
+ Args:
371
+ shape_to_merge: Shape to merge small dimensions.
372
+ max_dim: Maximal dimension of output shape used in merging.
373
+
374
+ Returns:
375
+ Merged shape.
376
+ """
377
+ resulting_shape = []
378
+ product = 1
379
+ for d in shape_to_merge:
380
+ if product * d <= max_dim:
381
+ product *= d
382
+ else:
383
+ if product > 1:
384
+ resulting_shape.append(product)
385
+ product = d
386
+ if product > 1:
387
+ resulting_shape.append(product)
388
+ return resulting_shape
389
+
390
+
391
+ def pad_matrix(mat, max_size):
392
+ """Pad a matrix to a max_size.
393
+
394
+ Args:
395
+ mat: a matrix to pad.
396
+ max_size: matrix size requested.
397
+
398
+ Returns:
399
+ Given M returns [[M, 0], [0, I]]
400
+ """
401
+ size = mat.shape[0]
402
+ assert size <= max_size
403
+ if size == max_size:
404
+ return mat
405
+ pad_size = max_size - size
406
+ zs1 = jnp.zeros([size, pad_size], dtype=mat.dtype)
407
+ zs2 = jnp.zeros([pad_size, size], dtype=mat.dtype)
408
+ eye = jnp.eye(pad_size, dtype=mat.dtype)
409
+ mat = jnp.concatenate([mat, zs1], 1)
410
+ mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
411
+ return mat
412
+
413
+
414
+ def pad_vector(vec, max_size):
415
+ """Pad a vector to a max_size.
416
+
417
+ Args:
418
+ vec: a vector to pad.
419
+ max_size: matrix size requested.
420
+
421
+ Returns:
422
+ Given V returns [V, 0]
423
+ """
424
+ size = vec.shape[0]
425
+ assert size <= max_size
426
+ if size == max_size:
427
+ return vec
428
+ pad_size = max_size - size
429
+ zs1 = jnp.zeros([pad_size], dtype=vec.dtype)
430
+ return jnp.concatenate([vec, zs1], 0)
431
+
432
+
433
+ def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
434
+ """Avoids wasteful buffer allocation with XLA."""
435
+
436
+ def _iter_body(unused_state):
437
+ results = compute_fn(*args, **kwargs)
438
+ return tuple([False] + list(results))
439
+
440
+ def _iter_condition(state):
441
+ return state[0]
442
+
443
+ results = jax.lax.while_loop(_iter_condition, _iter_body,
444
+ tuple([predicate] + init_state))
445
+ return tuple(results[1:])
446
+
447
+
448
+ class BlockPartitioner:
449
+ """Partitions a tensor into smaller tensors."""
450
+
451
+ def __init__(self, param, block_size):
452
+ self._shape = param.shape
453
+ self._splits = []
454
+ split_sizes = []
455
+ # We split params into smaller blocks. Here we store the metadata to make
456
+ # that split.
457
+ for i, d in enumerate(param.shape):
458
+ if 0 < block_size < d:
459
+ # d-1, otherwise split appends a 0-size array.
460
+ nsplit = (d - 1) // block_size
461
+ indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size
462
+ sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size
463
+ sizes[-1] = d - indices[-1]
464
+ self._splits.append((i, indices))
465
+ split_sizes.append(sizes)
466
+ else:
467
+ split_sizes.append(np.array([d], dtype=np.int32))
468
+ self._num_splits = len(split_sizes)
469
+ self._preconditioner_shapes = []
470
+ for t in itertools.product(*split_sizes):
471
+ self._preconditioner_shapes.extend([[d, d] for d in t])
472
+
473
+ def shapes_for_preconditioners(self):
474
+ return self._preconditioner_shapes
475
+
476
+ def num_splits(self):
477
+ return self._num_splits
478
+
479
+ def partition(self, tensor):
480
+ """Partition tensor into blocks."""
481
+
482
+ assert tensor.shape == self._shape
483
+ tensors = [tensor]
484
+ for (i, indices) in self._splits:
485
+ tensors_local = []
486
+ for t in tensors:
487
+ tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
488
+ tensors = tensors_local
489
+ return tensors
490
+
491
+ def merge_partitions(self, partitions):
492
+ """Merge partitions back to original shape."""
493
+
494
+ for (i, indices) in reversed(self._splits):
495
+ n = len(indices) + 1
496
+ partial_merged_tensors = []
497
+ ind = 0
498
+ while ind < len(partitions):
499
+ partial_merged_tensors.append(
500
+ jnp.concatenate(partitions[ind:ind + n], axis=i))
501
+ ind += n
502
+ partitions = partial_merged_tensors
503
+ assert len(partitions) == 1
504
+ return partitions[0]
505
+
506
+
507
+ class Preconditioner:
508
+ """Compute statistics/shape from gradients for preconditioning."""
509
+
510
+ def __init__(self, param, block_size, best_effort_shape_interpretation):
511
+ self._original_shape = param.shape
512
+ self._transformed_shape = param.shape
513
+ if best_effort_shape_interpretation:
514
+ self._transformed_shape = merge_small_dims(self._original_shape,
515
+ block_size)
516
+ reshaped_param = jnp.reshape(param, self._transformed_shape)
517
+ self._partitioner = BlockPartitioner(reshaped_param, block_size)
518
+
519
+ def statistics_from_grad(self, grad):
520
+ """Compute statistics from gradients.
521
+
522
+ Args:
523
+ grad: Gradient to compute statistics from.
524
+
525
+ Returns:
526
+ A list of gradient statistics for each partition.
527
+ """
528
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
529
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
530
+ stats = []
531
+ for g in partitioned_grads:
532
+ g_stats = []
533
+ rank = len(g.shape)
534
+ for i in range(rank):
535
+ axes = list(range(i)) + list(range(i + 1, rank))
536
+ stat = jnp.tensordot(g, g, axes=(axes, axes))
537
+ g_stats.append(stat)
538
+ stats.extend(g_stats)
539
+ return stats
540
+
541
+ def shapes_for_preconditioners(self):
542
+ """Returns shape from statistics."""
543
+ return self._partitioner.shapes_for_preconditioners()
544
+
545
+ def exponent_for_preconditioner(self):
546
+ """Returns exponent to use for inverse-pth root M^{-1/p}."""
547
+ return 2 * len(self._transformed_shape)
548
+
549
+ def preconditioned_grad(self, grad, preconditioners):
550
+ """Precondition the gradient.
551
+
552
+ Args:
553
+ grad: A gradient tensor to precondition.
554
+ preconditioners: A list of preconditioners to apply.
555
+
556
+ Returns:
557
+ A preconditioned gradient.
558
+ """
559
+
560
+ reshaped_grad = jnp.reshape(grad, self._transformed_shape)
561
+ partitioned_grads = self._partitioner.partition(reshaped_grad)
562
+ preconditioned_partitioned_grads = []
563
+ num_splits = self._partitioner.num_splits()
564
+ for i, g in enumerate(partitioned_grads):
565
+ preconditioners_for_grad = preconditioners[i * num_splits:(i + 1) *
566
+ num_splits]
567
+ rank = len(g.shape)
568
+ precond_g = g
569
+ for j in range(rank):
570
+ precond_g = jnp.tensordot(
571
+ precond_g, preconditioners_for_grad[j], axes=[[0], [0]])
572
+ preconditioned_partitioned_grads.append(precond_g)
573
+ merged_grad = self._partitioner.merge_partitions(
574
+ preconditioned_partitioned_grads)
575
+ return jnp.reshape(merged_grad, self._original_shape)
576
+
577
+
578
+ def _convert_to_parameter_stats(global_stats, local_stat):
579
+ """Creates parameter stats from sharded stats."""
580
+ index_start = int(local_stat.index_start)
581
+ index_end = int(len(local_stat.sizes)) + index_start
582
+ statistics = global_stats.statistics[index_start:index_end, :, :]
583
+ preconditioners = global_stats.preconditioners[index_start:index_end, :, :]
584
+ new_statistics = []
585
+ new_preconditioners = []
586
+ for i, size in enumerate(local_stat.sizes):
587
+ new_statistics.append(statistics[i][:size, :size])
588
+ new_preconditioners.append(preconditioners[i][:size, :size])
589
+ return ParameterStats(local_stat.diagonal_statistics, new_statistics,
590
+ new_preconditioners, local_stat.diagonal_momentum,
591
+ local_stat.momentum)
592
+
593
+
594
+ def _convert_from_parameter_stats(parameter_stats, local_stats):
595
+ """Creates sharded stats from paramter stats."""
596
+ return LocalShardedParameterStats(parameter_stats.diagonal_statistics,
597
+ parameter_stats.diagonal_momentum,
598
+ parameter_stats.momentum,
599
+ local_stats.index_start, local_stats.sizes)
600
+
601
+
602
+ def batch(x, num_devices):
603
+ """Batch `x` so that so that leading axis is num_devices."""
604
+ n = len(x)
605
+ b = int(n / num_devices)
606
+ return jnp.stack([jnp.stack(x[idx:idx + b]) for idx in range(0, n, b)])
607
+
608
+
609
+ def unbatch(batched_values):
610
+ """Unbatch values across leading axis and return a list of elements."""
611
+ b1, b2 = batched_values.shape[0], batched_values.shape[1]
612
+ results = []
613
+ for v_array in jnp.split(batched_values, indices_or_sections=b1, axis=0):
614
+ v_array = jnp.squeeze(v_array)
615
+ # b2 = batches (number of preconditioner computation) per core.
616
+ if b2 > 1:
617
+ for v in jnp.split(v_array, indices_or_sections=b2, axis=0):
618
+ results.append(jnp.squeeze(v))
619
+ else:
620
+ results.append(v_array)
621
+ return results
622
+
623
+
624
+ def distributed_shampoo(
625
+ learning_rate,
626
+ block_size,
627
+ beta1=0.9,
628
+ beta2=0.999,
629
+ diagonal_epsilon=1e-10,
630
+ matrix_epsilon=1e-6,
631
+ weight_decay=0.0,
632
+ start_preconditioning_step=5,
633
+ preconditioning_compute_steps=1,
634
+ statistics_compute_steps=1,
635
+ best_effort_shape_interpretation=True,
636
+ graft_type=GraftingType.SGD,
637
+ nesterov=True,
638
+ exponent_override=0,
639
+ # Pass pmap 'batch axis name' in pmap mode.
640
+ batch_axis_name=None,
641
+ ### Only set following 3 params in pjit/spmd mode.
642
+ ### WARNING: Experimental
643
+ mesh_axis_names=None,
644
+ num_devices_for_pjit=None,
645
+ shard_optimizer_states=False,
646
+ ###
647
+ ### Experimental memory reduction mode
648
+ best_effort_memory_usage_reduction=False,
649
+ ###
650
+ inverse_failure_threshold=0.1,
651
+ moving_average_for_momentum=False,
652
+ skip_preconditioning_dim_size_gt=4096,
653
+ clip_by_scaled_gradient_norm=None,
654
+ precision=lax.Precision.HIGHEST):
655
+ """Distributed Shampoo optimizer.
656
+
657
+ Distributed Shampoo is a second-order preconditioned method (concretely, a
658
+ variant of full-matrix Adagrad), that provides significant convergence and
659
+ wall-clock time improvements compared to conventional first-order methods,
660
+ and that has been shown to scale to large state-of-the-art deep learning
661
+ models.
662
+
663
+ References:
664
+ Scalable Second Order Optimization for Deep Learning,
665
+ Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
666
+
667
+ Preprint: https://arxiv.org/abs/2002.09018
668
+
669
+ Args:
670
+ learning_rate: the step size used to update the parameters.
671
+ block_size: Block size for large layers (if > 0). Preconditioning compute
672
+ operation is cubic in the dimension of the tensor. Block size allows us to
673
+ chunk the layers into sub-layers of maximal dimension dictated by this
674
+ value. Use 128 as default (increase if you have compute budget).
675
+ beta1: momentum parameter.
676
+ beta2: second moment averaging parameter.
677
+ diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
678
+ to AdaGrad is enabled).
679
+ matrix_epsilon: epsilon to add to statistics before computing inverse pth
680
+ root. If you are running in f32 precision for inverse pth root
681
+ (recommended today) this can go upto 1e-6. If you have latest hardware
682
+ with native f64 precision, set this upto 1e-12.
683
+ weight_decay: Weight decay for regularization.
684
+ start_preconditioning_step: When to start Shampoo update before which
685
+ diagonal update is used. This is because we dont have enough information
686
+ to do stable inverse.
687
+ preconditioning_compute_steps: How often to compute preconditioner.
688
+ Performance tuning params for controlling memory and compute requirements.
689
+ Ideally set this and statistics_compute_steps params to 1.
690
+ statistics_compute_steps: How often to compute statistics.
691
+ best_effort_shape_interpretation: If there are some small dimensions,
692
+ collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
693
+ block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
694
+ graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
695
+ optimizer. This allows us to plugin the Shampoo optimizer into settings
696
+ where SGD/AdaGrad is already well tuned. Available options are:
697
+ GraftingType.SGD and GraftingType.ADAGRAD.
698
+ nesterov: Nesterov momentum.
699
+ exponent_override: Override the exponent used in matrix inverse.
700
+ batch_axis_name: labeled axis over pmap for data-parallel training the
701
+ optimizer used for.
702
+ mesh_axis_names: Axis names for the mesh (used in pjit).
703
+ num_devices_for_pjit: Number of devices to parallelize over when using pjit.
704
+ shard_optimizer_states: Shard optimizer states to save memory in model
705
+ parallel training.
706
+ best_effort_memory_usage_reduction: Best effort memory usage reduction.
707
+ diagonal_statistics -> jnp.bfloat16
708
+ momentum buffers (2x) -> jnp.int8
709
+ statistics, preconditioners -> jnp.int16 + diagonals
710
+ inverse_failure_threshold: numerics are hard and inverses fail sometimes; we
711
+ determine that using this threshold.
712
+ moving_average_for_momentum: Whether to use moving average for momentum
713
+ instead of exponential moving average.
714
+ skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
715
+ greater than this value.
716
+ clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful
717
+ when using RMSProp Grafting).
718
+ precision: precision XLA related flag, the available options are: a)
719
+ lax.Precision.DEFAULT (better step time, but not precise) b)
720
+ lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
721
+ (best possible precision, slowest)
722
+
723
+ Returns:
724
+ a GradientTransformation.
725
+ """
726
+
727
+ def quantized_dtype_for_momentum_buffers():
728
+ return jnp.int8 if best_effort_memory_usage_reduction else jnp.float32
729
+
730
+ # TODO(rohananil): Explore int8-16 quantization with non-linear bucket sizes.
731
+ def quantized_dtype_for_diagonal_statistics_buffers():
732
+ return jnp.bfloat16 if best_effort_memory_usage_reduction else jnp.float32
733
+
734
+ # Preconditioner and statistics are both stores as int16 in this mode.
735
+ # We take out the diagonal to make quantization easier.
736
+ def quantized_dtype_for_second_moment_statistics_buffers():
737
+ return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32
738
+
739
+ # Preconditioner and statistics are both stores as int16 in this mode.
740
+ # We take out the diagonal to make quantization easier.
741
+ def quantized_dtype_for_second_moment_preconditioner_buffers():
742
+ return jnp.int16 if best_effort_memory_usage_reduction and batch_axis_name else jnp.float32
743
+
744
+ def _to_float(maybe_quantized):
745
+ if isinstance(maybe_quantized, QuantizedValue):
746
+ return maybe_quantized.to_float()
747
+ else:
748
+ return maybe_quantized
749
+
750
+ def _maybe_quantize_statistics(statistics_list):
751
+ return _maybe_quantize_matrices_with_dtype(
752
+ statistics_list, quantized_dtype_for_second_moment_statistics_buffers())
753
+
754
+ def _maybe_quantize_preconditioners(statistics_list):
755
+ return _maybe_quantize_matrices_with_dtype(
756
+ statistics_list,
757
+ quantized_dtype_for_second_moment_preconditioner_buffers())
758
+
759
+ def _maybe_quantize_matrices_with_dtype(statistics_list, quantized_dtype):
760
+ if quantized_dtype != jnp.float32:
761
+ return ([
762
+ QuantizedValue.from_float_value(
763
+ s, quantized_dtype, extract_diagonal=True)
764
+ for s in statistics_list
765
+ ])
766
+ else:
767
+ return statistics_list
768
+
769
+ def _maybe_dequantize_preconditioners(preconditioner_list):
770
+ return _maybe_dequantize_matrices_with_dtype(
771
+ preconditioner_list,
772
+ quantized_dtype_for_second_moment_preconditioner_buffers())
773
+
774
+ def _maybe_dequantize_matrices_with_dtype(statistics_list, quantized_dtype):
775
+ if quantized_dtype != jnp.float32:
776
+ return [s.to_float() for s in statistics_list]
777
+ else:
778
+ return statistics_list
779
+
780
+ def _quantize_diagonal_statistics(diagonal_statistics):
781
+ return QuantizedValue.from_float_value(
782
+ diagonal_statistics, quantized_dtype_for_diagonal_statistics_buffers())
783
+
784
+ def _quantize_momentum(momentum_statistics):
785
+ return QuantizedValue.from_float_value(
786
+ momentum_statistics, quantized_dtype_for_momentum_buffers())
787
+
788
+ def sharded_init_fn(params):
789
+ params_flat, treedef = jax.tree_flatten(params)
790
+ # Find max size to pad to.
791
+ max_size = 0
792
+ for param in params_flat:
793
+ preconditioner = Preconditioner(param, block_size,
794
+ best_effort_shape_interpretation)
795
+ if not _skip_preconditioning(param):
796
+ shapes = preconditioner.shapes_for_preconditioners()
797
+ sizes = [s[0] for s in shapes]
798
+ max_size = max(max(sizes), max_size)
799
+
800
+ padded_statistics = []
801
+ padded_preconditioners = []
802
+ local_stats_flat = []
803
+ for param in params_flat:
804
+ preconditioner = Preconditioner(param, block_size,
805
+ best_effort_shape_interpretation)
806
+ shapes = preconditioner.shapes_for_preconditioners()
807
+ sizes = []
808
+
809
+ statistics = []
810
+ preconditioners = []
811
+ index_start = len(padded_statistics)
812
+ if not _skip_preconditioning(param):
813
+ sizes = [s[0] for s in shapes]
814
+ shapes = preconditioner.shapes_for_preconditioners()
815
+ statistics = [matrix_epsilon * jnp.eye(max_size) for s in shapes]
816
+ preconditioners = [jnp.eye(max_size) for s in shapes]
817
+ padded_statistics.extend(statistics)
818
+ padded_preconditioners.extend(preconditioners)
819
+
820
+ diagonal_statistics = []
821
+ if graft_type != GraftingType.SGD:
822
+ diagonal_statistics = jnp.zeros_like(param)
823
+ local_stats_flat.append(
824
+ LocalShardedParameterStats(
825
+ _quantize_diagonal_statistics(diagonal_statistics),
826
+ _quantize_momentum(jnp.zeros_like(param)),
827
+ _quantize_momentum(jnp.zeros_like(param)), index_start, sizes))
828
+
829
+ local_stats = jax.tree_unflatten(treedef, local_stats_flat)
830
+ # Pad the statistics and preconditioner matrices to be a multiple of
831
+ # num devices.
832
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
833
+ # is split on.
834
+ to_pad = -len(padded_statistics) % num_devices_for_pjit
835
+ padded_statistics.extend([
836
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
837
+ for _ in range(to_pad)
838
+ ])
839
+ padded_preconditioners.extend([
840
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
841
+ for _ in range(to_pad)
842
+ ])
843
+ global_stats = GlobalShardedParameterStats(
844
+ jnp.stack(padded_statistics), jnp.stack(padded_preconditioners))
845
+ return ShampooState(
846
+ count=jnp.zeros([], jnp.int32),
847
+ stats=ShardedShampooStats(global_stats, local_stats))
848
+
849
+ def sharded_update_fn(grads, state, params):
850
+ """Transform the input gradient and update all statistics in sharded mode.
851
+
852
+ Args:
853
+ grads: the gradient tensors for the parameters.
854
+ state: a named tuple containing the state of the optimizer
855
+ params: the parameters that should be updated.
856
+
857
+ Returns:
858
+ A tuple containing the new parameters and the new optimizer state.
859
+ """
860
+ params_flat, treedef = jax.tree_flatten(params)
861
+ grads_flat = treedef.flatten_up_to(grads)
862
+
863
+ global_stats = state.stats.global_stats
864
+ local_stats_flat = treedef.flatten_up_to(state.stats.local_stats)
865
+ stats_flat = [
866
+ _convert_to_parameter_stats(global_stats, local_stat)
867
+ for local_stat in local_stats_flat
868
+ ]
869
+ new_stats_flat = jax.tree_multimap(
870
+ lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
871
+ stats_flat, params_flat)
872
+
873
+ exponents = []
874
+ for stat, param in zip(new_stats_flat, params_flat):
875
+ num_statistics = len(stat.statistics)
876
+ if num_statistics > 0:
877
+ preconditioner = Preconditioner(param, block_size,
878
+ best_effort_shape_interpretation)
879
+ exponent = (
880
+ preconditioner.exponent_for_preconditioner()
881
+ if exponent_override == 0 else exponent_override)
882
+ exponents.extend([exponent] * num_statistics)
883
+
884
+ outputs = jax.tree_multimap(
885
+ lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
886
+ new_stats_flat, params_flat)
887
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
888
+
889
+ updates = jax.tree_unflatten(treedef, updates_flat)
890
+ # Create new local_stats
891
+ new_local_stats_flat = [
892
+ _convert_from_parameter_stats(new_stat, local_stat)
893
+ for new_stat, local_stat in zip(new_stats_flat, local_stats_flat)
894
+ ]
895
+ new_local_stats = jax.tree_unflatten(treedef, new_local_stats_flat)
896
+
897
+ max_size = global_stats.statistics.shape[1]
898
+ new_padded_statistics = []
899
+ for stat in new_stats_flat:
900
+ new_padded_statistics.extend(
901
+ [pad_matrix(stat, max_size) for stat in stat.statistics])
902
+
903
+ # Create global stats
904
+ # TODO(rohananil): Preconditioner is not updated every step, so cost of
905
+ # stack/pad can be obviated away.
906
+ # Pad the statistics and preconditioner matrices to be a multiple of
907
+ # num devices.
908
+ # TODO(rohananil): Relax to only the size of the mesh axis where the dim
909
+ # is split on.
910
+ to_pad = -len(new_padded_statistics) % num_devices_for_pjit
911
+ new_padded_statistics.extend([
912
+ jnp.eye(max_size, dtype=new_padded_statistics[0].dtype)
913
+ for _ in range(to_pad)
914
+ ])
915
+ exponents.extend([1 for _ in range(to_pad)])
916
+ new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
917
+ new_stacked_exponents = jnp.stack(exponents)
918
+ def _matrix_inverse_pth_root_vmap(xs, ps):
919
+ mi_pth_root = functools.partial(
920
+ matrix_inverse_pth_root,
921
+ ridge_epsilon=matrix_epsilon,
922
+ precision=precision)
923
+ preconditioners, errors = jax.vmap(mi_pth_root)(xs, ps)
924
+ return preconditioners, errors
925
+
926
+ def _internal_inverse_pth_root_all():
927
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
928
+ new_stacked_padded_statistics, new_stacked_exponents)
929
+ return preconditioners, errors
930
+
931
+ if preconditioning_compute_steps == 1:
932
+ new_preconditioners, errors = _internal_inverse_pth_root_all()
933
+ else:
934
+ # Passing statistics instead of preconditioners as they are similarly
935
+ # shaped tensors. Note statistics will be ignored as we are passing in
936
+ # a large init value for error.
937
+ preconditioners_init = new_stacked_padded_statistics
938
+ errors_init = np.stack([inverse_failure_threshold] * len(exponents))
939
+ init_state = [preconditioners_init, errors_init]
940
+ perform_step = state.count % preconditioning_compute_steps == 0
941
+ new_preconditioners, errors = efficient_cond(
942
+ perform_step, _internal_inverse_pth_root_all, init_state)
943
+
944
+ errors = errors.reshape((-1, 1, 1))
945
+ predicate = jnp.logical_or(
946
+ jnp.isnan(errors),
947
+ errors >= inverse_failure_threshold).astype(new_preconditioners.dtype)
948
+ # TODO(rohananil): Check for numerical instabilities.
949
+ new_conditional_preconditioners = (
950
+ predicate * global_stats.preconditioners +
951
+ (1.0 - predicate) * new_preconditioners)
952
+ new_global_stats = GlobalShardedParameterStats(
953
+ new_stacked_padded_statistics, new_conditional_preconditioners)
954
+ new_shampoo_state = ShampooState(
955
+ count=state.count + 1,
956
+ stats=ShardedShampooStats(new_global_stats, new_local_stats))
957
+ return updates, new_shampoo_state
958
+
959
+ def init_fn(params):
960
+ """Initialise the optimiser's state."""
961
+
962
+ def _init(param):
963
+ preconditioner = Preconditioner(param, block_size,
964
+ best_effort_shape_interpretation)
965
+ statistics = []
966
+ preconditioners = []
967
+ if not _skip_preconditioning(param):
968
+ shapes = preconditioner.shapes_for_preconditioners()
969
+ statistics = [matrix_epsilon * jnp.eye(s[0]) for s in shapes]
970
+ preconditioners = [jnp.eye(s[0]) for s in shapes]
971
+
972
+ diagonal_statistics = []
973
+ if graft_type != GraftingType.SGD:
974
+ diagonal_statistics = jnp.zeros_like(param)
975
+ return ParameterStats(
976
+ _quantize_diagonal_statistics(diagonal_statistics),
977
+ _maybe_quantize_statistics(statistics),
978
+ _maybe_quantize_preconditioners(preconditioners),
979
+ _quantize_momentum(jnp.zeros_like(param)),
980
+ _quantize_momentum(jnp.zeros_like(param)))
981
+ return ShampooState(
982
+ count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params))
983
+
984
+ def _skip_preconditioning(param):
985
+ return len(param.shape) < 1 or any(
986
+ [s > skip_preconditioning_dim_size_gt for s in param.shape])
987
+
988
+ def _compute_stats(grad, state, param, step):
989
+ """Compute per-parameter statistics."""
990
+ preconditioner = Preconditioner(param, block_size,
991
+ best_effort_shape_interpretation)
992
+ new_statistics = [[]] * len(state.statistics)
993
+ w1 = beta2
994
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
995
+ if not _skip_preconditioning(param):
996
+
997
+ def compute_updated_statistics():
998
+ new_stats = preconditioner.statistics_from_grad(grad)
999
+ new_stats_accumulators = []
1000
+ for stat, stat_accumulator in zip(new_stats, state.statistics):
1001
+ new_stats_accumulators.append(w1 * _to_float(stat_accumulator) +
1002
+ w2 * stat)
1003
+ return _maybe_quantize_statistics(new_stats_accumulators)
1004
+
1005
+ if statistics_compute_steps > 1:
1006
+ perform_step = step % statistics_compute_steps == 0
1007
+ init_state = state.statistics
1008
+ new_statistics = list(
1009
+ efficient_cond(perform_step, compute_updated_statistics,
1010
+ init_state))
1011
+ else:
1012
+ new_statistics = compute_updated_statistics()
1013
+ return ParameterStats(state.diagonal_statistics, new_statistics,
1014
+ state.preconditioners, state.diagonal_momentum,
1015
+ state.momentum)
1016
+
1017
+ def _matrix_inverse_pth_root_vmap(xs, ps):
1018
+ mi_pth_root = functools.partial(
1019
+ matrix_inverse_pth_root,
1020
+ ridge_epsilon=matrix_epsilon,
1021
+ precision=precision)
1022
+ return jax.vmap(mi_pth_root)(xs, ps)
1023
+
1024
+ def _quantized_matrix_inverse_pth_root_vmap(qxs, qds, qbs, ps):
1025
+
1026
+ def _quantized_to_float(qx, qd, qb):
1027
+ qv = QuantizedValue(qx, qd, qb, qx.dtype, True, list(qx.shape))
1028
+ return qv.to_float()
1029
+
1030
+ def matrix_inverse_pth_root_wrapper(qx, qd, qb, p):
1031
+ v = _quantized_to_float(qx, qd, qb)
1032
+ preconditioner, error = matrix_inverse_pth_root(
1033
+ v, p, ridge_epsilon=matrix_epsilon, precision=precision)
1034
+ qp = QuantizedValue.from_float_value(preconditioner, qx.dtype, True)
1035
+ return qp.quantized, qp.diagonal, qp.bucket_size, error
1036
+
1037
+ return jax.vmap(matrix_inverse_pth_root_wrapper)(qxs, qds, qbs, ps)
1038
+
1039
+ def _matrix_inverse_pth_root_pjit(xs, ps):
1040
+ mesh_axis_names_tuple = tuple(mesh_axis_names)
1041
+ # Partition the concatenated statistics matrix across all cores.
1042
+ partitioned_xs, partitioned_ps = pjit.pjit(
1043
+ lambda x, y: (x, y),
1044
+ in_axis_resources=None,
1045
+ out_axis_resources=pjit.PartitionSpec(mesh_axis_names_tuple,))(xs, ps)
1046
+ # Run matrix inverse pth root on each shard.
1047
+ partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
1048
+ partitioned_xs, partitioned_ps)
1049
+ # Recombine the outputs at each core.
1050
+ preconditioners, errors = pjit.pjit(
1051
+ lambda x, y: (x, y),
1052
+ in_axis_resources=(pjit.PartitionSpec(mesh_axis_names_tuple,),
1053
+ pjit.PartitionSpec(mesh_axis_names_tuple,)),
1054
+ out_axis_resources=(None, None))(partitioned_preconditioners,
1055
+ partitioned_errors)
1056
+ return preconditioners, errors
1057
+
1058
+ def _pmap_compute_preconditioners(states, step, statistics,
1059
+ num_statistics_per_state, original_shapes,
1060
+ exponents, max_size, prev_preconditioners):
1061
+ """Computes preconditioners for given statistics in states in PMAP mode.
1062
+
1063
+ Args:
1064
+ states: A list of optimizer states.
1065
+ step: Current step number
1066
+ statistics: A list of statistics for all variables (for every dim)
1067
+ num_statistics_per_state: Number of statistis per state to reconstruct
1068
+ output states.
1069
+ original_shapes: A list of shapes of the statistics.
1070
+ exponents: Exponent power to use for inverse-pth roots.
1071
+ max_size: Maximum dim of the statistics to pad.
1072
+ prev_preconditioners: Previously available preconditioner.
1073
+
1074
+ Returns:
1075
+ New optimizer states after computing the preconditioner.
1076
+ """
1077
+ num_devices = lax.psum(1, batch_axis_name)
1078
+ num_statistics = len(statistics)
1079
+ # Pad statistics and exponents to next multiple of num_devices.
1080
+ packed_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1081
+ to_pad = -num_statistics % num_devices
1082
+ packed_statistics.extend([
1083
+ jnp.eye(max_size, dtype=packed_statistics[0].dtype)
1084
+ for _ in range(to_pad)
1085
+ ])
1086
+ exponents.extend([1 for _ in range(to_pad)])
1087
+
1088
+ if not packed_statistics:
1089
+ return states
1090
+
1091
+ all_statistics = batch(packed_statistics, num_devices)
1092
+ all_exponents = batch(exponents, num_devices)
1093
+
1094
+ def _internal_inverse_pth_root_all():
1095
+ current_replica = lax.axis_index(batch_axis_name)
1096
+ preconditioners, errors = _matrix_inverse_pth_root_vmap(
1097
+ all_statistics[current_replica], all_exponents[current_replica])
1098
+ preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
1099
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1100
+ preconditioners_flat = unbatch(preconditioners)
1101
+ errors_flat = unbatch(errors)
1102
+ return preconditioners_flat, errors_flat
1103
+
1104
+ if preconditioning_compute_steps == 1:
1105
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1106
+ else:
1107
+ # Passing statistics instead of preconditioners as they are similarly
1108
+ # shaped tensors. Note statistics will be ignored as we are passing in
1109
+ # a large init value for error.
1110
+ preconditioners_init = packed_statistics
1111
+ errors_init = ([inverse_failure_threshold] * len(packed_statistics))
1112
+ init_state = [preconditioners_init, errors_init]
1113
+ perform_step = step % preconditioning_compute_steps == 0
1114
+ preconditioners_flat, errors_flat = efficient_cond(
1115
+ perform_step, _internal_inverse_pth_root_all, init_state)
1116
+
1117
+ def _skip(error):
1118
+ condition = jnp.logical_or(
1119
+ jnp.isnan(error), error >= inverse_failure_threshold)
1120
+ return condition.astype(error.dtype)
1121
+
1122
+ def _select_preconditioner(error, new_p, old_p):
1123
+ return lax.cond(
1124
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
1125
+
1126
+ new_preconditioners_flat = []
1127
+ for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
1128
+ prev_preconditioners, errors_flat):
1129
+ new_preconditioners_flat.append(
1130
+ _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
1131
+
1132
+ assert len(states) == len(num_statistics_per_state)
1133
+ assert len(new_preconditioners_flat) == num_statistics
1134
+
1135
+ # Add back empty preconditioners so we that we can set the optimizer state.
1136
+ preconditioners_for_states = []
1137
+ idx = 0
1138
+ for num_statistics, state in zip(num_statistics_per_state, states):
1139
+ if num_statistics == 0:
1140
+ preconditioners_for_states.append([])
1141
+ else:
1142
+ preconditioners_for_state = new_preconditioners_flat[idx:idx +
1143
+ num_statistics]
1144
+ assert len(state.statistics) == len(preconditioners_for_state)
1145
+ preconditioners_for_states.append(preconditioners_for_state)
1146
+ idx += num_statistics
1147
+ new_states = []
1148
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1149
+ new_states.append(
1150
+ ParameterStats(state.diagonal_statistics, state.statistics,
1151
+ new_preconditioners, state.diagonal_momentum,
1152
+ state.momentum))
1153
+
1154
+ return new_states
1155
+
1156
+ def _pmap_quantized_compute_preconditioners(states, step, statistics,
1157
+ num_statistics_per_state,
1158
+ original_shapes, exponents,
1159
+ max_size, prev_preconditioners):
1160
+ """Computes preconditioners for given statistics in states in PMAP mode.
1161
+
1162
+ For quantization, each statistic is represented by three values:
1163
+ quantized matrix, diagonal, and bucket sizes, we run inverse pth-roots
1164
+ without ever recreating the original matrix in f32.
1165
+
1166
+ Args:
1167
+ states: A list of optimizer states.
1168
+ step: Current step number
1169
+ statistics: A list of statistics for all variables (for every dim)
1170
+ num_statistics_per_state: Number of statistis per state to reconstruct
1171
+ output states.
1172
+ original_shapes: A list of shapes of the statistics.
1173
+ exponents: Exponent power to use for inverse-pth roots.
1174
+ max_size: Maximum dim of the statistics to pad.
1175
+ prev_preconditioners: Previously available preconditioner.
1176
+
1177
+ Returns:
1178
+ New optimizer states after computing the preconditioner.
1179
+ """
1180
+ num_devices = lax.psum(1, batch_axis_name)
1181
+ num_statistics = len(statistics)
1182
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
1183
+ # Complexity here is around: shapes needing be statically shaped,
1184
+ # our custom quantization type requires a different type of packing.
1185
+
1186
+ # Parallel tensors:
1187
+ # quantized [dxd]
1188
+ # diagonals [d] f32
1189
+ # bucket_sizes [d] f32
1190
+ packed_quantized_statistics = [
1191
+ pad_matrix(stat.quantized, max_size) for stat in statistics
1192
+ ]
1193
+ packed_quantized_diagonals = [
1194
+ pad_vector(stat.diagonal, max_size) for stat in statistics
1195
+ ]
1196
+ packed_quantized_bucket_sizes = [
1197
+ pad_vector(stat.bucket_size, max_size) for stat in statistics
1198
+ ]
1199
+
1200
+ to_pad = -num_statistics % num_devices
1201
+ padded_eye = jnp.eye(max_size, dtype=jnp.float32)
1202
+ quantized_eye = QuantizedValue.from_float_value(padded_eye, quantized_dtype,
1203
+ True)
1204
+ packed_quantized_statistics.extend(
1205
+ [quantized_eye.quantized for _ in range(to_pad)])
1206
+ packed_quantized_diagonals.extend(
1207
+ [quantized_eye.diagonal for _ in range(to_pad)])
1208
+ packed_quantized_bucket_sizes.extend(
1209
+ [quantized_eye.bucket_size for _ in range(to_pad)])
1210
+ exponents.extend([1 for _ in range(to_pad)])
1211
+
1212
+ if not packed_quantized_statistics:
1213
+ return states
1214
+
1215
+ all_quantized_statistics = batch(packed_quantized_statistics, num_devices)
1216
+ all_quantized_diagonals = batch(packed_quantized_diagonals, num_devices)
1217
+ all_quantized_bucket_sizes = batch(packed_quantized_bucket_sizes,
1218
+ num_devices)
1219
+ all_exponents = batch(exponents, num_devices)
1220
+
1221
+ def _internal_inverse_pth_root_all():
1222
+ current_replica = lax.axis_index(batch_axis_name)
1223
+ quantized_preconditioners, quantized_diagonals, quantized_bucket_sizes, errors = (
1224
+ _quantized_matrix_inverse_pth_root_vmap(
1225
+ all_quantized_statistics[current_replica],
1226
+ all_quantized_diagonals[current_replica],
1227
+ all_quantized_bucket_sizes[current_replica],
1228
+ all_exponents[current_replica]))
1229
+ quantized_preconditioners = jax.lax.all_gather(quantized_preconditioners,
1230
+ batch_axis_name)
1231
+ quantized_diagonals = jax.lax.all_gather(quantized_diagonals,
1232
+ batch_axis_name)
1233
+ quantized_bucket_sizes = jax.lax.all_gather(quantized_bucket_sizes,
1234
+ batch_axis_name)
1235
+ errors = jax.lax.all_gather(errors, batch_axis_name)
1236
+ quantized_preconditioners_flat = unbatch(quantized_preconditioners)
1237
+ quantized_diagonals_flat = unbatch(quantized_diagonals)
1238
+ quantized_bucket_sizes_flat = unbatch(quantized_bucket_sizes)
1239
+ errors_flat = unbatch(errors)
1240
+ return (quantized_preconditioners_flat, quantized_diagonals_flat,
1241
+ quantized_bucket_sizes_flat, errors_flat)
1242
+
1243
+ if preconditioning_compute_steps == 1:
1244
+ (quantized_preconditioners_flat, quantized_diagonals_flat,
1245
+ quantized_bucket_sizes_flat, errors_flat) = (
1246
+ _internal_inverse_pth_root_all())
1247
+ else:
1248
+ # Passing statistics instead of preconditioners as they are similarly
1249
+ # shaped tensors. Note statistics will be ignored as we are passing in
1250
+ # a large init value for error.
1251
+ quantized_preconditioners_init = packed_quantized_statistics
1252
+ quantized_diagonals_init = packed_quantized_diagonals
1253
+ quantized_bucket_sizes_init = packed_quantized_bucket_sizes
1254
+ errors_init = ([inverse_failure_threshold] *
1255
+ len(quantized_preconditioners_init))
1256
+ init_state = [
1257
+ quantized_preconditioners_init, quantized_diagonals_init,
1258
+ quantized_bucket_sizes_init, errors_init
1259
+ ]
1260
+ perform_step = step % preconditioning_compute_steps == 0
1261
+ (quantized_preconditioners_flat, quantized_diagonals_flat,
1262
+ quantized_bucket_sizes_flat, errors_flat) = (
1263
+ efficient_cond(perform_step, _internal_inverse_pth_root_all,
1264
+ init_state))
1265
+
1266
+ def _skip(error):
1267
+ condition = jnp.logical_or(
1268
+ jnp.isnan(error), error >= inverse_failure_threshold)
1269
+ return condition.astype(error.dtype)
1270
+
1271
+ def _select_preconditioner(error, new_p, old_p):
1272
+ return lax.cond(
1273
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
1274
+
1275
+ new_quantized_preconditioners_flat = []
1276
+ new_quantized_diagonals_flat = []
1277
+ new_quantized_bucket_sizes_flat = []
1278
+ for p, d, b, shape, prev_p, error in zip(quantized_preconditioners_flat,
1279
+ quantized_diagonals_flat,
1280
+ quantized_bucket_sizes_flat,
1281
+ original_shapes,
1282
+ prev_preconditioners, errors_flat):
1283
+ new_quantized_preconditioners_flat.append(
1284
+ _select_preconditioner(error, p[:shape[0], :shape[1]],
1285
+ prev_p.quantized))
1286
+ new_quantized_diagonals_flat.append(
1287
+ _select_preconditioner(error, d[:shape[0]], prev_p.diagonal))
1288
+ new_quantized_bucket_sizes_flat.append(
1289
+ _select_preconditioner(error, b[:shape[0]], prev_p.bucket_size))
1290
+
1291
+ assert len(states) == len(num_statistics_per_state)
1292
+ assert len(new_quantized_preconditioners_flat) == num_statistics
1293
+ assert len(new_quantized_diagonals_flat) == num_statistics
1294
+ assert len(new_quantized_bucket_sizes_flat) == num_statistics
1295
+
1296
+ # Add back empty preconditioners so we that we can set the optimizer state.
1297
+ preconditioners_for_states = []
1298
+ idx = 0
1299
+ for num_statistics, state in zip(num_statistics_per_state, states):
1300
+ if num_statistics == 0:
1301
+ preconditioners_for_states.append([])
1302
+ else:
1303
+ quantized_preconditioners_for_state = new_quantized_preconditioners_flat[
1304
+ idx:idx + num_statistics]
1305
+ quantized_diagonals_for_state = new_quantized_diagonals_flat[
1306
+ idx:idx + num_statistics]
1307
+ quantized_bucket_sizes_for_state = new_quantized_bucket_sizes_flat[
1308
+ idx:idx + num_statistics]
1309
+
1310
+ assert len(state.statistics) == len(quantized_preconditioners_for_state)
1311
+ assert len(state.statistics) == len(quantized_diagonals_for_state)
1312
+ assert len(state.statistics) == len(quantized_bucket_sizes_for_state)
1313
+
1314
+ quantized_preconditioners = []
1315
+ for qv, qd, qb in zip(quantized_preconditioners_for_state,
1316
+ quantized_diagonals_for_state,
1317
+ quantized_bucket_sizes_for_state):
1318
+ quantized_preconditioners.append(
1319
+ QuantizedValue(qv, qd, qb, qv.dtype, True, list(qv.shape)))
1320
+ preconditioners_for_states.append(quantized_preconditioners)
1321
+ idx += num_statistics
1322
+ new_states = []
1323
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1324
+ new_states.append(
1325
+ ParameterStats(state.diagonal_statistics, state.statistics,
1326
+ new_preconditioners, state.diagonal_momentum,
1327
+ state.momentum))
1328
+
1329
+ return new_states
1330
+
1331
+ def _pjit_compute_preconditioners(states, step, statistics,
1332
+ num_statistics_per_state, original_shapes,
1333
+ exponents, max_size, prev_preconditioners):
1334
+ """Computes preconditioners for given statistics in states in PJIT mode.
1335
+
1336
+ Args:
1337
+ states: A list of optimizer states.
1338
+ step: Current step number
1339
+ statistics: A list of statistics for all variables (for every dim)
1340
+ num_statistics_per_state: Number of statistis per state to reconstruct
1341
+ output states.
1342
+ original_shapes: A list of shapes of the statistics.
1343
+ exponents: Exponent power to use for inverse-pth roots.
1344
+ max_size: Maximum dim of the statistics to pad.
1345
+ prev_preconditioners: Previously available preconditioner.
1346
+
1347
+ Returns:
1348
+ New optimizer states after computing the preconditioner.
1349
+ """
1350
+ num_statistics = len(statistics)
1351
+ to_pad = -num_statistics % num_devices_for_pjit
1352
+ padded_statistics = [pad_matrix(stat, max_size) for stat in statistics]
1353
+ padded_statistics.extend([
1354
+ jnp.eye(max_size, dtype=padded_statistics[0].dtype)
1355
+ for _ in range(to_pad)
1356
+ ])
1357
+ exponents.extend([1 for _ in range(to_pad)])
1358
+ all_statistics = jnp.stack(padded_statistics)
1359
+ all_exponents = jnp.stack(exponents)
1360
+
1361
+ def _internal_inverse_pth_root_all():
1362
+ preconditioners, errors = _matrix_inverse_pth_root_pjit(
1363
+ all_statistics, all_exponents)
1364
+ b1 = preconditioners.shape[0]
1365
+
1366
+ def split(batched_values):
1367
+ return [
1368
+ jnp.squeeze(v)
1369
+ for v in jnp.split(batched_values, indices_or_sections=b1, axis=0)
1370
+ ]
1371
+
1372
+ return split(preconditioners), split(errors)
1373
+
1374
+ if preconditioning_compute_steps == 1:
1375
+ preconditioners_flat, errors_flat = _internal_inverse_pth_root_all()
1376
+ else:
1377
+ # Passing statistics instead of preconditioners as they are similarly
1378
+ # shaped tensors. Note statistics will be ignored as we are passing in
1379
+ # a large init value for error.
1380
+ preconditioners_init = padded_statistics
1381
+ errors_init = [inverse_failure_threshold] * len(padded_statistics)
1382
+ init_state = [preconditioners_init, errors_init]
1383
+ perform_step = step % preconditioning_compute_steps == 0
1384
+ preconditioners_flat, errors_flat = efficient_cond(
1385
+ perform_step, _internal_inverse_pth_root_all, init_state)
1386
+
1387
+ def _skip(error):
1388
+ condition = jnp.logical_or(
1389
+ jnp.isnan(error), error >= inverse_failure_threshold)
1390
+ return condition.astype(error.dtype)
1391
+
1392
+ def _select_preconditioner(error, new_p, old_p):
1393
+ return lax.cond(
1394
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
1395
+
1396
+ new_preconditioners_flat = []
1397
+ for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
1398
+ prev_preconditioners, errors_flat):
1399
+ new_preconditioners_flat.append(
1400
+ _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
1401
+
1402
+ assert len(states) == len(num_statistics_per_state)
1403
+ assert len(new_preconditioners_flat) == num_statistics
1404
+
1405
+ # Add back empty preconditioners so we that we can set the optimizer state.
1406
+ preconditioners_for_states = []
1407
+ idx = 0
1408
+ for num_statistics, state in zip(num_statistics_per_state, states):
1409
+ if num_statistics == 0:
1410
+ preconditioners_for_states.append([])
1411
+ else:
1412
+ preconditioners_for_state = new_preconditioners_flat[idx:idx +
1413
+ num_statistics]
1414
+ assert len(state.statistics) == len(preconditioners_for_state)
1415
+ preconditioners_for_states.append(preconditioners_for_state)
1416
+ idx += num_statistics
1417
+ new_states = []
1418
+ for state, new_preconditioners in zip(states, preconditioners_for_states):
1419
+ new_states.append(
1420
+ ParameterStats(state.diagonal_statistics, state.statistics,
1421
+ new_preconditioners, state.diagonal_momentum,
1422
+ state.momentum))
1423
+
1424
+ return new_states
1425
+
1426
+ def _compute_preconditioners(states, params, step):
1427
+ """Computes preconditioners for given statistics in states.
1428
+
1429
+ Args:
1430
+ states: A list of optimizer states.
1431
+ params: A list of params.
1432
+ step: Current step number
1433
+
1434
+ Returns:
1435
+ New optimizer states after computing the preconditioner.
1436
+ """
1437
+ statistics = []
1438
+ num_statistics_per_state = []
1439
+ original_shapes = []
1440
+ exponents = []
1441
+ max_size = 0
1442
+ prev_preconditioners = []
1443
+
1444
+ for state, param in zip(states, params):
1445
+ num_statistics = len(state.statistics)
1446
+ num_statistics_per_state.append(num_statistics)
1447
+ original_shapes_for_state = []
1448
+ if num_statistics > 0:
1449
+ preconditioner = Preconditioner(param, block_size,
1450
+ best_effort_shape_interpretation)
1451
+ for statistic in state.statistics:
1452
+ exponents.append(preconditioner.exponent_for_preconditioner(
1453
+ ) if exponent_override == 0 else exponent_override)
1454
+ original_shapes_for_state.append(statistic.shape)
1455
+ max_size = max(max_size, statistic.shape[0])
1456
+
1457
+ statistics.extend(state.statistics)
1458
+ prev_preconditioners.extend(state.preconditioners)
1459
+ original_shapes.extend(original_shapes_for_state)
1460
+
1461
+ if batch_axis_name:
1462
+ # Quantization is only enabled if batch_axis_name is not set.
1463
+ quantized_dtype = quantized_dtype_for_second_moment_statistics_buffers()
1464
+
1465
+ if quantized_dtype == jnp.float32:
1466
+ return _pmap_compute_preconditioners(states, step, statistics,
1467
+ num_statistics_per_state,
1468
+ original_shapes, exponents,
1469
+ max_size, prev_preconditioners)
1470
+ else:
1471
+ return _pmap_quantized_compute_preconditioners(
1472
+ states, step, statistics, num_statistics_per_state, original_shapes,
1473
+ exponents, max_size, prev_preconditioners)
1474
+
1475
+ else:
1476
+ return _pjit_compute_preconditioners(states, step, statistics,
1477
+ num_statistics_per_state,
1478
+ original_shapes, exponents, max_size,
1479
+ prev_preconditioners)
1480
+
1481
+ def _transform_grad(grad, state, param, step):
1482
+ """Transform per-parameter gradients."""
1483
+ preconditioner = Preconditioner(param, block_size,
1484
+ best_effort_shape_interpretation)
1485
+ sgd_update = grad
1486
+ new_diagonal_statistics = state.diagonal_statistics.to_float()
1487
+ if graft_type == GraftingType.ADAGRAD:
1488
+ new_diagonal_statistics = state.diagonal_statistics.to_float(
1489
+ ) + jnp.square(grad)
1490
+ adagrad_update = grad / (
1491
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
1492
+ grafting_update = adagrad_update
1493
+ elif (graft_type == GraftingType.RMSPROP or
1494
+ graft_type == GraftingType.RMSPROP_NORMALIZED):
1495
+
1496
+ scaled_grad = grad
1497
+ if graft_type == GraftingType.RMSPROP_NORMALIZED:
1498
+ scaled_grad = grad / jnp.linalg.norm(grad)
1499
+
1500
+ w1 = beta2
1501
+ w2 = beta2 if beta2 == 1.0 else (1.0 - beta2)
1502
+
1503
+ new_diagonal_statistics = (
1504
+ w1 * state.diagonal_statistics.to_float() +
1505
+ w2 * jnp.square(scaled_grad))
1506
+ rmsprop_update = scaled_grad / (
1507
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
1508
+
1509
+ if clip_by_scaled_gradient_norm:
1510
+ scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / (
1511
+ jnp.sqrt(float(rmsprop_update.size)))
1512
+ clipping_denom = jnp.maximum(
1513
+ 1., scaled_grad_norm / clip_by_scaled_gradient_norm)
1514
+ rmsprop_update /= clipping_denom
1515
+
1516
+ grafting_update = rmsprop_update
1517
+ else:
1518
+ grafting_update = sgd_update
1519
+
1520
+ precond_grad = grad
1521
+ if not _skip_preconditioning(param):
1522
+ precond_grad = preconditioner.preconditioned_grad(
1523
+ precond_grad,
1524
+ _maybe_dequantize_preconditioners(state.preconditioners))
1525
+ else:
1526
+ precond_grad = grafting_update
1527
+
1528
+ grafting_update_norm = jnp.linalg.norm(grafting_update)
1529
+ precond_grad_norm = jnp.linalg.norm(precond_grad)
1530
+
1531
+ multiplier = (grafting_update_norm / (precond_grad_norm + 1e-16))
1532
+ shampoo_update = precond_grad * multiplier
1533
+
1534
+ shampoo_update_with_wd = shampoo_update
1535
+ grafting_update_with_wd = grafting_update
1536
+ if weight_decay != 0:
1537
+ shampoo_update_with_wd = shampoo_update + weight_decay * param
1538
+ grafting_update_with_wd = grafting_update + weight_decay * param
1539
+
1540
+ w = (1.0 - beta1) if moving_average_for_momentum else 1.0
1541
+ shampoo_update_with_wd_momentum = (
1542
+ state.momentum.to_float() * beta1 + w * shampoo_update_with_wd)
1543
+ grafting_update_with_wd_momentum = (
1544
+ state.diagonal_momentum.to_float() * beta1 +
1545
+ w * grafting_update_with_wd)
1546
+
1547
+ run_shampoo = (step >= start_preconditioning_step).astype(
1548
+ grafting_update_with_wd_momentum.dtype)
1549
+
1550
+ momentum_update = (
1551
+ run_shampoo * shampoo_update_with_wd_momentum +
1552
+ (1.0 - run_shampoo) * grafting_update_with_wd_momentum)
1553
+
1554
+ wd_update = (
1555
+ run_shampoo * shampoo_update_with_wd +
1556
+ (1.0 - run_shampoo) * grafting_update_with_wd)
1557
+
1558
+ if nesterov:
1559
+ momentum_update = w * wd_update + beta1 * momentum_update
1560
+
1561
+ lr = learning_rate
1562
+ if callable(learning_rate):
1563
+ lr = learning_rate(step)
1564
+ transformed_update = -1.0 * lr * momentum_update
1565
+
1566
+ param_stats = ParameterStats(
1567
+ _quantize_diagonal_statistics(new_diagonal_statistics),
1568
+ state.statistics, state.preconditioners,
1569
+ _quantize_momentum(grafting_update_with_wd_momentum),
1570
+ _quantize_momentum(shampoo_update_with_wd_momentum))
1571
+ return transformed_update, param_stats
1572
+
1573
+ def update_fn(grads, state, params):
1574
+ """Transform the input gradient and update all statistics.
1575
+
1576
+ Args:
1577
+ grads: the gradient tensors for the parameters.
1578
+ state: a named tuple containing the state of the optimizer
1579
+ params: the parameters that should be updated.
1580
+
1581
+ Returns:
1582
+ A tuple containing the new parameters and the new optimizer state.
1583
+ """
1584
+ params_flat, treedef = jax.tree_flatten(params)
1585
+ stats_flat = treedef.flatten_up_to(state.stats)
1586
+ grads_flat = treedef.flatten_up_to(grads)
1587
+
1588
+ new_stats_flat = jax.tree_multimap(
1589
+ lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat,
1590
+ stats_flat, params_flat)
1591
+ new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat,
1592
+ state.count)
1593
+
1594
+ outputs = jax.tree_multimap(
1595
+ lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat,
1596
+ new_stats_flat, params_flat)
1597
+ updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
1598
+
1599
+ updates = jax.tree_unflatten(treedef, updates_flat)
1600
+ new_stats = jax.tree_unflatten(treedef, new_stats_flat)
1601
+
1602
+ new_state = ShampooState(
1603
+ count=state.count+1, stats=new_stats)
1604
+ return updates, new_state
1605
+
1606
+ if shard_optimizer_states:
1607
+ return optax.GradientTransformation(sharded_init_fn, sharded_update_fn)
1608
+ else:
1609
+ return optax.GradientTransformation(init_fn, update_fn)
merges.txt ADDED
The diff for this file is too large to render. See raw diff
run_mlm_flax.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=fill-mask
22
+ """
23
+ import json
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ from dataclasses import asdict, dataclass, field
30
+ from enum import Enum
31
+ from itertools import chain
32
+
33
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
34
+ from pathlib import Path
35
+ from typing import Dict, List, Optional, Tuple
36
+
37
+ import numpy as np
38
+ from datasets import load_dataset
39
+ from tqdm import tqdm
40
+
41
+ import flax
42
+ import jax
43
+ import jax.numpy as jnp
44
+ import optax
45
+ from flax import jax_utils, traverse_util
46
+ from flax.training import train_state
47
+ from flax.training.common_utils import get_metrics, onehot, shard
48
+ from huggingface_hub import Repository
49
+ from transformers import (
50
+ CONFIG_MAPPING,
51
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
52
+ AutoConfig,
53
+ AutoTokenizer,
54
+ FlaxAutoModelForMaskedLM,
55
+ HfArgumentParser,
56
+ PreTrainedTokenizerBase,
57
+ TensorType,
58
+ is_tensorboard_available,
59
+ set_seed,
60
+ )
61
+ from transformers.file_utils import get_full_repo_name
62
+
63
+
64
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
65
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
66
+
67
+
68
+ @dataclass
69
+ class TrainingArguments:
70
+ output_dir: str = field(
71
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
72
+ )
73
+ overwrite_output_dir: bool = field(
74
+ default=False,
75
+ metadata={
76
+ "help": (
77
+ "Overwrite the content of the output directory. "
78
+ "Use this to continue training if output_dir points to a checkpoint directory."
79
+ )
80
+ },
81
+ )
82
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
83
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
84
+ per_device_train_batch_size: int = field(
85
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
86
+ )
87
+ per_device_eval_batch_size: int = field(
88
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
89
+ )
90
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
91
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
92
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
93
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
94
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
95
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
96
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
97
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
98
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
99
+ save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
100
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
101
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
102
+ push_to_hub: bool = field(
103
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
104
+ )
105
+ hub_model_id: str = field(
106
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
107
+ )
108
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
109
+
110
+ def __post_init__(self):
111
+ if self.output_dir is not None:
112
+ self.output_dir = os.path.expanduser(self.output_dir)
113
+
114
+ def to_dict(self):
115
+ """
116
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
117
+ the token values by removing their value.
118
+ """
119
+ d = asdict(self)
120
+ for k, v in d.items():
121
+ if isinstance(v, Enum):
122
+ d[k] = v.value
123
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
124
+ d[k] = [x.value for x in v]
125
+ if k.endswith("_token"):
126
+ d[k] = f"<{k.upper()}>"
127
+ return d
128
+
129
+
130
+ @dataclass
131
+ class ModelArguments:
132
+ """
133
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
134
+ """
135
+
136
+ model_name_or_path: Optional[str] = field(
137
+ default=None,
138
+ metadata={
139
+ "help": "The model checkpoint for weights initialization."
140
+ "Don't set if you want to train a model from scratch."
141
+ },
142
+ )
143
+ model_type: Optional[str] = field(
144
+ default=None,
145
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
146
+ )
147
+ config_name: Optional[str] = field(
148
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
149
+ )
150
+ tokenizer_name: Optional[str] = field(
151
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
152
+ )
153
+ cache_dir: Optional[str] = field(
154
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
155
+ )
156
+ use_fast_tokenizer: bool = field(
157
+ default=True,
158
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
159
+ )
160
+ dtype: Optional[str] = field(
161
+ default="float32",
162
+ metadata={
163
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
164
+ },
165
+ )
166
+
167
+
168
+ @dataclass
169
+ class DataTrainingArguments:
170
+ """
171
+ Arguments pertaining to what data we are going to input our model for training and eval.
172
+ """
173
+
174
+ dataset_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
176
+ )
177
+ dataset_config_name: Optional[str] = field(
178
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
179
+ )
180
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
181
+ validation_file: Optional[str] = field(
182
+ default=None,
183
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
184
+ )
185
+ train_ref_file: Optional[str] = field(
186
+ default=None,
187
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
188
+ )
189
+ validation_ref_file: Optional[str] = field(
190
+ default=None,
191
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
192
+ )
193
+ overwrite_cache: bool = field(
194
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
195
+ )
196
+ validation_split_percentage: Optional[int] = field(
197
+ default=5,
198
+ metadata={
199
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
200
+ },
201
+ )
202
+ max_seq_length: Optional[int] = field(
203
+ default=None,
204
+ metadata={
205
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
206
+ "than this will be truncated. Default to the max input length of the model."
207
+ },
208
+ )
209
+ preprocessing_num_workers: Optional[int] = field(
210
+ default=None,
211
+ metadata={"help": "The number of processes to use for the preprocessing."},
212
+ )
213
+ mlm_probability: float = field(
214
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
215
+ )
216
+ pad_to_max_length: bool = field(
217
+ default=False,
218
+ metadata={
219
+ "help": "Whether to pad all samples to `max_seq_length`. "
220
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
221
+ },
222
+ )
223
+ line_by_line: bool = field(
224
+ default=False,
225
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
226
+ )
227
+
228
+ def __post_init__(self):
229
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
230
+ raise ValueError("Need either a dataset name or a training/validation file.")
231
+ else:
232
+ if self.train_file is not None:
233
+ extension = self.train_file.split(".")[-1]
234
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
235
+ if self.validation_file is not None:
236
+ extension = self.validation_file.split(".")[-1]
237
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
238
+
239
+
240
+ @flax.struct.dataclass
241
+ class FlaxDataCollatorForLanguageModeling:
242
+ """
243
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
244
+ are not all of the same length.
245
+
246
+ Args:
247
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
248
+ The tokenizer used for encoding the data.
249
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
250
+ The probability with which to (randomly) mask tokens in the input.
251
+
252
+ .. note::
253
+
254
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
255
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
256
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
257
+ argument :obj:`return_special_tokens_mask=True`.
258
+ """
259
+
260
+ tokenizer: PreTrainedTokenizerBase
261
+ mlm_probability: float = 0.15
262
+
263
+ def __post_init__(self):
264
+ if self.tokenizer.mask_token is None:
265
+ raise ValueError(
266
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
267
+ "You should pass `mlm=False` to train on causal language modeling instead."
268
+ )
269
+
270
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
271
+ # Handle dict or lists with proper padding and conversion to tensor.
272
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
273
+
274
+ # If special token mask has been preprocessed, pop it from the dict.
275
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
276
+
277
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
278
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
279
+ )
280
+ return batch
281
+
282
+ def mask_tokens(
283
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
284
+ ) -> Tuple[np.ndarray, np.ndarray]:
285
+ """
286
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
287
+ """
288
+ labels = inputs.copy()
289
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
290
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
291
+ special_tokens_mask = special_tokens_mask.astype("bool")
292
+
293
+ probability_matrix[special_tokens_mask] = 0.0
294
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
295
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
296
+
297
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
298
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
299
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
300
+
301
+ # 10% of the time, we replace masked input tokens with random word
302
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
303
+ indices_random &= masked_indices & ~indices_replaced
304
+
305
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
306
+ inputs[indices_random] = random_words[indices_random]
307
+
308
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
309
+ return inputs, labels
310
+
311
+
312
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
313
+ num_samples = len(samples_idx)
314
+ samples_to_remove = num_samples % batch_size
315
+
316
+ if samples_to_remove != 0:
317
+ samples_idx = samples_idx[:-samples_to_remove]
318
+ sections_split = num_samples // batch_size
319
+ batch_idx = np.split(samples_idx, sections_split)
320
+ return batch_idx
321
+
322
+
323
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
324
+ summary_writer.scalar("train_time", train_time, step)
325
+
326
+ train_metrics = get_metrics(train_metrics)
327
+ for key, vals in train_metrics.items():
328
+ tag = f"train_{key}"
329
+ for i, val in enumerate(vals):
330
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
331
+
332
+
333
+ def write_eval_metric(summary_writer, eval_metrics, step):
334
+ for metric_name, value in eval_metrics.items():
335
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
336
+
337
+
338
+ def main():
339
+ # See all possible arguments in src/transformers/training_args.py
340
+ # or by passing the --help flag to this script.
341
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
342
+
343
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
344
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
345
+ # If we pass only one argument to the script and it's the path to a json file,
346
+ # let's parse it to get our arguments.
347
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
348
+ else:
349
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
350
+
351
+ if (
352
+ os.path.exists(training_args.output_dir)
353
+ and os.listdir(training_args.output_dir)
354
+ and training_args.do_train
355
+ and not training_args.overwrite_output_dir
356
+ ):
357
+ raise ValueError(
358
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
359
+ "Use --overwrite_output_dir to overcome."
360
+ )
361
+
362
+ # Setup logging
363
+ logging.basicConfig(
364
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
365
+ level=logging.INFO,
366
+ datefmt="[%X]",
367
+ )
368
+
369
+ # Log on each process the small summary:
370
+ logger = logging.getLogger(__name__)
371
+
372
+ # Set the verbosity to info of the Transformers logger (on main process only):
373
+ logger.info(f"Training/evaluation parameters {training_args}")
374
+
375
+ # Set seed before initializing model.
376
+ set_seed(training_args.seed)
377
+
378
+ # Handle the repository creation
379
+ if training_args.push_to_hub:
380
+ if training_args.hub_model_id is None:
381
+ repo_name = get_full_repo_name(
382
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
383
+ )
384
+ else:
385
+ repo_name = training_args.hub_model_id
386
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
387
+
388
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
389
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
390
+ # (the dataset will be downloaded automatically from the datasets Hub).
391
+ #
392
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
393
+ # 'text' is found. You can easily tweak this behavior (see below).
394
+ #
395
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
396
+ # download the dataset.
397
+ if data_args.dataset_name is not None:
398
+ # Downloading and loading a dataset from the hub.
399
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
400
+
401
+ if "validation" not in datasets.keys():
402
+ datasets["validation"] = load_dataset(
403
+ data_args.dataset_name,
404
+ data_args.dataset_config_name,
405
+ split=f"train[:{data_args.validation_split_percentage}%]",
406
+ cache_dir=model_args.cache_dir,
407
+ )
408
+ datasets["train"] = load_dataset(
409
+ data_args.dataset_name,
410
+ data_args.dataset_config_name,
411
+ split=f"train[{data_args.validation_split_percentage}%:]",
412
+ cache_dir=model_args.cache_dir,
413
+ )
414
+ else:
415
+ data_files = {}
416
+ if data_args.train_file is not None:
417
+ data_files["train"] = data_args.train_file
418
+ if data_args.validation_file is not None:
419
+ data_files["validation"] = data_args.validation_file
420
+ extension = data_args.train_file.split(".")[-1]
421
+ if extension == "txt":
422
+ extension = "text"
423
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
424
+
425
+ if "validation" not in datasets.keys():
426
+ datasets["validation"] = load_dataset(
427
+ extension,
428
+ data_files=data_files,
429
+ split=f"train[:{data_args.validation_split_percentage}%]",
430
+ cache_dir=model_args.cache_dir,
431
+ )
432
+ datasets["train"] = load_dataset(
433
+ extension,
434
+ data_files=data_files,
435
+ split=f"train[{data_args.validation_split_percentage}%:]",
436
+ cache_dir=model_args.cache_dir,
437
+ )
438
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
439
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
440
+
441
+ # Load pretrained model and tokenizer
442
+
443
+ # Distributed training:
444
+ # The .from_pretrained methods guarantee that only one local process can concurrently
445
+ # download model & vocab.
446
+ if model_args.config_name:
447
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
448
+ elif model_args.model_name_or_path:
449
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
450
+ else:
451
+ config = CONFIG_MAPPING[model_args.model_type]()
452
+ logger.warning("You are instantiating a new config instance from scratch.")
453
+
454
+ if model_args.tokenizer_name:
455
+ tokenizer = AutoTokenizer.from_pretrained(
456
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
457
+ )
458
+ elif model_args.model_name_or_path:
459
+ tokenizer = AutoTokenizer.from_pretrained(
460
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
461
+ )
462
+ else:
463
+ raise ValueError(
464
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
465
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
466
+ )
467
+
468
+ # Preprocessing the datasets.
469
+ # First we tokenize all the texts.
470
+ if training_args.do_train:
471
+ column_names = datasets["train"].column_names
472
+ else:
473
+ column_names = datasets["validation"].column_names
474
+ text_column_name = "text" if "text" in column_names else column_names[0]
475
+
476
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
477
+
478
+ if data_args.line_by_line:
479
+ # When using line_by_line, we just tokenize each nonempty line.
480
+ padding = "max_length" if data_args.pad_to_max_length else False
481
+
482
+ def tokenize_function(examples):
483
+ # Remove empty lines
484
+ examples = [line for line in examples if len(line) > 0 and not line.isspace()]
485
+ return tokenizer(
486
+ examples,
487
+ return_special_tokens_mask=True,
488
+ padding=padding,
489
+ truncation=True,
490
+ max_length=max_seq_length,
491
+ )
492
+
493
+ tokenized_datasets = datasets.map(
494
+ tokenize_function,
495
+ input_columns=[text_column_name],
496
+ batched=True,
497
+ num_proc=data_args.preprocessing_num_workers,
498
+ remove_columns=column_names,
499
+ load_from_cache_file=not data_args.overwrite_cache,
500
+ )
501
+
502
+ else:
503
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
504
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
505
+ # efficient when it receives the `special_tokens_mask`.
506
+ def tokenize_function(examples):
507
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
508
+
509
+ tokenized_datasets = datasets.map(
510
+ tokenize_function,
511
+ batched=True,
512
+ num_proc=data_args.preprocessing_num_workers,
513
+ remove_columns=column_names,
514
+ load_from_cache_file=not data_args.overwrite_cache,
515
+ )
516
+
517
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
518
+ # max_seq_length.
519
+ def group_texts(examples):
520
+ # Concatenate all texts.
521
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
522
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
523
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
524
+ # customize this part to your needs.
525
+ if total_length >= max_seq_length:
526
+ total_length = (total_length // max_seq_length) * max_seq_length
527
+ # Split by chunks of max_len.
528
+ result = {
529
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
530
+ for k, t in concatenated_examples.items()
531
+ }
532
+ return result
533
+
534
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
535
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
536
+ # might be slower to preprocess.
537
+ #
538
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
539
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
540
+ tokenized_datasets = tokenized_datasets.map(
541
+ group_texts,
542
+ batched=True,
543
+ num_proc=data_args.preprocessing_num_workers,
544
+ load_from_cache_file=not data_args.overwrite_cache,
545
+ )
546
+
547
+ # Enable tensorboard only on the master node
548
+ has_tensorboard = is_tensorboard_available()
549
+ if has_tensorboard and jax.process_index() == 0:
550
+ try:
551
+ # Enable Weight&Biases
552
+ import wandb
553
+ wandb.init(
554
+ entity='versae',
555
+ project='roberta-base-ncc',
556
+ sync_tensorboard=False,
557
+ )
558
+ wandb.config.update(training_args)
559
+ wandb.config.update(model_args)
560
+ wandb.config.update(data_args)
561
+
562
+ from flax.metrics.tensorboard import SummaryWriter
563
+
564
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
565
+ except ImportError as ie:
566
+ has_tensorboard = False
567
+ logger.warning(
568
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
569
+ )
570
+ else:
571
+ logger.warning(
572
+ "Unable to display metrics through TensorBoard because the package is not installed: "
573
+ "Please run pip install tensorboard to enable."
574
+ )
575
+
576
+ # Data collator
577
+ # This one will take care of randomly masking the tokens.
578
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
579
+
580
+ # Initialize our training
581
+ rng = jax.random.PRNGKey(training_args.seed)
582
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
583
+
584
+ if model_args.model_name_or_path:
585
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
586
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
587
+ )
588
+ else:
589
+ model = FlaxAutoModelForMaskedLM.from_config(
590
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
591
+ )
592
+
593
+ # Store some constant
594
+ num_epochs = int(training_args.num_train_epochs)
595
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
596
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
597
+
598
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
599
+
600
+ # Create learning rate schedule
601
+ warmup_fn = optax.linear_schedule(
602
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
603
+ )
604
+ decay_fn = optax.linear_schedule(
605
+ init_value=training_args.learning_rate,
606
+ end_value=0,
607
+ transition_steps=num_train_steps - training_args.warmup_steps,
608
+ )
609
+ linear_decay_lr_schedule_fn = optax.join_schedules(
610
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
611
+ )
612
+
613
+ # We use Optax's "masking" functionality to not apply weight decay
614
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
615
+ # mask boolean with the same structure as the parameters.
616
+ # The mask is True for parameters that should be decayed.
617
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
618
+ # For other models, one should correct the layer norm parameter naming
619
+ # accordingly.
620
+ def decay_mask_fn(params):
621
+ flat_params = traverse_util.flatten_dict(params)
622
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
623
+ return traverse_util.unflatten_dict(flat_mask)
624
+
625
+ # create adam optimizer
626
+ if training_args.adafactor:
627
+ # We use the default parameters here to initialize adafactor,
628
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
629
+ optimizer = optax.adafactor(
630
+ learning_rate=linear_decay_lr_schedule_fn,
631
+ )
632
+ else:
633
+ optimizer = optax.adamw(
634
+ learning_rate=linear_decay_lr_schedule_fn,
635
+ b1=training_args.adam_beta1,
636
+ b2=training_args.adam_beta2,
637
+ eps=training_args.adam_epsilon,
638
+ weight_decay=training_args.weight_decay,
639
+ mask=decay_mask_fn,
640
+ )
641
+
642
+ # Setup train state
643
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
644
+
645
+ # Define gradient update step fn
646
+ def train_step(state, batch, dropout_rng):
647
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
648
+
649
+ def loss_fn(params):
650
+ labels = batch.pop("labels")
651
+
652
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
653
+
654
+ # compute loss, ignore padded input tokens
655
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
656
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
657
+
658
+ # take average
659
+ loss = loss.sum() / label_mask.sum()
660
+
661
+ return loss
662
+
663
+ grad_fn = jax.value_and_grad(loss_fn)
664
+ loss, grad = grad_fn(state.params)
665
+ grad = jax.lax.pmean(grad, "batch")
666
+ new_state = state.apply_gradients(grads=grad)
667
+
668
+ metrics = jax.lax.pmean(
669
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
670
+ )
671
+
672
+ return new_state, metrics, new_dropout_rng
673
+
674
+ # Create parallel version of the train step
675
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
676
+
677
+ # Define eval fn
678
+ def eval_step(params, batch):
679
+ labels = batch.pop("labels")
680
+
681
+ logits = model(**batch, params=params, train=False)[0]
682
+
683
+ # compute loss, ignore padded input tokens
684
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
685
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
686
+
687
+ # compute accuracy
688
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
689
+
690
+ # summarize metrics
691
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
692
+ metrics = jax.lax.psum(metrics, axis_name="batch")
693
+
694
+ return metrics
695
+
696
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
697
+
698
+ # Replicate the train state on each device
699
+ state = jax_utils.replicate(state)
700
+
701
+ train_time = 0
702
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
703
+ for epoch in epochs:
704
+ # ======================== Training ================================
705
+ train_start = time.time()
706
+ train_metrics = []
707
+
708
+ # Create sampling rng
709
+ rng, input_rng = jax.random.split(rng)
710
+
711
+ # Generate an epoch by shuffling sampling indices from the train dataset
712
+ num_train_samples = len(tokenized_datasets["train"])
713
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
714
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
715
+
716
+ # Gather the indexes for creating the batch and do a training step
717
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
718
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
719
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
720
+
721
+ # Model forward
722
+ model_inputs = shard(model_inputs.data)
723
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
724
+ train_metrics.append(train_metric)
725
+
726
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
727
+
728
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
729
+ # Save metrics
730
+ train_metric = jax_utils.unreplicate(train_metric)
731
+ train_time += time.time() - train_start
732
+ if has_tensorboard and jax.process_index() == 0:
733
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
734
+
735
+ epochs.write(
736
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
737
+ )
738
+
739
+ train_metrics = []
740
+
741
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
742
+ # ======================== Evaluating ==============================
743
+ num_eval_samples = len(tokenized_datasets["validation"])
744
+ eval_samples_idx = jnp.arange(num_eval_samples)
745
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
746
+
747
+ eval_metrics = []
748
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
749
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
750
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
751
+
752
+ # Model forward
753
+ model_inputs = shard(model_inputs.data)
754
+ metrics = p_eval_step(state.params, model_inputs)
755
+ eval_metrics.append(metrics)
756
+
757
+ # normalize eval metrics
758
+ eval_metrics = get_metrics(eval_metrics)
759
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
760
+ eval_normalizer = eval_metrics.pop("normalizer")
761
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
762
+
763
+ # Update progress bar
764
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
765
+
766
+ # Save metrics
767
+ if has_tensorboard and jax.process_index() == 0:
768
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
769
+
770
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
771
+ # save checkpoint after each epoch and push checkpoint to the hub
772
+ if jax.process_index() == 0:
773
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
774
+ model.save_pretrained(training_args.output_dir, params=params)
775
+ tokenizer.save_pretrained(training_args.output_dir)
776
+ if training_args.push_to_hub:
777
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
778
+
779
+ # Eval after training
780
+ if training_args.do_eval:
781
+ num_eval_samples = len(tokenized_datasets["validation"])
782
+ eval_samples_idx = jnp.arange(num_eval_samples)
783
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
784
+
785
+ eval_metrics = []
786
+ for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
787
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
788
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
789
+
790
+ # Model forward
791
+ model_inputs = shard(model_inputs.data)
792
+ metrics = p_eval_step(state.params, model_inputs)
793
+ eval_metrics.append(metrics)
794
+
795
+ # normalize eval metrics
796
+ eval_metrics = get_metrics(eval_metrics)
797
+ eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
798
+ eval_normalizer = eval_metrics.pop("normalizer")
799
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
800
+
801
+ try:
802
+ perplexity = math.exp(eval_metrics["loss"])
803
+ except OverflowError:
804
+ perplexity = float("inf")
805
+ eval_metrics["perplexity"] = perplexity
806
+
807
+ if jax.process_index() == 0:
808
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
809
+ path = os.path.join(training_args.output_dir, "eval_results.json")
810
+ with open(path, "w") as f:
811
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
812
+
813
+
814
+ if __name__ == "__main__":
815
+ main()
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "add_prefix_space": false, "errors": "replace", "sep_token": "</s>", "cls_token": "<s>", "pad_token": "<pad>", "mask_token": "<mask>", "trim_offsets": true, "special_tokens_map_file": null, "name_or_path": "NbAiLab/nb-roberta-base", "tokenizer_class": "RobertaTokenizer"}
train.128.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_mlm_flax.py \
2
+ --output_dir="./" \
3
+ --model_type="roberta" \
4
+ --config_name="roberta-base" \
5
+ --tokenizer_name="NbAiLab/nb-roberta-base" \
6
+ --dataset_name="NbAiLab/NCC" \
7
+ --max_seq_length="128" \
8
+ --weight_decay="0.01" \
9
+ --per_device_train_batch_size="232" \
10
+ --per_device_eval_batch_size="232" \
11
+ --pad_to_max_length \
12
+ --learning_rate="6e-4" \
13
+ --warmup_steps="10000" \
14
+ --overwrite_output_dir \
15
+ --num_train_epochs="3" \
16
+ --adam_beta1="0.9" \
17
+ --adam_beta2="0.98" \
18
+ --adam_epsilon="1e-6" \
19
+ --logging_steps="1000" \
20
+ --save_steps="1000" \
21
+ --eval_steps="1000" \
22
+ --do_train \
23
+ --do_eval \
24
+ --dtype="bfloat16" \
25
+ --push_to_hub
vocab.json ADDED
The diff for this file is too large to render. See raw diff