aapot commited on
Commit
752f635
1 Parent(s): dd99550

Initial commit

Browse files
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/researchdisk/roberta-large-finnish-wechsel",
3
+ "architectures": [
4
+ "RobertaModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 514,
17
+ "model_type": "roberta",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "pad_token_id": 1,
21
+ "position_embedding_type": "absolute",
22
+ "torch_dtype": "float64",
23
+ "transformers_version": "4.13.0.dev0",
24
+ "type_vocab_size": 1,
25
+ "use_cache": true,
26
+ "vocab_size": 50265
27
+ }
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9a123577826ae147f24d257b4f877eaa05fd6b67d294bef5786cd5b174f7eb7
3
+ size 1421452955
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
run_mlm_flax.py ADDED
@@ -0,0 +1,869 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=fill-mask
22
+ """
23
+ import json
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ import time
29
+ import gc
30
+ from dataclasses import asdict, dataclass, field
31
+ from enum import Enum
32
+ from itertools import chain
33
+
34
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
35
+ from pathlib import Path
36
+ from typing import Dict, List, Optional, Tuple
37
+
38
+ import numpy as np
39
+ from datasets import load_dataset, load_from_disk
40
+ from tqdm import tqdm
41
+
42
+ import flax
43
+ import jax
44
+ import jax.numpy as jnp
45
+ import optax
46
+ from flax import jax_utils, traverse_util
47
+ from flax.training import train_state
48
+ from flax.training.common_utils import get_metrics, onehot, shard
49
+ from huggingface_hub import Repository
50
+ from transformers import (
51
+ CONFIG_MAPPING,
52
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
53
+ AutoConfig,
54
+ AutoTokenizer,
55
+ FlaxAutoModelForMaskedLM,
56
+ HfArgumentParser,
57
+ PreTrainedTokenizerBase,
58
+ TensorType,
59
+ is_tensorboard_available,
60
+ set_seed,
61
+ )
62
+ from transformers.file_utils import get_full_repo_name
63
+
64
+
65
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
66
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
67
+
68
+ @dataclass
69
+ class TrainingArguments:
70
+ output_dir: str = field(
71
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
72
+ )
73
+ overwrite_output_dir: bool = field(
74
+ default=False,
75
+ metadata={
76
+ "help": (
77
+ "Overwrite the content of the output directory. "
78
+ "Use this to continue training if output_dir points to a checkpoint directory."
79
+ )
80
+ },
81
+ )
82
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
83
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
84
+ per_device_train_batch_size: int = field(
85
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
86
+ )
87
+ per_device_eval_batch_size: int = field(
88
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
89
+ )
90
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
91
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
92
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
93
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
94
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
95
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
96
+ adabelief: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adabelief."})
97
+ sm3: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by SM3."})
98
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
99
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
100
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
101
+ save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
102
+ eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
103
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
104
+ push_to_hub: bool = field(
105
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
106
+ )
107
+ hub_model_id: str = field(
108
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
109
+ )
110
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
111
+
112
+ def __post_init__(self):
113
+ if self.output_dir is not None:
114
+ self.output_dir = os.path.expanduser(self.output_dir)
115
+
116
+ def to_dict(self):
117
+ """
118
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
119
+ the token values by removing their value.
120
+ """
121
+ d = asdict(self)
122
+ for k, v in d.items():
123
+ if isinstance(v, Enum):
124
+ d[k] = v.value
125
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
126
+ d[k] = [x.value for x in v]
127
+ if k.endswith("_token"):
128
+ d[k] = f"<{k.upper()}>"
129
+ return d
130
+
131
+ @dataclass
132
+ class ModelArguments:
133
+ """
134
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
135
+ """
136
+
137
+ model_name_or_path: Optional[str] = field(
138
+ default=None,
139
+ metadata={
140
+ "help": "The model checkpoint for weights initialization."
141
+ "Don't set if you want to train a model from scratch."
142
+ },
143
+ )
144
+ model_type: Optional[str] = field(
145
+ default=None,
146
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
147
+ )
148
+ config_name: Optional[str] = field(
149
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
150
+ )
151
+ tokenizer_name: Optional[str] = field(
152
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
153
+ )
154
+ cache_dir: Optional[str] = field(
155
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
156
+ )
157
+ use_fast_tokenizer: bool = field(
158
+ default=True,
159
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
160
+ )
161
+ dtype: Optional[str] = field(
162
+ default="float32",
163
+ metadata={
164
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
165
+ },
166
+ )
167
+
168
+
169
+ @dataclass
170
+ class DataTrainingArguments:
171
+ """
172
+ Arguments pertaining to what data we are going to input our model for training and eval.
173
+ """
174
+
175
+ dataset_name: Optional[str] = field(
176
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
177
+ )
178
+ dataset_config_name: Optional[str] = field(
179
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
180
+ )
181
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
182
+ validation_file: Optional[str] = field(
183
+ default=None,
184
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
185
+ )
186
+ dataset_filepath: Optional[str] = field(
187
+ default=None, metadata={"help": "Filepath to locally saved HF Dataset (with 'dataset.save_to_disk' method) to use for training"}
188
+ )
189
+ save_tokenized_dataset_filepath: Optional[str] = field(
190
+ default=None, metadata={"help": "Filepath for saving tokenized HF Dataset (with 'dataset.save_to_disk' method) to use for future trainings"}
191
+ )
192
+ tokenized_dataset_filepath: Optional[str] = field(
193
+ default=None, metadata={"help": "Filepath to locally saved tokenized HF Dataset (with 'dataset.save_to_disk' method) to use for training"}
194
+ )
195
+ train_ref_file: Optional[str] = field(
196
+ default=None,
197
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
198
+ )
199
+ validation_ref_file: Optional[str] = field(
200
+ default=None,
201
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
202
+ )
203
+ overwrite_cache: bool = field(
204
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
205
+ )
206
+ validation_split_percentage: Optional[int] = field(
207
+ default=5,
208
+ metadata={
209
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
210
+ },
211
+ )
212
+ max_seq_length: Optional[int] = field(
213
+ default=None,
214
+ metadata={
215
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
216
+ "than this will be truncated. Default to the max input length of the model."
217
+ },
218
+ )
219
+ preprocessing_num_workers: Optional[int] = field(
220
+ default=None,
221
+ metadata={"help": "The number of processes to use for the preprocessing."},
222
+ )
223
+ mlm_probability: float = field(
224
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
225
+ )
226
+ pad_to_max_length: bool = field(
227
+ default=False,
228
+ metadata={
229
+ "help": "Whether to pad all samples to `max_seq_length`. "
230
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
231
+ },
232
+ )
233
+ line_by_line: bool = field(
234
+ default=False,
235
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
236
+ )
237
+
238
+ def __post_init__(self):
239
+ if self.dataset_name is None and self.train_file is None and self.dataset_filepath is None and self.validation_file is None:
240
+ raise ValueError("Need either a dataset name or a training/validation file.")
241
+ else:
242
+ if self.train_file is not None:
243
+ extension = self.train_file.split(".")[-1]
244
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
245
+ if self.validation_file is not None:
246
+ extension = self.validation_file.split(".")[-1]
247
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
248
+
249
+
250
+ @flax.struct.dataclass
251
+ class FlaxDataCollatorForLanguageModeling:
252
+ """
253
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
254
+ are not all of the same length.
255
+
256
+ Args:
257
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
258
+ The tokenizer used for encoding the data.
259
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
260
+ The probability with which to (randomly) mask tokens in the input.
261
+
262
+ .. note::
263
+
264
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
265
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
266
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
267
+ argument :obj:`return_special_tokens_mask=True`.
268
+ """
269
+
270
+ tokenizer: PreTrainedTokenizerBase
271
+ mlm_probability: float = 0.15
272
+
273
+ def __post_init__(self):
274
+ if self.tokenizer.mask_token is None:
275
+ raise ValueError(
276
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
277
+ "You should pass `mlm=False` to train on causal language modeling instead."
278
+ )
279
+
280
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
281
+ # Handle dict or lists with proper padding and conversion to tensor.
282
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
283
+
284
+ # If special token mask has been preprocessed, pop it from the dict.
285
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
286
+
287
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
288
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
289
+ )
290
+ return batch
291
+
292
+ def mask_tokens(
293
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
294
+ ) -> Tuple[np.ndarray, np.ndarray]:
295
+ """
296
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
297
+ """
298
+ labels = inputs.copy()
299
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
300
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
301
+ special_tokens_mask = special_tokens_mask.astype("bool")
302
+
303
+ probability_matrix[special_tokens_mask] = 0.0
304
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
305
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
306
+
307
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
308
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
309
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
310
+
311
+ # 10% of the time, we replace masked input tokens with random word
312
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
313
+ indices_random &= masked_indices & ~indices_replaced
314
+
315
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
316
+ inputs[indices_random] = random_words[indices_random]
317
+
318
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
319
+ return inputs, labels
320
+
321
+
322
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
323
+ num_samples = len(samples_idx)
324
+ samples_to_remove = num_samples % batch_size
325
+
326
+ if samples_to_remove != 0:
327
+ samples_idx = samples_idx[:-samples_to_remove]
328
+ sections_split = num_samples // batch_size
329
+ batch_idx = np.split(samples_idx, sections_split)
330
+ return batch_idx
331
+
332
+
333
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
334
+ summary_writer.scalar("train_time", train_time, step)
335
+
336
+ train_metrics = get_metrics(train_metrics)
337
+ for key, vals in train_metrics.items():
338
+ tag = f"train_{key}"
339
+ for i, val in enumerate(vals):
340
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
341
+
342
+
343
+ def write_eval_metric(summary_writer, eval_metrics, step):
344
+ for metric_name, value in eval_metrics.items():
345
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
346
+
347
+
348
+ def main():
349
+ # See all possible arguments in src/transformers/training_args.py
350
+ # or by passing the --help flag to this script.
351
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
352
+
353
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
354
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
355
+ # If we pass only one argument to the script and it's the path to a json file,
356
+ # let's parse it to get our arguments.
357
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
358
+ else:
359
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
360
+
361
+ if (
362
+ os.path.exists(training_args.output_dir)
363
+ and os.listdir(training_args.output_dir)
364
+ and training_args.do_train
365
+ and not training_args.overwrite_output_dir
366
+ ):
367
+ raise ValueError(
368
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
369
+ "Use --overwrite_output_dir to overcome."
370
+ )
371
+
372
+ # Setup logging
373
+ logging.basicConfig(
374
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
375
+ level="NOTSET",
376
+ datefmt="[%X]",
377
+ )
378
+
379
+ # Log on each process the small summary:
380
+ logger = logging.getLogger(__name__)
381
+
382
+ # Set the verbosity to info of the Transformers logger (on main process only):
383
+ logger.info(f"Training/evaluation parameters {training_args}")
384
+
385
+ # Set seed before initializing model.
386
+ set_seed(training_args.seed)
387
+
388
+ # Handle the repository creation
389
+ if training_args.push_to_hub:
390
+ if training_args.hub_model_id is None:
391
+ repo_name = get_full_repo_name(
392
+ Path(training_args.output_dir).absolute().name, token=training_args.hub_token
393
+ )
394
+ else:
395
+ repo_name = training_args.hub_model_id
396
+ repo = Repository(training_args.output_dir, clone_from=repo_name)
397
+
398
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
399
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
400
+ # (the dataset will be downloaded automatically from the datasets Hub).
401
+ #
402
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
403
+ # 'text' is found. You can easily tweak this behavior (see below).
404
+ #
405
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
406
+ # download the dataset.
407
+ if data_args.dataset_name is not None:
408
+ # Downloading and loading a dataset from the hub.
409
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
410
+
411
+ if "validation" not in datasets.keys():
412
+ datasets["validation"] = load_dataset(
413
+ data_args.dataset_name,
414
+ data_args.dataset_config_name,
415
+ split=f"train[:{data_args.validation_split_percentage}%]",
416
+ cache_dir=model_args.cache_dir,
417
+ )
418
+ datasets["train"] = load_dataset(
419
+ data_args.dataset_name,
420
+ data_args.dataset_config_name,
421
+ split=f"train[{data_args.validation_split_percentage}%:]",
422
+ cache_dir=model_args.cache_dir,
423
+ )
424
+ elif data_args.dataset_filepath is not None:
425
+ # Loading a dataset from local file.
426
+ datasets = load_from_disk(data_args.dataset_filepath)
427
+ if "validation" not in datasets.keys():
428
+ datasets = datasets.train_test_split(test_size=data_args.validation_split_percentage/100)
429
+ datasets["validation"] = datasets["test"]
430
+ del datasets["test"]
431
+
432
+ else:
433
+ data_files = {}
434
+ if data_args.train_file is not None:
435
+ data_files["train"] = data_args.train_file
436
+ if data_args.validation_file is not None:
437
+ data_files["validation"] = data_args.validation_file
438
+ extension = data_args.train_file.split(".")[-1]
439
+ if extension == "txt":
440
+ extension = "text"
441
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
442
+
443
+ if "validation" not in datasets.keys():
444
+ datasets["validation"] = load_dataset(
445
+ extension,
446
+ data_files=data_files,
447
+ split=f"train[:{data_args.validation_split_percentage}%]",
448
+ cache_dir=model_args.cache_dir,
449
+ )
450
+ datasets["train"] = load_dataset(
451
+ extension,
452
+ data_files=data_files,
453
+ split=f"train[{data_args.validation_split_percentage}%:]",
454
+ cache_dir=model_args.cache_dir,
455
+ )
456
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
457
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
458
+
459
+ # Load pretrained model and tokenizer
460
+
461
+ # Distributed training:
462
+ # The .from_pretrained methods guarantee that only one local process can concurrently
463
+ # download model & vocab.
464
+ if model_args.config_name:
465
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
466
+ elif model_args.model_name_or_path:
467
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
468
+ else:
469
+ config = CONFIG_MAPPING[model_args.model_type]()
470
+ logger.warning("You are instantiating a new config instance from scratch.")
471
+
472
+ if model_args.tokenizer_name:
473
+ tokenizer = AutoTokenizer.from_pretrained(
474
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
475
+ )
476
+ elif model_args.model_name_or_path:
477
+ tokenizer = AutoTokenizer.from_pretrained(
478
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
479
+ )
480
+ else:
481
+ raise ValueError(
482
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
483
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
484
+ )
485
+
486
+ # Preprocessing the datasets.
487
+ # First we tokenize all the texts.
488
+ if training_args.do_train:
489
+ column_names = datasets["train"].column_names
490
+ else:
491
+ column_names = datasets["validation"].column_names
492
+ text_column_name = "text" if "text" in column_names else column_names[0]
493
+
494
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
495
+
496
+ if data_args.line_by_line:
497
+ # When using line_by_line, we just tokenize each nonempty line.
498
+ padding = "max_length" if data_args.pad_to_max_length else False
499
+
500
+ def tokenize_function(examples):
501
+ # Remove empty lines
502
+ examples = [line for line in examples if len(line) > 0 and not line.isspace()]
503
+ return tokenizer(
504
+ examples,
505
+ return_special_tokens_mask=True,
506
+ padding=padding,
507
+ truncation=True,
508
+ max_length=max_seq_length,
509
+ )
510
+
511
+ if data_args.tokenized_dataset_filepath is not None:
512
+ # Loading a tokenized dataset from local file.
513
+ tokenized_datasets = load_from_disk(data_args.tokenized_dataset_filepath)
514
+ else:
515
+ tokenized_datasets = datasets.map(
516
+ tokenize_function,
517
+ input_columns=[text_column_name],
518
+ batched=True,
519
+ num_proc=data_args.preprocessing_num_workers,
520
+ remove_columns=column_names,
521
+ load_from_cache_file=not data_args.overwrite_cache,
522
+ )
523
+
524
+ else:
525
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
526
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
527
+ # efficient when it receives the `special_tokens_mask`.
528
+ def tokenize_function(examples):
529
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
530
+
531
+ if data_args.tokenized_dataset_filepath is not None:
532
+ # Loading a tokenized dataset from local file.
533
+ tokenized_datasets = load_from_disk(data_args.tokenized_dataset_filepath)
534
+ else:
535
+ tokenized_datasets = datasets.map(
536
+ tokenize_function,
537
+ batched=True,
538
+ num_proc=data_args.preprocessing_num_workers,
539
+ remove_columns=column_names,
540
+ load_from_cache_file=not data_args.overwrite_cache,
541
+ )
542
+
543
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
544
+ # max_seq_length.
545
+ def group_texts(examples):
546
+ # Concatenate all texts.
547
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
548
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
549
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
550
+ # customize this part to your needs.
551
+ if total_length >= max_seq_length:
552
+ total_length = (total_length // max_seq_length) * max_seq_length
553
+ # Split by chunks of max_len.
554
+ result = {
555
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
556
+ for k, t in concatenated_examples.items()
557
+ }
558
+ return result
559
+
560
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
561
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
562
+ # might be slower to preprocess.
563
+ #
564
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
565
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
566
+ tokenized_datasets = tokenized_datasets.map(
567
+ group_texts,
568
+ batched=True,
569
+ num_proc=data_args.preprocessing_num_workers,
570
+ load_from_cache_file=not data_args.overwrite_cache,
571
+ )
572
+
573
+ # test to see that tokenization worked
574
+ detokenized_example = tokenizer.decode(tokenized_datasets["train"][0]["input_ids"])
575
+ logger.info(f"Detokenized example: {detokenized_example}")
576
+ detokenized_example = tokenizer.decode(tokenized_datasets["train"][-1]["input_ids"])
577
+ logger.info(f"Detokenized example 2: {detokenized_example}")
578
+
579
+ # save the tokenized dataset for future runs
580
+ if data_args.save_tokenized_dataset_filepath is not None:
581
+ tokenized_datasets.save_to_disk(data_args.save_tokenized_dataset_filepath)
582
+
583
+ # Enable tensorboard only on the master node
584
+ has_tensorboard = is_tensorboard_available()
585
+ if has_tensorboard and jax.process_index() == 0:
586
+ try:
587
+ from flax.metrics.tensorboard import SummaryWriter
588
+
589
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
590
+ except ImportError as ie:
591
+ has_tensorboard = False
592
+ logger.warning(
593
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
594
+ )
595
+ else:
596
+ logger.warning(
597
+ "Unable to display metrics through TensorBoard because the package is not installed: "
598
+ "Please run pip install tensorboard to enable."
599
+ )
600
+
601
+ # Data collator
602
+ # This one will take care of randomly masking the tokens.
603
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
604
+
605
+ # Initialize our training
606
+ rng = jax.random.PRNGKey(training_args.seed)
607
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
608
+
609
+ if model_args.model_name_or_path:
610
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
611
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
612
+ )
613
+ else:
614
+ model = FlaxAutoModelForMaskedLM.from_config(
615
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
616
+ )
617
+
618
+
619
+ # Store some constant
620
+ num_epochs = int(training_args.num_train_epochs)
621
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
622
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
623
+
624
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
625
+
626
+ # Create learning rate schedule
627
+ warmup_fn = optax.linear_schedule(
628
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
629
+ )
630
+ decay_fn = optax.linear_schedule(
631
+ init_value=training_args.learning_rate,
632
+ end_value=0,
633
+ transition_steps=num_train_steps - training_args.warmup_steps,
634
+ )
635
+ linear_decay_lr_schedule_fn = optax.join_schedules(
636
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
637
+ )
638
+
639
+ # We use Optax's "masking" functionality to not apply weight decay
640
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
641
+ # mask boolean with the same structure as the parameters.
642
+ # The mask is True for parameters that should be decayed.
643
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
644
+ # For other models, one should correct the layer norm parameter naming
645
+ # accordingly.
646
+ def decay_mask_fn(params):
647
+ flat_params = traverse_util.flatten_dict(params)
648
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
649
+ return traverse_util.unflatten_dict(flat_mask)
650
+
651
+ # create adam optimizer
652
+ if training_args.adafactor:
653
+ # We use the default parameters here to initialize adafactor,
654
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
655
+ optimizer = optax.adafactor(
656
+ learning_rate=linear_decay_lr_schedule_fn,
657
+ )
658
+ elif training_args.adabelief:
659
+ optimizer = optax.adabelief(
660
+ learning_rate=linear_decay_lr_schedule_fn,
661
+ )
662
+ elif training_args.sm3:
663
+ optimizer = optax.sm3(
664
+ learning_rate=training_args.learning_rate,
665
+ )
666
+ else:
667
+ optimizer = optax.adamw(
668
+ learning_rate=linear_decay_lr_schedule_fn,
669
+ b1=training_args.adam_beta1,
670
+ b2=training_args.adam_beta2,
671
+ eps=training_args.adam_epsilon,
672
+ weight_decay=training_args.weight_decay,
673
+ mask=decay_mask_fn,
674
+ )
675
+
676
+ # Setup train state
677
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
678
+
679
+ # Define gradient update step fn
680
+ def train_step(state, batch, dropout_rng):
681
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
682
+
683
+ def loss_fn(params):
684
+ labels = batch.pop("labels")
685
+
686
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
687
+
688
+ # compute loss, ignore padded input tokens
689
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
690
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
691
+
692
+ # take average
693
+ loss = loss.sum() / label_mask.sum()
694
+
695
+ return loss
696
+
697
+ grad_fn = jax.value_and_grad(loss_fn)
698
+ loss, grad = grad_fn(state.params)
699
+ grad = jax.lax.pmean(grad, "batch")
700
+ new_state = state.apply_gradients(grads=grad)
701
+
702
+ metrics = jax.lax.pmean(
703
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
704
+ )
705
+
706
+ return new_state, metrics, new_dropout_rng
707
+
708
+ # Create parallel version of the train step
709
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
710
+
711
+ # Define eval fn
712
+ def eval_step(params, batch):
713
+ labels = batch.pop("labels")
714
+
715
+ logits = model(**batch, params=params, train=False)[0]
716
+
717
+ # compute loss, ignore padded input tokens
718
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
719
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
720
+
721
+ # compute accuracy
722
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
723
+
724
+ # summarize metrics
725
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
726
+ metrics = jax.lax.psum(metrics, axis_name="batch")
727
+
728
+ return metrics
729
+
730
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
731
+
732
+ # Replicate the train state on each device
733
+ state = jax_utils.replicate(state)
734
+
735
+ train_time = 0
736
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
737
+ for epoch in epochs:
738
+ # ======================== Training ================================
739
+ train_start = time.time()
740
+ train_metrics = []
741
+
742
+ # Create sampling rng
743
+ rng, input_rng = jax.random.split(rng)
744
+
745
+ # Generate an epoch by shuffling sampling indices from the train dataset
746
+ num_train_samples = len(tokenized_datasets["train"])
747
+ train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
748
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
749
+
750
+ # Gather the indexes for creating the batch and do a training step
751
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
752
+ samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
753
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
754
+
755
+ # Model forward
756
+ model_inputs = shard(model_inputs.data)
757
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
758
+ train_metrics.append(train_metric)
759
+
760
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
761
+
762
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
763
+ # Save metrics
764
+ train_metric = jax_utils.unreplicate(train_metric)
765
+ train_time += time.time() - train_start
766
+ if has_tensorboard and jax.process_index() == 0:
767
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
768
+
769
+ epochs.write(
770
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
771
+ )
772
+
773
+ train_metrics = []
774
+
775
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
776
+ # ======================== Evaluating ==============================
777
+ num_eval_samples = len(tokenized_datasets["validation"])
778
+ eval_samples_idx = jnp.arange(num_eval_samples)
779
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
780
+
781
+ eval_metrics = []
782
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
783
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
784
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
785
+
786
+ # Model forward
787
+ model_inputs = shard(model_inputs.data)
788
+ metrics = p_eval_step(state.params, model_inputs)
789
+ eval_metrics.append(metrics)
790
+
791
+ # normalize eval metrics
792
+ eval_metrics = get_metrics(eval_metrics)
793
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
794
+ eval_normalizer = eval_metrics.pop("normalizer")
795
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
796
+
797
+ # Update progress bar
798
+ epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
799
+
800
+ # Save metrics
801
+ if has_tensorboard and jax.process_index() == 0:
802
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
803
+
804
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
805
+ # save checkpoint after each epoch and push checkpoint to the hub
806
+ if jax.process_index() == 0:
807
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
808
+ model.save_pretrained(training_args.output_dir, params=params)
809
+ tokenizer.save_pretrained(training_args.output_dir)
810
+ if training_args.push_to_hub:
811
+ repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
812
+
813
+ # free memory so that epoch loop won't crash to resource exhausted error
814
+ del samples
815
+ del model_inputs
816
+ del num_train_samples
817
+ del train_samples_idx
818
+ del train_batch_idx
819
+ gc.collect()
820
+
821
+ # save also at the end of epoch
822
+ try:
823
+ if jax.process_index() == 0:
824
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
825
+ model.save_pretrained(training_args.output_dir, params=params)
826
+ tokenizer.save_pretrained(training_args.output_dir)
827
+ if training_args.push_to_hub:
828
+ repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
829
+ except:
830
+ # push to hub fails the whole script if nothing new to commit
831
+ pass
832
+
833
+ # Eval after training
834
+ if training_args.do_eval:
835
+ num_eval_samples = len(tokenized_datasets["validation"])
836
+ eval_samples_idx = jnp.arange(num_eval_samples)
837
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
838
+
839
+ eval_metrics = []
840
+ for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
841
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
842
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
843
+
844
+ # Model forward
845
+ model_inputs = shard(model_inputs.data)
846
+ metrics = p_eval_step(state.params, model_inputs)
847
+ eval_metrics.append(metrics)
848
+
849
+ # normalize eval metrics
850
+ eval_metrics = get_metrics(eval_metrics)
851
+ eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
852
+ eval_normalizer = eval_metrics.pop("normalizer")
853
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
854
+
855
+ try:
856
+ perplexity = math.exp(eval_metrics["loss"])
857
+ except OverflowError:
858
+ perplexity = float("inf")
859
+ eval_metrics["perplexity"] = perplexity
860
+
861
+ if jax.process_index() == 0:
862
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
863
+ path = os.path.join(training_args.output_dir, "eval_results.json")
864
+ with open(path, "w") as f:
865
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
866
+
867
+
868
+ if __name__ == "__main__":
869
+ main()
run_wechsel.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModel, AutoTokenizer, FlaxAutoModel
3
+ from datasets import load_dataset
4
+ from wechsel import WECHSEL, load_embeddings
5
+
6
+ source_tokenizer = AutoTokenizer.from_pretrained("roberta-large")
7
+ model = AutoModel.from_pretrained("roberta-large")
8
+
9
+ target_tokenizer = AutoTokenizer.from_pretrained("./")
10
+
11
+ wechsel = WECHSEL(
12
+ load_embeddings("en"),
13
+ load_embeddings("fi"),
14
+ bilingual_dictionary="finnish"
15
+ )
16
+
17
+ target_embeddings, info = wechsel.apply(
18
+ source_tokenizer,
19
+ target_tokenizer,
20
+ model.get_input_embeddings().weight.detach().numpy(),
21
+ )
22
+
23
+ model.get_input_embeddings().weight.data = torch.from_numpy(target_embeddings)
24
+
25
+ model.save_pretrained("./")
26
+
27
+ flax_model = FlaxAutoModel.from_pretrained("./", from_pt=True)
28
+ flax_model.save_pretrained("./")
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
start_train.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set train hyperparams
2
+ unset LD_PRELOAD
3
+ export HF_DATASETS_CACHE="/researchdisk/datasets_cache"
4
+ export USE_TORCH=0
5
+ python3 run_mlm_flax.py \
6
+ --output_dir="./" \
7
+ --model_name_or_path="./" \
8
+ --config_name="./" \
9
+ --tokenizer_name="./" \
10
+ --dataset_filepath="/researchdisk/training_dataset_full" \
11
+ --max_seq_length="128" \
12
+ --pad_to_max_length \
13
+ --preprocessing_num_workers="64" \
14
+ --per_device_train_batch_size="64" \
15
+ --per_device_eval_batch_size="64" \
16
+ --adam_beta1="0.9" \
17
+ --adam_beta2="0.98" \
18
+ --adam_epsilon="1e-6" \
19
+ --learning_rate="2e-4" \
20
+ --weight_decay="0.01" \
21
+ --warmup_steps="2500" \
22
+ --overwrite_output_dir \
23
+ --num_train_epochs="4" \
24
+ --save_steps="10000" \
25
+ --eval_steps="10000" \
26
+ --logging_steps="500" \
27
+ --dtype="bfloat16" \
28
+ --push_to_hub \
29
+ --hub_model_id="Finnish-NLP/roberta-large-wechsel-finnish"
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "add_prefix_space": false, "errors": "replace", "sep_token": "</s>", "cls_token": "<s>", "pad_token": "<pad>", "mask_token": "<mask>", "special_tokens_map_file": null, "name_or_path": "./", "tokenizer_class": "RobertaTokenizer"}
vocab.json ADDED
The diff for this file is too large to render. See raw diff