pere commited on
Commit
bf8b191
1 Parent(s): aed0a88
gpt-neo-1.3B/config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "gelu_new",
3
+ "architectures": [
4
+ "GPTNeoForCausalLM"
5
+ ],
6
+ "attention_dropout": 0,
7
+ "attention_layers": [
8
+ "global",
9
+ "local",
10
+ "global",
11
+ "local",
12
+ "global",
13
+ "local",
14
+ "global",
15
+ "local",
16
+ "global",
17
+ "local",
18
+ "global",
19
+ "local",
20
+ "global",
21
+ "local",
22
+ "global",
23
+ "local",
24
+ "global",
25
+ "local",
26
+ "global",
27
+ "local",
28
+ "global",
29
+ "local",
30
+ "global",
31
+ "local"
32
+ ],
33
+ "attention_types": [
34
+ [
35
+ [
36
+ "global",
37
+ "local"
38
+ ],
39
+ 12
40
+ ]
41
+ ],
42
+ "bos_token_id": 50256,
43
+ "embed_dropout": 0,
44
+ "eos_token_id": 50256,
45
+ "gradient_checkpointing": false,
46
+ "hidden_size": 2048,
47
+ "initializer_range": 0.02,
48
+ "intermediate_size": null,
49
+ "layer_norm_epsilon": 1e-05,
50
+ "max_position_embeddings": 2048,
51
+ "model_type": "gpt_neo",
52
+ "num_heads": 16,
53
+ "num_layers": 24,
54
+ "resid_dropout": 0,
55
+ "summary_activation": null,
56
+ "summary_first_dropout": 0.1,
57
+ "summary_proj_to_labels": true,
58
+ "summary_type": "cls_index",
59
+ "summary_use_proj": true,
60
+ "task_specific_params": {
61
+ "text-generation": {
62
+ "do_sample": true,
63
+ "max_length": 50,
64
+ "temperature": 0.9
65
+ }
66
+ },
67
+ "tokenizer_class": "GPT2Tokenizer",
68
+ "transformers_version": "4.9.0.dev0",
69
+ "use_cache": true,
70
+ "vocab_size": 50264,
71
+ "window_size": 256
72
+ }
gpt-neo-1.3B/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c694052f126d1176eaa04c0dffd77b0c49a301f6677b83a9d39d4a61f3c59ccc
3
+ size 5262371934
partitions.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The Google Research Authors and The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """Utilities for constructing PyTrees of PartitionSpecs."""
17
+
18
+ # utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
19
+
20
+ import re
21
+
22
+ from flax.core.frozen_dict import freeze
23
+ from flax.traverse_util import flatten_dict, unflatten_dict
24
+ from jax.experimental import PartitionSpec as P
25
+
26
+
27
+ # Sentinels
28
+ _unmatched = object()
29
+
30
+ # For specifying empty leaf dict `{}`
31
+ empty_dict = object()
32
+
33
+
34
+ def _match(qs, ks):
35
+ """Return True if regexes in qs match any window of strings in tuple ks."""
36
+ # compile regexes and force complete match
37
+ qts = tuple(map(lambda x: re.compile(x + "$"), qs))
38
+ for i in range(len(ks) - len(qs) + 1):
39
+ matches = [x.match(y) for x, y in zip(qts, ks[i:])]
40
+ if matches and all(matches):
41
+ return True
42
+ return False
43
+
44
+
45
+ def _replacement_rules(rules):
46
+ def replace(key, val):
47
+ for rule, replacement in rules:
48
+ if _match(rule, key):
49
+ return replacement
50
+ return val
51
+
52
+ return replace
53
+
54
+
55
+ # PartitionSpec for GPTNeo
56
+ # replicate the hidden dim and shard feed-forward and head dim
57
+ def _get_partition_rules():
58
+ return [
59
+ # embeddings
60
+ (("transformer", "wpe", "embedding"), P("mp", None)),
61
+ (("transformer", "wte", "embedding"), P("mp", None)),
62
+ # atention
63
+ (("attention", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
64
+ (("attention", "out_proj", "kernel"), P("mp", None)),
65
+ (("attention", "out_proj", "bias"), None),
66
+ # mlp
67
+ (("mlp", "c_fc", "kernel"), P(None, "mp")),
68
+ (("mlp", "c_fc", "bias"), P("mp")),
69
+ (("mlp", "c_proj", "kernel"), P("mp", None)),
70
+ (("mlp", "c_proj", "bias"), None),
71
+ # layer norms
72
+ ((r"ln_\d+", "bias"), None),
73
+ ((r"\d+", r"ln_\d+", "scale"), None),
74
+ (("ln_f", "bias"), None),
75
+ (("ln_f", "scale"), None),
76
+ ]
77
+
78
+
79
+ def set_partitions(in_dict):
80
+ rules = _get_partition_rules()
81
+ replace = _replacement_rules(rules)
82
+ initd = {k: _unmatched for k in flatten_dict(in_dict)}
83
+ result = {k: replace(k, v) for k, v in initd.items()}
84
+ assert _unmatched not in result.values(), "Incomplete partition spec."
85
+ return freeze(unflatten_dict(result))
run_clm_mp.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Pre-training/Fine-tuning the GPTNeo model for causal language modeling on a text file or a dataset using model parallelism.
18
+ """
19
+
20
+ import logging
21
+ import math
22
+ import os
23
+ import sys
24
+ import time
25
+ from dataclasses import dataclass, field
26
+ from pathlib import Path
27
+ from typing import Callable, Optional
28
+
29
+ import datasets
30
+ import numpy as np
31
+ from datasets import Dataset, load_dataset
32
+ from tqdm import tqdm
33
+
34
+ import jax
35
+ import jax.numpy as jnp
36
+ import optax
37
+ import transformers
38
+ from flax.core.frozen_dict import freeze, unfreeze
39
+ from flax.training.common_utils import onehot, stack_forest
40
+ from jax.experimental.maps import mesh
41
+ from jax.experimental.pjit import pjit
42
+ from partitions import set_partitions
43
+ from transformers import (
44
+ CONFIG_MAPPING,
45
+ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
46
+ AutoConfig,
47
+ AutoTokenizer,
48
+ FlaxAutoModelForCausalLM,
49
+ HfArgumentParser,
50
+ TrainingArguments,
51
+ is_tensorboard_available,
52
+ )
53
+ from transformers.testing_utils import CaptureLogger
54
+
55
+
56
+ logger = logging.getLogger(__name__)
57
+
58
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
59
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
60
+
61
+
62
+ @dataclass
63
+ class ModelArguments:
64
+ """
65
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
66
+ """
67
+
68
+ model_name_or_path: Optional[str] = field(
69
+ default=None,
70
+ metadata={
71
+ "help": "The model checkpoint for weights initialization."
72
+ "Don't set if you want to train a model from scratch."
73
+ },
74
+ )
75
+ model_type: Optional[str] = field(
76
+ default=None,
77
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
78
+ )
79
+ config_name: Optional[str] = field(
80
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
81
+ )
82
+ tokenizer_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
84
+ )
85
+ cache_dir: Optional[str] = field(
86
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
87
+ )
88
+ use_fast_tokenizer: bool = field(
89
+ default=True,
90
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
91
+ )
92
+ dtype: Optional[str] = field(
93
+ default="float32",
94
+ metadata={
95
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
96
+ },
97
+ )
98
+
99
+
100
+ @dataclass
101
+ class DataTrainingArguments:
102
+ """
103
+ Arguments pertaining to what data we are going to input our model for training and eval.
104
+ """
105
+
106
+ dataset_name: Optional[str] = field(
107
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
108
+ )
109
+ dataset_config_name: Optional[str] = field(
110
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
111
+ )
112
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
113
+ validation_file: Optional[str] = field(
114
+ default=None,
115
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
116
+ )
117
+ max_train_samples: Optional[int] = field(
118
+ default=None,
119
+ metadata={
120
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
121
+ "value if set."
122
+ },
123
+ )
124
+ max_eval_samples: Optional[int] = field(
125
+ default=None,
126
+ metadata={
127
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
128
+ "value if set."
129
+ },
130
+ )
131
+ overwrite_cache: bool = field(
132
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
133
+ )
134
+ validation_split_percentage: Optional[int] = field(
135
+ default=5,
136
+ metadata={
137
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
138
+ },
139
+ )
140
+ block_size: Optional[int] = field(
141
+ default=None,
142
+ metadata={
143
+ "help": "Optional input sequence length after tokenization. "
144
+ "The training dataset will be truncated in block of this size for training. "
145
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
146
+ },
147
+ )
148
+ overwrite_cache: bool = field(
149
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
150
+ )
151
+ preprocessing_num_workers: Optional[int] = field(
152
+ default=None,
153
+ metadata={"help": "The number of processes to use for the preprocessing."},
154
+ )
155
+
156
+ def __post_init__(self):
157
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
158
+ raise ValueError("Need either a dataset name or a training/validation file.")
159
+ else:
160
+ if self.train_file is not None:
161
+ extension = self.train_file.split(".")[-1]
162
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
163
+ if self.validation_file is not None:
164
+ extension = self.validation_file.split(".")[-1]
165
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
166
+
167
+
168
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
169
+ """
170
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
171
+ Shuffle batches if `shuffle` is `True`.
172
+ """
173
+ steps_per_epoch = len(dataset) // batch_size
174
+
175
+ if shuffle:
176
+ batch_idx = jax.random.permutation(rng, len(dataset))
177
+ else:
178
+ batch_idx = jnp.arange(len(dataset))
179
+
180
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
181
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
182
+
183
+ for idx in batch_idx:
184
+ batch = dataset[idx]
185
+ batch = {k: jnp.array(v) for k, v in batch.items()}
186
+ yield batch
187
+
188
+
189
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
190
+ summary_writer.scalar("train_time", train_time, step)
191
+
192
+ train_metrics = stack_forest(train_metrics)
193
+ for key, vals in train_metrics.items():
194
+ tag = f"train_{key}"
195
+ for i, val in enumerate(vals):
196
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
197
+
198
+
199
+ def write_eval_metric(summary_writer, eval_metrics, step):
200
+ for metric_name, value in eval_metrics.items():
201
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
202
+
203
+
204
+ def create_learning_rate_fn(
205
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
206
+ ) -> Callable[[int], jnp.array]:
207
+ """Returns a linear warmup, linear_decay learning rate function."""
208
+ steps_per_epoch = train_ds_size // train_batch_size
209
+ num_train_steps = steps_per_epoch * num_train_epochs
210
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
211
+ decay_fn = optax.linear_schedule(
212
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
213
+ )
214
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
215
+ return schedule_fn
216
+
217
+
218
+ def main():
219
+ # See all possible arguments in src/transformers/training_args.py
220
+ # or by passing the --help flag to this script.
221
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
222
+
223
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
224
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
225
+ # If we pass only one argument to the script and it's the path to a json file,
226
+ # let's parse it to get our arguments.
227
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
228
+ else:
229
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
230
+
231
+ if (
232
+ os.path.exists(training_args.output_dir)
233
+ and os.listdir(training_args.output_dir)
234
+ and training_args.do_train
235
+ and not training_args.overwrite_output_dir
236
+ ):
237
+ raise ValueError(
238
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
239
+ "Use --overwrite_output_dir to overcome."
240
+ )
241
+
242
+ # Make one log on every process with the configuration for debugging.
243
+ logging.basicConfig(
244
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
245
+ datefmt="%m/%d/%Y %H:%M:%S",
246
+ level=logging.INFO,
247
+ )
248
+ # Setup logging, we only want one process per machine to log things on the screen.
249
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
250
+ if jax.process_index() == 0:
251
+ datasets.utils.logging.set_verbosity_warning()
252
+ transformers.utils.logging.set_verbosity_info()
253
+ else:
254
+ datasets.utils.logging.set_verbosity_error()
255
+ transformers.utils.logging.set_verbosity_error()
256
+
257
+ # Set the verbosity to info of the Transformers logger (on main process only):
258
+ logger.info(f"Training/evaluation parameters {training_args}")
259
+
260
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
261
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
262
+ # (the dataset will be downloaded automatically from the datasets Hub).
263
+ #
264
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
265
+ # 'text' is found. You can easily tweak this behavior (see below).
266
+ if data_args.dataset_name is not None:
267
+ # Downloading and loading a dataset from the hub.
268
+ dataset = load_dataset(
269
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
270
+ )
271
+
272
+ if "validation" not in dataset.keys():
273
+ dataset["validation"] = load_dataset(
274
+ data_args.dataset_name,
275
+ data_args.dataset_config_name,
276
+ split=f"train[:{data_args.validation_split_percentage}%]",
277
+ cache_dir=model_args.cache_dir,
278
+ )
279
+ dataset["train"] = load_dataset(
280
+ data_args.dataset_name,
281
+ data_args.dataset_config_name,
282
+ split=f"train[{data_args.validation_split_percentage}%:]",
283
+ cache_dir=model_args.cache_dir,
284
+ )
285
+ else:
286
+ data_files = {}
287
+ if data_args.train_file is not None:
288
+ data_files["train"] = data_args.train_file
289
+ if data_args.validation_file is not None:
290
+ data_files["validation"] = data_args.validation_file
291
+ extension = data_args.train_file.split(".")[-1]
292
+ if extension == "txt":
293
+ extension = "text"
294
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
295
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
296
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
297
+
298
+ # Load pretrained config and tokenizer
299
+ if model_args.config_name:
300
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
301
+ elif model_args.model_name_or_path:
302
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
303
+ else:
304
+ config = CONFIG_MAPPING[model_args.model_type]()
305
+ logger.warning("You are instantiating a new config instance from scratch.")
306
+
307
+ if model_args.tokenizer_name:
308
+ tokenizer = AutoTokenizer.from_pretrained(
309
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
310
+ )
311
+ elif model_args.model_name_or_path:
312
+ tokenizer = AutoTokenizer.from_pretrained(
313
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
314
+ )
315
+ else:
316
+ raise ValueError(
317
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
318
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
319
+ )
320
+
321
+ if training_args.do_train:
322
+ column_names = dataset["train"].column_names
323
+ else:
324
+ column_names = dataset["validation"].column_names
325
+ text_column_name = "text" if "text" in column_names else column_names[0]
326
+
327
+ # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
328
+ tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
329
+
330
+ def tokenize_function(examples):
331
+ with CaptureLogger(tok_logger) as cl:
332
+ output = tokenizer(examples[text_column_name])
333
+ # clm input could be much much longer than block_size
334
+ if "Token indices sequence length is longer than the" in cl.out:
335
+ tok_logger.warning(
336
+ "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
337
+ )
338
+ return output
339
+
340
+ tokenized_datasets = dataset.map(
341
+ tokenize_function,
342
+ batched=True,
343
+ num_proc=data_args.preprocessing_num_workers,
344
+ remove_columns=column_names,
345
+ load_from_cache_file=not data_args.overwrite_cache,
346
+ )
347
+
348
+ if data_args.block_size is None:
349
+ block_size = tokenizer.model_max_length
350
+ if block_size > config.max_position_embeddings:
351
+ logger.warning(
352
+ f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
353
+ "Picking 1024 instead. You can change that default value by passing --block_size xxx."
354
+ )
355
+ block_size = 1024
356
+ else:
357
+ if data_args.block_size > tokenizer.model_max_length:
358
+ logger.warning(
359
+ f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
360
+ f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
361
+ )
362
+ block_size = min(data_args.block_size, tokenizer.model_max_length)
363
+
364
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
365
+ def group_texts(examples):
366
+ # Concatenate all texts.
367
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
368
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
369
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
370
+ # customize this part to your needs.
371
+ if total_length >= block_size:
372
+ total_length = (total_length // block_size) * block_size
373
+ # Split by chunks of max_len.
374
+ result = {
375
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
376
+ for k, t in concatenated_examples.items()
377
+ }
378
+ result["labels"] = result["input_ids"].copy()
379
+ return result
380
+
381
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
382
+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
383
+ # to preprocess.
384
+ #
385
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
386
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
387
+
388
+ lm_datasets = tokenized_datasets.map(
389
+ group_texts,
390
+ batched=True,
391
+ num_proc=data_args.preprocessing_num_workers,
392
+ load_from_cache_file=not data_args.overwrite_cache,
393
+ )
394
+
395
+ if training_args.do_train:
396
+ if "train" not in tokenized_datasets:
397
+ raise ValueError("--do_train requires a train dataset")
398
+ train_dataset = lm_datasets["train"]
399
+ if data_args.max_train_samples is not None:
400
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
401
+
402
+ if training_args.do_eval:
403
+ if "validation" not in tokenized_datasets:
404
+ raise ValueError("--do_eval requires a validation dataset")
405
+ eval_dataset = lm_datasets["validation"]
406
+ if data_args.max_eval_samples is not None:
407
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
408
+
409
+ # Enable tensorboard only on the master node
410
+ has_tensorboard = is_tensorboard_available()
411
+ if has_tensorboard and jax.process_index() == 0:
412
+ try:
413
+ from flax.metrics.tensorboard import SummaryWriter
414
+
415
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
416
+ except ImportError as ie:
417
+ has_tensorboard = False
418
+ logger.warning(
419
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
420
+ )
421
+ else:
422
+ logger.warning(
423
+ "Unable to display metrics through TensorBoard because the package is not installed: "
424
+ "Please run pip install tensorboard to enable."
425
+ )
426
+
427
+ # Initialize our training
428
+ rng = jax.random.PRNGKey(training_args.seed)
429
+ rng, dropout_rng = jax.random.split(rng)
430
+
431
+ # Store some constant
432
+ num_epochs = int(training_args.num_train_epochs)
433
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
434
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
435
+ steps_per_epoch = len(train_dataset) // train_batch_size
436
+ total_train_steps = steps_per_epoch * num_epochs
437
+
438
+ # TODO: weights should be initialized in pjitted fun, this won't work for REALLY large models
439
+ # TODO: when loading from pre-trained model we need to make sure the vocab is divisible by num_partitions
440
+ # GPT2's vocab is odd, we need to resize it for fine-tuning
441
+ model = FlaxAutoModelForCausalLM.from_pretrained(
442
+ model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
443
+ )
444
+
445
+ # Create learning rate schedule
446
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
447
+ len(train_dataset),
448
+ train_batch_size,
449
+ training_args.num_train_epochs,
450
+ training_args.warmup_steps,
451
+ training_args.learning_rate,
452
+ )
453
+
454
+ optimizer = optax.adamw(
455
+ learning_rate=linear_decay_lr_schedule_fn,
456
+ b1=training_args.adam_beta1,
457
+ b2=training_args.adam_beta2,
458
+ eps=training_args.adam_epsilon,
459
+ weight_decay=training_args.weight_decay,
460
+ )
461
+
462
+ def get_initial_state(params):
463
+ state = optimizer.init(params)
464
+ return tuple(state), params
465
+
466
+ # Get PartitionSpec for model params
467
+ param_spec = set_partitions(unfreeze(model.params))
468
+
469
+ # Get the PyTree for opt_state, we don't actually initialize the opt_state yet.
470
+ params_shapes = jax.tree_map(lambda x: x.shape, model.params)
471
+ state_shapes = jax.eval_shape(get_initial_state, params_shapes)
472
+
473
+ # get PartitionSpec for opt_state, this is very specific to adamw
474
+ # TODO: optax returns different state for different optimizers, how can we handle this generically ?
475
+ # or maybe we don't since in our examples we just use adamw or adafactor
476
+ def get_opt_spec(x):
477
+ if isinstance(x, dict):
478
+ return param_spec
479
+ return None
480
+
481
+ opt_state_spec, param_spec = jax.tree_map(
482
+ get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
483
+ )
484
+
485
+ # pjit the get_initial_state function to shard params and init
486
+ # optimizer state in sharded way
487
+ p_get_initial_state = pjit(
488
+ get_initial_state,
489
+ in_axis_resources=None,
490
+ out_axis_resources=(opt_state_spec, param_spec),
491
+ )
492
+
493
+ # hack: move the inital params to CPU to free up device memory
494
+ # TODO: allow loading weights on CPU in pre-trained model
495
+ model.params = jax.tree_map(lambda x: np.asarray(x), model.params)
496
+
497
+ # mesh defination
498
+ mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count())
499
+
500
+ # actually initialize the opt_state
501
+ with mesh(mesh_devices, ("dp", "mp")):
502
+ opt_state, params = p_get_initial_state(freeze(model.params))
503
+
504
+ # cross-entropy with z loss
505
+ def loss_fn(logits, labels, z_loss=0):
506
+ shift_logits = logits[..., :-1, :]
507
+ shift_labels = labels[..., 1:]
508
+
509
+ shift_labels = onehot(shift_labels, shift_logits.shape[-1])
510
+
511
+ shift_logits = shift_logits - jax.lax.stop_gradient(shift_logits.max(axis=-1, keepdims=True))
512
+ log_z = jnp.log(jnp.sum(jnp.exp(shift_logits), axis=-1, keepdims=True))
513
+ log_softmax = shift_logits - log_z
514
+ loss = -jnp.sum(shift_labels * log_softmax, axis=-1)
515
+
516
+ loss += (1e-4 * jnp.square(log_z.squeeze(-1))) * z_loss
517
+
518
+ return loss.mean()
519
+
520
+ # Define gradient update step fn
521
+ # TODO: try to use TrainState instead of passing params and opt_state individually
522
+ def train_step(params, opt_state, dropout_rng, batch, step):
523
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
524
+
525
+ def compute_loss(params):
526
+ labels = batch.pop("labels")
527
+ logits = model(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
528
+ loss = loss_fn(logits, labels, z_loss=1.0)
529
+ return loss
530
+
531
+ grad_fn = jax.value_and_grad(compute_loss)
532
+ loss, grads = grad_fn(params)
533
+
534
+ updates, new_opt_state = optimizer.update(grads, opt_state, params)
535
+ new_params = optax.apply_updates(params, updates)
536
+
537
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(step)}
538
+ return new_params, tuple(new_opt_state), new_dropout_rng, metrics, step + 1
539
+
540
+ # Define eval fn
541
+ def eval_step(input_ids, labels, params):
542
+ logits = model(input_ids=input_ids, params=params, train=False)[0]
543
+ loss = loss_fn(logits, labels)
544
+ # metrics
545
+ return {"loss": loss}
546
+
547
+ p_train_step = pjit(
548
+ train_step,
549
+ in_axis_resources=(param_spec, opt_state_spec, None, None, None),
550
+ out_axis_resources=(param_spec, opt_state_spec, None, None, None),
551
+ donate_argnums=(0, 1),
552
+ )
553
+
554
+ p_eval_step = pjit(
555
+ eval_step,
556
+ in_axis_resources=(None, None, param_spec),
557
+ out_axis_resources=None,
558
+ )
559
+
560
+ logger.info("***** Running training *****")
561
+ logger.info(f" Num examples = {len(train_dataset)}")
562
+ logger.info(f" Num Epochs = {num_epochs}")
563
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
564
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
565
+ logger.info(f" Total optimization steps = {total_train_steps}")
566
+
567
+ train_time = 0
568
+ train_metrics = []
569
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
570
+ global_step = 0
571
+ # we are not doing 2D parallelism (yet!), this just does model parallelism
572
+ with mesh(mesh_devices, ("dp", "mp")):
573
+ for _ in epochs:
574
+ # ======================== Training ================================
575
+ train_start = time.time()
576
+
577
+ # Create sampling rng
578
+ rng, input_rng = jax.random.split(rng)
579
+
580
+ # Generate an epoch by shuffling sampling indices from the train dataset
581
+ train_metrics = []
582
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
583
+ steps_per_epoch = len(train_dataset) // train_batch_size
584
+
585
+ # train
586
+ for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
587
+ batch = next(train_loader)
588
+ params, opt_state, dropout_rng, train_metric, global_step = p_train_step(
589
+ params,
590
+ opt_state,
591
+ dropout_rng,
592
+ batch,
593
+ global_step,
594
+ )
595
+ train_metrics.append(train_metric)
596
+
597
+ cur_step = global_step
598
+
599
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
600
+ # Save metrics
601
+ train_time += time.time() - train_start
602
+ if has_tensorboard and jax.process_index() == 0:
603
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
604
+
605
+ epochs.write(
606
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
607
+ )
608
+
609
+ train_metrics = []
610
+
611
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
612
+ # ======================== Evaluating ==============================
613
+ eval_metrics = []
614
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
615
+ eval_steps = len(eval_dataset) // eval_batch_size
616
+
617
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
618
+ batch = next(eval_loader)
619
+ metrics = p_eval_step(batch["input_ids"], batch["labels"], params)
620
+ eval_metrics.append(metrics)
621
+
622
+ # normalize eval metrics
623
+ eval_metrics = stack_forest(eval_metrics)
624
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
625
+
626
+ try:
627
+ eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
628
+ except OverflowError:
629
+ eval_metrics["perplexity"] = float("inf")
630
+
631
+ logger.info(
632
+ f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
633
+ )
634
+
635
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
636
+ # save checkpoint after each epoch and push checkpoint to the hub
637
+ if jax.process_index() == 0:
638
+ params = jax.device_get(params)
639
+ model.save_pretrained(
640
+ training_args.output_dir,
641
+ params=params,
642
+ push_to_hub=training_args.push_to_hub,
643
+ commit_message=f"Saving weights and logs of step {cur_step}",
644
+ )
645
+
646
+
647
+ if __name__ == "__main__":
648
+ main()
setup_devices.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from transformers import FlaxGPTNeoForCausalLM, GPTNeoConfig
4
+ model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
5
+
6
+ emb = jnp.zeros((50264, model.config.hidden_size))
7
+ # update the first 50257 weights using pre-trained weights
8
+ emb = jax.ops.index_update(emb, jax.ops.index[:50257, :], model.params["transformer"]["wte"]["embedding"])
9
+ params = model.params
10
+ params["transformer"]["wte"]["embedding"] = emb
11
+
12
+ # initialize a random model with the right vocab_size
13
+ config = GPTNeoConfig.from_pretrained("EleutherAI/gpt-neo-1.3B", vocab_size=50264)
14
+ model = FlaxGPTNeoForCausalLM(config)
15
+
16
+ # assign the pre-trained weights and save the model.
17
+ model.params = params
18
+ model.save_pretrained("gpt-neo-1.3B")