versae commited on
Commit
b28d4f7
1 Parent(s): 36a3fa6

Testing lr 3e-5

Browse files
.gitignore ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+ /model/
131
+
132
+ # VS Code
133
+ .vscode
134
+
135
+ wandb
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "roberta-base",
3
+ "architectures": [
4
+ "RobertaForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 768,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 514,
17
+ "model_type": "roberta",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "position_embedding_type": "absolute",
22
+ "transformers_version": "4.16.0.dev0",
23
+ "type_vocab_size": 1,
24
+ "use_cache": true,
25
+ "vocab_size": 50265
26
+ }
eval_results.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ {
2
+ "eval_accuracy": 0.6885889057979672,
3
+ "eval_loss": 1.430497475427537,
4
+ "eval_perplexity": 4.180778509252052
5
+ }
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6c1b729fa1c3dd64fbd625887667b363143278ab0161b5aeb411a6811b37680
3
+ size 498796983
merges.txt ADDED
The diff for this file is too large to render. See raw diff
run_mlm_flax.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=fill-mask
22
+ """
23
+ import json
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ from dataclasses import asdict, dataclass, field
30
+ from enum import Enum
31
+ from itertools import chain
32
+
33
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
34
+ from pathlib import Path
35
+ from typing import Dict, List, Optional, Tuple
36
+
37
+ import numpy as np
38
+ from datasets import load_dataset
39
+ from tqdm import tqdm
40
+
41
+ import flax
42
+ import jax
43
+ import jax.numpy as jnp
44
+ import optax
45
+ from flax import jax_utils, traverse_util
46
+ from flax.training import train_state
47
+ from flax.training.common_utils import get_metrics, onehot, shard
48
+ from huggingface_hub import Repository
49
+ from transformers import (
50
+ CONFIG_MAPPING,
51
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
52
+ AutoConfig,
53
+ AutoTokenizer,
54
+ FlaxAutoModelForMaskedLM,
55
+ HfArgumentParser,
56
+ PreTrainedTokenizerBase,
57
+ TensorType,
58
+ is_tensorboard_available,
59
+ set_seed,
60
+ )
61
+ from transformers.file_utils import get_full_repo_name
62
+
63
+
64
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
65
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
66
+
67
+
68
+ @dataclass
69
+ class TrainingArguments:
70
+ output_dir: str = field(
71
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
72
+ )
73
+ overwrite_output_dir: bool = field(
74
+ default=False,
75
+ metadata={
76
+ "help": (
77
+ "Overwrite the content of the output directory. "
78
+ "Use this to continue training if output_dir points to a checkpoint directory."
79
+ )
80
+ },
81
+ )
82
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
83
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
84
+ per_device_train_batch_size: int = field(
85
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
86
+ )
87
+ per_device_eval_batch_size: int = field(
88
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
89
+ )
90
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
91
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
92
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
93
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
94
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
95
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
96
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
97
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
98
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
99
+ save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
100
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
101
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
102
+ push_to_hub: bool = field(
103
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
104
+ )
105
+ hub_model_id: str = field(
106
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
107
+ )
108
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
109
+
110
+ def __post_init__(self):
111
+ if self.output_dir is not None:
112
+ self.output_dir = os.path.expanduser(self.output_dir)
113
+
114
+ def to_dict(self):
115
+ """
116
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
117
+ the token values by removing their value.
118
+ """
119
+ d = asdict(self)
120
+ for k, v in d.items():
121
+ if isinstance(v, Enum):
122
+ d[k] = v.value
123
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
124
+ d[k] = [x.value for x in v]
125
+ if k.endswith("_token"):
126
+ d[k] = f"<{k.upper()}>"
127
+ return d
128
+
129
+
130
+ @dataclass
131
+ class ModelArguments:
132
+ """
133
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
134
+ """
135
+
136
+ model_name_or_path: Optional[str] = field(
137
+ default=None,
138
+ metadata={
139
+ "help": "The model checkpoint for weights initialization."
140
+ "Don't set if you want to train a model from scratch."
141
+ },
142
+ )
143
+ model_type: Optional[str] = field(
144
+ default=None,
145
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
146
+ )
147
+ config_name: Optional[str] = field(
148
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
149
+ )
150
+ tokenizer_name: Optional[str] = field(
151
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
152
+ )
153
+ cache_dir: Optional[str] = field(
154
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
155
+ )
156
+ use_fast_tokenizer: bool = field(
157
+ default=True,
158
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
159
+ )
160
+ dtype: Optional[str] = field(
161
+ default="float32",
162
+ metadata={
163
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
164
+ },
165
+ )
166
+
167
+
168
+ @dataclass
169
+ class DataTrainingArguments:
170
+ """
171
+ Arguments pertaining to what data we are going to input our model for training and eval.
172
+ """
173
+
174
+ dataset_name: Optional[str] = field(
175
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
176
+ )
177
+ dataset_config_name: Optional[str] = field(
178
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
179
+ )
180
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
181
+ validation_file: Optional[str] = field(
182
+ default=None,
183
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
184
+ )
185
+ train_ref_file: Optional[str] = field(
186
+ default=None,
187
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
188
+ )
189
+ validation_ref_file: Optional[str] = field(
190
+ default=None,
191
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
192
+ )
193
+ overwrite_cache: bool = field(
194
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
195
+ )
196
+ validation_split_percentage: Optional[int] = field(
197
+ default=5,
198
+ metadata={
199
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
200
+ },
201
+ )
202
+ max_seq_length: Optional[int] = field(
203
+ default=None,
204
+ metadata={
205
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
206
+ "than this will be truncated. Default to the max input length of the model."
207
+ },
208
+ )
209
+ preprocessing_num_workers: Optional[int] = field(
210
+ default=None,
211
+ metadata={"help": "The number of processes to use for the preprocessing."},
212
+ )
213
+ mlm_probability: float = field(
214
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
215
+ )
216
+ pad_to_max_length: bool = field(
217
+ default=False,
218
+ metadata={
219
+ "help": "Whether to pad all samples to `max_seq_length`. "
220
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
221
+ },
222
+ )
223
+ line_by_line: bool = field(
224
+ default=False,
225
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
226
+ )
227
+
228
+ def __post_init__(self):
229
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
230
+ raise ValueError("Need either a dataset name or a training/validation file.")
231
+ else:
232
+ if self.train_file is not None:
233
+ extension = self.train_file.split(".")[-1]
234
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
235
+ if self.validation_file is not None:
236
+ extension = self.validation_file.split(".")[-1]
237
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
238
+
239
+
240
+ @flax.struct.dataclass
241
+ class FlaxDataCollatorForLanguageModeling:
242
+ """
243
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
244
+ are not all of the same length.
245
+
246
+ Args:
247
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
248
+ The tokenizer used for encoding the data.
249
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
250
+ The probability with which to (randomly) mask tokens in the input.
251
+
252
+ .. note::
253
+
254
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
255
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
256
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
257
+ argument :obj:`return_special_tokens_mask=True`.
258
+ """
259
+
260
+ tokenizer: PreTrainedTokenizerBase
261
+ mlm_probability: float = 0.15
262
+
263
+ def __post_init__(self):
264
+ if self.tokenizer.mask_token is None:
265
+ raise ValueError(
266
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
267
+ "You should pass `mlm=False` to train on causal language modeling instead."
268
+ )
269
+
270
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
271
+ # Handle dict or lists with proper padding and conversion to tensor.
272
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
273
+
274
+ # If special token mask has been preprocessed, pop it from the dict.
275
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
276
+
277
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
278
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
279
+ )
280
+ return batch
281
+
282
+ def mask_tokens(
283
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
284
+ ) -> Tuple[np.ndarray, np.ndarray]:
285
+ """
286
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
287
+ """
288
+ labels = inputs.copy()
289
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
290
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
291
+ special_tokens_mask = special_tokens_mask.astype("bool")
292
+
293
+ probability_matrix[special_tokens_mask] = 0.0
294
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
295
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
296
+
297
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
298
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
299
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
300
+
301
+ # 10% of the time, we replace masked input tokens with random word
302
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
303
+ indices_random &= masked_indices & ~indices_replaced
304
+
305
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
306
+ inputs[indices_random] = random_words[indices_random]
307
+
308
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
309
+ return inputs, labels
310
+
311
+
312
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
313
+ num_samples = len(samples_idx)
314
+ samples_to_remove = num_samples % batch_size
315
+
316
+ if samples_to_remove != 0:
317
+ samples_idx = samples_idx[:-samples_to_remove]
318
+ sections_split = num_samples // batch_size
319
+ batch_idx = np.split(samples_idx, sections_split)
320
+ return batch_idx
321
+
322
+
323
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
324
+ summary_writer.scalar("train_time", train_time, step)
325
+
326
+ train_metrics = get_metrics(train_metrics)
327
+ for key, vals in train_metrics.items():
328
+ tag = f"train_{key}"
329
+ for i, val in enumerate(vals):
330
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
331
+
332
+
333
+ def write_eval_metric(summary_writer, eval_metrics, step):
334
+ for metric_name, value in eval_metrics.items():
335
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
336
+
337
+
338
+ def main():
339
+ # See all possible arguments in src/transformers/training_args.py
340
+ # or by passing the --help flag to this script.
341
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
342
+
343
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
344
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
345
+ # If we pass only one argument to the script and it's the path to a json file,
346
+ # let's parse it to get our arguments.
347
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
348
+ else:
349
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
350
+
351
+ if (
352
+ os.path.exists(training_args.output_dir)
353
+ and os.listdir(training_args.output_dir)
354
+ and training_args.do_train
355
+ and not training_args.overwrite_output_dir
356
+ ):
357
+ raise ValueError(
358
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
359
+ "Use --overwrite_output_dir to overcome."
360
+ )
361
+
362
+ # Setup logging
363
+ logging.basicConfig(
364
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
365
+ level=logging.INFO,
366
+ datefmt="[%X]",
367
+ )
368
+
369
+ # Log on each process the small summary:
370
+ logger = logging.getLogger(__name__)
371
+
372
+ # Set the verbosity to info of the Transformers logger (on main process only):
373
+ logger.info(f"Training/evaluation parameters {training_args}")
374
+
375
+ # Set seed before initializing model.
376
+ set_seed(training_args.seed)
377
+
378
+ # Handle the repository creation
379
+ if training_args.push_to_hub:
380
+ if training_args.hub_model_id is None:
381
+ repo_name = get_full_repo_name(
382
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
383
+ )
384
+ else:
385
+ repo_name = training_args.hub_model_id
386
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
387
+
388
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
389
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
390
+ # (the dataset will be downloaded automatically from the datasets Hub).
391
+ #
392
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
393
+ # 'text' is found. You can easily tweak this behavior (see below).
394
+ #
395
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
396
+ # download the dataset.
397
+ if data_args.dataset_name is not None:
398
+ # Downloading and loading a dataset from the hub.
399
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
400
+
401
+ if "validation" not in datasets.keys():
402
+ datasets["validation"] = load_dataset(
403
+ data_args.dataset_name,
404
+ data_args.dataset_config_name,
405
+ split=f"train[:{data_args.validation_split_percentage}%]",
406
+ cache_dir=model_args.cache_dir,
407
+ )
408
+ datasets["train"] = load_dataset(
409
+ data_args.dataset_name,
410
+ data_args.dataset_config_name,
411
+ split=f"train[{data_args.validation_split_percentage}%:]",
412
+ cache_dir=model_args.cache_dir,
413
+ )
414
+ else:
415
+ data_files = {}
416
+ if data_args.train_file is not None:
417
+ data_files["train"] = data_args.train_file
418
+ if data_args.validation_file is not None:
419
+ data_files["validation"] = data_args.validation_file
420
+ extension = data_args.train_file.split(".")[-1]
421
+ if extension == "txt":
422
+ extension = "text"
423
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
424
+
425
+ if "validation" not in datasets.keys():
426
+ datasets["validation"] = load_dataset(
427
+ extension,
428
+ data_files=data_files,
429
+ split=f"train[:{data_args.validation_split_percentage}%]",
430
+ cache_dir=model_args.cache_dir,
431
+ )
432
+ datasets["train"] = load_dataset(
433
+ extension,
434
+ data_files=data_files,
435
+ split=f"train[{data_args.validation_split_percentage}%:]",
436
+ cache_dir=model_args.cache_dir,
437
+ )
438
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
439
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
440
+
441
+ # Load pretrained model and tokenizer
442
+
443
+ # Distributed training:
444
+ # The .from_pretrained methods guarantee that only one local process can concurrently
445
+ # download model & vocab.
446
+ if model_args.config_name:
447
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
448
+ elif model_args.model_name_or_path:
449
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
450
+ else:
451
+ config = CONFIG_MAPPING[model_args.model_type]()
452
+ logger.warning("You are instantiating a new config instance from scratch.")
453
+
454
+ if model_args.tokenizer_name:
455
+ tokenizer = AutoTokenizer.from_pretrained(
456
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
457
+ )
458
+ elif model_args.model_name_or_path:
459
+ tokenizer = AutoTokenizer.from_pretrained(
460
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
461
+ )
462
+ else:
463
+ raise ValueError(
464
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
465
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
466
+ )
467
+
468
+ # Preprocessing the datasets.
469
+ # First we tokenize all the texts.
470
+ if training_args.do_train:
471
+ column_names = datasets["train"].column_names
472
+ else:
473
+ column_names = datasets["validation"].column_names
474
+ text_column_name = "text" if "text" in column_names else column_names[0]
475
+
476
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
477
+
478
+ if data_args.line_by_line:
479
+ # When using line_by_line, we just tokenize each nonempty line.
480
+ padding = "max_length" if data_args.pad_to_max_length else False
481
+
482
+ def tokenize_function(examples):
483
+ # Remove empty lines
484
+ examples = [line for line in examples if len(line) > 0 and not line.isspace()]
485
+ return tokenizer(
486
+ examples,
487
+ return_special_tokens_mask=True,
488
+ padding=padding,
489
+ truncation=True,
490
+ max_length=max_seq_length,
491
+ )
492
+
493
+ tokenized_datasets = datasets.map(
494
+ tokenize_function,
495
+ input_columns=[text_column_name],
496
+ batched=True,
497
+ num_proc=data_args.preprocessing_num_workers,
498
+ remove_columns=column_names,
499
+ load_from_cache_file=not data_args.overwrite_cache,
500
+ )
501
+
502
+ else:
503
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
504
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
505
+ # efficient when it receives the `special_tokens_mask`.
506
+ def tokenize_function(examples):
507
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
508
+
509
+ tokenized_datasets = datasets.map(
510
+ tokenize_function,
511
+ batched=True,
512
+ num_proc=data_args.preprocessing_num_workers,
513
+ remove_columns=column_names,
514
+ load_from_cache_file=not data_args.overwrite_cache,
515
+ )
516
+
517
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
518
+ # max_seq_length.
519
+ def group_texts(examples):
520
+ # Concatenate all texts.
521
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
522
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
523
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
524
+ # customize this part to your needs.
525
+ if total_length >= max_seq_length:
526
+ total_length = (total_length // max_seq_length) * max_seq_length
527
+ # Split by chunks of max_len.
528
+ result = {
529
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
530
+ for k, t in concatenated_examples.items()
531
+ }
532
+ return result
533
+
534
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
535
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
536
+ # might be slower to preprocess.
537
+ #
538
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
539
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
540
+ tokenized_datasets = tokenized_datasets.map(
541
+ group_texts,
542
+ batched=True,
543
+ num_proc=data_args.preprocessing_num_workers,
544
+ load_from_cache_file=not data_args.overwrite_cache,
545
+ )
546
+
547
+ # Enable tensorboard only on the master node
548
+ has_tensorboard = is_tensorboard_available()
549
+ if has_tensorboard and jax.process_index() == 0:
550
+ try:
551
+ # Enable Weight&Biases
552
+ import wandb
553
+ wandb.init(
554
+ entity='versae',
555
+ project='roberta-base-ncc',
556
+ sync_tensorboard=False,
557
+ )
558
+ wandb.config.update(training_args)
559
+ wandb.config.update(model_args)
560
+ wandb.config.update(data_args)
561
+
562
+ from flax.metrics.tensorboard import SummaryWriter
563
+
564
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
565
+ except ImportError as ie:
566
+ has_tensorboard = False
567
+ logger.warning(
568
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
569
+ )
570
+ else:
571
+ logger.warning(
572
+ "Unable to display metrics through TensorBoard because the package is not installed: "
573
+ "Please run pip install tensorboard to enable."
574
+ )
575
+
576
+ # Data collator
577
+ # This one will take care of randomly masking the tokens.
578
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
579
+
580
+ # Initialize our training
581
+ rng = jax.random.PRNGKey(training_args.seed)
582
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
583
+
584
+ if model_args.model_name_or_path:
585
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
586
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
587
+ )
588
+ else:
589
+ model = FlaxAutoModelForMaskedLM.from_config(
590
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
591
+ )
592
+
593
+ # Store some constant
594
+ num_epochs = int(training_args.num_train_epochs)
595
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
596
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
597
+
598
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
599
+
600
+ # Create learning rate schedule
601
+ warmup_fn = optax.linear_schedule(
602
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
603
+ )
604
+ decay_fn = optax.linear_schedule(
605
+ init_value=training_args.learning_rate,
606
+ end_value=0,
607
+ transition_steps=num_train_steps - training_args.warmup_steps,
608
+ )
609
+ linear_decay_lr_schedule_fn = optax.join_schedules(
610
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
611
+ )
612
+
613
+ # We use Optax's "masking" functionality to not apply weight decay
614
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
615
+ # mask boolean with the same structure as the parameters.
616
+ # The mask is True for parameters that should be decayed.
617
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
618
+ # For other models, one should correct the layer norm parameter naming
619
+ # accordingly.
620
+ def decay_mask_fn(params):
621
+ flat_params = traverse_util.flatten_dict(params)
622
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
623
+ return traverse_util.unflatten_dict(flat_mask)
624
+
625
+ # create adam optimizer
626
+ if training_args.adafactor:
627
+ # We use the default parameters here to initialize adafactor,
628
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
629
+ optimizer = optax.adafactor(
630
+ learning_rate=linear_decay_lr_schedule_fn,
631
+ )
632
+ else:
633
+ optimizer = optax.adamw(
634
+ learning_rate=linear_decay_lr_schedule_fn,
635
+ b1=training_args.adam_beta1,
636
+ b2=training_args.adam_beta2,
637
+ eps=training_args.adam_epsilon,
638
+ weight_decay=training_args.weight_decay,
639
+ mask=decay_mask_fn,
640
+ )
641
+
642
+ # Setup train state
643
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
644
+
645
+ # Define gradient update step fn
646
+ def train_step(state, batch, dropout_rng):
647
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
648
+
649
+ def loss_fn(params):
650
+ labels = batch.pop("labels")
651
+
652
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
653
+
654
+ # compute loss, ignore padded input tokens
655
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
656
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
657
+
658
+ # take average
659
+ loss = loss.sum() / label_mask.sum()
660
+
661
+ return loss
662
+
663
+ grad_fn = jax.value_and_grad(loss_fn)
664
+ loss, grad = grad_fn(state.params)
665
+ grad = jax.lax.pmean(grad, "batch")
666
+ new_state = state.apply_gradients(grads=grad)
667
+
668
+ metrics = jax.lax.pmean(
669
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
670
+ )
671
+
672
+ return new_state, metrics, new_dropout_rng
673
+
674
+ # Create parallel version of the train step
675
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
676
+
677
+ # Define eval fn
678
+ def eval_step(params, batch):
679
+ labels = batch.pop("labels")
680
+
681
+ logits = model(**batch, params=params, train=False)[0]
682
+
683
+ # compute loss, ignore padded input tokens
684
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
685
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
686
+
687
+ # compute accuracy
688
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
689
+
690
+ # summarize metrics
691
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
692
+ metrics = jax.lax.psum(metrics, axis_name="batch")
693
+
694
+ return metrics
695
+
696
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
697
+
698
+ # Replicate the train state on each device
699
+ state = jax_utils.replicate(state)
700
+
701
+ train_time = 0
702
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
703
+ for epoch in epochs:
704
+ # ======================== Training ================================
705
+ train_start = time.time()
706
+ train_metrics = []
707
+
708
+ # Create sampling rng
709
+ rng, input_rng = jax.random.split(rng)
710
+
711
+ # Generate an epoch by shuffling sampling indices from the train dataset
712
+ num_train_samples = len(tokenized_datasets["train"])
713
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
714
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
715
+
716
+ # Gather the indexes for creating the batch and do a training step
717
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
718
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
719
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
720
+
721
+ # Model forward
722
+ model_inputs = shard(model_inputs.data)
723
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
724
+ train_metrics.append(train_metric)
725
+
726
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
727
+
728
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
729
+ # Save metrics
730
+ train_metric = jax_utils.unreplicate(train_metric)
731
+ train_time += time.time() - train_start
732
+ if has_tensorboard and jax.process_index() == 0:
733
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
734
+
735
+ epochs.write(
736
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
737
+ )
738
+
739
+ train_metrics = []
740
+
741
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
742
+ # ======================== Evaluating ==============================
743
+ num_eval_samples = len(tokenized_datasets["validation"])
744
+ eval_samples_idx = jnp.arange(num_eval_samples)
745
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
746
+
747
+ eval_metrics = []
748
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
749
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
750
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
751
+
752
+ # Model forward
753
+ model_inputs = shard(model_inputs.data)
754
+ metrics = p_eval_step(state.params, model_inputs)
755
+ eval_metrics.append(metrics)
756
+
757
+ # normalize eval metrics
758
+ eval_metrics = get_metrics(eval_metrics)
759
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
760
+ eval_normalizer = eval_metrics.pop("normalizer")
761
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
762
+
763
+ # Update progress bar
764
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
765
+
766
+ # Save metrics
767
+ if has_tensorboard and jax.process_index() == 0:
768
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
769
+
770
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
771
+ # save checkpoint after each epoch and push checkpoint to the hub
772
+ if jax.process_index() == 0:
773
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
774
+ model.save_pretrained(training_args.output_dir, params=params)
775
+ tokenizer.save_pretrained(training_args.output_dir)
776
+ if training_args.push_to_hub:
777
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
778
+
779
+ # Eval after training
780
+ if training_args.do_eval:
781
+ num_eval_samples = len(tokenized_datasets["validation"])
782
+ eval_samples_idx = jnp.arange(num_eval_samples)
783
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
784
+
785
+ eval_metrics = []
786
+ for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
787
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
788
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
789
+
790
+ # Model forward
791
+ model_inputs = shard(model_inputs.data)
792
+ metrics = p_eval_step(state.params, model_inputs)
793
+ eval_metrics.append(metrics)
794
+
795
+ # normalize eval metrics
796
+ eval_metrics = get_metrics(eval_metrics)
797
+ eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
798
+ eval_normalizer = eval_metrics.pop("normalizer")
799
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
800
+
801
+ try:
802
+ perplexity = math.exp(eval_metrics["loss"])
803
+ except OverflowError:
804
+ perplexity = float("inf")
805
+ eval_metrics["perplexity"] = perplexity
806
+
807
+ if jax.process_index() == 0:
808
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
809
+ path = os.path.join(training_args.output_dir, "eval_results.json")
810
+ with open(path, "w") as f:
811
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
812
+
813
+
814
+ if __name__ == "__main__":
815
+ main()
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
1
+ {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "add_prefix_space": false, "errors": "replace", "sep_token": "</s>", "cls_token": "<s>", "pad_token": "<pad>", "mask_token": "<mask>", "trim_offsets": true, "special_tokens_map_file": null, "name_or_path": "NbAiLab/nb-roberta-base", "tokenizer_class": "RobertaTokenizer"}
train.128.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_mlm_flax.py \
2
+ --output_dir="./" \
3
+ --model_type="roberta" \
4
+ --config_name="roberta-base" \
5
+ --tokenizer_name="NbAiLab/nb-roberta-base" \
6
+ --dataset_name="NbAiLab/NCC" \
7
+ --max_seq_length="128" \
8
+ --weight_decay="0.01" \
9
+ --per_device_train_batch_size="232" \
10
+ --per_device_eval_batch_size="232" \
11
+ --pad_to_max_length \
12
+ --learning_rate="6e-4" \
13
+ --warmup_steps="10000" \
14
+ --overwrite_output_dir \
15
+ --num_train_epochs="3" \
16
+ --adam_beta1="0.9" \
17
+ --adam_beta2="0.98" \
18
+ --adam_epsilon="1e-6" \
19
+ --logging_steps="1000" \
20
+ --save_steps="1000" \
21
+ --eval_steps="1000" \
22
+ --do_train \
23
+ --do_eval \
24
+ --dtype="bfloat16" \
25
+ --push_to_hub
train.512.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python run_mlm_flax.py \
2
+ --output_dir="./" \
3
+ --model_type="roberta" \
4
+ --config_name="./" \
5
+ --tokenizer_name="NbAiLab/nb-roberta-base" \
6
+ --dataset_name="NbAiLab/NCC" \
7
+ --max_seq_length="512" \
8
+ --weight_decay="0.01" \
9
+ --per_device_train_batch_size="46" \
10
+ --per_device_eval_batch_size="46" \
11
+ --pad_to_max_length \
12
+ --learning_rate="3e-5" \
13
+ --warmup_steps="0" \
14
+ --overwrite_output_dir \
15
+ --num_train_epochs="3" \
16
+ --adam_beta1="0.9" \
17
+ --adam_beta2="0.98" \
18
+ --adam_epsilon="1e-6" \
19
+ --logging_steps="1000" \
20
+ --save_steps="1000" \
21
+ --eval_steps="1000" \
22
+ --do_train \
23
+ --do_eval \
24
+ --dtype="bfloat16" \
25
+ --push_to_hub
vocab.json ADDED
The diff for this file is too large to render. See raw diff